From bedd7a10fff556ea9b6ceb5e40799c603a3c0a74 Mon Sep 17 00:00:00 2001 From: Vincent Batts Date: Tue, 31 Mar 2015 10:55:25 -0400 Subject: [PATCH] stream: tests and fix for Reset() --- stream.go | 19 +++++++++++---- stream_test.go | 66 ++++++++++++++++++++++++++++++++++++++++++-------- 2 files changed, 70 insertions(+), 15 deletions(-) diff --git a/stream.go b/stream.go index 88bdfc7..c1f51e7 100644 --- a/stream.go +++ b/stream.go @@ -10,6 +10,10 @@ import ( // HashMaker for the checksums of the blocks written and the blockSize of each // block per node in the tree. func NewHash(hm HashMaker, merkleBlockLength int) HashTreeer { + return newMerkleHash(hm, merkleBlockLength) +} + +func newMerkleHash(hm HashMaker, merkleBlockLength int) *merkleHash { mh := new(merkleHash) mh.blockSize = merkleBlockLength mh.hm = hm @@ -44,6 +48,11 @@ type merkleHash struct { partialLastNode bool // true when Sum() has appended a Node for a partial block } +func (mh *merkleHash) Reset() { + mh1 := newMerkleHash(mh.hm, mh.blockSize) + *mh = *mh1 +} + func (mh merkleHash) Nodes() []*Node { return mh.tree.Nodes } @@ -80,6 +89,11 @@ func (mh *merkleHash) Sum(b []byte) []byte { curBlock = append(curBlock, b...) } + // incase we're at a new or reset state + if len(mh.tree.Nodes) == 0 && len(curBlock) == 0 { + return nil + } + for i := 0; i < len(curBlock)/mh.blockSize; i++ { n, err := NewNodeHashBlock(mh.hm, curBlock[offset:(offset+mh.blockSize)]) if err != nil { @@ -172,11 +186,6 @@ func (mh *merkleHash) Write(b []byte) (int, error) { return numWritten, nil } -func (mh *merkleHash) Reset() { - mh.tree = &Tree{} - mh.lastBlock = nil -} - // likely not the best to pass this through and not use our own node block // size, but let's revisit this. func (mh *merkleHash) BlockSize() int { return mh.hm().BlockSize() } diff --git a/stream_test.go b/stream_test.go index 5c9ebad..4afa42e 100644 --- a/stream_test.go +++ b/stream_test.go @@ -2,12 +2,15 @@ package merkle import ( "bytes" + "fmt" "io" "testing" ) func TestMerkleHashWriter(t *testing.T) { msg := "the quick brown fox jumps over the lazy dog" + expectedSum := "48940c1c72636648ad40aa59c162f2208e835b38" + h := NewHash(DefaultHashMaker, 10) i, err := io.Copy(h, bytes.NewBufferString(msg)) if err != nil { @@ -17,26 +20,69 @@ func TestMerkleHashWriter(t *testing.T) { t.Fatalf("expected to write %d, only wrote %d", len(msg), i) } - var ( - mh *merkleHash - ok bool - ) - if mh, ok = h.(*merkleHash); !ok { - t.Fatalf("expected to get merkleHash, but got %#t", h) - } - // We're left with a partial lastBlock expectedNum := 4 - if len(mh.tree.Nodes) != expectedNum { - t.Errorf("expected %d nodes, got %d", expectedNum, len(mh.tree.Nodes)) + if len(h.Nodes()) != expectedNum { + t.Errorf("expected %d nodes, got %d", expectedNum, len(h.Nodes())) } // Next test Sum() + gotSum := fmt.Sprintf("%x", h.Sum(nil)) + if expectedSum != gotSum { + t.Errorf("expected initial checksum %q; got %q", expectedSum, gotSum) + } // count blocks again, we should get 5 nodes now + expectedNum = 5 + if len(h.Nodes()) != expectedNum { + t.Errorf("expected %d nodes, got %d", expectedNum, len(h.Nodes())) + } // Test Sum() again, ensure same sum + gotSum = fmt.Sprintf("%x", h.Sum(nil)) + if expectedSum != gotSum { + t.Errorf("expected checksum %q; got %q", expectedSum, gotSum) + } + + // test that Reset() nulls us out + h.Reset() + gotSum = fmt.Sprintf("%x", h.Sum(nil)) + if expectedSum == gotSum { + t.Errorf("expected reset checksum to not equal %q; got %q", expectedSum, gotSum) + } + + // write our msg again and get the same sum + i, err = io.Copy(h, bytes.NewBufferString(msg)) + if err != nil { + t.Fatal(err) + } + if i != int64(len(msg)) { + t.Fatalf("expected to write %d, only wrote %d", len(msg), i) + } + // Test Sum(), ensure same sum + gotSum = fmt.Sprintf("%x", h.Sum(nil)) + if expectedSum != gotSum { + t.Errorf("expected checksum %q; got %q", expectedSum, gotSum) + } // Write more. This should pop the last node, and use the lastBlock. + i, err = io.Copy(h, bytes.NewBufferString(msg)) + if err != nil { + t.Fatal(err) + } + if i != int64(len(msg)) { + t.Fatalf("expected to write %d, only wrote %d", len(msg), i) + } + expectedNum = 9 + if len(h.Nodes()) != expectedNum { + t.Errorf("expected %d nodes, got %d", expectedNum, len(h.Nodes())) + } + gotSum = fmt.Sprintf("%x", h.Sum(nil)) + if expectedSum == gotSum { + t.Errorf("expected reset checksum to not equal %q; got %q", expectedSum, gotSum) + } + if len(h.Nodes()) != expectedNum { + t.Errorf("expected %d nodes, got %d", expectedNum, len(h.Nodes())) + } }