package token import ( "crypto" "crypto/rand" "crypto/x509" "encoding/base64" "encoding/json" "encoding/pem" "fmt" "io/ioutil" "net/http" "os" "strings" "testing" "time" "github.com/docker/docker-registry/auth" "github.com/docker/libtrust" ) func makeRootKeys(numKeys int) ([]libtrust.PrivateKey, error) { keys := make([]libtrust.PrivateKey, 0, numKeys) for i := 0; i < numKeys; i++ { key, err := libtrust.GenerateECP256PrivateKey() if err != nil { return nil, err } keys = append(keys, key) } return keys, nil } func makeSigningKeyWithChain(rootKey libtrust.PrivateKey, depth int) (libtrust.PrivateKey, error) { if depth == 0 { // Don't need to build a chain. return rootKey, nil } var ( x5c = make([]string, depth) parentKey = rootKey key libtrust.PrivateKey cert *x509.Certificate err error ) for depth > 0 { if key, err = libtrust.GenerateECP256PrivateKey(); err != nil { return nil, err } if cert, err = libtrust.GenerateCACert(parentKey, key); err != nil { return nil, err } depth-- x5c[depth] = base64.StdEncoding.EncodeToString(cert.Raw) parentKey = key } key.AddExtendedField("x5c", x5c) return key, nil } func makeRootCerts(rootKeys []libtrust.PrivateKey) ([]*x509.Certificate, error) { certs := make([]*x509.Certificate, 0, len(rootKeys)) for _, key := range rootKeys { cert, err := libtrust.GenerateCACert(key, key) if err != nil { return nil, err } certs = append(certs, cert) } return certs, nil } func makeTrustedKeyMap(rootKeys []libtrust.PrivateKey) map[string]libtrust.PublicKey { trustedKeys := make(map[string]libtrust.PublicKey, len(rootKeys)) for _, key := range rootKeys { trustedKeys[key.KeyID()] = key.PublicKey() } return trustedKeys } func makeTestToken(issuer, audience string, access []*ResourceActions, rootKey libtrust.PrivateKey, depth int) (*Token, error) { signingKey, err := makeSigningKeyWithChain(rootKey, depth) if err != nil { return nil, fmt.Errorf("unable to amke signing key with chain: %s", err) } rawJWK, err := signingKey.PublicKey().MarshalJSON() if err != nil { return nil, fmt.Errorf("unable to marshal signing key to JSON: %s", err) } joseHeader := &Header{ Type: "JWT", SigningAlg: "ES256", RawJWK: json.RawMessage(rawJWK), } now := time.Now() randomBytes := make([]byte, 15) if _, err = rand.Read(randomBytes); err != nil { return nil, fmt.Errorf("unable to read random bytes for jwt id: %s", err) } claimSet := &ClaimSet{ Issuer: issuer, Subject: "foo", Audience: audience, Expiration: now.Add(5 * time.Minute).Unix(), NotBefore: now.Unix(), IssuedAt: now.Unix(), JWTID: base64.URLEncoding.EncodeToString(randomBytes), Access: access, } var joseHeaderBytes, claimSetBytes []byte if joseHeaderBytes, err = json.Marshal(joseHeader); err != nil { return nil, fmt.Errorf("unable to marshal jose header: %s", err) } if claimSetBytes, err = json.Marshal(claimSet); err != nil { return nil, fmt.Errorf("unable to marshal claim set: %s", err) } encodedJoseHeader := joseBase64UrlEncode(joseHeaderBytes) encodedClaimSet := joseBase64UrlEncode(claimSetBytes) encodingToSign := fmt.Sprintf("%s.%s", encodedJoseHeader, encodedClaimSet) var signatureBytes []byte if signatureBytes, _, err = signingKey.Sign(strings.NewReader(encodingToSign), crypto.SHA256); err != nil { return nil, fmt.Errorf("unable to sign jwt payload: %s", err) } signature := joseBase64UrlEncode(signatureBytes) tokenString := fmt.Sprintf("%s.%s", encodingToSign, signature) return NewToken(tokenString) } // This test makes 4 tokens with a varying number of intermediate // certificates ranging from no intermediate chain to a length of 3 // intermediates. func TestTokenVerify(t *testing.T) { var ( numTokens = 4 issuer = "test-issuer" audience = "test-audience" access = []*ResourceActions{ { Type: "repository", Name: "foo/bar", Actions: []string{"pull", "push"}, }, } ) rootKeys, err := makeRootKeys(numTokens) if err != nil { t.Fatal(err) } rootCerts, err := makeRootCerts(rootKeys) if err != nil { t.Fatal(err) } rootPool := x509.NewCertPool() for _, rootCert := range rootCerts { rootPool.AddCert(rootCert) } trustedKeys := makeTrustedKeyMap(rootKeys) tokens := make([]*Token, 0, numTokens) for i := 0; i < numTokens; i++ { token, err := makeTestToken(issuer, audience, access, rootKeys[i], i) if err != nil { t.Fatal(err) } tokens = append(tokens, token) } verifyOps := VerifyOptions{ TrustedIssuers: newStringSet(issuer), AccpetedAudiences: newStringSet(audience), Roots: rootPool, TrustedKeys: trustedKeys, } for _, token := range tokens { if err := token.Verify(verifyOps); err != nil { t.Fatal(err) } if !token.Valid { t.Fatal("token not marked as Valid") } } } func writeTempRootCerts(rootKeys []libtrust.PrivateKey) (filename string, err error) { rootCerts, err := makeRootCerts(rootKeys) if err != nil { return "", err } tempFile, err := ioutil.TempFile("", "rootCertBundle") if err != nil { return "", err } defer tempFile.Close() for _, cert := range rootCerts { if err = pem.Encode(tempFile, &pem.Block{ Type: "CERTIFICATE", Bytes: cert.Raw, }); err != nil { os.Remove(tempFile.Name()) return "", err } } return tempFile.Name(), nil } // TestAccessController tests complete integration of the token auth package. // It starts by mocking the options for a token auth accessController which // it creates. It then tries a few mock requests: // - don't supply a token; should error with challenge // - supply an invalid token; should error with challenge // - supply a token with insufficient access; should error with challenge // - supply a valid token; should not error func TestAccessController(t *testing.T) { // Make 2 keys; only the first is to be a trusted root key. rootKeys, err := makeRootKeys(2) if err != nil { t.Fatal(err) } rootCertBundleFilename, err := writeTempRootCerts(rootKeys[:1]) if err != nil { t.Fatal(err) } defer os.Remove(rootCertBundleFilename) realm := "https://auth.example.com/token/" issuer := "test-issuer.example.com" service := "test-service.example.com" options := map[string]interface{}{ "realm": realm, "issuer": issuer, "service": service, "rootCertBundle": rootCertBundleFilename, } accessController, err := newAccessController(options) if err != nil { t.Fatal(err) } // 1. Make a mock http.Request with no token. req, err := http.NewRequest("GET", "http://example.com/foo", nil) if err != nil { t.Fatal(err) } testAccess := auth.Access{ Resource: auth.Resource{ Type: "foo", Name: "bar", }, Action: "baz", } err = accessController.Authorized(req, testAccess) challenge, ok := err.(auth.Challenge) if !ok { t.Fatal("accessController did not return a challenge") } if challenge.Error() != ErrTokenRequired.Error() { t.Fatalf("accessControler did not get expected error - got %s - expected %s", challenge, ErrTokenRequired) } // 2. Supply an invalid token. token, err := makeTestToken( issuer, service, []*ResourceActions{{ Type: testAccess.Type, Name: testAccess.Name, Actions: []string{testAccess.Action}, }}, rootKeys[1], 1, // Everything is valid except the key which signed it. ) if err != nil { t.Fatal(err) } req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.compactRaw())) err = accessController.Authorized(req, testAccess) challenge, ok = err.(auth.Challenge) if !ok { t.Fatal("accessController did not return a challenge") } if challenge.Error() != ErrInvalidToken.Error() { t.Fatalf("accessControler did not get expected error - got %s - expected %s", challenge, ErrTokenRequired) } // 3. Supply a token with insufficient access. token, err = makeTestToken( issuer, service, []*ResourceActions{}, // No access specified. rootKeys[0], 1, ) if err != nil { t.Fatal(err) } req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.compactRaw())) err = accessController.Authorized(req, testAccess) challenge, ok = err.(auth.Challenge) if !ok { t.Fatal("accessController did not return a challenge") } if challenge.Error() != ErrInsufficientScope.Error() { t.Fatalf("accessControler did not get expected error - got %s - expected %s", challenge, ErrInsufficientScope) } // 4. Supply the token we need, or deserve, or whatever. token, err = makeTestToken( issuer, service, []*ResourceActions{{ Type: testAccess.Type, Name: testAccess.Name, Actions: []string{testAccess.Action}, }}, rootKeys[0], 1, // Everything is valid except the key which signed it. ) if err != nil { t.Fatal(err) } req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.compactRaw())) if err = accessController.Authorized(req, testAccess); err != nil { t.Fatalf("accessController returned unexpected error: %s", err) } }