diff --git a/docs/endpoint.go b/docs/endpoint.go index 9ca9ed8b..72bcce4a 100644 --- a/docs/endpoint.go +++ b/docs/endpoint.go @@ -227,6 +227,21 @@ func (e *Endpoint) pingV2() (RegistryInfo, error) { } defer resp.Body.Close() + // The endpoint may have multiple supported versions. + // Ensure it supports the v2 Registry API. + var supportsV2 bool + + for _, versionName := range resp.Header[http.CanonicalHeaderKey("Docker-Distribution-API-Version")] { + if versionName == "registry/2.0" { + supportsV2 = true + break + } + } + + if !supportsV2 { + return RegistryInfo{}, fmt.Errorf("%s does not appear to be a v2 registry endpoint", e) + } + if resp.StatusCode == http.StatusOK { // It would seem that no authentication/authorization is required. // So we don't need to parse/add any authorization schemes. diff --git a/docs/endpoint_test.go b/docs/endpoint_test.go index f6489034..ef258999 100644 --- a/docs/endpoint_test.go +++ b/docs/endpoint_test.go @@ -1,6 +1,11 @@ package registry -import "testing" +import ( + "net/http" + "net/http/httptest" + "net/url" + "testing" +) func TestEndpointParse(t *testing.T) { testData := []struct { @@ -27,3 +32,59 @@ func TestEndpointParse(t *testing.T) { } } } + +// Ensure that a registry endpoint that responds with a 401 only is determined +// to be a v1 registry unless it includes a valid v2 API header. +func TestValidateEndpointAmbiguousAPIVersion(t *testing.T) { + requireBasicAuthHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("WWW-Authenticate", `Basic realm="localhost"`) + w.WriteHeader(http.StatusUnauthorized) + }) + + requireBasicAuthHandlerV2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("Docker-Distribution-API-Version", "registry/2.0") + requireBasicAuthHandler.ServeHTTP(w, r) + }) + + // Make a test server which should validate as a v1 server. + testServer := httptest.NewServer(requireBasicAuthHandler) + defer testServer.Close() + + testServerURL, err := url.Parse(testServer.URL) + if err != nil { + t.Fatal(err) + } + + testEndpoint := Endpoint{ + URL: testServerURL, + Version: APIVersionUnknown, + } + + if err = validateEndpoint(&testEndpoint); err != nil { + t.Fatal(err) + } + + if testEndpoint.Version != APIVersion1 { + t.Fatalf("expected endpoint to validate to %s, got %s", APIVersion1, testEndpoint.Version) + } + + // Make a test server which should validate as a v2 server. + testServer = httptest.NewServer(requireBasicAuthHandlerV2) + defer testServer.Close() + + testServerURL, err = url.Parse(testServer.URL) + if err != nil { + t.Fatal(err) + } + + testEndpoint.URL = testServerURL + testEndpoint.Version = APIVersionUnknown + + if err = validateEndpoint(&testEndpoint); err != nil { + t.Fatal(err) + } + + if testEndpoint.Version != APIVersion2 { + t.Fatalf("expected endpoint to validate to %s, got %s", APIVersion2, testEndpoint.Version) + } +}