diff --git a/tarsum/tarsum.go b/tarsum/tarsum.go index 4ae71f0..fe15287 100644 --- a/tarsum/tarsum.go +++ b/tarsum/tarsum.go @@ -35,6 +35,14 @@ func NewTarSum(r io.Reader, dc bool, v Version) (TarSum, error) { return &tarSum{Reader: r, DisableCompression: dc, tarSumVersion: v}, nil } +// Create a new TarSum, providing a THash to use rather than the DefaultTHash +func NewTarSumHash(r io.Reader, dc bool, v Version, tHash THash) (TarSum, error) { + if _, ok := tarSumVersions[v]; !ok { + return nil, ErrVersionNotImplemented + } + return &tarSum{Reader: r, DisableCompression: dc, tarSumVersion: v, tHash: tHash}, nil +} + // TarSum is the generic interface for calculating fixed time // checksums of a tar archive type TarSum interface { @@ -42,6 +50,7 @@ type TarSum interface { GetSums() FileInfoSums Sum([]byte) string Version() Version + Hash() THash } // tarSum struct is the structure for a Version0 checksum calculation @@ -49,11 +58,12 @@ type tarSum struct { io.Reader tarR *tar.Reader tarW *tar.Writer - gz writeCloseFlusher + writer writeCloseFlusher bufTar *bytes.Buffer - bufGz *bytes.Buffer + bufWriter *bytes.Buffer bufData []byte h hash.Hash + tHash THash sums FileInfoSums fileCounter int64 currentFile string @@ -63,10 +73,36 @@ type tarSum struct { tarSumVersion Version // this field is not exported so it can not be mutated during use } +func (ts tarSum) Hash() THash { + return ts.tHash +} + func (ts tarSum) Version() Version { return ts.tarSumVersion } +// A hash.Hash type generator and its name +type THash interface { + Hash() hash.Hash + Name() string +} + +// Convenience method for creating a THash +func NewTHash(name string, h func() hash.Hash) THash { + return simpleTHash{n: name, h: h} +} + +// TarSum default is "sha256" +var DefaultTHash = NewTHash("sha256", sha256.New) + +type simpleTHash struct { + n string + h func() hash.Hash +} + +func (sth simpleTHash) Name() string { return sth.n } +func (sth simpleTHash) Hash() hash.Hash { return sth.h() } + func (ts tarSum) selectHeaders(h *tar.Header, v Version) (set [][2]string) { for _, elem := range [][2]string{ {"name", h.Name}, @@ -113,25 +149,35 @@ func (ts *tarSum) encodeHeader(h *tar.Header) error { return nil } +func (ts *tarSum) initTarSum() error { + ts.bufTar = bytes.NewBuffer([]byte{}) + ts.bufWriter = bytes.NewBuffer([]byte{}) + ts.tarR = tar.NewReader(ts.Reader) + ts.tarW = tar.NewWriter(ts.bufTar) + if !ts.DisableCompression { + ts.writer = gzip.NewWriter(ts.bufWriter) + } else { + ts.writer = &nopCloseFlusher{Writer: ts.bufWriter} + } + if ts.tHash == nil { + ts.tHash = DefaultTHash + } + ts.h = ts.tHash.Hash() + ts.h.Reset() + ts.first = true + ts.sums = FileInfoSums{} + return nil +} + func (ts *tarSum) Read(buf []byte) (int, error) { - if ts.gz == nil { - ts.bufTar = bytes.NewBuffer([]byte{}) - ts.bufGz = bytes.NewBuffer([]byte{}) - ts.tarR = tar.NewReader(ts.Reader) - ts.tarW = tar.NewWriter(ts.bufTar) - if !ts.DisableCompression { - ts.gz = gzip.NewWriter(ts.bufGz) - } else { - ts.gz = &nopCloseFlusher{Writer: ts.bufGz} + if ts.writer == nil { + if err := ts.initTarSum(); err != nil { + return 0, err } - ts.h = sha256.New() - ts.h.Reset() - ts.first = true - ts.sums = FileInfoSums{} } if ts.finished { - return ts.bufGz.Read(buf) + return ts.bufWriter.Read(buf) } if ts.bufData == nil { switch { @@ -167,10 +213,10 @@ func (ts *tarSum) Read(buf []byte) (int, error) { if err := ts.tarW.Close(); err != nil { return 0, err } - if _, err := io.Copy(ts.gz, ts.bufTar); err != nil { + if _, err := io.Copy(ts.writer, ts.bufTar); err != nil { return 0, err } - if err := ts.gz.Close(); err != nil { + if err := ts.writer.Close(); err != nil { return 0, err } ts.finished = true @@ -189,12 +235,12 @@ func (ts *tarSum) Read(buf []byte) (int, error) { return 0, err } ts.tarW.Flush() - if _, err := io.Copy(ts.gz, ts.bufTar); err != nil { + if _, err := io.Copy(ts.writer, ts.bufTar); err != nil { return 0, err } - ts.gz.Flush() + ts.writer.Flush() - return ts.bufGz.Read(buf) + return ts.bufWriter.Read(buf) } return n, err } @@ -210,18 +256,18 @@ func (ts *tarSum) Read(buf []byte) (int, error) { } ts.tarW.Flush() - // Filling the gz writter - if _, err = io.Copy(ts.gz, ts.bufTar); err != nil { + // Filling the output writer + if _, err = io.Copy(ts.writer, ts.bufTar); err != nil { return 0, err } - ts.gz.Flush() + ts.writer.Flush() - return ts.bufGz.Read(buf) + return ts.bufWriter.Read(buf) } func (ts *tarSum) Sum(extra []byte) string { ts.sums.SortBySums() - h := sha256.New() + h := ts.tHash.Hash() if extra != nil { h.Write(extra) } @@ -229,7 +275,7 @@ func (ts *tarSum) Sum(extra []byte) string { log.Debugf("-->%s<--", fis.Sum()) h.Write([]byte(fis.Sum())) } - checksum := ts.Version().String() + "+sha256:" + hex.EncodeToString(h.Sum(nil)) + checksum := ts.Version().String() + "+" + ts.tHash.Name() + ":" + hex.EncodeToString(h.Sum(nil)) log.Debugf("checksum processed: %s", checksum) return checksum } diff --git a/tarsum/tarsum_test.go b/tarsum/tarsum_test.go index d0b4c94..58dbdda 100644 --- a/tarsum/tarsum_test.go +++ b/tarsum/tarsum_test.go @@ -3,8 +3,11 @@ package tarsum import ( "bytes" "compress/gzip" + "crypto/md5" "crypto/rand" + "crypto/sha1" "crypto/sha256" + "crypto/sha512" "encoding/hex" "fmt" "io" @@ -22,6 +25,7 @@ type testLayer struct { gzip bool tarsum string version Version + hash THash } var testLayers = []testLayer{ @@ -75,6 +79,31 @@ var testLayers = []testLayer{ // this tar has newer of collider-1.tar, ensuring is has different hash filename: "testdata/collision/collision-3.tar", tarsum: "tarsum+sha256:f886e431c08143164a676805205979cd8fa535dfcef714db5515650eea5a7c0f"}, + { + options: &sizedOptions{1, 1024 * 1024, false, false}, // a 1mb file (in memory) + tarsum: "tarsum+md5:0d7529ec7a8360155b48134b8e599f53", + hash: md5THash, + }, + { + options: &sizedOptions{1, 1024 * 1024, false, false}, // a 1mb file (in memory) + tarsum: "tarsum+sha1:f1fee39c5925807ff75ef1925e7a23be444ba4df", + hash: sha1Hash, + }, + { + options: &sizedOptions{1, 1024 * 1024, false, false}, // a 1mb file (in memory) + tarsum: "tarsum+sha224:6319390c0b061d639085d8748b14cd55f697cf9313805218b21cf61c", + hash: sha224Hash, + }, + { + options: &sizedOptions{1, 1024 * 1024, false, false}, // a 1mb file (in memory) + tarsum: "tarsum+sha384:a578ce3ce29a2ae03b8ed7c26f47d0f75b4fc849557c62454be4b5ffd66ba021e713b48ce71e947b43aab57afd5a7636", + hash: sha384Hash, + }, + { + options: &sizedOptions{1, 1024 * 1024, false, false}, // a 1mb file (in memory) + tarsum: "tarsum+sha512:e9bfb90ca5a4dfc93c46ee061a5cf9837de6d2fdf82544d6460d3147290aecfabf7b5e415b9b6e72db9b8941f149d5d69fb17a394cbfaf2eac523bd9eae21855", + hash: sha512Hash, + }, } type sizedOptions struct { @@ -203,6 +232,14 @@ func TestEmptyTar(t *testing.T) { } } +var ( + md5THash = NewTHash("md5", md5.New) + sha1Hash = NewTHash("sha1", sha1.New) + sha224Hash = NewTHash("sha224", sha256.New224) + sha384Hash = NewTHash("sha384", sha512.New384) + sha512Hash = NewTHash("sha512", sha512.New) +) + func TestTarSums(t *testing.T) { for _, layer := range testLayers { var ( @@ -226,8 +263,13 @@ func TestTarSums(t *testing.T) { defer file.Close() } - // double negatives! - ts, err := NewTarSum(fh, !layer.gzip, layer.version) + var ts TarSum + if layer.hash == nil { + // double negatives! + ts, err = NewTarSum(fh, !layer.gzip, layer.version) + } else { + ts, err = NewTarSumHash(fh, !layer.gzip, layer.version, layer.hash) + } if err != nil { t.Errorf("%q :: %q", err, layer.filename) continue