diff --git a/registry/client/authchallenge.go b/registry/client/authchallenge.go index f45704b1..a9cce3cc 100644 --- a/registry/client/authchallenge.go +++ b/registry/client/authchallenge.go @@ -8,9 +8,9 @@ import ( // Octet types from RFC 2616. type octetType byte -// AuthorizationChallenge carries information +// authorizationChallenge carries information // from a WWW-Authenticate response header. -type AuthorizationChallenge struct { +type authorizationChallenge struct { Scheme string Parameters map[string]string } @@ -54,12 +54,12 @@ func init() { } } -func parseAuthHeader(header http.Header) []AuthorizationChallenge { - var challenges []AuthorizationChallenge +func parseAuthHeader(header http.Header) []authorizationChallenge { + var challenges []authorizationChallenge for _, h := range header[http.CanonicalHeaderKey("WWW-Authenticate")] { v, p := parseValueAndParams(h) if v != "" { - challenges = append(challenges, AuthorizationChallenge{Scheme: v, Parameters: p}) + challenges = append(challenges, authorizationChallenge{Scheme: v, Parameters: p}) } } return challenges diff --git a/registry/client/endpoint.go b/registry/client/endpoint.go deleted file mode 100644 index 9889dc66..00000000 --- a/registry/client/endpoint.go +++ /dev/null @@ -1,268 +0,0 @@ -package client - -import ( - "fmt" - "net/http" - "net/url" - "strings" - "sync" - "time" - - "github.com/Sirupsen/logrus" - "github.com/docker/distribution/registry/api/v2" -) - -// Authorizer is used to apply Authorization to an HTTP request -type Authorizer interface { - // Authorizer updates an HTTP request with the needed authorization - Authorize(req *http.Request) error -} - -// CredentialStore is an interface for getting credentials for -// a given URL -type CredentialStore interface { - // Basic returns basic auth for the given URL - Basic(*url.URL) (string, string) -} - -// RepositoryEndpoint represents a single host endpoint serving up -// the distribution API. -type RepositoryEndpoint struct { - Endpoint string - Mirror bool - - Header http.Header - Credentials CredentialStore - - ub *v2.URLBuilder -} - -type nullAuthorizer struct{} - -func (na nullAuthorizer) Authorize(req *http.Request) error { - return nil -} - -type repositoryTransport struct { - Transport http.RoundTripper - Header http.Header - Authorizer Authorizer -} - -func (rt *repositoryTransport) RoundTrip(req *http.Request) (*http.Response, error) { - reqCopy := new(http.Request) - *reqCopy = *req - - // Copy existing headers then static headers - reqCopy.Header = make(http.Header, len(req.Header)+len(rt.Header)) - for k, s := range req.Header { - reqCopy.Header[k] = append([]string(nil), s...) - } - for k, s := range rt.Header { - reqCopy.Header[k] = append(reqCopy.Header[k], s...) - } - - if rt.Authorizer != nil { - if err := rt.Authorizer.Authorize(reqCopy); err != nil { - return nil, err - } - } - - logrus.Debugf("HTTP: %s %s", req.Method, req.URL) - - if rt.Transport != nil { - return rt.Transport.RoundTrip(reqCopy) - } - return http.DefaultTransport.RoundTrip(reqCopy) -} - -type authTransport struct { - Transport http.RoundTripper - Header http.Header -} - -func (rt *authTransport) RoundTrip(req *http.Request) (*http.Response, error) { - reqCopy := new(http.Request) - *reqCopy = *req - - // Copy existing headers then static headers - reqCopy.Header = make(http.Header, len(req.Header)+len(rt.Header)) - for k, s := range req.Header { - reqCopy.Header[k] = append([]string(nil), s...) - } - for k, s := range rt.Header { - reqCopy.Header[k] = append(reqCopy.Header[k], s...) - } - - logrus.Debugf("HTTP: %s %s", req.Method, req.URL) - - if rt.Transport != nil { - return rt.Transport.RoundTrip(reqCopy) - } - return http.DefaultTransport.RoundTrip(reqCopy) -} - -// URLBuilder returns a new URL builder -func (e *RepositoryEndpoint) URLBuilder() (*v2.URLBuilder, error) { - if e.ub == nil { - var err error - e.ub, err = v2.NewURLBuilderFromString(e.Endpoint) - if err != nil { - return nil, err - } - } - - return e.ub, nil -} - -// HTTPClient returns a new HTTP client configured for this endpoint -func (e *RepositoryEndpoint) HTTPClient(name string) (*http.Client, error) { - // TODO(dmcgowan): create http.Transport - - transport := &repositoryTransport{ - Header: e.Header, - } - client := &http.Client{ - Transport: transport, - } - - challenges, err := e.ping(client) - if err != nil { - return nil, err - } - actions := []string{"pull"} - if !e.Mirror { - actions = append(actions, "push") - } - - transport.Authorizer = &endpointAuthorizer{ - client: &http.Client{Transport: &authTransport{Header: e.Header}}, - challenges: challenges, - creds: e.Credentials, - resource: "repository", - scope: name, - actions: actions, - } - - return client, nil -} - -func (e *RepositoryEndpoint) ping(client *http.Client) ([]AuthorizationChallenge, error) { - ub, err := e.URLBuilder() - if err != nil { - return nil, err - } - u, err := ub.BuildBaseURL() - if err != nil { - return nil, err - } - - req, err := http.NewRequest("GET", u, nil) - if err != nil { - return nil, err - } - req.Header = make(http.Header, len(e.Header)) - for k, s := range e.Header { - req.Header[k] = append([]string(nil), s...) - } - - resp, err := client.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - var supportsV2 bool -HeaderLoop: - for _, supportedVersions := range resp.Header[http.CanonicalHeaderKey("Docker-Distribution-API-Version")] { - for _, versionName := range strings.Fields(supportedVersions) { - if versionName == "registry/2.0" { - supportsV2 = true - break HeaderLoop - } - } - } - - if !supportsV2 { - return nil, fmt.Errorf("%s does not appear to be a v2 registry endpoint", e.Endpoint) - } - - if resp.StatusCode == http.StatusUnauthorized { - // Parse the WWW-Authenticate Header and store the challenges - // on this endpoint object. - return parseAuthHeader(resp.Header), nil - } else if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("unable to get valid ping response: %d", resp.StatusCode) - } - - return nil, nil -} - -type endpointAuthorizer struct { - client *http.Client - challenges []AuthorizationChallenge - creds CredentialStore - - resource string - scope string - actions []string - - tokenLock sync.Mutex - tokenCache string - tokenExpiration time.Time -} - -func (ta *endpointAuthorizer) Authorize(req *http.Request) error { - token, err := ta.getToken() - if err != nil { - return err - } - if token != "" { - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) - } else if ta.creds != nil { - username, password := ta.creds.Basic(req.URL) - if username != "" && password != "" { - req.SetBasicAuth(username, password) - } - } - return nil -} - -func (ta *endpointAuthorizer) getToken() (string, error) { - ta.tokenLock.Lock() - defer ta.tokenLock.Unlock() - now := time.Now() - if now.Before(ta.tokenExpiration) { - //log.Debugf("Using cached token for %q", ta.auth.Username) - return ta.tokenCache, nil - } - - for _, challenge := range ta.challenges { - switch strings.ToLower(challenge.Scheme) { - case "basic": - // no token necessary - case "bearer": - //log.Debugf("Getting bearer token with %s for %s", challenge.Parameters, ta.auth.Username) - params := map[string]string{} - for k, v := range challenge.Parameters { - params[k] = v - } - params["scope"] = fmt.Sprintf("%s:%s:%s", ta.resource, ta.scope, strings.Join(ta.actions, ",")) - token, err := getToken(ta.creds, params, ta.client) - if err != nil { - return "", err - } - ta.tokenCache = token - ta.tokenExpiration = now.Add(time.Minute) - - return token, nil - default: - //log.Infof("Unsupported auth scheme: %q", challenge.Scheme) - } - } - - // Do not expire cache since there are no challenges which use a token - ta.tokenExpiration = time.Now().Add(time.Hour * 24) - - return "", nil -} diff --git a/registry/client/layer_upload_test.go b/registry/client/layer_upload_test.go index 1aa5cf1e..9e22cb7c 100644 --- a/registry/client/layer_upload_test.go +++ b/registry/client/layer_upload_test.go @@ -124,7 +124,8 @@ func TestUploadReadFrom(t *testing.T) { e, c := testServer(m) defer c() - client, err := e.HTTPClient(repo) + repoConfig := &RepositoryConfig{} + client, err := repoConfig.HTTPClient() if err != nil { t.Fatalf("Error creating client: %s", err) } @@ -133,7 +134,7 @@ func TestUploadReadFrom(t *testing.T) { } // Valid case - layerUpload.location = e.Endpoint + locationPath + layerUpload.location = e + locationPath n, err := layerUpload.ReadFrom(bytes.NewReader(b)) if err != nil { t.Fatalf("Error calling ReadFrom: %s", err) @@ -143,26 +144,26 @@ func TestUploadReadFrom(t *testing.T) { } // Bad range - layerUpload.location = e.Endpoint + locationPath + layerUpload.location = e + locationPath _, err = layerUpload.ReadFrom(bytes.NewReader(b)) if err == nil { t.Fatalf("Expected error when bad range received") } // 404 - layerUpload.location = e.Endpoint + locationPath + layerUpload.location = e + locationPath _, err = layerUpload.ReadFrom(bytes.NewReader(b)) if err == nil { t.Fatalf("Expected error when not found") } if blobErr, ok := err.(*BlobUploadNotFoundError); !ok { t.Fatalf("Wrong error type %T: %s", err, err) - } else if expected := e.Endpoint + locationPath; blobErr.Location != expected { + } else if expected := e + locationPath; blobErr.Location != expected { t.Fatalf("Unexpected location: %s, expected %s", blobErr.Location, expected) } // 400 valid json - layerUpload.location = e.Endpoint + locationPath + layerUpload.location = e + locationPath _, err = layerUpload.ReadFrom(bytes.NewReader(b)) if err == nil { t.Fatalf("Expected error when not found") @@ -185,7 +186,7 @@ func TestUploadReadFrom(t *testing.T) { } // 400 invalid json - layerUpload.location = e.Endpoint + locationPath + layerUpload.location = e + locationPath _, err = layerUpload.ReadFrom(bytes.NewReader(b)) if err == nil { t.Fatalf("Expected error when not found") @@ -200,7 +201,7 @@ func TestUploadReadFrom(t *testing.T) { } // 500 - layerUpload.location = e.Endpoint + locationPath + layerUpload.location = e + locationPath _, err = layerUpload.ReadFrom(bytes.NewReader(b)) if err == nil { t.Fatalf("Expected error when not found") diff --git a/registry/client/repository.go b/registry/client/repository.go index 22a02373..d5f75bda 100644 --- a/registry/client/repository.go +++ b/registry/client/repository.go @@ -19,17 +19,17 @@ import ( ) // NewRepository creates a new Repository for the given repository name and endpoint -func NewRepository(ctx context.Context, name string, endpoint *RepositoryEndpoint) (distribution.Repository, error) { +func NewRepository(ctx context.Context, name, endpoint string, repoConfig *RepositoryConfig) (distribution.Repository, error) { if err := v2.ValidateRespositoryName(name); err != nil { return nil, err } - ub, err := endpoint.URLBuilder() + ub, err := v2.NewURLBuilderFromString(endpoint) if err != nil { return nil, err } - client, err := endpoint.HTTPClient(name) + client, err := repoConfig.HTTPClient() if err != nil { return nil, err } @@ -39,7 +39,7 @@ func NewRepository(ctx context.Context, name string, endpoint *RepositoryEndpoin ub: ub, name: name, context: ctx, - mirror: endpoint.Mirror, + mirror: repoConfig.AllowMirrors, }, nil } diff --git a/registry/client/repository_test.go b/registry/client/repository_test.go index b96c52e5..1674213d 100644 --- a/registry/client/repository_test.go +++ b/registry/client/repository_test.go @@ -20,11 +20,10 @@ import ( "golang.org/x/net/context" ) -func testServer(rrm testutil.RequestResponseMap) (*RepositoryEndpoint, func()) { +func testServer(rrm testutil.RequestResponseMap) (string, func()) { h := testutil.NewHandler(rrm) s := httptest.NewServer(h) - e := RepositoryEndpoint{Endpoint: s.URL, Mirror: false} - return &e, s.Close + return s.URL, s.Close } func newRandomBlob(size int) (digest.Digest, []byte) { @@ -97,7 +96,7 @@ func TestLayerFetch(t *testing.T) { e, c := testServer(m) defer c() - r, err := NewRepository(context.Background(), "test.example.com/repo1", e) + r, err := NewRepository(context.Background(), "test.example.com/repo1", e, &RepositoryConfig{}) if err != nil { t.Fatal(err) } @@ -127,7 +126,7 @@ func TestLayerExists(t *testing.T) { e, c := testServer(m) defer c() - r, err := NewRepository(context.Background(), "test.example.com/repo1", e) + r, err := NewRepository(context.Background(), "test.example.com/repo1", e, &RepositoryConfig{}) if err != nil { t.Fatal(err) } @@ -227,7 +226,7 @@ func TestLayerUploadChunked(t *testing.T) { e, c := testServer(m) defer c() - r, err := NewRepository(context.Background(), repo, e) + r, err := NewRepository(context.Background(), repo, e, &RepositoryConfig{}) if err != nil { t.Fatal(err) } @@ -334,7 +333,7 @@ func TestLayerUploadMonolithic(t *testing.T) { e, c := testServer(m) defer c() - r, err := NewRepository(context.Background(), repo, e) + r, err := NewRepository(context.Background(), repo, e, &RepositoryConfig{}) if err != nil { t.Fatal(err) } @@ -475,7 +474,7 @@ func TestManifestFetch(t *testing.T) { e, c := testServer(m) defer c() - r, err := NewRepository(context.Background(), repo, e) + r, err := NewRepository(context.Background(), repo, e, &RepositoryConfig{}) if err != nil { t.Fatal(err) } @@ -508,7 +507,7 @@ func TestManifestFetchByTag(t *testing.T) { e, c := testServer(m) defer c() - r, err := NewRepository(context.Background(), repo, e) + r, err := NewRepository(context.Background(), repo, e, &RepositoryConfig{}) if err != nil { t.Fatal(err) } @@ -553,7 +552,7 @@ func TestManifestDelete(t *testing.T) { e, c := testServer(m) defer c() - r, err := NewRepository(context.Background(), repo, e) + r, err := NewRepository(context.Background(), repo, e, &RepositoryConfig{}) if err != nil { t.Fatal(err) } @@ -591,7 +590,7 @@ func TestManifestPut(t *testing.T) { e, c := testServer(m) defer c() - r, err := NewRepository(context.Background(), repo, e) + r, err := NewRepository(context.Background(), repo, e, &RepositoryConfig{}) if err != nil { t.Fatal(err) } diff --git a/registry/client/session.go b/registry/client/session.go new file mode 100644 index 00000000..bd8abe0f --- /dev/null +++ b/registry/client/session.go @@ -0,0 +1,282 @@ +package client + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "strings" + "sync" + "time" +) + +// Authorizer is used to apply Authorization to an HTTP request +type Authorizer interface { + // Authorizer updates an HTTP request with the needed authorization + Authorize(req *http.Request) error +} + +// CredentialStore is an interface for getting credentials for +// a given URL +type CredentialStore interface { + // Basic returns basic auth for the given URL + Basic(*url.URL) (string, string) +} + +// RepositoryConfig holds the base configuration needed to communicate +// with a registry including a method of authorization and HTTP headers. +type RepositoryConfig struct { + Header http.Header + AuthSource Authorizer + AllowMirrors bool +} + +// HTTPClient returns a new HTTP client configured for this configuration +func (rc *RepositoryConfig) HTTPClient() (*http.Client, error) { + // TODO(dmcgowan): create base http.Transport with proper TLS configuration + + transport := &Transport{ + ExtraHeader: rc.Header, + AuthSource: rc.AuthSource, + } + + client := &http.Client{ + Transport: transport, + } + + return client, nil +} + +// TokenScope represents the scope at which a token will be requested. +// This represents a specific action on a registry resource. +type TokenScope struct { + Resource string + Scope string + Actions []string +} + +func (ts TokenScope) String() string { + return fmt.Sprintf("%s:%s:%s", ts.Resource, ts.Scope, strings.Join(ts.Actions, ",")) +} + +// NewTokenAuthorizer returns an authorizer which is capable of getting a token +// from a token server. The expected authorization method will be discovered +// by the authorizer, getting the token server endpoint from the URL being +// requested. Basic authentication may either be done to the token source or +// directly with the requested endpoint depending on the endpoint's +// WWW-Authenticate header. +func NewTokenAuthorizer(creds CredentialStore, header http.Header, scope TokenScope) Authorizer { + return &tokenAuthorizer{ + header: header, + creds: creds, + scope: scope, + challenges: map[string][]authorizationChallenge{}, + } +} + +type tokenAuthorizer struct { + header http.Header + challenges map[string][]authorizationChallenge + creds CredentialStore + scope TokenScope + + tokenLock sync.Mutex + tokenCache string + tokenExpiration time.Time +} + +func (ta *tokenAuthorizer) ping(endpoint string) ([]authorizationChallenge, error) { + req, err := http.NewRequest("GET", endpoint, nil) + if err != nil { + return nil, err + } + + resp, err := ta.client().Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + var supportsV2 bool +HeaderLoop: + for _, supportedVersions := range resp.Header[http.CanonicalHeaderKey("Docker-Distribution-API-Version")] { + for _, versionName := range strings.Fields(supportedVersions) { + if versionName == "registry/2.0" { + supportsV2 = true + break HeaderLoop + } + } + } + + if !supportsV2 { + return nil, fmt.Errorf("%s does not appear to be a v2 registry endpoint", endpoint) + } + + if resp.StatusCode == http.StatusUnauthorized { + // Parse the WWW-Authenticate Header and store the challenges + // on this endpoint object. + return parseAuthHeader(resp.Header), nil + } else if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unable to get valid ping response: %d", resp.StatusCode) + } + + return nil, nil +} + +func (ta *tokenAuthorizer) Authorize(req *http.Request) error { + v2Root := strings.Index(req.URL.Path, "/v2/") + if v2Root == -1 { + return nil + } + + ping := url.URL{ + Host: req.URL.Host, + Scheme: req.URL.Scheme, + Path: req.URL.Path[:v2Root+4], + } + + pingEndpoint := ping.String() + + challenges, ok := ta.challenges[pingEndpoint] + if !ok { + var err error + challenges, err = ta.ping(pingEndpoint) + if err != nil { + return err + } + ta.challenges[pingEndpoint] = challenges + } + + return ta.setAuth(challenges, req) +} + +func (ta *tokenAuthorizer) client() *http.Client { + // TODO(dmcgowan): Use same transport which has properly configured TLS + return &http.Client{Transport: &Transport{ExtraHeader: ta.header}} +} + +func (ta *tokenAuthorizer) setAuth(challenges []authorizationChallenge, req *http.Request) error { + var useBasic bool + for _, challenge := range challenges { + switch strings.ToLower(challenge.Scheme) { + case "basic": + useBasic = true + case "bearer": + if err := ta.refreshToken(challenge); err != nil { + return err + } + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ta.tokenCache)) + + return nil + default: + //log.Infof("Unsupported auth scheme: %q", challenge.Scheme) + } + } + + // Only use basic when no token auth challenges found + if useBasic { + if ta.creds != nil { + username, password := ta.creds.Basic(req.URL) + if username != "" && password != "" { + req.SetBasicAuth(username, password) + return nil + } + } + return errors.New("no basic auth credentials") + } + + return nil +} + +func (ta *tokenAuthorizer) refreshToken(challenge authorizationChallenge) error { + ta.tokenLock.Lock() + defer ta.tokenLock.Unlock() + now := time.Now() + if now.After(ta.tokenExpiration) { + token, err := ta.fetchToken(challenge) + if err != nil { + return err + } + ta.tokenCache = token + ta.tokenExpiration = now.Add(time.Minute) + } + + return nil +} + +type tokenResponse struct { + Token string `json:"token"` +} + +func (ta *tokenAuthorizer) fetchToken(challenge authorizationChallenge) (token string, err error) { + //log.Debugf("Getting bearer token with %s for %s", challenge.Parameters, ta.auth.Username) + params := map[string]string{} + for k, v := range challenge.Parameters { + params[k] = v + } + params["scope"] = ta.scope.String() + + realm, ok := params["realm"] + if !ok { + return "", errors.New("no realm specified for token auth challenge") + } + + realmURL, err := url.Parse(realm) + if err != nil { + return "", fmt.Errorf("invalid token auth challenge realm: %s", err) + } + + // TODO(dmcgowan): Handle empty scheme + + req, err := http.NewRequest("GET", realmURL.String(), nil) + if err != nil { + return "", err + } + + reqParams := req.URL.Query() + service := params["service"] + scope := params["scope"] + + if service != "" { + reqParams.Add("service", service) + } + + for _, scopeField := range strings.Fields(scope) { + reqParams.Add("scope", scopeField) + } + + if ta.creds != nil { + username, password := ta.creds.Basic(realmURL) + if username != "" && password != "" { + reqParams.Add("account", username) + req.SetBasicAuth(username, password) + } + } + + req.URL.RawQuery = reqParams.Encode() + + resp, err := ta.client().Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("token auth attempt for registry: %s request failed with status: %d %s", req.URL, resp.StatusCode, http.StatusText(resp.StatusCode)) + } + + decoder := json.NewDecoder(resp.Body) + + tr := new(tokenResponse) + if err = decoder.Decode(tr); err != nil { + return "", fmt.Errorf("unable to decode token response: %s", err) + } + + if tr.Token == "" { + return "", errors.New("authorization server did not include a token in the response") + } + + return tr.Token, nil +} diff --git a/registry/client/endpoint_test.go b/registry/client/session_test.go similarity index 79% rename from registry/client/endpoint_test.go rename to registry/client/session_test.go index 42bdc357..87e1e66e 100644 --- a/registry/client/endpoint_test.go +++ b/registry/client/session_test.go @@ -30,7 +30,7 @@ func (w *testAuthenticationWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Re w.next.ServeHTTP(rw, r) } -func testServerWithAuth(rrm testutil.RequestResponseMap, authenticate string, authCheck func(string) bool) (*RepositoryEndpoint, func()) { +func testServerWithAuth(rrm testutil.RequestResponseMap, authenticate string, authCheck func(string) bool) (string, func()) { h := testutil.NewHandler(rrm) wrapper := &testAuthenticationWrapper{ @@ -43,8 +43,7 @@ func testServerWithAuth(rrm testutil.RequestResponseMap, authenticate string, au } s := httptest.NewServer(wrapper) - e := RepositoryEndpoint{Endpoint: s.URL, Mirror: false} - return &e, s.Close + return s.URL, s.Close } type testCredentialStore struct { @@ -62,6 +61,16 @@ func TestEndpointAuthorizeToken(t *testing.T) { repo2 := "other/registry" scope1 := fmt.Sprintf("repository:%s:pull,push", repo1) scope2 := fmt.Sprintf("repository:%s:pull,push", repo2) + tokenScope1 := TokenScope{ + Resource: "repository", + Scope: repo1, + Actions: []string{"pull", "push"}, + } + tokenScope2 := TokenScope{ + Resource: "repository", + Scope: repo2, + Actions: []string{"pull", "push"}, + } tokenMap := testutil.RequestResponseMap([]testutil.RequestResponseMapping{ { @@ -92,7 +101,7 @@ func TestEndpointAuthorizeToken(t *testing.T) { { Request: testutil.Request{ Method: "GET", - Route: "/hello", + Route: "/v2/hello", }, Response: testutil.Response{ StatusCode: http.StatusAccepted, @@ -100,19 +109,23 @@ func TestEndpointAuthorizeToken(t *testing.T) { }, }) - authenicate := fmt.Sprintf("Bearer realm=%q,service=%q", te.Endpoint+"/token", service) + authenicate := fmt.Sprintf("Bearer realm=%q,service=%q", te+"/token", service) validCheck := func(a string) bool { return a == "Bearer statictoken" } e, c := testServerWithAuth(m, authenicate, validCheck) defer c() - client, err := e.HTTPClient(repo1) + repo1Config := &RepositoryConfig{ + AuthSource: NewTokenAuthorizer(nil, nil, tokenScope1), + } + + client, err := repo1Config.HTTPClient() if err != nil { t.Fatalf("Error creating http client: %s", err) } - req, _ := http.NewRequest("GET", e.Endpoint+"/hello", nil) + req, _ := http.NewRequest("GET", e+"/v2/hello", nil) resp, err := client.Do(req) if err != nil { t.Fatalf("Error sending get request: %s", err) @@ -128,12 +141,15 @@ func TestEndpointAuthorizeToken(t *testing.T) { e2, c2 := testServerWithAuth(m, authenicate, badCheck) defer c2() - client2, err := e2.HTTPClient(repo2) + repo2Config := &RepositoryConfig{ + AuthSource: NewTokenAuthorizer(nil, nil, tokenScope2), + } + client2, err := repo2Config.HTTPClient() if err != nil { t.Fatalf("Error creating http client: %s", err) } - req, _ = http.NewRequest("GET", e.Endpoint+"/hello", nil) + req, _ = http.NewRequest("GET", e2+"/v2/hello", nil) resp, err = client2.Do(req) if err != nil { t.Fatalf("Error sending get request: %s", err) @@ -155,6 +171,11 @@ func TestEndpointAuthorizeTokenBasic(t *testing.T) { scope := fmt.Sprintf("repository:%s:pull,push", repo) username := "tokenuser" password := "superSecretPa$$word" + tokenScope := TokenScope{ + Resource: "repository", + Scope: repo, + Actions: []string{"pull", "push"}, + } tokenMap := testutil.RequestResponseMap([]testutil.RequestResponseMapping{ { @@ -180,7 +201,7 @@ func TestEndpointAuthorizeTokenBasic(t *testing.T) { { Request: testutil.Request{ Method: "GET", - Route: "/hello", + Route: "/v2/hello", }, Response: testutil.Response{ StatusCode: http.StatusAccepted, @@ -188,24 +209,27 @@ func TestEndpointAuthorizeTokenBasic(t *testing.T) { }, }) - authenicate2 := fmt.Sprintf("Bearer realm=%q,service=%q", te.Endpoint+"/token", service) + authenicate2 := fmt.Sprintf("Bearer realm=%q,service=%q", te+"/token", service) bearerCheck := func(a string) bool { return a == "Bearer statictoken" } e, c := testServerWithAuth(m, authenicate2, bearerCheck) defer c() - e.Credentials = &testCredentialStore{ + creds := &testCredentialStore{ username: username, password: password, } + repoConfig := &RepositoryConfig{ + AuthSource: NewTokenAuthorizer(creds, nil, tokenScope), + } - client, err := e.HTTPClient(repo) + client, err := repoConfig.HTTPClient() if err != nil { t.Fatalf("Error creating http client: %s", err) } - req, _ := http.NewRequest("GET", e.Endpoint+"/hello", nil) + req, _ := http.NewRequest("GET", e+"/v2/hello", nil) resp, err := client.Do(req) if err != nil { t.Fatalf("Error sending get request: %s", err) @@ -221,7 +245,7 @@ func TestEndpointAuthorizeBasic(t *testing.T) { { Request: testutil.Request{ Method: "GET", - Route: "/hello", + Route: "/v2/hello", }, Response: testutil.Response{ StatusCode: http.StatusAccepted, @@ -237,17 +261,20 @@ func TestEndpointAuthorizeBasic(t *testing.T) { } e, c := testServerWithAuth(m, authenicate, validCheck) defer c() - e.Credentials = &testCredentialStore{ + creds := &testCredentialStore{ username: username, password: password, } + repoConfig := &RepositoryConfig{ + AuthSource: NewTokenAuthorizer(creds, nil, TokenScope{}), + } - client, err := e.HTTPClient("test/repo/basic") + client, err := repoConfig.HTTPClient() if err != nil { t.Fatalf("Error creating http client: %s", err) } - req, _ := http.NewRequest("GET", e.Endpoint+"/hello", nil) + req, _ := http.NewRequest("GET", e+"/v2/hello", nil) resp, err := client.Do(req) if err != nil { t.Fatalf("Error sending get request: %s", err) diff --git a/registry/client/token.go b/registry/client/token.go deleted file mode 100644 index 6439e01e..00000000 --- a/registry/client/token.go +++ /dev/null @@ -1,78 +0,0 @@ -package client - -import ( - "encoding/json" - "errors" - "fmt" - "net/http" - "net/url" - "strings" -) - -type tokenResponse struct { - Token string `json:"token"` -} - -func getToken(creds CredentialStore, params map[string]string, client *http.Client) (token string, err error) { - realm, ok := params["realm"] - if !ok { - return "", errors.New("no realm specified for token auth challenge") - } - - realmURL, err := url.Parse(realm) - if err != nil { - return "", fmt.Errorf("invalid token auth challenge realm: %s", err) - } - - // TODO(dmcgowan): Handle empty scheme - - req, err := http.NewRequest("GET", realmURL.String(), nil) - if err != nil { - return "", err - } - - reqParams := req.URL.Query() - service := params["service"] - scope := params["scope"] - - if service != "" { - reqParams.Add("service", service) - } - - for _, scopeField := range strings.Fields(scope) { - reqParams.Add("scope", scopeField) - } - - if creds != nil { - username, password := creds.Basic(realmURL) - if username != "" && password != "" { - reqParams.Add("account", username) - req.SetBasicAuth(username, password) - } - } - - req.URL.RawQuery = reqParams.Encode() - - resp, err := client.Do(req) - if err != nil { - return "", err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return "", fmt.Errorf("token auth attempt for registry: %s request failed with status: %d %s", req.URL, resp.StatusCode, http.StatusText(resp.StatusCode)) - } - - decoder := json.NewDecoder(resp.Body) - - tr := new(tokenResponse) - if err = decoder.Decode(tr); err != nil { - return "", fmt.Errorf("unable to decode token response: %s", err) - } - - if tr.Token == "" { - return "", errors.New("authorization server did not include a token in the response") - } - - return tr.Token, nil -} diff --git a/registry/client/transport.go b/registry/client/transport.go new file mode 100644 index 00000000..e92ba543 --- /dev/null +++ b/registry/client/transport.go @@ -0,0 +1,120 @@ +package client + +import ( + "io" + "net/http" + "sync" +) + +// Transport is an http.RoundTripper that makes registry HTTP requests, +// wrapping a base RoundTripper and adding an Authorization header +// from an Auth source +type Transport struct { + AuthSource Authorizer + ExtraHeader http.Header + + Base http.RoundTripper + + mu sync.Mutex // guards modReq + modReq map[*http.Request]*http.Request // original -> modified +} + +// RoundTrip authorizes and authenticates the request with an +// access token. If no token exists or token is expired, +// tries to refresh/fetch a new token. +func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { + req2 := t.cloneRequest(req) + if t.AuthSource != nil { + if err := t.AuthSource.Authorize(req2); err != nil { + return nil, err + } + } + t.setModReq(req, req2) + res, err := t.base().RoundTrip(req2) + if err != nil { + t.setModReq(req, nil) + return nil, err + } + res.Body = &onEOFReader{ + rc: res.Body, + fn: func() { t.setModReq(req, nil) }, + } + return res, nil +} + +// CancelRequest cancels an in-flight request by closing its connection. +func (t *Transport) CancelRequest(req *http.Request) { + type canceler interface { + CancelRequest(*http.Request) + } + if cr, ok := t.base().(canceler); ok { + t.mu.Lock() + modReq := t.modReq[req] + delete(t.modReq, req) + t.mu.Unlock() + cr.CancelRequest(modReq) + } +} + +func (t *Transport) base() http.RoundTripper { + if t.Base != nil { + return t.Base + } + return http.DefaultTransport +} + +func (t *Transport) setModReq(orig, mod *http.Request) { + t.mu.Lock() + defer t.mu.Unlock() + if t.modReq == nil { + t.modReq = make(map[*http.Request]*http.Request) + } + if mod == nil { + delete(t.modReq, orig) + } else { + t.modReq[orig] = mod + } +} + +// cloneRequest returns a clone of the provided *http.Request. +// The clone is a shallow copy of the struct and its Header map. +func (t *Transport) cloneRequest(r *http.Request) *http.Request { + // shallow copy of the struct + r2 := new(http.Request) + *r2 = *r + // deep copy of the Header + r2.Header = make(http.Header, len(r.Header)) + for k, s := range r.Header { + r2.Header[k] = append([]string(nil), s...) + } + for k, s := range t.ExtraHeader { + r2.Header[k] = append(r2.Header[k], s...) + } + return r2 +} + +type onEOFReader struct { + rc io.ReadCloser + fn func() +} + +func (r *onEOFReader) Read(p []byte) (n int, err error) { + n, err = r.rc.Read(p) + if err == io.EOF { + r.runFunc() + } + return +} + +func (r *onEOFReader) Close() error { + err := r.rc.Close() + r.runFunc() + return err +} + +func (r *onEOFReader) runFunc() { + if fn := r.fn; fn != nil { + fn() + r.fn = nil + } +}