mirror of
https://github.com/vbatts/merkle.git
synced 2024-12-02 19:15:39 +00:00
stream: tests and fix for Reset()
This commit is contained in:
parent
5277ffd66c
commit
bedd7a10ff
2 changed files with 70 additions and 15 deletions
19
stream.go
19
stream.go
|
@ -10,6 +10,10 @@ import (
|
||||||
// HashMaker for the checksums of the blocks written and the blockSize of each
|
// HashMaker for the checksums of the blocks written and the blockSize of each
|
||||||
// block per node in the tree.
|
// block per node in the tree.
|
||||||
func NewHash(hm HashMaker, merkleBlockLength int) HashTreeer {
|
func NewHash(hm HashMaker, merkleBlockLength int) HashTreeer {
|
||||||
|
return newMerkleHash(hm, merkleBlockLength)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMerkleHash(hm HashMaker, merkleBlockLength int) *merkleHash {
|
||||||
mh := new(merkleHash)
|
mh := new(merkleHash)
|
||||||
mh.blockSize = merkleBlockLength
|
mh.blockSize = merkleBlockLength
|
||||||
mh.hm = hm
|
mh.hm = hm
|
||||||
|
@ -44,6 +48,11 @@ type merkleHash struct {
|
||||||
partialLastNode bool // true when Sum() has appended a Node for a partial block
|
partialLastNode bool // true when Sum() has appended a Node for a partial block
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (mh *merkleHash) Reset() {
|
||||||
|
mh1 := newMerkleHash(mh.hm, mh.blockSize)
|
||||||
|
*mh = *mh1
|
||||||
|
}
|
||||||
|
|
||||||
func (mh merkleHash) Nodes() []*Node {
|
func (mh merkleHash) Nodes() []*Node {
|
||||||
return mh.tree.Nodes
|
return mh.tree.Nodes
|
||||||
}
|
}
|
||||||
|
@ -80,6 +89,11 @@ func (mh *merkleHash) Sum(b []byte) []byte {
|
||||||
curBlock = append(curBlock, b...)
|
curBlock = append(curBlock, b...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// incase we're at a new or reset state
|
||||||
|
if len(mh.tree.Nodes) == 0 && len(curBlock) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
for i := 0; i < len(curBlock)/mh.blockSize; i++ {
|
for i := 0; i < len(curBlock)/mh.blockSize; i++ {
|
||||||
n, err := NewNodeHashBlock(mh.hm, curBlock[offset:(offset+mh.blockSize)])
|
n, err := NewNodeHashBlock(mh.hm, curBlock[offset:(offset+mh.blockSize)])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -172,11 +186,6 @@ func (mh *merkleHash) Write(b []byte) (int, error) {
|
||||||
return numWritten, nil
|
return numWritten, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mh *merkleHash) Reset() {
|
|
||||||
mh.tree = &Tree{}
|
|
||||||
mh.lastBlock = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// likely not the best to pass this through and not use our own node block
|
// likely not the best to pass this through and not use our own node block
|
||||||
// size, but let's revisit this.
|
// size, but let's revisit this.
|
||||||
func (mh *merkleHash) BlockSize() int { return mh.hm().BlockSize() }
|
func (mh *merkleHash) BlockSize() int { return mh.hm().BlockSize() }
|
||||||
|
|
|
@ -2,12 +2,15 @@ package merkle
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMerkleHashWriter(t *testing.T) {
|
func TestMerkleHashWriter(t *testing.T) {
|
||||||
msg := "the quick brown fox jumps over the lazy dog"
|
msg := "the quick brown fox jumps over the lazy dog"
|
||||||
|
expectedSum := "48940c1c72636648ad40aa59c162f2208e835b38"
|
||||||
|
|
||||||
h := NewHash(DefaultHashMaker, 10)
|
h := NewHash(DefaultHashMaker, 10)
|
||||||
i, err := io.Copy(h, bytes.NewBufferString(msg))
|
i, err := io.Copy(h, bytes.NewBufferString(msg))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -17,26 +20,69 @@ func TestMerkleHashWriter(t *testing.T) {
|
||||||
t.Fatalf("expected to write %d, only wrote %d", len(msg), i)
|
t.Fatalf("expected to write %d, only wrote %d", len(msg), i)
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
|
||||||
mh *merkleHash
|
|
||||||
ok bool
|
|
||||||
)
|
|
||||||
if mh, ok = h.(*merkleHash); !ok {
|
|
||||||
t.Fatalf("expected to get merkleHash, but got %#t", h)
|
|
||||||
}
|
|
||||||
|
|
||||||
// We're left with a partial lastBlock
|
// We're left with a partial lastBlock
|
||||||
expectedNum := 4
|
expectedNum := 4
|
||||||
if len(mh.tree.Nodes) != expectedNum {
|
if len(h.Nodes()) != expectedNum {
|
||||||
t.Errorf("expected %d nodes, got %d", expectedNum, len(mh.tree.Nodes))
|
t.Errorf("expected %d nodes, got %d", expectedNum, len(h.Nodes()))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Next test Sum()
|
// Next test Sum()
|
||||||
|
gotSum := fmt.Sprintf("%x", h.Sum(nil))
|
||||||
|
if expectedSum != gotSum {
|
||||||
|
t.Errorf("expected initial checksum %q; got %q", expectedSum, gotSum)
|
||||||
|
}
|
||||||
|
|
||||||
// count blocks again, we should get 5 nodes now
|
// count blocks again, we should get 5 nodes now
|
||||||
|
expectedNum = 5
|
||||||
|
if len(h.Nodes()) != expectedNum {
|
||||||
|
t.Errorf("expected %d nodes, got %d", expectedNum, len(h.Nodes()))
|
||||||
|
}
|
||||||
|
|
||||||
// Test Sum() again, ensure same sum
|
// Test Sum() again, ensure same sum
|
||||||
|
gotSum = fmt.Sprintf("%x", h.Sum(nil))
|
||||||
|
if expectedSum != gotSum {
|
||||||
|
t.Errorf("expected checksum %q; got %q", expectedSum, gotSum)
|
||||||
|
}
|
||||||
|
|
||||||
|
// test that Reset() nulls us out
|
||||||
|
h.Reset()
|
||||||
|
gotSum = fmt.Sprintf("%x", h.Sum(nil))
|
||||||
|
if expectedSum == gotSum {
|
||||||
|
t.Errorf("expected reset checksum to not equal %q; got %q", expectedSum, gotSum)
|
||||||
|
}
|
||||||
|
|
||||||
|
// write our msg again and get the same sum
|
||||||
|
i, err = io.Copy(h, bytes.NewBufferString(msg))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if i != int64(len(msg)) {
|
||||||
|
t.Fatalf("expected to write %d, only wrote %d", len(msg), i)
|
||||||
|
}
|
||||||
|
// Test Sum(), ensure same sum
|
||||||
|
gotSum = fmt.Sprintf("%x", h.Sum(nil))
|
||||||
|
if expectedSum != gotSum {
|
||||||
|
t.Errorf("expected checksum %q; got %q", expectedSum, gotSum)
|
||||||
|
}
|
||||||
|
|
||||||
// Write more. This should pop the last node, and use the lastBlock.
|
// Write more. This should pop the last node, and use the lastBlock.
|
||||||
|
i, err = io.Copy(h, bytes.NewBufferString(msg))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if i != int64(len(msg)) {
|
||||||
|
t.Fatalf("expected to write %d, only wrote %d", len(msg), i)
|
||||||
|
}
|
||||||
|
expectedNum = 9
|
||||||
|
if len(h.Nodes()) != expectedNum {
|
||||||
|
t.Errorf("expected %d nodes, got %d", expectedNum, len(h.Nodes()))
|
||||||
|
}
|
||||||
|
gotSum = fmt.Sprintf("%x", h.Sum(nil))
|
||||||
|
if expectedSum == gotSum {
|
||||||
|
t.Errorf("expected reset checksum to not equal %q; got %q", expectedSum, gotSum)
|
||||||
|
}
|
||||||
|
if len(h.Nodes()) != expectedNum {
|
||||||
|
t.Errorf("expected %d nodes, got %d", expectedNum, len(h.Nodes()))
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue