Skip to content

Commit f94dea5

Browse files
authored
perf(types): 3x speedup MakePartSet (#3117)
<!-- Please add a reference to the issue that this PR addresses and indicate which files are most critical to review. If it fully addresses a particular issue, please include "Closes #XXX" (where "XXX" is the issue number). If this PR is non-trivial/large/complex, please ensure that you have either created an issue that the team's had a chance to respond to, or had some discussion with the team prior to submitting substantial pull requests. The team can be reached via GitHub Discussions or the Cosmos Network Discord server in the #cometbft channel. GitHub Discussions is preferred over Discord as it allows us to keep track of conversations topically. https://github.com/cometbft/cometbft/discussions If the work in this PR is not aligned with the team's current priorities, please be advised that it may take some time before it is merged - especially if it has not yet been discussed with the team. See the project board for the team's current priorities: https://github.com/orgs/cometbft/projects/1 --> This PR adds some benchmarks, and significantly speeds up types.MakePartSet, and Partset.AddPart. (Used by the block proposer, and every consensus instance) It does so by doing two things: - Saving mutexes on the newly created bit array, by defaulting every value to True (rather than setting it in a loop that goes through a mutex) - Uses the same hash object throughout, and avoids an extra copy of every leaf. (main speedup) I do the same hash optimization for proof.Verify, which is used in the add block part codepath for both the proposer and every full node. New: ``` BenchmarkMakePartSet/nParts=1-12 38616 29817 ns/op 568 B/op 12 allocs/op BenchmarkMakePartSet/nParts=2-12 19888 59866 ns/op 1000 B/op 22 allocs/op BenchmarkMakePartSet/nParts=3-12 12979 95691 ns/op 1528 B/op 33 allocs/op BenchmarkMakePartSet/nParts=4-12 8688 128192 ns/op 2024 B/op 44 allocs/op BenchmarkMakePartSet/nParts=5-12 7308 155224 ns/op 2888 B/op 57 allocs/op ``` Old: ``` BenchmarkMakePartSet/nParts=1-12 16647 106545 ns/op 74169 B/op 12 allocs/op BenchmarkMakePartSet/nParts=2-12 10000 106361 ns/op 148329 B/op 23 allocs/op BenchmarkMakePartSet/nParts=3-12 6992 337644 ns/op 222587 B/op 35 allocs/op BenchmarkMakePartSet/nParts=4-12 3488 480109 ns/op 296811 B/op 47 allocs/op BenchmarkMakePartSet/nParts=5-12 2228 557768 ns/op 371404 B/op 61 allocs/op ``` System wide, this is definitely not our issue (looks like roughly .1ms per blockpart), but still definitely useful time to remove --- #### PR checklist - [x] Tests written/updated - [x] Changelog entry added in `.changelog` (we use [unclog](https://github.com/informalsystems/unclog) to manage our changelog) - [x] Updated relevant documentation (`docs/` or `spec/`) and code comments - [x] Title follows the [Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/) spec
1 parent baca084 commit f94dea5

6 files changed

Lines changed: 58 additions & 15 deletions

File tree

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
- [`types`] Significantly speedup types.MakePartSet and types.AddPart, which are used in creating a block proposal
2+
([\#3117](https://github.com/cometbft/cometbft/issues/3117)

crypto/merkle/bench_test.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,25 @@ func BenchmarkInnerHash(b *testing.B) {
4040
b.Fatal("Benchmark did not run!")
4141
}
4242
}
43+
44+
// Benchmark the time it takes to hash a 64kb leaf, which is the size of
45+
// a block part.
46+
// This helps determine whether its worth parallelizing this hash for the proposer.
47+
func BenchmarkLeafHash64kb(b *testing.B) {
48+
b.ReportAllocs()
49+
leaf := make([]byte, 64*1024)
50+
hash := sha256.New()
51+
52+
for i := 0; i < b.N; i++ {
53+
leaf[0] = byte(i)
54+
got := leafHashOpt(hash, leaf)
55+
if g, w := len(got), sha256.Size; g != w {
56+
b.Fatalf("size discrepancy: got %d, want %d", g, w)
57+
}
58+
sink = got
59+
}
60+
61+
if sink == nil {
62+
b.Fatal("Benchmark did not run!")
63+
}
64+
}

crypto/merkle/proof.go

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"bytes"
55
"errors"
66
"fmt"
7+
"hash"
78

89
cmtcrypto "github.com/cometbft/cometbft/api/cometbft/crypto/v1"
910
"github.com/cometbft/cometbft/crypto/tmhash"
@@ -91,13 +92,14 @@ func (sp *Proof) Verify(rootHash []byte, leaf []byte) error {
9192
Err: errors.New("negative proof index"),
9293
}
9394
}
94-
leafHash := leafHash(leaf)
95+
hash := tmhash.New()
96+
leafHash := leafHashOpt(hash, leaf)
9597
if !bytes.Equal(sp.LeafHash, leafHash) {
9698
return ErrInvalidHash{
9799
Err: fmt.Errorf("leaf %x, want %x", sp.LeafHash, leafHash),
98100
}
99101
}
100-
computedHash, err := sp.computeRootHash()
102+
computedHash, err := sp.computeRootHash(hash)
101103
if err != nil {
102104
return ErrInvalidHash{
103105
Err: fmt.Errorf("compute root hash: %w", err),
@@ -112,8 +114,9 @@ func (sp *Proof) Verify(rootHash []byte, leaf []byte) error {
112114
}
113115

114116
// Compute the root hash given a leaf hash.
115-
func (sp *Proof) computeRootHash() ([]byte, error) {
117+
func (sp *Proof) computeRootHash(hash hash.Hash) ([]byte, error) {
116118
return computeHashFromAunts(
119+
hash,
117120
sp.Index,
118121
sp.Total,
119122
sp.LeafHash,
@@ -200,7 +203,7 @@ func ProofFromProto(pb *cmtcrypto.Proof) (*Proof, error) {
200203
// Use the leafHash and innerHashes to get the root merkle hash.
201204
// If the length of the innerHashes slice isn't exactly correct, the result is nil.
202205
// Recursive impl.
203-
func computeHashFromAunts(index, total int64, leafHash []byte, innerHashes [][]byte) ([]byte, error) {
206+
func computeHashFromAunts(hash hash.Hash, index, total int64, leafHash []byte, innerHashes [][]byte) ([]byte, error) {
204207
if index >= total || index < 0 || total <= 0 {
205208
return nil, fmt.Errorf("invalid index %d and/or total %d", index, total)
206209
}
@@ -218,18 +221,18 @@ func computeHashFromAunts(index, total int64, leafHash []byte, innerHashes [][]b
218221
}
219222
numLeft := getSplitPoint(total)
220223
if index < numLeft {
221-
leftHash, err := computeHashFromAunts(index, numLeft, leafHash, innerHashes[:len(innerHashes)-1])
224+
leftHash, err := computeHashFromAunts(hash, index, numLeft, leafHash, innerHashes[:len(innerHashes)-1])
222225
if err != nil {
223226
return nil, err
224227
}
225228

226-
return innerHash(leftHash, innerHashes[len(innerHashes)-1]), nil
229+
return innerHashOpt(hash, leftHash, innerHashes[len(innerHashes)-1]), nil
227230
}
228-
rightHash, err := computeHashFromAunts(index-numLeft, total-numLeft, leafHash, innerHashes[:len(innerHashes)-1])
231+
rightHash, err := computeHashFromAunts(hash, index-numLeft, total-numLeft, leafHash, innerHashes[:len(innerHashes)-1])
229232
if err != nil {
230233
return nil, err
231234
}
232-
return innerHash(innerHashes[len(innerHashes)-1], rightHash), nil
235+
return innerHashOpt(hash, innerHashes[len(innerHashes)-1], rightHash), nil
233236
}
234237
}
235238

@@ -266,18 +269,22 @@ func (spn *ProofNode) FlattenAunts() [][]byte {
266269
// trails[0].Hash is the leaf hash for items[0].
267270
// trails[i].Parent.Parent....Parent == root for all i.
268271
func trailsFromByteSlices(items [][]byte) (trails []*ProofNode, root *ProofNode) {
272+
return trailsFromByteSlicesInternal(tmhash.New(), items)
273+
}
274+
275+
func trailsFromByteSlicesInternal(hash hash.Hash, items [][]byte) (trails []*ProofNode, root *ProofNode) {
269276
// Recursive impl.
270277
switch len(items) {
271278
case 0:
272279
return []*ProofNode{}, &ProofNode{emptyHash(), nil, nil, nil}
273280
case 1:
274-
trail := &ProofNode{leafHash(items[0]), nil, nil, nil}
281+
trail := &ProofNode{leafHashOpt(hash, items[0]), nil, nil, nil}
275282
return []*ProofNode{trail}, trail
276283
default:
277284
k := getSplitPoint(int64(len(items)))
278-
lefts, leftRoot := trailsFromByteSlices(items[:k])
279-
rights, rightRoot := trailsFromByteSlices(items[k:])
280-
rootHash := innerHash(leftRoot.Hash, rightRoot.Hash)
285+
lefts, leftRoot := trailsFromByteSlicesInternal(hash, items[:k])
286+
rights, rightRoot := trailsFromByteSlicesInternal(hash, items[k:])
287+
rootHash := innerHashOpt(hash, leftRoot.Hash, rightRoot.Hash)
281288
root := &ProofNode{rootHash, nil, nil, nil}
282289
leftRoot.Parent = root
283290
leftRoot.Right = rightRoot

crypto/merkle/proof_value.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ func (op ValueOp) Run(args [][]byte) ([][]byte, error) {
104104
}
105105
}
106106

107-
rootHash, err := op.Proof.computeRootHash()
107+
rootHash, err := op.Proof.computeRootHash(tmhash.New())
108108
if err != nil {
109109
return nil, err
110110
}

types/part_set.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,21 +180,20 @@ func NewPartSetFromData(data []byte, partSize uint32) *PartSet {
180180
total := (uint32(len(data)) + partSize - 1) / partSize
181181
parts := make([]*Part, total)
182182
partsBytes := make([][]byte, total)
183-
partsBitArray := bits.NewBitArray(int(total))
184183
for i := uint32(0); i < total; i++ {
185184
part := &Part{
186185
Index: i,
187186
Bytes: data[i*partSize : cmtmath.MinInt(len(data), int((i+1)*partSize))],
188187
}
189188
parts[i] = part
190189
partsBytes[i] = part.Bytes
191-
partsBitArray.SetIndex(int(i), true)
192190
}
193191
// Compute merkle proofs
194192
root, proofs := merkle.ProofsFromByteSlices(partsBytes)
195193
for i := uint32(0); i < total; i++ {
196194
parts[i].Proof = *proofs[i]
197195
}
196+
partsBitArray := bits.NewBitArrayFromFn(int(total), func(int) bool { return true })
198197
return &PartSet{
199198
total: total,
200199
hash: root,

types/part_set_test.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package types
22

33
import (
4+
"fmt"
45
"io"
56
"testing"
67

@@ -219,3 +220,15 @@ func TestPartProtoBuf(t *testing.T) {
219220
}
220221
}
221222
}
223+
224+
func BenchmarkMakePartSet(b *testing.B) {
225+
for nParts := 1; nParts <= 5; nParts++ {
226+
b.Run(fmt.Sprintf("nParts=%d", nParts), func(b *testing.B) {
227+
data := cmtrand.Bytes(testPartSize * nParts)
228+
b.ResetTimer()
229+
for i := 0; i < b.N; i++ {
230+
NewPartSetFromData(data, testPartSize)
231+
}
232+
})
233+
}
234+
}

0 commit comments

Comments
 (0)