diff --git a/digest/algorithm.go b/digest/algorithm.go index 9f74e552..f44b50db 100644 --- a/digest/algorithm.go +++ b/digest/algorithm.go @@ -94,6 +94,11 @@ func (a Algorithm) New() Digester { // method will panic. Check Algorithm.Available() before calling. func (a Algorithm) Hash() hash.Hash { if !a.Available() { + // Empty algorithm string is invalid + if a == "" { + panic(fmt.Sprintf("empty digest algorithm, validate before calling Algorithm.Hash()")) + } + // NOTE(stevvooe): A missing hash is usually a programming error that // must be resolved at compile time. We don't import in the digest // package to allow users to choose their hash implementation (such as diff --git a/digest/digest.go b/digest/digest.go index 7e6e2d89..4616788b 100644 --- a/digest/digest.go +++ b/digest/digest.go @@ -104,16 +104,17 @@ func (d Digest) Validate() error { return ErrDigestInvalidFormat } - switch algorithm := Algorithm(s[:i]); algorithm { - case SHA256, SHA384, SHA512: - if algorithm.Size()*2 != len(s[i+1:]) { - return ErrDigestInvalidLength - } - break - default: + algorithm := Algorithm(s[:i]) + if !algorithm.Available() { return ErrDigestUnsupported } + // Digests much always be hex-encoded, ensuring that their hex portion will + // always be size*2 + if algorithm.Size()*2 != len(s[i+1:]) { + return ErrDigestInvalidLength + } + return nil } @@ -124,17 +125,12 @@ func (d Digest) Algorithm() Algorithm { } // Verifier returns a writer object that can be used to verify a stream of -// content against the digest. If the digest is invalid, an error will be -// returned. -func (d Digest) Verifier() (Verifier, error) { - if err := d.Validate(); err != nil { - return nil, err - } - +// content against the digest. If the digest is invalid, the method will panic. +func (d Digest) Verifier() Verifier { return hashVerifier{ hash: d.Algorithm().Hash(), digest: d, - }, nil + } } // Hex returns the hex digest portion of the digest. This will panic if the @@ -151,7 +147,7 @@ func (d Digest) sepIndex() int { i := strings.Index(string(d), ":") if i < 0 { - panic("could not find ':' in digest: " + d) + panic(fmt.Sprintf("no ':' separator in digest %q", d)) } return i diff --git a/digest/digest_test.go b/digest/digest_test.go index afb4ebf6..c27b8171 100644 --- a/digest/digest_test.go +++ b/digest/digest_test.go @@ -1,6 +1,8 @@ package digest import ( + _ "crypto/sha256" + _ "crypto/sha512" "testing" ) diff --git a/digest/verifiers.go b/digest/verifiers.go index 89c1fb08..df9872cf 100644 --- a/digest/verifiers.go +++ b/digest/verifiers.go @@ -19,14 +19,7 @@ type Verifier interface { // NewDigestVerifier is deprecated. Please use Digest.Verifier. func NewDigestVerifier(d Digest) (Verifier, error) { - if err := d.Validate(); err != nil { - return nil, err - } - - return hashVerifier{ - hash: d.Algorithm().Hash(), - digest: d, - }, nil + return d.Verifier(), nil } type hashVerifier struct { diff --git a/digest/verifiers_test.go b/digest/verifiers_test.go index c342d6e7..251d4fce 100644 --- a/digest/verifiers_test.go +++ b/digest/verifiers_test.go @@ -4,6 +4,7 @@ import ( "bytes" "crypto/rand" "io" + "reflect" "testing" ) @@ -12,10 +13,7 @@ func TestDigestVerifier(t *testing.T) { rand.Read(p) digest := FromBytes(p) - verifier, err := NewDigestVerifier(digest) - if err != nil { - t.Fatalf("unexpected error getting digest verifier: %s", err) - } + verifier := digest.Verifier() io.Copy(verifier, bytes.NewReader(p)) @@ -27,23 +25,42 @@ func TestDigestVerifier(t *testing.T) { // TestVerifierUnsupportedDigest ensures that unsupported digest validation is // flowing through verifier creation. func TestVerifierUnsupportedDigest(t *testing.T) { - unsupported := Digest("bean:0123456789abcdef") + for _, testcase := range []struct { + Name string + Digest Digest + Expected interface{} // expected panic target + }{ + { + Name: "Empty", + Digest: "", + Expected: "no ':' separator in digest \"\"", + }, + { + Name: "EmptyAlg", + Digest: ":", + Expected: "empty digest algorithm, validate before calling Algorithm.Hash()", + }, + { + Name: "Unsupported", + Digest: Digest("bean:0123456789abcdef"), + Expected: "bean not available (make sure it is imported)", + }, + { + Name: "Garbage", + Digest: Digest("sha256-garbage:pure"), + Expected: "sha256-garbage not available (make sure it is imported)", + }, + } { + t.Run(testcase.Name, func(t *testing.T) { + expected := testcase.Expected + defer func() { + recovered := recover() + if !reflect.DeepEqual(recovered, expected) { + t.Fatalf("unexpected recover: %v != %v", recovered, expected) + } + }() - _, err := NewDigestVerifier(unsupported) - if err == nil { - t.Fatalf("expected error when creating verifier") - } - - if err != ErrDigestUnsupported { - t.Fatalf("incorrect error for unsupported digest: %v", err) + _ = testcase.Digest.Verifier() + }) } } - -// TODO(stevvooe): Add benchmarks to measure bytes/second throughput for -// DigestVerifier. -// -// The relevant benchmark for comparison can be run with the following -// commands: -// -// go test -bench . crypto/sha1 -// diff --git a/registry/storage/blobwriter.go b/registry/storage/blobwriter.go index 7e42a59b..e2f03f5f 100644 --- a/registry/storage/blobwriter.go +++ b/registry/storage/blobwriter.go @@ -235,11 +235,7 @@ func (bw *blobWriter) validateBlob(ctx context.Context, desc distribution.Descri // guarantee, so this may be defensive. if !verified { digester := digest.Canonical.New() - - digestVerifier, err := desc.Digest.Verifier() - if err != nil { - return distribution.Descriptor{}, err - } + verifier := desc.Digest.Verifier() // Read the file from the backend driver and validate it. fr, err := newFileReader(ctx, bw.driver, bw.path, desc.Size) @@ -250,12 +246,12 @@ func (bw *blobWriter) validateBlob(ctx context.Context, desc distribution.Descri tr := io.TeeReader(fr, digester.Hash()) - if _, err := io.Copy(digestVerifier, tr); err != nil { + if _, err := io.Copy(verifier, tr); err != nil { return distribution.Descriptor{}, err } canonical = digester.Digest() - verified = digestVerifier.Verified() + verified = verifier.Verified() } } diff --git a/registry/storage/filereader_test.go b/registry/storage/filereader_test.go index fa363f44..371a410d 100644 --- a/registry/storage/filereader_test.go +++ b/registry/storage/filereader_test.go @@ -41,11 +41,7 @@ func TestSimpleRead(t *testing.T) { t.Fatalf("error allocating file reader: %v", err) } - verifier, err := dgst.Verifier() - if err != nil { - t.Fatalf("error getting digest verifier: %s", err) - } - + verifier := dgst.Verifier() io.Copy(verifier, fr) if !verifier.Verified() {