Use http.NewRequestWithContext for outgoing HTTP requests
This simple change mainly affects the distribution client. By respecting the context the caller passes in, timeouts and cancellations will work as expected. Also, transports which rely on the context (such as tracing transports that retrieve a span from the context) will work properly. Signed-off-by: Aaron Lehmann <alehmann@netflix.com>
This commit is contained in:
		
							parent
							
								
									26163d8256
								
							
						
					
					
						commit
						fbdfd1ac35
					
				
					 4 changed files with 32 additions and 20 deletions
				
			
		|  | @ -1,6 +1,7 @@ | |||
| package auth | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
|  | @ -258,7 +259,7 @@ func (th *tokenHandler) AuthorizeRequest(req *http.Request, params map[string]st | |||
| 		}.String()) | ||||
| 	} | ||||
| 
 | ||||
| 	token, err := th.getToken(params, additionalScopes...) | ||||
| 	token, err := th.getToken(req.Context(), params, additionalScopes...) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | @ -268,7 +269,7 @@ func (th *tokenHandler) AuthorizeRequest(req *http.Request, params map[string]st | |||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (th *tokenHandler) getToken(params map[string]string, additionalScopes ...string) (string, error) { | ||||
| func (th *tokenHandler) getToken(ctx context.Context, params map[string]string, additionalScopes ...string) (string, error) { | ||||
| 	th.tokenLock.Lock() | ||||
| 	defer th.tokenLock.Unlock() | ||||
| 	scopes := make([]string, 0, len(th.scopes)+len(additionalScopes)) | ||||
|  | @ -286,7 +287,7 @@ func (th *tokenHandler) getToken(params map[string]string, additionalScopes ...s | |||
| 
 | ||||
| 	now := th.clock.Now() | ||||
| 	if now.After(th.tokenExpiration) || addedScopes { | ||||
| 		token, expiration, err := th.fetchToken(params, scopes) | ||||
| 		token, expiration, err := th.fetchToken(ctx, params, scopes) | ||||
| 		if err != nil { | ||||
| 			return "", err | ||||
| 		} | ||||
|  | @ -320,7 +321,7 @@ type postTokenResponse struct { | |||
| 	Scope        string    `json:"scope"` | ||||
| } | ||||
| 
 | ||||
| func (th *tokenHandler) fetchTokenWithOAuth(realm *url.URL, refreshToken, service string, scopes []string) (token string, expiration time.Time, err error) { | ||||
| func (th *tokenHandler) fetchTokenWithOAuth(ctx context.Context, realm *url.URL, refreshToken, service string, scopes []string) (token string, expiration time.Time, err error) { | ||||
| 	form := url.Values{} | ||||
| 	form.Set("scope", strings.Join(scopes, " ")) | ||||
| 	form.Set("service", service) | ||||
|  | @ -348,7 +349,12 @@ func (th *tokenHandler) fetchTokenWithOAuth(realm *url.URL, refreshToken, servic | |||
| 		return "", time.Time{}, fmt.Errorf("no supported grant type") | ||||
| 	} | ||||
| 
 | ||||
| 	resp, err := th.client().PostForm(realm.String(), form) | ||||
| 	req, err := http.NewRequestWithContext(ctx, http.MethodPost, realm.String(), strings.NewReader(form.Encode())) | ||||
| 	if err != nil { | ||||
| 		return "", time.Time{}, err | ||||
| 	} | ||||
| 	req.Header.Set("Content-Type", "application/x-www-form-urlencoded") | ||||
| 	resp, err := th.client().Do(req) | ||||
| 	if err != nil { | ||||
| 		return "", time.Time{}, err | ||||
| 	} | ||||
|  | @ -396,9 +402,8 @@ type getTokenResponse struct { | |||
| 	RefreshToken string    `json:"refresh_token"` | ||||
| } | ||||
| 
 | ||||
| func (th *tokenHandler) fetchTokenWithBasicAuth(realm *url.URL, service string, scopes []string) (token string, expiration time.Time, err error) { | ||||
| 
 | ||||
| 	req, err := http.NewRequest("GET", realm.String(), nil) | ||||
| func (th *tokenHandler) fetchTokenWithBasicAuth(ctx context.Context, realm *url.URL, service string, scopes []string) (token string, expiration time.Time, err error) { | ||||
| 	req, err := http.NewRequestWithContext(ctx, http.MethodGet, realm.String(), nil) | ||||
| 	if err != nil { | ||||
| 		return "", time.Time{}, err | ||||
| 	} | ||||
|  | @ -479,7 +484,7 @@ func (th *tokenHandler) fetchTokenWithBasicAuth(realm *url.URL, service string, | |||
| 	return tr.Token, tr.IssuedAt.Add(time.Duration(tr.ExpiresIn) * time.Second), nil | ||||
| } | ||||
| 
 | ||||
| func (th *tokenHandler) fetchToken(params map[string]string, scopes []string) (token string, expiration time.Time, err error) { | ||||
| func (th *tokenHandler) fetchToken(ctx context.Context, params map[string]string, scopes []string) (token string, expiration time.Time, err error) { | ||||
| 	realm, ok := params["realm"] | ||||
| 	if !ok { | ||||
| 		return "", time.Time{}, errors.New("no realm specified for token auth challenge") | ||||
|  | @ -500,10 +505,10 @@ func (th *tokenHandler) fetchToken(params map[string]string, scopes []string) (t | |||
| 	} | ||||
| 
 | ||||
| 	if refreshToken != "" || th.forceOAuth { | ||||
| 		return th.fetchTokenWithOAuth(realmURL, refreshToken, service, scopes) | ||||
| 		return th.fetchTokenWithOAuth(ctx, realmURL, refreshToken, service, scopes) | ||||
| 	} | ||||
| 
 | ||||
| 	return th.fetchTokenWithBasicAuth(realmURL, service, scopes) | ||||
| 	return th.fetchTokenWithBasicAuth(ctx, realmURL, service, scopes) | ||||
| } | ||||
| 
 | ||||
| type basicHandler struct { | ||||
|  |  | |||
|  | @ -13,6 +13,8 @@ import ( | |||
| ) | ||||
| 
 | ||||
| type httpBlobUpload struct { | ||||
| 	ctx context.Context | ||||
| 
 | ||||
| 	statter distribution.BlobStatter | ||||
| 	client  *http.Client | ||||
| 
 | ||||
|  | @ -36,7 +38,7 @@ func (hbu *httpBlobUpload) handleErrorResponse(resp *http.Response) error { | |||
| } | ||||
| 
 | ||||
| func (hbu *httpBlobUpload) ReadFrom(r io.Reader) (n int64, err error) { | ||||
| 	req, err := http.NewRequest("PATCH", hbu.location, ioutil.NopCloser(r)) | ||||
| 	req, err := http.NewRequestWithContext(hbu.ctx, "PATCH", hbu.location, ioutil.NopCloser(r)) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
|  | @ -69,7 +71,7 @@ func (hbu *httpBlobUpload) ReadFrom(r io.Reader) (n int64, err error) { | |||
| } | ||||
| 
 | ||||
| func (hbu *httpBlobUpload) Write(p []byte) (n int, err error) { | ||||
| 	req, err := http.NewRequest("PATCH", hbu.location, bytes.NewReader(p)) | ||||
| 	req, err := http.NewRequestWithContext(hbu.ctx, "PATCH", hbu.location, bytes.NewReader(p)) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
|  | @ -117,7 +119,7 @@ func (hbu *httpBlobUpload) StartedAt() time.Time { | |||
| 
 | ||||
| func (hbu *httpBlobUpload) Commit(ctx context.Context, desc distribution.Descriptor) (distribution.Descriptor, error) { | ||||
| 	// TODO(dmcgowan): Check if already finished, if so just fetch | ||||
| 	req, err := http.NewRequest("PUT", hbu.location, nil) | ||||
| 	req, err := http.NewRequestWithContext(hbu.ctx, "PUT", hbu.location, nil) | ||||
| 	if err != nil { | ||||
| 		return distribution.Descriptor{}, err | ||||
| 	} | ||||
|  | @ -140,7 +142,7 @@ func (hbu *httpBlobUpload) Commit(ctx context.Context, desc distribution.Descrip | |||
| } | ||||
| 
 | ||||
| func (hbu *httpBlobUpload) Cancel(ctx context.Context) error { | ||||
| 	req, err := http.NewRequest("DELETE", hbu.location, nil) | ||||
| 	req, err := http.NewRequestWithContext(hbu.ctx, "DELETE", hbu.location, nil) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  |  | |||
|  | @ -2,6 +2,7 @@ package client | |||
| 
 | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"testing" | ||||
|  | @ -126,6 +127,7 @@ func TestUploadReadFrom(t *testing.T) { | |||
| 	defer c() | ||||
| 
 | ||||
| 	blobUpload := &httpBlobUpload{ | ||||
| 		ctx:    context.Background(), | ||||
| 		client: &http.Client{}, | ||||
| 	} | ||||
| 
 | ||||
|  | @ -265,6 +267,7 @@ func TestUploadSize(t *testing.T) { | |||
| 
 | ||||
| 	// Writing with ReadFrom | ||||
| 	blobUpload := &httpBlobUpload{ | ||||
| 		ctx:      context.Background(), | ||||
| 		client:   &http.Client{}, | ||||
| 		location: e + readFromLocationPath, | ||||
| 	} | ||||
|  | @ -284,6 +287,7 @@ func TestUploadSize(t *testing.T) { | |||
| 
 | ||||
| 	// Writing with Write | ||||
| 	blobUpload = &httpBlobUpload{ | ||||
| 		ctx:      context.Background(), | ||||
| 		client:   &http.Client{}, | ||||
| 		location: e + writeLocationPath, | ||||
| 	} | ||||
|  | @ -409,6 +413,7 @@ func TestUploadWrite(t *testing.T) { | |||
| 	defer c() | ||||
| 
 | ||||
| 	blobUpload := &httpBlobUpload{ | ||||
| 		ctx:    context.Background(), | ||||
| 		client: &http.Client{}, | ||||
| 	} | ||||
| 
 | ||||
|  |  | |||
|  | @ -118,9 +118,7 @@ func (r *registry) Repositories(ctx context.Context, entries []string, last stri | |||
| 			return 0, err | ||||
| 		} | ||||
| 
 | ||||
| 		for cnt := range ctlg.Repositories { | ||||
| 			entries[cnt] = ctlg.Repositories[cnt] | ||||
| 		} | ||||
| 		copy(entries, ctlg.Repositories) | ||||
| 		numFilled = len(ctlg.Repositories) | ||||
| 
 | ||||
| 		link := resp.Header.Get("Link") | ||||
|  | @ -373,7 +371,7 @@ func (t *tags) Untag(ctx context.Context, tag string) error { | |||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	req, err := http.NewRequest("DELETE", u, nil) | ||||
| 	req, err := http.NewRequestWithContext(ctx, "DELETE", u, nil) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | @ -792,7 +790,7 @@ func (bs *blobs) Create(ctx context.Context, options ...distribution.BlobCreateO | |||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	req, err := http.NewRequest("POST", u, nil) | ||||
| 	req, err := http.NewRequestWithContext(ctx, "POST", u, nil) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | @ -827,6 +825,7 @@ func (bs *blobs) Create(ctx context.Context, options ...distribution.BlobCreateO | |||
| 		} | ||||
| 
 | ||||
| 		return &httpBlobUpload{ | ||||
| 			ctx:       ctx, | ||||
| 			statter:   bs.statter, | ||||
| 			client:    bs.client, | ||||
| 			uuid:      uuid, | ||||
|  | @ -845,6 +844,7 @@ func (bs *blobs) Resume(ctx context.Context, id string) (distribution.BlobWriter | |||
| 	} | ||||
| 
 | ||||
| 	return &httpBlobUpload{ | ||||
| 		ctx:       ctx, | ||||
| 		statter:   bs.statter, | ||||
| 		client:    bs.client, | ||||
| 		uuid:      id, | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue