diff --git a/docs/endpoint.go b/docs/endpoint.go index 6dd4e1f6..2eac41ce 100644 --- a/docs/endpoint.go +++ b/docs/endpoint.go @@ -33,21 +33,15 @@ func scanForApiVersion(hostname string) (string, APIVersion) { return hostname, DefaultAPIVersion } -func NewEndpoint(hostname string, secure bool) (*Endpoint, error) { - var ( - endpoint = Endpoint{secure: secure} - trimmedHostname string - err error - ) - if !strings.HasPrefix(hostname, "http") { - hostname = "https://" + hostname - } - trimmedHostname, endpoint.Version = scanForApiVersion(hostname) - endpoint.URL, err = url.Parse(trimmedHostname) +func NewEndpoint(hostname string, insecureRegistries []string) (*Endpoint, error) { + endpoint, err := newEndpoint(hostname) if err != nil { return nil, err } + secure := isSecure(endpoint.URL.Host, insecureRegistries) + endpoint.secure = secure + // Try HTTPS ping to registry endpoint.URL.Scheme = "https" if _, err := endpoint.Ping(); err != nil { @@ -65,12 +59,28 @@ func NewEndpoint(hostname string, secure bool) (*Endpoint, error) { endpoint.URL.Scheme = "http" _, err2 := endpoint.Ping() if err2 == nil { - return &endpoint, nil + return endpoint, nil } return nil, fmt.Errorf("Invalid registry endpoint %q. HTTPS attempt: %v. HTTP attempt: %v", endpoint, err, err2) } + return endpoint, nil +} +func newEndpoint(hostname string) (*Endpoint, error) { + var ( + endpoint = Endpoint{secure: true} + trimmedHostname string + err error + ) + if !strings.HasPrefix(hostname, "http") { + hostname = "https://" + hostname + } + trimmedHostname, endpoint.Version = scanForApiVersion(hostname) + endpoint.URL, err = url.Parse(trimmedHostname) + if err != nil { + return nil, err + } return &endpoint, nil } @@ -141,9 +151,9 @@ func (e Endpoint) Ping() (RegistryInfo, error) { return info, nil } -// IsSecure returns false if the provided hostname is part of the list of insecure registries. +// isSecure returns false if the provided hostname is part of the list of insecure registries. // Insecure registries accept HTTP and/or accept HTTPS with certificates from unknown CAs. -func IsSecure(hostname string, insecureRegistries []string) bool { +func isSecure(hostname string, insecureRegistries []string) bool { if hostname == IndexServerAddress() { return true } diff --git a/docs/endpoint_test.go b/docs/endpoint_test.go new file mode 100644 index 00000000..0ec1220d --- /dev/null +++ b/docs/endpoint_test.go @@ -0,0 +1,27 @@ +package registry + +import "testing" + +func TestEndpointParse(t *testing.T) { + testData := []struct { + str string + expected string + }{ + {IndexServerAddress(), IndexServerAddress()}, + {"http://0.0.0.0:5000", "http://0.0.0.0:5000/v1/"}, + {"0.0.0.0:5000", "https://0.0.0.0:5000/v1/"}, + } + for _, td := range testData { + e, err := newEndpoint(td.str) + if err != nil { + t.Errorf("%q: %s", td.str, err) + } + if e == nil { + t.Logf("something's fishy, endpoint for %q is nil", td.str) + continue + } + if e.String() != td.expected { + t.Errorf("expected %q, got %q", td.expected, e.String()) + } + } +} diff --git a/docs/registry_test.go b/docs/registry_test.go index c9a9fc81..7e63ee92 100644 --- a/docs/registry_test.go +++ b/docs/registry_test.go @@ -316,3 +316,32 @@ func TestAddRequiredHeadersToRedirectedRequests(t *testing.T) { } } } + +func TestIsSecure(t *testing.T) { + tests := []struct { + addr string + insecureRegistries []string + expected bool + }{ + {"example.com", []string{}, true}, + {"example.com", []string{"example.com"}, false}, + {"localhost", []string{"localhost:5000"}, false}, + {"localhost:5000", []string{"localhost:5000"}, false}, + {"localhost", []string{"example.com"}, false}, + {"127.0.0.1:5000", []string{"127.0.0.1:5000"}, false}, + {"localhost", []string{}, false}, + {"localhost:5000", []string{}, false}, + {"127.0.0.1", []string{}, false}, + {"localhost", []string{"example.com"}, false}, + {"127.0.0.1", []string{"example.com"}, false}, + {"example.com", []string{}, true}, + {"example.com", []string{"example.com"}, false}, + {"127.0.0.1", []string{"example.com"}, false}, + {"127.0.0.1:5000", []string{"example.com"}, false}, + } + for _, tt := range tests { + if sec := isSecure(tt.addr, tt.insecureRegistries); sec != tt.expected { + t.Errorf("isSecure failed for %q %v, expected %v got %v", tt.addr, tt.insecureRegistries, tt.expected, sec) + } + } +} diff --git a/docs/service.go b/docs/service.go index 7051d934..53e8278b 100644 --- a/docs/service.go +++ b/docs/service.go @@ -40,7 +40,7 @@ func (s *Service) Auth(job *engine.Job) engine.Status { job.GetenvJson("authConfig", authConfig) if addr := authConfig.ServerAddress; addr != "" && addr != IndexServerAddress() { - endpoint, err := NewEndpoint(addr, IsSecure(addr, s.insecureRegistries)) + endpoint, err := NewEndpoint(addr, s.insecureRegistries) if err != nil { return job.Error(err) } @@ -92,9 +92,7 @@ func (s *Service) Search(job *engine.Job) engine.Status { return job.Error(err) } - secure := IsSecure(hostname, s.insecureRegistries) - - endpoint, err := NewEndpoint(hostname, secure) + endpoint, err := NewEndpoint(hostname, s.insecureRegistries) if err != nil { return job.Error(err) }