commit e2d6977b0af3a22b1ec0dd7d06935932f1f0c833 Author: Vincent Batts Date: Wed Feb 5 14:08:37 2025 -0500 *: initial commit from https://github.com/distribution/distribution/commit/29b5e79f8254fd606828c84cda2a2da4d6ae3a11 before the whole ./contrib/ directory was deleted. Note this mostly only for reference. Signed-off-by: Vincent Batts diff --git a/README.md b/README.md new file mode 100644 index 0000000..c4a6d63 --- /dev/null +++ b/README.md @@ -0,0 +1,5 @@ +token-server + +from https://github.com/distribution/distribution/commit/29b5e79f8254fd606828c84cda2a2da4d6ae3a11 before the whole ./contrib/ directory was deleted. + +Note this mostly only for reference. diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..2978323 --- /dev/null +++ b/errors.go @@ -0,0 +1,38 @@ +package main + +import ( + "net/http" + + "github.com/distribution/distribution/v3/registry/api/errcode" +) + +var ( + errGroup = "tokenserver" + + // ErrorBadTokenOption is returned when a token parameter is invalid + ErrorBadTokenOption = errcode.Register(errGroup, errcode.ErrorDescriptor{ + Value: "BAD_TOKEN_OPTION", + Message: "bad token option", + Description: `This error may be returned when a request for a + token contains an option which is not valid`, + HTTPStatusCode: http.StatusBadRequest, + }) + + // ErrorMissingRequiredField is returned when a required form field is missing + ErrorMissingRequiredField = errcode.Register(errGroup, errcode.ErrorDescriptor{ + Value: "MISSING_REQUIRED_FIELD", + Message: "missing required field", + Description: `This error may be returned when a request for a + token does not contain a required form field`, + HTTPStatusCode: http.StatusBadRequest, + }) + + // ErrorUnsupportedValue is returned when a form field has an unsupported value + ErrorUnsupportedValue = errcode.Register(errGroup, errcode.ErrorDescriptor{ + Value: "UNSUPPORTED_VALUE", + Message: "unsupported value", + Description: `This error may be returned when a request for a + token contains a form field with an unsupported value`, + HTTPStatusCode: http.StatusBadRequest, + }) +) diff --git a/main.go b/main.go new file mode 100644 index 0000000..9a42976 --- /dev/null +++ b/main.go @@ -0,0 +1,431 @@ +package main + +import ( + "context" + "crypto/rand" + "encoding/json" + "flag" + "math/big" + "net/http" + "strconv" + "strings" + "time" + + dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/registry/api/errcode" + "github.com/distribution/distribution/v3/registry/auth" + _ "github.com/distribution/distribution/v3/registry/auth/htpasswd" + "github.com/docker/libtrust" + "github.com/gorilla/mux" + "github.com/sirupsen/logrus" +) + +var enforceRepoClass bool + +func main() { + var ( + issuer = &TokenIssuer{} + pkFile string + addr string + debug bool + err error + + passwdFile string + realm string + + cert string + certKey string + ) + + flag.StringVar(&issuer.Issuer, "issuer", "distribution-token-server", "Issuer string for token") + flag.StringVar(&pkFile, "key", "", "Private key file") + flag.StringVar(&addr, "addr", "localhost:8080", "Address to listen on") + flag.BoolVar(&debug, "debug", false, "Debug mode") + + flag.StringVar(&passwdFile, "passwd", ".htpasswd", "Passwd file") + flag.StringVar(&realm, "realm", "", "Authentication realm") + + flag.StringVar(&cert, "tlscert", "", "Certificate file for TLS") + flag.StringVar(&certKey, "tlskey", "", "Certificate key for TLS") + + flag.BoolVar(&enforceRepoClass, "enforce-class", false, "Enforce policy for single repository class") + + flag.Parse() + + if debug { + logrus.SetLevel(logrus.DebugLevel) + } + + if pkFile == "" { + issuer.SigningKey, err = libtrust.GenerateECP256PrivateKey() + if err != nil { + logrus.Fatalf("Error generating private key: %v", err) + } + logrus.Debugf("Using newly generated key with id %s", issuer.SigningKey.KeyID()) + } else { + issuer.SigningKey, err = libtrust.LoadKeyFile(pkFile) + if err != nil { + logrus.Fatalf("Error loading key file %s: %v", pkFile, err) + } + logrus.Debugf("Loaded private key with id %s", issuer.SigningKey.KeyID()) + } + + if realm == "" { + logrus.Fatalf("Must provide realm") + } + + ac, err := auth.GetAccessController("htpasswd", map[string]interface{}{ + "realm": realm, + "path": passwdFile, + }) + if err != nil { + logrus.Fatalf("Error initializing access controller: %v", err) + } + + // TODO: Make configurable + issuer.Expiration = 15 * time.Minute + + ctx := dcontext.Background() + + ts := &tokenServer{ + issuer: issuer, + accessController: ac, + refreshCache: map[string]refreshToken{}, + } + + router := mux.NewRouter() + router.Path("/token/").Methods(http.MethodGet).Handler(handlerWithContext(ctx, ts.getToken)) + router.Path("/token/").Methods(http.MethodPost).Handler(handlerWithContext(ctx, ts.postToken)) + + if cert == "" { + err = http.ListenAndServe(addr, router) + } else if certKey == "" { + logrus.Fatalf("Must provide certficate (-tlscert) and key (-tlskey)") + } else { + err = http.ListenAndServeTLS(addr, cert, certKey, router) + } + + if err != nil { + logrus.Infof("Error serving: %v", err) + } +} + +// handlerWithContext wraps the given context-aware handler by setting up the +// request context from a base context. +func handlerWithContext(ctx context.Context, handler func(context.Context, http.ResponseWriter, *http.Request)) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := dcontext.WithRequest(ctx, r) + logger := dcontext.GetRequestLogger(ctx) + ctx = dcontext.WithLogger(ctx, logger) + + handler(ctx, w, r) + }) +} + +func handleError(ctx context.Context, err error, w http.ResponseWriter) { + ctx, w = dcontext.WithResponseWriter(ctx, w) + + if serveErr := errcode.ServeJSON(w, err); serveErr != nil { + dcontext.GetResponseLogger(ctx).Errorf("error sending error response: %v", serveErr) + return + } + + dcontext.GetResponseLogger(ctx).Info("application error") +} + +var refreshCharacters = []rune("0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") + +const refreshTokenLength = 15 + +func newRefreshToken() string { + s := make([]rune, refreshTokenLength) + max := int64(len(refreshCharacters)) + for i := range s { + randInt, err := rand.Int(rand.Reader, big.NewInt(max)) + // let '0' serves the failure case + if err != nil { + logrus.Infof("Error on making refersh token: %v", err) + randInt = big.NewInt(0) + } + s[i] = refreshCharacters[randInt.Int64()] + } + return string(s) +} + +type refreshToken struct { + subject string + service string +} + +type tokenServer struct { + issuer *TokenIssuer + accessController auth.AccessController + refreshCache map[string]refreshToken +} + +type tokenResponse struct { + Token string `json:"access_token"` + RefreshToken string `json:"refresh_token,omitempty"` + ExpiresIn int `json:"expires_in,omitempty"` +} + +var repositoryClassCache = map[string]string{} + +func filterAccessList(ctx context.Context, scope string, requestedAccessList []auth.Access) []auth.Access { + if !strings.HasSuffix(scope, "/") { + scope = scope + "/" + } + grantedAccessList := make([]auth.Access, 0, len(requestedAccessList)) + for _, access := range requestedAccessList { + if access.Type == "repository" { + if !strings.HasPrefix(access.Name, scope) { + dcontext.GetLogger(ctx).Debugf("Resource scope not allowed: %s", access.Name) + continue + } + if enforceRepoClass { + if class, ok := repositoryClassCache[access.Name]; ok { + if class != access.Class { + dcontext.GetLogger(ctx).Debugf("Different repository class: %q, previously %q", access.Class, class) + continue + } + } else if strings.EqualFold(access.Action, "push") { + repositoryClassCache[access.Name] = access.Class + } + } + } else if access.Type == "registry" { + if access.Name != "catalog" { + dcontext.GetLogger(ctx).Debugf("Unknown registry resource: %s", access.Name) + continue + } + // TODO: Limit some actions to "admin" users + } else { + dcontext.GetLogger(ctx).Debugf("Skipping unsupported resource type: %s", access.Type) + continue + } + grantedAccessList = append(grantedAccessList, access) + } + return grantedAccessList +} + +type acctSubject struct{} + +func (acctSubject) String() string { return "acctSubject" } + +type requestedAccess struct{} + +func (requestedAccess) String() string { return "requestedAccess" } + +type grantedAccess struct{} + +func (grantedAccess) String() string { return "grantedAccess" } + +// getToken handles authenticating the request and authorizing access to the +// requested scopes. +func (ts *tokenServer) getToken(ctx context.Context, w http.ResponseWriter, r *http.Request) { + dcontext.GetLogger(ctx).Info("getToken") + + params := r.URL.Query() + service := params.Get("service") + scopeSpecifiers := params["scope"] + var offline bool + if offlineStr := params.Get("offline_token"); offlineStr != "" { + var err error + offline, err = strconv.ParseBool(offlineStr) + if err != nil { + handleError(ctx, ErrorBadTokenOption.WithDetail(err), w) + return + } + } + + requestedAccessList := ResolveScopeSpecifiers(ctx, scopeSpecifiers) + + authorizedCtx, err := ts.accessController.Authorized(ctx, requestedAccessList...) + if err != nil { + challenge, ok := err.(auth.Challenge) + if !ok { + handleError(ctx, err, w) + return + } + + // Get response context. + ctx, w = dcontext.WithResponseWriter(ctx, w) + + challenge.SetHeaders(r, w) + handleError(ctx, errcode.ErrorCodeUnauthorized.WithDetail(challenge.Error()), w) + + dcontext.GetResponseLogger(ctx).Info("get token authentication challenge") + + return + } + ctx = authorizedCtx + + username := dcontext.GetStringValue(ctx, "auth.user.name") + + ctx = context.WithValue(ctx, acctSubject{}, username) + ctx = dcontext.WithLogger(ctx, dcontext.GetLogger(ctx, acctSubject{})) + + dcontext.GetLogger(ctx).Info("authenticated client") + + ctx = context.WithValue(ctx, requestedAccess{}, requestedAccessList) + ctx = dcontext.WithLogger(ctx, dcontext.GetLogger(ctx, requestedAccess{})) + + grantedAccessList := filterAccessList(ctx, username, requestedAccessList) + ctx = context.WithValue(ctx, grantedAccess{}, grantedAccessList) + ctx = dcontext.WithLogger(ctx, dcontext.GetLogger(ctx, grantedAccess{})) + + token, err := ts.issuer.CreateJWT(username, service, grantedAccessList) + if err != nil { + handleError(ctx, err, w) + return + } + + dcontext.GetLogger(ctx).Info("authorized client") + + response := tokenResponse{ + Token: token, + ExpiresIn: int(ts.issuer.Expiration.Seconds()), + } + + if offline { + response.RefreshToken = newRefreshToken() + ts.refreshCache[response.RefreshToken] = refreshToken{ + subject: username, + service: service, + } + } + + ctx, w = dcontext.WithResponseWriter(ctx, w) + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) + + dcontext.GetResponseLogger(ctx).Info("get token complete") +} + +type postTokenResponse struct { + Token string `json:"access_token"` + Scope string `json:"scope,omitempty"` + ExpiresIn int `json:"expires_in,omitempty"` + IssuedAt string `json:"issued_at,omitempty"` + RefreshToken string `json:"refresh_token,omitempty"` +} + +// postToken handles authenticating the request and authorizing access to the +// requested scopes. +func (ts *tokenServer) postToken(ctx context.Context, w http.ResponseWriter, r *http.Request) { + grantType := r.PostFormValue("grant_type") + if grantType == "" { + handleError(ctx, ErrorMissingRequiredField.WithDetail("missing grant_type value"), w) + return + } + + service := r.PostFormValue("service") + if service == "" { + handleError(ctx, ErrorMissingRequiredField.WithDetail("missing service value"), w) + return + } + + clientID := r.PostFormValue("client_id") + if clientID == "" { + handleError(ctx, ErrorMissingRequiredField.WithDetail("missing client_id value"), w) + return + } + + var offline bool + switch r.PostFormValue("access_type") { + case "", "online": + case "offline": + offline = true + default: + handleError(ctx, ErrorUnsupportedValue.WithDetail("unknown access_type value"), w) + return + } + + requestedAccessList := ResolveScopeList(ctx, r.PostFormValue("scope")) + + var subject string + var rToken string + switch grantType { + case "refresh_token": + rToken = r.PostFormValue("refresh_token") + if rToken == "" { + handleError(ctx, ErrorUnsupportedValue.WithDetail("missing refresh_token value"), w) + return + } + rt, ok := ts.refreshCache[rToken] + if !ok || rt.service != service { + handleError(ctx, errcode.ErrorCodeUnauthorized.WithDetail("invalid refresh token"), w) + return + } + subject = rt.subject + case "password": + ca, ok := ts.accessController.(auth.CredentialAuthenticator) + if !ok { + handleError(ctx, ErrorUnsupportedValue.WithDetail("password grant type not supported"), w) + return + } + subject = r.PostFormValue("username") + if subject == "" { + handleError(ctx, ErrorUnsupportedValue.WithDetail("missing username value"), w) + return + } + password := r.PostFormValue("password") + if password == "" { + handleError(ctx, ErrorUnsupportedValue.WithDetail("missing password value"), w) + return + } + if err := ca.AuthenticateUser(subject, password); err != nil { + handleError(ctx, errcode.ErrorCodeUnauthorized.WithDetail("invalid credentials"), w) + return + } + default: + handleError(ctx, ErrorUnsupportedValue.WithDetail("unknown grant_type value"), w) + return + } + + ctx = context.WithValue(ctx, acctSubject{}, subject) + ctx = dcontext.WithLogger(ctx, dcontext.GetLogger(ctx, acctSubject{})) + + dcontext.GetLogger(ctx).Info("authenticated client") + + ctx = context.WithValue(ctx, requestedAccess{}, requestedAccessList) + ctx = dcontext.WithLogger(ctx, dcontext.GetLogger(ctx, requestedAccess{})) + + grantedAccessList := filterAccessList(ctx, subject, requestedAccessList) + ctx = context.WithValue(ctx, grantedAccess{}, grantedAccessList) + ctx = dcontext.WithLogger(ctx, dcontext.GetLogger(ctx, grantedAccess{})) + + token, err := ts.issuer.CreateJWT(subject, service, grantedAccessList) + if err != nil { + handleError(ctx, err, w) + return + } + + dcontext.GetLogger(ctx).Info("authorized client") + + response := postTokenResponse{ + Token: token, + ExpiresIn: int(ts.issuer.Expiration.Seconds()), + IssuedAt: time.Now().UTC().Format(time.RFC3339), + Scope: ToScopeList(grantedAccessList), + } + + if offline { + rToken = newRefreshToken() + ts.refreshCache[rToken] = refreshToken{ + subject: subject, + service: service, + } + } + + if rToken != "" { + response.RefreshToken = rToken + } + + ctx, w = dcontext.WithResponseWriter(ctx, w) + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) + + dcontext.GetResponseLogger(ctx).Info("post token complete") +} diff --git a/token.go b/token.go new file mode 100644 index 0000000..dc0956d --- /dev/null +++ b/token.go @@ -0,0 +1,220 @@ +package main + +import ( + "context" + "crypto" + "crypto/rand" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "regexp" + "strings" + "time" + + dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/registry/auth" + "github.com/distribution/distribution/v3/registry/auth/token" + "github.com/docker/libtrust" +) + +// ResolveScopeSpecifiers converts a list of scope specifiers from a token +// request's `scope` query parameters into a list of standard access objects. +func ResolveScopeSpecifiers(ctx context.Context, scopeSpecs []string) []auth.Access { + requestedAccessSet := make(map[auth.Access]struct{}, 2*len(scopeSpecs)) + + for _, scopeSpecifier := range scopeSpecs { + // There should be 3 parts, separated by a `:` character. + parts := strings.SplitN(scopeSpecifier, ":", 3) + + if len(parts) != 3 { + dcontext.GetLogger(ctx).Infof("ignoring unsupported scope format %s", scopeSpecifier) + continue + } + + resourceType, resourceName, actions := parts[0], parts[1], parts[2] + + resourceType, resourceClass := splitResourceClass(resourceType) + if resourceType == "" { + continue + } + + // Actions should be a comma-separated list of actions. + for _, action := range strings.Split(actions, ",") { + requestedAccess := auth.Access{ + Resource: auth.Resource{ + Type: resourceType, + Class: resourceClass, + Name: resourceName, + }, + Action: action, + } + + // Add this access to the requested access set. + requestedAccessSet[requestedAccess] = struct{}{} + } + } + + requestedAccessList := make([]auth.Access, 0, len(requestedAccessSet)) + for requestedAccess := range requestedAccessSet { + requestedAccessList = append(requestedAccessList, requestedAccess) + } + + return requestedAccessList +} + +var typeRegexp = regexp.MustCompile(`^([a-z0-9]+)(\([a-z0-9]+\))?$`) + +func splitResourceClass(t string) (string, string) { + matches := typeRegexp.FindStringSubmatch(t) + if len(matches) < 2 { + return "", "" + } + if len(matches) == 2 || len(matches[2]) < 2 { + return matches[1], "" + } + return matches[1], matches[2][1 : len(matches[2])-1] +} + +// ResolveScopeList converts a scope list from a token request's +// `scope` parameter into a list of standard access objects. +func ResolveScopeList(ctx context.Context, scopeList string) []auth.Access { + scopes := strings.Split(scopeList, " ") + return ResolveScopeSpecifiers(ctx, scopes) +} + +func scopeString(a auth.Access) string { + if a.Class != "" { + return fmt.Sprintf("%s(%s):%s:%s", a.Type, a.Class, a.Name, a.Action) + } + return fmt.Sprintf("%s:%s:%s", a.Type, a.Name, a.Action) +} + +// ToScopeList converts a list of access to a +// scope list string +func ToScopeList(access []auth.Access) string { + var s []string + for _, a := range access { + s = append(s, scopeString(a)) + } + return strings.Join(s, ",") +} + +// TokenIssuer represents an issuer capable of generating JWT tokens +type TokenIssuer struct { + Issuer string + SigningKey libtrust.PrivateKey + Expiration time.Duration +} + +// CreateJWT creates and signs a JSON Web Token for the given subject and +// audience with the granted access. +func (issuer *TokenIssuer) CreateJWT(subject string, audience string, grantedAccessList []auth.Access) (string, error) { + // Make a set of access entries to put in the token's claimset. + resourceActionSets := make(map[auth.Resource]map[string]struct{}, len(grantedAccessList)) + for _, access := range grantedAccessList { + actionSet, exists := resourceActionSets[access.Resource] + if !exists { + actionSet = map[string]struct{}{} + resourceActionSets[access.Resource] = actionSet + } + actionSet[access.Action] = struct{}{} + } + + accessEntries := make([]*token.ResourceActions, 0, len(resourceActionSets)) + for resource, actionSet := range resourceActionSets { + actions := make([]string, 0, len(actionSet)) + for action := range actionSet { + actions = append(actions, action) + } + + accessEntries = append(accessEntries, &token.ResourceActions{ + Type: resource.Type, + Class: resource.Class, + Name: resource.Name, + Actions: actions, + }) + } + + randomBytes := make([]byte, 15) + _, err := io.ReadFull(rand.Reader, randomBytes) + if err != nil { + return "", err + } + randomID := base64.URLEncoding.EncodeToString(randomBytes) + + now := time.Now() + + signingHash := crypto.SHA256 + var alg string + switch issuer.SigningKey.KeyType() { + case "RSA": + alg = "RS256" + case "EC": + alg = "ES256" + default: + panic(fmt.Errorf("unsupported signing key type %q", issuer.SigningKey.KeyType())) + } + + joseHeader := token.Header{ + Type: "JWT", + SigningAlg: alg, + } + + if x5c := issuer.SigningKey.GetExtendedField("x5c"); x5c != nil { + joseHeader.X5c = x5c.([]string) + } else { + var jwkMessage json.RawMessage + jwkMessage, err = issuer.SigningKey.PublicKey().MarshalJSON() + if err != nil { + return "", err + } + joseHeader.RawJWK = &jwkMessage + } + + exp := issuer.Expiration + if exp == 0 { + exp = 5 * time.Minute + } + + claimSet := token.ClaimSet{ + Issuer: issuer.Issuer, + Subject: subject, + Audience: []string{audience}, + Expiration: now.Add(exp).Unix(), + NotBefore: now.Unix(), + IssuedAt: now.Unix(), + JWTID: randomID, + + Access: accessEntries, + } + + var ( + joseHeaderBytes []byte + claimSetBytes []byte + ) + + if joseHeaderBytes, err = json.Marshal(joseHeader); err != nil { + return "", fmt.Errorf("unable to encode jose header: %s", err) + } + if claimSetBytes, err = json.Marshal(claimSet); err != nil { + return "", fmt.Errorf("unable to encode claim set: %s", err) + } + + encodedJoseHeader := joseBase64Encode(joseHeaderBytes) + encodedClaimSet := joseBase64Encode(claimSetBytes) + encodingToSign := fmt.Sprintf("%s.%s", encodedJoseHeader, encodedClaimSet) + + var signatureBytes []byte + if signatureBytes, _, err = issuer.SigningKey.Sign(strings.NewReader(encodingToSign), signingHash); err != nil { + return "", fmt.Errorf("unable to sign jwt payload: %s", err) + } + + signature := joseBase64Encode(signatureBytes) + + return fmt.Sprintf("%s.%s", encodingToSign, signature), nil +} + +func joseBase64Encode(data []byte) string { + return strings.TrimRight(base64.URLEncoding.EncodeToString(data), "=") +} diff --git a/token_test.go b/token_test.go new file mode 100644 index 0000000..ea93ad4 --- /dev/null +++ b/token_test.go @@ -0,0 +1,78 @@ +package main + +import ( + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "errors" + "strings" + "testing" + "time" + + "github.com/distribution/distribution/v3/registry/auth" + "github.com/docker/libtrust" +) + +func TestCreateJWTSuccessWithEmptyACL(t *testing.T) { + key, err := rsa.GenerateKey(rand.Reader, 1024) + if err != nil { + t.Fatal(err) + } + pk, err := libtrust.FromCryptoPrivateKey(key) + if err != nil { + t.Fatal(err) + } + tokenIssuer := TokenIssuer{ + Expiration: time.Duration(100), + Issuer: "localhost", + SigningKey: pk, + } + + grantedAccessList := make([]auth.Access, 0) + token, err := tokenIssuer.CreateJWT("test", "test", grantedAccessList) + if err != nil { + t.Fatal(err) + } + + tokens := strings.Split(token, ".") + + if len(token) == 0 { + t.Fatal("token not generated.") + } + + json, err := decodeJWT(tokens[1]) + if err != nil { + t.Fatal(err) + } + + if !strings.Contains(json, "test") { + t.Fatal("Valid token was not generated.") + } +} + +func decodeJWT(rawToken string) (string, error) { + data, err := joseBase64Decode(rawToken) + if err != nil { + return "", errors.New("Error in Decoding base64 String") + } + return data, nil +} + +func joseBase64Decode(s string) (string, error) { + switch len(s) % 4 { + case 0: + case 2: + s += "==" + case 3: + s += "=" + default: + { + return "", errors.New("Invalid base64 String") + } + } + data, err := base64.StdEncoding.DecodeString(s) + if err != nil { + return "", err // errors.New("Error in Decoding base64 String") + } + return string(data), nil +}