1
0
Fork 0
mirror of https://github.com/vbatts/merkle.git synced 2024-12-02 19:15:39 +00:00

stream: tests and fix for Reset()

This commit is contained in:
Vincent Batts 2015-03-31 10:55:25 -04:00
parent 5277ffd66c
commit bedd7a10ff
2 changed files with 70 additions and 15 deletions

View file

@ -10,6 +10,10 @@ import (
// HashMaker for the checksums of the blocks written and the blockSize of each // HashMaker for the checksums of the blocks written and the blockSize of each
// block per node in the tree. // block per node in the tree.
func NewHash(hm HashMaker, merkleBlockLength int) HashTreeer { func NewHash(hm HashMaker, merkleBlockLength int) HashTreeer {
return newMerkleHash(hm, merkleBlockLength)
}
func newMerkleHash(hm HashMaker, merkleBlockLength int) *merkleHash {
mh := new(merkleHash) mh := new(merkleHash)
mh.blockSize = merkleBlockLength mh.blockSize = merkleBlockLength
mh.hm = hm mh.hm = hm
@ -44,6 +48,11 @@ type merkleHash struct {
partialLastNode bool // true when Sum() has appended a Node for a partial block partialLastNode bool // true when Sum() has appended a Node for a partial block
} }
func (mh *merkleHash) Reset() {
mh1 := newMerkleHash(mh.hm, mh.blockSize)
*mh = *mh1
}
func (mh merkleHash) Nodes() []*Node { func (mh merkleHash) Nodes() []*Node {
return mh.tree.Nodes return mh.tree.Nodes
} }
@ -80,6 +89,11 @@ func (mh *merkleHash) Sum(b []byte) []byte {
curBlock = append(curBlock, b...) curBlock = append(curBlock, b...)
} }
// incase we're at a new or reset state
if len(mh.tree.Nodes) == 0 && len(curBlock) == 0 {
return nil
}
for i := 0; i < len(curBlock)/mh.blockSize; i++ { for i := 0; i < len(curBlock)/mh.blockSize; i++ {
n, err := NewNodeHashBlock(mh.hm, curBlock[offset:(offset+mh.blockSize)]) n, err := NewNodeHashBlock(mh.hm, curBlock[offset:(offset+mh.blockSize)])
if err != nil { if err != nil {
@ -172,11 +186,6 @@ func (mh *merkleHash) Write(b []byte) (int, error) {
return numWritten, nil return numWritten, nil
} }
func (mh *merkleHash) Reset() {
mh.tree = &Tree{}
mh.lastBlock = nil
}
// likely not the best to pass this through and not use our own node block // likely not the best to pass this through and not use our own node block
// size, but let's revisit this. // size, but let's revisit this.
func (mh *merkleHash) BlockSize() int { return mh.hm().BlockSize() } func (mh *merkleHash) BlockSize() int { return mh.hm().BlockSize() }

View file

@ -2,12 +2,15 @@ package merkle
import ( import (
"bytes" "bytes"
"fmt"
"io" "io"
"testing" "testing"
) )
func TestMerkleHashWriter(t *testing.T) { func TestMerkleHashWriter(t *testing.T) {
msg := "the quick brown fox jumps over the lazy dog" msg := "the quick brown fox jumps over the lazy dog"
expectedSum := "48940c1c72636648ad40aa59c162f2208e835b38"
h := NewHash(DefaultHashMaker, 10) h := NewHash(DefaultHashMaker, 10)
i, err := io.Copy(h, bytes.NewBufferString(msg)) i, err := io.Copy(h, bytes.NewBufferString(msg))
if err != nil { if err != nil {
@ -17,26 +20,69 @@ func TestMerkleHashWriter(t *testing.T) {
t.Fatalf("expected to write %d, only wrote %d", len(msg), i) t.Fatalf("expected to write %d, only wrote %d", len(msg), i)
} }
var (
mh *merkleHash
ok bool
)
if mh, ok = h.(*merkleHash); !ok {
t.Fatalf("expected to get merkleHash, but got %#t", h)
}
// We're left with a partial lastBlock // We're left with a partial lastBlock
expectedNum := 4 expectedNum := 4
if len(mh.tree.Nodes) != expectedNum { if len(h.Nodes()) != expectedNum {
t.Errorf("expected %d nodes, got %d", expectedNum, len(mh.tree.Nodes)) t.Errorf("expected %d nodes, got %d", expectedNum, len(h.Nodes()))
} }
// Next test Sum() // Next test Sum()
gotSum := fmt.Sprintf("%x", h.Sum(nil))
if expectedSum != gotSum {
t.Errorf("expected initial checksum %q; got %q", expectedSum, gotSum)
}
// count blocks again, we should get 5 nodes now // count blocks again, we should get 5 nodes now
expectedNum = 5
if len(h.Nodes()) != expectedNum {
t.Errorf("expected %d nodes, got %d", expectedNum, len(h.Nodes()))
}
// Test Sum() again, ensure same sum // Test Sum() again, ensure same sum
gotSum = fmt.Sprintf("%x", h.Sum(nil))
if expectedSum != gotSum {
t.Errorf("expected checksum %q; got %q", expectedSum, gotSum)
}
// test that Reset() nulls us out
h.Reset()
gotSum = fmt.Sprintf("%x", h.Sum(nil))
if expectedSum == gotSum {
t.Errorf("expected reset checksum to not equal %q; got %q", expectedSum, gotSum)
}
// write our msg again and get the same sum
i, err = io.Copy(h, bytes.NewBufferString(msg))
if err != nil {
t.Fatal(err)
}
if i != int64(len(msg)) {
t.Fatalf("expected to write %d, only wrote %d", len(msg), i)
}
// Test Sum(), ensure same sum
gotSum = fmt.Sprintf("%x", h.Sum(nil))
if expectedSum != gotSum {
t.Errorf("expected checksum %q; got %q", expectedSum, gotSum)
}
// Write more. This should pop the last node, and use the lastBlock. // Write more. This should pop the last node, and use the lastBlock.
i, err = io.Copy(h, bytes.NewBufferString(msg))
if err != nil {
t.Fatal(err)
}
if i != int64(len(msg)) {
t.Fatalf("expected to write %d, only wrote %d", len(msg), i)
}
expectedNum = 9
if len(h.Nodes()) != expectedNum {
t.Errorf("expected %d nodes, got %d", expectedNum, len(h.Nodes()))
}
gotSum = fmt.Sprintf("%x", h.Sum(nil))
if expectedSum == gotSum {
t.Errorf("expected reset checksum to not equal %q; got %q", expectedSum, gotSum)
}
if len(h.Nodes()) != expectedNum {
t.Errorf("expected %d nodes, got %d", expectedNum, len(h.Nodes()))
}
} }