Add unit tests for auth challenge and endpoint
Signed-off-by: Derek McGowan <derek@mcgstyle.net> (github: dmcgowan)
This commit is contained in:
parent
174a732c94
commit
b1ba2183ee
7 changed files with 315 additions and 14 deletions
|
@ -127,7 +127,7 @@ func expectTokenOrQuoted(s string) (value string, rest string) {
|
|||
p := make([]byte, len(s)-1)
|
||||
j := copy(p, s[:i])
|
||||
escape := true
|
||||
for i = i + i; i < len(s); i++ {
|
||||
for i = i + 1; i < len(s); i++ {
|
||||
b := s[i]
|
||||
switch {
|
||||
case escape:
|
||||
|
|
37
registry/client/authchallenge_test.go
Normal file
37
registry/client/authchallenge_test.go
Normal file
|
@ -0,0 +1,37 @@
|
|||
package client
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAuthChallengeParse(t *testing.T) {
|
||||
header := http.Header{}
|
||||
header.Add("WWW-Authenticate", `Bearer realm="https://auth.example.com/token",service="registry.example.com",other=fun,slashed="he\"\l\lo"`)
|
||||
|
||||
challenges := parseAuthHeader(header)
|
||||
if len(challenges) != 1 {
|
||||
t.Fatalf("Unexpected number of auth challenges: %d, expected 1", len(challenges))
|
||||
}
|
||||
|
||||
if expected := "bearer"; challenges[0].Scheme != expected {
|
||||
t.Fatalf("Unexpected scheme: %s, expected: %s", challenges[0].Scheme, expected)
|
||||
}
|
||||
|
||||
if expected := "https://auth.example.com/token"; challenges[0].Parameters["realm"] != expected {
|
||||
t.Fatalf("Unexpected param: %s, expected: %s", challenges[0].Parameters["realm"], expected)
|
||||
}
|
||||
|
||||
if expected := "registry.example.com"; challenges[0].Parameters["service"] != expected {
|
||||
t.Fatalf("Unexpected param: %s, expected: %s", challenges[0].Parameters["service"], expected)
|
||||
}
|
||||
|
||||
if expected := "fun"; challenges[0].Parameters["other"] != expected {
|
||||
t.Fatalf("Unexpected param: %s, expected: %s", challenges[0].Parameters["other"], expected)
|
||||
}
|
||||
|
||||
if expected := "he\"llo"; challenges[0].Parameters["slashed"] != expected {
|
||||
t.Fatalf("Unexpected param: %s, expected: %s", challenges[0].Parameters["slashed"], expected)
|
||||
}
|
||||
|
||||
}
|
|
@ -117,6 +117,8 @@ func (e *RepositoryEndpoint) URLBuilder() (*v2.URLBuilder, error) {
|
|||
|
||||
// HTTPClient returns a new HTTP client configured for this endpoint
|
||||
func (e *RepositoryEndpoint) HTTPClient(name string) (*http.Client, error) {
|
||||
// TODO(dmcgowan): create http.Transport
|
||||
|
||||
transport := &repositoryTransport{
|
||||
Header: e.Header,
|
||||
}
|
||||
|
|
259
registry/client/endpoint_test.go
Normal file
259
registry/client/endpoint_test.go
Normal file
|
@ -0,0 +1,259 @@
|
|||
package client
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/docker/distribution/testutil"
|
||||
)
|
||||
|
||||
type testAuthenticationWrapper struct {
|
||||
headers http.Header
|
||||
authCheck func(string) bool
|
||||
next http.Handler
|
||||
}
|
||||
|
||||
func (w *testAuthenticationWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
auth := r.Header.Get("Authorization")
|
||||
if auth == "" || !w.authCheck(auth) {
|
||||
h := rw.Header()
|
||||
for k, values := range w.headers {
|
||||
h[k] = values
|
||||
}
|
||||
rw.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
w.next.ServeHTTP(rw, r)
|
||||
}
|
||||
|
||||
func testServerWithAuth(rrm testutil.RequestResponseMap, authenticate string, authCheck func(string) bool) (*RepositoryEndpoint, func()) {
|
||||
h := testutil.NewHandler(rrm)
|
||||
wrapper := &testAuthenticationWrapper{
|
||||
|
||||
headers: http.Header(map[string][]string{
|
||||
"Docker-Distribution-API-Version": {"registry/2.0"},
|
||||
"WWW-Authenticate": {authenticate},
|
||||
}),
|
||||
authCheck: authCheck,
|
||||
next: h,
|
||||
}
|
||||
|
||||
s := httptest.NewServer(wrapper)
|
||||
e := RepositoryEndpoint{Endpoint: s.URL, Mirror: false}
|
||||
return &e, s.Close
|
||||
}
|
||||
|
||||
type testCredentialStore struct {
|
||||
username string
|
||||
password string
|
||||
}
|
||||
|
||||
func (tcs *testCredentialStore) Basic(*url.URL) (string, string) {
|
||||
return tcs.username, tcs.password
|
||||
}
|
||||
|
||||
func TestEndpointAuthorizeToken(t *testing.T) {
|
||||
service := "localhost.localdomain"
|
||||
repo1 := "some/registry"
|
||||
repo2 := "other/registry"
|
||||
scope1 := fmt.Sprintf("repository:%s:pull,push", repo1)
|
||||
scope2 := fmt.Sprintf("repository:%s:pull,push", repo2)
|
||||
|
||||
tokenMap := testutil.RequestResponseMap([]testutil.RequestResponseMapping{
|
||||
{
|
||||
Request: testutil.Request{
|
||||
Method: "GET",
|
||||
Route: fmt.Sprintf("/token?scope=%s&service=%s", url.QueryEscape(scope1), service),
|
||||
},
|
||||
Response: testutil.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: []byte(`{"token":"statictoken"}`),
|
||||
},
|
||||
},
|
||||
{
|
||||
Request: testutil.Request{
|
||||
Method: "GET",
|
||||
Route: fmt.Sprintf("/token?scope=%s&service=%s", url.QueryEscape(scope2), service),
|
||||
},
|
||||
Response: testutil.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: []byte(`{"token":"badtoken"}`),
|
||||
},
|
||||
},
|
||||
})
|
||||
te, tc := testServer(tokenMap)
|
||||
defer tc()
|
||||
|
||||
m := testutil.RequestResponseMap([]testutil.RequestResponseMapping{
|
||||
{
|
||||
Request: testutil.Request{
|
||||
Method: "GET",
|
||||
Route: "/hello",
|
||||
},
|
||||
Response: testutil.Response{
|
||||
StatusCode: http.StatusAccepted,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
authenicate := fmt.Sprintf("Bearer realm=%q,service=%q", te.Endpoint+"/token", service)
|
||||
validCheck := func(a string) bool {
|
||||
return a == "Bearer statictoken"
|
||||
}
|
||||
e, c := testServerWithAuth(m, authenicate, validCheck)
|
||||
defer c()
|
||||
|
||||
client, err := e.HTTPClient(repo1)
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating http client: %s", err)
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest("GET", e.Endpoint+"/hello", nil)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Error sending get request: %s", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusAccepted {
|
||||
t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusAccepted)
|
||||
}
|
||||
|
||||
badCheck := func(a string) bool {
|
||||
return a == "Bearer statictoken"
|
||||
}
|
||||
e2, c2 := testServerWithAuth(m, authenicate, badCheck)
|
||||
defer c2()
|
||||
|
||||
client2, err := e2.HTTPClient(repo2)
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating http client: %s", err)
|
||||
}
|
||||
|
||||
req, _ = http.NewRequest("GET", e.Endpoint+"/hello", nil)
|
||||
resp, err = client2.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Error sending get request: %s", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusUnauthorized)
|
||||
}
|
||||
}
|
||||
|
||||
func basicAuth(username, password string) string {
|
||||
auth := username + ":" + password
|
||||
return base64.StdEncoding.EncodeToString([]byte(auth))
|
||||
}
|
||||
|
||||
func TestEndpointAuthorizeTokenBasic(t *testing.T) {
|
||||
service := "localhost.localdomain"
|
||||
repo := "some/fun/registry"
|
||||
scope := fmt.Sprintf("repository:%s:pull,push", repo)
|
||||
username := "tokenuser"
|
||||
password := "superSecretPa$$word"
|
||||
|
||||
tokenMap := testutil.RequestResponseMap([]testutil.RequestResponseMapping{
|
||||
{
|
||||
Request: testutil.Request{
|
||||
Method: "GET",
|
||||
Route: fmt.Sprintf("/token?account=%s&scope=%s&service=%s", username, url.QueryEscape(scope), service),
|
||||
},
|
||||
Response: testutil.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: []byte(`{"token":"statictoken"}`),
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
authenicate1 := fmt.Sprintf("Basic realm=localhost")
|
||||
basicCheck := func(a string) bool {
|
||||
return a == fmt.Sprintf("Basic %s", basicAuth(username, password))
|
||||
}
|
||||
te, tc := testServerWithAuth(tokenMap, authenicate1, basicCheck)
|
||||
defer tc()
|
||||
|
||||
m := testutil.RequestResponseMap([]testutil.RequestResponseMapping{
|
||||
{
|
||||
Request: testutil.Request{
|
||||
Method: "GET",
|
||||
Route: "/hello",
|
||||
},
|
||||
Response: testutil.Response{
|
||||
StatusCode: http.StatusAccepted,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
authenicate2 := fmt.Sprintf("Bearer realm=%q,service=%q", te.Endpoint+"/token", service)
|
||||
bearerCheck := func(a string) bool {
|
||||
return a == "Bearer statictoken"
|
||||
}
|
||||
e, c := testServerWithAuth(m, authenicate2, bearerCheck)
|
||||
defer c()
|
||||
|
||||
e.Credentials = &testCredentialStore{
|
||||
username: username,
|
||||
password: password,
|
||||
}
|
||||
|
||||
client, err := e.HTTPClient(repo)
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating http client: %s", err)
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest("GET", e.Endpoint+"/hello", nil)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Error sending get request: %s", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusAccepted {
|
||||
t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusAccepted)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEndpointAuthorizeBasic(t *testing.T) {
|
||||
m := testutil.RequestResponseMap([]testutil.RequestResponseMapping{
|
||||
{
|
||||
Request: testutil.Request{
|
||||
Method: "GET",
|
||||
Route: "/hello",
|
||||
},
|
||||
Response: testutil.Response{
|
||||
StatusCode: http.StatusAccepted,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
username := "user1"
|
||||
password := "funSecretPa$$word"
|
||||
authenicate := fmt.Sprintf("Basic realm=localhost")
|
||||
validCheck := func(a string) bool {
|
||||
return a == fmt.Sprintf("Basic %s", basicAuth(username, password))
|
||||
}
|
||||
e, c := testServerWithAuth(m, authenicate, validCheck)
|
||||
defer c()
|
||||
e.Credentials = &testCredentialStore{
|
||||
username: username,
|
||||
password: password,
|
||||
}
|
||||
|
||||
client, err := e.HTTPClient("test/repo/basic")
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating http client: %s", err)
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest("GET", e.Endpoint+"/hello", nil)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Error sending get request: %s", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusAccepted {
|
||||
t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusAccepted)
|
||||
}
|
||||
}
|
|
@ -25,8 +25,8 @@ import (
|
|||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
// NewRepositoryClient creates a new Repository for the given repository name and endpoint
|
||||
func NewRepositoryClient(ctx context.Context, name string, endpoint *RepositoryEndpoint) (distribution.Repository, error) {
|
||||
// NewRepository creates a new Repository for the given repository name and endpoint
|
||||
func NewRepository(ctx context.Context, name string, endpoint *RepositoryEndpoint) (distribution.Repository, error) {
|
||||
if err := v2.ValidateRespositoryName(name); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -97,7 +97,7 @@ func TestLayerFetch(t *testing.T) {
|
|||
e, c := testServer(m)
|
||||
defer c()
|
||||
|
||||
r, err := NewRepositoryClient(context.Background(), "test.example.com/repo1", e)
|
||||
r, err := NewRepository(context.Background(), "test.example.com/repo1", e)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -127,7 +127,7 @@ func TestLayerExists(t *testing.T) {
|
|||
e, c := testServer(m)
|
||||
defer c()
|
||||
|
||||
r, err := NewRepositoryClient(context.Background(), "test.example.com/repo1", e)
|
||||
r, err := NewRepository(context.Background(), "test.example.com/repo1", e)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -227,7 +227,7 @@ func TestLayerUploadChunked(t *testing.T) {
|
|||
e, c := testServer(m)
|
||||
defer c()
|
||||
|
||||
r, err := NewRepositoryClient(context.Background(), repo, e)
|
||||
r, err := NewRepository(context.Background(), repo, e)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -334,7 +334,7 @@ func TestLayerUploadMonolithic(t *testing.T) {
|
|||
e, c := testServer(m)
|
||||
defer c()
|
||||
|
||||
r, err := NewRepositoryClient(context.Background(), repo, e)
|
||||
r, err := NewRepository(context.Background(), repo, e)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -475,7 +475,7 @@ func TestManifestFetch(t *testing.T) {
|
|||
e, c := testServer(m)
|
||||
defer c()
|
||||
|
||||
r, err := NewRepositoryClient(context.Background(), repo, e)
|
||||
r, err := NewRepository(context.Background(), repo, e)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -508,7 +508,7 @@ func TestManifestFetchByTag(t *testing.T) {
|
|||
e, c := testServer(m)
|
||||
defer c()
|
||||
|
||||
r, err := NewRepositoryClient(context.Background(), repo, e)
|
||||
r, err := NewRepository(context.Background(), repo, e)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -553,7 +553,7 @@ func TestManifestDelete(t *testing.T) {
|
|||
e, c := testServer(m)
|
||||
defer c()
|
||||
|
||||
r, err := NewRepositoryClient(context.Background(), repo, e)
|
||||
r, err := NewRepository(context.Background(), repo, e)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -591,7 +591,7 @@ func TestManifestPut(t *testing.T) {
|
|||
e, c := testServer(m)
|
||||
defer c()
|
||||
|
||||
r, err := NewRepositoryClient(context.Background(), repo, e)
|
||||
r, err := NewRepository(context.Background(), repo, e)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
@ -40,16 +41,18 @@ type Request struct {
|
|||
func (r Request) String() string {
|
||||
queryString := ""
|
||||
if len(r.QueryParams) > 0 {
|
||||
queryString = "?"
|
||||
keys := make([]string, 0, len(r.QueryParams))
|
||||
queryParts := make([]string, 0, len(r.QueryParams))
|
||||
for k := range r.QueryParams {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
for _, k := range keys {
|
||||
queryString += strings.Join(r.QueryParams[k], "&") + "&"
|
||||
for _, val := range r.QueryParams[k] {
|
||||
queryParts = append(queryParts, fmt.Sprintf("%s=%s", k, url.QueryEscape(val)))
|
||||
}
|
||||
queryString = queryString[:len(queryString)-1]
|
||||
}
|
||||
queryString = "?" + strings.Join(queryParts, "&")
|
||||
}
|
||||
return fmt.Sprintf("%s %s%s\n%s", r.Method, r.Route, queryString, r.Body)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue