diff --git a/digest/algorithm.go b/digest/algorithm.go index f44b50db..b71314e6 100644 --- a/digest/algorithm.go +++ b/digest/algorithm.go @@ -72,6 +72,10 @@ func (a *Algorithm) Set(value string) error { *a = Algorithm(value) } + if !a.Available() { + return ErrDigestUnsupported + } + return nil } diff --git a/digest/algorithm_test.go b/digest/algorithm_test.go new file mode 100644 index 00000000..12db5d13 --- /dev/null +++ b/digest/algorithm_test.go @@ -0,0 +1,100 @@ +package digest + +import ( + "bytes" + "crypto/rand" + _ "crypto/sha256" + _ "crypto/sha512" + "flag" + "fmt" + "strings" + "testing" +) + +func TestFlagInterface(t *testing.T) { + var ( + alg Algorithm + flagSet flag.FlagSet + ) + + flagSet.Var(&alg, "algorithm", "set the digest algorithm") + for _, testcase := range []struct { + Name string + Args []string + Err error + Expected Algorithm + }{ + { + Name: "Invalid", + Args: []string{"-algorithm", "bean"}, + Err: ErrDigestUnsupported, + }, + { + Name: "Default", + Args: []string{"unrelated"}, + Expected: "sha256", + }, + { + Name: "Other", + Args: []string{"-algorithm", "sha512"}, + Expected: "sha512", + }, + } { + t.Run(testcase.Name, func(t *testing.T) { + alg = Canonical + if err := flagSet.Parse(testcase.Args); err != testcase.Err { + if testcase.Err == nil { + t.Fatal("unexpected error", err) + } + + // check that flag package returns correct error + if !strings.Contains(err.Error(), testcase.Err.Error()) { + t.Fatalf("unexpected error: %v != %v", err, testcase.Err) + } + return + } + + if alg != testcase.Expected { + t.Fatalf("unexpected algorithm: %v != %v", alg, testcase.Expected) + } + }) + } +} + +func TestFroms(t *testing.T) { + p := make([]byte, 1<<20) + rand.Read(p) + + for alg := range algorithms { + h := alg.Hash() + h.Write(p) + expected := Digest(fmt.Sprintf("%s:%x", alg, h.Sum(nil))) + readerDgst, err := alg.FromReader(bytes.NewReader(p)) + if err != nil { + t.Fatalf("error calculating hash from reader: %v", err) + } + + dgsts := []Digest{ + alg.FromBytes(p), + alg.FromString(string(p)), + readerDgst, + } + + if alg == Canonical { + readerDgst, err := FromReader(bytes.NewReader(p)) + if err != nil { + t.Fatalf("error calculating hash from reader: %v", err) + } + + dgsts = append(dgsts, + FromBytes(p), + FromString(string(p)), + readerDgst) + } + for _, dgst := range dgsts { + if dgst != expected { + t.Fatalf("unexpected digest %v != %v", dgst, expected) + } + } + } +} diff --git a/digest/digest.go b/digest/digest.go index 761f2e54..a9ab8e51 100644 --- a/digest/digest.go +++ b/digest/digest.go @@ -94,17 +94,10 @@ func FromString(s string) Digest { func (d Digest) Validate() error { s := string(d) - if !DigestRegexpAnchored.MatchString(s) { - return ErrDigestInvalidFormat - } - i := strings.Index(s, ":") - if i < 0 { - return ErrDigestInvalidFormat - } - // case: "sha256:" with no hex. - if i+1 == len(s) { + // validate i then run through regexp + if i < 0 || i+1 == len(s) || !DigestRegexpAnchored.MatchString(s) { return ErrDigestInvalidFormat } diff --git a/digest/digest_test.go b/digest/digest_test.go index 62158add..d85c476f 100644 --- a/digest/digest_test.go +++ b/digest/digest_test.go @@ -1,8 +1,6 @@ package digest import ( - _ "crypto/sha256" - _ "crypto/sha512" "testing" ) @@ -28,6 +26,11 @@ func TestParseDigest(t *testing.T) { input: "sha256:", err: ErrDigestInvalidFormat, }, + { + // empty hex + input: ":", + err: ErrDigestInvalidFormat, + }, { // just hex input: "d41d8cd98f00b204e9800998ecf8427e", @@ -80,5 +83,10 @@ func TestParseDigest(t *testing.T) { if newParsed != digest { t.Fatalf("expected equal: %q != %q", newParsed, digest) } + + newFromHex := NewDigestFromHex(newParsed.Algorithm().String(), newParsed.Hex()) + if newFromHex != digest { + t.Fatalf("%v != %v", newFromHex, digest) + } } }