Merge pull request #1 from docker/master

update
This commit is contained in:
Eric Yang 2016-06-07 10:58:38 +08:00
commit a4f7559f6c
80 changed files with 3871 additions and 4693 deletions

View file

@ -63,6 +63,19 @@ var (
Description: "Returned when a service is not available", Description: "Returned when a service is not available",
HTTPStatusCode: http.StatusServiceUnavailable, HTTPStatusCode: http.StatusServiceUnavailable,
}) })
// ErrorCodeTooManyRequests is returned if a client attempts too many
// times to contact a service endpoint.
ErrorCodeTooManyRequests = Register("errcode", ErrorDescriptor{
Value: "TOOMANYREQUESTS",
Message: "too many requests",
Description: `Returned when a client attempts to contact a
service too many times`,
// FIXME: go1.5 doesn't export http.StatusTooManyRequests while
// go1.6 does. Update the hardcoded value to the constant once
// Docker updates golang version to 1.6.
HTTPStatusCode: 429,
})
) )
var nextCode = 1000 var nextCode = 1000

View file

@ -17,33 +17,35 @@ import (
// under "/foo/v2/...". Most application will only provide a schema, host and // under "/foo/v2/...". Most application will only provide a schema, host and
// port, such as "https://localhost:5000/". // port, such as "https://localhost:5000/".
type URLBuilder struct { type URLBuilder struct {
root *url.URL // url root (ie http://localhost/) root *url.URL // url root (ie http://localhost/)
router *mux.Router router *mux.Router
relative bool
} }
// NewURLBuilder creates a URLBuilder with provided root url object. // NewURLBuilder creates a URLBuilder with provided root url object.
func NewURLBuilder(root *url.URL) *URLBuilder { func NewURLBuilder(root *url.URL, relative bool) *URLBuilder {
return &URLBuilder{ return &URLBuilder{
root: root, root: root,
router: Router(), router: Router(),
relative: relative,
} }
} }
// NewURLBuilderFromString workes identically to NewURLBuilder except it takes // NewURLBuilderFromString workes identically to NewURLBuilder except it takes
// a string argument for the root, returning an error if it is not a valid // a string argument for the root, returning an error if it is not a valid
// url. // url.
func NewURLBuilderFromString(root string) (*URLBuilder, error) { func NewURLBuilderFromString(root string, relative bool) (*URLBuilder, error) {
u, err := url.Parse(root) u, err := url.Parse(root)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return NewURLBuilder(u), nil return NewURLBuilder(u, relative), nil
} }
// NewURLBuilderFromRequest uses information from an *http.Request to // NewURLBuilderFromRequest uses information from an *http.Request to
// construct the root url. // construct the root url.
func NewURLBuilderFromRequest(r *http.Request) *URLBuilder { func NewURLBuilderFromRequest(r *http.Request, relative bool) *URLBuilder {
var scheme string var scheme string
forwardedProto := r.Header.Get("X-Forwarded-Proto") forwardedProto := r.Header.Get("X-Forwarded-Proto")
@ -85,7 +87,7 @@ func NewURLBuilderFromRequest(r *http.Request) *URLBuilder {
u.Path = requestPath[0 : index+1] u.Path = requestPath[0 : index+1]
} }
return NewURLBuilder(u) return NewURLBuilder(u, relative)
} }
// BuildBaseURL constructs a base url for the API, typically just "/v2/". // BuildBaseURL constructs a base url for the API, typically just "/v2/".
@ -194,12 +196,13 @@ func (ub *URLBuilder) cloneRoute(name string) clonedRoute {
*route = *ub.router.GetRoute(name) // clone the route *route = *ub.router.GetRoute(name) // clone the route
*root = *ub.root *root = *ub.root
return clonedRoute{Route: route, root: root} return clonedRoute{Route: route, root: root, relative: ub.relative}
} }
type clonedRoute struct { type clonedRoute struct {
*mux.Route *mux.Route
root *url.URL root *url.URL
relative bool
} }
func (cr clonedRoute) URL(pairs ...string) (*url.URL, error) { func (cr clonedRoute) URL(pairs ...string) (*url.URL, error) {
@ -208,6 +211,10 @@ func (cr clonedRoute) URL(pairs ...string) (*url.URL, error) {
return nil, err return nil, err
} }
if cr.relative {
return routeURL, nil
}
if routeURL.Scheme == "" && routeURL.User == nil && routeURL.Host == "" { if routeURL.Scheme == "" && routeURL.User == nil && routeURL.Host == "" {
routeURL.Path = routeURL.Path[1:] routeURL.Path = routeURL.Path[1:]
} }

View file

@ -92,25 +92,31 @@ func TestURLBuilder(t *testing.T) {
"https://localhost:5443", "https://localhost:5443",
} }
for _, root := range roots { doTest := func(relative bool) {
urlBuilder, err := NewURLBuilderFromString(root) for _, root := range roots {
if err != nil { urlBuilder, err := NewURLBuilderFromString(root, relative)
t.Fatalf("unexpected error creating urlbuilder: %v", err)
}
for _, testCase := range makeURLBuilderTestCases(urlBuilder) {
url, err := testCase.build()
if err != nil { if err != nil {
t.Fatalf("%s: error building url: %v", testCase.description, err) t.Fatalf("unexpected error creating urlbuilder: %v", err)
} }
expectedURL := root + testCase.expectedPath for _, testCase := range makeURLBuilderTestCases(urlBuilder) {
url, err := testCase.build()
if err != nil {
t.Fatalf("%s: error building url: %v", testCase.description, err)
}
expectedURL := testCase.expectedPath
if !relative {
expectedURL = root + expectedURL
}
if url != expectedURL { if url != expectedURL {
t.Fatalf("%s: %q != %q", testCase.description, url, expectedURL) t.Fatalf("%s: %q != %q", testCase.description, url, expectedURL)
}
} }
} }
} }
doTest(true)
doTest(false)
} }
func TestURLBuilderWithPrefix(t *testing.T) { func TestURLBuilderWithPrefix(t *testing.T) {
@ -121,25 +127,31 @@ func TestURLBuilderWithPrefix(t *testing.T) {
"https://localhost:5443/prefix/", "https://localhost:5443/prefix/",
} }
for _, root := range roots { doTest := func(relative bool) {
urlBuilder, err := NewURLBuilderFromString(root) for _, root := range roots {
if err != nil { urlBuilder, err := NewURLBuilderFromString(root, relative)
t.Fatalf("unexpected error creating urlbuilder: %v", err)
}
for _, testCase := range makeURLBuilderTestCases(urlBuilder) {
url, err := testCase.build()
if err != nil { if err != nil {
t.Fatalf("%s: error building url: %v", testCase.description, err) t.Fatalf("unexpected error creating urlbuilder: %v", err)
} }
expectedURL := root[0:len(root)-1] + testCase.expectedPath for _, testCase := range makeURLBuilderTestCases(urlBuilder) {
url, err := testCase.build()
if err != nil {
t.Fatalf("%s: error building url: %v", testCase.description, err)
}
if url != expectedURL { expectedURL := testCase.expectedPath
t.Fatalf("%s: %q != %q", testCase.description, url, expectedURL) if !relative {
expectedURL = root[0:len(root)-1] + expectedURL
}
if url != expectedURL {
t.Fatalf("%s: %q != %q", testCase.description, url, expectedURL)
}
} }
} }
} }
doTest(true)
doTest(false)
} }
type builderFromRequestTestCase struct { type builderFromRequestTestCase struct {
@ -197,39 +209,48 @@ func TestBuilderFromRequest(t *testing.T) {
}, },
}, },
} }
doTest := func(relative bool) {
for _, tr := range testRequests { for _, tr := range testRequests {
var builder *URLBuilder var builder *URLBuilder
if tr.configHost.Scheme != "" && tr.configHost.Host != "" { if tr.configHost.Scheme != "" && tr.configHost.Host != "" {
builder = NewURLBuilder(&tr.configHost) builder = NewURLBuilder(&tr.configHost, relative)
} else {
builder = NewURLBuilderFromRequest(tr.request)
}
for _, testCase := range makeURLBuilderTestCases(builder) {
buildURL, err := testCase.build()
if err != nil {
t.Fatalf("%s: error building url: %v", testCase.description, err)
}
var expectedURL string
proto, ok := tr.request.Header["X-Forwarded-Proto"]
if !ok {
expectedURL = tr.base + testCase.expectedPath
} else { } else {
urlBase, err := url.Parse(tr.base) builder = NewURLBuilderFromRequest(tr.request, relative)
if err != nil {
t.Fatal(err)
}
urlBase.Scheme = proto[0]
expectedURL = urlBase.String() + testCase.expectedPath
} }
if buildURL != expectedURL { for _, testCase := range makeURLBuilderTestCases(builder) {
t.Fatalf("%s: %q != %q", testCase.description, buildURL, expectedURL) buildURL, err := testCase.build()
if err != nil {
t.Fatalf("%s: error building url: %v", testCase.description, err)
}
var expectedURL string
proto, ok := tr.request.Header["X-Forwarded-Proto"]
if !ok {
expectedURL = testCase.expectedPath
if !relative {
expectedURL = tr.base + expectedURL
}
} else {
urlBase, err := url.Parse(tr.base)
if err != nil {
t.Fatal(err)
}
urlBase.Scheme = proto[0]
expectedURL = testCase.expectedPath
if !relative {
expectedURL = urlBase.String() + expectedURL
}
}
if buildURL != expectedURL {
t.Fatalf("%s: %q != %q", testCase.description, buildURL, expectedURL)
}
} }
} }
} }
doTest(true)
doTest(false)
} }
func TestBuilderFromRequestWithPrefix(t *testing.T) { func TestBuilderFromRequestWithPrefix(t *testing.T) {
@ -270,12 +291,13 @@ func TestBuilderFromRequestWithPrefix(t *testing.T) {
}, },
} }
var relative bool
for _, tr := range testRequests { for _, tr := range testRequests {
var builder *URLBuilder var builder *URLBuilder
if tr.configHost.Scheme != "" && tr.configHost.Host != "" { if tr.configHost.Scheme != "" && tr.configHost.Host != "" {
builder = NewURLBuilder(&tr.configHost) builder = NewURLBuilder(&tr.configHost, false)
} else { } else {
builder = NewURLBuilderFromRequest(tr.request) builder = NewURLBuilderFromRequest(tr.request, false)
} }
for _, testCase := range makeURLBuilderTestCases(builder) { for _, testCase := range makeURLBuilderTestCases(builder) {
@ -283,17 +305,25 @@ func TestBuilderFromRequestWithPrefix(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("%s: error building url: %v", testCase.description, err) t.Fatalf("%s: error building url: %v", testCase.description, err)
} }
var expectedURL string var expectedURL string
proto, ok := tr.request.Header["X-Forwarded-Proto"] proto, ok := tr.request.Header["X-Forwarded-Proto"]
if !ok { if !ok {
expectedURL = tr.base[0:len(tr.base)-1] + testCase.expectedPath expectedURL = testCase.expectedPath
if !relative {
expectedURL = tr.base[0:len(tr.base)-1] + expectedURL
}
} else { } else {
urlBase, err := url.Parse(tr.base) urlBase, err := url.Parse(tr.base)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
urlBase.Scheme = proto[0] urlBase.Scheme = proto[0]
expectedURL = urlBase.String()[0:len(urlBase.String())-1] + testCase.expectedPath expectedURL = testCase.expectedPath
if !relative {
expectedURL = urlBase.String()[0:len(urlBase.String())-1] + expectedURL
}
} }
if buildURL != expectedURL { if buildURL != expectedURL {

View file

@ -46,7 +46,7 @@ func (htpasswd *htpasswd) authenticateUser(username string, password string) err
// parseHTPasswd parses the contents of htpasswd. This will read all the // parseHTPasswd parses the contents of htpasswd. This will read all the
// entries in the file, whether or not they are needed. An error is returned // entries in the file, whether or not they are needed. An error is returned
// if an syntax errors are encountered or if the reader fails. // if a syntax errors are encountered or if the reader fails.
func parseHTPasswd(rd io.Reader) (map[string][]byte, error) { func parseHTPasswd(rd io.Reader) (map[string][]byte, error) {
entries := map[string][]byte{} entries := map[string][]byte{}
scanner := bufio.NewScanner(rd) scanner := bufio.NewScanner(rd)

View file

@ -25,7 +25,7 @@ type Challenge struct {
type ChallengeManager interface { type ChallengeManager interface {
// GetChallenges returns the challenges for the given // GetChallenges returns the challenges for the given
// endpoint URL. // endpoint URL.
GetChallenges(endpoint string) ([]Challenge, error) GetChallenges(endpoint url.URL) ([]Challenge, error)
// AddResponse adds the response to the challenge // AddResponse adds the response to the challenge
// manager. The challenges will be parsed out of // manager. The challenges will be parsed out of
@ -48,8 +48,10 @@ func NewSimpleChallengeManager() ChallengeManager {
type simpleChallengeManager map[string][]Challenge type simpleChallengeManager map[string][]Challenge
func (m simpleChallengeManager) GetChallenges(endpoint string) ([]Challenge, error) { func (m simpleChallengeManager) GetChallenges(endpoint url.URL) ([]Challenge, error) {
challenges := m[endpoint] endpoint.Host = strings.ToLower(endpoint.Host)
challenges := m[endpoint.String()]
return challenges, nil return challenges, nil
} }
@ -60,11 +62,10 @@ func (m simpleChallengeManager) AddResponse(resp *http.Response) error {
} }
urlCopy := url.URL{ urlCopy := url.URL{
Path: resp.Request.URL.Path, Path: resp.Request.URL.Path,
Host: resp.Request.URL.Host, Host: strings.ToLower(resp.Request.URL.Host),
Scheme: resp.Request.URL.Scheme, Scheme: resp.Request.URL.Scheme,
} }
m[urlCopy.String()] = challenges m[urlCopy.String()] = challenges
return nil return nil
} }

View file

@ -1,7 +1,10 @@
package auth package auth
import ( import (
"fmt"
"net/http" "net/http"
"net/url"
"strings"
"testing" "testing"
) )
@ -36,3 +39,43 @@ func TestAuthChallengeParse(t *testing.T) {
} }
} }
func TestAuthChallengeNormalization(t *testing.T) {
testAuthChallengeNormalization(t, "reg.EXAMPLE.com")
testAuthChallengeNormalization(t, "bɿɒʜɔiɿ-ɿɘƚƨim-ƚol-ɒ-ƨʞnɒʜƚ.com")
}
func testAuthChallengeNormalization(t *testing.T, host string) {
scm := NewSimpleChallengeManager()
url, err := url.Parse(fmt.Sprintf("http://%s/v2/", host))
if err != nil {
t.Fatal(err)
}
resp := &http.Response{
Request: &http.Request{
URL: url,
},
Header: make(http.Header),
StatusCode: http.StatusUnauthorized,
}
resp.Header.Add("WWW-Authenticate", fmt.Sprintf("Bearer realm=\"https://%s/token\",service=\"registry.example.com\"", host))
err = scm.AddResponse(resp)
if err != nil {
t.Fatal(err)
}
lowered := *url
lowered.Host = strings.ToLower(lowered.Host)
c, err := scm.GetChallenges(lowered)
if err != nil {
t.Fatal(err)
}
if len(c) == 0 {
t.Fatal("Expected challenge for lower-cased-host URL")
}
}

View file

@ -15,9 +15,17 @@ import (
"github.com/docker/distribution/registry/client/transport" "github.com/docker/distribution/registry/client/transport"
) )
// ErrNoBasicAuthCredentials is returned if a request can't be authorized with var (
// basic auth due to lack of credentials. // ErrNoBasicAuthCredentials is returned if a request can't be authorized with
var ErrNoBasicAuthCredentials = errors.New("no basic auth credentials") // basic auth due to lack of credentials.
ErrNoBasicAuthCredentials = errors.New("no basic auth credentials")
// ErrNoToken is returned if a request is successful but the body does not
// contain an authorization token.
ErrNoToken = errors.New("authorization server did not include a token in the response")
)
const defaultClientID = "registry-client"
// AuthenticationHandler is an interface for authorizing a request from // AuthenticationHandler is an interface for authorizing a request from
// params from a "WWW-Authenicate" header for a single scheme. // params from a "WWW-Authenicate" header for a single scheme.
@ -36,6 +44,14 @@ type AuthenticationHandler interface {
type CredentialStore interface { type CredentialStore interface {
// Basic returns basic auth for the given URL // Basic returns basic auth for the given URL
Basic(*url.URL) (string, string) Basic(*url.URL) (string, string)
// RefreshToken returns a refresh token for the
// given URL and service
RefreshToken(*url.URL, string) string
// SetRefreshToken sets the refresh token if none
// is provided for the given url and service
SetRefreshToken(realm *url.URL, service, token string)
} }
// NewAuthorizer creates an authorizer which can handle multiple authentication // NewAuthorizer creates an authorizer which can handle multiple authentication
@ -67,9 +83,7 @@ func (ea *endpointAuthorizer) ModifyRequest(req *http.Request) error {
Path: req.URL.Path[:v2Root+4], Path: req.URL.Path[:v2Root+4],
} }
pingEndpoint := ping.String() challenges, err := ea.challenges.GetChallenges(ping)
challenges, err := ea.challenges.GetChallenges(pingEndpoint)
if err != nil { if err != nil {
return err return err
} }
@ -105,27 +119,47 @@ type clock interface {
type tokenHandler struct { type tokenHandler struct {
header http.Header header http.Header
creds CredentialStore creds CredentialStore
scope tokenScope
transport http.RoundTripper transport http.RoundTripper
clock clock clock clock
offlineAccess bool
forceOAuth bool
clientID string
scopes []Scope
tokenLock sync.Mutex tokenLock sync.Mutex
tokenCache string tokenCache string
tokenExpiration time.Time tokenExpiration time.Time
additionalScopes map[string]struct{}
} }
// tokenScope represents the scope at which a token will be requested. // Scope is a type which is serializable to a string
// This represents a specific action on a registry resource. // using the allow scope grammar.
type tokenScope struct { type Scope interface {
Resource string String() string
Scope string
Actions []string
} }
func (ts tokenScope) String() string { // RepositoryScope represents a token scope for access
return fmt.Sprintf("%s:%s:%s", ts.Resource, ts.Scope, strings.Join(ts.Actions, ",")) // to a repository.
type RepositoryScope struct {
Repository string
Actions []string
}
// String returns the string representation of the repository
// using the scope grammar
func (rs RepositoryScope) String() string {
return fmt.Sprintf("repository:%s:%s", rs.Repository, strings.Join(rs.Actions, ","))
}
// TokenHandlerOptions is used to configure a new token handler
type TokenHandlerOptions struct {
Transport http.RoundTripper
Credentials CredentialStore
OfflineAccess bool
ForceOAuth bool
ClientID string
Scopes []Scope
} }
// An implementation of clock for providing real time data. // An implementation of clock for providing real time data.
@ -137,22 +171,33 @@ func (realClock) Now() time.Time { return time.Now() }
// NewTokenHandler creates a new AuthenicationHandler which supports // NewTokenHandler creates a new AuthenicationHandler which supports
// fetching tokens from a remote token server. // fetching tokens from a remote token server.
func NewTokenHandler(transport http.RoundTripper, creds CredentialStore, scope string, actions ...string) AuthenticationHandler { func NewTokenHandler(transport http.RoundTripper, creds CredentialStore, scope string, actions ...string) AuthenticationHandler {
return newTokenHandler(transport, creds, realClock{}, scope, actions...) // Create options...
return NewTokenHandlerWithOptions(TokenHandlerOptions{
Transport: transport,
Credentials: creds,
Scopes: []Scope{
RepositoryScope{
Repository: scope,
Actions: actions,
},
},
})
} }
// newTokenHandler exposes the option to provide a clock to manipulate time in unit testing. // NewTokenHandlerWithOptions creates a new token handler using the provided
func newTokenHandler(transport http.RoundTripper, creds CredentialStore, c clock, scope string, actions ...string) AuthenticationHandler { // options structure.
return &tokenHandler{ func NewTokenHandlerWithOptions(options TokenHandlerOptions) AuthenticationHandler {
transport: transport, handler := &tokenHandler{
creds: creds, transport: options.Transport,
clock: c, creds: options.Credentials,
scope: tokenScope{ offlineAccess: options.OfflineAccess,
Resource: "repository", forceOAuth: options.ForceOAuth,
Scope: scope, clientID: options.ClientID,
Actions: actions, scopes: options.Scopes,
}, clock: realClock{},
additionalScopes: map[string]struct{}{},
} }
return handler
} }
func (th *tokenHandler) client() *http.Client { func (th *tokenHandler) client() *http.Client {
@ -169,122 +214,110 @@ func (th *tokenHandler) Scheme() string {
func (th *tokenHandler) AuthorizeRequest(req *http.Request, params map[string]string) error { func (th *tokenHandler) AuthorizeRequest(req *http.Request, params map[string]string) error {
var additionalScopes []string var additionalScopes []string
if fromParam := req.URL.Query().Get("from"); fromParam != "" { if fromParam := req.URL.Query().Get("from"); fromParam != "" {
additionalScopes = append(additionalScopes, tokenScope{ additionalScopes = append(additionalScopes, RepositoryScope{
Resource: "repository", Repository: fromParam,
Scope: fromParam, Actions: []string{"pull"},
Actions: []string{"pull"},
}.String()) }.String())
} }
if err := th.refreshToken(params, additionalScopes...); err != nil {
token, err := th.getToken(params, additionalScopes...)
if err != nil {
return err return err
} }
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", th.tokenCache)) req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
return nil return nil
} }
func (th *tokenHandler) refreshToken(params map[string]string, additionalScopes ...string) error { func (th *tokenHandler) getToken(params map[string]string, additionalScopes ...string) (string, error) {
th.tokenLock.Lock() th.tokenLock.Lock()
defer th.tokenLock.Unlock() defer th.tokenLock.Unlock()
scopes := make([]string, 0, len(th.scopes)+len(additionalScopes))
for _, scope := range th.scopes {
scopes = append(scopes, scope.String())
}
var addedScopes bool var addedScopes bool
for _, scope := range additionalScopes { for _, scope := range additionalScopes {
if _, ok := th.additionalScopes[scope]; !ok { scopes = append(scopes, scope)
th.additionalScopes[scope] = struct{}{} addedScopes = true
addedScopes = true
}
} }
now := th.clock.Now() now := th.clock.Now()
if now.After(th.tokenExpiration) || addedScopes { if now.After(th.tokenExpiration) || addedScopes {
tr, err := th.fetchToken(params) token, expiration, err := th.fetchToken(params, scopes)
if err != nil { if err != nil {
return err return "", err
} }
th.tokenCache = tr.Token
th.tokenExpiration = tr.IssuedAt.Add(time.Duration(tr.ExpiresIn) * time.Second) // do not update cache for added scope tokens
if !addedScopes {
th.tokenCache = token
th.tokenExpiration = expiration
}
return token, nil
} }
return nil return th.tokenCache, nil
} }
type tokenResponse struct { type postTokenResponse struct {
Token string `json:"token"` AccessToken string `json:"access_token"`
AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"` ExpiresIn int `json:"expires_in"`
IssuedAt time.Time `json:"issued_at"` IssuedAt time.Time `json:"issued_at"`
Scope string `json:"scope"`
} }
func (th *tokenHandler) fetchToken(params map[string]string) (token *tokenResponse, err error) { func (th *tokenHandler) fetchTokenWithOAuth(realm *url.URL, refreshToken, service string, scopes []string) (token string, expiration time.Time, err error) {
realm, ok := params["realm"] form := url.Values{}
if !ok { form.Set("scope", strings.Join(scopes, " "))
return nil, errors.New("no realm specified for token auth challenge") form.Set("service", service)
clientID := th.clientID
if clientID == "" {
// Use default client, this is a required field
clientID = defaultClientID
}
form.Set("client_id", clientID)
if refreshToken != "" {
form.Set("grant_type", "refresh_token")
form.Set("refresh_token", refreshToken)
} else if th.creds != nil {
form.Set("grant_type", "password")
username, password := th.creds.Basic(realm)
form.Set("username", username)
form.Set("password", password)
// attempt to get a refresh token
form.Set("access_type", "offline")
} else {
// refuse to do oauth without a grant type
return "", time.Time{}, fmt.Errorf("no supported grant type")
} }
// TODO(dmcgowan): Handle empty scheme resp, err := th.client().PostForm(realm.String(), form)
realmURL, err := url.Parse(realm)
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid token auth challenge realm: %s", err) return "", time.Time{}, err
}
req, err := http.NewRequest("GET", realmURL.String(), nil)
if err != nil {
return nil, err
}
reqParams := req.URL.Query()
service := params["service"]
scope := th.scope.String()
if service != "" {
reqParams.Add("service", service)
}
for _, scopeField := range strings.Fields(scope) {
reqParams.Add("scope", scopeField)
}
for scope := range th.additionalScopes {
reqParams.Add("scope", scope)
}
if th.creds != nil {
username, password := th.creds.Basic(realmURL)
if username != "" && password != "" {
reqParams.Add("account", username)
req.SetBasicAuth(username, password)
}
}
req.URL.RawQuery = reqParams.Encode()
resp, err := th.client().Do(req)
if err != nil {
return nil, err
} }
defer resp.Body.Close() defer resp.Body.Close()
if !client.SuccessStatus(resp.StatusCode) { if !client.SuccessStatus(resp.StatusCode) {
err := client.HandleErrorResponse(resp) err := client.HandleErrorResponse(resp)
return nil, err return "", time.Time{}, err
} }
decoder := json.NewDecoder(resp.Body) decoder := json.NewDecoder(resp.Body)
tr := new(tokenResponse) var tr postTokenResponse
if err = decoder.Decode(tr); err != nil { if err = decoder.Decode(&tr); err != nil {
return nil, fmt.Errorf("unable to decode token response: %s", err) return "", time.Time{}, fmt.Errorf("unable to decode token response: %s", err)
} }
// `access_token` is equivalent to `token` and if both are specified if tr.RefreshToken != "" && tr.RefreshToken != refreshToken {
// the choice is undefined. Canonicalize `access_token` by sticking th.creds.SetRefreshToken(realm, service, tr.RefreshToken)
// things in `token`.
if tr.AccessToken != "" {
tr.Token = tr.AccessToken
}
if tr.Token == "" {
return nil, errors.New("authorization server did not include a token in the response")
} }
if tr.ExpiresIn < minimumTokenLifetimeSeconds { if tr.ExpiresIn < minimumTokenLifetimeSeconds {
@ -295,10 +328,128 @@ func (th *tokenHandler) fetchToken(params map[string]string) (token *tokenRespon
if tr.IssuedAt.IsZero() { if tr.IssuedAt.IsZero() {
// issued_at is optional in the token response. // issued_at is optional in the token response.
tr.IssuedAt = th.clock.Now() tr.IssuedAt = th.clock.Now().UTC()
} }
return tr, nil return tr.AccessToken, tr.IssuedAt.Add(time.Duration(tr.ExpiresIn) * time.Second), nil
}
type getTokenResponse struct {
Token string `json:"token"`
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
IssuedAt time.Time `json:"issued_at"`
RefreshToken string `json:"refresh_token"`
}
func (th *tokenHandler) fetchTokenWithBasicAuth(realm *url.URL, service string, scopes []string) (token string, expiration time.Time, err error) {
req, err := http.NewRequest("GET", realm.String(), nil)
if err != nil {
return "", time.Time{}, err
}
reqParams := req.URL.Query()
if service != "" {
reqParams.Add("service", service)
}
for _, scope := range scopes {
reqParams.Add("scope", scope)
}
if th.offlineAccess {
reqParams.Add("offline_token", "true")
clientID := th.clientID
if clientID == "" {
clientID = defaultClientID
}
reqParams.Add("client_id", clientID)
}
if th.creds != nil {
username, password := th.creds.Basic(realm)
if username != "" && password != "" {
reqParams.Add("account", username)
req.SetBasicAuth(username, password)
}
}
req.URL.RawQuery = reqParams.Encode()
resp, err := th.client().Do(req)
if err != nil {
return "", time.Time{}, err
}
defer resp.Body.Close()
if !client.SuccessStatus(resp.StatusCode) {
err := client.HandleErrorResponse(resp)
return "", time.Time{}, err
}
decoder := json.NewDecoder(resp.Body)
var tr getTokenResponse
if err = decoder.Decode(&tr); err != nil {
return "", time.Time{}, fmt.Errorf("unable to decode token response: %s", err)
}
if tr.RefreshToken != "" && th.creds != nil {
th.creds.SetRefreshToken(realm, service, tr.RefreshToken)
}
// `access_token` is equivalent to `token` and if both are specified
// the choice is undefined. Canonicalize `access_token` by sticking
// things in `token`.
if tr.AccessToken != "" {
tr.Token = tr.AccessToken
}
if tr.Token == "" {
return "", time.Time{}, ErrNoToken
}
if tr.ExpiresIn < minimumTokenLifetimeSeconds {
// The default/minimum lifetime.
tr.ExpiresIn = minimumTokenLifetimeSeconds
logrus.Debugf("Increasing token expiration to: %d seconds", tr.ExpiresIn)
}
if tr.IssuedAt.IsZero() {
// issued_at is optional in the token response.
tr.IssuedAt = th.clock.Now().UTC()
}
return tr.Token, tr.IssuedAt.Add(time.Duration(tr.ExpiresIn) * time.Second), nil
}
func (th *tokenHandler) fetchToken(params map[string]string, scopes []string) (token string, expiration time.Time, err error) {
realm, ok := params["realm"]
if !ok {
return "", time.Time{}, errors.New("no realm specified for token auth challenge")
}
// TODO(dmcgowan): Handle empty scheme and relative realm
realmURL, err := url.Parse(realm)
if err != nil {
return "", time.Time{}, fmt.Errorf("invalid token auth challenge realm: %s", err)
}
service := params["service"]
var refreshToken string
if th.creds != nil {
refreshToken = th.creds.RefreshToken(realmURL, service)
}
if refreshToken != "" || th.forceOAuth {
return th.fetchTokenWithOAuth(realmURL, refreshToken, service, scopes)
}
return th.fetchTokenWithBasicAuth(realmURL, service, scopes)
} }
type basicHandler struct { type basicHandler struct {

View file

@ -80,14 +80,25 @@ func ping(manager ChallengeManager, endpoint, versionHeader string) ([]APIVersio
} }
type testCredentialStore struct { type testCredentialStore struct {
username string username string
password string password string
refreshTokens map[string]string
} }
func (tcs *testCredentialStore) Basic(*url.URL) (string, string) { func (tcs *testCredentialStore) Basic(*url.URL) (string, string) {
return tcs.username, tcs.password return tcs.username, tcs.password
} }
func (tcs *testCredentialStore) RefreshToken(u *url.URL, service string) string {
return tcs.refreshTokens[service]
}
func (tcs *testCredentialStore) SetRefreshToken(u *url.URL, service string, token string) {
if tcs.refreshTokens != nil {
tcs.refreshTokens[service] = token
}
}
func TestEndpointAuthorizeToken(t *testing.T) { func TestEndpointAuthorizeToken(t *testing.T) {
service := "localhost.localdomain" service := "localhost.localdomain"
repo1 := "some/registry" repo1 := "some/registry"
@ -162,14 +173,11 @@ func TestEndpointAuthorizeToken(t *testing.T) {
t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusAccepted) t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusAccepted)
} }
badCheck := func(a string) bool { e2, c2 := testServerWithAuth(m, authenicate, validCheck)
return a == "Bearer statictoken"
}
e2, c2 := testServerWithAuth(m, authenicate, badCheck)
defer c2() defer c2()
challengeManager2 := NewSimpleChallengeManager() challengeManager2 := NewSimpleChallengeManager()
versions, err = ping(challengeManager2, e+"/v2/", "x-multi-api-version") versions, err = ping(challengeManager2, e2+"/v2/", "x-multi-api-version")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -199,6 +207,161 @@ func TestEndpointAuthorizeToken(t *testing.T) {
} }
} }
func TestEndpointAuthorizeRefreshToken(t *testing.T) {
service := "localhost.localdomain"
repo1 := "some/registry"
repo2 := "other/registry"
scope1 := fmt.Sprintf("repository:%s:pull,push", repo1)
scope2 := fmt.Sprintf("repository:%s:pull,push", repo2)
refreshToken1 := "0123456790abcdef"
refreshToken2 := "0123456790fedcba"
tokenMap := testutil.RequestResponseMap([]testutil.RequestResponseMapping{
{
Request: testutil.Request{
Method: "POST",
Route: "/token",
Body: []byte(fmt.Sprintf("client_id=registry-client&grant_type=refresh_token&refresh_token=%s&scope=%s&service=%s", refreshToken1, url.QueryEscape(scope1), service)),
},
Response: testutil.Response{
StatusCode: http.StatusOK,
Body: []byte(fmt.Sprintf(`{"access_token":"statictoken","refresh_token":"%s"}`, refreshToken1)),
},
},
{
// In the future this test may fail and require using basic auth to get a different refresh token
Request: testutil.Request{
Method: "POST",
Route: "/token",
Body: []byte(fmt.Sprintf("client_id=registry-client&grant_type=refresh_token&refresh_token=%s&scope=%s&service=%s", refreshToken1, url.QueryEscape(scope2), service)),
},
Response: testutil.Response{
StatusCode: http.StatusOK,
Body: []byte(fmt.Sprintf(`{"access_token":"statictoken","refresh_token":"%s"}`, refreshToken2)),
},
},
{
Request: testutil.Request{
Method: "POST",
Route: "/token",
Body: []byte(fmt.Sprintf("client_id=registry-client&grant_type=refresh_token&refresh_token=%s&scope=%s&service=%s", refreshToken2, url.QueryEscape(scope2), service)),
},
Response: testutil.Response{
StatusCode: http.StatusOK,
Body: []byte(`{"access_token":"badtoken","refresh_token":"%s"}`),
},
},
})
te, tc := testServer(tokenMap)
defer tc()
m := testutil.RequestResponseMap([]testutil.RequestResponseMapping{
{
Request: testutil.Request{
Method: "GET",
Route: "/v2/hello",
},
Response: testutil.Response{
StatusCode: http.StatusAccepted,
},
},
})
authenicate := fmt.Sprintf("Bearer realm=%q,service=%q", te+"/token", service)
validCheck := func(a string) bool {
return a == "Bearer statictoken"
}
e, c := testServerWithAuth(m, authenicate, validCheck)
defer c()
challengeManager1 := NewSimpleChallengeManager()
versions, err := ping(challengeManager1, e+"/v2/", "x-api-version")
if err != nil {
t.Fatal(err)
}
if len(versions) != 1 {
t.Fatalf("Unexpected version count: %d, expected 1", len(versions))
}
if check := (APIVersion{Type: "registry", Version: "2.0"}); versions[0] != check {
t.Fatalf("Unexpected api version: %#v, expected %#v", versions[0], check)
}
creds := &testCredentialStore{
refreshTokens: map[string]string{
service: refreshToken1,
},
}
transport1 := transport.NewTransport(nil, NewAuthorizer(challengeManager1, NewTokenHandler(nil, creds, repo1, "pull", "push")))
client := &http.Client{Transport: transport1}
req, _ := http.NewRequest("GET", e+"/v2/hello", nil)
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Error sending get request: %s", err)
}
if resp.StatusCode != http.StatusAccepted {
t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusAccepted)
}
// Try with refresh token setting
e2, c2 := testServerWithAuth(m, authenicate, validCheck)
defer c2()
challengeManager2 := NewSimpleChallengeManager()
versions, err = ping(challengeManager2, e2+"/v2/", "x-api-version")
if err != nil {
t.Fatal(err)
}
if len(versions) != 1 {
t.Fatalf("Unexpected version count: %d, expected 1", len(versions))
}
if check := (APIVersion{Type: "registry", Version: "2.0"}); versions[0] != check {
t.Fatalf("Unexpected api version: %#v, expected %#v", versions[0], check)
}
transport2 := transport.NewTransport(nil, NewAuthorizer(challengeManager2, NewTokenHandler(nil, creds, repo2, "pull", "push")))
client2 := &http.Client{Transport: transport2}
req, _ = http.NewRequest("GET", e2+"/v2/hello", nil)
resp, err = client2.Do(req)
if err != nil {
t.Fatalf("Error sending get request: %s", err)
}
if resp.StatusCode != http.StatusAccepted {
t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusUnauthorized)
}
if creds.refreshTokens[service] != refreshToken2 {
t.Fatalf("Refresh token not set after change")
}
// Try with bad token
e3, c3 := testServerWithAuth(m, authenicate, validCheck)
defer c3()
challengeManager3 := NewSimpleChallengeManager()
versions, err = ping(challengeManager3, e3+"/v2/", "x-api-version")
if err != nil {
t.Fatal(err)
}
if check := (APIVersion{Type: "registry", Version: "2.0"}); versions[0] != check {
t.Fatalf("Unexpected api version: %#v, expected %#v", versions[0], check)
}
transport3 := transport.NewTransport(nil, NewAuthorizer(challengeManager3, NewTokenHandler(nil, creds, repo2, "pull", "push")))
client3 := &http.Client{Transport: transport3}
req, _ = http.NewRequest("GET", e3+"/v2/hello", nil)
resp, err = client3.Do(req)
if err != nil {
t.Fatalf("Error sending get request: %s", err)
}
if resp.StatusCode != http.StatusUnauthorized {
t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusUnauthorized)
}
}
func basicAuth(username, password string) string { func basicAuth(username, password string) string {
auth := username + ":" + password auth := username + ":" + password
return base64.StdEncoding.EncodeToString([]byte(auth)) return base64.StdEncoding.EncodeToString([]byte(auth))
@ -379,7 +542,19 @@ func TestEndpointAuthorizeTokenBasicWithExpiresIn(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
clock := &fakeClock{current: time.Now()} clock := &fakeClock{current: time.Now()}
transport1 := transport.NewTransport(nil, NewAuthorizer(challengeManager, newTokenHandler(nil, creds, clock, repo, "pull", "push"), NewBasicHandler(creds))) options := TokenHandlerOptions{
Transport: nil,
Credentials: creds,
Scopes: []Scope{
RepositoryScope{
Repository: repo,
Actions: []string{"pull", "push"},
},
},
}
tHandler := NewTokenHandlerWithOptions(options)
tHandler.(*tokenHandler).clock = clock
transport1 := transport.NewTransport(nil, NewAuthorizer(challengeManager, tHandler, NewBasicHandler(creds)))
client := &http.Client{Transport: transport1} client := &http.Client{Transport: transport1}
// First call should result in a token exchange // First call should result in a token exchange
@ -517,7 +692,20 @@ func TestEndpointAuthorizeTokenBasicWithExpiresInAndIssuedAt(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
transport1 := transport.NewTransport(nil, NewAuthorizer(challengeManager, newTokenHandler(nil, creds, clock, repo, "pull", "push"), NewBasicHandler(creds)))
options := TokenHandlerOptions{
Transport: nil,
Credentials: creds,
Scopes: []Scope{
RepositoryScope{
Repository: repo,
Actions: []string{"pull", "push"},
},
},
}
tHandler := NewTokenHandlerWithOptions(options)
tHandler.(*tokenHandler).clock = clock
transport1 := transport.NewTransport(nil, NewAuthorizer(challengeManager, tHandler, NewBasicHandler(creds)))
client := &http.Client{Transport: transport1} client := &http.Client{Transport: transport1}
// First call should result in a token exchange // First call should result in a token exchange

View file

@ -6,7 +6,6 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"os"
"time" "time"
"github.com/docker/distribution" "github.com/docker/distribution"
@ -104,21 +103,8 @@ func (hbu *httpBlobUpload) Write(p []byte) (n int, err error) {
} }
func (hbu *httpBlobUpload) Seek(offset int64, whence int) (int64, error) { func (hbu *httpBlobUpload) Size() int64 {
newOffset := hbu.offset return hbu.offset
switch whence {
case os.SEEK_CUR:
newOffset += int64(offset)
case os.SEEK_END:
newOffset += int64(offset)
case os.SEEK_SET:
newOffset = int64(offset)
}
hbu.offset = newOffset
return hbu.offset, nil
} }
func (hbu *httpBlobUpload) ID() string { func (hbu *httpBlobUpload) ID() string {

View file

@ -2,6 +2,7 @@ package client
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
@ -10,6 +11,10 @@ import (
"github.com/docker/distribution/registry/api/errcode" "github.com/docker/distribution/registry/api/errcode"
) )
// ErrNoErrorsInBody is returned when an HTTP response body parses to an empty
// errcode.Errors slice.
var ErrNoErrorsInBody = errors.New("no error details found in HTTP response body")
// UnexpectedHTTPStatusError is returned when an unexpected HTTP status is // UnexpectedHTTPStatusError is returned when an unexpected HTTP status is
// returned when making a registry api call. // returned when making a registry api call.
type UnexpectedHTTPStatusError struct { type UnexpectedHTTPStatusError struct {
@ -17,18 +22,19 @@ type UnexpectedHTTPStatusError struct {
} }
func (e *UnexpectedHTTPStatusError) Error() string { func (e *UnexpectedHTTPStatusError) Error() string {
return fmt.Sprintf("Received unexpected HTTP status: %s", e.Status) return fmt.Sprintf("received unexpected HTTP status: %s", e.Status)
} }
// UnexpectedHTTPResponseError is returned when an expected HTTP status code // UnexpectedHTTPResponseError is returned when an expected HTTP status code
// is returned, but the content was unexpected and failed to be parsed. // is returned, but the content was unexpected and failed to be parsed.
type UnexpectedHTTPResponseError struct { type UnexpectedHTTPResponseError struct {
ParseErr error ParseErr error
Response []byte StatusCode int
Response []byte
} }
func (e *UnexpectedHTTPResponseError) Error() string { func (e *UnexpectedHTTPResponseError) Error() string {
return fmt.Sprintf("Error parsing HTTP response: %s: %q", e.ParseErr.Error(), string(e.Response)) return fmt.Sprintf("error parsing HTTP %d response body: %s: %q", e.StatusCode, e.ParseErr.Error(), string(e.Response))
} }
func parseHTTPErrorResponse(statusCode int, r io.Reader) error { func parseHTTPErrorResponse(statusCode int, r io.Reader) error {
@ -45,18 +51,37 @@ func parseHTTPErrorResponse(statusCode int, r io.Reader) error {
} }
err = json.Unmarshal(body, &detailsErr) err = json.Unmarshal(body, &detailsErr)
if err == nil && detailsErr.Details != "" { if err == nil && detailsErr.Details != "" {
if statusCode == http.StatusUnauthorized { switch statusCode {
case http.StatusUnauthorized:
return errcode.ErrorCodeUnauthorized.WithMessage(detailsErr.Details) return errcode.ErrorCodeUnauthorized.WithMessage(detailsErr.Details)
// FIXME: go1.5 doesn't export http.StatusTooManyRequests while
// go1.6 does. Update the hardcoded value to the constant once
// Docker updates golang version to 1.6.
case 429:
return errcode.ErrorCodeTooManyRequests.WithMessage(detailsErr.Details)
default:
return errcode.ErrorCodeUnknown.WithMessage(detailsErr.Details)
} }
return errcode.ErrorCodeUnknown.WithMessage(detailsErr.Details)
} }
if err := json.Unmarshal(body, &errors); err != nil { if err := json.Unmarshal(body, &errors); err != nil {
return &UnexpectedHTTPResponseError{ return &UnexpectedHTTPResponseError{
ParseErr: err, ParseErr: err,
Response: body, StatusCode: statusCode,
Response: body,
} }
} }
if len(errors) == 0 {
// If there was no error specified in the body, return
// UnexpectedHTTPResponseError.
return &UnexpectedHTTPResponseError{
ParseErr: ErrNoErrorsInBody,
StatusCode: statusCode,
Response: body,
}
}
return errors return errors
} }

View file

@ -59,6 +59,21 @@ func TestHandleErrorResponseExpectedStatusCode400ValidBody(t *testing.T) {
} }
} }
func TestHandleErrorResponseExpectedStatusCode404EmptyErrorSlice(t *testing.T) {
json := `{"randomkey": "randomvalue"}`
response := &http.Response{
Status: "404 Not Found",
StatusCode: 404,
Body: nopCloser{bytes.NewBufferString(json)},
}
err := HandleErrorResponse(response)
expectedMsg := `error parsing HTTP 404 response body: no error details found in HTTP response body: "{\"randomkey\": \"randomvalue\"}"`
if !strings.Contains(err.Error(), expectedMsg) {
t.Errorf("Expected \"%s\", got: \"%s\"", expectedMsg, err.Error())
}
}
func TestHandleErrorResponseExpectedStatusCode404InvalidBody(t *testing.T) { func TestHandleErrorResponseExpectedStatusCode404InvalidBody(t *testing.T) {
json := "{invalid json}" json := "{invalid json}"
response := &http.Response{ response := &http.Response{
@ -68,7 +83,7 @@ func TestHandleErrorResponseExpectedStatusCode404InvalidBody(t *testing.T) {
} }
err := HandleErrorResponse(response) err := HandleErrorResponse(response)
expectedMsg := "Error parsing HTTP response: invalid character 'i' looking for beginning of object key string: \"{invalid json}\"" expectedMsg := "error parsing HTTP 404 response body: invalid character 'i' looking for beginning of object key string: \"{invalid json}\""
if !strings.Contains(err.Error(), expectedMsg) { if !strings.Contains(err.Error(), expectedMsg) {
t.Errorf("Expected \"%s\", got: \"%s\"", expectedMsg, err.Error()) t.Errorf("Expected \"%s\", got: \"%s\"", expectedMsg, err.Error())
} }
@ -82,7 +97,7 @@ func TestHandleErrorResponseUnexpectedStatusCode501(t *testing.T) {
} }
err := HandleErrorResponse(response) err := HandleErrorResponse(response)
expectedMsg := "Received unexpected HTTP status: 501 Not Implemented" expectedMsg := "received unexpected HTTP status: 501 Not Implemented"
if !strings.Contains(err.Error(), expectedMsg) { if !strings.Contains(err.Error(), expectedMsg) {
t.Errorf("Expected \"%s\", got: \"%s\"", expectedMsg, err.Error()) t.Errorf("Expected \"%s\", got: \"%s\"", expectedMsg, err.Error())
} }

View file

@ -62,7 +62,7 @@ func checkHTTPRedirect(req *http.Request, via []*http.Request) error {
// NewRegistry creates a registry namespace which can be used to get a listing of repositories // NewRegistry creates a registry namespace which can be used to get a listing of repositories
func NewRegistry(ctx context.Context, baseURL string, transport http.RoundTripper) (Registry, error) { func NewRegistry(ctx context.Context, baseURL string, transport http.RoundTripper) (Registry, error) {
ub, err := v2.NewURLBuilderFromString(baseURL) ub, err := v2.NewURLBuilderFromString(baseURL, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -133,7 +133,7 @@ func (r *registry) Repositories(ctx context.Context, entries []string, last stri
// NewRepository creates a new Repository for the given repository name and base URL. // NewRepository creates a new Repository for the given repository name and base URL.
func NewRepository(ctx context.Context, name reference.Named, baseURL string, transport http.RoundTripper) (distribution.Repository, error) { func NewRepository(ctx context.Context, name reference.Named, baseURL string, transport http.RoundTripper) (distribution.Repository, error) {
ub, err := v2.NewURLBuilderFromString(baseURL) ub, err := v2.NewURLBuilderFromString(baseURL, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -308,6 +308,7 @@ check:
if err != nil { if err != nil {
return distribution.Descriptor{}, err return distribution.Descriptor{}, err
} }
defer resp.Body.Close()
switch { switch {
case resp.StatusCode >= 200 && resp.StatusCode < 400: case resp.StatusCode >= 200 && resp.StatusCode < 400:
@ -401,9 +402,9 @@ func (ms *manifests) Get(ctx context.Context, dgst digest.Digest, options ...dis
) )
for _, option := range options { for _, option := range options {
if opt, ok := option.(withTagOption); ok { if opt, ok := option.(distribution.WithTagOption); ok {
digestOrTag = opt.tag digestOrTag = opt.Tag
ref, err = reference.WithTag(ms.name, opt.tag) ref, err = reference.WithTag(ms.name, opt.Tag)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -464,21 +465,6 @@ func (ms *manifests) Get(ctx context.Context, dgst digest.Digest, options ...dis
return nil, HandleErrorResponse(resp) return nil, HandleErrorResponse(resp)
} }
// WithTag allows a tag to be passed into Put which enables the client
// to build a correct URL.
func WithTag(tag string) distribution.ManifestServiceOption {
return withTagOption{tag}
}
type withTagOption struct{ tag string }
func (o withTagOption) Apply(m distribution.ManifestService) error {
if _, ok := m.(*manifests); ok {
return nil
}
return fmt.Errorf("withTagOption is a client-only option")
}
// Put puts a manifest. A tag can be specified using an options parameter which uses some shared state to hold the // Put puts a manifest. A tag can be specified using an options parameter which uses some shared state to hold the
// tag name in order to build the correct upload URL. // tag name in order to build the correct upload URL.
func (ms *manifests) Put(ctx context.Context, m distribution.Manifest, options ...distribution.ManifestServiceOption) (digest.Digest, error) { func (ms *manifests) Put(ctx context.Context, m distribution.Manifest, options ...distribution.ManifestServiceOption) (digest.Digest, error) {
@ -486,9 +472,9 @@ func (ms *manifests) Put(ctx context.Context, m distribution.Manifest, options .
var tagged bool var tagged bool
for _, option := range options { for _, option := range options {
if opt, ok := option.(withTagOption); ok { if opt, ok := option.(distribution.WithTagOption); ok {
var err error var err error
ref, err = reference.WithTag(ref, opt.tag) ref, err = reference.WithTag(ref, opt.Tag)
if err != nil { if err != nil {
return "", err return "", err
} }

View file

@ -710,7 +710,7 @@ func TestV1ManifestFetch(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
manifest, err = ms.Get(ctx, dgst, WithTag("latest")) manifest, err = ms.Get(ctx, dgst, distribution.WithTag("latest"))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -723,7 +723,7 @@ func TestV1ManifestFetch(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
manifest, err = ms.Get(ctx, dgst, WithTag("badcontenttype")) manifest, err = ms.Get(ctx, dgst, distribution.WithTag("badcontenttype"))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -761,7 +761,7 @@ func TestManifestFetchWithEtag(t *testing.T) {
if !ok { if !ok {
panic("wrong type for client manifest service") panic("wrong type for client manifest service")
} }
_, err = clientManifestService.Get(ctx, d1, WithTag("latest"), AddEtagToTag("latest", d1.String())) _, err = clientManifestService.Get(ctx, d1, distribution.WithTag("latest"), AddEtagToTag("latest", d1.String()))
if err != distribution.ErrManifestNotModified { if err != distribution.ErrManifestNotModified {
t.Fatal(err) t.Fatal(err)
} }
@ -861,7 +861,7 @@ func TestManifestPut(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
if _, err := ms.Put(ctx, m1, WithTag(m1.Tag)); err != nil { if _, err := ms.Put(ctx, m1, distribution.WithTag(m1.Tag)); err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -43,7 +43,6 @@ var headerConfig = http.Header{
// 200 OK response. // 200 OK response.
func TestCheckAPI(t *testing.T) { func TestCheckAPI(t *testing.T) {
env := newTestEnv(t, false) env := newTestEnv(t, false)
baseURL, err := env.builder.BuildBaseURL() baseURL, err := env.builder.BuildBaseURL()
if err != nil { if err != nil {
t.Fatalf("unexpected error building base url: %v", err) t.Fatalf("unexpected error building base url: %v", err)
@ -294,6 +293,79 @@ func TestBlobDelete(t *testing.T) {
testBlobDelete(t, env, args) testBlobDelete(t, env, args)
} }
func TestRelativeURL(t *testing.T) {
config := configuration.Configuration{
Storage: configuration.Storage{
"inmemory": configuration.Parameters{},
},
}
config.HTTP.Headers = headerConfig
config.HTTP.RelativeURLs = false
env := newTestEnvWithConfig(t, &config)
ref, _ := reference.WithName("foo/bar")
uploadURLBaseAbs, _ := startPushLayer(t, env, ref)
u, err := url.Parse(uploadURLBaseAbs)
if err != nil {
t.Fatal(err)
}
if !u.IsAbs() {
t.Fatal("Relative URL returned from blob upload chunk with non-relative configuration")
}
args := makeBlobArgs(t)
resp, err := doPushLayer(t, env.builder, ref, args.layerDigest, uploadURLBaseAbs, args.layerFile)
if err != nil {
t.Fatalf("unexpected error doing layer push relative url: %v", err)
}
checkResponse(t, "relativeurl blob upload", resp, http.StatusCreated)
u, err = url.Parse(resp.Header.Get("Location"))
if err != nil {
t.Fatal(err)
}
if !u.IsAbs() {
t.Fatal("Relative URL returned from blob upload with non-relative configuration")
}
config.HTTP.RelativeURLs = true
args = makeBlobArgs(t)
uploadURLBaseRelative, _ := startPushLayer(t, env, ref)
u, err = url.Parse(uploadURLBaseRelative)
if err != nil {
t.Fatal(err)
}
if u.IsAbs() {
t.Fatal("Absolute URL returned from blob upload chunk with relative configuration")
}
// Start a new upload in absolute mode to get a valid base URL
config.HTTP.RelativeURLs = false
uploadURLBaseAbs, _ = startPushLayer(t, env, ref)
u, err = url.Parse(uploadURLBaseAbs)
if err != nil {
t.Fatal(err)
}
if !u.IsAbs() {
t.Fatal("Relative URL returned from blob upload chunk with non-relative configuration")
}
// Complete upload with relative URLs enabled to ensure the final location is relative
config.HTTP.RelativeURLs = true
resp, err = doPushLayer(t, env.builder, ref, args.layerDigest, uploadURLBaseAbs, args.layerFile)
if err != nil {
t.Fatalf("unexpected error doing layer push relative url: %v", err)
}
checkResponse(t, "relativeurl blob upload", resp, http.StatusCreated)
u, err = url.Parse(resp.Header.Get("Location"))
if err != nil {
t.Fatal(err)
}
if u.IsAbs() {
t.Fatal("Relative URL returned from blob upload with non-relative configuration")
}
}
func TestBlobDeleteDisabled(t *testing.T) { func TestBlobDeleteDisabled(t *testing.T) {
deleteEnabled := false deleteEnabled := false
env := newTestEnv(t, deleteEnabled) env := newTestEnv(t, deleteEnabled)
@ -349,7 +421,7 @@ func testBlobAPI(t *testing.T, env *testEnv, args blobArgs) *testEnv {
// ------------------------------------------ // ------------------------------------------
// Start an upload, check the status then cancel // Start an upload, check the status then cancel
uploadURLBase, uploadUUID := startPushLayer(t, env.builder, imageName) uploadURLBase, uploadUUID := startPushLayer(t, env, imageName)
// A status check should work // A status check should work
resp, err = http.Get(uploadURLBase) resp, err = http.Get(uploadURLBase)
@ -384,7 +456,7 @@ func testBlobAPI(t *testing.T, env *testEnv, args blobArgs) *testEnv {
// ----------------------------------------- // -----------------------------------------
// Do layer push with an empty body and different digest // Do layer push with an empty body and different digest
uploadURLBase, uploadUUID = startPushLayer(t, env.builder, imageName) uploadURLBase, uploadUUID = startPushLayer(t, env, imageName)
resp, err = doPushLayer(t, env.builder, imageName, layerDigest, uploadURLBase, bytes.NewReader([]byte{})) resp, err = doPushLayer(t, env.builder, imageName, layerDigest, uploadURLBase, bytes.NewReader([]byte{}))
if err != nil { if err != nil {
t.Fatalf("unexpected error doing bad layer push: %v", err) t.Fatalf("unexpected error doing bad layer push: %v", err)
@ -400,7 +472,7 @@ func testBlobAPI(t *testing.T, env *testEnv, args blobArgs) *testEnv {
t.Fatalf("unexpected error digesting empty buffer: %v", err) t.Fatalf("unexpected error digesting empty buffer: %v", err)
} }
uploadURLBase, uploadUUID = startPushLayer(t, env.builder, imageName) uploadURLBase, uploadUUID = startPushLayer(t, env, imageName)
pushLayer(t, env.builder, imageName, zeroDigest, uploadURLBase, bytes.NewReader([]byte{})) pushLayer(t, env.builder, imageName, zeroDigest, uploadURLBase, bytes.NewReader([]byte{}))
// ----------------------------------------- // -----------------------------------------
@ -413,7 +485,7 @@ func testBlobAPI(t *testing.T, env *testEnv, args blobArgs) *testEnv {
t.Fatalf("unexpected error digesting empty tar: %v", err) t.Fatalf("unexpected error digesting empty tar: %v", err)
} }
uploadURLBase, uploadUUID = startPushLayer(t, env.builder, imageName) uploadURLBase, uploadUUID = startPushLayer(t, env, imageName)
pushLayer(t, env.builder, imageName, emptyDigest, uploadURLBase, bytes.NewReader(emptyTar)) pushLayer(t, env.builder, imageName, emptyDigest, uploadURLBase, bytes.NewReader(emptyTar))
// ------------------------------------------ // ------------------------------------------
@ -421,7 +493,7 @@ func testBlobAPI(t *testing.T, env *testEnv, args blobArgs) *testEnv {
layerLength, _ := layerFile.Seek(0, os.SEEK_END) layerLength, _ := layerFile.Seek(0, os.SEEK_END)
layerFile.Seek(0, os.SEEK_SET) layerFile.Seek(0, os.SEEK_SET)
uploadURLBase, uploadUUID = startPushLayer(t, env.builder, imageName) uploadURLBase, uploadUUID = startPushLayer(t, env, imageName)
pushLayer(t, env.builder, imageName, layerDigest, uploadURLBase, layerFile) pushLayer(t, env.builder, imageName, layerDigest, uploadURLBase, layerFile)
// ------------------------------------------ // ------------------------------------------
@ -435,7 +507,7 @@ func testBlobAPI(t *testing.T, env *testEnv, args blobArgs) *testEnv {
canonicalDigest := canonicalDigester.Digest() canonicalDigest := canonicalDigester.Digest()
layerFile.Seek(0, 0) layerFile.Seek(0, 0)
uploadURLBase, uploadUUID = startPushLayer(t, env.builder, imageName) uploadURLBase, uploadUUID = startPushLayer(t, env, imageName)
uploadURLBase, dgst := pushChunk(t, env.builder, imageName, uploadURLBase, layerFile, layerLength) uploadURLBase, dgst := pushChunk(t, env.builder, imageName, uploadURLBase, layerFile, layerLength)
finishUpload(t, env.builder, imageName, uploadURLBase, dgst) finishUpload(t, env.builder, imageName, uploadURLBase, dgst)
@ -585,7 +657,7 @@ func testBlobDelete(t *testing.T, env *testEnv, args blobArgs) {
// Reupload previously deleted blob // Reupload previously deleted blob
layerFile.Seek(0, os.SEEK_SET) layerFile.Seek(0, os.SEEK_SET)
uploadURLBase, _ := startPushLayer(t, env.builder, imageName) uploadURLBase, _ := startPushLayer(t, env, imageName)
pushLayer(t, env.builder, imageName, layerDigest, uploadURLBase, layerFile) pushLayer(t, env.builder, imageName, layerDigest, uploadURLBase, layerFile)
layerFile.Seek(0, os.SEEK_SET) layerFile.Seek(0, os.SEEK_SET)
@ -625,7 +697,7 @@ func TestDeleteDisabled(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Error building blob URL") t.Fatalf("Error building blob URL")
} }
uploadURLBase, _ := startPushLayer(t, env.builder, imageName) uploadURLBase, _ := startPushLayer(t, env, imageName)
pushLayer(t, env.builder, imageName, layerDigest, uploadURLBase, layerFile) pushLayer(t, env.builder, imageName, layerDigest, uploadURLBase, layerFile)
resp, err := httpDelete(layerURL) resp, err := httpDelete(layerURL)
@ -651,7 +723,7 @@ func TestDeleteReadOnly(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Error building blob URL") t.Fatalf("Error building blob URL")
} }
uploadURLBase, _ := startPushLayer(t, env.builder, imageName) uploadURLBase, _ := startPushLayer(t, env, imageName)
pushLayer(t, env.builder, imageName, layerDigest, uploadURLBase, layerFile) pushLayer(t, env.builder, imageName, layerDigest, uploadURLBase, layerFile)
env.app.readOnly = true env.app.readOnly = true
@ -854,7 +926,7 @@ func testManifestAPISchema1(t *testing.T, env *testEnv, imageName reference.Name
} }
// TODO(stevvooe): Add a test case where we take a mostly valid registry, // TODO(stevvooe): Add a test case where we take a mostly valid registry,
// tamper with the content and ensure that we get a unverified manifest // tamper with the content and ensure that we get an unverified manifest
// error. // error.
// Push 2 random layers // Push 2 random layers
@ -871,7 +943,7 @@ func testManifestAPISchema1(t *testing.T, env *testEnv, imageName reference.Name
expectedLayers[dgst] = rs expectedLayers[dgst] = rs
unsignedManifest.FSLayers[i].BlobSum = dgst unsignedManifest.FSLayers[i].BlobSum = dgst
uploadURLBase, _ := startPushLayer(t, env.builder, imageName) uploadURLBase, _ := startPushLayer(t, env, imageName)
pushLayer(t, env.builder, imageName, dgst, uploadURLBase, rs) pushLayer(t, env.builder, imageName, dgst, uploadURLBase, rs)
} }
@ -995,13 +1067,13 @@ func testManifestAPISchema1(t *testing.T, env *testEnv, imageName reference.Name
t.Fatalf("error decoding fetched manifest: %v", err) t.Fatalf("error decoding fetched manifest: %v", err)
} }
// check two signatures were roundtripped // check only 1 signature is returned
signatures, err = fetchedManifestByDigest.Signatures() signatures, err = fetchedManifestByDigest.Signatures()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(signatures) != 2 { if len(signatures) != 1 {
t.Fatalf("expected 2 signature from manifest, got: %d", len(signatures)) t.Fatalf("expected 2 signature from manifest, got: %d", len(signatures))
} }
@ -1177,7 +1249,7 @@ func testManifestAPISchema2(t *testing.T, env *testEnv, imageName reference.Name
}`) }`)
sampleConfigDigest := digest.FromBytes(sampleConfig) sampleConfigDigest := digest.FromBytes(sampleConfig)
uploadURLBase, _ := startPushLayer(t, env.builder, imageName) uploadURLBase, _ := startPushLayer(t, env, imageName)
pushLayer(t, env.builder, imageName, sampleConfigDigest, uploadURLBase, bytes.NewReader(sampleConfig)) pushLayer(t, env.builder, imageName, sampleConfigDigest, uploadURLBase, bytes.NewReader(sampleConfig))
manifest.Config.Digest = sampleConfigDigest manifest.Config.Digest = sampleConfigDigest
manifest.Config.Size = int64(len(sampleConfig)) manifest.Config.Size = int64(len(sampleConfig))
@ -1210,7 +1282,7 @@ func testManifestAPISchema2(t *testing.T, env *testEnv, imageName reference.Name
expectedLayers[dgst] = rs expectedLayers[dgst] = rs
manifest.Layers[i].Digest = dgst manifest.Layers[i].Digest = dgst
uploadURLBase, _ := startPushLayer(t, env.builder, imageName) uploadURLBase, _ := startPushLayer(t, env, imageName)
pushLayer(t, env.builder, imageName, dgst, uploadURLBase, rs) pushLayer(t, env.builder, imageName, dgst, uploadURLBase, rs)
} }
@ -1842,7 +1914,7 @@ func newTestEnvWithConfig(t *testing.T, config *configuration.Configuration) *te
app := NewApp(ctx, config) app := NewApp(ctx, config)
server := httptest.NewServer(handlers.CombinedLoggingHandler(os.Stderr, app)) server := httptest.NewServer(handlers.CombinedLoggingHandler(os.Stderr, app))
builder, err := v2.NewURLBuilderFromString(server.URL + config.HTTP.Prefix) builder, err := v2.NewURLBuilderFromString(server.URL+config.HTTP.Prefix, false)
if err != nil { if err != nil {
t.Fatalf("error creating url builder: %v", err) t.Fatalf("error creating url builder: %v", err)
@ -1904,21 +1976,33 @@ func putManifest(t *testing.T, msg, url, contentType string, v interface{}) *htt
return resp return resp
} }
func startPushLayer(t *testing.T, ub *v2.URLBuilder, name reference.Named) (location string, uuid string) { func startPushLayer(t *testing.T, env *testEnv, name reference.Named) (location string, uuid string) {
layerUploadURL, err := ub.BuildBlobUploadURL(name) layerUploadURL, err := env.builder.BuildBlobUploadURL(name)
if err != nil { if err != nil {
t.Fatalf("unexpected error building layer upload url: %v", err) t.Fatalf("unexpected error building layer upload url: %v", err)
} }
u, err := url.Parse(layerUploadURL)
if err != nil {
t.Fatalf("error parsing layer upload URL: %v", err)
}
base, err := url.Parse(env.server.URL)
if err != nil {
t.Fatalf("error parsing server URL: %v", err)
}
layerUploadURL = base.ResolveReference(u).String()
resp, err := http.Post(layerUploadURL, "", nil) resp, err := http.Post(layerUploadURL, "", nil)
if err != nil { if err != nil {
t.Fatalf("unexpected error starting layer push: %v", err) t.Fatalf("unexpected error starting layer push: %v", err)
} }
defer resp.Body.Close() defer resp.Body.Close()
checkResponse(t, fmt.Sprintf("pushing starting layer push %v", name.String()), resp, http.StatusAccepted) checkResponse(t, fmt.Sprintf("pushing starting layer push %v", name.String()), resp, http.StatusAccepted)
u, err := url.Parse(resp.Header.Get("Location")) u, err = url.Parse(resp.Header.Get("Location"))
if err != nil { if err != nil {
t.Fatalf("error parsing location header: %v", err) t.Fatalf("error parsing location header: %v", err)
} }
@ -1943,7 +2027,6 @@ func doPushLayer(t *testing.T, ub *v2.URLBuilder, name reference.Named, dgst dig
u.RawQuery = url.Values{ u.RawQuery = url.Values{
"_state": u.Query()["_state"], "_state": u.Query()["_state"],
"digest": []string{dgst.String()}, "digest": []string{dgst.String()},
}.Encode() }.Encode()
@ -2211,8 +2294,7 @@ func createRepository(env *testEnv, t *testing.T, imageName string, tag string)
expectedLayers[dgst] = rs expectedLayers[dgst] = rs
unsignedManifest.FSLayers[i].BlobSum = dgst unsignedManifest.FSLayers[i].BlobSum = dgst
uploadURLBase, _ := startPushLayer(t, env, imageNameRef)
uploadURLBase, _ := startPushLayer(t, env.builder, imageNameRef)
pushLayer(t, env.builder, imageNameRef, dgst, uploadURLBase, rs) pushLayer(t, env.builder, imageNameRef, dgst, uploadURLBase, rs)
} }

View file

@ -155,6 +155,7 @@ func NewApp(ctx context.Context, config *configuration.Configuration) *App {
app.configureRedis(config) app.configureRedis(config)
app.configureLogHook(config) app.configureLogHook(config)
options := registrymiddleware.GetRegistryOptions()
if config.Compatibility.Schema1.TrustKey != "" { if config.Compatibility.Schema1.TrustKey != "" {
app.trustKey, err = libtrust.LoadKeyFile(config.Compatibility.Schema1.TrustKey) app.trustKey, err = libtrust.LoadKeyFile(config.Compatibility.Schema1.TrustKey)
if err != nil { if err != nil {
@ -169,6 +170,8 @@ func NewApp(ctx context.Context, config *configuration.Configuration) *App {
} }
} }
options = append(options, storage.Schema1SigningKey(app.trustKey))
if config.HTTP.Host != "" { if config.HTTP.Host != "" {
u, err := url.Parse(config.HTTP.Host) u, err := url.Parse(config.HTTP.Host)
if err != nil { if err != nil {
@ -177,17 +180,10 @@ func NewApp(ctx context.Context, config *configuration.Configuration) *App {
app.httpHost = *u app.httpHost = *u
} }
options := []storage.RegistryOption{}
if app.isCache { if app.isCache {
options = append(options, storage.DisableDigestResumption) options = append(options, storage.DisableDigestResumption)
} }
if config.Compatibility.Schema1.DisableSignatureStore {
options = append(options, storage.DisableSchema1Signatures)
options = append(options, storage.Schema1SigningKey(app.trustKey))
}
// configure deletion // configure deletion
if d, ok := config.Storage["delete"]; ok { if d, ok := config.Storage["delete"]; ok {
e, ok := d["enabled"] e, ok := d["enabled"]
@ -258,7 +254,7 @@ func NewApp(ctx context.Context, config *configuration.Configuration) *App {
} }
} }
app.registry, err = applyRegistryMiddleware(app.Context, app.registry, config.Middleware["registry"]) app.registry, err = applyRegistryMiddleware(app, app.registry, config.Middleware["registry"])
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -634,6 +630,8 @@ func (app *App) dispatcher(dispatch dispatchFunc) http.Handler {
context.Errors = append(context.Errors, v2.ErrorCodeNameUnknown.WithDetail(err)) context.Errors = append(context.Errors, v2.ErrorCodeNameUnknown.WithDetail(err))
case distribution.ErrRepositoryNameInvalid: case distribution.ErrRepositoryNameInvalid:
context.Errors = append(context.Errors, v2.ErrorCodeNameInvalid.WithDetail(err)) context.Errors = append(context.Errors, v2.ErrorCodeNameInvalid.WithDetail(err))
case errcode.Error:
context.Errors = append(context.Errors, err)
} }
if err := errcode.ServeJSON(w, context.Errors); err != nil { if err := errcode.ServeJSON(w, context.Errors); err != nil {
@ -647,7 +645,7 @@ func (app *App) dispatcher(dispatch dispatchFunc) http.Handler {
repository, repository,
app.eventBridge(context, r)) app.eventBridge(context, r))
context.Repository, err = applyRepoMiddleware(context.Context, context.Repository, app.Config.Middleware["repository"]) context.Repository, err = applyRepoMiddleware(app, context.Repository, app.Config.Middleware["repository"])
if err != nil { if err != nil {
ctxu.GetLogger(context).Errorf("error initializing repository middleware: %v", err) ctxu.GetLogger(context).Errorf("error initializing repository middleware: %v", err)
context.Errors = append(context.Errors, errcode.ErrorCodeUnknown.WithDetail(err)) context.Errors = append(context.Errors, errcode.ErrorCodeUnknown.WithDetail(err))
@ -721,9 +719,9 @@ func (app *App) context(w http.ResponseWriter, r *http.Request) *Context {
// A "host" item in the configuration takes precedence over // A "host" item in the configuration takes precedence over
// X-Forwarded-Proto and X-Forwarded-Host headers, and the // X-Forwarded-Proto and X-Forwarded-Host headers, and the
// hostname in the request. // hostname in the request.
context.urlBuilder = v2.NewURLBuilder(&app.httpHost) context.urlBuilder = v2.NewURLBuilder(&app.httpHost, false)
} else { } else {
context.urlBuilder = v2.NewURLBuilderFromRequest(r) context.urlBuilder = v2.NewURLBuilderFromRequest(r, app.Config.HTTP.RelativeURLs)
} }
return context return context

View file

@ -160,7 +160,7 @@ func TestNewApp(t *testing.T) {
app := NewApp(ctx, &config) app := NewApp(ctx, &config)
server := httptest.NewServer(app) server := httptest.NewServer(app)
builder, err := v2.NewURLBuilderFromString(server.URL) builder, err := v2.NewURLBuilderFromString(server.URL, false)
if err != nil { if err != nil {
t.Fatalf("error creating urlbuilder: %v", err) t.Fatalf("error creating urlbuilder: %v", err)
} }

View file

@ -4,7 +4,6 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"net/url" "net/url"
"os"
"github.com/docker/distribution" "github.com/docker/distribution"
ctxu "github.com/docker/distribution/context" ctxu "github.com/docker/distribution/context"
@ -76,28 +75,14 @@ func blobUploadDispatcher(ctx *Context, r *http.Request) http.Handler {
} }
buh.Upload = upload buh.Upload = upload
if state.Offset > 0 { if size := upload.Size(); size != buh.State.Offset {
// Seek the blob upload to the correct spot if it's non-zero. defer upload.Close()
// These error conditions should be rare and demonstrate really ctxu.GetLogger(ctx).Infof("upload resumed at wrong offest: %d != %d", size, buh.State.Offset)
// problems. We basically cancel the upload and tell the client to return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// start over. buh.Errors = append(buh.Errors, v2.ErrorCodeBlobUploadInvalid.WithDetail(err))
if nn, err := upload.Seek(buh.State.Offset, os.SEEK_SET); err != nil { upload.Cancel(buh)
defer upload.Close() })
ctxu.GetLogger(ctx).Infof("error seeking blob upload: %v", err)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
buh.Errors = append(buh.Errors, v2.ErrorCodeBlobUploadInvalid.WithDetail(err))
upload.Cancel(buh)
})
} else if nn != buh.State.Offset {
defer upload.Close()
ctxu.GetLogger(ctx).Infof("seek to wrong offest: %d != %d", nn, buh.State.Offset)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
buh.Errors = append(buh.Errors, v2.ErrorCodeBlobUploadInvalid.WithDetail(err))
upload.Cancel(buh)
})
}
} }
return closeResources(handler, buh.Upload) return closeResources(handler, buh.Upload)
} }
@ -239,10 +224,7 @@ func (buh *blobUploadHandler) PutBlobUploadComplete(w http.ResponseWriter, r *ht
return return
} }
size := buh.State.Offset size := buh.Upload.Size()
if offset, err := buh.Upload.Seek(0, os.SEEK_CUR); err == nil {
size = offset
}
desc, err := buh.Upload.Commit(buh, distribution.Descriptor{ desc, err := buh.Upload.Commit(buh, distribution.Descriptor{
Digest: dgst, Digest: dgst,
@ -257,6 +239,8 @@ func (buh *blobUploadHandler) PutBlobUploadComplete(w http.ResponseWriter, r *ht
switch err := err.(type) { switch err := err.(type) {
case distribution.ErrBlobInvalidDigest: case distribution.ErrBlobInvalidDigest:
buh.Errors = append(buh.Errors, v2.ErrorCodeDigestInvalid.WithDetail(err)) buh.Errors = append(buh.Errors, v2.ErrorCodeDigestInvalid.WithDetail(err))
case errcode.Error:
buh.Errors = append(buh.Errors, err)
default: default:
switch err { switch err {
case distribution.ErrAccessDenied: case distribution.ErrAccessDenied:
@ -308,21 +292,10 @@ func (buh *blobUploadHandler) CancelBlobUpload(w http.ResponseWriter, r *http.Re
// uploads always start at a 0 offset. This allows disabling resumable push by // uploads always start at a 0 offset. This allows disabling resumable push by
// always returning a 0 offset on check status. // always returning a 0 offset on check status.
func (buh *blobUploadHandler) blobUploadResponse(w http.ResponseWriter, r *http.Request, fresh bool) error { func (buh *blobUploadHandler) blobUploadResponse(w http.ResponseWriter, r *http.Request, fresh bool) error {
var offset int64
if !fresh {
var err error
offset, err = buh.Upload.Seek(0, os.SEEK_CUR)
if err != nil {
ctxu.GetLogger(buh).Errorf("unable get current offset of blob upload: %v", err)
return err
}
}
// TODO(stevvooe): Need a better way to manage the upload state automatically. // TODO(stevvooe): Need a better way to manage the upload state automatically.
buh.State.Name = buh.Repository.Named().Name() buh.State.Name = buh.Repository.Named().Name()
buh.State.UUID = buh.Upload.ID() buh.State.UUID = buh.Upload.ID()
buh.State.Offset = offset buh.State.Offset = buh.Upload.Size()
buh.State.StartedAt = buh.Upload.StartedAt() buh.State.StartedAt = buh.Upload.StartedAt()
token, err := hmacKey(buh.Config.HTTP.Secret).packUploadState(buh.State) token, err := hmacKey(buh.Config.HTTP.Secret).packUploadState(buh.State)
@ -341,13 +314,14 @@ func (buh *blobUploadHandler) blobUploadResponse(w http.ResponseWriter, r *http.
return err return err
} }
endRange := offset endRange := buh.Upload.Size()
if endRange > 0 { if endRange > 0 {
endRange = endRange - 1 endRange = endRange - 1
} }
w.Header().Set("Docker-Upload-UUID", buh.UUID) w.Header().Set("Docker-Upload-UUID", buh.UUID)
w.Header().Set("Location", uploadURL) w.Header().Set("Location", uploadURL)
w.Header().Set("Content-Length", "0") w.Header().Set("Content-Length", "0")
w.Header().Set("Range", fmt.Sprintf("0-%d", endRange)) w.Header().Set("Range", fmt.Sprintf("0-%d", endRange))

View file

@ -20,7 +20,7 @@ func closeResources(handler http.Handler, closers ...io.Closer) http.Handler {
}) })
} }
// copyFullPayload copies the payload of a HTTP request to destWriter. If it // copyFullPayload copies the payload of an HTTP request to destWriter. If it
// receives less content than expected, and the client disconnected during the // receives less content than expected, and the client disconnected during the
// upload, it avoids sending a 400 error to keep the logs cleaner. // upload, it avoids sending a 400 error to keep the logs cleaner.
func copyFullPayload(responseWriter http.ResponseWriter, r *http.Request, destWriter io.Writer, context ctxu.Context, action string, errSlice *errcode.Errors) error { func copyFullPayload(responseWriter http.ResponseWriter, r *http.Request, destWriter io.Writer, context ctxu.Context, action string, errSlice *errcode.Errors) error {
@ -46,7 +46,11 @@ func copyFullPayload(responseWriter http.ResponseWriter, r *http.Request, destWr
// instead of showing 0 for the HTTP status. // instead of showing 0 for the HTTP status.
responseWriter.WriteHeader(499) responseWriter.WriteHeader(499)
ctxu.GetLogger(context).Error("client disconnected during " + action) ctxu.GetLoggerWithFields(context, map[interface{}]interface{}{
"error": err,
"copied": copied,
"contentLength": r.ContentLength,
}, "error", "copied", "contentLength").Error("client disconnected during " + action)
return errors.New("client disconnected") return errors.New("client disconnected")
default: default:
} }

View file

@ -86,7 +86,11 @@ func (imh *imageManifestHandler) GetImageManifest(w http.ResponseWriter, r *http
return return
} }
manifest, err = manifests.Get(imh, imh.Digest) var options []distribution.ManifestServiceOption
if imh.Tag != "" {
options = append(options, distribution.WithTag(imh.Tag))
}
manifest, err = manifests.Get(imh, imh.Digest, options...)
if err != nil { if err != nil {
imh.Errors = append(imh.Errors, v2.ErrorCodeManifestUnknown.WithDetail(err)) imh.Errors = append(imh.Errors, v2.ErrorCodeManifestUnknown.WithDetail(err))
return return
@ -245,7 +249,11 @@ func (imh *imageManifestHandler) PutImageManifest(w http.ResponseWriter, r *http
return return
} }
_, err = manifests.Put(imh, manifest) var options []distribution.ManifestServiceOption
if imh.Tag != "" {
options = append(options, distribution.WithTag(imh.Tag))
}
_, err = manifests.Put(imh, manifest, options...)
if err != nil { if err != nil {
// TODO(stevvooe): These error handling switches really need to be // TODO(stevvooe): These error handling switches really need to be
// handled by an app global mapper. // handled by an app global mapper.
@ -275,6 +283,8 @@ func (imh *imageManifestHandler) PutImageManifest(w http.ResponseWriter, r *http
} }
} }
} }
case errcode.Error:
imh.Errors = append(imh.Errors, err)
default: default:
imh.Errors = append(imh.Errors, errcode.ErrorCodeUnknown.WithDetail(err)) imh.Errors = append(imh.Errors, errcode.ErrorCodeUnknown.WithDetail(err))
} }

View file

@ -41,6 +41,8 @@ func (th *tagsHandler) GetTags(w http.ResponseWriter, r *http.Request) {
switch err := err.(type) { switch err := err.(type) {
case distribution.ErrRepositoryUnknown: case distribution.ErrRepositoryUnknown:
th.Errors = append(th.Errors, v2.ErrorCodeNameUnknown.WithDetail(map[string]string{"name": th.Repository.Named().Name()})) th.Errors = append(th.Errors, v2.ErrorCodeNameUnknown.WithDetail(map[string]string{"name": th.Repository.Named().Name()}))
case errcode.Error:
th.Errors = append(th.Errors, err)
default: default:
th.Errors = append(th.Errors, errcode.ErrorCodeUnknown.WithDetail(err)) th.Errors = append(th.Errors, errcode.ErrorCodeUnknown.WithDetail(err))
} }

View file

@ -5,6 +5,7 @@ import (
"github.com/docker/distribution" "github.com/docker/distribution"
"github.com/docker/distribution/context" "github.com/docker/distribution/context"
"github.com/docker/distribution/registry/storage"
) )
// InitFunc is the type of a RegistryMiddleware factory function and is // InitFunc is the type of a RegistryMiddleware factory function and is
@ -12,6 +13,7 @@ import (
type InitFunc func(ctx context.Context, registry distribution.Namespace, options map[string]interface{}) (distribution.Namespace, error) type InitFunc func(ctx context.Context, registry distribution.Namespace, options map[string]interface{}) (distribution.Namespace, error)
var middlewares map[string]InitFunc var middlewares map[string]InitFunc
var registryoptions []storage.RegistryOption
// Register is used to register an InitFunc for // Register is used to register an InitFunc for
// a RegistryMiddleware backend with the given name. // a RegistryMiddleware backend with the given name.
@ -38,3 +40,15 @@ func Get(ctx context.Context, name string, options map[string]interface{}, regis
return nil, fmt.Errorf("no registry middleware registered with name: %s", name) return nil, fmt.Errorf("no registry middleware registered with name: %s", name)
} }
// RegisterOptions adds more options to RegistryOption list. Options get applied before
// any other configuration-based options.
func RegisterOptions(options ...storage.RegistryOption) error {
registryoptions = append(registryoptions, options...)
return nil
}
// GetRegistryOptions returns list of RegistryOption.
func GetRegistryOptions() []storage.RegistryOption {
return registryoptions
}

View file

@ -25,6 +25,13 @@ func (c credentials) Basic(u *url.URL) (string, string) {
return up.username, up.password return up.username, up.password
} }
func (c credentials) RefreshToken(u *url.URL, service string) string {
return ""
}
func (c credentials) SetRefreshToken(u *url.URL, service, token string) {
}
// configureAuth stores credentials for challenge responses // configureAuth stores credentials for challenge responses
func configureAuth(username, password string) (auth.CredentialStore, error) { func configureAuth(username, password string) (auth.CredentialStore, error) {
creds := map[string]userpass{ creds := map[string]userpass{

View file

@ -132,8 +132,15 @@ func makeTestEnv(t *testing.T, name string) *testEnv {
t.Fatalf("unable to create tempdir: %s", err) t.Fatalf("unable to create tempdir: %s", err)
} }
localDriver, err := filesystem.FromParameters(map[string]interface{}{
"rootdirectory": truthDir,
})
if err != nil {
t.Fatalf("unable to create filesystem driver: %s", err)
}
// todo: create a tempfile area here // todo: create a tempfile area here
localRegistry, err := storage.NewRegistry(ctx, filesystem.New(truthDir), storage.BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider()), storage.EnableRedirect, storage.DisableDigestResumption) localRegistry, err := storage.NewRegistry(ctx, localDriver, storage.BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider()), storage.EnableRedirect, storage.DisableDigestResumption)
if err != nil { if err != nil {
t.Fatalf("error creating registry: %v", err) t.Fatalf("error creating registry: %v", err)
} }
@ -142,7 +149,14 @@ func makeTestEnv(t *testing.T, name string) *testEnv {
t.Fatalf("unexpected error getting repo: %v", err) t.Fatalf("unexpected error getting repo: %v", err)
} }
truthRegistry, err := storage.NewRegistry(ctx, filesystem.New(cacheDir), storage.BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider())) cacheDriver, err := filesystem.FromParameters(map[string]interface{}{
"rootdirectory": cacheDir,
})
if err != nil {
t.Fatalf("unable to create filesystem driver: %s", err)
}
truthRegistry, err := storage.NewRegistry(ctx, cacheDriver, storage.BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider()))
if err != nil { if err != nil {
t.Fatalf("error creating registry: %v", err) t.Fatalf("error creating registry: %v", err)
} }

View file

@ -60,12 +60,6 @@ func (sm statsManifest) Put(ctx context.Context, manifest distribution.Manifest,
return sm.manifests.Put(ctx, manifest) return sm.manifests.Put(ctx, manifest)
} }
/*func (sm statsManifest) Enumerate(ctx context.Context, manifests []distribution.Manifest, last distribution.Manifest) (n int, err error) {
sm.stats["enumerate"]++
return sm.manifests.Enumerate(ctx, manifests, last)
}
*/
type mockChallenger struct { type mockChallenger struct {
sync.Mutex sync.Mutex
count int count int
@ -75,7 +69,6 @@ type mockChallenger struct {
func (m *mockChallenger) tryEstablishChallenges(context.Context) error { func (m *mockChallenger) tryEstablishChallenges(context.Context) error {
m.Lock() m.Lock()
defer m.Unlock() defer m.Unlock()
m.count++ m.count++
return nil return nil
} }
@ -93,9 +86,15 @@ func newManifestStoreTestEnv(t *testing.T, name, tag string) *manifestStoreTestE
if err != nil { if err != nil {
t.Fatalf("unable to parse reference: %s", err) t.Fatalf("unable to parse reference: %s", err)
} }
k, err := libtrust.GenerateECP256PrivateKey()
if err != nil {
t.Fatal(err)
}
ctx := context.Background() ctx := context.Background()
truthRegistry, err := storage.NewRegistry(ctx, inmemory.New(), storage.BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider())) truthRegistry, err := storage.NewRegistry(ctx, inmemory.New(),
storage.BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider()),
storage.Schema1SigningKey(k))
if err != nil { if err != nil {
t.Fatalf("error creating registry: %v", err) t.Fatalf("error creating registry: %v", err)
} }
@ -117,7 +116,7 @@ func newManifestStoreTestEnv(t *testing.T, name, tag string) *manifestStoreTestE
t.Fatalf(err.Error()) t.Fatalf(err.Error())
} }
localRegistry, err := storage.NewRegistry(ctx, inmemory.New(), storage.BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider()), storage.EnableRedirect, storage.DisableDigestResumption) localRegistry, err := storage.NewRegistry(ctx, inmemory.New(), storage.BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider()), storage.EnableRedirect, storage.DisableDigestResumption, storage.Schema1SigningKey(k))
if err != nil { if err != nil {
t.Fatalf("error creating registry: %v", err) t.Fatalf("error creating registry: %v", err)
} }

View file

@ -22,13 +22,13 @@ import (
type proxyingRegistry struct { type proxyingRegistry struct {
embedded distribution.Namespace // provides local registry functionality embedded distribution.Namespace // provides local registry functionality
scheduler *scheduler.TTLExpirationScheduler scheduler *scheduler.TTLExpirationScheduler
remoteURL string remoteURL url.URL
authChallenger authChallenger authChallenger authChallenger
} }
// NewRegistryPullThroughCache creates a registry acting as a pull through cache // NewRegistryPullThroughCache creates a registry acting as a pull through cache
func NewRegistryPullThroughCache(ctx context.Context, registry distribution.Namespace, driver driver.StorageDriver, config configuration.Proxy) (distribution.Namespace, error) { func NewRegistryPullThroughCache(ctx context.Context, registry distribution.Namespace, driver driver.StorageDriver, config configuration.Proxy) (distribution.Namespace, error) {
_, err := url.Parse(config.RemoteURL) remoteURL, err := url.Parse(config.RemoteURL)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -99,9 +99,9 @@ func NewRegistryPullThroughCache(ctx context.Context, registry distribution.Name
return &proxyingRegistry{ return &proxyingRegistry{
embedded: registry, embedded: registry,
scheduler: s, scheduler: s,
remoteURL: config.RemoteURL, remoteURL: *remoteURL,
authChallenger: &remoteAuthChallenger{ authChallenger: &remoteAuthChallenger{
remoteURL: config.RemoteURL, remoteURL: *remoteURL,
cm: auth.NewSimpleChallengeManager(), cm: auth.NewSimpleChallengeManager(),
cs: cs, cs: cs,
}, },
@ -131,7 +131,7 @@ func (pr *proxyingRegistry) Repository(ctx context.Context, name reference.Named
return nil, err return nil, err
} }
remoteRepo, err := client.NewRepository(ctx, name, pr.remoteURL, tr) remoteRepo, err := client.NewRepository(ctx, name, pr.remoteURL.String(), tr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -182,7 +182,7 @@ type authChallenger interface {
} }
type remoteAuthChallenger struct { type remoteAuthChallenger struct {
remoteURL string remoteURL url.URL
sync.Mutex sync.Mutex
cm auth.ChallengeManager cm auth.ChallengeManager
cs auth.CredentialStore cs auth.CredentialStore
@ -201,8 +201,9 @@ func (r *remoteAuthChallenger) tryEstablishChallenges(ctx context.Context) error
r.Lock() r.Lock()
defer r.Unlock() defer r.Unlock()
remoteURL := r.remoteURL + "/v2/" remoteURL := r.remoteURL
challenges, err := r.cm.GetChallenges(remoteURL) remoteURL.Path = "/v2/"
challenges, err := r.cm.GetChallenges(r.remoteURL)
if err != nil { if err != nil {
return err return err
} }
@ -212,7 +213,7 @@ func (r *remoteAuthChallenger) tryEstablishChallenges(ctx context.Context) error
} }
// establish challenge type with upstream // establish challenge type with upstream
if err := ping(r.cm, remoteURL, challengeHeader); err != nil { if err := ping(r.cm, remoteURL.String(), challengeHeader); err != nil {
return err return err
} }

View file

@ -1,6 +1,7 @@
package proxy package proxy
import ( import (
"reflect"
"sort" "sort"
"sync" "sync"
"testing" "testing"
@ -92,7 +93,7 @@ func TestGet(t *testing.T) {
t.Fatalf("Expected 1 auth challenge call, got %#v", proxyTags.authChallenger) t.Fatalf("Expected 1 auth challenge call, got %#v", proxyTags.authChallenger)
} }
if d != remoteDesc { if !reflect.DeepEqual(d, remoteDesc) {
t.Fatal("unable to get put tag") t.Fatal("unable to get put tag")
} }
@ -101,7 +102,7 @@ func TestGet(t *testing.T) {
t.Fatal("remote tag not pulled into store") t.Fatal("remote tag not pulled into store")
} }
if local != remoteDesc { if !reflect.DeepEqual(local, remoteDesc) {
t.Fatalf("unexpected descriptor pulled through") t.Fatalf("unexpected descriptor pulled through")
} }
@ -121,7 +122,7 @@ func TestGet(t *testing.T) {
t.Fatalf("Expected 2 auth challenge calls, got %#v", proxyTags.authChallenger) t.Fatalf("Expected 2 auth challenge calls, got %#v", proxyTags.authChallenger)
} }
if d != newRemoteDesc { if !reflect.DeepEqual(d, newRemoteDesc) {
t.Fatal("unable to get put tag") t.Fatal("unable to get put tag")
} }

View file

@ -267,7 +267,7 @@ func logLevel(level configuration.Loglevel) log.Level {
return l return l
} }
// panicHandler add a HTTP handler to web app. The handler recover the happening // panicHandler add an HTTP handler to web app. The handler recover the happening
// panic. logrus.Panic transmits panic message to pre-config log hooks, which is // panic. logrus.Panic transmits panic message to pre-config log hooks, which is
// defined in config.yml. // defined in config.yml.
func panicHandler(handler http.Handler) http.Handler { func panicHandler(handler http.Handler) http.Handler {

View file

@ -1,7 +1,14 @@
package registry package registry
import ( import (
"fmt"
"os"
"github.com/docker/distribution/context"
"github.com/docker/distribution/registry/storage"
"github.com/docker/distribution/registry/storage/driver/factory"
"github.com/docker/distribution/version" "github.com/docker/distribution/version"
"github.com/docker/libtrust"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -10,6 +17,7 @@ var showVersion bool
func init() { func init() {
RootCmd.AddCommand(ServeCmd) RootCmd.AddCommand(ServeCmd)
RootCmd.AddCommand(GCCmd) RootCmd.AddCommand(GCCmd)
GCCmd.Flags().BoolVarP(&dryRun, "dry-run", "d", false, "do everything except remove the blobs")
RootCmd.Flags().BoolVarP(&showVersion, "version", "v", false, "show the version and exit") RootCmd.Flags().BoolVarP(&showVersion, "version", "v", false, "show the version and exit")
} }
@ -26,3 +34,51 @@ var RootCmd = &cobra.Command{
cmd.Usage() cmd.Usage()
}, },
} }
var dryRun bool
// GCCmd is the cobra command that corresponds to the garbage-collect subcommand
var GCCmd = &cobra.Command{
Use: "garbage-collect <config>",
Short: "`garbage-collect` deletes layers not referenced by any manifests",
Long: "`garbage-collect` deletes layers not referenced by any manifests",
Run: func(cmd *cobra.Command, args []string) {
config, err := resolveConfiguration(args)
if err != nil {
fmt.Fprintf(os.Stderr, "configuration error: %v\n", err)
cmd.Usage()
os.Exit(1)
}
driver, err := factory.Create(config.Storage.Type(), config.Storage.Parameters())
if err != nil {
fmt.Fprintf(os.Stderr, "failed to construct %s driver: %v", config.Storage.Type(), err)
os.Exit(1)
}
ctx := context.Background()
ctx, err = configureLogging(ctx, config)
if err != nil {
fmt.Fprintf(os.Stderr, "unable to configure logging with config: %s", err)
os.Exit(1)
}
k, err := libtrust.GenerateECP256PrivateKey()
if err != nil {
fmt.Fprint(os.Stderr, err)
os.Exit(1)
}
registry, err := storage.NewRegistry(ctx, driver, storage.Schema1SigningKey(k))
if err != nil {
fmt.Fprintf(os.Stderr, "failed to construct registry: %v", err)
os.Exit(1)
}
err = storage.MarkAndSweep(ctx, driver, registry, dryRun)
if err != nil {
fmt.Fprintf(os.Stderr, "failed to garbage collect: %v", err)
os.Exit(1)
}
},
}

View file

@ -7,6 +7,8 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"os" "os"
"path"
"reflect"
"testing" "testing"
"github.com/docker/distribution" "github.com/docker/distribution"
@ -41,10 +43,7 @@ func TestWriteSeek(t *testing.T) {
} }
contents := []byte{1, 2, 3} contents := []byte{1, 2, 3}
blobUpload.Write(contents) blobUpload.Write(contents)
offset, err := blobUpload.Seek(0, os.SEEK_CUR) offset := blobUpload.Size()
if err != nil {
t.Fatalf("unexpected error in blobUpload.Seek: %s", err)
}
if offset != int64(len(contents)) { if offset != int64(len(contents)) {
t.Fatalf("unexpected value for blobUpload offset: %v != %v", offset, len(contents)) t.Fatalf("unexpected value for blobUpload offset: %v != %v", offset, len(contents))
} }
@ -86,6 +85,15 @@ func TestSimpleBlobUpload(t *testing.T) {
t.Fatalf("unexpected error during upload cancellation: %v", err) t.Fatalf("unexpected error during upload cancellation: %v", err)
} }
// get the enclosing directory
uploadPath := path.Dir(blobUpload.(*blobWriter).path)
// ensure state was cleaned up
_, err = driver.List(ctx, uploadPath)
if err == nil {
t.Fatal("files in upload path after cleanup")
}
// Do a resume, get unknown upload // Do a resume, get unknown upload
blobUpload, err = bs.Resume(ctx, blobUpload.ID()) blobUpload, err = bs.Resume(ctx, blobUpload.ID())
if err != distribution.ErrBlobUploadUnknown { if err != distribution.ErrBlobUploadUnknown {
@ -113,11 +121,7 @@ func TestSimpleBlobUpload(t *testing.T) {
t.Fatalf("layer data write incomplete") t.Fatalf("layer data write incomplete")
} }
offset, err := blobUpload.Seek(0, os.SEEK_CUR) offset := blobUpload.Size()
if err != nil {
t.Fatalf("unexpected error seeking layer upload: %v", err)
}
if offset != nn { if offset != nn {
t.Fatalf("blobUpload not updated with correct offset: %v != %v", offset, nn) t.Fatalf("blobUpload not updated with correct offset: %v != %v", offset, nn)
} }
@ -135,6 +139,13 @@ func TestSimpleBlobUpload(t *testing.T) {
t.Fatalf("unexpected error finishing layer upload: %v", err) t.Fatalf("unexpected error finishing layer upload: %v", err)
} }
// ensure state was cleaned up
uploadPath = path.Dir(blobUpload.(*blobWriter).path)
_, err = driver.List(ctx, uploadPath)
if err == nil {
t.Fatal("files in upload path after commit")
}
// After finishing an upload, it should no longer exist. // After finishing an upload, it should no longer exist.
if _, err := bs.Resume(ctx, blobUpload.ID()); err != distribution.ErrBlobUploadUnknown { if _, err := bs.Resume(ctx, blobUpload.ID()); err != distribution.ErrBlobUploadUnknown {
t.Fatalf("expected layer upload to be unknown, got %v", err) t.Fatalf("expected layer upload to be unknown, got %v", err)
@ -146,7 +157,7 @@ func TestSimpleBlobUpload(t *testing.T) {
t.Fatalf("unexpected error checking for existence: %v, %#v", err, bs) t.Fatalf("unexpected error checking for existence: %v, %#v", err, bs)
} }
if statDesc != desc { if !reflect.DeepEqual(statDesc, desc) {
t.Fatalf("descriptors not equal: %v != %v", statDesc, desc) t.Fatalf("descriptors not equal: %v != %v", statDesc, desc)
} }
@ -400,7 +411,7 @@ func TestBlobMount(t *testing.T) {
t.Fatalf("unexpected error checking for existence: %v, %#v", err, sbs) t.Fatalf("unexpected error checking for existence: %v, %#v", err, sbs)
} }
if statDesc != desc { if !reflect.DeepEqual(statDesc, desc) {
t.Fatalf("descriptors not equal: %v != %v", statDesc, desc) t.Fatalf("descriptors not equal: %v != %v", statDesc, desc)
} }
@ -426,7 +437,7 @@ func TestBlobMount(t *testing.T) {
t.Fatalf("unexpected error mounting layer: %v", err) t.Fatalf("unexpected error mounting layer: %v", err)
} }
if ebm.Descriptor != desc { if !reflect.DeepEqual(ebm.Descriptor, desc) {
t.Fatalf("descriptors not equal: %v != %v", ebm.Descriptor, desc) t.Fatalf("descriptors not equal: %v != %v", ebm.Descriptor, desc)
} }
@ -436,7 +447,7 @@ func TestBlobMount(t *testing.T) {
t.Fatalf("unexpected error checking for existence: %v, %#v", err, bs) t.Fatalf("unexpected error checking for existence: %v, %#v", err, bs)
} }
if statDesc != desc { if !reflect.DeepEqual(statDesc, desc) {
t.Fatalf("descriptors not equal: %v != %v", statDesc, desc) t.Fatalf("descriptors not equal: %v != %v", statDesc, desc)
} }

View file

@ -75,7 +75,6 @@ func (bs *blobStore) Put(ctx context.Context, mediaType string, p []byte) (distr
} }
// TODO(stevvooe): Write out mediatype here, as well. // TODO(stevvooe): Write out mediatype here, as well.
return distribution.Descriptor{ return distribution.Descriptor{
Size: int64(len(p)), Size: int64(len(p)),

View file

@ -18,9 +18,10 @@ var (
errResumableDigestNotAvailable = errors.New("resumable digest not available") errResumableDigestNotAvailable = errors.New("resumable digest not available")
) )
// layerWriter is used to control the various aspects of resumable // blobWriter is used to control the various aspects of resumable
// layer upload. It implements the LayerUpload interface. // blob upload.
type blobWriter struct { type blobWriter struct {
ctx context.Context
blobStore *linkedBlobStore blobStore *linkedBlobStore
id string id string
@ -28,11 +29,12 @@ type blobWriter struct {
digester digest.Digester digester digest.Digester
written int64 // track the contiguous write written int64 // track the contiguous write
// implementes io.WriteSeeker, io.ReaderFrom and io.Closer to satisfy fileWriter storagedriver.FileWriter
// LayerUpload Interface driver storagedriver.StorageDriver
fileWriter path string
resumableDigestEnabled bool resumableDigestEnabled bool
committed bool
} }
var _ distribution.BlobWriter = &blobWriter{} var _ distribution.BlobWriter = &blobWriter{}
@ -51,10 +53,12 @@ func (bw *blobWriter) StartedAt() time.Time {
func (bw *blobWriter) Commit(ctx context.Context, desc distribution.Descriptor) (distribution.Descriptor, error) { func (bw *blobWriter) Commit(ctx context.Context, desc distribution.Descriptor) (distribution.Descriptor, error) {
context.GetLogger(ctx).Debug("(*blobWriter).Commit") context.GetLogger(ctx).Debug("(*blobWriter).Commit")
if err := bw.fileWriter.Close(); err != nil { if err := bw.fileWriter.Commit(); err != nil {
return distribution.Descriptor{}, err return distribution.Descriptor{}, err
} }
bw.Close()
canonical, err := bw.validateBlob(ctx, desc) canonical, err := bw.validateBlob(ctx, desc)
if err != nil { if err != nil {
return distribution.Descriptor{}, err return distribution.Descriptor{}, err
@ -77,6 +81,7 @@ func (bw *blobWriter) Commit(ctx context.Context, desc distribution.Descriptor)
return distribution.Descriptor{}, err return distribution.Descriptor{}, err
} }
bw.committed = true
return canonical, nil return canonical, nil
} }
@ -84,23 +89,34 @@ func (bw *blobWriter) Commit(ctx context.Context, desc distribution.Descriptor)
// the writer and canceling the operation. // the writer and canceling the operation.
func (bw *blobWriter) Cancel(ctx context.Context) error { func (bw *blobWriter) Cancel(ctx context.Context) error {
context.GetLogger(ctx).Debug("(*blobWriter).Rollback") context.GetLogger(ctx).Debug("(*blobWriter).Rollback")
if err := bw.fileWriter.Cancel(); err != nil {
return err
}
if err := bw.Close(); err != nil {
context.GetLogger(ctx).Errorf("error closing blobwriter: %s", err)
}
if err := bw.removeResources(ctx); err != nil { if err := bw.removeResources(ctx); err != nil {
return err return err
} }
bw.Close()
return nil return nil
} }
func (bw *blobWriter) Size() int64 {
return bw.fileWriter.Size()
}
func (bw *blobWriter) Write(p []byte) (int, error) { func (bw *blobWriter) Write(p []byte) (int, error) {
// Ensure that the current write offset matches how many bytes have been // Ensure that the current write offset matches how many bytes have been
// written to the digester. If not, we need to update the digest state to // written to the digester. If not, we need to update the digest state to
// match the current write position. // match the current write position.
if err := bw.resumeDigestAt(bw.blobStore.ctx, bw.offset); err != nil && err != errResumableDigestNotAvailable { if err := bw.resumeDigest(bw.blobStore.ctx); err != nil && err != errResumableDigestNotAvailable {
return 0, err return 0, err
} }
n, err := io.MultiWriter(&bw.fileWriter, bw.digester.Hash()).Write(p) n, err := io.MultiWriter(bw.fileWriter, bw.digester.Hash()).Write(p)
bw.written += int64(n) bw.written += int64(n)
return n, err return n, err
@ -110,19 +126,19 @@ func (bw *blobWriter) ReadFrom(r io.Reader) (n int64, err error) {
// Ensure that the current write offset matches how many bytes have been // Ensure that the current write offset matches how many bytes have been
// written to the digester. If not, we need to update the digest state to // written to the digester. If not, we need to update the digest state to
// match the current write position. // match the current write position.
if err := bw.resumeDigestAt(bw.blobStore.ctx, bw.offset); err != nil && err != errResumableDigestNotAvailable { if err := bw.resumeDigest(bw.blobStore.ctx); err != nil && err != errResumableDigestNotAvailable {
return 0, err return 0, err
} }
nn, err := bw.fileWriter.ReadFrom(io.TeeReader(r, bw.digester.Hash())) nn, err := io.Copy(io.MultiWriter(bw.fileWriter, bw.digester.Hash()), r)
bw.written += nn bw.written += nn
return nn, err return nn, err
} }
func (bw *blobWriter) Close() error { func (bw *blobWriter) Close() error {
if bw.err != nil { if bw.committed {
return bw.err return errors.New("blobwriter close after commit")
} }
if err := bw.storeHashState(bw.blobStore.ctx); err != nil { if err := bw.storeHashState(bw.blobStore.ctx); err != nil {
@ -148,8 +164,10 @@ func (bw *blobWriter) validateBlob(ctx context.Context, desc distribution.Descri
} }
} }
var size int64
// Stat the on disk file // Stat the on disk file
if fi, err := bw.fileWriter.driver.Stat(ctx, bw.path); err != nil { if fi, err := bw.driver.Stat(ctx, bw.path); err != nil {
switch err := err.(type) { switch err := err.(type) {
case storagedriver.PathNotFoundError: case storagedriver.PathNotFoundError:
// NOTE(stevvooe): We really don't care if the file is // NOTE(stevvooe): We really don't care if the file is
@ -165,23 +183,23 @@ func (bw *blobWriter) validateBlob(ctx context.Context, desc distribution.Descri
return distribution.Descriptor{}, fmt.Errorf("unexpected directory at upload location %q", bw.path) return distribution.Descriptor{}, fmt.Errorf("unexpected directory at upload location %q", bw.path)
} }
bw.size = fi.Size() size = fi.Size()
} }
if desc.Size > 0 { if desc.Size > 0 {
if desc.Size != bw.size { if desc.Size != size {
return distribution.Descriptor{}, distribution.ErrBlobInvalidLength return distribution.Descriptor{}, distribution.ErrBlobInvalidLength
} }
} else { } else {
// if provided 0 or negative length, we can assume caller doesn't know or // if provided 0 or negative length, we can assume caller doesn't know or
// care about length. // care about length.
desc.Size = bw.size desc.Size = size
} }
// TODO(stevvooe): This section is very meandering. Need to be broken down // TODO(stevvooe): This section is very meandering. Need to be broken down
// to be a lot more clear. // to be a lot more clear.
if err := bw.resumeDigestAt(ctx, bw.size); err == nil { if err := bw.resumeDigest(ctx); err == nil {
canonical = bw.digester.Digest() canonical = bw.digester.Digest()
if canonical.Algorithm() == desc.Digest.Algorithm() { if canonical.Algorithm() == desc.Digest.Algorithm() {
@ -206,7 +224,7 @@ func (bw *blobWriter) validateBlob(ctx context.Context, desc distribution.Descri
// the same, we don't need to read the data from the backend. This is // the same, we don't need to read the data from the backend. This is
// because we've written the entire file in the lifecycle of the // because we've written the entire file in the lifecycle of the
// current instance. // current instance.
if bw.written == bw.size && digest.Canonical == desc.Digest.Algorithm() { if bw.written == size && digest.Canonical == desc.Digest.Algorithm() {
canonical = bw.digester.Digest() canonical = bw.digester.Digest()
verified = desc.Digest == canonical verified = desc.Digest == canonical
} }
@ -223,7 +241,7 @@ func (bw *blobWriter) validateBlob(ctx context.Context, desc distribution.Descri
} }
// Read the file from the backend driver and validate it. // Read the file from the backend driver and validate it.
fr, err := newFileReader(ctx, bw.fileWriter.driver, bw.path, desc.Size) fr, err := newFileReader(ctx, bw.driver, bw.path, desc.Size)
if err != nil { if err != nil {
return distribution.Descriptor{}, err return distribution.Descriptor{}, err
} }
@ -357,7 +375,7 @@ func (bw *blobWriter) Reader() (io.ReadCloser, error) {
// todo(richardscothern): Change to exponential backoff, i=0.5, e=2, n=4 // todo(richardscothern): Change to exponential backoff, i=0.5, e=2, n=4
try := 1 try := 1
for try <= 5 { for try <= 5 {
_, err := bw.fileWriter.driver.Stat(bw.ctx, bw.path) _, err := bw.driver.Stat(bw.ctx, bw.path)
if err == nil { if err == nil {
break break
} }
@ -371,7 +389,7 @@ func (bw *blobWriter) Reader() (io.ReadCloser, error) {
} }
} }
readCloser, err := bw.fileWriter.driver.ReadStream(bw.ctx, bw.path, 0) readCloser, err := bw.driver.Reader(bw.ctx, bw.path, 0)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -4,8 +4,6 @@ package storage
import ( import (
"fmt" "fmt"
"io"
"os"
"path" "path"
"strconv" "strconv"
@ -19,24 +17,18 @@ import (
_ "github.com/stevvooe/resumable/sha512" _ "github.com/stevvooe/resumable/sha512"
) )
// resumeDigestAt attempts to restore the state of the internal hash function // resumeDigest attempts to restore the state of the internal hash function
// by loading the most recent saved hash state less than or equal to the given // by loading the most recent saved hash state equal to the current size of the blob.
// offset. Any unhashed bytes remaining less than the given offset are hashed func (bw *blobWriter) resumeDigest(ctx context.Context) error {
// from the content uploaded so far.
func (bw *blobWriter) resumeDigestAt(ctx context.Context, offset int64) error {
if !bw.resumableDigestEnabled { if !bw.resumableDigestEnabled {
return errResumableDigestNotAvailable return errResumableDigestNotAvailable
} }
if offset < 0 {
return fmt.Errorf("cannot resume hash at negative offset: %d", offset)
}
h, ok := bw.digester.Hash().(resumable.Hash) h, ok := bw.digester.Hash().(resumable.Hash)
if !ok { if !ok {
return errResumableDigestNotAvailable return errResumableDigestNotAvailable
} }
offset := bw.fileWriter.Size()
if offset == int64(h.Len()) { if offset == int64(h.Len()) {
// State of digester is already at the requested offset. // State of digester is already at the requested offset.
return nil return nil
@ -49,24 +41,12 @@ func (bw *blobWriter) resumeDigestAt(ctx context.Context, offset int64) error {
return fmt.Errorf("unable to get stored hash states with offset %d: %s", offset, err) return fmt.Errorf("unable to get stored hash states with offset %d: %s", offset, err)
} }
// Find the highest stored hashState with offset less than or equal to // Find the highest stored hashState with offset equal to
// the requested offset. // the requested offset.
for _, hashState := range hashStates { for _, hashState := range hashStates {
if hashState.offset == offset { if hashState.offset == offset {
hashStateMatch = hashState hashStateMatch = hashState
break // Found an exact offset match. break // Found an exact offset match.
} else if hashState.offset < offset && hashState.offset > hashStateMatch.offset {
// This offset is closer to the requested offset.
hashStateMatch = hashState
} else if hashState.offset > offset {
// Remove any stored hash state with offsets higher than this one
// as writes to this resumed hasher will make those invalid. This
// is probably okay to skip for now since we don't expect anyone to
// use the API in this way. For that reason, we don't treat an
// an error here as a fatal error, but only log it.
if err := bw.driver.Delete(ctx, hashState.path); err != nil {
logrus.Errorf("unable to delete stale hash state %q: %s", hashState.path, err)
}
} }
} }
@ -86,20 +66,7 @@ func (bw *blobWriter) resumeDigestAt(ctx context.Context, offset int64) error {
// Mind the gap. // Mind the gap.
if gapLen := offset - int64(h.Len()); gapLen > 0 { if gapLen := offset - int64(h.Len()); gapLen > 0 {
// Need to read content from the upload to catch up to the desired offset. return errResumableDigestNotAvailable
fr, err := newFileReader(ctx, bw.driver, bw.path, bw.size)
if err != nil {
return err
}
defer fr.Close()
if _, err = fr.Seek(int64(h.Len()), os.SEEK_SET); err != nil {
return fmt.Errorf("unable to seek to layer reader offset %d: %s", h.Len(), err)
}
if _, err := io.CopyN(h, fr, gapLen); err != nil {
return err
}
} }
return nil return nil

View file

@ -1,6 +1,7 @@
package cachecheck package cachecheck
import ( import (
"reflect"
"testing" "testing"
"github.com/docker/distribution" "github.com/docker/distribution"
@ -79,7 +80,7 @@ func checkBlobDescriptorCacheSetAndRead(t *testing.T, ctx context.Context, provi
t.Fatalf("unexpected error statting fake2:abc: %v", err) t.Fatalf("unexpected error statting fake2:abc: %v", err)
} }
if expected != desc { if !reflect.DeepEqual(expected, desc) {
t.Fatalf("unexpected descriptor: %#v != %#v", expected, desc) t.Fatalf("unexpected descriptor: %#v != %#v", expected, desc)
} }
@ -89,7 +90,7 @@ func checkBlobDescriptorCacheSetAndRead(t *testing.T, ctx context.Context, provi
t.Fatalf("descriptor not returned for canonical key: %v", err) t.Fatalf("descriptor not returned for canonical key: %v", err)
} }
if expected != desc { if !reflect.DeepEqual(expected, desc) {
t.Fatalf("unexpected descriptor: %#v != %#v", expected, desc) t.Fatalf("unexpected descriptor: %#v != %#v", expected, desc)
} }
@ -99,7 +100,7 @@ func checkBlobDescriptorCacheSetAndRead(t *testing.T, ctx context.Context, provi
t.Fatalf("expected blob unknown in global cache: %v, %v", err, desc) t.Fatalf("expected blob unknown in global cache: %v, %v", err, desc)
} }
if desc != expected { if !reflect.DeepEqual(desc, expected) {
t.Fatalf("unexpected descriptor: %#v != %#v", expected, desc) t.Fatalf("unexpected descriptor: %#v != %#v", expected, desc)
} }
@ -109,7 +110,7 @@ func checkBlobDescriptorCacheSetAndRead(t *testing.T, ctx context.Context, provi
t.Fatalf("unexpected error checking glboal descriptor: %v", err) t.Fatalf("unexpected error checking glboal descriptor: %v", err)
} }
if desc != expected { if !reflect.DeepEqual(desc, expected) {
t.Fatalf("unexpected descriptor: %#v != %#v", expected, desc) t.Fatalf("unexpected descriptor: %#v != %#v", expected, desc)
} }
@ -126,7 +127,7 @@ func checkBlobDescriptorCacheSetAndRead(t *testing.T, ctx context.Context, provi
t.Fatalf("unexpected error getting descriptor: %v", err) t.Fatalf("unexpected error getting descriptor: %v", err)
} }
if desc != expected { if !reflect.DeepEqual(desc, expected) {
t.Fatalf("unexpected descriptor: %#v != %#v", desc, expected) t.Fatalf("unexpected descriptor: %#v != %#v", desc, expected)
} }
@ -137,7 +138,7 @@ func checkBlobDescriptorCacheSetAndRead(t *testing.T, ctx context.Context, provi
expected.MediaType = "application/octet-stream" // expect original mediatype in global expected.MediaType = "application/octet-stream" // expect original mediatype in global
if desc != expected { if !reflect.DeepEqual(desc, expected) {
t.Fatalf("unexpected descriptor: %#v != %#v", desc, expected) t.Fatalf("unexpected descriptor: %#v != %#v", desc, expected)
} }
} }
@ -163,7 +164,7 @@ func checkBlobDescriptorCacheClear(t *testing.T, ctx context.Context, provider c
t.Fatalf("unexpected error statting fake2:abc: %v", err) t.Fatalf("unexpected error statting fake2:abc: %v", err)
} }
if expected != desc { if !reflect.DeepEqual(expected, desc) {
t.Fatalf("unexpected descriptor: %#v != %#v", expected, desc) t.Fatalf("unexpected descriptor: %#v != %#v", expected, desc)
} }

View file

@ -3,6 +3,7 @@
package azure package azure
import ( import (
"bufio"
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
@ -26,6 +27,7 @@ const (
paramAccountKey = "accountkey" paramAccountKey = "accountkey"
paramContainer = "container" paramContainer = "container"
paramRealm = "realm" paramRealm = "realm"
maxChunkSize = 4 * 1024 * 1024
) )
type driver struct { type driver struct {
@ -117,18 +119,21 @@ func (d *driver) PutContent(ctx context.Context, path string, contents []byte) e
if _, err := d.client.DeleteBlobIfExists(d.container, path); err != nil { if _, err := d.client.DeleteBlobIfExists(d.container, path); err != nil {
return err return err
} }
if err := d.client.CreateBlockBlob(d.container, path); err != nil { writer, err := d.Writer(ctx, path, false)
if err != nil {
return err return err
} }
bs := newAzureBlockStorage(d.client) defer writer.Close()
bw := newRandomBlobWriter(&bs, azure.MaxBlobBlockSize) _, err = writer.Write(contents)
_, err := bw.WriteBlobAt(d.container, path, 0, bytes.NewReader(contents)) if err != nil {
return err return err
}
return writer.Commit()
} }
// ReadStream retrieves an io.ReadCloser for the content stored at "path" with a // Reader retrieves an io.ReadCloser for the content stored at "path" with a
// given byte offset. // given byte offset.
func (d *driver) ReadStream(ctx context.Context, path string, offset int64) (io.ReadCloser, error) { func (d *driver) Reader(ctx context.Context, path string, offset int64) (io.ReadCloser, error) {
if ok, err := d.client.BlobExists(d.container, path); err != nil { if ok, err := d.client.BlobExists(d.container, path); err != nil {
return nil, err return nil, err
} else if !ok { } else if !ok {
@ -153,25 +158,38 @@ func (d *driver) ReadStream(ctx context.Context, path string, offset int64) (io.
return resp, nil return resp, nil
} }
// WriteStream stores the contents of the provided io.ReadCloser at a location // Writer returns a FileWriter which will store the content written to it
// designated by the given path. // at the location designated by "path" after the call to Commit.
func (d *driver) WriteStream(ctx context.Context, path string, offset int64, reader io.Reader) (int64, error) { func (d *driver) Writer(ctx context.Context, path string, append bool) (storagedriver.FileWriter, error) {
if blobExists, err := d.client.BlobExists(d.container, path); err != nil { blobExists, err := d.client.BlobExists(d.container, path)
return 0, err if err != nil {
} else if !blobExists { return nil, err
err := d.client.CreateBlockBlob(d.container, path) }
var size int64
if blobExists {
if append {
blobProperties, err := d.client.GetBlobProperties(d.container, path)
if err != nil {
return nil, err
}
size = blobProperties.ContentLength
} else {
err := d.client.DeleteBlob(d.container, path)
if err != nil {
return nil, err
}
}
} else {
if append {
return nil, storagedriver.PathNotFoundError{Path: path}
}
err := d.client.PutAppendBlob(d.container, path, nil)
if err != nil { if err != nil {
return 0, err return nil, err
} }
} }
if offset < 0 {
return 0, storagedriver.InvalidOffsetError{Path: path, Offset: offset}
}
bs := newAzureBlockStorage(d.client) return d.newWriter(path, size), nil
bw := newRandomBlobWriter(&bs, azure.MaxBlobBlockSize)
zw := newZeroFillWriter(&bw)
return zw.Write(d.container, path, offset, reader)
} }
// Stat retrieves the FileInfo for the given path, including the current size // Stat retrieves the FileInfo for the given path, including the current size
@ -236,6 +254,9 @@ func (d *driver) List(ctx context.Context, path string) ([]string, error) {
} }
list := directDescendants(blobs, path) list := directDescendants(blobs, path)
if path != "" && len(list) == 0 {
return nil, storagedriver.PathNotFoundError{Path: path}
}
return list, nil return list, nil
} }
@ -361,6 +382,101 @@ func (d *driver) listBlobs(container, virtPath string) ([]string, error) {
} }
func is404(err error) bool { func is404(err error) bool {
e, ok := err.(azure.AzureStorageServiceError) statusCodeErr, ok := err.(azure.AzureStorageServiceError)
return ok && e.StatusCode == http.StatusNotFound return ok && statusCodeErr.StatusCode == http.StatusNotFound
}
type writer struct {
driver *driver
path string
size int64
bw *bufio.Writer
closed bool
committed bool
cancelled bool
}
func (d *driver) newWriter(path string, size int64) storagedriver.FileWriter {
return &writer{
driver: d,
path: path,
size: size,
bw: bufio.NewWriterSize(&blockWriter{
client: d.client,
container: d.container,
path: path,
}, maxChunkSize),
}
}
func (w *writer) Write(p []byte) (int, error) {
if w.closed {
return 0, fmt.Errorf("already closed")
} else if w.committed {
return 0, fmt.Errorf("already committed")
} else if w.cancelled {
return 0, fmt.Errorf("already cancelled")
}
n, err := w.bw.Write(p)
w.size += int64(n)
return n, err
}
func (w *writer) Size() int64 {
return w.size
}
func (w *writer) Close() error {
if w.closed {
return fmt.Errorf("already closed")
}
w.closed = true
return w.bw.Flush()
}
func (w *writer) Cancel() error {
if w.closed {
return fmt.Errorf("already closed")
} else if w.committed {
return fmt.Errorf("already committed")
}
w.cancelled = true
return w.driver.client.DeleteBlob(w.driver.container, w.path)
}
func (w *writer) Commit() error {
if w.closed {
return fmt.Errorf("already closed")
} else if w.committed {
return fmt.Errorf("already committed")
} else if w.cancelled {
return fmt.Errorf("already cancelled")
}
w.committed = true
return w.bw.Flush()
}
type blockWriter struct {
client azure.BlobStorageClient
container string
path string
}
func (bw *blockWriter) Write(p []byte) (int, error) {
n := 0
for offset := 0; offset < len(p); offset += maxChunkSize {
chunkSize := maxChunkSize
if offset+chunkSize > len(p) {
chunkSize = len(p) - offset
}
err := bw.client.AppendBlock(bw.container, bw.path, p[offset:offset+chunkSize])
if err != nil {
return n, err
}
n += chunkSize
}
return n, nil
} }

View file

@ -1,24 +0,0 @@
package azure
import (
"fmt"
"io"
azure "github.com/Azure/azure-sdk-for-go/storage"
)
// azureBlockStorage is adaptor between azure.BlobStorageClient and
// blockStorage interface.
type azureBlockStorage struct {
azure.BlobStorageClient
}
func (b *azureBlockStorage) GetSectionReader(container, blob string, start, length int64) (io.ReadCloser, error) {
return b.BlobStorageClient.GetBlobRange(container, blob, fmt.Sprintf("%v-%v", start, start+length-1))
}
func newAzureBlockStorage(b azure.BlobStorageClient) azureBlockStorage {
a := azureBlockStorage{}
a.BlobStorageClient = b
return a
}

View file

@ -1,155 +0,0 @@
package azure
import (
"bytes"
"fmt"
"io"
"io/ioutil"
azure "github.com/Azure/azure-sdk-for-go/storage"
)
type StorageSimulator struct {
blobs map[string]*BlockBlob
}
type BlockBlob struct {
blocks map[string]*DataBlock
blockList []string
}
type DataBlock struct {
data []byte
committed bool
}
func (s *StorageSimulator) path(container, blob string) string {
return fmt.Sprintf("%s/%s", container, blob)
}
func (s *StorageSimulator) BlobExists(container, blob string) (bool, error) {
_, ok := s.blobs[s.path(container, blob)]
return ok, nil
}
func (s *StorageSimulator) GetBlob(container, blob string) (io.ReadCloser, error) {
bb, ok := s.blobs[s.path(container, blob)]
if !ok {
return nil, fmt.Errorf("blob not found")
}
var readers []io.Reader
for _, bID := range bb.blockList {
readers = append(readers, bytes.NewReader(bb.blocks[bID].data))
}
return ioutil.NopCloser(io.MultiReader(readers...)), nil
}
func (s *StorageSimulator) GetSectionReader(container, blob string, start, length int64) (io.ReadCloser, error) {
r, err := s.GetBlob(container, blob)
if err != nil {
return nil, err
}
b, err := ioutil.ReadAll(r)
if err != nil {
return nil, err
}
return ioutil.NopCloser(bytes.NewReader(b[start : start+length])), nil
}
func (s *StorageSimulator) CreateBlockBlob(container, blob string) error {
path := s.path(container, blob)
bb := &BlockBlob{
blocks: make(map[string]*DataBlock),
blockList: []string{},
}
s.blobs[path] = bb
return nil
}
func (s *StorageSimulator) PutBlock(container, blob, blockID string, chunk []byte) error {
path := s.path(container, blob)
bb, ok := s.blobs[path]
if !ok {
return fmt.Errorf("blob not found")
}
data := make([]byte, len(chunk))
copy(data, chunk)
bb.blocks[blockID] = &DataBlock{data: data, committed: false} // add block to blob
return nil
}
func (s *StorageSimulator) GetBlockList(container, blob string, blockType azure.BlockListType) (azure.BlockListResponse, error) {
resp := azure.BlockListResponse{}
bb, ok := s.blobs[s.path(container, blob)]
if !ok {
return resp, fmt.Errorf("blob not found")
}
// Iterate committed blocks (in order)
if blockType == azure.BlockListTypeAll || blockType == azure.BlockListTypeCommitted {
for _, blockID := range bb.blockList {
b := bb.blocks[blockID]
block := azure.BlockResponse{
Name: blockID,
Size: int64(len(b.data)),
}
resp.CommittedBlocks = append(resp.CommittedBlocks, block)
}
}
// Iterate uncommitted blocks (in no order)
if blockType == azure.BlockListTypeAll || blockType == azure.BlockListTypeCommitted {
for blockID, b := range bb.blocks {
block := azure.BlockResponse{
Name: blockID,
Size: int64(len(b.data)),
}
if !b.committed {
resp.UncommittedBlocks = append(resp.UncommittedBlocks, block)
}
}
}
return resp, nil
}
func (s *StorageSimulator) PutBlockList(container, blob string, blocks []azure.Block) error {
bb, ok := s.blobs[s.path(container, blob)]
if !ok {
return fmt.Errorf("blob not found")
}
var blockIDs []string
for _, v := range blocks {
bl, ok := bb.blocks[v.ID]
if !ok { // check if block ID exists
return fmt.Errorf("Block id '%s' not found", v.ID)
}
bl.committed = true
blockIDs = append(blockIDs, v.ID)
}
// Mark all other blocks uncommitted
for k, b := range bb.blocks {
inList := false
for _, v := range blockIDs {
if k == v {
inList = true
break
}
}
if !inList {
b.committed = false
}
}
bb.blockList = blockIDs
return nil
}
func NewStorageSimulator() StorageSimulator {
return StorageSimulator{
blobs: make(map[string]*BlockBlob),
}
}

View file

@ -1,60 +0,0 @@
package azure
import (
"encoding/base64"
"fmt"
"math/rand"
"sync"
"time"
azure "github.com/Azure/azure-sdk-for-go/storage"
)
type blockIDGenerator struct {
pool map[string]bool
r *rand.Rand
m sync.Mutex
}
// Generate returns an unused random block id and adds the generated ID
// to list of used IDs so that the same block name is not used again.
func (b *blockIDGenerator) Generate() string {
b.m.Lock()
defer b.m.Unlock()
var id string
for {
id = toBlockID(int(b.r.Int()))
if !b.exists(id) {
break
}
}
b.pool[id] = true
return id
}
func (b *blockIDGenerator) exists(id string) bool {
_, used := b.pool[id]
return used
}
func (b *blockIDGenerator) Feed(blocks azure.BlockListResponse) {
b.m.Lock()
defer b.m.Unlock()
for _, bl := range append(blocks.CommittedBlocks, blocks.UncommittedBlocks...) {
b.pool[bl.Name] = true
}
}
func newBlockIDGenerator() *blockIDGenerator {
return &blockIDGenerator{
pool: make(map[string]bool),
r: rand.New(rand.NewSource(time.Now().UnixNano()))}
}
// toBlockId converts given integer to base64-encoded block ID of a fixed length.
func toBlockID(i int) string {
s := fmt.Sprintf("%029d", i) // add zero padding for same length-blobs
return base64.StdEncoding.EncodeToString([]byte(s))
}

View file

@ -1,74 +0,0 @@
package azure
import (
"math"
"testing"
azure "github.com/Azure/azure-sdk-for-go/storage"
)
func Test_blockIdGenerator(t *testing.T) {
r := newBlockIDGenerator()
for i := 1; i <= 10; i++ {
if expected := i - 1; len(r.pool) != expected {
t.Fatalf("rand pool had wrong number of items: %d, expected:%d", len(r.pool), expected)
}
if id := r.Generate(); id == "" {
t.Fatal("returned empty id")
}
if expected := i; len(r.pool) != expected {
t.Fatalf("rand pool has wrong number of items: %d, expected:%d", len(r.pool), expected)
}
}
}
func Test_blockIdGenerator_Feed(t *testing.T) {
r := newBlockIDGenerator()
if expected := 0; len(r.pool) != expected {
t.Fatalf("rand pool had wrong number of items: %d, expected:%d", len(r.pool), expected)
}
// feed empty list
blocks := azure.BlockListResponse{}
r.Feed(blocks)
if expected := 0; len(r.pool) != expected {
t.Fatalf("rand pool had wrong number of items: %d, expected:%d", len(r.pool), expected)
}
// feed blocks
blocks = azure.BlockListResponse{
CommittedBlocks: []azure.BlockResponse{
{"1", 1},
{"2", 2},
},
UncommittedBlocks: []azure.BlockResponse{
{"3", 3},
}}
r.Feed(blocks)
if expected := 3; len(r.pool) != expected {
t.Fatalf("rand pool had wrong number of items: %d, expected:%d", len(r.pool), expected)
}
// feed same block IDs with committed/uncommitted place changed
blocks = azure.BlockListResponse{
CommittedBlocks: []azure.BlockResponse{
{"3", 3},
},
UncommittedBlocks: []azure.BlockResponse{
{"1", 1},
}}
r.Feed(blocks)
if expected := 3; len(r.pool) != expected {
t.Fatalf("rand pool had wrong number of items: %d, expected:%d", len(r.pool), expected)
}
}
func Test_toBlockId(t *testing.T) {
min := 0
max := math.MaxInt64
if len(toBlockID(min)) != len(toBlockID(max)) {
t.Fatalf("different-sized blockIDs are returned")
}
}

View file

@ -1,208 +0,0 @@
package azure
import (
"fmt"
"io"
"io/ioutil"
azure "github.com/Azure/azure-sdk-for-go/storage"
)
// blockStorage is the interface required from a block storage service
// client implementation
type blockStorage interface {
CreateBlockBlob(container, blob string) error
GetBlob(container, blob string) (io.ReadCloser, error)
GetSectionReader(container, blob string, start, length int64) (io.ReadCloser, error)
PutBlock(container, blob, blockID string, chunk []byte) error
GetBlockList(container, blob string, blockType azure.BlockListType) (azure.BlockListResponse, error)
PutBlockList(container, blob string, blocks []azure.Block) error
}
// randomBlobWriter enables random access semantics on Azure block blobs
// by enabling writing arbitrary length of chunks to arbitrary write offsets
// within the blob. Normally, Azure Blob Storage does not support random
// access semantics on block blobs; however, this writer can download, split and
// reupload the overlapping blocks and discards those being overwritten entirely.
type randomBlobWriter struct {
bs blockStorage
blockSize int
}
func newRandomBlobWriter(bs blockStorage, blockSize int) randomBlobWriter {
return randomBlobWriter{bs: bs, blockSize: blockSize}
}
// WriteBlobAt writes the given chunk to the specified position of an existing blob.
// The offset must be equals to size of the blob or smaller than it.
func (r *randomBlobWriter) WriteBlobAt(container, blob string, offset int64, chunk io.Reader) (int64, error) {
rand := newBlockIDGenerator()
blocks, err := r.bs.GetBlockList(container, blob, azure.BlockListTypeCommitted)
if err != nil {
return 0, err
}
rand.Feed(blocks) // load existing block IDs
// Check for write offset for existing blob
size := getBlobSize(blocks)
if offset < 0 || offset > size {
return 0, fmt.Errorf("wrong offset for Write: %v", offset)
}
// Upload the new chunk as blocks
blockList, nn, err := r.writeChunkToBlocks(container, blob, chunk, rand)
if err != nil {
return 0, err
}
// For non-append operations, existing blocks may need to be splitted
if offset != size {
// Split the block on the left end (if any)
leftBlocks, err := r.blocksLeftSide(container, blob, offset, rand)
if err != nil {
return 0, err
}
blockList = append(leftBlocks, blockList...)
// Split the block on the right end (if any)
rightBlocks, err := r.blocksRightSide(container, blob, offset, nn, rand)
if err != nil {
return 0, err
}
blockList = append(blockList, rightBlocks...)
} else {
// Use existing block list
var existingBlocks []azure.Block
for _, v := range blocks.CommittedBlocks {
existingBlocks = append(existingBlocks, azure.Block{ID: v.Name, Status: azure.BlockStatusCommitted})
}
blockList = append(existingBlocks, blockList...)
}
// Put block list
return nn, r.bs.PutBlockList(container, blob, blockList)
}
func (r *randomBlobWriter) GetSize(container, blob string) (int64, error) {
blocks, err := r.bs.GetBlockList(container, blob, azure.BlockListTypeCommitted)
if err != nil {
return 0, err
}
return getBlobSize(blocks), nil
}
// writeChunkToBlocks writes given chunk to one or multiple blocks within specified
// blob and returns their block representations. Those blocks are not committed, yet
func (r *randomBlobWriter) writeChunkToBlocks(container, blob string, chunk io.Reader, rand *blockIDGenerator) ([]azure.Block, int64, error) {
var newBlocks []azure.Block
var nn int64
// Read chunks of at most size N except the last chunk to
// maximize block size and minimize block count.
buf := make([]byte, r.blockSize)
for {
n, err := io.ReadFull(chunk, buf)
if err == io.EOF {
break
}
nn += int64(n)
data := buf[:n]
blockID := rand.Generate()
if err := r.bs.PutBlock(container, blob, blockID, data); err != nil {
return newBlocks, nn, err
}
newBlocks = append(newBlocks, azure.Block{ID: blockID, Status: azure.BlockStatusUncommitted})
}
return newBlocks, nn, nil
}
// blocksLeftSide returns the blocks that are going to be at the left side of
// the writeOffset: [0, writeOffset) by identifying blocks that will remain
// the same and splitting blocks and reuploading them as needed.
func (r *randomBlobWriter) blocksLeftSide(container, blob string, writeOffset int64, rand *blockIDGenerator) ([]azure.Block, error) {
var left []azure.Block
bx, err := r.bs.GetBlockList(container, blob, azure.BlockListTypeAll)
if err != nil {
return left, err
}
o := writeOffset
elapsed := int64(0)
for _, v := range bx.CommittedBlocks {
blkSize := int64(v.Size)
if o >= blkSize { // use existing block
left = append(left, azure.Block{ID: v.Name, Status: azure.BlockStatusCommitted})
o -= blkSize
elapsed += blkSize
} else if o > 0 { // current block needs to be splitted
start := elapsed
size := o
part, err := r.bs.GetSectionReader(container, blob, start, size)
if err != nil {
return left, err
}
newBlockID := rand.Generate()
data, err := ioutil.ReadAll(part)
if err != nil {
return left, err
}
if err = r.bs.PutBlock(container, blob, newBlockID, data); err != nil {
return left, err
}
left = append(left, azure.Block{ID: newBlockID, Status: azure.BlockStatusUncommitted})
break
}
}
return left, nil
}
// blocksRightSide returns the blocks that are going to be at the right side of
// the written chunk: [writeOffset+size, +inf) by identifying blocks that will remain
// the same and splitting blocks and reuploading them as needed.
func (r *randomBlobWriter) blocksRightSide(container, blob string, writeOffset int64, chunkSize int64, rand *blockIDGenerator) ([]azure.Block, error) {
var right []azure.Block
bx, err := r.bs.GetBlockList(container, blob, azure.BlockListTypeAll)
if err != nil {
return nil, err
}
re := writeOffset + chunkSize - 1 // right end of written chunk
var elapsed int64
for _, v := range bx.CommittedBlocks {
var (
bs = elapsed // left end of current block
be = elapsed + int64(v.Size) - 1 // right end of current block
)
if bs > re { // take the block as is
right = append(right, azure.Block{ID: v.Name, Status: azure.BlockStatusCommitted})
} else if be > re { // current block needs to be splitted
part, err := r.bs.GetSectionReader(container, blob, re+1, be-(re+1)+1)
if err != nil {
return right, err
}
newBlockID := rand.Generate()
data, err := ioutil.ReadAll(part)
if err != nil {
return right, err
}
if err = r.bs.PutBlock(container, blob, newBlockID, data); err != nil {
return right, err
}
right = append(right, azure.Block{ID: newBlockID, Status: azure.BlockStatusUncommitted})
}
elapsed += int64(v.Size)
}
return right, nil
}
func getBlobSize(blocks azure.BlockListResponse) int64 {
var n int64
for _, v := range blocks.CommittedBlocks {
n += int64(v.Size)
}
return n
}

View file

@ -1,339 +0,0 @@
package azure
import (
"bytes"
"io"
"io/ioutil"
"math/rand"
"reflect"
"strings"
"testing"
azure "github.com/Azure/azure-sdk-for-go/storage"
)
func TestRandomWriter_writeChunkToBlocks(t *testing.T) {
s := NewStorageSimulator()
rw := newRandomBlobWriter(&s, 3)
rand := newBlockIDGenerator()
c := []byte("AAABBBCCCD")
if err := rw.bs.CreateBlockBlob("a", "b"); err != nil {
t.Fatal(err)
}
bw, nn, err := rw.writeChunkToBlocks("a", "b", bytes.NewReader(c), rand)
if err != nil {
t.Fatal(err)
}
if expected := int64(len(c)); nn != expected {
t.Fatalf("wrong nn:%v, expected:%v", nn, expected)
}
if expected := 4; len(bw) != expected {
t.Fatal("unexpected written block count")
}
bx, err := s.GetBlockList("a", "b", azure.BlockListTypeAll)
if err != nil {
t.Fatal(err)
}
if expected := 0; len(bx.CommittedBlocks) != expected {
t.Fatal("unexpected committed block count")
}
if expected := 4; len(bx.UncommittedBlocks) != expected {
t.Fatalf("unexpected uncommitted block count: %d -- %#v", len(bx.UncommittedBlocks), bx)
}
if err := rw.bs.PutBlockList("a", "b", bw); err != nil {
t.Fatal(err)
}
r, err := rw.bs.GetBlob("a", "b")
if err != nil {
t.Fatal(err)
}
assertBlobContents(t, r, c)
}
func TestRandomWriter_blocksLeftSide(t *testing.T) {
blob := "AAAAABBBBBCCC"
cases := []struct {
offset int64
expectedBlob string
expectedPattern []azure.BlockStatus
}{
{0, "", []azure.BlockStatus{}}, // write to beginning, discard all
{13, blob, []azure.BlockStatus{azure.BlockStatusCommitted, azure.BlockStatusCommitted, azure.BlockStatusCommitted}}, // write to end, no change
{1, "A", []azure.BlockStatus{azure.BlockStatusUncommitted}}, // write at 1
{5, "AAAAA", []azure.BlockStatus{azure.BlockStatusCommitted}}, // write just after first block
{6, "AAAAAB", []azure.BlockStatus{azure.BlockStatusCommitted, azure.BlockStatusUncommitted}}, // split the second block
{9, "AAAAABBBB", []azure.BlockStatus{azure.BlockStatusCommitted, azure.BlockStatusUncommitted}}, // write just after first block
}
for _, c := range cases {
s := NewStorageSimulator()
rw := newRandomBlobWriter(&s, 5)
rand := newBlockIDGenerator()
if err := rw.bs.CreateBlockBlob("a", "b"); err != nil {
t.Fatal(err)
}
bw, _, err := rw.writeChunkToBlocks("a", "b", strings.NewReader(blob), rand)
if err != nil {
t.Fatal(err)
}
if err := rw.bs.PutBlockList("a", "b", bw); err != nil {
t.Fatal(err)
}
bx, err := rw.blocksLeftSide("a", "b", c.offset, rand)
if err != nil {
t.Fatal(err)
}
bs := []azure.BlockStatus{}
for _, v := range bx {
bs = append(bs, v.Status)
}
if !reflect.DeepEqual(bs, c.expectedPattern) {
t.Logf("Committed blocks %v", bw)
t.Fatalf("For offset %v: Expected pattern: %v, Got: %v\n(Returned: %v)", c.offset, c.expectedPattern, bs, bx)
}
if rw.bs.PutBlockList("a", "b", bx); err != nil {
t.Fatal(err)
}
r, err := rw.bs.GetBlob("a", "b")
if err != nil {
t.Fatal(err)
}
cout, err := ioutil.ReadAll(r)
if err != nil {
t.Fatal(err)
}
outBlob := string(cout)
if outBlob != c.expectedBlob {
t.Fatalf("wrong blob contents: %v, expected: %v", outBlob, c.expectedBlob)
}
}
}
func TestRandomWriter_blocksRightSide(t *testing.T) {
blob := "AAAAABBBBBCCC"
cases := []struct {
offset int64
size int64
expectedBlob string
expectedPattern []azure.BlockStatus
}{
{0, 100, "", []azure.BlockStatus{}}, // overwrite the entire blob
{0, 3, "AABBBBBCCC", []azure.BlockStatus{azure.BlockStatusUncommitted, azure.BlockStatusCommitted, azure.BlockStatusCommitted}}, // split first block
{4, 1, "BBBBBCCC", []azure.BlockStatus{azure.BlockStatusCommitted, azure.BlockStatusCommitted}}, // write to last char of first block
{1, 6, "BBBCCC", []azure.BlockStatus{azure.BlockStatusUncommitted, azure.BlockStatusCommitted}}, // overwrite splits first and second block, last block remains
{3, 8, "CC", []azure.BlockStatus{azure.BlockStatusUncommitted}}, // overwrite a block in middle block, split end block
{10, 1, "CC", []azure.BlockStatus{azure.BlockStatusUncommitted}}, // overwrite first byte of rightmost block
{11, 2, "", []azure.BlockStatus{}}, // overwrite the rightmost index
{13, 20, "", []azure.BlockStatus{}}, // append to the end
}
for _, c := range cases {
s := NewStorageSimulator()
rw := newRandomBlobWriter(&s, 5)
rand := newBlockIDGenerator()
if err := rw.bs.CreateBlockBlob("a", "b"); err != nil {
t.Fatal(err)
}
bw, _, err := rw.writeChunkToBlocks("a", "b", strings.NewReader(blob), rand)
if err != nil {
t.Fatal(err)
}
if err := rw.bs.PutBlockList("a", "b", bw); err != nil {
t.Fatal(err)
}
bx, err := rw.blocksRightSide("a", "b", c.offset, c.size, rand)
if err != nil {
t.Fatal(err)
}
bs := []azure.BlockStatus{}
for _, v := range bx {
bs = append(bs, v.Status)
}
if !reflect.DeepEqual(bs, c.expectedPattern) {
t.Logf("Committed blocks %v", bw)
t.Fatalf("For offset %v-size:%v: Expected pattern: %v, Got: %v\n(Returned: %v)", c.offset, c.size, c.expectedPattern, bs, bx)
}
if rw.bs.PutBlockList("a", "b", bx); err != nil {
t.Fatal(err)
}
r, err := rw.bs.GetBlob("a", "b")
if err != nil {
t.Fatal(err)
}
cout, err := ioutil.ReadAll(r)
if err != nil {
t.Fatal(err)
}
outBlob := string(cout)
if outBlob != c.expectedBlob {
t.Fatalf("For offset %v-size:%v: wrong blob contents: %v, expected: %v", c.offset, c.size, outBlob, c.expectedBlob)
}
}
}
func TestRandomWriter_Write_NewBlob(t *testing.T) {
var (
s = NewStorageSimulator()
rw = newRandomBlobWriter(&s, 1024*3) // 3 KB blocks
blob = randomContents(1024 * 7) // 7 KB blob
)
if err := rw.bs.CreateBlockBlob("a", "b"); err != nil {
t.Fatal(err)
}
if _, err := rw.WriteBlobAt("a", "b", 10, bytes.NewReader(blob)); err == nil {
t.Fatal("expected error, got nil")
}
if _, err := rw.WriteBlobAt("a", "b", 100000, bytes.NewReader(blob)); err == nil {
t.Fatal("expected error, got nil")
}
if nn, err := rw.WriteBlobAt("a", "b", 0, bytes.NewReader(blob)); err != nil {
t.Fatal(err)
} else if expected := int64(len(blob)); expected != nn {
t.Fatalf("wrong written bytes count: %v, expected: %v", nn, expected)
}
if out, err := rw.bs.GetBlob("a", "b"); err != nil {
t.Fatal(err)
} else {
assertBlobContents(t, out, blob)
}
if bx, err := rw.bs.GetBlockList("a", "b", azure.BlockListTypeCommitted); err != nil {
t.Fatal(err)
} else if len(bx.CommittedBlocks) != 3 {
t.Fatalf("got wrong number of committed blocks: %v", len(bx.CommittedBlocks))
}
// Replace first 512 bytes
leftChunk := randomContents(512)
blob = append(leftChunk, blob[512:]...)
if nn, err := rw.WriteBlobAt("a", "b", 0, bytes.NewReader(leftChunk)); err != nil {
t.Fatal(err)
} else if expected := int64(len(leftChunk)); expected != nn {
t.Fatalf("wrong written bytes count: %v, expected: %v", nn, expected)
}
if out, err := rw.bs.GetBlob("a", "b"); err != nil {
t.Fatal(err)
} else {
assertBlobContents(t, out, blob)
}
if bx, err := rw.bs.GetBlockList("a", "b", azure.BlockListTypeCommitted); err != nil {
t.Fatal(err)
} else if expected := 4; len(bx.CommittedBlocks) != expected {
t.Fatalf("got wrong number of committed blocks: %v, expected: %v", len(bx.CommittedBlocks), expected)
}
// Replace last 512 bytes with 1024 bytes
rightChunk := randomContents(1024)
offset := int64(len(blob) - 512)
blob = append(blob[:offset], rightChunk...)
if nn, err := rw.WriteBlobAt("a", "b", offset, bytes.NewReader(rightChunk)); err != nil {
t.Fatal(err)
} else if expected := int64(len(rightChunk)); expected != nn {
t.Fatalf("wrong written bytes count: %v, expected: %v", nn, expected)
}
if out, err := rw.bs.GetBlob("a", "b"); err != nil {
t.Fatal(err)
} else {
assertBlobContents(t, out, blob)
}
if bx, err := rw.bs.GetBlockList("a", "b", azure.BlockListTypeCommitted); err != nil {
t.Fatal(err)
} else if expected := 5; len(bx.CommittedBlocks) != expected {
t.Fatalf("got wrong number of committed blocks: %v, expected: %v", len(bx.CommittedBlocks), expected)
}
// Replace 2K-4K (overlaps 2 blocks from L/R)
newChunk := randomContents(1024 * 2)
offset = 1024 * 2
blob = append(append(blob[:offset], newChunk...), blob[offset+int64(len(newChunk)):]...)
if nn, err := rw.WriteBlobAt("a", "b", offset, bytes.NewReader(newChunk)); err != nil {
t.Fatal(err)
} else if expected := int64(len(newChunk)); expected != nn {
t.Fatalf("wrong written bytes count: %v, expected: %v", nn, expected)
}
if out, err := rw.bs.GetBlob("a", "b"); err != nil {
t.Fatal(err)
} else {
assertBlobContents(t, out, blob)
}
if bx, err := rw.bs.GetBlockList("a", "b", azure.BlockListTypeCommitted); err != nil {
t.Fatal(err)
} else if expected := 6; len(bx.CommittedBlocks) != expected {
t.Fatalf("got wrong number of committed blocks: %v, expected: %v\n%v", len(bx.CommittedBlocks), expected, bx.CommittedBlocks)
}
// Replace the entire blob
newBlob := randomContents(1024 * 30)
if nn, err := rw.WriteBlobAt("a", "b", 0, bytes.NewReader(newBlob)); err != nil {
t.Fatal(err)
} else if expected := int64(len(newBlob)); expected != nn {
t.Fatalf("wrong written bytes count: %v, expected: %v", nn, expected)
}
if out, err := rw.bs.GetBlob("a", "b"); err != nil {
t.Fatal(err)
} else {
assertBlobContents(t, out, newBlob)
}
if bx, err := rw.bs.GetBlockList("a", "b", azure.BlockListTypeCommitted); err != nil {
t.Fatal(err)
} else if expected := 10; len(bx.CommittedBlocks) != expected {
t.Fatalf("got wrong number of committed blocks: %v, expected: %v\n%v", len(bx.CommittedBlocks), expected, bx.CommittedBlocks)
} else if expected, size := int64(1024*30), getBlobSize(bx); size != expected {
t.Fatalf("committed block size does not indicate blob size")
}
}
func Test_getBlobSize(t *testing.T) {
// with some committed blocks
if expected, size := int64(151), getBlobSize(azure.BlockListResponse{
CommittedBlocks: []azure.BlockResponse{
{"A", 100},
{"B", 50},
{"C", 1},
},
UncommittedBlocks: []azure.BlockResponse{
{"D", 200},
}}); expected != size {
t.Fatalf("wrong blob size: %v, expected: %v", size, expected)
}
// with no committed blocks
if expected, size := int64(0), getBlobSize(azure.BlockListResponse{
UncommittedBlocks: []azure.BlockResponse{
{"A", 100},
{"B", 50},
{"C", 1},
{"D", 200},
}}); expected != size {
t.Fatalf("wrong blob size: %v, expected: %v", size, expected)
}
}
func assertBlobContents(t *testing.T, r io.Reader, expected []byte) {
out, err := ioutil.ReadAll(r)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(out, expected) {
t.Fatalf("wrong blob contents. size: %v, expected: %v", len(out), len(expected))
}
}
func randomContents(length int64) []byte {
b := make([]byte, length)
for i := range b {
b[i] = byte(rand.Intn(2 << 8))
}
return b
}

View file

@ -1,49 +0,0 @@
package azure
import (
"bytes"
"io"
)
type blockBlobWriter interface {
GetSize(container, blob string) (int64, error)
WriteBlobAt(container, blob string, offset int64, chunk io.Reader) (int64, error)
}
// zeroFillWriter enables writing to an offset outside a block blob's size
// by offering the chunk to the underlying writer as a contiguous data with
// the gap in between filled with NUL (zero) bytes.
type zeroFillWriter struct {
blockBlobWriter
}
func newZeroFillWriter(b blockBlobWriter) zeroFillWriter {
w := zeroFillWriter{}
w.blockBlobWriter = b
return w
}
// Write writes the given chunk to the specified existing blob even though
// offset is out of blob's size. The gaps are filled with zeros. Returned
// written number count does not include zeros written.
func (z *zeroFillWriter) Write(container, blob string, offset int64, chunk io.Reader) (int64, error) {
size, err := z.blockBlobWriter.GetSize(container, blob)
if err != nil {
return 0, err
}
var reader io.Reader
var zeroPadding int64
if offset <= size {
reader = chunk
} else {
zeroPadding = offset - size
offset = size // adjust offset to be the append index
zeros := bytes.NewReader(make([]byte, zeroPadding))
reader = io.MultiReader(zeros, chunk)
}
nn, err := z.blockBlobWriter.WriteBlobAt(container, blob, offset, reader)
nn -= zeroPadding
return nn, err
}

View file

@ -1,126 +0,0 @@
package azure
import (
"bytes"
"testing"
)
func Test_zeroFillWrite_AppendNoGap(t *testing.T) {
s := NewStorageSimulator()
bw := newRandomBlobWriter(&s, 1024*1)
zw := newZeroFillWriter(&bw)
if err := s.CreateBlockBlob("a", "b"); err != nil {
t.Fatal(err)
}
firstChunk := randomContents(1024*3 + 512)
if nn, err := zw.Write("a", "b", 0, bytes.NewReader(firstChunk)); err != nil {
t.Fatal(err)
} else if expected := int64(len(firstChunk)); expected != nn {
t.Fatalf("wrong written bytes count: %v, expected: %v", nn, expected)
}
if out, err := s.GetBlob("a", "b"); err != nil {
t.Fatal(err)
} else {
assertBlobContents(t, out, firstChunk)
}
secondChunk := randomContents(256)
if nn, err := zw.Write("a", "b", int64(len(firstChunk)), bytes.NewReader(secondChunk)); err != nil {
t.Fatal(err)
} else if expected := int64(len(secondChunk)); expected != nn {
t.Fatalf("wrong written bytes count: %v, expected: %v", nn, expected)
}
if out, err := s.GetBlob("a", "b"); err != nil {
t.Fatal(err)
} else {
assertBlobContents(t, out, append(firstChunk, secondChunk...))
}
}
func Test_zeroFillWrite_StartWithGap(t *testing.T) {
s := NewStorageSimulator()
bw := newRandomBlobWriter(&s, 1024*2)
zw := newZeroFillWriter(&bw)
if err := s.CreateBlockBlob("a", "b"); err != nil {
t.Fatal(err)
}
chunk := randomContents(1024 * 5)
padding := int64(1024*2 + 256)
if nn, err := zw.Write("a", "b", padding, bytes.NewReader(chunk)); err != nil {
t.Fatal(err)
} else if expected := int64(len(chunk)); expected != nn {
t.Fatalf("wrong written bytes count: %v, expected: %v", nn, expected)
}
if out, err := s.GetBlob("a", "b"); err != nil {
t.Fatal(err)
} else {
assertBlobContents(t, out, append(make([]byte, padding), chunk...))
}
}
func Test_zeroFillWrite_AppendWithGap(t *testing.T) {
s := NewStorageSimulator()
bw := newRandomBlobWriter(&s, 1024*2)
zw := newZeroFillWriter(&bw)
if err := s.CreateBlockBlob("a", "b"); err != nil {
t.Fatal(err)
}
firstChunk := randomContents(1024*3 + 512)
if _, err := zw.Write("a", "b", 0, bytes.NewReader(firstChunk)); err != nil {
t.Fatal(err)
}
if out, err := s.GetBlob("a", "b"); err != nil {
t.Fatal(err)
} else {
assertBlobContents(t, out, firstChunk)
}
secondChunk := randomContents(256)
padding := int64(1024 * 4)
if nn, err := zw.Write("a", "b", int64(len(firstChunk))+padding, bytes.NewReader(secondChunk)); err != nil {
t.Fatal(err)
} else if expected := int64(len(secondChunk)); expected != nn {
t.Fatalf("wrong written bytes count: %v, expected: %v", nn, expected)
}
if out, err := s.GetBlob("a", "b"); err != nil {
t.Fatal(err)
} else {
assertBlobContents(t, out, append(firstChunk, append(make([]byte, padding), secondChunk...)...))
}
}
func Test_zeroFillWrite_LiesWithinSize(t *testing.T) {
s := NewStorageSimulator()
bw := newRandomBlobWriter(&s, 1024*2)
zw := newZeroFillWriter(&bw)
if err := s.CreateBlockBlob("a", "b"); err != nil {
t.Fatal(err)
}
firstChunk := randomContents(1024 * 3)
if _, err := zw.Write("a", "b", 0, bytes.NewReader(firstChunk)); err != nil {
t.Fatal(err)
}
if out, err := s.GetBlob("a", "b"); err != nil {
t.Fatal(err)
} else {
assertBlobContents(t, out, firstChunk)
}
// in this case, zerofill won't be used
secondChunk := randomContents(256)
if nn, err := zw.Write("a", "b", 0, bytes.NewReader(secondChunk)); err != nil {
t.Fatal(err)
} else if expected := int64(len(secondChunk)); expected != nn {
t.Fatalf("wrong written bytes count: %v, expected: %v", nn, expected)
}
if out, err := s.GetBlob("a", "b"); err != nil {
t.Fatal(err)
} else {
assertBlobContents(t, out, append(secondChunk, firstChunk[len(secondChunk):]...))
}
}

View file

@ -102,10 +102,10 @@ func (base *Base) PutContent(ctx context.Context, path string, content []byte) e
return base.setDriverName(base.StorageDriver.PutContent(ctx, path, content)) return base.setDriverName(base.StorageDriver.PutContent(ctx, path, content))
} }
// ReadStream wraps ReadStream of underlying storage driver. // Reader wraps Reader of underlying storage driver.
func (base *Base) ReadStream(ctx context.Context, path string, offset int64) (io.ReadCloser, error) { func (base *Base) Reader(ctx context.Context, path string, offset int64) (io.ReadCloser, error) {
ctx, done := context.WithTrace(ctx) ctx, done := context.WithTrace(ctx)
defer done("%s.ReadStream(%q, %d)", base.Name(), path, offset) defer done("%s.Reader(%q, %d)", base.Name(), path, offset)
if offset < 0 { if offset < 0 {
return nil, storagedriver.InvalidOffsetError{Path: path, Offset: offset, DriverName: base.StorageDriver.Name()} return nil, storagedriver.InvalidOffsetError{Path: path, Offset: offset, DriverName: base.StorageDriver.Name()}
@ -115,25 +115,21 @@ func (base *Base) ReadStream(ctx context.Context, path string, offset int64) (io
return nil, storagedriver.InvalidPathError{Path: path, DriverName: base.StorageDriver.Name()} return nil, storagedriver.InvalidPathError{Path: path, DriverName: base.StorageDriver.Name()}
} }
rc, e := base.StorageDriver.ReadStream(ctx, path, offset) rc, e := base.StorageDriver.Reader(ctx, path, offset)
return rc, base.setDriverName(e) return rc, base.setDriverName(e)
} }
// WriteStream wraps WriteStream of underlying storage driver. // Writer wraps Writer of underlying storage driver.
func (base *Base) WriteStream(ctx context.Context, path string, offset int64, reader io.Reader) (nn int64, err error) { func (base *Base) Writer(ctx context.Context, path string, append bool) (storagedriver.FileWriter, error) {
ctx, done := context.WithTrace(ctx) ctx, done := context.WithTrace(ctx)
defer done("%s.WriteStream(%q, %d)", base.Name(), path, offset) defer done("%s.Writer(%q, %v)", base.Name(), path, append)
if offset < 0 {
return 0, storagedriver.InvalidOffsetError{Path: path, Offset: offset, DriverName: base.StorageDriver.Name()}
}
if !storagedriver.PathRegexp.MatchString(path) { if !storagedriver.PathRegexp.MatchString(path) {
return 0, storagedriver.InvalidPathError{Path: path, DriverName: base.StorageDriver.Name()} return nil, storagedriver.InvalidPathError{Path: path, DriverName: base.StorageDriver.Name()}
} }
i64, e := base.StorageDriver.WriteStream(ctx, path, offset, reader) writer, e := base.StorageDriver.Writer(ctx, path, append)
return i64, base.setDriverName(e) return writer, base.setDriverName(e)
} }
// Stat wraps Stat of underlying storage driver. // Stat wraps Stat of underlying storage driver.

View file

@ -0,0 +1,145 @@
package base
import (
"io"
"sync"
"github.com/docker/distribution/context"
storagedriver "github.com/docker/distribution/registry/storage/driver"
)
type regulator struct {
storagedriver.StorageDriver
*sync.Cond
available uint64
}
// NewRegulator wraps the given driver and is used to regulate concurrent calls
// to the given storage driver to a maximum of the given limit. This is useful
// for storage drivers that would otherwise create an unbounded number of OS
// threads if allowed to be called unregulated.
func NewRegulator(driver storagedriver.StorageDriver, limit uint64) storagedriver.StorageDriver {
return &regulator{
StorageDriver: driver,
Cond: sync.NewCond(&sync.Mutex{}),
available: limit,
}
}
func (r *regulator) enter() {
r.L.Lock()
for r.available == 0 {
r.Wait()
}
r.available--
r.L.Unlock()
}
func (r *regulator) exit() {
r.L.Lock()
// We only need to signal to a waiting FS operation if we're already at the
// limit of threads used
if r.available == 0 {
r.Signal()
}
r.available++
r.L.Unlock()
}
// Name returns the human-readable "name" of the driver, useful in error
// messages and logging. By convention, this will just be the registration
// name, but drivers may provide other information here.
func (r *regulator) Name() string {
r.enter()
defer r.exit()
return r.StorageDriver.Name()
}
// GetContent retrieves the content stored at "path" as a []byte.
// This should primarily be used for small objects.
func (r *regulator) GetContent(ctx context.Context, path string) ([]byte, error) {
r.enter()
defer r.exit()
return r.StorageDriver.GetContent(ctx, path)
}
// PutContent stores the []byte content at a location designated by "path".
// This should primarily be used for small objects.
func (r *regulator) PutContent(ctx context.Context, path string, content []byte) error {
r.enter()
defer r.exit()
return r.StorageDriver.PutContent(ctx, path, content)
}
// Reader retrieves an io.ReadCloser for the content stored at "path"
// with a given byte offset.
// May be used to resume reading a stream by providing a nonzero offset.
func (r *regulator) Reader(ctx context.Context, path string, offset int64) (io.ReadCloser, error) {
r.enter()
defer r.exit()
return r.StorageDriver.Reader(ctx, path, offset)
}
// Writer stores the contents of the provided io.ReadCloser at a
// location designated by the given path.
// May be used to resume writing a stream by providing a nonzero offset.
// The offset must be no larger than the CurrentSize for this path.
func (r *regulator) Writer(ctx context.Context, path string, append bool) (storagedriver.FileWriter, error) {
r.enter()
defer r.exit()
return r.StorageDriver.Writer(ctx, path, append)
}
// Stat retrieves the FileInfo for the given path, including the current
// size in bytes and the creation time.
func (r *regulator) Stat(ctx context.Context, path string) (storagedriver.FileInfo, error) {
r.enter()
defer r.exit()
return r.StorageDriver.Stat(ctx, path)
}
// List returns a list of the objects that are direct descendants of the
//given path.
func (r *regulator) List(ctx context.Context, path string) ([]string, error) {
r.enter()
defer r.exit()
return r.StorageDriver.List(ctx, path)
}
// Move moves an object stored at sourcePath to destPath, removing the
// original object.
// Note: This may be no more efficient than a copy followed by a delete for
// many implementations.
func (r *regulator) Move(ctx context.Context, sourcePath string, destPath string) error {
r.enter()
defer r.exit()
return r.StorageDriver.Move(ctx, sourcePath, destPath)
}
// Delete recursively deletes all objects stored at "path" and its subpaths.
func (r *regulator) Delete(ctx context.Context, path string) error {
r.enter()
defer r.exit()
return r.StorageDriver.Delete(ctx, path)
}
// URLFor returns a URL which may be used to retrieve the content stored at
// the given path, possibly using the given options.
// May return an ErrUnsupportedMethod in certain StorageDriver
// implementations.
func (r *regulator) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) {
r.enter()
defer r.exit()
return r.StorageDriver.URLFor(ctx, path, options)
}

View file

@ -11,7 +11,14 @@ import (
var driverFactories = make(map[string]StorageDriverFactory) var driverFactories = make(map[string]StorageDriverFactory)
// StorageDriverFactory is a factory interface for creating storagedriver.StorageDriver interfaces // StorageDriverFactory is a factory interface for creating storagedriver.StorageDriver interfaces
// Storage drivers should call Register() with a factory to make the driver available by name // Storage drivers should call Register() with a factory to make the driver available by name.
// Individual StorageDriver implementations generally register with the factory via the Register
// func (below) in their init() funcs, and as such they should be imported anonymously before use.
// See below for an example of how to register and get a StorageDriver for S3
//
// import _ "github.com/docker/distribution/registry/storage/driver/s3-aws"
// s3Driver, err = factory.Create("s3", storageParams)
// // assuming no error, s3Driver is the StorageDriver that communicates with S3 according to storageParams
type StorageDriverFactory interface { type StorageDriverFactory interface {
// Create returns a new storagedriver.StorageDriver with the given parameters // Create returns a new storagedriver.StorageDriver with the given parameters
// Parameters will vary by driver and may be ignored // Parameters will vary by driver and may be ignored
@ -21,6 +28,8 @@ type StorageDriverFactory interface {
// Register makes a storage driver available by the provided name. // Register makes a storage driver available by the provided name.
// If Register is called twice with the same name or if driver factory is nil, it panics. // If Register is called twice with the same name or if driver factory is nil, it panics.
// Additionally, it is not concurrency safe. Most Storage Drivers call this function
// in their init() functions. See the documentation for StorageDriverFactory for more.
func Register(name string, factory StorageDriverFactory) { func Register(name string, factory StorageDriverFactory) {
if factory == nil { if factory == nil {
panic("Must not provide nil StorageDriverFactory") panic("Must not provide nil StorageDriverFactory")

View file

@ -1,12 +1,15 @@
package filesystem package filesystem
import ( import (
"bufio"
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"os" "os"
"path" "path"
"reflect"
"strconv"
"time" "time"
"github.com/docker/distribution/context" "github.com/docker/distribution/context"
@ -15,8 +18,23 @@ import (
"github.com/docker/distribution/registry/storage/driver/factory" "github.com/docker/distribution/registry/storage/driver/factory"
) )
const driverName = "filesystem" const (
const defaultRootDirectory = "/var/lib/registry" driverName = "filesystem"
defaultRootDirectory = "/var/lib/registry"
defaultMaxThreads = uint64(100)
// minThreads is the minimum value for the maxthreads configuration
// parameter. If the driver's parameters are less than this we set
// the parameters to minThreads
minThreads = uint64(25)
)
// DriverParameters represents all configuration options available for the
// filesystem driver
type DriverParameters struct {
RootDirectory string
MaxThreads uint64
}
func init() { func init() {
factory.Register(driverName, &filesystemDriverFactory{}) factory.Register(driverName, &filesystemDriverFactory{})
@ -26,7 +44,7 @@ func init() {
type filesystemDriverFactory struct{} type filesystemDriverFactory struct{}
func (factory *filesystemDriverFactory) Create(parameters map[string]interface{}) (storagedriver.StorageDriver, error) { func (factory *filesystemDriverFactory) Create(parameters map[string]interface{}) (storagedriver.StorageDriver, error) {
return FromParameters(parameters), nil return FromParameters(parameters)
} }
type driver struct { type driver struct {
@ -46,25 +64,72 @@ type Driver struct {
// FromParameters constructs a new Driver with a given parameters map // FromParameters constructs a new Driver with a given parameters map
// Optional Parameters: // Optional Parameters:
// - rootdirectory // - rootdirectory
func FromParameters(parameters map[string]interface{}) *Driver { // - maxthreads
var rootDirectory = defaultRootDirectory func FromParameters(parameters map[string]interface{}) (*Driver, error) {
params, err := fromParametersImpl(parameters)
if err != nil || params == nil {
return nil, err
}
return New(*params), nil
}
func fromParametersImpl(parameters map[string]interface{}) (*DriverParameters, error) {
var (
err error
maxThreads = defaultMaxThreads
rootDirectory = defaultRootDirectory
)
if parameters != nil { if parameters != nil {
rootDir, ok := parameters["rootdirectory"] if rootDir, ok := parameters["rootdirectory"]; ok {
if ok {
rootDirectory = fmt.Sprint(rootDir) rootDirectory = fmt.Sprint(rootDir)
} }
// Get maximum number of threads for blocking filesystem operations,
// if specified
threads := parameters["maxthreads"]
switch v := threads.(type) {
case string:
if maxThreads, err = strconv.ParseUint(v, 0, 64); err != nil {
return nil, fmt.Errorf("maxthreads parameter must be an integer, %v invalid", threads)
}
case uint64:
maxThreads = v
case int, int32, int64:
val := reflect.ValueOf(v).Convert(reflect.TypeOf(threads)).Int()
// If threads is negative casting to uint64 will wrap around and
// give you the hugest thread limit ever. Let's be sensible, here
if val > 0 {
maxThreads = uint64(val)
}
case uint, uint32:
maxThreads = reflect.ValueOf(v).Convert(reflect.TypeOf(threads)).Uint()
case nil:
// do nothing
default:
return nil, fmt.Errorf("invalid value for maxthreads: %#v", threads)
}
if maxThreads < minThreads {
maxThreads = minThreads
}
} }
return New(rootDirectory)
params := &DriverParameters{
RootDirectory: rootDirectory,
MaxThreads: maxThreads,
}
return params, nil
} }
// New constructs a new Driver with a given rootDirectory // New constructs a new Driver with a given rootDirectory
func New(rootDirectory string) *Driver { func New(params DriverParameters) *Driver {
fsDriver := &driver{rootDirectory: params.RootDirectory}
return &Driver{ return &Driver{
baseEmbed: baseEmbed{ baseEmbed: baseEmbed{
Base: base.Base{ Base: base.Base{
StorageDriver: &driver{ StorageDriver: base.NewRegulator(fsDriver, params.MaxThreads),
rootDirectory: rootDirectory,
},
}, },
}, },
} }
@ -78,7 +143,7 @@ func (d *driver) Name() string {
// GetContent retrieves the content stored at "path" as a []byte. // GetContent retrieves the content stored at "path" as a []byte.
func (d *driver) GetContent(ctx context.Context, path string) ([]byte, error) { func (d *driver) GetContent(ctx context.Context, path string) ([]byte, error) {
rc, err := d.ReadStream(ctx, path, 0) rc, err := d.Reader(ctx, path, 0)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -94,16 +159,22 @@ func (d *driver) GetContent(ctx context.Context, path string) ([]byte, error) {
// PutContent stores the []byte content at a location designated by "path". // PutContent stores the []byte content at a location designated by "path".
func (d *driver) PutContent(ctx context.Context, subPath string, contents []byte) error { func (d *driver) PutContent(ctx context.Context, subPath string, contents []byte) error {
if _, err := d.WriteStream(ctx, subPath, 0, bytes.NewReader(contents)); err != nil { writer, err := d.Writer(ctx, subPath, false)
if err != nil {
return err return err
} }
defer writer.Close()
return os.Truncate(d.fullPath(subPath), int64(len(contents))) _, err = io.Copy(writer, bytes.NewReader(contents))
if err != nil {
writer.Cancel()
return err
}
return writer.Commit()
} }
// ReadStream retrieves an io.ReadCloser for the content stored at "path" with a // Reader retrieves an io.ReadCloser for the content stored at "path" with a
// given byte offset. // given byte offset.
func (d *driver) ReadStream(ctx context.Context, path string, offset int64) (io.ReadCloser, error) { func (d *driver) Reader(ctx context.Context, path string, offset int64) (io.ReadCloser, error) {
file, err := os.OpenFile(d.fullPath(path), os.O_RDONLY, 0644) file, err := os.OpenFile(d.fullPath(path), os.O_RDONLY, 0644)
if err != nil { if err != nil {
if os.IsNotExist(err) { if os.IsNotExist(err) {
@ -125,40 +196,36 @@ func (d *driver) ReadStream(ctx context.Context, path string, offset int64) (io.
return file, nil return file, nil
} }
// WriteStream stores the contents of the provided io.Reader at a location func (d *driver) Writer(ctx context.Context, subPath string, append bool) (storagedriver.FileWriter, error) {
// designated by the given path.
func (d *driver) WriteStream(ctx context.Context, subPath string, offset int64, reader io.Reader) (nn int64, err error) {
// TODO(stevvooe): This needs to be a requirement.
// if !path.IsAbs(subPath) {
// return fmt.Errorf("absolute path required: %q", subPath)
// }
fullPath := d.fullPath(subPath) fullPath := d.fullPath(subPath)
parentDir := path.Dir(fullPath) parentDir := path.Dir(fullPath)
if err := os.MkdirAll(parentDir, 0777); err != nil { if err := os.MkdirAll(parentDir, 0777); err != nil {
return 0, err return nil, err
} }
fp, err := os.OpenFile(fullPath, os.O_WRONLY|os.O_CREATE, 0666) fp, err := os.OpenFile(fullPath, os.O_WRONLY|os.O_CREATE, 0666)
if err != nil { if err != nil {
// TODO(stevvooe): A few missing conditions in storage driver: return nil, err
// 1. What if the path is already a directory?
// 2. Should number 1 be exposed explicitly in storagedriver?
// 2. Can this path not exist, even if we create above?
return 0, err
}
defer fp.Close()
nn, err = fp.Seek(offset, os.SEEK_SET)
if err != nil {
return 0, err
} }
if nn != offset { var offset int64
return 0, fmt.Errorf("bad seek to %v, expected %v in fp=%v", offset, nn, fp)
if !append {
err := fp.Truncate(0)
if err != nil {
fp.Close()
return nil, err
}
} else {
n, err := fp.Seek(0, os.SEEK_END)
if err != nil {
fp.Close()
return nil, err
}
offset = int64(n)
} }
return io.Copy(fp, reader) return newFileWriter(fp, offset), nil
} }
// Stat retrieves the FileInfo for the given path, including the current size // Stat retrieves the FileInfo for the given path, including the current size
@ -286,3 +353,88 @@ func (fi fileInfo) ModTime() time.Time {
func (fi fileInfo) IsDir() bool { func (fi fileInfo) IsDir() bool {
return fi.FileInfo.IsDir() return fi.FileInfo.IsDir()
} }
type fileWriter struct {
file *os.File
size int64
bw *bufio.Writer
closed bool
committed bool
cancelled bool
}
func newFileWriter(file *os.File, size int64) *fileWriter {
return &fileWriter{
file: file,
size: size,
bw: bufio.NewWriter(file),
}
}
func (fw *fileWriter) Write(p []byte) (int, error) {
if fw.closed {
return 0, fmt.Errorf("already closed")
} else if fw.committed {
return 0, fmt.Errorf("already committed")
} else if fw.cancelled {
return 0, fmt.Errorf("already cancelled")
}
n, err := fw.bw.Write(p)
fw.size += int64(n)
return n, err
}
func (fw *fileWriter) Size() int64 {
return fw.size
}
func (fw *fileWriter) Close() error {
if fw.closed {
return fmt.Errorf("already closed")
}
if err := fw.bw.Flush(); err != nil {
return err
}
if err := fw.file.Sync(); err != nil {
return err
}
if err := fw.file.Close(); err != nil {
return err
}
fw.closed = true
return nil
}
func (fw *fileWriter) Cancel() error {
if fw.closed {
return fmt.Errorf("already closed")
}
fw.cancelled = true
fw.file.Close()
return os.Remove(fw.file.Name())
}
func (fw *fileWriter) Commit() error {
if fw.closed {
return fmt.Errorf("already closed")
} else if fw.committed {
return fmt.Errorf("already committed")
} else if fw.cancelled {
return fmt.Errorf("already cancelled")
}
if err := fw.bw.Flush(); err != nil {
return err
}
if err := fw.file.Sync(); err != nil {
return err
}
fw.committed = true
return nil
}

View file

@ -3,6 +3,7 @@ package filesystem
import ( import (
"io/ioutil" "io/ioutil"
"os" "os"
"reflect"
"testing" "testing"
storagedriver "github.com/docker/distribution/registry/storage/driver" storagedriver "github.com/docker/distribution/registry/storage/driver"
@ -20,7 +21,93 @@ func init() {
} }
defer os.Remove(root) defer os.Remove(root)
driver, err := FromParameters(map[string]interface{}{
"rootdirectory": root,
})
if err != nil {
panic(err)
}
testsuites.RegisterSuite(func() (storagedriver.StorageDriver, error) { testsuites.RegisterSuite(func() (storagedriver.StorageDriver, error) {
return New(root), nil return driver, nil
}, testsuites.NeverSkip) }, testsuites.NeverSkip)
} }
func TestFromParametersImpl(t *testing.T) {
tests := []struct {
params map[string]interface{} // techincally the yaml can contain anything
expected DriverParameters
pass bool
}{
// check we use default threads and root dirs
{
params: map[string]interface{}{},
expected: DriverParameters{
RootDirectory: defaultRootDirectory,
MaxThreads: defaultMaxThreads,
},
pass: true,
},
// Testing initiation with a string maxThreads which can't be parsed
{
params: map[string]interface{}{
"maxthreads": "fail",
},
expected: DriverParameters{},
pass: false,
},
{
params: map[string]interface{}{
"maxthreads": "100",
},
expected: DriverParameters{
RootDirectory: defaultRootDirectory,
MaxThreads: uint64(100),
},
pass: true,
},
{
params: map[string]interface{}{
"maxthreads": 100,
},
expected: DriverParameters{
RootDirectory: defaultRootDirectory,
MaxThreads: uint64(100),
},
pass: true,
},
// check that we use minimum thread counts
{
params: map[string]interface{}{
"maxthreads": 1,
},
expected: DriverParameters{
RootDirectory: defaultRootDirectory,
MaxThreads: minThreads,
},
pass: true,
},
}
for _, item := range tests {
params, err := fromParametersImpl(item.params)
if !item.pass {
// We only need to assert that expected failures have an error
if err == nil {
t.Fatalf("expected error configuring filesystem driver with invalid param: %+v", item.params)
}
continue
}
if err != nil {
t.Fatalf("unexpected error creating filesystem driver: %s", err)
}
// Note that we get a pointer to params back
if !reflect.DeepEqual(*params, item.expected) {
t.Fatalf("unexpected params from filesystem driver. expected %+v, got %+v", item.expected, params)
}
}
}

View file

@ -7,11 +7,8 @@
// Because gcs is a key, value store the Stat call does not support last modification // Because gcs is a key, value store the Stat call does not support last modification
// time for directories (directories are an abstraction for key, value stores) // time for directories (directories are an abstraction for key, value stores)
// //
// Keep in mind that gcs guarantees only eventual consistency, so do not assume // Note that the contents of incomplete uploads are not accessible even though
// that a successful write will mean immediate access to the data written (although // Stat returns their length
// in most regions a new object put has guaranteed read after write). The only true
// guarantee is that once you call Stat and receive a certain file size, that much of
// the file is already accessible.
// //
// +build include_gcs // +build include_gcs
@ -25,7 +22,10 @@ import (
"math/rand" "math/rand"
"net/http" "net/http"
"net/url" "net/url"
"reflect"
"regexp"
"sort" "sort"
"strconv"
"strings" "strings"
"time" "time"
@ -34,7 +34,6 @@ import (
"golang.org/x/oauth2/google" "golang.org/x/oauth2/google"
"golang.org/x/oauth2/jwt" "golang.org/x/oauth2/jwt"
"google.golang.org/api/googleapi" "google.golang.org/api/googleapi"
storageapi "google.golang.org/api/storage/v1"
"google.golang.org/cloud" "google.golang.org/cloud"
"google.golang.org/cloud/storage" "google.golang.org/cloud/storage"
@ -46,8 +45,18 @@ import (
"github.com/docker/distribution/registry/storage/driver/factory" "github.com/docker/distribution/registry/storage/driver/factory"
) )
const driverName = "gcs" const (
const dummyProjectID = "<unknown>" driverName = "gcs"
dummyProjectID = "<unknown>"
uploadSessionContentType = "application/x-docker-upload-session"
minChunkSize = 256 * 1024
defaultChunkSize = 20 * minChunkSize
maxTries = 5
)
var rangeHeader = regexp.MustCompile(`^bytes=([0-9])+-([0-9]+)$`)
// driverParameters is a struct that encapsulates all of the driver parameters after all values have been set // driverParameters is a struct that encapsulates all of the driver parameters after all values have been set
type driverParameters struct { type driverParameters struct {
@ -57,6 +66,7 @@ type driverParameters struct {
privateKey []byte privateKey []byte
client *http.Client client *http.Client
rootDirectory string rootDirectory string
chunkSize int
} }
func init() { func init() {
@ -79,6 +89,7 @@ type driver struct {
email string email string
privateKey []byte privateKey []byte
rootDirectory string rootDirectory string
chunkSize int
} }
// FromParameters constructs a new Driver with a given parameters map // FromParameters constructs a new Driver with a given parameters map
@ -95,6 +106,31 @@ func FromParameters(parameters map[string]interface{}) (storagedriver.StorageDri
rootDirectory = "" rootDirectory = ""
} }
chunkSize := defaultChunkSize
chunkSizeParam, ok := parameters["chunksize"]
if ok {
switch v := chunkSizeParam.(type) {
case string:
vv, err := strconv.Atoi(v)
if err != nil {
return nil, fmt.Errorf("chunksize parameter must be an integer, %v invalid", chunkSizeParam)
}
chunkSize = vv
case int, uint, int32, uint32, uint64, int64:
chunkSize = int(reflect.ValueOf(v).Convert(reflect.TypeOf(chunkSize)).Int())
default:
return nil, fmt.Errorf("invalid valud for chunksize: %#v", chunkSizeParam)
}
if chunkSize < minChunkSize {
return nil, fmt.Errorf("The chunksize %#v parameter should be a number that is larger than or equal to %d", chunkSize, minChunkSize)
}
if chunkSize%minChunkSize != 0 {
return nil, fmt.Errorf("chunksize should be a multiple of %d", minChunkSize)
}
}
var ts oauth2.TokenSource var ts oauth2.TokenSource
jwtConf := new(jwt.Config) jwtConf := new(jwt.Config)
if keyfile, ok := parameters["keyfile"]; ok { if keyfile, ok := parameters["keyfile"]; ok {
@ -113,7 +149,6 @@ func FromParameters(parameters map[string]interface{}) (storagedriver.StorageDri
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
params := driverParameters{ params := driverParameters{
@ -122,6 +157,7 @@ func FromParameters(parameters map[string]interface{}) (storagedriver.StorageDri
email: jwtConf.Email, email: jwtConf.Email,
privateKey: jwtConf.PrivateKey, privateKey: jwtConf.PrivateKey,
client: oauth2.NewClient(context.Background(), ts), client: oauth2.NewClient(context.Background(), ts),
chunkSize: chunkSize,
} }
return New(params) return New(params)
@ -133,12 +169,16 @@ func New(params driverParameters) (storagedriver.StorageDriver, error) {
if rootDirectory != "" { if rootDirectory != "" {
rootDirectory += "/" rootDirectory += "/"
} }
if params.chunkSize <= 0 || params.chunkSize%minChunkSize != 0 {
return nil, fmt.Errorf("Invalid chunksize: %d is not a positive multiple of %d", params.chunkSize, minChunkSize)
}
d := &driver{ d := &driver{
bucket: params.bucket, bucket: params.bucket,
rootDirectory: rootDirectory, rootDirectory: rootDirectory,
email: params.email, email: params.email,
privateKey: params.privateKey, privateKey: params.privateKey,
client: params.client, client: params.client,
chunkSize: params.chunkSize,
} }
return &base.Base{ return &base.Base{
@ -155,7 +195,17 @@ func (d *driver) Name() string {
// GetContent retrieves the content stored at "path" as a []byte. // GetContent retrieves the content stored at "path" as a []byte.
// This should primarily be used for small objects. // This should primarily be used for small objects.
func (d *driver) GetContent(context ctx.Context, path string) ([]byte, error) { func (d *driver) GetContent(context ctx.Context, path string) ([]byte, error) {
rc, err := d.ReadStream(context, path, 0) gcsContext := d.context(context)
name := d.pathToKey(path)
var rc io.ReadCloser
err := retry(func() error {
var err error
rc, err = storage.NewReader(gcsContext, d.bucket, name)
return err
})
if err == storage.ErrObjectNotExist {
return nil, storagedriver.PathNotFoundError{Path: path}
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -171,25 +221,53 @@ func (d *driver) GetContent(context ctx.Context, path string) ([]byte, error) {
// PutContent stores the []byte content at a location designated by "path". // PutContent stores the []byte content at a location designated by "path".
// This should primarily be used for small objects. // This should primarily be used for small objects.
func (d *driver) PutContent(context ctx.Context, path string, contents []byte) error { func (d *driver) PutContent(context ctx.Context, path string, contents []byte) error {
wc := storage.NewWriter(d.context(context), d.bucket, d.pathToKey(path)) return retry(func() error {
wc.ContentType = "application/octet-stream" wc := storage.NewWriter(d.context(context), d.bucket, d.pathToKey(path))
defer wc.Close() wc.ContentType = "application/octet-stream"
_, err := wc.Write(contents) return putContentsClose(wc, contents)
return err })
} }
// ReadStream retrieves an io.ReadCloser for the content stored at "path" // Reader retrieves an io.ReadCloser for the content stored at "path"
// with a given byte offset. // with a given byte offset.
// May be used to resume reading a stream by providing a nonzero offset. // May be used to resume reading a stream by providing a nonzero offset.
func (d *driver) ReadStream(context ctx.Context, path string, offset int64) (io.ReadCloser, error) { func (d *driver) Reader(context ctx.Context, path string, offset int64) (io.ReadCloser, error) {
name := d.pathToKey(path) res, err := getObject(d.client, d.bucket, d.pathToKey(path), offset)
if err != nil {
if res != nil {
if res.StatusCode == http.StatusNotFound {
res.Body.Close()
return nil, storagedriver.PathNotFoundError{Path: path}
}
if res.StatusCode == http.StatusRequestedRangeNotSatisfiable {
res.Body.Close()
obj, err := storageStatObject(d.context(context), d.bucket, d.pathToKey(path))
if err != nil {
return nil, err
}
if offset == int64(obj.Size) {
return ioutil.NopCloser(bytes.NewReader([]byte{})), nil
}
return nil, storagedriver.InvalidOffsetError{Path: path, Offset: offset}
}
}
return nil, err
}
if res.Header.Get("Content-Type") == uploadSessionContentType {
defer res.Body.Close()
return nil, storagedriver.PathNotFoundError{Path: path}
}
return res.Body, nil
}
func getObject(client *http.Client, bucket string, name string, offset int64) (*http.Response, error) {
// copied from google.golang.org/cloud/storage#NewReader : // copied from google.golang.org/cloud/storage#NewReader :
// to set the additional "Range" header // to set the additional "Range" header
u := &url.URL{ u := &url.URL{
Scheme: "https", Scheme: "https",
Host: "storage.googleapis.com", Host: "storage.googleapis.com",
Path: fmt.Sprintf("/%s/%s", d.bucket, name), Path: fmt.Sprintf("/%s/%s", bucket, name),
} }
req, err := http.NewRequest("GET", u.String(), nil) req, err := http.NewRequest("GET", u.String(), nil)
if err != nil { if err != nil {
@ -198,122 +276,253 @@ func (d *driver) ReadStream(context ctx.Context, path string, offset int64) (io.
if offset > 0 { if offset > 0 {
req.Header.Set("Range", fmt.Sprintf("bytes=%v-", offset)) req.Header.Set("Range", fmt.Sprintf("bytes=%v-", offset))
} }
res, err := d.client.Do(req) var res *http.Response
err = retry(func() error {
var err error
res, err = client.Do(req)
return err
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
if res.StatusCode == http.StatusNotFound { return res, googleapi.CheckMediaResponse(res)
res.Body.Close() }
return nil, storagedriver.PathNotFoundError{Path: path}
// Writer returns a FileWriter which will store the content written to it
// at the location designated by "path" after the call to Commit.
func (d *driver) Writer(context ctx.Context, path string, append bool) (storagedriver.FileWriter, error) {
writer := &writer{
client: d.client,
bucket: d.bucket,
name: d.pathToKey(path),
buffer: make([]byte, d.chunkSize),
} }
if res.StatusCode == http.StatusRequestedRangeNotSatisfiable {
res.Body.Close() if append {
obj, err := storageStatObject(d.context(context), d.bucket, name) err := writer.init(path)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if offset == int64(obj.Size) {
return ioutil.NopCloser(bytes.NewReader([]byte{})), nil
}
return nil, storagedriver.InvalidOffsetError{Path: path, Offset: offset}
} }
if res.StatusCode < 200 || res.StatusCode > 299 { return writer, nil
res.Body.Close()
return nil, fmt.Errorf("storage: can't read object %v/%v, status code: %v", d.bucket, name, res.Status)
}
return res.Body, nil
} }
// WriteStream stores the contents of the provided io.ReadCloser at a type writer struct {
// location designated by the given path. client *http.Client
// May be used to resume writing a stream by providing a nonzero offset. bucket string
// The offset must be no larger than the CurrentSize for this path. name string
func (d *driver) WriteStream(context ctx.Context, path string, offset int64, reader io.Reader) (totalRead int64, err error) { size int64
if offset < 0 { offset int64
return 0, storagedriver.InvalidOffsetError{Path: path, Offset: offset} closed bool
} sessionURI string
buffer []byte
buffSize int
}
if offset == 0 { // Cancel removes any written content from this FileWriter.
return d.writeCompletely(context, path, 0, reader) func (w *writer) Cancel() error {
} err := w.checkClosed()
service, err := storageapi.New(d.client)
if err != nil { if err != nil {
return 0, err
}
objService := storageapi.NewObjectsService(service)
var obj *storageapi.Object
err = retry(5, func() error {
o, err := objService.Get(d.bucket, d.pathToKey(path)).Do()
obj = o
return err return err
}) }
// obj, err := retry(5, objService.Get(d.bucket, d.pathToKey(path)).Do) w.closed = true
err = storageDeleteObject(cloud.NewContext(dummyProjectID, w.client), w.bucket, w.name)
if err != nil { if err != nil {
return 0, err if status, ok := err.(*googleapi.Error); ok {
} if status.Code == http.StatusNotFound {
err = nil
// cannot append more chunks, so redo from scratch }
if obj.ComponentCount >= 1023 { }
return d.writeCompletely(context, path, offset, reader) }
} return err
}
// skip from reader
objSize := int64(obj.Size) func (w *writer) Close() error {
nn, err := skip(reader, objSize-offset) if w.closed {
if err != nil { return nil
return nn, err }
} w.closed = true
// Size <= offset err := w.writeChunk()
partName := fmt.Sprintf("%v#part-%d#", d.pathToKey(path), obj.ComponentCount) if err != nil {
gcsContext := d.context(context) return err
wc := storage.NewWriter(gcsContext, d.bucket, partName) }
wc.ContentType = "application/octet-stream"
// Copy the remaining bytes from the buffer to the upload session
if objSize < offset { // Normally buffSize will be smaller than minChunkSize. However, in the
err = writeZeros(wc, offset-objSize) // unlikely event that the upload session failed to start, this number could be higher.
if err != nil { // In this case we can safely clip the remaining bytes to the minChunkSize
wc.CloseWithError(err) if w.buffSize > minChunkSize {
return nn, err w.buffSize = minChunkSize
}
// commit the writes by updating the upload session
err = retry(func() error {
wc := storage.NewWriter(cloud.NewContext(dummyProjectID, w.client), w.bucket, w.name)
wc.ContentType = uploadSessionContentType
wc.Metadata = map[string]string{
"Session-URI": w.sessionURI,
"Offset": strconv.FormatInt(w.offset, 10),
}
return putContentsClose(wc, w.buffer[0:w.buffSize])
})
if err != nil {
return err
}
w.size = w.offset + int64(w.buffSize)
w.buffSize = 0
return nil
}
func putContentsClose(wc *storage.Writer, contents []byte) error {
size := len(contents)
var nn int
var err error
for nn < size {
n, err := wc.Write(contents[nn:size])
nn += n
if err != nil {
break
} }
} }
n, err := io.Copy(wc, reader)
if err != nil { if err != nil {
wc.CloseWithError(err) wc.CloseWithError(err)
return nn, err return err
}
return wc.Close()
}
// Commit flushes all content written to this FileWriter and makes it
// available for future calls to StorageDriver.GetContent and
// StorageDriver.Reader.
func (w *writer) Commit() error {
if err := w.checkClosed(); err != nil {
return err
}
w.closed = true
// no session started yet just perform a simple upload
if w.sessionURI == "" {
err := retry(func() error {
wc := storage.NewWriter(cloud.NewContext(dummyProjectID, w.client), w.bucket, w.name)
wc.ContentType = "application/octet-stream"
return putContentsClose(wc, w.buffer[0:w.buffSize])
})
if err != nil {
return err
}
w.size = w.offset + int64(w.buffSize)
w.buffSize = 0
return nil
}
size := w.offset + int64(w.buffSize)
var nn int
// loop must be performed at least once to ensure the file is committed even when
// the buffer is empty
for {
n, err := putChunk(w.client, w.sessionURI, w.buffer[nn:w.buffSize], w.offset, size)
nn += int(n)
w.offset += n
w.size = w.offset
if err != nil {
w.buffSize = copy(w.buffer, w.buffer[nn:w.buffSize])
return err
}
if nn == w.buffSize {
break
}
}
w.buffSize = 0
return nil
}
func (w *writer) checkClosed() error {
if w.closed {
return fmt.Errorf("Writer already closed")
}
return nil
}
func (w *writer) writeChunk() error {
var err error
// chunks can be uploaded only in multiples of minChunkSize
// chunkSize is a multiple of minChunkSize less than or equal to buffSize
chunkSize := w.buffSize - (w.buffSize % minChunkSize)
if chunkSize == 0 {
return nil
}
// if their is no sessionURI yet, obtain one by starting the session
if w.sessionURI == "" {
w.sessionURI, err = startSession(w.client, w.bucket, w.name)
} }
err = wc.Close()
if err != nil { if err != nil {
return nn, err return err
} }
// wc was closed successfully, so the temporary part exists, schedule it for deletion at the end nn, err := putChunk(w.client, w.sessionURI, w.buffer[0:chunkSize], w.offset, -1)
// of the function w.offset += nn
defer storageDeleteObject(gcsContext, d.bucket, partName) if w.offset > w.size {
w.size = w.offset
}
// shift the remaining bytes to the start of the buffer
w.buffSize = copy(w.buffer, w.buffer[int(nn):w.buffSize])
req := &storageapi.ComposeRequest{ return err
Destination: &storageapi.Object{Bucket: obj.Bucket, Name: obj.Name, ContentType: obj.ContentType}, }
SourceObjects: []*storageapi.ComposeRequestSourceObjects{
{ func (w *writer) Write(p []byte) (int, error) {
Name: obj.Name, err := w.checkClosed()
Generation: obj.Generation, if err != nil {
}, { return 0, err
Name: partName,
Generation: wc.Object().Generation,
}},
} }
err = retry(5, func() error { _, err := objService.Compose(d.bucket, obj.Name, req).Do(); return err }) var nn int
if err == nil { for nn < len(p) {
nn = nn + n n := copy(w.buffer[w.buffSize:], p[nn:])
w.buffSize += n
if w.buffSize == cap(w.buffer) {
err = w.writeChunk()
if err != nil {
break
}
}
nn += n
} }
return nn, err return nn, err
} }
// Size returns the number of bytes written to this FileWriter.
func (w *writer) Size() int64 {
return w.size
}
func (w *writer) init(path string) error {
res, err := getObject(w.client, w.bucket, w.name, 0)
if err != nil {
return err
}
defer res.Body.Close()
if res.Header.Get("Content-Type") != uploadSessionContentType {
return storagedriver.PathNotFoundError{Path: path}
}
offset, err := strconv.ParseInt(res.Header.Get("X-Goog-Meta-Offset"), 10, 64)
if err != nil {
return err
}
buffer, err := ioutil.ReadAll(res.Body)
if err != nil {
return err
}
w.sessionURI = res.Header.Get("X-Goog-Meta-Session-URI")
w.buffSize = copy(w.buffer, buffer)
w.offset = offset
w.size = offset + int64(w.buffSize)
return nil
}
type request func() error type request func() error
func retry(maxTries int, req request) error { func retry(req request) error {
backoff := time.Second backoff := time.Second
var err error var err error
for i := 0; i < maxTries; i++ { for i := 0; i < maxTries; i++ {
@ -335,53 +544,6 @@ func retry(maxTries int, req request) error {
return err return err
} }
func (d *driver) writeCompletely(context ctx.Context, path string, offset int64, reader io.Reader) (totalRead int64, err error) {
wc := storage.NewWriter(d.context(context), d.bucket, d.pathToKey(path))
wc.ContentType = "application/octet-stream"
defer wc.Close()
// Copy the first offset bytes of the existing contents
// (padded with zeros if needed) into the writer
if offset > 0 {
existing, err := d.ReadStream(context, path, 0)
if err != nil {
return 0, err
}
defer existing.Close()
n, err := io.CopyN(wc, existing, offset)
if err == io.EOF {
err = writeZeros(wc, offset-n)
}
if err != nil {
return 0, err
}
}
return io.Copy(wc, reader)
}
func skip(reader io.Reader, count int64) (int64, error) {
if count <= 0 {
return 0, nil
}
return io.CopyN(ioutil.Discard, reader, count)
}
func writeZeros(wc io.Writer, count int64) error {
buf := make([]byte, 32*1024)
for count > 0 {
size := cap(buf)
if int64(size) > count {
size = int(count)
}
n, err := wc.Write(buf[0:size])
if err != nil {
return err
}
count = count - int64(n)
}
return nil
}
// Stat retrieves the FileInfo for the given path, including the current // Stat retrieves the FileInfo for the given path, including the current
// size in bytes and the creation time. // size in bytes and the creation time.
func (d *driver) Stat(context ctx.Context, path string) (storagedriver.FileInfo, error) { func (d *driver) Stat(context ctx.Context, path string) (storagedriver.FileInfo, error) {
@ -390,6 +552,9 @@ func (d *driver) Stat(context ctx.Context, path string) (storagedriver.FileInfo,
gcsContext := d.context(context) gcsContext := d.context(context)
obj, err := storageStatObject(gcsContext, d.bucket, d.pathToKey(path)) obj, err := storageStatObject(gcsContext, d.bucket, d.pathToKey(path))
if err == nil { if err == nil {
if obj.ContentType == uploadSessionContentType {
return nil, storagedriver.PathNotFoundError{Path: path}
}
fi = storagedriver.FileInfoFields{ fi = storagedriver.FileInfoFields{
Path: path, Path: path,
Size: obj.Size, Size: obj.Size,
@ -440,15 +605,10 @@ func (d *driver) List(context ctx.Context, path string) ([]string, error) {
} }
for _, object := range objects.Results { for _, object := range objects.Results {
// GCS does not guarantee strong consistency between // GCS does not guarantee strong consistency between
// DELETE and LIST operationsCheck that the object is not deleted, // DELETE and LIST operations. Check that the object is not deleted,
// so filter out any objects with a non-zero time-deleted // and filter out any objects with a non-zero time-deleted
if object.Deleted.IsZero() { if object.Deleted.IsZero() && object.ContentType != uploadSessionContentType {
name := object.Name list = append(list, d.keyToPath(object.Name))
// Ignore objects with names that end with '#' (these are uploaded parts)
if name[len(name)-1] != '#' {
name = d.keyToPath(name)
list = append(list, name)
}
} }
} }
for _, subpath := range objects.Prefixes { for _, subpath := range objects.Prefixes {
@ -474,7 +634,7 @@ func (d *driver) Move(context ctx.Context, sourcePath string, destPath string) e
gcsContext := d.context(context) gcsContext := d.context(context)
_, err := storageCopyObject(gcsContext, d.bucket, d.pathToKey(sourcePath), d.bucket, d.pathToKey(destPath), nil) _, err := storageCopyObject(gcsContext, d.bucket, d.pathToKey(sourcePath), d.bucket, d.pathToKey(destPath), nil)
if err != nil { if err != nil {
if status := err.(*googleapi.Error); status != nil { if status, ok := err.(*googleapi.Error); ok {
if status.Code == http.StatusNotFound { if status.Code == http.StatusNotFound {
return storagedriver.PathNotFoundError{Path: sourcePath} return storagedriver.PathNotFoundError{Path: sourcePath}
} }
@ -482,7 +642,7 @@ func (d *driver) Move(context ctx.Context, sourcePath string, destPath string) e
return err return err
} }
err = storageDeleteObject(gcsContext, d.bucket, d.pathToKey(sourcePath)) err = storageDeleteObject(gcsContext, d.bucket, d.pathToKey(sourcePath))
// if deleting the file fails, log the error, but do not fail; the file was succesfully copied, // if deleting the file fails, log the error, but do not fail; the file was successfully copied,
// and the original should eventually be cleaned when purging the uploads folder. // and the original should eventually be cleaned when purging the uploads folder.
if err != nil { if err != nil {
logrus.Infof("error deleting file: %v due to %v", sourcePath, err) logrus.Infof("error deleting file: %v due to %v", sourcePath, err)
@ -545,7 +705,7 @@ func (d *driver) Delete(context ctx.Context, path string) error {
} }
err = storageDeleteObject(gcsContext, d.bucket, d.pathToKey(path)) err = storageDeleteObject(gcsContext, d.bucket, d.pathToKey(path))
if err != nil { if err != nil {
if status := err.(*googleapi.Error); status != nil { if status, ok := err.(*googleapi.Error); ok {
if status.Code == http.StatusNotFound { if status.Code == http.StatusNotFound {
return storagedriver.PathNotFoundError{Path: path} return storagedriver.PathNotFoundError{Path: path}
} }
@ -555,14 +715,14 @@ func (d *driver) Delete(context ctx.Context, path string) error {
} }
func storageDeleteObject(context context.Context, bucket string, name string) error { func storageDeleteObject(context context.Context, bucket string, name string) error {
return retry(5, func() error { return retry(func() error {
return storage.DeleteObject(context, bucket, name) return storage.DeleteObject(context, bucket, name)
}) })
} }
func storageStatObject(context context.Context, bucket string, name string) (*storage.Object, error) { func storageStatObject(context context.Context, bucket string, name string) (*storage.Object, error) {
var obj *storage.Object var obj *storage.Object
err := retry(5, func() error { err := retry(func() error {
var err error var err error
obj, err = storage.StatObject(context, bucket, name) obj, err = storage.StatObject(context, bucket, name)
return err return err
@ -572,7 +732,7 @@ func storageStatObject(context context.Context, bucket string, name string) (*st
func storageListObjects(context context.Context, bucket string, q *storage.Query) (*storage.Objects, error) { func storageListObjects(context context.Context, bucket string, q *storage.Query) (*storage.Objects, error) {
var objs *storage.Objects var objs *storage.Objects
err := retry(5, func() error { err := retry(func() error {
var err error var err error
objs, err = storage.ListObjects(context, bucket, q) objs, err = storage.ListObjects(context, bucket, q)
return err return err
@ -582,7 +742,7 @@ func storageListObjects(context context.Context, bucket string, q *storage.Query
func storageCopyObject(context context.Context, srcBucket, srcName string, destBucket, destName string, attrs *storage.ObjectAttrs) (*storage.Object, error) { func storageCopyObject(context context.Context, srcBucket, srcName string, destBucket, destName string, attrs *storage.ObjectAttrs) (*storage.Object, error) {
var obj *storage.Object var obj *storage.Object
err := retry(5, func() error { err := retry(func() error {
var err error var err error
obj, err = storage.CopyObject(context, srcBucket, srcName, destBucket, destName, attrs) obj, err = storage.CopyObject(context, srcBucket, srcName, destBucket, destName, attrs)
return err return err
@ -626,6 +786,80 @@ func (d *driver) URLFor(context ctx.Context, path string, options map[string]int
return storage.SignedURL(d.bucket, name, opts) return storage.SignedURL(d.bucket, name, opts)
} }
func startSession(client *http.Client, bucket string, name string) (uri string, err error) {
u := &url.URL{
Scheme: "https",
Host: "www.googleapis.com",
Path: fmt.Sprintf("/upload/storage/v1/b/%v/o", bucket),
RawQuery: fmt.Sprintf("uploadType=resumable&name=%v", name),
}
err = retry(func() error {
req, err := http.NewRequest("POST", u.String(), nil)
if err != nil {
return err
}
req.Header.Set("X-Upload-Content-Type", "application/octet-stream")
req.Header.Set("Content-Length", "0")
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
err = googleapi.CheckMediaResponse(resp)
if err != nil {
return err
}
uri = resp.Header.Get("Location")
return nil
})
return uri, err
}
func putChunk(client *http.Client, sessionURI string, chunk []byte, from int64, totalSize int64) (int64, error) {
bytesPut := int64(0)
err := retry(func() error {
req, err := http.NewRequest("PUT", sessionURI, bytes.NewReader(chunk))
if err != nil {
return err
}
length := int64(len(chunk))
to := from + length - 1
size := "*"
if totalSize >= 0 {
size = strconv.FormatInt(totalSize, 10)
}
req.Header.Set("Content-Type", "application/octet-stream")
if from == to+1 {
req.Header.Set("Content-Range", fmt.Sprintf("bytes */%v", size))
} else {
req.Header.Set("Content-Range", fmt.Sprintf("bytes %v-%v/%v", from, to, size))
}
req.Header.Set("Content-Length", strconv.FormatInt(length, 10))
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if totalSize < 0 && resp.StatusCode == 308 {
groups := rangeHeader.FindStringSubmatch(resp.Header.Get("Range"))
end, err := strconv.ParseInt(groups[2], 10, 64)
if err != nil {
return err
}
bytesPut = end - from + 1
return nil
}
err = googleapi.CheckMediaResponse(resp)
if err != nil {
return err
}
bytesPut = to - from + 1
return nil
})
return bytesPut, err
}
func (d *driver) context(context ctx.Context) context.Context { func (d *driver) context(context ctx.Context) context.Context {
return cloud.WithContext(context, dummyProjectID, d.client) return cloud.WithContext(context, dummyProjectID, d.client)
} }

View file

@ -75,6 +75,7 @@ func init() {
email: email, email: email,
privateKey: privateKey, privateKey: privateKey,
client: oauth2.NewClient(ctx.Background(), ts), client: oauth2.NewClient(ctx.Background(), ts),
chunkSize: defaultChunkSize,
} }
return New(parameters) return New(parameters)
@ -85,6 +86,102 @@ func init() {
}, skipGCS) }, skipGCS)
} }
// Test Committing a FileWriter without having called Write
func TestCommitEmpty(t *testing.T) {
if skipGCS() != "" {
t.Skip(skipGCS())
}
validRoot, err := ioutil.TempDir("", "driver-")
if err != nil {
t.Fatalf("unexpected error creating temporary directory: %v", err)
}
defer os.Remove(validRoot)
driver, err := gcsDriverConstructor(validRoot)
if err != nil {
t.Fatalf("unexpected error creating rooted driver: %v", err)
}
filename := "/test"
ctx := ctx.Background()
writer, err := driver.Writer(ctx, filename, false)
defer driver.Delete(ctx, filename)
if err != nil {
t.Fatalf("driver.Writer: unexpected error: %v", err)
}
err = writer.Commit()
if err != nil {
t.Fatalf("writer.Commit: unexpected error: %v", err)
}
err = writer.Close()
if err != nil {
t.Fatalf("writer.Close: unexpected error: %v", err)
}
if writer.Size() != 0 {
t.Fatalf("writer.Size: %d != 0", writer.Size())
}
readContents, err := driver.GetContent(ctx, filename)
if err != nil {
t.Fatalf("driver.GetContent: unexpected error: %v", err)
}
if len(readContents) != 0 {
t.Fatalf("len(driver.GetContent(..)): %d != 0", len(readContents))
}
}
// Test Committing a FileWriter after having written exactly
// defaultChunksize bytes.
func TestCommit(t *testing.T) {
if skipGCS() != "" {
t.Skip(skipGCS())
}
validRoot, err := ioutil.TempDir("", "driver-")
if err != nil {
t.Fatalf("unexpected error creating temporary directory: %v", err)
}
defer os.Remove(validRoot)
driver, err := gcsDriverConstructor(validRoot)
if err != nil {
t.Fatalf("unexpected error creating rooted driver: %v", err)
}
filename := "/test"
ctx := ctx.Background()
contents := make([]byte, defaultChunkSize)
writer, err := driver.Writer(ctx, filename, false)
defer driver.Delete(ctx, filename)
if err != nil {
t.Fatalf("driver.Writer: unexpected error: %v", err)
}
_, err = writer.Write(contents)
if err != nil {
t.Fatalf("writer.Write: unexpected error: %v", err)
}
err = writer.Commit()
if err != nil {
t.Fatalf("writer.Commit: unexpected error: %v", err)
}
err = writer.Close()
if err != nil {
t.Fatalf("writer.Close: unexpected error: %v", err)
}
if writer.Size() != int64(len(contents)) {
t.Fatalf("writer.Size: %d != %d", writer.Size(), len(contents))
}
readContents, err := driver.GetContent(ctx, filename)
if err != nil {
t.Fatalf("driver.GetContent: unexpected error: %v", err)
}
if len(readContents) != len(contents) {
t.Fatalf("len(driver.GetContent(..)): %d != %d", len(readContents), len(contents))
}
}
func TestRetry(t *testing.T) { func TestRetry(t *testing.T) {
if skipGCS() != "" { if skipGCS() != "" {
t.Skip(skipGCS()) t.Skip(skipGCS())
@ -100,7 +197,7 @@ func TestRetry(t *testing.T) {
} }
} }
err := retry(2, func() error { err := retry(func() error {
return &googleapi.Error{ return &googleapi.Error{
Code: 503, Code: 503,
Message: "google api error", Message: "google api error",
@ -108,7 +205,7 @@ func TestRetry(t *testing.T) {
}) })
assertError("googleapi: Error 503: google api error", err) assertError("googleapi: Error 503: google api error", err)
err = retry(2, func() error { err = retry(func() error {
return &googleapi.Error{ return &googleapi.Error{
Code: 404, Code: 404,
Message: "google api error", Message: "google api error",
@ -116,7 +213,7 @@ func TestRetry(t *testing.T) {
}) })
assertError("googleapi: Error 404: google api error", err) assertError("googleapi: Error 404: google api error", err)
err = retry(2, func() error { err = retry(func() error {
return fmt.Errorf("error") return fmt.Errorf("error")
}) })
assertError("error", err) assertError("error", err)

View file

@ -1,7 +1,6 @@
package inmemory package inmemory
import ( import (
"bytes"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
@ -74,7 +73,7 @@ func (d *driver) GetContent(ctx context.Context, path string) ([]byte, error) {
d.mutex.RLock() d.mutex.RLock()
defer d.mutex.RUnlock() defer d.mutex.RUnlock()
rc, err := d.ReadStream(ctx, path, 0) rc, err := d.Reader(ctx, path, 0)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -88,7 +87,9 @@ func (d *driver) PutContent(ctx context.Context, p string, contents []byte) erro
d.mutex.Lock() d.mutex.Lock()
defer d.mutex.Unlock() defer d.mutex.Unlock()
f, err := d.root.mkfile(p) normalized := normalize(p)
f, err := d.root.mkfile(normalized)
if err != nil { if err != nil {
// TODO(stevvooe): Again, we need to clarify when this is not a // TODO(stevvooe): Again, we need to clarify when this is not a
// directory in StorageDriver API. // directory in StorageDriver API.
@ -101,9 +102,9 @@ func (d *driver) PutContent(ctx context.Context, p string, contents []byte) erro
return nil return nil
} }
// ReadStream retrieves an io.ReadCloser for the content stored at "path" with a // Reader retrieves an io.ReadCloser for the content stored at "path" with a
// given byte offset. // given byte offset.
func (d *driver) ReadStream(ctx context.Context, path string, offset int64) (io.ReadCloser, error) { func (d *driver) Reader(ctx context.Context, path string, offset int64) (io.ReadCloser, error) {
d.mutex.RLock() d.mutex.RLock()
defer d.mutex.RUnlock() defer d.mutex.RUnlock()
@ -111,10 +112,10 @@ func (d *driver) ReadStream(ctx context.Context, path string, offset int64) (io.
return nil, storagedriver.InvalidOffsetError{Path: path, Offset: offset} return nil, storagedriver.InvalidOffsetError{Path: path, Offset: offset}
} }
path = normalize(path) normalized := normalize(path)
found := d.root.find(path) found := d.root.find(normalized)
if found.path() != path { if found.path() != normalized {
return nil, storagedriver.PathNotFoundError{Path: path} return nil, storagedriver.PathNotFoundError{Path: path}
} }
@ -125,46 +126,24 @@ func (d *driver) ReadStream(ctx context.Context, path string, offset int64) (io.
return ioutil.NopCloser(found.(*file).sectionReader(offset)), nil return ioutil.NopCloser(found.(*file).sectionReader(offset)), nil
} }
// WriteStream stores the contents of the provided io.ReadCloser at a location // Writer returns a FileWriter which will store the content written to it
// designated by the given path. // at the location designated by "path" after the call to Commit.
func (d *driver) WriteStream(ctx context.Context, path string, offset int64, reader io.Reader) (nn int64, err error) { func (d *driver) Writer(ctx context.Context, path string, append bool) (storagedriver.FileWriter, error) {
d.mutex.Lock() d.mutex.Lock()
defer d.mutex.Unlock() defer d.mutex.Unlock()
if offset < 0 {
return 0, storagedriver.InvalidOffsetError{Path: path, Offset: offset}
}
normalized := normalize(path) normalized := normalize(path)
f, err := d.root.mkfile(normalized) f, err := d.root.mkfile(normalized)
if err != nil { if err != nil {
return 0, fmt.Errorf("not a file") return nil, fmt.Errorf("not a file")
} }
// Unlock while we are reading from the source, in case we are reading if !append {
// from the same mfs instance. This can be fixed by a more granular f.truncate()
// locking model.
d.mutex.Unlock()
d.mutex.RLock() // Take the readlock to block other writers.
var buf bytes.Buffer
nn, err = buf.ReadFrom(reader)
if err != nil {
// TODO(stevvooe): This condition is odd and we may need to clarify:
// we've read nn bytes from reader but have written nothing to the
// backend. What is the correct return value? Really, the caller needs
// to know that the reader has been advanced and reattempting the
// operation is incorrect.
d.mutex.RUnlock()
d.mutex.Lock()
return nn, err
} }
d.mutex.RUnlock() return d.newWriter(f), nil
d.mutex.Lock()
f.WriteAt(buf.Bytes(), offset)
return nn, err
} }
// Stat returns info about the provided path. // Stat returns info about the provided path.
@ -173,7 +152,7 @@ func (d *driver) Stat(ctx context.Context, path string) (storagedriver.FileInfo,
defer d.mutex.RUnlock() defer d.mutex.RUnlock()
normalized := normalize(path) normalized := normalize(path)
found := d.root.find(path) found := d.root.find(normalized)
if found.path() != normalized { if found.path() != normalized {
return nil, storagedriver.PathNotFoundError{Path: path} return nil, storagedriver.PathNotFoundError{Path: path}
@ -260,3 +239,74 @@ func (d *driver) Delete(ctx context.Context, path string) error {
func (d *driver) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) { func (d *driver) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) {
return "", storagedriver.ErrUnsupportedMethod{} return "", storagedriver.ErrUnsupportedMethod{}
} }
type writer struct {
d *driver
f *file
closed bool
committed bool
cancelled bool
}
func (d *driver) newWriter(f *file) storagedriver.FileWriter {
return &writer{
d: d,
f: f,
}
}
func (w *writer) Write(p []byte) (int, error) {
if w.closed {
return 0, fmt.Errorf("already closed")
} else if w.committed {
return 0, fmt.Errorf("already committed")
} else if w.cancelled {
return 0, fmt.Errorf("already cancelled")
}
w.d.mutex.Lock()
defer w.d.mutex.Unlock()
return w.f.WriteAt(p, int64(len(w.f.data)))
}
func (w *writer) Size() int64 {
w.d.mutex.RLock()
defer w.d.mutex.RUnlock()
return int64(len(w.f.data))
}
func (w *writer) Close() error {
if w.closed {
return fmt.Errorf("already closed")
}
w.closed = true
return nil
}
func (w *writer) Cancel() error {
if w.closed {
return fmt.Errorf("already closed")
} else if w.committed {
return fmt.Errorf("already committed")
}
w.cancelled = true
w.d.mutex.Lock()
defer w.d.mutex.Unlock()
return w.d.root.delete(w.f.path())
}
func (w *writer) Commit() error {
if w.closed {
return fmt.Errorf("already closed")
} else if w.committed {
return fmt.Errorf("already committed")
} else if w.cancelled {
return fmt.Errorf("already cancelled")
}
w.committed = true
return nil
}

View file

@ -18,7 +18,7 @@ import (
storagemiddleware "github.com/docker/distribution/registry/storage/driver/middleware" storagemiddleware "github.com/docker/distribution/registry/storage/driver/middleware"
) )
// cloudFrontStorageMiddleware provides an simple implementation of layerHandler that // cloudFrontStorageMiddleware provides a simple implementation of layerHandler that
// constructs temporary signed CloudFront URLs from the storagedriver layer URL, // constructs temporary signed CloudFront URLs from the storagedriver layer URL,
// then issues HTTP Temporary Redirects to this CloudFront content URL. // then issues HTTP Temporary Redirects to this CloudFront content URL.
type cloudFrontStorageMiddleware struct { type cloudFrontStorageMiddleware struct {

View file

@ -0,0 +1,50 @@
package middleware
import (
"fmt"
"net/url"
"github.com/docker/distribution/context"
storagedriver "github.com/docker/distribution/registry/storage/driver"
storagemiddleware "github.com/docker/distribution/registry/storage/driver/middleware"
)
type redirectStorageMiddleware struct {
storagedriver.StorageDriver
scheme string
host string
}
var _ storagedriver.StorageDriver = &redirectStorageMiddleware{}
func newRedirectStorageMiddleware(sd storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error) {
o, ok := options["baseurl"]
if !ok {
return nil, fmt.Errorf("no baseurl provided")
}
b, ok := o.(string)
if !ok {
return nil, fmt.Errorf("baseurl must be a string")
}
u, err := url.Parse(b)
if err != nil {
return nil, fmt.Errorf("unable to parse redirect baseurl: %s", b)
}
if u.Scheme == "" {
return nil, fmt.Errorf("no scheme specified for redirect baseurl")
}
if u.Host == "" {
return nil, fmt.Errorf("no host specified for redirect baseurl")
}
return &redirectStorageMiddleware{StorageDriver: sd, scheme: u.Scheme, host: u.Host}, nil
}
func (r *redirectStorageMiddleware) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) {
u := &url.URL{Scheme: r.scheme, Host: r.host, Path: path}
return u.String(), nil
}
func init() {
storagemiddleware.Register("redirect", storagemiddleware.InitFunc(newRedirectStorageMiddleware))
}

View file

@ -0,0 +1,58 @@
package middleware
import (
"testing"
check "gopkg.in/check.v1"
)
func Test(t *testing.T) { check.TestingT(t) }
type MiddlewareSuite struct{}
var _ = check.Suite(&MiddlewareSuite{})
func (s *MiddlewareSuite) TestNoConfig(c *check.C) {
options := make(map[string]interface{})
_, err := newRedirectStorageMiddleware(nil, options)
c.Assert(err, check.ErrorMatches, "no baseurl provided")
}
func (s *MiddlewareSuite) TestMissingScheme(c *check.C) {
options := make(map[string]interface{})
options["baseurl"] = "example.com"
_, err := newRedirectStorageMiddleware(nil, options)
c.Assert(err, check.ErrorMatches, "no scheme specified for redirect baseurl")
}
func (s *MiddlewareSuite) TestHttpsPort(c *check.C) {
options := make(map[string]interface{})
options["baseurl"] = "https://example.com:5443"
middleware, err := newRedirectStorageMiddleware(nil, options)
c.Assert(err, check.Equals, nil)
m, ok := middleware.(*redirectStorageMiddleware)
c.Assert(ok, check.Equals, true)
c.Assert(m.scheme, check.Equals, "https")
c.Assert(m.host, check.Equals, "example.com:5443")
url, err := middleware.URLFor(nil, "/rick/data", nil)
c.Assert(err, check.Equals, nil)
c.Assert(url, check.Equals, "https://example.com:5443/rick/data")
}
func (s *MiddlewareSuite) TestHTTP(c *check.C) {
options := make(map[string]interface{})
options["baseurl"] = "http://example.com"
middleware, err := newRedirectStorageMiddleware(nil, options)
c.Assert(err, check.Equals, nil)
m, ok := middleware.(*redirectStorageMiddleware)
c.Assert(ok, check.Equals, true)
c.Assert(m.scheme, check.Equals, "http")
c.Assert(m.host, check.Equals, "example.com")
url, err := middleware.URLFor(nil, "morty/data", nil)
c.Assert(err, check.Equals, nil)
c.Assert(url, check.Equals, "http://example.com/morty/data")
}

View file

@ -20,7 +20,6 @@ import (
"reflect" "reflect"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"github.com/docker/distribution/context" "github.com/docker/distribution/context"
@ -75,9 +74,6 @@ type driver struct {
ChunkSize int64 ChunkSize int64
Encrypt bool Encrypt bool
RootDirectory string RootDirectory string
pool sync.Pool // pool []byte buffers used for WriteStream
zeros []byte // shared, zero-valued buffer used for WriteStream
} }
type baseEmbed struct { type baseEmbed struct {
@ -99,8 +95,7 @@ type Driver struct {
// - encrypt // - encrypt
func FromParameters(parameters map[string]interface{}) (*Driver, error) { func FromParameters(parameters map[string]interface{}) (*Driver, error) {
// Providing no values for these is valid in case the user is authenticating // Providing no values for these is valid in case the user is authenticating
// with an IAM on an ec2 instance (in which case the instance credentials will
// be summoned when GetAuth is called)
accessKey, ok := parameters["accesskeyid"] accessKey, ok := parameters["accesskeyid"]
if !ok { if !ok {
return nil, fmt.Errorf("No accesskeyid parameter provided") return nil, fmt.Errorf("No accesskeyid parameter provided")
@ -220,11 +215,6 @@ func New(params DriverParameters) (*Driver, error) {
ChunkSize: params.ChunkSize, ChunkSize: params.ChunkSize,
Encrypt: params.Encrypt, Encrypt: params.Encrypt,
RootDirectory: params.RootDirectory, RootDirectory: params.RootDirectory,
zeros: make([]byte, params.ChunkSize),
}
d.pool.New = func() interface{} {
return make([]byte, d.ChunkSize)
} }
return &Driver{ return &Driver{
@ -256,9 +246,9 @@ func (d *driver) PutContent(ctx context.Context, path string, contents []byte) e
return parseError(path, d.Bucket.Put(d.ossPath(path), contents, d.getContentType(), getPermissions(), d.getOptions())) return parseError(path, d.Bucket.Put(d.ossPath(path), contents, d.getContentType(), getPermissions(), d.getOptions()))
} }
// ReadStream retrieves an io.ReadCloser for the content stored at "path" with a // Reader retrieves an io.ReadCloser for the content stored at "path" with a
// given byte offset. // given byte offset.
func (d *driver) ReadStream(ctx context.Context, path string, offset int64) (io.ReadCloser, error) { func (d *driver) Reader(ctx context.Context, path string, offset int64) (io.ReadCloser, error) {
headers := make(http.Header) headers := make(http.Header)
headers.Add("Range", "bytes="+strconv.FormatInt(offset, 10)+"-") headers.Add("Range", "bytes="+strconv.FormatInt(offset, 10)+"-")
@ -279,315 +269,37 @@ func (d *driver) ReadStream(ctx context.Context, path string, offset int64) (io.
return resp.Body, nil return resp.Body, nil
} }
// WriteStream stores the contents of the provided io.Reader at a // Writer returns a FileWriter which will store the content written to it
// location designated by the given path. The driver will know it has // at the location designated by "path" after the call to Commit.
// received the full contents when the reader returns io.EOF. The number func (d *driver) Writer(ctx context.Context, path string, append bool) (storagedriver.FileWriter, error) {
// of successfully READ bytes will be returned, even if an error is key := d.ossPath(path)
// returned. May be used to resume writing a stream by providing a nonzero if !append {
// offset. Offsets past the current size will write from the position // TODO (brianbland): cancel other uploads at this path
// beyond the end of the file. multi, err := d.Bucket.InitMulti(key, d.getContentType(), getPermissions(), d.getOptions())
func (d *driver) WriteStream(ctx context.Context, path string, offset int64, reader io.Reader) (totalRead int64, err error) { if err != nil {
partNumber := 1 return nil, err
bytesRead := 0 }
var putErrChan chan error return d.newWriter(key, multi, nil), nil
parts := []oss.Part{} }
var part oss.Part multis, _, err := d.Bucket.ListMulti(key, "")
done := make(chan struct{}) // stopgap to free up waiting goroutines
multi, err := d.Bucket.InitMulti(d.ossPath(path), d.getContentType(), getPermissions(), d.getOptions())
if err != nil { if err != nil {
return 0, err return nil, parseError(path, err)
} }
for _, multi := range multis {
buf := d.getbuf() if key != multi.Key {
continue
// We never want to leave a dangling multipart upload, our only consistent state is
// when there is a whole object at path. This is in order to remain consistent with
// the stat call.
//
// Note that if the machine dies before executing the defer, we will be left with a dangling
// multipart upload, which will eventually be cleaned up, but we will lose all of the progress
// made prior to the machine crashing.
defer func() {
if putErrChan != nil {
if putErr := <-putErrChan; putErr != nil {
err = putErr
}
} }
parts, err := multi.ListParts()
if len(parts) > 0 {
if multi == nil {
// Parts should be empty if the multi is not initialized
panic("Unreachable")
} else {
if multi.Complete(parts) != nil {
multi.Abort()
}
}
}
d.putbuf(buf) // needs to be here to pick up new buf value
close(done) // free up any waiting goroutines
}()
// Fills from 0 to total from current
fromSmallCurrent := func(total int64) error {
current, err := d.ReadStream(ctx, path, 0)
if err != nil { if err != nil {
return err return nil, parseError(path, err)
} }
var multiSize int64
bytesRead = 0 for _, part := range parts {
for int64(bytesRead) < total { multiSize += part.Size
//The loop should very rarely enter a second iteration
nn, err := current.Read(buf[bytesRead:total])
bytesRead += nn
if err != nil {
if err != io.EOF {
return err
}
break
}
} }
return nil return d.newWriter(key, multi, parts), nil
} }
return nil, storagedriver.PathNotFoundError{Path: path}
// Fills from parameter to chunkSize from reader
fromReader := func(from int64) error {
bytesRead = 0
for from+int64(bytesRead) < d.ChunkSize {
nn, err := reader.Read(buf[from+int64(bytesRead):])
totalRead += int64(nn)
bytesRead += nn
if err != nil {
if err != io.EOF {
return err
}
break
}
}
if putErrChan == nil {
putErrChan = make(chan error)
} else {
if putErr := <-putErrChan; putErr != nil {
putErrChan = nil
return putErr
}
}
go func(bytesRead int, from int64, buf []byte) {
defer d.putbuf(buf) // this buffer gets dropped after this call
// DRAGONS(stevvooe): There are few things one might want to know
// about this section. First, the putErrChan is expecting an error
// and a nil or just a nil to come through the channel. This is
// covered by the silly defer below. The other aspect is the OSS
// retry backoff to deal with RequestTimeout errors. Even though
// the underlying OSS library should handle it, it doesn't seem to
// be part of the shouldRetry function (see denverdino/aliyungo/oss).
defer func() {
select {
case putErrChan <- nil: // for some reason, we do this no matter what.
case <-done:
return // ensure we don't leak the goroutine
}
}()
if bytesRead <= 0 {
return
}
var err error
var part oss.Part
part, err = multi.PutPartWithTimeout(int(partNumber), bytes.NewReader(buf[0:int64(bytesRead)+from]), defaultTimeout)
if err != nil {
logrus.Errorf("error putting part, aborting: %v", err)
select {
case putErrChan <- err:
case <-done:
return // don't leak the goroutine
}
}
// parts and partNumber are safe, because this function is the
// only one modifying them and we force it to be executed
// serially.
parts = append(parts, part)
partNumber++
}(bytesRead, from, buf)
buf = d.getbuf() // use a new buffer for the next call
return nil
}
if offset > 0 {
resp, err := d.Bucket.Head(d.ossPath(path), nil)
if err != nil {
if ossErr, ok := err.(*oss.Error); !ok || ossErr.StatusCode != http.StatusNotFound {
return 0, err
}
}
currentLength := int64(0)
if err == nil {
currentLength = resp.ContentLength
}
if currentLength >= offset {
if offset < d.ChunkSize {
// chunkSize > currentLength >= offset
if err = fromSmallCurrent(offset); err != nil {
return totalRead, err
}
if err = fromReader(offset); err != nil {
return totalRead, err
}
if totalRead+offset < d.ChunkSize {
return totalRead, nil
}
} else {
// currentLength >= offset >= chunkSize
_, part, err = multi.PutPartCopy(partNumber,
oss.CopyOptions{CopySourceOptions: "bytes=0-" + strconv.FormatInt(offset-1, 10)},
d.Bucket.Path(d.ossPath(path)))
if err != nil {
return 0, err
}
parts = append(parts, part)
partNumber++
}
} else {
// Fills between parameters with 0s but only when to - from <= chunkSize
fromZeroFillSmall := func(from, to int64) error {
bytesRead = 0
for from+int64(bytesRead) < to {
nn, err := bytes.NewReader(d.zeros).Read(buf[from+int64(bytesRead) : to])
bytesRead += nn
if err != nil {
return err
}
}
return nil
}
// Fills between parameters with 0s, making new parts
fromZeroFillLarge := func(from, to int64) error {
bytesRead64 := int64(0)
for to-(from+bytesRead64) >= d.ChunkSize {
part, err := multi.PutPartWithTimeout(int(partNumber), bytes.NewReader(d.zeros), defaultTimeout)
if err != nil {
return err
}
bytesRead64 += d.ChunkSize
parts = append(parts, part)
partNumber++
}
return fromZeroFillSmall(0, (to-from)%d.ChunkSize)
}
// currentLength < offset
if currentLength < d.ChunkSize {
if offset < d.ChunkSize {
// chunkSize > offset > currentLength
if err = fromSmallCurrent(currentLength); err != nil {
return totalRead, err
}
if err = fromZeroFillSmall(currentLength, offset); err != nil {
return totalRead, err
}
if err = fromReader(offset); err != nil {
return totalRead, err
}
if totalRead+offset < d.ChunkSize {
return totalRead, nil
}
} else {
// offset >= chunkSize > currentLength
if err = fromSmallCurrent(currentLength); err != nil {
return totalRead, err
}
if err = fromZeroFillSmall(currentLength, d.ChunkSize); err != nil {
return totalRead, err
}
part, err = multi.PutPartWithTimeout(int(partNumber), bytes.NewReader(buf), defaultTimeout)
if err != nil {
return totalRead, err
}
parts = append(parts, part)
partNumber++
//Zero fill from chunkSize up to offset, then some reader
if err = fromZeroFillLarge(d.ChunkSize, offset); err != nil {
return totalRead, err
}
if err = fromReader(offset % d.ChunkSize); err != nil {
return totalRead, err
}
if totalRead+(offset%d.ChunkSize) < d.ChunkSize {
return totalRead, nil
}
}
} else {
// offset > currentLength >= chunkSize
_, part, err = multi.PutPartCopy(partNumber,
oss.CopyOptions{},
d.Bucket.Path(d.ossPath(path)))
if err != nil {
return 0, err
}
parts = append(parts, part)
partNumber++
//Zero fill from currentLength up to offset, then some reader
if err = fromZeroFillLarge(currentLength, offset); err != nil {
return totalRead, err
}
if err = fromReader((offset - currentLength) % d.ChunkSize); err != nil {
return totalRead, err
}
if totalRead+((offset-currentLength)%d.ChunkSize) < d.ChunkSize {
return totalRead, nil
}
}
}
}
for {
if err = fromReader(0); err != nil {
return totalRead, err
}
if int64(bytesRead) < d.ChunkSize {
break
}
}
return totalRead, nil
} }
// Stat retrieves the FileInfo for the given path, including the current size // Stat retrieves the FileInfo for the given path, including the current size
@ -778,12 +490,181 @@ func (d *driver) getContentType() string {
return "application/octet-stream" return "application/octet-stream"
} }
// getbuf returns a buffer from the driver's pool with length d.ChunkSize. // writer attempts to upload parts to S3 in a buffered fashion where the last
func (d *driver) getbuf() []byte { // part is at least as large as the chunksize, so the multipart upload could be
return d.pool.Get().([]byte) // cleanly resumed in the future. This is violated if Close is called after less
// than a full chunk is written.
type writer struct {
driver *driver
key string
multi *oss.Multi
parts []oss.Part
size int64
readyPart []byte
pendingPart []byte
closed bool
committed bool
cancelled bool
} }
func (d *driver) putbuf(p []byte) { func (d *driver) newWriter(key string, multi *oss.Multi, parts []oss.Part) storagedriver.FileWriter {
copy(p, d.zeros) var size int64
d.pool.Put(p) for _, part := range parts {
size += part.Size
}
return &writer{
driver: d,
key: key,
multi: multi,
parts: parts,
size: size,
}
}
func (w *writer) Write(p []byte) (int, error) {
if w.closed {
return 0, fmt.Errorf("already closed")
} else if w.committed {
return 0, fmt.Errorf("already committed")
} else if w.cancelled {
return 0, fmt.Errorf("already cancelled")
}
// If the last written part is smaller than minChunkSize, we need to make a
// new multipart upload :sadface:
if len(w.parts) > 0 && int(w.parts[len(w.parts)-1].Size) < minChunkSize {
err := w.multi.Complete(w.parts)
if err != nil {
w.multi.Abort()
return 0, err
}
multi, err := w.driver.Bucket.InitMulti(w.key, w.driver.getContentType(), getPermissions(), w.driver.getOptions())
if err != nil {
return 0, err
}
w.multi = multi
// If the entire written file is smaller than minChunkSize, we need to make
// a new part from scratch :double sad face:
if w.size < minChunkSize {
contents, err := w.driver.Bucket.Get(w.key)
if err != nil {
return 0, err
}
w.parts = nil
w.readyPart = contents
} else {
// Otherwise we can use the old file as the new first part
_, part, err := multi.PutPartCopy(1, oss.CopyOptions{}, w.driver.Bucket.Name+"/"+w.key)
if err != nil {
return 0, err
}
w.parts = []oss.Part{part}
}
}
var n int
for len(p) > 0 {
// If no parts are ready to write, fill up the first part
if neededBytes := int(w.driver.ChunkSize) - len(w.readyPart); neededBytes > 0 {
if len(p) >= neededBytes {
w.readyPart = append(w.readyPart, p[:neededBytes]...)
n += neededBytes
p = p[neededBytes:]
} else {
w.readyPart = append(w.readyPart, p...)
n += len(p)
p = nil
}
}
if neededBytes := int(w.driver.ChunkSize) - len(w.pendingPart); neededBytes > 0 {
if len(p) >= neededBytes {
w.pendingPart = append(w.pendingPart, p[:neededBytes]...)
n += neededBytes
p = p[neededBytes:]
err := w.flushPart()
if err != nil {
w.size += int64(n)
return n, err
}
} else {
w.pendingPart = append(w.pendingPart, p...)
n += len(p)
p = nil
}
}
}
w.size += int64(n)
return n, nil
}
func (w *writer) Size() int64 {
return w.size
}
func (w *writer) Close() error {
if w.closed {
return fmt.Errorf("already closed")
}
w.closed = true
return w.flushPart()
}
func (w *writer) Cancel() error {
if w.closed {
return fmt.Errorf("already closed")
} else if w.committed {
return fmt.Errorf("already committed")
}
w.cancelled = true
err := w.multi.Abort()
return err
}
func (w *writer) Commit() error {
if w.closed {
return fmt.Errorf("already closed")
} else if w.committed {
return fmt.Errorf("already committed")
} else if w.cancelled {
return fmt.Errorf("already cancelled")
}
err := w.flushPart()
if err != nil {
return err
}
w.committed = true
err = w.multi.Complete(w.parts)
if err != nil {
w.multi.Abort()
return err
}
return nil
}
// flushPart flushes buffers to write a part to S3.
// Only called by Write (with both buffers full) and Close/Commit (always)
func (w *writer) flushPart() error {
if len(w.readyPart) == 0 && len(w.pendingPart) == 0 {
// nothing to write
return nil
}
if len(w.pendingPart) < int(w.driver.ChunkSize) {
// closing with a small pending part
// combine ready and pending to avoid writing a small part
w.readyPart = append(w.readyPart, w.pendingPart...)
w.pendingPart = nil
}
part, err := w.multi.PutPart(len(w.parts)+1, bytes.NewReader(w.readyPart))
if err != nil {
return err
}
w.parts = append(w.parts, part)
w.readyPart = w.pendingPart
w.pendingPart = nil
return nil
} }

View file

@ -1,3 +0,0 @@
// Package rados implements the rados storage driver backend. Support can be
// enabled by including the "include_rados" build tag.
package rados

View file

@ -1,632 +0,0 @@
// +build include_rados
package rados
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"io/ioutil"
"path"
"strconv"
log "github.com/Sirupsen/logrus"
"github.com/docker/distribution/context"
storagedriver "github.com/docker/distribution/registry/storage/driver"
"github.com/docker/distribution/registry/storage/driver/base"
"github.com/docker/distribution/registry/storage/driver/factory"
"github.com/docker/distribution/uuid"
"github.com/noahdesu/go-ceph/rados"
)
const driverName = "rados"
// Prefix all the stored blob
const objectBlobPrefix = "blob:"
// Stripes objects size to 4M
const defaultChunkSize = 4 << 20
const defaultXattrTotalSizeName = "total-size"
// Max number of keys fetched from omap at each read operation
const defaultKeysFetched = 1
//DriverParameters A struct that encapsulates all of the driver parameters after all values have been set
type DriverParameters struct {
poolname string
username string
chunksize uint64
}
func init() {
factory.Register(driverName, &radosDriverFactory{})
}
// radosDriverFactory implements the factory.StorageDriverFactory interface
type radosDriverFactory struct{}
func (factory *radosDriverFactory) Create(parameters map[string]interface{}) (storagedriver.StorageDriver, error) {
return FromParameters(parameters)
}
type driver struct {
Conn *rados.Conn
Ioctx *rados.IOContext
chunksize uint64
}
type baseEmbed struct {
base.Base
}
// Driver is a storagedriver.StorageDriver implementation backed by Ceph RADOS
// Objects are stored at absolute keys in the provided bucket.
type Driver struct {
baseEmbed
}
// FromParameters constructs a new Driver with a given parameters map
// Required parameters:
// - poolname: the ceph pool name
func FromParameters(parameters map[string]interface{}) (*Driver, error) {
pool, ok := parameters["poolname"]
if !ok {
return nil, fmt.Errorf("No poolname parameter provided")
}
username, ok := parameters["username"]
if !ok {
username = ""
}
chunksize := uint64(defaultChunkSize)
chunksizeParam, ok := parameters["chunksize"]
if ok {
chunksize, ok = chunksizeParam.(uint64)
if !ok {
return nil, fmt.Errorf("The chunksize parameter should be a number")
}
}
params := DriverParameters{
fmt.Sprint(pool),
fmt.Sprint(username),
chunksize,
}
return New(params)
}
// New constructs a new Driver
func New(params DriverParameters) (*Driver, error) {
var conn *rados.Conn
var err error
if params.username != "" {
log.Infof("Opening connection to pool %s using user %s", params.poolname, params.username)
conn, err = rados.NewConnWithUser(params.username)
} else {
log.Infof("Opening connection to pool %s", params.poolname)
conn, err = rados.NewConn()
}
if err != nil {
return nil, err
}
err = conn.ReadDefaultConfigFile()
if err != nil {
return nil, err
}
err = conn.Connect()
if err != nil {
return nil, err
}
log.Infof("Connected")
ioctx, err := conn.OpenIOContext(params.poolname)
log.Infof("Connected to pool %s", params.poolname)
if err != nil {
return nil, err
}
d := &driver{
Ioctx: ioctx,
Conn: conn,
chunksize: params.chunksize,
}
return &Driver{
baseEmbed: baseEmbed{
Base: base.Base{
StorageDriver: d,
},
},
}, nil
}
// Implement the storagedriver.StorageDriver interface
func (d *driver) Name() string {
return driverName
}
// GetContent retrieves the content stored at "path" as a []byte.
func (d *driver) GetContent(ctx context.Context, path string) ([]byte, error) {
rc, err := d.ReadStream(ctx, path, 0)
if err != nil {
return nil, err
}
defer rc.Close()
p, err := ioutil.ReadAll(rc)
if err != nil {
return nil, err
}
return p, nil
}
// PutContent stores the []byte content at a location designated by "path".
func (d *driver) PutContent(ctx context.Context, path string, contents []byte) error {
if _, err := d.WriteStream(ctx, path, 0, bytes.NewReader(contents)); err != nil {
return err
}
return nil
}
// ReadStream retrieves an io.ReadCloser for the content stored at "path" with a
// given byte offset.
type readStreamReader struct {
driver *driver
oid string
size uint64
offset uint64
}
func (r *readStreamReader) Read(b []byte) (n int, err error) {
// Determine the part available to read
bufferOffset := uint64(0)
bufferSize := uint64(len(b))
// End of the object, read less than the buffer size
if bufferSize > r.size-r.offset {
bufferSize = r.size - r.offset
}
// Fill `b`
for bufferOffset < bufferSize {
// Get the offset in the object chunk
chunkedOid, chunkedOffset := r.driver.getChunkNameFromOffset(r.oid, r.offset)
// Determine the best size to read
bufferEndOffset := bufferSize
if bufferEndOffset-bufferOffset > r.driver.chunksize-chunkedOffset {
bufferEndOffset = bufferOffset + (r.driver.chunksize - chunkedOffset)
}
// Read the chunk
n, err = r.driver.Ioctx.Read(chunkedOid, b[bufferOffset:bufferEndOffset], chunkedOffset)
if err != nil {
return int(bufferOffset), err
}
bufferOffset += uint64(n)
r.offset += uint64(n)
}
// EOF if the offset is at the end of the object
if r.offset == r.size {
return int(bufferOffset), io.EOF
}
return int(bufferOffset), nil
}
func (r *readStreamReader) Close() error {
return nil
}
func (d *driver) ReadStream(ctx context.Context, path string, offset int64) (io.ReadCloser, error) {
// get oid from filename
oid, err := d.getOid(path)
if err != nil {
return nil, err
}
// get object stat
stat, err := d.Stat(ctx, path)
if err != nil {
return nil, err
}
if offset > stat.Size() {
return nil, storagedriver.InvalidOffsetError{Path: path, Offset: offset}
}
return &readStreamReader{
driver: d,
oid: oid,
size: uint64(stat.Size()),
offset: uint64(offset),
}, nil
}
func (d *driver) WriteStream(ctx context.Context, path string, offset int64, reader io.Reader) (totalRead int64, err error) {
buf := make([]byte, d.chunksize)
totalRead = 0
oid, err := d.getOid(path)
if err != nil {
switch err.(type) {
// Trying to write new object, generate new blob identifier for it
case storagedriver.PathNotFoundError:
oid = d.generateOid()
err = d.putOid(path, oid)
if err != nil {
return 0, err
}
default:
return 0, err
}
} else {
// Check total object size only for existing ones
totalSize, err := d.getXattrTotalSize(ctx, oid)
if err != nil {
return 0, err
}
// If offset if after the current object size, fill the gap with zeros
for totalSize < uint64(offset) {
sizeToWrite := d.chunksize
if totalSize-uint64(offset) < sizeToWrite {
sizeToWrite = totalSize - uint64(offset)
}
chunkName, chunkOffset := d.getChunkNameFromOffset(oid, uint64(totalSize))
err = d.Ioctx.Write(chunkName, buf[:sizeToWrite], uint64(chunkOffset))
if err != nil {
return totalRead, err
}
totalSize += sizeToWrite
}
}
// Writer
for {
// Align to chunk size
sizeRead := uint64(0)
sizeToRead := uint64(offset+totalRead) % d.chunksize
if sizeToRead == 0 {
sizeToRead = d.chunksize
}
// Read from `reader`
for sizeRead < sizeToRead {
nn, err := reader.Read(buf[sizeRead:sizeToRead])
sizeRead += uint64(nn)
if err != nil {
if err != io.EOF {
return totalRead, err
}
break
}
}
// End of file and nothing was read
if sizeRead == 0 {
break
}
// Write chunk object
chunkName, chunkOffset := d.getChunkNameFromOffset(oid, uint64(offset+totalRead))
err = d.Ioctx.Write(chunkName, buf[:sizeRead], uint64(chunkOffset))
if err != nil {
return totalRead, err
}
// Update total object size as xattr in the first chunk of the object
err = d.setXattrTotalSize(oid, uint64(offset+totalRead)+sizeRead)
if err != nil {
return totalRead, err
}
totalRead += int64(sizeRead)
// End of file
if sizeRead < sizeToRead {
break
}
}
return totalRead, nil
}
// Stat retrieves the FileInfo for the given path, including the current size
func (d *driver) Stat(ctx context.Context, path string) (storagedriver.FileInfo, error) {
// get oid from filename
oid, err := d.getOid(path)
if err != nil {
return nil, err
}
// the path is a virtual directory?
if oid == "" {
return storagedriver.FileInfoInternal{
FileInfoFields: storagedriver.FileInfoFields{
Path: path,
Size: 0,
IsDir: true,
},
}, nil
}
// stat first chunk
stat, err := d.Ioctx.Stat(oid + "-0")
if err != nil {
return nil, err
}
// get total size of chunked object
totalSize, err := d.getXattrTotalSize(ctx, oid)
if err != nil {
return nil, err
}
return storagedriver.FileInfoInternal{
FileInfoFields: storagedriver.FileInfoFields{
Path: path,
Size: int64(totalSize),
ModTime: stat.ModTime,
},
}, nil
}
// List returns a list of the objects that are direct descendants of the given path.
func (d *driver) List(ctx context.Context, dirPath string) ([]string, error) {
files, err := d.listDirectoryOid(dirPath)
if err != nil {
return nil, storagedriver.PathNotFoundError{Path: dirPath}
}
keys := make([]string, 0, len(files))
for k := range files {
if k != dirPath {
keys = append(keys, path.Join(dirPath, k))
}
}
return keys, nil
}
// Move moves an object stored at sourcePath to destPath, removing the original
// object.
func (d *driver) Move(ctx context.Context, sourcePath string, destPath string) error {
// Get oid
oid, err := d.getOid(sourcePath)
if err != nil {
return err
}
// Move reference
err = d.putOid(destPath, oid)
if err != nil {
return err
}
// Delete old reference
err = d.deleteOid(sourcePath)
if err != nil {
return err
}
return nil
}
// Delete recursively deletes all objects stored at "path" and its subpaths.
func (d *driver) Delete(ctx context.Context, objectPath string) error {
// Get oid
oid, err := d.getOid(objectPath)
if err != nil {
return err
}
// Deleting virtual directory
if oid == "" {
objects, err := d.listDirectoryOid(objectPath)
if err != nil {
return err
}
for object := range objects {
err = d.Delete(ctx, path.Join(objectPath, object))
if err != nil {
return err
}
}
} else {
// Delete object chunks
totalSize, err := d.getXattrTotalSize(ctx, oid)
if err != nil {
return err
}
for offset := uint64(0); offset < totalSize; offset += d.chunksize {
chunkName, _ := d.getChunkNameFromOffset(oid, offset)
err = d.Ioctx.Delete(chunkName)
if err != nil {
return err
}
}
// Delete reference
err = d.deleteOid(objectPath)
if err != nil {
return err
}
}
return nil
}
// URLFor returns a URL which may be used to retrieve the content stored at the given path.
// May return an UnsupportedMethodErr in certain StorageDriver implementations.
func (d *driver) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) {
return "", storagedriver.ErrUnsupportedMethod{}
}
// Generate a blob identifier
func (d *driver) generateOid() string {
return objectBlobPrefix + uuid.Generate().String()
}
// Reference a object and its hierarchy
func (d *driver) putOid(objectPath string, oid string) error {
directory := path.Dir(objectPath)
base := path.Base(objectPath)
createParentReference := true
// After creating this reference, skip the parents referencing since the
// hierarchy already exists
if oid == "" {
firstReference, err := d.Ioctx.GetOmapValues(directory, "", "", 1)
if (err == nil) && (len(firstReference) > 0) {
createParentReference = false
}
}
oids := map[string][]byte{
base: []byte(oid),
}
// Reference object
err := d.Ioctx.SetOmap(directory, oids)
if err != nil {
return err
}
// Esure parent virtual directories
if createParentReference {
return d.putOid(directory, "")
}
return nil
}
// Get the object identifier from an object name
func (d *driver) getOid(objectPath string) (string, error) {
directory := path.Dir(objectPath)
base := path.Base(objectPath)
files, err := d.Ioctx.GetOmapValues(directory, "", base, 1)
if (err != nil) || (files[base] == nil) {
return "", storagedriver.PathNotFoundError{Path: objectPath}
}
return string(files[base]), nil
}
// List the objects of a virtual directory
func (d *driver) listDirectoryOid(path string) (list map[string][]byte, err error) {
return d.Ioctx.GetAllOmapValues(path, "", "", defaultKeysFetched)
}
// Remove a file from the files hierarchy
func (d *driver) deleteOid(objectPath string) error {
// Remove object reference
directory := path.Dir(objectPath)
base := path.Base(objectPath)
err := d.Ioctx.RmOmapKeys(directory, []string{base})
if err != nil {
return err
}
// Remove virtual directory if empty (no more references)
firstReference, err := d.Ioctx.GetOmapValues(directory, "", "", 1)
if err != nil {
return err
}
if len(firstReference) == 0 {
// Delete omap
err := d.Ioctx.Delete(directory)
if err != nil {
return err
}
// Remove reference on parent omaps
if directory != "" {
return d.deleteOid(directory)
}
}
return nil
}
// Takes an offset in an chunked object and return the chunk name and a new
// offset in this chunk object
func (d *driver) getChunkNameFromOffset(oid string, offset uint64) (string, uint64) {
chunkID := offset / d.chunksize
chunkedOid := oid + "-" + strconv.FormatInt(int64(chunkID), 10)
chunkedOffset := offset % d.chunksize
return chunkedOid, chunkedOffset
}
// Set the total size of a chunked object `oid`
func (d *driver) setXattrTotalSize(oid string, size uint64) error {
// Convert uint64 `size` to []byte
xattr := make([]byte, binary.MaxVarintLen64)
binary.LittleEndian.PutUint64(xattr, size)
// Save the total size as a xattr in the first chunk
return d.Ioctx.SetXattr(oid+"-0", defaultXattrTotalSizeName, xattr)
}
// Get the total size of the chunked object `oid` stored as xattr
func (d *driver) getXattrTotalSize(ctx context.Context, oid string) (uint64, error) {
// Fetch xattr as []byte
xattr := make([]byte, binary.MaxVarintLen64)
xattrLength, err := d.Ioctx.GetXattr(oid+"-0", defaultXattrTotalSizeName, xattr)
if err != nil {
return 0, err
}
if xattrLength != len(xattr) {
context.GetLogger(ctx).Errorf("object %s xattr length mismatch: %d != %d", oid, xattrLength, len(xattr))
return 0, storagedriver.PathNotFoundError{Path: oid}
}
// Convert []byte as uint64
totalSize := binary.LittleEndian.Uint64(xattr)
return totalSize, nil
}

View file

@ -1,40 +0,0 @@
// +build include_rados
package rados
import (
"os"
"testing"
storagedriver "github.com/docker/distribution/registry/storage/driver"
"github.com/docker/distribution/registry/storage/driver/testsuites"
"gopkg.in/check.v1"
)
// Hook up gocheck into the "go test" runner.
func Test(t *testing.T) { check.TestingT(t) }
func init() {
poolname := os.Getenv("RADOS_POOL")
username := os.Getenv("RADOS_USER")
driverConstructor := func() (storagedriver.StorageDriver, error) {
parameters := DriverParameters{
poolname,
username,
defaultChunkSize,
}
return New(parameters)
}
skipCheck := func() string {
if poolname == "" {
return "RADOS_POOL must be set to run Rado tests"
}
return ""
}
testsuites.RegisterSuite(driverConstructor, skipCheck)
}

File diff suppressed because it is too large Load diff

View file

@ -27,9 +27,11 @@ func init() {
secretKey := os.Getenv("AWS_SECRET_KEY") secretKey := os.Getenv("AWS_SECRET_KEY")
bucket := os.Getenv("S3_BUCKET") bucket := os.Getenv("S3_BUCKET")
encrypt := os.Getenv("S3_ENCRYPT") encrypt := os.Getenv("S3_ENCRYPT")
keyID := os.Getenv("S3_KEY_ID")
secure := os.Getenv("S3_SECURE") secure := os.Getenv("S3_SECURE")
region := os.Getenv("AWS_REGION") region := os.Getenv("AWS_REGION")
root, err := ioutil.TempDir("", "driver-") root, err := ioutil.TempDir("", "driver-")
regionEndpoint := os.Getenv("REGION_ENDPOINT")
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -57,7 +59,9 @@ func init() {
secretKey, secretKey,
bucket, bucket,
region, region,
regionEndpoint,
encryptBool, encryptBool,
keyID,
secureBool, secureBool,
minChunkSize, minChunkSize,
rootDirectory, rootDirectory,

View file

@ -21,10 +21,8 @@ import (
"reflect" "reflect"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"github.com/Sirupsen/logrus"
"github.com/docker/goamz/aws" "github.com/docker/goamz/aws"
"github.com/docker/goamz/s3" "github.com/docker/goamz/s3"
@ -79,9 +77,6 @@ type driver struct {
Encrypt bool Encrypt bool
RootDirectory string RootDirectory string
StorageClass s3.StorageClass StorageClass s3.StorageClass
pool sync.Pool // pool []byte buffers used for WriteStream
zeros []byte // shared, zero-valued buffer used for WriteStream
} }
type baseEmbed struct { type baseEmbed struct {
@ -301,11 +296,6 @@ func New(params DriverParameters) (*Driver, error) {
Encrypt: params.Encrypt, Encrypt: params.Encrypt,
RootDirectory: params.RootDirectory, RootDirectory: params.RootDirectory,
StorageClass: params.StorageClass, StorageClass: params.StorageClass,
zeros: make([]byte, params.ChunkSize),
}
d.pool.New = func() interface{} {
return make([]byte, d.ChunkSize)
} }
return &Driver{ return &Driver{
@ -337,9 +327,9 @@ func (d *driver) PutContent(ctx context.Context, path string, contents []byte) e
return parseError(path, d.Bucket.Put(d.s3Path(path), contents, d.getContentType(), getPermissions(), d.getOptions())) return parseError(path, d.Bucket.Put(d.s3Path(path), contents, d.getContentType(), getPermissions(), d.getOptions()))
} }
// ReadStream retrieves an io.ReadCloser for the content stored at "path" with a // Reader retrieves an io.ReadCloser for the content stored at "path" with a
// given byte offset. // given byte offset.
func (d *driver) ReadStream(ctx context.Context, path string, offset int64) (io.ReadCloser, error) { func (d *driver) Reader(ctx context.Context, path string, offset int64) (io.ReadCloser, error) {
headers := make(http.Header) headers := make(http.Header)
headers.Add("Range", "bytes="+strconv.FormatInt(offset, 10)+"-") headers.Add("Range", "bytes="+strconv.FormatInt(offset, 10)+"-")
@ -354,343 +344,37 @@ func (d *driver) ReadStream(ctx context.Context, path string, offset int64) (io.
return resp.Body, nil return resp.Body, nil
} }
// WriteStream stores the contents of the provided io.Reader at a // Writer returns a FileWriter which will store the content written to it
// location designated by the given path. The driver will know it has // at the location designated by "path" after the call to Commit.
// received the full contents when the reader returns io.EOF. The number func (d *driver) Writer(ctx context.Context, path string, append bool) (storagedriver.FileWriter, error) {
// of successfully READ bytes will be returned, even if an error is key := d.s3Path(path)
// returned. May be used to resume writing a stream by providing a nonzero if !append {
// offset. Offsets past the current size will write from the position // TODO (brianbland): cancel other uploads at this path
// beyond the end of the file. multi, err := d.Bucket.InitMulti(key, d.getContentType(), getPermissions(), d.getOptions())
func (d *driver) WriteStream(ctx context.Context, path string, offset int64, reader io.Reader) (totalRead int64, err error) { if err != nil {
partNumber := 1 return nil, err
bytesRead := 0 }
var putErrChan chan error return d.newWriter(key, multi, nil), nil
parts := []s3.Part{} }
var part s3.Part multis, _, err := d.Bucket.ListMulti(key, "")
done := make(chan struct{}) // stopgap to free up waiting goroutines
multi, err := d.Bucket.InitMulti(d.s3Path(path), d.getContentType(), getPermissions(), d.getOptions())
if err != nil { if err != nil {
return 0, err return nil, parseError(path, err)
} }
for _, multi := range multis {
buf := d.getbuf() if key != multi.Key {
continue
// We never want to leave a dangling multipart upload, our only consistent state is
// when there is a whole object at path. This is in order to remain consistent with
// the stat call.
//
// Note that if the machine dies before executing the defer, we will be left with a dangling
// multipart upload, which will eventually be cleaned up, but we will lose all of the progress
// made prior to the machine crashing.
defer func() {
if putErrChan != nil {
if putErr := <-putErrChan; putErr != nil {
err = putErr
}
} }
parts, err := multi.ListParts()
if len(parts) > 0 {
if multi == nil {
// Parts should be empty if the multi is not initialized
panic("Unreachable")
} else {
if multi.Complete(parts) != nil {
multi.Abort()
}
}
}
d.putbuf(buf) // needs to be here to pick up new buf value
close(done) // free up any waiting goroutines
}()
// Fills from 0 to total from current
fromSmallCurrent := func(total int64) error {
current, err := d.ReadStream(ctx, path, 0)
if err != nil { if err != nil {
return err return nil, parseError(path, err)
} }
var multiSize int64
bytesRead = 0 for _, part := range parts {
for int64(bytesRead) < total { multiSize += part.Size
//The loop should very rarely enter a second iteration
nn, err := current.Read(buf[bytesRead:total])
bytesRead += nn
if err != nil {
if err != io.EOF {
return err
}
break
}
} }
return nil return d.newWriter(key, multi, parts), nil
} }
return nil, storagedriver.PathNotFoundError{Path: path}
// Fills from parameter to chunkSize from reader
fromReader := func(from int64) error {
bytesRead = 0
for from+int64(bytesRead) < d.ChunkSize {
nn, err := reader.Read(buf[from+int64(bytesRead):])
totalRead += int64(nn)
bytesRead += nn
if err != nil {
if err != io.EOF {
return err
}
break
}
}
if putErrChan == nil {
putErrChan = make(chan error)
} else {
if putErr := <-putErrChan; putErr != nil {
putErrChan = nil
return putErr
}
}
go func(bytesRead int, from int64, buf []byte) {
defer d.putbuf(buf) // this buffer gets dropped after this call
// DRAGONS(stevvooe): There are few things one might want to know
// about this section. First, the putErrChan is expecting an error
// and a nil or just a nil to come through the channel. This is
// covered by the silly defer below. The other aspect is the s3
// retry backoff to deal with RequestTimeout errors. Even though
// the underlying s3 library should handle it, it doesn't seem to
// be part of the shouldRetry function (see AdRoll/goamz/s3).
defer func() {
select {
case putErrChan <- nil: // for some reason, we do this no matter what.
case <-done:
return // ensure we don't leak the goroutine
}
}()
if bytesRead <= 0 {
return
}
var err error
var part s3.Part
loop:
for retries := 0; retries < 5; retries++ {
part, err = multi.PutPart(int(partNumber), bytes.NewReader(buf[0:int64(bytesRead)+from]))
if err == nil {
break // success!
}
// NOTE(stevvooe): This retry code tries to only retry under
// conditions where the s3 package does not. We may add s3
// error codes to the below if we see others bubble up in the
// application. Right now, the most troubling is
// RequestTimeout, which seems to only triggered when a tcp
// connection to s3 slows to a crawl. If the RequestTimeout
// ends up getting added to the s3 library and we don't see
// other errors, this retry loop can be removed.
switch err := err.(type) {
case *s3.Error:
switch err.Code {
case "RequestTimeout":
// allow retries on only this error.
default:
break loop
}
}
backoff := 100 * time.Millisecond * time.Duration(retries+1)
logrus.Errorf("error putting part, retrying after %v: %v", err, backoff.String())
time.Sleep(backoff)
}
if err != nil {
logrus.Errorf("error putting part, aborting: %v", err)
select {
case putErrChan <- err:
case <-done:
return // don't leak the goroutine
}
}
// parts and partNumber are safe, because this function is the
// only one modifying them and we force it to be executed
// serially.
parts = append(parts, part)
partNumber++
}(bytesRead, from, buf)
buf = d.getbuf() // use a new buffer for the next call
return nil
}
if offset > 0 {
resp, err := d.Bucket.Head(d.s3Path(path), nil)
if err != nil {
if s3Err, ok := err.(*s3.Error); !ok || s3Err.Code != "NoSuchKey" {
return 0, err
}
}
currentLength := int64(0)
if err == nil {
currentLength = resp.ContentLength
}
if currentLength >= offset {
if offset < d.ChunkSize {
// chunkSize > currentLength >= offset
if err = fromSmallCurrent(offset); err != nil {
return totalRead, err
}
if err = fromReader(offset); err != nil {
return totalRead, err
}
if totalRead+offset < d.ChunkSize {
return totalRead, nil
}
} else {
// currentLength >= offset >= chunkSize
_, part, err = multi.PutPartCopy(partNumber,
s3.CopyOptions{CopySourceOptions: "bytes=0-" + strconv.FormatInt(offset-1, 10)},
d.Bucket.Name+"/"+d.s3Path(path))
if err != nil {
return 0, err
}
parts = append(parts, part)
partNumber++
}
} else {
// Fills between parameters with 0s but only when to - from <= chunkSize
fromZeroFillSmall := func(from, to int64) error {
bytesRead = 0
for from+int64(bytesRead) < to {
nn, err := bytes.NewReader(d.zeros).Read(buf[from+int64(bytesRead) : to])
bytesRead += nn
if err != nil {
return err
}
}
return nil
}
// Fills between parameters with 0s, making new parts
fromZeroFillLarge := func(from, to int64) error {
bytesRead64 := int64(0)
for to-(from+bytesRead64) >= d.ChunkSize {
part, err := multi.PutPart(int(partNumber), bytes.NewReader(d.zeros))
if err != nil {
return err
}
bytesRead64 += d.ChunkSize
parts = append(parts, part)
partNumber++
}
return fromZeroFillSmall(0, (to-from)%d.ChunkSize)
}
// currentLength < offset
if currentLength < d.ChunkSize {
if offset < d.ChunkSize {
// chunkSize > offset > currentLength
if err = fromSmallCurrent(currentLength); err != nil {
return totalRead, err
}
if err = fromZeroFillSmall(currentLength, offset); err != nil {
return totalRead, err
}
if err = fromReader(offset); err != nil {
return totalRead, err
}
if totalRead+offset < d.ChunkSize {
return totalRead, nil
}
} else {
// offset >= chunkSize > currentLength
if err = fromSmallCurrent(currentLength); err != nil {
return totalRead, err
}
if err = fromZeroFillSmall(currentLength, d.ChunkSize); err != nil {
return totalRead, err
}
part, err = multi.PutPart(int(partNumber), bytes.NewReader(buf))
if err != nil {
return totalRead, err
}
parts = append(parts, part)
partNumber++
//Zero fill from chunkSize up to offset, then some reader
if err = fromZeroFillLarge(d.ChunkSize, offset); err != nil {
return totalRead, err
}
if err = fromReader(offset % d.ChunkSize); err != nil {
return totalRead, err
}
if totalRead+(offset%d.ChunkSize) < d.ChunkSize {
return totalRead, nil
}
}
} else {
// offset > currentLength >= chunkSize
_, part, err = multi.PutPartCopy(partNumber,
s3.CopyOptions{},
d.Bucket.Name+"/"+d.s3Path(path))
if err != nil {
return 0, err
}
parts = append(parts, part)
partNumber++
//Zero fill from currentLength up to offset, then some reader
if err = fromZeroFillLarge(currentLength, offset); err != nil {
return totalRead, err
}
if err = fromReader((offset - currentLength) % d.ChunkSize); err != nil {
return totalRead, err
}
if totalRead+((offset-currentLength)%d.ChunkSize) < d.ChunkSize {
return totalRead, nil
}
}
}
}
for {
if err = fromReader(0); err != nil {
return totalRead, err
}
if int64(bytesRead) < d.ChunkSize {
break
}
}
return totalRead, nil
} }
// Stat retrieves the FileInfo for the given path, including the current size // Stat retrieves the FileInfo for the given path, including the current size
@ -882,12 +566,181 @@ func (d *driver) getContentType() string {
return "application/octet-stream" return "application/octet-stream"
} }
// getbuf returns a buffer from the driver's pool with length d.ChunkSize. // writer attempts to upload parts to S3 in a buffered fashion where the last
func (d *driver) getbuf() []byte { // part is at least as large as the chunksize, so the multipart upload could be
return d.pool.Get().([]byte) // cleanly resumed in the future. This is violated if Close is called after less
// than a full chunk is written.
type writer struct {
driver *driver
key string
multi *s3.Multi
parts []s3.Part
size int64
readyPart []byte
pendingPart []byte
closed bool
committed bool
cancelled bool
} }
func (d *driver) putbuf(p []byte) { func (d *driver) newWriter(key string, multi *s3.Multi, parts []s3.Part) storagedriver.FileWriter {
copy(p, d.zeros) var size int64
d.pool.Put(p) for _, part := range parts {
size += part.Size
}
return &writer{
driver: d,
key: key,
multi: multi,
parts: parts,
size: size,
}
}
func (w *writer) Write(p []byte) (int, error) {
if w.closed {
return 0, fmt.Errorf("already closed")
} else if w.committed {
return 0, fmt.Errorf("already committed")
} else if w.cancelled {
return 0, fmt.Errorf("already cancelled")
}
// If the last written part is smaller than minChunkSize, we need to make a
// new multipart upload :sadface:
if len(w.parts) > 0 && int(w.parts[len(w.parts)-1].Size) < minChunkSize {
err := w.multi.Complete(w.parts)
if err != nil {
w.multi.Abort()
return 0, err
}
multi, err := w.driver.Bucket.InitMulti(w.key, w.driver.getContentType(), getPermissions(), w.driver.getOptions())
if err != nil {
return 0, err
}
w.multi = multi
// If the entire written file is smaller than minChunkSize, we need to make
// a new part from scratch :double sad face:
if w.size < minChunkSize {
contents, err := w.driver.Bucket.Get(w.key)
if err != nil {
return 0, err
}
w.parts = nil
w.readyPart = contents
} else {
// Otherwise we can use the old file as the new first part
_, part, err := multi.PutPartCopy(1, s3.CopyOptions{}, w.driver.Bucket.Name+"/"+w.key)
if err != nil {
return 0, err
}
w.parts = []s3.Part{part}
}
}
var n int
for len(p) > 0 {
// If no parts are ready to write, fill up the first part
if neededBytes := int(w.driver.ChunkSize) - len(w.readyPart); neededBytes > 0 {
if len(p) >= neededBytes {
w.readyPart = append(w.readyPart, p[:neededBytes]...)
n += neededBytes
p = p[neededBytes:]
} else {
w.readyPart = append(w.readyPart, p...)
n += len(p)
p = nil
}
}
if neededBytes := int(w.driver.ChunkSize) - len(w.pendingPart); neededBytes > 0 {
if len(p) >= neededBytes {
w.pendingPart = append(w.pendingPart, p[:neededBytes]...)
n += neededBytes
p = p[neededBytes:]
err := w.flushPart()
if err != nil {
w.size += int64(n)
return n, err
}
} else {
w.pendingPart = append(w.pendingPart, p...)
n += len(p)
p = nil
}
}
}
w.size += int64(n)
return n, nil
}
func (w *writer) Size() int64 {
return w.size
}
func (w *writer) Close() error {
if w.closed {
return fmt.Errorf("already closed")
}
w.closed = true
return w.flushPart()
}
func (w *writer) Cancel() error {
if w.closed {
return fmt.Errorf("already closed")
} else if w.committed {
return fmt.Errorf("already committed")
}
w.cancelled = true
err := w.multi.Abort()
return err
}
func (w *writer) Commit() error {
if w.closed {
return fmt.Errorf("already closed")
} else if w.committed {
return fmt.Errorf("already committed")
} else if w.cancelled {
return fmt.Errorf("already cancelled")
}
err := w.flushPart()
if err != nil {
return err
}
w.committed = true
err = w.multi.Complete(w.parts)
if err != nil {
w.multi.Abort()
return err
}
return nil
}
// flushPart flushes buffers to write a part to S3.
// Only called by Write (with both buffers full) and Close/Commit (always)
func (w *writer) flushPart() error {
if len(w.readyPart) == 0 && len(w.pendingPart) == 0 {
// nothing to write
return nil
}
if len(w.pendingPart) < int(w.driver.ChunkSize) {
// closing with a small pending part
// combine ready and pending to avoid writing a small part
w.readyPart = append(w.readyPart, w.pendingPart...)
w.pendingPart = nil
}
part, err := w.multi.PutPart(len(w.parts)+1, bytes.NewReader(w.readyPart))
if err != nil {
return err
}
w.parts = append(w.parts, part)
w.readyPart = w.pendingPart
w.pendingPart = nil
return nil
} }

View file

@ -34,7 +34,11 @@ func (version Version) Minor() uint {
const CurrentVersion Version = "0.1" const CurrentVersion Version = "0.1"
// StorageDriver defines methods that a Storage Driver must implement for a // StorageDriver defines methods that a Storage Driver must implement for a
// filesystem-like key/value object storage. // filesystem-like key/value object storage. Storage Drivers are automatically
// registered via an internal registration mechanism, and generally created
// via the StorageDriverFactory interface (https://godoc.org/github.com/docker/distribution/registry/storage/driver/factory).
// Please see the aforementioned factory package for example code showing how to get an instance
// of a StorageDriver
type StorageDriver interface { type StorageDriver interface {
// Name returns the human-readable "name" of the driver, useful in error // Name returns the human-readable "name" of the driver, useful in error
// messages and logging. By convention, this will just be the registration // messages and logging. By convention, this will just be the registration
@ -49,15 +53,14 @@ type StorageDriver interface {
// This should primarily be used for small objects. // This should primarily be used for small objects.
PutContent(ctx context.Context, path string, content []byte) error PutContent(ctx context.Context, path string, content []byte) error
// ReadStream retrieves an io.ReadCloser for the content stored at "path" // Reader retrieves an io.ReadCloser for the content stored at "path"
// with a given byte offset. // with a given byte offset.
// May be used to resume reading a stream by providing a nonzero offset. // May be used to resume reading a stream by providing a nonzero offset.
ReadStream(ctx context.Context, path string, offset int64) (io.ReadCloser, error) Reader(ctx context.Context, path string, offset int64) (io.ReadCloser, error)
// WriteStream stores the contents of the provided io.ReadCloser at a // Writer returns a FileWriter which will store the content written to it
// location designated by the given path. // at the location designated by "path" after the call to Commit.
// May be used to resume writing a stream by providing a nonzero offset. Writer(ctx context.Context, path string, append bool) (FileWriter, error)
WriteStream(ctx context.Context, path string, offset int64, reader io.Reader) (nn int64, err error)
// Stat retrieves the FileInfo for the given path, including the current // Stat retrieves the FileInfo for the given path, including the current
// size in bytes and the creation time. // size in bytes and the creation time.
@ -83,6 +86,25 @@ type StorageDriver interface {
URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error)
} }
// FileWriter provides an abstraction for an opened writable file-like object in
// the storage backend. The FileWriter must flush all content written to it on
// the call to Close, but is only required to make its content readable on a
// call to Commit.
type FileWriter interface {
io.WriteCloser
// Size returns the number of bytes written to this FileWriter.
Size() int64
// Cancel removes any written content from this FileWriter.
Cancel() error
// Commit flushes all content written to this FileWriter and makes it
// available for future calls to StorageDriver.GetContent and
// StorageDriver.Reader.
Commit() error
}
// PathRegexp is the regular expression which each file path must match. A // PathRegexp is the regular expression which each file path must match. A
// file path is absolute, beginning with a slash and containing a positive // file path is absolute, beginning with a slash and containing a positive
// number of path components separated by slashes, where each component is // number of path components separated by slashes, where each component is

View file

@ -16,8 +16,8 @@
package swift package swift
import ( import (
"bufio"
"bytes" "bytes"
"crypto/md5"
"crypto/rand" "crypto/rand"
"crypto/sha1" "crypto/sha1"
"crypto/tls" "crypto/tls"
@ -49,6 +49,9 @@ const defaultChunkSize = 20 * 1024 * 1024
// minChunkSize defines the minimum size of a segment // minChunkSize defines the minimum size of a segment
const minChunkSize = 1 << 20 const minChunkSize = 1 << 20
// contentType defines the Content-Type header associated with stored segments
const contentType = "application/octet-stream"
// readAfterWriteTimeout defines the time we wait before an object appears after having been uploaded // readAfterWriteTimeout defines the time we wait before an object appears after having been uploaded
var readAfterWriteTimeout = 15 * time.Second var readAfterWriteTimeout = 15 * time.Second
@ -66,6 +69,7 @@ type Parameters struct {
DomainID string DomainID string
TrustID string TrustID string
Region string Region string
AuthVersion int
Container string Container string
Prefix string Prefix string
InsecureSkipVerify bool InsecureSkipVerify bool
@ -171,6 +175,7 @@ func New(params Parameters) (*Driver, error) {
ApiKey: params.Password, ApiKey: params.Password,
AuthUrl: params.AuthURL, AuthUrl: params.AuthURL,
Region: params.Region, Region: params.Region,
AuthVersion: params.AuthVersion,
UserAgent: "distribution/" + version.Version, UserAgent: "distribution/" + version.Version,
Tenant: params.Tenant, Tenant: params.Tenant,
TenantId: params.TenantID, TenantId: params.TenantID,
@ -277,21 +282,21 @@ func (d *driver) GetContent(ctx context.Context, path string) ([]byte, error) {
if err == swift.ObjectNotFound { if err == swift.ObjectNotFound {
return nil, storagedriver.PathNotFoundError{Path: path} return nil, storagedriver.PathNotFoundError{Path: path}
} }
return content, nil return content, err
} }
// PutContent stores the []byte content at a location designated by "path". // PutContent stores the []byte content at a location designated by "path".
func (d *driver) PutContent(ctx context.Context, path string, contents []byte) error { func (d *driver) PutContent(ctx context.Context, path string, contents []byte) error {
err := d.Conn.ObjectPutBytes(d.Container, d.swiftPath(path), contents, d.getContentType()) err := d.Conn.ObjectPutBytes(d.Container, d.swiftPath(path), contents, contentType)
if err == swift.ObjectNotFound { if err == swift.ObjectNotFound {
return storagedriver.PathNotFoundError{Path: path} return storagedriver.PathNotFoundError{Path: path}
} }
return err return err
} }
// ReadStream retrieves an io.ReadCloser for the content stored at "path" with a // Reader retrieves an io.ReadCloser for the content stored at "path" with a
// given byte offset. // given byte offset.
func (d *driver) ReadStream(ctx context.Context, path string, offset int64) (io.ReadCloser, error) { func (d *driver) Reader(ctx context.Context, path string, offset int64) (io.ReadCloser, error) {
headers := make(swift.Headers) headers := make(swift.Headers)
headers["Range"] = "bytes=" + strconv.FormatInt(offset, 10) + "-" headers["Range"] = "bytes=" + strconv.FormatInt(offset, 10) + "-"
@ -305,224 +310,46 @@ func (d *driver) ReadStream(ctx context.Context, path string, offset int64) (io.
return file, err return file, err
} }
// WriteStream stores the contents of the provided io.Reader at a // Writer returns a FileWriter which will store the content written to it
// location designated by the given path. The driver will know it has // at the location designated by "path" after the call to Commit.
// received the full contents when the reader returns io.EOF. The number func (d *driver) Writer(ctx context.Context, path string, append bool) (storagedriver.FileWriter, error) {
// of successfully READ bytes will be returned, even if an error is
// returned. May be used to resume writing a stream by providing a nonzero
// offset. Offsets past the current size will write from the position
// beyond the end of the file.
func (d *driver) WriteStream(ctx context.Context, path string, offset int64, reader io.Reader) (int64, error) {
var ( var (
segments []swift.Object segments []swift.Object
multi io.Reader segmentsPath string
paddingReader io.Reader err error
currentLength int64
cursor int64
segmentPath string
) )
partNumber := 1 if !append {
chunkSize := int64(d.ChunkSize) segmentsPath, err = d.swiftSegmentPath(path)
zeroBuf := make([]byte, d.ChunkSize) if err != nil {
hash := md5.New() return nil, err
getSegment := func() string {
return fmt.Sprintf("%s/%016d", segmentPath, partNumber)
}
max := func(a int64, b int64) int64 {
if a > b {
return a
}
return b
}
createManifest := true
info, headers, err := d.Conn.Object(d.Container, d.swiftPath(path))
if err == nil {
manifest, ok := headers["X-Object-Manifest"]
if !ok {
if segmentPath, err = d.swiftSegmentPath(path); err != nil {
return 0, err
}
if err := d.Conn.ObjectMove(d.Container, d.swiftPath(path), d.Container, getSegment()); err != nil {
return 0, err
}
segments = append(segments, info)
} else {
_, segmentPath = parseManifest(manifest)
if segments, err = d.getAllSegments(segmentPath); err != nil {
return 0, err
}
createManifest = false
}
currentLength = info.Bytes
} else if err == swift.ObjectNotFound {
if segmentPath, err = d.swiftSegmentPath(path); err != nil {
return 0, err
} }
} else { } else {
return 0, err info, headers, err := d.Conn.Object(d.Container, d.swiftPath(path))
} if err == swift.ObjectNotFound {
return nil, storagedriver.PathNotFoundError{Path: path}
// First, we skip the existing segments that are not modified by this call } else if err != nil {
for i := range segments { return nil, err
if offset < cursor+segments[i].Bytes {
break
} }
cursor += segments[i].Bytes manifest, ok := headers["X-Object-Manifest"]
hash.Write([]byte(segments[i].Hash)) if !ok {
partNumber++ segmentsPath, err = d.swiftSegmentPath(path)
}
// We reached the end of the file but we haven't reached 'offset' yet
// Therefore we add blocks of zeros
if offset >= currentLength {
for offset-currentLength >= chunkSize {
// Insert a block a zero
headers, err := d.Conn.ObjectPut(d.Container, getSegment(), bytes.NewReader(zeroBuf), false, "", d.getContentType(), nil)
if err != nil { if err != nil {
if err == swift.ObjectNotFound { return nil, err
return 0, storagedriver.PathNotFoundError{Path: getSegment()}
}
return 0, err
} }
currentLength += chunkSize if err := d.Conn.ObjectMove(d.Container, d.swiftPath(path), d.Container, getSegmentPath(segmentsPath, len(segments))); err != nil {
partNumber++ return nil, err
hash.Write([]byte(headers["Etag"]))
}
cursor = currentLength
paddingReader = bytes.NewReader(zeroBuf)
} else if offset-cursor > 0 {
// Offset is inside the current segment : we need to read the
// data from the beginning of the segment to offset
file, _, err := d.Conn.ObjectOpen(d.Container, getSegment(), false, nil)
if err != nil {
if err == swift.ObjectNotFound {
return 0, storagedriver.PathNotFoundError{Path: getSegment()}
} }
return 0, err segments = []swift.Object{info}
} } else {
defer file.Close() _, segmentsPath = parseManifest(manifest)
paddingReader = file if segments, err = d.getAllSegments(segmentsPath); err != nil {
} return nil, err
readers := []io.Reader{}
if paddingReader != nil {
readers = append(readers, io.LimitReader(paddingReader, offset-cursor))
}
readers = append(readers, io.LimitReader(reader, chunkSize-(offset-cursor)))
multi = io.MultiReader(readers...)
writeSegment := func(segment string) (finished bool, bytesRead int64, err error) {
currentSegment, err := d.Conn.ObjectCreate(d.Container, segment, false, "", d.getContentType(), nil)
if err != nil {
if err == swift.ObjectNotFound {
return false, bytesRead, storagedriver.PathNotFoundError{Path: segment}
} }
return false, bytesRead, err
}
segmentHash := md5.New()
writer := io.MultiWriter(currentSegment, segmentHash)
n, err := io.Copy(writer, multi)
if err != nil {
return false, bytesRead, err
}
if n > 0 {
defer func() {
closeError := currentSegment.Close()
if err != nil {
err = closeError
}
hexHash := hex.EncodeToString(segmentHash.Sum(nil))
hash.Write([]byte(hexHash))
}()
bytesRead += n - max(0, offset-cursor)
}
if n < chunkSize {
// We wrote all the data
if cursor+n < currentLength {
// Copy the end of the chunk
headers := make(swift.Headers)
headers["Range"] = "bytes=" + strconv.FormatInt(cursor+n, 10) + "-" + strconv.FormatInt(cursor+chunkSize, 10)
file, _, err := d.Conn.ObjectOpen(d.Container, d.swiftPath(path), false, headers)
if err != nil {
if err == swift.ObjectNotFound {
return false, bytesRead, storagedriver.PathNotFoundError{Path: path}
}
return false, bytesRead, err
}
_, copyErr := io.Copy(writer, file)
if err := file.Close(); err != nil {
if err == swift.ObjectNotFound {
return false, bytesRead, storagedriver.PathNotFoundError{Path: path}
}
return false, bytesRead, err
}
if copyErr != nil {
return false, bytesRead, copyErr
}
}
return true, bytesRead, nil
}
multi = io.LimitReader(reader, chunkSize)
cursor += chunkSize
partNumber++
return false, bytesRead, nil
}
finished := false
read := int64(0)
bytesRead := int64(0)
for finished == false {
finished, read, err = writeSegment(getSegment())
bytesRead += read
if err != nil {
return bytesRead, err
} }
} }
for ; partNumber < len(segments); partNumber++ { return d.newWriter(path, segmentsPath, segments), nil
hash.Write([]byte(segments[partNumber].Hash))
}
if createManifest {
if err := d.createManifest(path, d.Container+"/"+segmentPath); err != nil {
return 0, err
}
}
expectedHash := hex.EncodeToString(hash.Sum(nil))
waitingTime := readAfterWriteWait
endTime := time.Now().Add(readAfterWriteTimeout)
for {
var infos swift.Object
if infos, _, err = d.Conn.Object(d.Container, d.swiftPath(path)); err == nil {
if strings.Trim(infos.Hash, "\"") == expectedHash {
return bytesRead, nil
}
err = fmt.Errorf("Timeout expired while waiting for segments of %s to show up", path)
}
if time.Now().Add(waitingTime).After(endTime) {
break
}
time.Sleep(waitingTime)
waitingTime *= 2
}
return bytesRead, err
} }
// Stat retrieves the FileInfo for the given path, including the current size // Stat retrieves the FileInfo for the given path, including the current size
@ -551,23 +378,26 @@ func (d *driver) Stat(ctx context.Context, path string) (storagedriver.FileInfo,
fi.IsDir = true fi.IsDir = true
return storagedriver.FileInfoInternal{FileInfoFields: fi}, nil return storagedriver.FileInfoInternal{FileInfoFields: fi}, nil
} else if obj.Name == swiftPath { } else if obj.Name == swiftPath {
// On Swift 1.12, the 'bytes' field is always 0 // The file exists. But on Swift 1.12, the 'bytes' field is always 0 so
// so we need to do a second HEAD request // we need to do a separate HEAD request.
info, _, err := d.Conn.Object(d.Container, swiftPath) break
if err != nil {
if err == swift.ObjectNotFound {
return nil, storagedriver.PathNotFoundError{Path: path}
}
return nil, err
}
fi.IsDir = false
fi.Size = info.Bytes
fi.ModTime = info.LastModified
return storagedriver.FileInfoInternal{FileInfoFields: fi}, nil
} }
} }
return nil, storagedriver.PathNotFoundError{Path: path} //Don't trust an empty `objects` slice. A container listing can be
//outdated. For files, we can make a HEAD request on the object which
//reports existence (at least) much more reliably.
info, _, err := d.Conn.Object(d.Container, swiftPath)
if err != nil {
if err == swift.ObjectNotFound {
return nil, storagedriver.PathNotFoundError{Path: path}
}
return nil, err
}
fi.IsDir = false
fi.Size = info.Bytes
fi.ModTime = info.LastModified
return storagedriver.FileInfoInternal{FileInfoFields: fi}, nil
} }
// List returns a list of the objects that are direct descendants of the given path. // List returns a list of the objects that are direct descendants of the given path.
@ -763,22 +593,59 @@ func (d *driver) swiftSegmentPath(path string) (string, error) {
return strings.TrimLeft(strings.TrimRight(d.Prefix+"/segments/"+path[0:3]+"/"+path[3:], "/"), "/"), nil return strings.TrimLeft(strings.TrimRight(d.Prefix+"/segments/"+path[0:3]+"/"+path[3:], "/"), "/"), nil
} }
func (d *driver) getContentType() string {
return "application/octet-stream"
}
func (d *driver) getAllSegments(path string) ([]swift.Object, error) { func (d *driver) getAllSegments(path string) ([]swift.Object, error) {
//a simple container listing works 99.9% of the time
segments, err := d.Conn.ObjectsAll(d.Container, &swift.ObjectsOpts{Prefix: path}) segments, err := d.Conn.ObjectsAll(d.Container, &swift.ObjectsOpts{Prefix: path})
if err == swift.ContainerNotFound { if err != nil {
return nil, storagedriver.PathNotFoundError{Path: path} if err == swift.ContainerNotFound {
return nil, storagedriver.PathNotFoundError{Path: path}
}
return nil, err
}
//build a lookup table by object name
hasObjectName := make(map[string]struct{})
for _, segment := range segments {
hasObjectName[segment.Name] = struct{}{}
}
//The container listing might be outdated (i.e. not contain all existing
//segment objects yet) because of temporary inconsistency (Swift is only
//eventually consistent!). Check its completeness.
segmentNumber := 0
for {
segmentNumber++
segmentPath := getSegmentPath(path, segmentNumber)
if _, seen := hasObjectName[segmentPath]; seen {
continue
}
//This segment is missing in the container listing. Use a more reliable
//request to check its existence. (HEAD requests on segments are
//guaranteed to return the correct metadata, except for the pathological
//case of an outage of large parts of the Swift cluster or its network,
//since every segment is only written once.)
segment, _, err := d.Conn.Object(d.Container, segmentPath)
switch err {
case nil:
//found new segment -> keep going, more might be missing
segments = append(segments, segment)
continue
case swift.ObjectNotFound:
//This segment is missing. Since we upload segments sequentially,
//there won't be any more segments after it.
return segments, nil
default:
return nil, err //unexpected error
}
} }
return segments, err
} }
func (d *driver) createManifest(path string, segments string) error { func (d *driver) createManifest(path string, segments string) error {
headers := make(swift.Headers) headers := make(swift.Headers)
headers["X-Object-Manifest"] = segments headers["X-Object-Manifest"] = segments
manifest, err := d.Conn.ObjectCreate(d.Container, d.swiftPath(path), false, "", d.getContentType(), headers) manifest, err := d.Conn.ObjectCreate(d.Container, d.swiftPath(path), false, "", contentType, headers)
if err != nil { if err != nil {
if err == swift.ObjectNotFound { if err == swift.ObjectNotFound {
return storagedriver.PathNotFoundError{Path: path} return storagedriver.PathNotFoundError{Path: path}
@ -810,3 +677,159 @@ func generateSecret() (string, error) {
} }
return hex.EncodeToString(secretBytes[:]), nil return hex.EncodeToString(secretBytes[:]), nil
} }
func getSegmentPath(segmentsPath string, partNumber int) string {
return fmt.Sprintf("%s/%016d", segmentsPath, partNumber)
}
type writer struct {
driver *driver
path string
segmentsPath string
size int64
bw *bufio.Writer
closed bool
committed bool
cancelled bool
}
func (d *driver) newWriter(path, segmentsPath string, segments []swift.Object) storagedriver.FileWriter {
var size int64
for _, segment := range segments {
size += segment.Bytes
}
return &writer{
driver: d,
path: path,
segmentsPath: segmentsPath,
size: size,
bw: bufio.NewWriterSize(&segmentWriter{
conn: d.Conn,
container: d.Container,
segmentsPath: segmentsPath,
segmentNumber: len(segments) + 1,
maxChunkSize: d.ChunkSize,
}, d.ChunkSize),
}
}
func (w *writer) Write(p []byte) (int, error) {
if w.closed {
return 0, fmt.Errorf("already closed")
} else if w.committed {
return 0, fmt.Errorf("already committed")
} else if w.cancelled {
return 0, fmt.Errorf("already cancelled")
}
n, err := w.bw.Write(p)
w.size += int64(n)
return n, err
}
func (w *writer) Size() int64 {
return w.size
}
func (w *writer) Close() error {
if w.closed {
return fmt.Errorf("already closed")
}
if err := w.bw.Flush(); err != nil {
return err
}
if !w.committed && !w.cancelled {
if err := w.driver.createManifest(w.path, w.driver.Container+"/"+w.segmentsPath); err != nil {
return err
}
if err := w.waitForSegmentsToShowUp(); err != nil {
return err
}
}
w.closed = true
return nil
}
func (w *writer) Cancel() error {
if w.closed {
return fmt.Errorf("already closed")
} else if w.committed {
return fmt.Errorf("already committed")
}
w.cancelled = true
return w.driver.Delete(context.Background(), w.path)
}
func (w *writer) Commit() error {
if w.closed {
return fmt.Errorf("already closed")
} else if w.committed {
return fmt.Errorf("already committed")
} else if w.cancelled {
return fmt.Errorf("already cancelled")
}
if err := w.bw.Flush(); err != nil {
return err
}
if err := w.driver.createManifest(w.path, w.driver.Container+"/"+w.segmentsPath); err != nil {
return err
}
w.committed = true
return w.waitForSegmentsToShowUp()
}
func (w *writer) waitForSegmentsToShowUp() error {
var err error
waitingTime := readAfterWriteWait
endTime := time.Now().Add(readAfterWriteTimeout)
for {
var info swift.Object
if info, _, err = w.driver.Conn.Object(w.driver.Container, w.driver.swiftPath(w.path)); err == nil {
if info.Bytes == w.size {
break
}
err = fmt.Errorf("Timeout expired while waiting for segments of %s to show up", w.path)
}
if time.Now().Add(waitingTime).After(endTime) {
break
}
time.Sleep(waitingTime)
waitingTime *= 2
}
return err
}
type segmentWriter struct {
conn swift.Connection
container string
segmentsPath string
segmentNumber int
maxChunkSize int
}
func (sw *segmentWriter) Write(p []byte) (int, error) {
n := 0
for offset := 0; offset < len(p); offset += sw.maxChunkSize {
chunkSize := sw.maxChunkSize
if offset+chunkSize > len(p) {
chunkSize = len(p) - offset
}
_, err := sw.conn.ObjectPut(sw.container, getSegmentPath(sw.segmentsPath, sw.segmentNumber), bytes.NewReader(p[offset:offset+chunkSize]), false, "", contentType, nil)
if err != nil {
return n, err
}
sw.segmentNumber++
n += chunkSize
}
return n, nil
}

View file

@ -33,6 +33,7 @@ func init() {
trustID string trustID string
container string container string
region string region string
AuthVersion int
insecureSkipVerify bool insecureSkipVerify bool
secretKey string secretKey string
accessKey string accessKey string
@ -52,6 +53,7 @@ func init() {
trustID = os.Getenv("SWIFT_TRUST_ID") trustID = os.Getenv("SWIFT_TRUST_ID")
container = os.Getenv("SWIFT_CONTAINER_NAME") container = os.Getenv("SWIFT_CONTAINER_NAME")
region = os.Getenv("SWIFT_REGION_NAME") region = os.Getenv("SWIFT_REGION_NAME")
AuthVersion, _ = strconv.Atoi(os.Getenv("SWIFT_AUTH_VERSION"))
insecureSkipVerify, _ = strconv.ParseBool(os.Getenv("SWIFT_INSECURESKIPVERIFY")) insecureSkipVerify, _ = strconv.ParseBool(os.Getenv("SWIFT_INSECURESKIPVERIFY"))
secretKey = os.Getenv("SWIFT_SECRET_KEY") secretKey = os.Getenv("SWIFT_SECRET_KEY")
accessKey = os.Getenv("SWIFT_ACCESS_KEY") accessKey = os.Getenv("SWIFT_ACCESS_KEY")
@ -85,6 +87,7 @@ func init() {
domainID, domainID,
trustID, trustID,
region, region,
AuthVersion,
container, container,
root, root,
insecureSkipVerify, insecureSkipVerify,

View file

@ -282,11 +282,19 @@ func (suite *DriverSuite) TestWriteReadLargeStreams(c *check.C) {
var fileSize int64 = 5 * 1024 * 1024 * 1024 var fileSize int64 = 5 * 1024 * 1024 * 1024
contents := newRandReader(fileSize) contents := newRandReader(fileSize)
written, err := suite.StorageDriver.WriteStream(suite.ctx, filename, 0, io.TeeReader(contents, checksum))
writer, err := suite.StorageDriver.Writer(suite.ctx, filename, false)
c.Assert(err, check.IsNil)
written, err := io.Copy(writer, io.TeeReader(contents, checksum))
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(written, check.Equals, fileSize) c.Assert(written, check.Equals, fileSize)
reader, err := suite.StorageDriver.ReadStream(suite.ctx, filename, 0) err = writer.Commit()
c.Assert(err, check.IsNil)
err = writer.Close()
c.Assert(err, check.IsNil)
reader, err := suite.StorageDriver.Reader(suite.ctx, filename, 0)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
defer reader.Close() defer reader.Close()
@ -296,9 +304,9 @@ func (suite *DriverSuite) TestWriteReadLargeStreams(c *check.C) {
c.Assert(writtenChecksum.Sum(nil), check.DeepEquals, checksum.Sum(nil)) c.Assert(writtenChecksum.Sum(nil), check.DeepEquals, checksum.Sum(nil))
} }
// TestReadStreamWithOffset tests that the appropriate data is streamed when // TestReaderWithOffset tests that the appropriate data is streamed when
// reading with a given offset. // reading with a given offset.
func (suite *DriverSuite) TestReadStreamWithOffset(c *check.C) { func (suite *DriverSuite) TestReaderWithOffset(c *check.C) {
filename := randomPath(32) filename := randomPath(32)
defer suite.deletePath(c, firstPart(filename)) defer suite.deletePath(c, firstPart(filename))
@ -311,7 +319,7 @@ func (suite *DriverSuite) TestReadStreamWithOffset(c *check.C) {
err := suite.StorageDriver.PutContent(suite.ctx, filename, append(append(contentsChunk1, contentsChunk2...), contentsChunk3...)) err := suite.StorageDriver.PutContent(suite.ctx, filename, append(append(contentsChunk1, contentsChunk2...), contentsChunk3...))
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
reader, err := suite.StorageDriver.ReadStream(suite.ctx, filename, 0) reader, err := suite.StorageDriver.Reader(suite.ctx, filename, 0)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
defer reader.Close() defer reader.Close()
@ -320,7 +328,7 @@ func (suite *DriverSuite) TestReadStreamWithOffset(c *check.C) {
c.Assert(readContents, check.DeepEquals, append(append(contentsChunk1, contentsChunk2...), contentsChunk3...)) c.Assert(readContents, check.DeepEquals, append(append(contentsChunk1, contentsChunk2...), contentsChunk3...))
reader, err = suite.StorageDriver.ReadStream(suite.ctx, filename, chunkSize) reader, err = suite.StorageDriver.Reader(suite.ctx, filename, chunkSize)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
defer reader.Close() defer reader.Close()
@ -329,7 +337,7 @@ func (suite *DriverSuite) TestReadStreamWithOffset(c *check.C) {
c.Assert(readContents, check.DeepEquals, append(contentsChunk2, contentsChunk3...)) c.Assert(readContents, check.DeepEquals, append(contentsChunk2, contentsChunk3...))
reader, err = suite.StorageDriver.ReadStream(suite.ctx, filename, chunkSize*2) reader, err = suite.StorageDriver.Reader(suite.ctx, filename, chunkSize*2)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
defer reader.Close() defer reader.Close()
@ -338,7 +346,7 @@ func (suite *DriverSuite) TestReadStreamWithOffset(c *check.C) {
c.Assert(readContents, check.DeepEquals, contentsChunk3) c.Assert(readContents, check.DeepEquals, contentsChunk3)
// Ensure we get invalid offest for negative offsets. // Ensure we get invalid offest for negative offsets.
reader, err = suite.StorageDriver.ReadStream(suite.ctx, filename, -1) reader, err = suite.StorageDriver.Reader(suite.ctx, filename, -1)
c.Assert(err, check.FitsTypeOf, storagedriver.InvalidOffsetError{}) c.Assert(err, check.FitsTypeOf, storagedriver.InvalidOffsetError{})
c.Assert(err.(storagedriver.InvalidOffsetError).Offset, check.Equals, int64(-1)) c.Assert(err.(storagedriver.InvalidOffsetError).Offset, check.Equals, int64(-1))
c.Assert(err.(storagedriver.InvalidOffsetError).Path, check.Equals, filename) c.Assert(err.(storagedriver.InvalidOffsetError).Path, check.Equals, filename)
@ -347,7 +355,7 @@ func (suite *DriverSuite) TestReadStreamWithOffset(c *check.C) {
// Read past the end of the content and make sure we get a reader that // Read past the end of the content and make sure we get a reader that
// returns 0 bytes and io.EOF // returns 0 bytes and io.EOF
reader, err = suite.StorageDriver.ReadStream(suite.ctx, filename, chunkSize*3) reader, err = suite.StorageDriver.Reader(suite.ctx, filename, chunkSize*3)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
defer reader.Close() defer reader.Close()
@ -357,7 +365,7 @@ func (suite *DriverSuite) TestReadStreamWithOffset(c *check.C) {
c.Assert(n, check.Equals, 0) c.Assert(n, check.Equals, 0)
// Check the N-1 boundary condition, ensuring we get 1 byte then io.EOF. // Check the N-1 boundary condition, ensuring we get 1 byte then io.EOF.
reader, err = suite.StorageDriver.ReadStream(suite.ctx, filename, chunkSize*3-1) reader, err = suite.StorageDriver.Reader(suite.ctx, filename, chunkSize*3-1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
defer reader.Close() defer reader.Close()
@ -395,78 +403,51 @@ func (suite *DriverSuite) testContinueStreamAppend(c *check.C, chunkSize int64)
contentsChunk1 := randomContents(chunkSize) contentsChunk1 := randomContents(chunkSize)
contentsChunk2 := randomContents(chunkSize) contentsChunk2 := randomContents(chunkSize)
contentsChunk3 := randomContents(chunkSize) contentsChunk3 := randomContents(chunkSize)
contentsChunk4 := randomContents(chunkSize)
zeroChunk := make([]byte, int64(chunkSize))
fullContents := append(append(contentsChunk1, contentsChunk2...), contentsChunk3...) fullContents := append(append(contentsChunk1, contentsChunk2...), contentsChunk3...)
nn, err := suite.StorageDriver.WriteStream(suite.ctx, filename, 0, bytes.NewReader(contentsChunk1)) writer, err := suite.StorageDriver.Writer(suite.ctx, filename, false)
c.Assert(err, check.IsNil)
nn, err := io.Copy(writer, bytes.NewReader(contentsChunk1))
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(nn, check.Equals, int64(len(contentsChunk1))) c.Assert(nn, check.Equals, int64(len(contentsChunk1)))
fi, err := suite.StorageDriver.Stat(suite.ctx, filename) err = writer.Close()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(fi, check.NotNil)
c.Assert(fi.Size(), check.Equals, int64(len(contentsChunk1)))
nn, err = suite.StorageDriver.WriteStream(suite.ctx, filename, fi.Size(), bytes.NewReader(contentsChunk2)) curSize := writer.Size()
c.Assert(curSize, check.Equals, int64(len(contentsChunk1)))
writer, err = suite.StorageDriver.Writer(suite.ctx, filename, true)
c.Assert(err, check.IsNil)
c.Assert(writer.Size(), check.Equals, curSize)
nn, err = io.Copy(writer, bytes.NewReader(contentsChunk2))
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(nn, check.Equals, int64(len(contentsChunk2))) c.Assert(nn, check.Equals, int64(len(contentsChunk2)))
fi, err = suite.StorageDriver.Stat(suite.ctx, filename) err = writer.Close()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(fi, check.NotNil)
c.Assert(fi.Size(), check.Equals, 2*chunkSize)
// Test re-writing the last chunk curSize = writer.Size()
nn, err = suite.StorageDriver.WriteStream(suite.ctx, filename, fi.Size()-chunkSize, bytes.NewReader(contentsChunk2)) c.Assert(curSize, check.Equals, 2*chunkSize)
c.Assert(err, check.IsNil)
c.Assert(nn, check.Equals, int64(len(contentsChunk2)))
fi, err = suite.StorageDriver.Stat(suite.ctx, filename) writer, err = suite.StorageDriver.Writer(suite.ctx, filename, true)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(fi, check.NotNil) c.Assert(writer.Size(), check.Equals, curSize)
c.Assert(fi.Size(), check.Equals, 2*chunkSize)
nn, err = suite.StorageDriver.WriteStream(suite.ctx, filename, fi.Size(), bytes.NewReader(fullContents[fi.Size():])) nn, err = io.Copy(writer, bytes.NewReader(fullContents[curSize:]))
c.Assert(err, check.IsNil)
c.Assert(nn, check.Equals, int64(len(fullContents[curSize:])))
err = writer.Commit()
c.Assert(err, check.IsNil)
err = writer.Close()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(nn, check.Equals, int64(len(fullContents[fi.Size():])))
received, err := suite.StorageDriver.GetContent(suite.ctx, filename) received, err := suite.StorageDriver.GetContent(suite.ctx, filename)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(received, check.DeepEquals, fullContents) c.Assert(received, check.DeepEquals, fullContents)
// Writing past size of file extends file (no offset error). We would like
// to write chunk 4 one chunk length past chunk 3. It should be successful
// and the resulting file will be 5 chunks long, with a chunk of all
// zeros.
fullContents = append(fullContents, zeroChunk...)
fullContents = append(fullContents, contentsChunk4...)
nn, err = suite.StorageDriver.WriteStream(suite.ctx, filename, int64(len(fullContents))-chunkSize, bytes.NewReader(contentsChunk4))
c.Assert(err, check.IsNil)
c.Assert(nn, check.Equals, chunkSize)
fi, err = suite.StorageDriver.Stat(suite.ctx, filename)
c.Assert(err, check.IsNil)
c.Assert(fi, check.NotNil)
c.Assert(fi.Size(), check.Equals, int64(len(fullContents)))
received, err = suite.StorageDriver.GetContent(suite.ctx, filename)
c.Assert(err, check.IsNil)
c.Assert(len(received), check.Equals, len(fullContents))
c.Assert(received[chunkSize*3:chunkSize*4], check.DeepEquals, zeroChunk)
c.Assert(received[chunkSize*4:chunkSize*5], check.DeepEquals, contentsChunk4)
c.Assert(received, check.DeepEquals, fullContents)
// Ensure that negative offsets return correct error.
nn, err = suite.StorageDriver.WriteStream(suite.ctx, filename, -1, bytes.NewReader(zeroChunk))
c.Assert(err, check.NotNil)
c.Assert(err, check.FitsTypeOf, storagedriver.InvalidOffsetError{})
c.Assert(err.(storagedriver.InvalidOffsetError).Path, check.Equals, filename)
c.Assert(err.(storagedriver.InvalidOffsetError).Offset, check.Equals, int64(-1))
c.Assert(strings.Contains(err.Error(), suite.Name()), check.Equals, true)
} }
// TestReadNonexistentStream tests that reading a stream for a nonexistent path // TestReadNonexistentStream tests that reading a stream for a nonexistent path
@ -474,12 +455,12 @@ func (suite *DriverSuite) testContinueStreamAppend(c *check.C, chunkSize int64)
func (suite *DriverSuite) TestReadNonexistentStream(c *check.C) { func (suite *DriverSuite) TestReadNonexistentStream(c *check.C) {
filename := randomPath(32) filename := randomPath(32)
_, err := suite.StorageDriver.ReadStream(suite.ctx, filename, 0) _, err := suite.StorageDriver.Reader(suite.ctx, filename, 0)
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
c.Assert(err, check.FitsTypeOf, storagedriver.PathNotFoundError{}) c.Assert(err, check.FitsTypeOf, storagedriver.PathNotFoundError{})
c.Assert(strings.Contains(err.Error(), suite.Name()), check.Equals, true) c.Assert(strings.Contains(err.Error(), suite.Name()), check.Equals, true)
_, err = suite.StorageDriver.ReadStream(suite.ctx, filename, 64) _, err = suite.StorageDriver.Reader(suite.ctx, filename, 64)
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
c.Assert(err, check.FitsTypeOf, storagedriver.PathNotFoundError{}) c.Assert(err, check.FitsTypeOf, storagedriver.PathNotFoundError{})
c.Assert(strings.Contains(err.Error(), suite.Name()), check.Equals, true) c.Assert(strings.Contains(err.Error(), suite.Name()), check.Equals, true)
@ -800,7 +781,7 @@ func (suite *DriverSuite) TestStatCall(c *check.C) {
// TestPutContentMultipleTimes checks that if storage driver can overwrite the content // TestPutContentMultipleTimes checks that if storage driver can overwrite the content
// in the subsequent puts. Validates that PutContent does not have to work // in the subsequent puts. Validates that PutContent does not have to work
// with an offset like WriteStream does and overwrites the file entirely // with an offset like Writer does and overwrites the file entirely
// rather than writing the data to the [0,len(data)) of the file. // rather than writing the data to the [0,len(data)) of the file.
func (suite *DriverSuite) TestPutContentMultipleTimes(c *check.C) { func (suite *DriverSuite) TestPutContentMultipleTimes(c *check.C) {
filename := randomPath(32) filename := randomPath(32)
@ -842,7 +823,7 @@ func (suite *DriverSuite) TestConcurrentStreamReads(c *check.C) {
readContents := func() { readContents := func() {
defer wg.Done() defer wg.Done()
offset := rand.Int63n(int64(len(contents))) offset := rand.Int63n(int64(len(contents)))
reader, err := suite.StorageDriver.ReadStream(suite.ctx, filename, offset) reader, err := suite.StorageDriver.Reader(suite.ctx, filename, offset)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
readContents, err := ioutil.ReadAll(reader) readContents, err := ioutil.ReadAll(reader)
@ -858,7 +839,7 @@ func (suite *DriverSuite) TestConcurrentStreamReads(c *check.C) {
} }
// TestConcurrentFileStreams checks that multiple *os.File objects can be passed // TestConcurrentFileStreams checks that multiple *os.File objects can be passed
// in to WriteStream concurrently without hanging. // in to Writer concurrently without hanging.
func (suite *DriverSuite) TestConcurrentFileStreams(c *check.C) { func (suite *DriverSuite) TestConcurrentFileStreams(c *check.C) {
numStreams := 32 numStreams := 32
@ -882,53 +863,54 @@ func (suite *DriverSuite) TestConcurrentFileStreams(c *check.C) {
wg.Wait() wg.Wait()
} }
// TODO (brianbland): evaluate the relevancy of this test
// TestEventualConsistency checks that if stat says that a file is a certain size, then // TestEventualConsistency checks that if stat says that a file is a certain size, then
// you can freely read from the file (this is the only guarantee that the driver needs to provide) // you can freely read from the file (this is the only guarantee that the driver needs to provide)
func (suite *DriverSuite) TestEventualConsistency(c *check.C) { // func (suite *DriverSuite) TestEventualConsistency(c *check.C) {
if testing.Short() { // if testing.Short() {
c.Skip("Skipping test in short mode") // c.Skip("Skipping test in short mode")
} // }
//
filename := randomPath(32) // filename := randomPath(32)
defer suite.deletePath(c, firstPart(filename)) // defer suite.deletePath(c, firstPart(filename))
//
var offset int64 // var offset int64
var misswrites int // var misswrites int
var chunkSize int64 = 32 // var chunkSize int64 = 32
//
for i := 0; i < 1024; i++ { // for i := 0; i < 1024; i++ {
contents := randomContents(chunkSize) // contents := randomContents(chunkSize)
read, err := suite.StorageDriver.WriteStream(suite.ctx, filename, offset, bytes.NewReader(contents)) // read, err := suite.StorageDriver.Writer(suite.ctx, filename, offset, bytes.NewReader(contents))
c.Assert(err, check.IsNil) // c.Assert(err, check.IsNil)
//
fi, err := suite.StorageDriver.Stat(suite.ctx, filename) // fi, err := suite.StorageDriver.Stat(suite.ctx, filename)
c.Assert(err, check.IsNil) // c.Assert(err, check.IsNil)
//
// We are most concerned with being able to read data as soon as Stat declares // // We are most concerned with being able to read data as soon as Stat declares
// it is uploaded. This is the strongest guarantee that some drivers (that guarantee // // it is uploaded. This is the strongest guarantee that some drivers (that guarantee
// at best eventual consistency) absolutely need to provide. // // at best eventual consistency) absolutely need to provide.
if fi.Size() == offset+chunkSize { // if fi.Size() == offset+chunkSize {
reader, err := suite.StorageDriver.ReadStream(suite.ctx, filename, offset) // reader, err := suite.StorageDriver.Reader(suite.ctx, filename, offset)
c.Assert(err, check.IsNil) // c.Assert(err, check.IsNil)
//
readContents, err := ioutil.ReadAll(reader) // readContents, err := ioutil.ReadAll(reader)
c.Assert(err, check.IsNil) // c.Assert(err, check.IsNil)
//
c.Assert(readContents, check.DeepEquals, contents) // c.Assert(readContents, check.DeepEquals, contents)
//
reader.Close() // reader.Close()
offset += read // offset += read
} else { // } else {
misswrites++ // misswrites++
} // }
} // }
//
if misswrites > 0 { // if misswrites > 0 {
c.Log("There were " + string(misswrites) + " occurrences of a write not being instantly available.") // c.Log("There were " + string(misswrites) + " occurrences of a write not being instantly available.")
} // }
//
c.Assert(misswrites, check.Not(check.Equals), 1024) // c.Assert(misswrites, check.Not(check.Equals), 1024)
} // }
// BenchmarkPutGetEmptyFiles benchmarks PutContent/GetContent for 0B files // BenchmarkPutGetEmptyFiles benchmarks PutContent/GetContent for 0B files
func (suite *DriverSuite) BenchmarkPutGetEmptyFiles(c *check.C) { func (suite *DriverSuite) BenchmarkPutGetEmptyFiles(c *check.C) {
@ -968,22 +950,22 @@ func (suite *DriverSuite) benchmarkPutGetFiles(c *check.C, size int64) {
} }
} }
// BenchmarkStreamEmptyFiles benchmarks WriteStream/ReadStream for 0B files // BenchmarkStreamEmptyFiles benchmarks Writer/Reader for 0B files
func (suite *DriverSuite) BenchmarkStreamEmptyFiles(c *check.C) { func (suite *DriverSuite) BenchmarkStreamEmptyFiles(c *check.C) {
suite.benchmarkStreamFiles(c, 0) suite.benchmarkStreamFiles(c, 0)
} }
// BenchmarkStream1KBFiles benchmarks WriteStream/ReadStream for 1KB files // BenchmarkStream1KBFiles benchmarks Writer/Reader for 1KB files
func (suite *DriverSuite) BenchmarkStream1KBFiles(c *check.C) { func (suite *DriverSuite) BenchmarkStream1KBFiles(c *check.C) {
suite.benchmarkStreamFiles(c, 1024) suite.benchmarkStreamFiles(c, 1024)
} }
// BenchmarkStream1MBFiles benchmarks WriteStream/ReadStream for 1MB files // BenchmarkStream1MBFiles benchmarks Writer/Reader for 1MB files
func (suite *DriverSuite) BenchmarkStream1MBFiles(c *check.C) { func (suite *DriverSuite) BenchmarkStream1MBFiles(c *check.C) {
suite.benchmarkStreamFiles(c, 1024*1024) suite.benchmarkStreamFiles(c, 1024*1024)
} }
// BenchmarkStream1GBFiles benchmarks WriteStream/ReadStream for 1GB files // BenchmarkStream1GBFiles benchmarks Writer/Reader for 1GB files
func (suite *DriverSuite) BenchmarkStream1GBFiles(c *check.C) { func (suite *DriverSuite) BenchmarkStream1GBFiles(c *check.C) {
suite.benchmarkStreamFiles(c, 1024*1024*1024) suite.benchmarkStreamFiles(c, 1024*1024*1024)
} }
@ -998,11 +980,18 @@ func (suite *DriverSuite) benchmarkStreamFiles(c *check.C, size int64) {
for i := 0; i < c.N; i++ { for i := 0; i < c.N; i++ {
filename := path.Join(parentDir, randomPath(32)) filename := path.Join(parentDir, randomPath(32))
written, err := suite.StorageDriver.WriteStream(suite.ctx, filename, 0, bytes.NewReader(randomContents(size))) writer, err := suite.StorageDriver.Writer(suite.ctx, filename, false)
c.Assert(err, check.IsNil)
written, err := io.Copy(writer, bytes.NewReader(randomContents(size)))
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(written, check.Equals, size) c.Assert(written, check.Equals, size)
rc, err := suite.StorageDriver.ReadStream(suite.ctx, filename, 0) err = writer.Commit()
c.Assert(err, check.IsNil)
err = writer.Close()
c.Assert(err, check.IsNil)
rc, err := suite.StorageDriver.Reader(suite.ctx, filename, 0)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
rc.Close() rc.Close()
} }
@ -1083,11 +1072,18 @@ func (suite *DriverSuite) testFileStreams(c *check.C, size int64) {
tf.Sync() tf.Sync()
tf.Seek(0, os.SEEK_SET) tf.Seek(0, os.SEEK_SET)
nn, err := suite.StorageDriver.WriteStream(suite.ctx, filename, 0, tf) writer, err := suite.StorageDriver.Writer(suite.ctx, filename, false)
c.Assert(err, check.IsNil)
nn, err := io.Copy(writer, tf)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(nn, check.Equals, size) c.Assert(nn, check.Equals, size)
reader, err := suite.StorageDriver.ReadStream(suite.ctx, filename, 0) err = writer.Commit()
c.Assert(err, check.IsNil)
err = writer.Close()
c.Assert(err, check.IsNil)
reader, err := suite.StorageDriver.Reader(suite.ctx, filename, 0)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
defer reader.Close() defer reader.Close()
@ -1112,11 +1108,18 @@ func (suite *DriverSuite) writeReadCompare(c *check.C, filename string, contents
func (suite *DriverSuite) writeReadCompareStreams(c *check.C, filename string, contents []byte) { func (suite *DriverSuite) writeReadCompareStreams(c *check.C, filename string, contents []byte) {
defer suite.deletePath(c, firstPart(filename)) defer suite.deletePath(c, firstPart(filename))
nn, err := suite.StorageDriver.WriteStream(suite.ctx, filename, 0, bytes.NewReader(contents)) writer, err := suite.StorageDriver.Writer(suite.ctx, filename, false)
c.Assert(err, check.IsNil)
nn, err := io.Copy(writer, bytes.NewReader(contents))
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(nn, check.Equals, int64(len(contents))) c.Assert(nn, check.Equals, int64(len(contents)))
reader, err := suite.StorageDriver.ReadStream(suite.ctx, filename, 0) err = writer.Commit()
c.Assert(err, check.IsNil)
err = writer.Close()
c.Assert(err, check.IsNil)
reader, err := suite.StorageDriver.Reader(suite.ctx, filename, 0)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
defer reader.Close() defer reader.Close()

View file

@ -119,7 +119,7 @@ func (fr *fileReader) reader() (io.Reader, error) {
} }
// If we don't have a reader, open one up. // If we don't have a reader, open one up.
rc, err := fr.driver.ReadStream(fr.ctx, fr.path, fr.offset) rc, err := fr.driver.Reader(fr.ctx, fr.path, fr.offset)
if err != nil { if err != nil {
switch err := err.(type) { switch err := err.(type) {
case storagedriver.PathNotFoundError: case storagedriver.PathNotFoundError:

View file

@ -183,7 +183,7 @@ func TestFileReaderNonExistentFile(t *testing.T) {
// conditions that can arise when reading a layer. // conditions that can arise when reading a layer.
func TestFileReaderErrors(t *testing.T) { func TestFileReaderErrors(t *testing.T) {
// TODO(stevvooe): We need to cover error return types, driven by the // TODO(stevvooe): We need to cover error return types, driven by the
// errors returned via the HTTP API. For now, here is a incomplete list: // errors returned via the HTTP API. For now, here is an incomplete list:
// //
// 1. Layer Not Found: returned when layer is not found or access is // 1. Layer Not Found: returned when layer is not found or access is
// denied. // denied.

View file

@ -1,135 +0,0 @@
package storage
import (
"bytes"
"fmt"
"io"
"os"
"github.com/docker/distribution/context"
storagedriver "github.com/docker/distribution/registry/storage/driver"
)
// fileWriter implements a remote file writer backed by a storage driver.
type fileWriter struct {
driver storagedriver.StorageDriver
ctx context.Context
// identifying fields
path string
// mutable fields
size int64 // size of the file, aka the current end
offset int64 // offset is the current write offset
err error // terminal error, if set, reader is closed
}
// fileWriterInterface makes the desired io compliant interface that the
// filewriter should implement.
type fileWriterInterface interface {
io.WriteSeeker
io.ReaderFrom
io.Closer
}
var _ fileWriterInterface = &fileWriter{}
// newFileWriter returns a prepared fileWriter for the driver and path. This
// could be considered similar to an "open" call on a regular filesystem.
func newFileWriter(ctx context.Context, driver storagedriver.StorageDriver, path string) (*fileWriter, error) {
fw := fileWriter{
driver: driver,
path: path,
ctx: ctx,
}
if fi, err := driver.Stat(ctx, path); err != nil {
switch err := err.(type) {
case storagedriver.PathNotFoundError:
// ignore, offset is zero
default:
return nil, err
}
} else {
if fi.IsDir() {
return nil, fmt.Errorf("cannot write to a directory")
}
fw.size = fi.Size()
}
return &fw, nil
}
// Write writes the buffer p at the current write offset.
func (fw *fileWriter) Write(p []byte) (n int, err error) {
nn, err := fw.ReadFrom(bytes.NewReader(p))
return int(nn), err
}
// ReadFrom reads reader r until io.EOF writing the contents at the current
// offset.
func (fw *fileWriter) ReadFrom(r io.Reader) (n int64, err error) {
if fw.err != nil {
return 0, fw.err
}
nn, err := fw.driver.WriteStream(fw.ctx, fw.path, fw.offset, r)
// We should forward the offset, whether or not there was an error.
// Basically, we keep the filewriter in sync with the reader's head. If an
// error is encountered, the whole thing should be retried but we proceed
// from an expected offset, even if the data didn't make it to the
// backend.
fw.offset += nn
if fw.offset > fw.size {
fw.size = fw.offset
}
return nn, err
}
// Seek moves the write position do the requested offest based on the whence
// argument, which can be os.SEEK_CUR, os.SEEK_END, or os.SEEK_SET.
func (fw *fileWriter) Seek(offset int64, whence int) (int64, error) {
if fw.err != nil {
return 0, fw.err
}
var err error
newOffset := fw.offset
switch whence {
case os.SEEK_CUR:
newOffset += int64(offset)
case os.SEEK_END:
newOffset = fw.size + int64(offset)
case os.SEEK_SET:
newOffset = int64(offset)
}
if newOffset < 0 {
err = fmt.Errorf("cannot seek to negative position")
} else {
// No problems, set the offset.
fw.offset = newOffset
}
return fw.offset, err
}
// Close closes the fileWriter for writing.
// Calling it once is valid and correct and it will
// return a nil error. Calling it subsequent times will
// detect that fw.err has been set and will return the error.
func (fw *fileWriter) Close() error {
if fw.err != nil {
return fw.err
}
fw.err = fmt.Errorf("filewriter@%v: closed", fw.path)
return nil
}

View file

@ -1,226 +0,0 @@
package storage
import (
"bytes"
"crypto/rand"
"io"
"os"
"testing"
"github.com/docker/distribution/context"
"github.com/docker/distribution/digest"
storagedriver "github.com/docker/distribution/registry/storage/driver"
"github.com/docker/distribution/registry/storage/driver/inmemory"
)
// TestSimpleWrite takes the fileWriter through common write operations
// ensuring data integrity.
func TestSimpleWrite(t *testing.T) {
content := make([]byte, 1<<20)
n, err := rand.Read(content)
if err != nil {
t.Fatalf("unexpected error building random data: %v", err)
}
if n != len(content) {
t.Fatalf("random read did't fill buffer")
}
dgst, err := digest.FromReader(bytes.NewReader(content))
if err != nil {
t.Fatalf("unexpected error digesting random content: %v", err)
}
driver := inmemory.New()
path := "/random"
ctx := context.Background()
fw, err := newFileWriter(ctx, driver, path)
if err != nil {
t.Fatalf("unexpected error creating fileWriter: %v", err)
}
defer fw.Close()
n, err = fw.Write(content)
if err != nil {
t.Fatalf("unexpected error writing content: %v", err)
}
if n != len(content) {
t.Fatalf("unexpected write length: %d != %d", n, len(content))
}
fr, err := newFileReader(ctx, driver, path, int64(len(content)))
if err != nil {
t.Fatalf("unexpected error creating fileReader: %v", err)
}
defer fr.Close()
verifier, err := digest.NewDigestVerifier(dgst)
if err != nil {
t.Fatalf("unexpected error getting digest verifier: %s", err)
}
io.Copy(verifier, fr)
if !verifier.Verified() {
t.Fatalf("unable to verify write data")
}
// Check the seek position is equal to the content length
end, err := fw.Seek(0, os.SEEK_END)
if err != nil {
t.Fatalf("unexpected error seeking: %v", err)
}
if end != int64(len(content)) {
t.Fatalf("write did not advance offset: %d != %d", end, len(content))
}
// Double the content
doubled := append(content, content...)
doubledgst, err := digest.FromReader(bytes.NewReader(doubled))
if err != nil {
t.Fatalf("unexpected error digesting doubled content: %v", err)
}
nn, err := fw.ReadFrom(bytes.NewReader(content))
if err != nil {
t.Fatalf("unexpected error doubling content: %v", err)
}
if nn != int64(len(content)) {
t.Fatalf("writeat was short: %d != %d", n, len(content))
}
fr, err = newFileReader(ctx, driver, path, int64(len(doubled)))
if err != nil {
t.Fatalf("unexpected error creating fileReader: %v", err)
}
defer fr.Close()
verifier, err = digest.NewDigestVerifier(doubledgst)
if err != nil {
t.Fatalf("unexpected error getting digest verifier: %s", err)
}
io.Copy(verifier, fr)
if !verifier.Verified() {
t.Fatalf("unable to verify write data")
}
// Check that Write updated the offset.
end, err = fw.Seek(0, os.SEEK_END)
if err != nil {
t.Fatalf("unexpected error seeking: %v", err)
}
if end != int64(len(doubled)) {
t.Fatalf("write did not advance offset: %d != %d", end, len(doubled))
}
// Now, we copy from one path to another, running the data through the
// fileReader to fileWriter, rather than the driver.Move command to ensure
// everything is working correctly.
fr, err = newFileReader(ctx, driver, path, int64(len(doubled)))
if err != nil {
t.Fatalf("unexpected error creating fileReader: %v", err)
}
defer fr.Close()
fw, err = newFileWriter(ctx, driver, "/copied")
if err != nil {
t.Fatalf("unexpected error creating fileWriter: %v", err)
}
defer fw.Close()
nn, err = io.Copy(fw, fr)
if err != nil {
t.Fatalf("unexpected error copying data: %v", err)
}
if nn != int64(len(doubled)) {
t.Fatalf("unexpected copy length: %d != %d", nn, len(doubled))
}
fr, err = newFileReader(ctx, driver, "/copied", int64(len(doubled)))
if err != nil {
t.Fatalf("unexpected error creating fileReader: %v", err)
}
defer fr.Close()
verifier, err = digest.NewDigestVerifier(doubledgst)
if err != nil {
t.Fatalf("unexpected error getting digest verifier: %s", err)
}
io.Copy(verifier, fr)
if !verifier.Verified() {
t.Fatalf("unable to verify write data")
}
}
func BenchmarkFileWriter(b *testing.B) {
b.StopTimer() // not sure how long setup above will take
for i := 0; i < b.N; i++ {
// Start basic fileWriter initialization
fw := fileWriter{
driver: inmemory.New(),
path: "/random",
}
ctx := context.Background()
if fi, err := fw.driver.Stat(ctx, fw.path); err != nil {
switch err := err.(type) {
case storagedriver.PathNotFoundError:
// ignore, offset is zero
default:
b.Fatalf("Failed to initialize fileWriter: %v", err.Error())
}
} else {
if fi.IsDir() {
b.Fatalf("Cannot write to a directory")
}
fw.size = fi.Size()
}
randomBytes := make([]byte, 1<<20)
_, err := rand.Read(randomBytes)
if err != nil {
b.Fatalf("unexpected error building random data: %v", err)
}
// End basic file writer initialization
b.StartTimer()
for j := 0; j < 100; j++ {
fw.Write(randomBytes)
}
b.StopTimer()
}
}
func BenchmarkfileWriter(b *testing.B) {
b.StopTimer() // not sure how long setup above will take
ctx := context.Background()
for i := 0; i < b.N; i++ {
bfw, err := newFileWriter(ctx, inmemory.New(), "/random")
if err != nil {
b.Fatalf("Failed to initialize fileWriter: %v", err.Error())
}
randomBytes := make([]byte, 1<<20)
_, err = rand.Read(randomBytes)
if err != nil {
b.Fatalf("unexpected error building random data: %v", err)
}
b.StartTimer()
for j := 0; j < 100; j++ {
bfw.Write(randomBytes)
}
b.StopTimer()
}
}

View file

@ -1,39 +1,34 @@
package registry package storage
import ( import (
"fmt" "fmt"
"os"
"github.com/docker/distribution" "github.com/docker/distribution"
"github.com/docker/distribution/context" "github.com/docker/distribution/context"
"github.com/docker/distribution/digest" "github.com/docker/distribution/digest"
"github.com/docker/distribution/manifest/schema1"
"github.com/docker/distribution/manifest/schema2" "github.com/docker/distribution/manifest/schema2"
"github.com/docker/distribution/reference" "github.com/docker/distribution/reference"
"github.com/docker/distribution/registry/storage"
"github.com/docker/distribution/registry/storage/driver" "github.com/docker/distribution/registry/storage/driver"
"github.com/docker/distribution/registry/storage/driver/factory"
"github.com/spf13/cobra"
) )
func markAndSweep(storageDriver driver.StorageDriver) error { func emit(format string, a ...interface{}) {
ctx := context.Background() fmt.Printf(format+"\n", a...)
}
// Construct a registry
registry, err := storage.NewRegistry(ctx, storageDriver)
if err != nil {
return fmt.Errorf("failed to construct registry: %v", err)
}
// MarkAndSweep performs a mark and sweep of registry data
func MarkAndSweep(ctx context.Context, storageDriver driver.StorageDriver, registry distribution.Namespace, dryRun bool) error {
repositoryEnumerator, ok := registry.(distribution.RepositoryEnumerator) repositoryEnumerator, ok := registry.(distribution.RepositoryEnumerator)
if !ok { if !ok {
return fmt.Errorf("coercion error: unable to convert Namespace to RepositoryEnumerator") return fmt.Errorf("unable to convert Namespace to RepositoryEnumerator")
} }
// mark // mark
markSet := make(map[digest.Digest]struct{}) markSet := make(map[digest.Digest]struct{})
err = repositoryEnumerator.Enumerate(ctx, func(repoName string) error { err := repositoryEnumerator.Enumerate(ctx, func(repoName string) error {
if dryRun {
emit(repoName)
}
var err error var err error
named, err := reference.ParseNamed(repoName) named, err := reference.ParseNamed(repoName)
if err != nil { if err != nil {
@ -51,11 +46,14 @@ func markAndSweep(storageDriver driver.StorageDriver) error {
manifestEnumerator, ok := manifestService.(distribution.ManifestEnumerator) manifestEnumerator, ok := manifestService.(distribution.ManifestEnumerator)
if !ok { if !ok {
return fmt.Errorf("coercion error: unable to convert ManifestService into ManifestEnumerator") return fmt.Errorf("unable to convert ManifestService into ManifestEnumerator")
} }
err = manifestEnumerator.Enumerate(ctx, func(dgst digest.Digest) error { err = manifestEnumerator.Enumerate(ctx, func(dgst digest.Digest) error {
// Mark the manifest's blob // Mark the manifest's blob
if dryRun {
emit("%s: marking manifest %s ", repoName, dgst)
}
markSet[dgst] = struct{}{} markSet[dgst] = struct{}{}
manifest, err := manifestService.Get(ctx, dgst) manifest, err := manifestService.Get(ctx, dgst)
@ -66,24 +64,17 @@ func markAndSweep(storageDriver driver.StorageDriver) error {
descriptors := manifest.References() descriptors := manifest.References()
for _, descriptor := range descriptors { for _, descriptor := range descriptors {
markSet[descriptor.Digest] = struct{}{} markSet[descriptor.Digest] = struct{}{}
if dryRun {
emit("%s: marking blob %s", repoName, descriptor.Digest)
}
} }
switch manifest.(type) { switch manifest.(type) {
case *schema1.SignedManifest:
signaturesGetter, ok := manifestService.(distribution.SignaturesGetter)
if !ok {
return fmt.Errorf("coercion error: unable to convert ManifestSErvice into SignaturesGetter")
}
signatures, err := signaturesGetter.GetSignatures(ctx, dgst)
if err != nil {
return fmt.Errorf("failed to get signatures for signed manifest: %v", err)
}
for _, signatureDigest := range signatures {
markSet[signatureDigest] = struct{}{}
}
break
case *schema2.DeserializedManifest: case *schema2.DeserializedManifest:
config := manifest.(*schema2.DeserializedManifest).Config config := manifest.(*schema2.DeserializedManifest).Config
if dryRun {
emit("%s: marking configuration %s", repoName, config.Digest)
}
markSet[config.Digest] = struct{}{} markSet[config.Digest] = struct{}{}
break break
} }
@ -91,6 +82,17 @@ func markAndSweep(storageDriver driver.StorageDriver) error {
return nil return nil
}) })
if err != nil {
// In certain situations such as unfinished uploads, deleting all
// tags in S3 or removing the _manifests folder manually, this
// error may be of type PathNotFound.
//
// In these cases we can continue marking other manifests safely.
if _, ok := err.(driver.PathNotFoundError); ok {
return nil
}
}
return err return err
}) })
@ -108,10 +110,19 @@ func markAndSweep(storageDriver driver.StorageDriver) error {
} }
return nil return nil
}) })
if err != nil {
return fmt.Errorf("error enumerating blobs: %v", err)
}
if dryRun {
emit("\n%d blobs marked, %d blobs eligible for deletion", len(markSet), len(deleteSet))
}
// Construct vacuum // Construct vacuum
vacuum := storage.NewVacuum(ctx, storageDriver) vacuum := NewVacuum(ctx, storageDriver)
for dgst := range deleteSet { for dgst := range deleteSet {
if dryRun {
emit("blob eligible for deletion: %s", dgst)
continue
}
err = vacuum.RemoveBlob(string(dgst)) err = vacuum.RemoveBlob(string(dgst))
if err != nil { if err != nil {
return fmt.Errorf("failed to delete blob %s: %v\n", dgst, err) return fmt.Errorf("failed to delete blob %s: %v\n", dgst, err)
@ -120,31 +131,3 @@ func markAndSweep(storageDriver driver.StorageDriver) error {
return err return err
} }
// GCCmd is the cobra command that corresponds to the garbage-collect subcommand
var GCCmd = &cobra.Command{
Use: "garbage-collect <config>",
Short: "`garbage-collects` deletes layers not referenced by any manifests",
Long: "`garbage-collects` deletes layers not referenced by any manifests",
Run: func(cmd *cobra.Command, args []string) {
config, err := resolveConfiguration(args)
if err != nil {
fmt.Fprintf(os.Stderr, "configuration error: %v\n", err)
cmd.Usage()
os.Exit(1)
}
driver, err := factory.Create(config.Storage.Type(), config.Storage.Parameters())
if err != nil {
fmt.Fprintf(os.Stderr, "failed to construct %s driver: %v", config.Storage.Type(), err)
os.Exit(1)
}
err = markAndSweep(driver)
if err != nil {
fmt.Fprintf(os.Stderr, "failed to garbage collect: %v", err)
os.Exit(1)
}
},
}

View file

@ -1,17 +1,18 @@
package registry package storage
import ( import (
"io" "io"
"path"
"testing" "testing"
"github.com/docker/distribution" "github.com/docker/distribution"
"github.com/docker/distribution/context" "github.com/docker/distribution/context"
"github.com/docker/distribution/digest" "github.com/docker/distribution/digest"
"github.com/docker/distribution/reference" "github.com/docker/distribution/reference"
"github.com/docker/distribution/registry/storage"
"github.com/docker/distribution/registry/storage/driver" "github.com/docker/distribution/registry/storage/driver"
"github.com/docker/distribution/registry/storage/driver/inmemory" "github.com/docker/distribution/registry/storage/driver/inmemory"
"github.com/docker/distribution/testutil" "github.com/docker/distribution/testutil"
"github.com/docker/libtrust"
) )
type image struct { type image struct {
@ -22,7 +23,11 @@ type image struct {
func createRegistry(t *testing.T, driver driver.StorageDriver) distribution.Namespace { func createRegistry(t *testing.T, driver driver.StorageDriver) distribution.Namespace {
ctx := context.Background() ctx := context.Background()
registry, err := storage.NewRegistry(ctx, driver, storage.EnableDelete) k, err := libtrust.GenerateECP256PrivateKey()
if err != nil {
t.Fatal(err)
}
registry, err := NewRegistry(ctx, driver, EnableDelete, Schema1SigningKey(k))
if err != nil { if err != nil {
t.Fatalf("Failed to construct namespace") t.Fatalf("Failed to construct namespace")
} }
@ -139,13 +144,13 @@ func TestNoDeletionNoEffect(t *testing.T) {
ctx := context.Background() ctx := context.Background()
inmemoryDriver := inmemory.New() inmemoryDriver := inmemory.New()
registry := createRegistry(t, inmemoryDriver) registry := createRegistry(t, inmemory.New())
repo := makeRepository(t, registry, "palailogos") repo := makeRepository(t, registry, "palailogos")
manifestService, err := repo.Manifests(ctx) manifestService, err := repo.Manifests(ctx)
image1 := uploadRandomSchema1Image(t, repo) image1 := uploadRandomSchema1Image(t, repo)
image2 := uploadRandomSchema1Image(t, repo) image2 := uploadRandomSchema1Image(t, repo)
image3 := uploadRandomSchema2Image(t, repo) uploadRandomSchema2Image(t, repo)
// construct manifestlist for fun. // construct manifestlist for fun.
blobstatter := registry.BlobStatter() blobstatter := registry.BlobStatter()
@ -160,20 +165,48 @@ func TestNoDeletionNoEffect(t *testing.T) {
t.Fatalf("Failed to add manifest list: %v", err) t.Fatalf("Failed to add manifest list: %v", err)
} }
before := allBlobs(t, registry)
// Run GC // Run GC
err = markAndSweep(inmemoryDriver) err = MarkAndSweep(context.Background(), inmemoryDriver, registry, false)
if err != nil {
t.Fatalf("Failed mark and sweep: %v", err)
}
after := allBlobs(t, registry)
if len(before) != len(after) {
t.Fatalf("Garbage collection affected storage: %d != %d", len(before), len(after))
}
}
func TestGCWithMissingManifests(t *testing.T) {
ctx := context.Background()
d := inmemory.New()
registry := createRegistry(t, d)
repo := makeRepository(t, registry, "testrepo")
uploadRandomSchema1Image(t, repo)
// Simulate a missing _manifests directory
revPath, err := pathFor(manifestRevisionsPathSpec{"testrepo"})
if err != nil {
t.Fatal(err)
}
_manifestsPath := path.Dir(revPath)
err = d.Delete(ctx, _manifestsPath)
if err != nil {
t.Fatal(err)
}
err = MarkAndSweep(context.Background(), d, registry, false)
if err != nil { if err != nil {
t.Fatalf("Failed mark and sweep: %v", err) t.Fatalf("Failed mark and sweep: %v", err)
} }
blobs := allBlobs(t, registry) blobs := allBlobs(t, registry)
if len(blobs) > 0 {
// the +1 at the end is for the manifestList t.Errorf("unexpected blobs after gc")
// the first +3 at the end for each manifest's blob
// the second +3 at the end for each manifest's signature/config layer
totalBlobCount := len(image1.layers) + len(image2.layers) + len(image3.layers) + 1 + 3 + 3
if len(blobs) != totalBlobCount {
t.Fatalf("Garbage collection affected storage")
} }
} }
@ -193,7 +226,7 @@ func TestDeletionHasEffect(t *testing.T) {
manifests.Delete(ctx, image3.manifestDigest) manifests.Delete(ctx, image3.manifestDigest)
// Run GC // Run GC
err = markAndSweep(inmemoryDriver) err = MarkAndSweep(context.Background(), inmemoryDriver, registry, false)
if err != nil { if err != nil {
t.Fatalf("Failed mark and sweep: %v", err) t.Fatalf("Failed mark and sweep: %v", err)
} }
@ -327,7 +360,7 @@ func TestOrphanBlobDeleted(t *testing.T) {
uploadRandomSchema2Image(t, repo) uploadRandomSchema2Image(t, repo)
// Run GC // Run GC
err = markAndSweep(inmemoryDriver) err = MarkAndSweep(context.Background(), inmemoryDriver, registry, false)
if err != nil { if err != nil {
t.Fatalf("Failed mark and sweep: %v", err) t.Fatalf("Failed mark and sweep: %v", err)
} }

View file

@ -35,7 +35,7 @@ type linkedBlobStore struct {
// control the repository blob link set to which the blob store // control the repository blob link set to which the blob store
// dispatches. This is required because manifest and layer blobs have not // dispatches. This is required because manifest and layer blobs have not
// yet been fully merged. At some point, this functionality should be // yet been fully merged. At some point, this functionality should be
// removed an the blob links folder should be merged. The first entry is // removed the blob links folder should be merged. The first entry is
// treated as the "canonical" link location and will be used for writes. // treated as the "canonical" link location and will be used for writes.
linkPathFns []linkPathFunc linkPathFns []linkPathFunc
@ -179,7 +179,7 @@ func (lbs *linkedBlobStore) Create(ctx context.Context, options ...distribution.
return nil, err return nil, err
} }
return lbs.newBlobUpload(ctx, uuid, path, startedAt) return lbs.newBlobUpload(ctx, uuid, path, startedAt, false)
} }
func (lbs *linkedBlobStore) Resume(ctx context.Context, id string) (distribution.BlobWriter, error) { func (lbs *linkedBlobStore) Resume(ctx context.Context, id string) (distribution.BlobWriter, error) {
@ -218,7 +218,7 @@ func (lbs *linkedBlobStore) Resume(ctx context.Context, id string) (distribution
return nil, err return nil, err
} }
return lbs.newBlobUpload(ctx, id, path, startedAt) return lbs.newBlobUpload(ctx, id, path, startedAt, true)
} }
func (lbs *linkedBlobStore) Delete(ctx context.Context, dgst digest.Digest) error { func (lbs *linkedBlobStore) Delete(ctx context.Context, dgst digest.Digest) error {
@ -312,18 +312,21 @@ func (lbs *linkedBlobStore) mount(ctx context.Context, sourceRepo reference.Name
} }
// newBlobUpload allocates a new upload controller with the given state. // newBlobUpload allocates a new upload controller with the given state.
func (lbs *linkedBlobStore) newBlobUpload(ctx context.Context, uuid, path string, startedAt time.Time) (distribution.BlobWriter, error) { func (lbs *linkedBlobStore) newBlobUpload(ctx context.Context, uuid, path string, startedAt time.Time, append bool) (distribution.BlobWriter, error) {
fw, err := newFileWriter(ctx, lbs.driver, path) fw, err := lbs.driver.Writer(ctx, path, append)
if err != nil { if err != nil {
return nil, err return nil, err
} }
bw := &blobWriter{ bw := &blobWriter{
blobStore: lbs, ctx: ctx,
id: uuid, blobStore: lbs,
startedAt: startedAt, id: uuid,
digester: digest.Canonical.New(), startedAt: startedAt,
fileWriter: *fw, digester: digest.Canonical.New(),
fileWriter: fw,
driver: lbs.driver,
path: path,
resumableDigestEnabled: lbs.resumableDigestEnabled, resumableDigestEnabled: lbs.resumableDigestEnabled,
} }
@ -381,8 +384,8 @@ var _ distribution.BlobDescriptorService = &linkedBlobStatter{}
func (lbs *linkedBlobStatter) Stat(ctx context.Context, dgst digest.Digest) (distribution.Descriptor, error) { func (lbs *linkedBlobStatter) Stat(ctx context.Context, dgst digest.Digest) (distribution.Descriptor, error) {
var ( var (
resolveErr error found bool
target digest.Digest target digest.Digest
) )
// try the many link path functions until we get success or an error that // try the many link path functions until we get success or an error that
@ -392,19 +395,20 @@ func (lbs *linkedBlobStatter) Stat(ctx context.Context, dgst digest.Digest) (dis
target, err = lbs.resolveWithLinkFunc(ctx, dgst, linkPathFn) target, err = lbs.resolveWithLinkFunc(ctx, dgst, linkPathFn)
if err == nil { if err == nil {
found = true
break // success! break // success!
} }
switch err := err.(type) { switch err := err.(type) {
case driver.PathNotFoundError: case driver.PathNotFoundError:
resolveErr = distribution.ErrBlobUnknown // move to the next linkPathFn, saving the error // do nothing, just move to the next linkPathFn
default: default:
return distribution.Descriptor{}, err return distribution.Descriptor{}, err
} }
} }
if resolveErr != nil { if !found {
return distribution.Descriptor{}, resolveErr return distribution.Descriptor{}, distribution.ErrBlobUnknown
} }
if target != dgst { if target != dgst {

View file

@ -2,7 +2,6 @@ package storage
import ( import (
"fmt" "fmt"
"path"
"encoding/json" "encoding/json"
"github.com/docker/distribution" "github.com/docker/distribution"
@ -140,42 +139,3 @@ func (ms *manifestStore) Enumerate(ctx context.Context, ingester func(digest.Dig
}) })
return err return err
} }
// Only valid for schema1 signed manifests
func (ms *manifestStore) GetSignatures(ctx context.Context, manifestDigest digest.Digest) ([]digest.Digest, error) {
// sanity check that digest refers to a schema1 digest
manifest, err := ms.Get(ctx, manifestDigest)
if err != nil {
return nil, err
}
if _, ok := manifest.(*schema1.SignedManifest); !ok {
return nil, fmt.Errorf("digest %v is not for schema1 manifest", manifestDigest)
}
signaturesPath, err := pathFor(manifestSignaturesPathSpec{
name: ms.repository.Named().Name(),
revision: manifestDigest,
})
if err != nil {
return nil, err
}
signaturesPath = path.Join(signaturesPath, "sha256")
signaturePaths, err := ms.blobStore.driver.List(ctx, signaturesPath)
if err != nil {
return nil, err
}
var digests []digest.Digest
for _, sigPath := range signaturePaths {
sigdigest, err := digest.ParseDigest("sha256:" + path.Base(sigPath))
if err != nil {
// merely found not a digest
continue
}
digests = append(digests, sigdigest)
}
return digests, nil
}

View file

@ -52,15 +52,11 @@ func newManifestStoreTestEnv(t *testing.T, name reference.Named, tag string, opt
} }
func TestManifestStorage(t *testing.T) { func TestManifestStorage(t *testing.T) {
testManifestStorage(t, BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider()), EnableDelete, EnableRedirect)
}
func TestManifestStorageDisabledSignatures(t *testing.T) {
k, err := libtrust.GenerateECP256PrivateKey() k, err := libtrust.GenerateECP256PrivateKey()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
testManifestStorage(t, BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider()), EnableDelete, EnableRedirect, DisableSchema1Signatures, Schema1SigningKey(k)) testManifestStorage(t, BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider()), EnableDelete, EnableRedirect, Schema1SigningKey(k))
} }
func testManifestStorage(t *testing.T, options ...RegistryOption) { func testManifestStorage(t *testing.T, options ...RegistryOption) {
@ -71,7 +67,6 @@ func testManifestStorage(t *testing.T, options ...RegistryOption) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
equalSignatures := env.registry.(*registry).schema1SignaturesEnabled
m := schema1.Manifest{ m := schema1.Manifest{
Versioned: manifest.Versioned{ Versioned: manifest.Versioned{
@ -175,12 +170,6 @@ func testManifestStorage(t *testing.T, options ...RegistryOption) {
t.Fatalf("fetched payload does not match original payload: %q != %q", fetchedManifest.Canonical, sm.Canonical) t.Fatalf("fetched payload does not match original payload: %q != %q", fetchedManifest.Canonical, sm.Canonical)
} }
if equalSignatures {
if !reflect.DeepEqual(fetchedManifest, sm) {
t.Fatalf("fetched manifest not equal: %#v != %#v", fetchedManifest.Manifest, sm.Manifest)
}
}
_, pl, err := fetchedManifest.Payload() _, pl, err := fetchedManifest.Payload()
if err != nil { if err != nil {
t.Fatalf("error getting payload %#v", err) t.Fatalf("error getting payload %#v", err)
@ -223,12 +212,6 @@ func testManifestStorage(t *testing.T, options ...RegistryOption) {
t.Fatalf("fetched manifest not equal: %q != %q", byDigestManifest.Canonical, fetchedManifest.Canonical) t.Fatalf("fetched manifest not equal: %q != %q", byDigestManifest.Canonical, fetchedManifest.Canonical)
} }
if equalSignatures {
if !reflect.DeepEqual(fetchedByDigest, fetchedManifest) {
t.Fatalf("fetched manifest not equal: %#v != %#v", fetchedByDigest, fetchedManifest)
}
}
sigs, err := fetchedJWS.Signatures() sigs, err := fetchedJWS.Signatures()
if err != nil { if err != nil {
t.Fatalf("unable to extract signatures: %v", err) t.Fatalf("unable to extract signatures: %v", err)
@ -285,17 +268,6 @@ func testManifestStorage(t *testing.T, options ...RegistryOption) {
t.Fatalf("unexpected error verifying manifest: %v", err) t.Fatalf("unexpected error verifying manifest: %v", err)
} }
// Assemble our payload and two signatures to get what we expect!
expectedJWS, err := libtrust.NewJSONSignature(payload, sigs[0], sigs2[0])
if err != nil {
t.Fatalf("unexpected error merging jws: %v", err)
}
expectedSigs, err := expectedJWS.Signatures()
if err != nil {
t.Fatalf("unexpected error getting expected signatures: %v", err)
}
_, pl, err = fetched.Payload() _, pl, err = fetched.Payload()
if err != nil { if err != nil {
t.Fatalf("error getting payload %#v", err) t.Fatalf("error getting payload %#v", err)
@ -315,19 +287,6 @@ func testManifestStorage(t *testing.T, options ...RegistryOption) {
t.Fatalf("payloads are not equal") t.Fatalf("payloads are not equal")
} }
if equalSignatures {
receivedSigs, err := receivedJWS.Signatures()
if err != nil {
t.Fatalf("error getting signatures: %v", err)
}
for i, sig := range receivedSigs {
if !bytes.Equal(sig, expectedSigs[i]) {
t.Fatalf("mismatched signatures from remote: %v != %v", string(sig), string(expectedSigs[i]))
}
}
}
// Test deleting manifests // Test deleting manifests
err = ms.Delete(ctx, dgst) err = ms.Delete(ctx, dgst)
if err != nil { if err != nil {

View file

@ -30,8 +30,6 @@ const (
// revisions // revisions
// -> <manifest digest path> // -> <manifest digest path>
// -> link // -> link
// -> signatures
// <algorithm>/<digest>/link
// tags/<tag> // tags/<tag>
// -> current/link // -> current/link
// -> index // -> index
@ -62,8 +60,7 @@ const (
// //
// The third component of the repository directory is the manifests store, // The third component of the repository directory is the manifests store,
// which is made up of a revision store and tag store. Manifests are stored in // which is made up of a revision store and tag store. Manifests are stored in
// the blob store and linked into the revision store. Signatures are separated // the blob store and linked into the revision store.
// from the manifest payload data and linked into the blob store, as well.
// While the registry can save all revisions of a manifest, no relationship is // While the registry can save all revisions of a manifest, no relationship is
// implied as to the ordering of changes to a manifest. The tag store provides // implied as to the ordering of changes to a manifest. The tag store provides
// support for name, tag lookups of manifests, using "current/link" under a // support for name, tag lookups of manifests, using "current/link" under a
@ -77,8 +74,6 @@ const (
// manifestRevisionsPathSpec: <root>/v2/repositories/<name>/_manifests/revisions/ // manifestRevisionsPathSpec: <root>/v2/repositories/<name>/_manifests/revisions/
// manifestRevisionPathSpec: <root>/v2/repositories/<name>/_manifests/revisions/<algorithm>/<hex digest>/ // manifestRevisionPathSpec: <root>/v2/repositories/<name>/_manifests/revisions/<algorithm>/<hex digest>/
// manifestRevisionLinkPathSpec: <root>/v2/repositories/<name>/_manifests/revisions/<algorithm>/<hex digest>/link // manifestRevisionLinkPathSpec: <root>/v2/repositories/<name>/_manifests/revisions/<algorithm>/<hex digest>/link
// manifestSignaturesPathSpec: <root>/v2/repositories/<name>/_manifests/revisions/<algorithm>/<hex digest>/signatures/
// manifestSignatureLinkPathSpec: <root>/v2/repositories/<name>/_manifests/revisions/<algorithm>/<hex digest>/signatures/<algorithm>/<hex digest>/link
// //
// Tags: // Tags:
// //
@ -148,33 +143,6 @@ func pathFor(spec pathSpec) (string, error) {
} }
return path.Join(root, "link"), nil return path.Join(root, "link"), nil
case manifestSignaturesPathSpec:
root, err := pathFor(manifestRevisionPathSpec{
name: v.name,
revision: v.revision,
})
if err != nil {
return "", err
}
return path.Join(root, "signatures"), nil
case manifestSignatureLinkPathSpec:
root, err := pathFor(manifestSignaturesPathSpec{
name: v.name,
revision: v.revision,
})
if err != nil {
return "", err
}
signatureComponents, err := digestPathComponents(v.signature, false)
if err != nil {
return "", err
}
return path.Join(root, path.Join(append(signatureComponents, "link")...)), nil
case manifestTagsPathSpec: case manifestTagsPathSpec:
return path.Join(append(repoPrefix, v.name, "_manifests", "tags")...), nil return path.Join(append(repoPrefix, v.name, "_manifests", "tags")...), nil
case manifestTagPathSpec: case manifestTagPathSpec:
@ -325,26 +293,6 @@ type manifestRevisionLinkPathSpec struct {
func (manifestRevisionLinkPathSpec) pathSpec() {} func (manifestRevisionLinkPathSpec) pathSpec() {}
// manifestSignaturesPathSpec describes the path components for the directory
// containing all the signatures for the target blob. Entries are named with
// the underlying key id.
type manifestSignaturesPathSpec struct {
name string
revision digest.Digest
}
func (manifestSignaturesPathSpec) pathSpec() {}
// manifestSignatureLinkPathSpec describes the path components used to look up
// a signature file by the hash of its blob.
type manifestSignatureLinkPathSpec struct {
name string
revision digest.Digest
signature digest.Digest
}
func (manifestSignatureLinkPathSpec) pathSpec() {}
// manifestTagsPathSpec describes the path elements required to point to the // manifestTagsPathSpec describes the path elements required to point to the
// manifest tags directory. // manifest tags directory.
type manifestTagsPathSpec struct { type manifestTagsPathSpec struct {

View file

@ -26,21 +26,6 @@ func TestPathMapper(t *testing.T) {
}, },
expected: "/docker/registry/v2/repositories/foo/bar/_manifests/revisions/sha256/abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789/link", expected: "/docker/registry/v2/repositories/foo/bar/_manifests/revisions/sha256/abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789/link",
}, },
{
spec: manifestSignatureLinkPathSpec{
name: "foo/bar",
revision: "sha256:abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789",
signature: "sha256:abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789",
},
expected: "/docker/registry/v2/repositories/foo/bar/_manifests/revisions/sha256/abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789/signatures/sha256/abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789/link",
},
{
spec: manifestSignaturesPathSpec{
name: "foo/bar",
revision: "sha256:abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789",
},
expected: "/docker/registry/v2/repositories/foo/bar/_manifests/revisions/sha256/abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789/signatures",
},
{ {
spec: manifestTagsPathSpec{ spec: manifestTagsPathSpec{
name: "foo/bar", name: "foo/bar",
@ -113,7 +98,7 @@ func TestPathMapper(t *testing.T) {
// Add a few test cases to ensure we cover some errors // Add a few test cases to ensure we cover some errors
// Specify a path that requires a revision and get a digest validation error. // Specify a path that requires a revision and get a digest validation error.
badpath, err := pathFor(manifestSignaturesPathSpec{ badpath, err := pathFor(manifestRevisionPathSpec{
name: "foo/bar", name: "foo/bar",
}) })

View file

@ -12,14 +12,14 @@ import (
// registry is the top-level implementation of Registry for use in the storage // registry is the top-level implementation of Registry for use in the storage
// package. All instances should descend from this object. // package. All instances should descend from this object.
type registry struct { type registry struct {
blobStore *blobStore blobStore *blobStore
blobServer *blobServer blobServer *blobServer
statter *blobStatter // global statter service. statter *blobStatter // global statter service.
blobDescriptorCacheProvider cache.BlobDescriptorCacheProvider blobDescriptorCacheProvider cache.BlobDescriptorCacheProvider
deleteEnabled bool deleteEnabled bool
resumableDigestEnabled bool resumableDigestEnabled bool
schema1SignaturesEnabled bool schema1SigningKey libtrust.PrivateKey
schema1SigningKey libtrust.PrivateKey blobDescriptorServiceFactory distribution.BlobDescriptorServiceFactory
} }
// RegistryOption is the type used for functional options for NewRegistry. // RegistryOption is the type used for functional options for NewRegistry.
@ -46,17 +46,8 @@ func DisableDigestResumption(registry *registry) error {
return nil return nil
} }
// DisableSchema1Signatures is a functional option for NewRegistry. It disables
// signature storage and ensures all schema1 manifests will only be returned
// with a signature from a provided signing key.
func DisableSchema1Signatures(registry *registry) error {
registry.schema1SignaturesEnabled = false
return nil
}
// Schema1SigningKey returns a functional option for NewRegistry. It sets the // Schema1SigningKey returns a functional option for NewRegistry. It sets the
// signing key for adding a signature to all schema1 manifests. This should be // key for signing all schema1 manifests.
// used in conjunction with disabling signature store.
func Schema1SigningKey(key libtrust.PrivateKey) RegistryOption { func Schema1SigningKey(key libtrust.PrivateKey) RegistryOption {
return func(registry *registry) error { return func(registry *registry) error {
registry.schema1SigningKey = key registry.schema1SigningKey = key
@ -64,6 +55,15 @@ func Schema1SigningKey(key libtrust.PrivateKey) RegistryOption {
} }
} }
// BlobDescriptorServiceFactory returns a functional option for NewRegistry. It sets the
// factory to create BlobDescriptorServiceFactory middleware.
func BlobDescriptorServiceFactory(factory distribution.BlobDescriptorServiceFactory) RegistryOption {
return func(registry *registry) error {
registry.blobDescriptorServiceFactory = factory
return nil
}
}
// BlobDescriptorCacheProvider returns a functional option for // BlobDescriptorCacheProvider returns a functional option for
// NewRegistry. It creates a cached blob statter for use by the // NewRegistry. It creates a cached blob statter for use by the
// registry. // registry.
@ -106,9 +106,8 @@ func NewRegistry(ctx context.Context, driver storagedriver.StorageDriver, option
statter: statter, statter: statter,
pathFn: bs.path, pathFn: bs.path,
}, },
statter: statter, statter: statter,
resumableDigestEnabled: true, resumableDigestEnabled: true,
schema1SignaturesEnabled: true,
} }
for _, option := range options { for _, option := range options {
@ -190,16 +189,22 @@ func (repo *repository) Manifests(ctx context.Context, options ...distribution.M
manifestDirectoryPathSpec := manifestRevisionsPathSpec{name: repo.name.Name()} manifestDirectoryPathSpec := manifestRevisionsPathSpec{name: repo.name.Name()}
var statter distribution.BlobDescriptorService = &linkedBlobStatter{
blobStore: repo.blobStore,
repository: repo,
linkPathFns: manifestLinkPathFns,
}
if repo.registry.blobDescriptorServiceFactory != nil {
statter = repo.registry.blobDescriptorServiceFactory.BlobAccessController(statter)
}
blobStore := &linkedBlobStore{ blobStore := &linkedBlobStore{
ctx: ctx, ctx: ctx,
blobStore: repo.blobStore, blobStore: repo.blobStore,
repository: repo, repository: repo,
deleteEnabled: repo.registry.deleteEnabled, deleteEnabled: repo.registry.deleteEnabled,
blobAccessController: &linkedBlobStatter{ blobAccessController: statter,
blobStore: repo.blobStore,
repository: repo,
linkPathFns: manifestLinkPathFns,
},
// TODO(stevvooe): linkPath limits this blob store to only // TODO(stevvooe): linkPath limits this blob store to only
// manifests. This instance cannot be used for blob checks. // manifests. This instance cannot be used for blob checks.
@ -215,11 +220,6 @@ func (repo *repository) Manifests(ctx context.Context, options ...distribution.M
ctx: ctx, ctx: ctx,
repository: repo, repository: repo,
blobStore: blobStore, blobStore: blobStore,
signatures: &signatureStore{
ctx: ctx,
repository: repo,
blobStore: repo.blobStore,
},
}, },
schema2Handler: &schema2ManifestHandler{ schema2Handler: &schema2ManifestHandler{
ctx: ctx, ctx: ctx,
@ -258,6 +258,10 @@ func (repo *repository) Blobs(ctx context.Context) distribution.BlobStore {
statter = cache.NewCachedBlobStatter(repo.descriptorCache, statter) statter = cache.NewCachedBlobStatter(repo.descriptorCache, statter)
} }
if repo.registry.blobDescriptorServiceFactory != nil {
statter = repo.registry.blobDescriptorServiceFactory.BlobAccessController(statter)
}
return &linkedBlobStore{ return &linkedBlobStore{
registry: repo.registry, registry: repo.registry,
blobStore: repo.blobStore, blobStore: repo.blobStore,

View file

@ -1,15 +1,24 @@
package storage package storage
import ( import (
"errors"
"fmt" "fmt"
"net/url"
"encoding/json" "encoding/json"
"github.com/docker/distribution" "github.com/docker/distribution"
"github.com/docker/distribution/context" "github.com/docker/distribution/context"
"github.com/docker/distribution/digest" "github.com/docker/distribution/digest"
"github.com/docker/distribution/manifest/schema2" "github.com/docker/distribution/manifest/schema2"
) )
var (
errUnexpectedURL = errors.New("unexpected URL on layer")
errMissingURL = errors.New("missing URL on layer")
errInvalidURL = errors.New("invalid URL on layer")
)
//schema2ManifestHandler is a ManifestHandler that covers schema2 manifests. //schema2ManifestHandler is a ManifestHandler that covers schema2 manifests.
type schema2ManifestHandler struct { type schema2ManifestHandler struct {
repository *repository repository *repository
@ -80,7 +89,27 @@ func (ms *schema2ManifestHandler) verifyManifest(ctx context.Context, mnfst sche
} }
for _, fsLayer := range mnfst.References() { for _, fsLayer := range mnfst.References() {
_, err := ms.repository.Blobs(ctx).Stat(ctx, fsLayer.Digest) var err error
if fsLayer.MediaType != schema2.MediaTypeForeignLayer {
if len(fsLayer.URLs) == 0 {
_, err = ms.repository.Blobs(ctx).Stat(ctx, fsLayer.Digest)
} else {
err = errUnexpectedURL
}
} else {
// Clients download this layer from an external URL, so do not check for
// its presense.
if len(fsLayer.URLs) == 0 {
err = errMissingURL
}
for _, u := range fsLayer.URLs {
var pu *url.URL
pu, err = url.Parse(u)
if err != nil || (pu.Scheme != "http" && pu.Scheme != "https") || pu.Fragment != "" {
err = errInvalidURL
}
}
}
if err != nil { if err != nil {
if err != distribution.ErrBlobUnknown { if err != distribution.ErrBlobUnknown {
errs = append(errs, err) errs = append(errs, err)

View file

@ -0,0 +1,117 @@
package storage
import (
"testing"
"github.com/docker/distribution"
"github.com/docker/distribution/context"
"github.com/docker/distribution/manifest"
"github.com/docker/distribution/manifest/schema2"
"github.com/docker/distribution/registry/storage/driver/inmemory"
)
func TestVerifyManifestForeignLayer(t *testing.T) {
ctx := context.Background()
inmemoryDriver := inmemory.New()
registry := createRegistry(t, inmemoryDriver)
repo := makeRepository(t, registry, "test")
manifestService := makeManifestService(t, repo)
config, err := repo.Blobs(ctx).Put(ctx, schema2.MediaTypeConfig, nil)
if err != nil {
t.Fatal(err)
}
layer, err := repo.Blobs(ctx).Put(ctx, schema2.MediaTypeLayer, nil)
if err != nil {
t.Fatal(err)
}
foreignLayer := distribution.Descriptor{
Digest: "sha256:463435349086340864309863409683460843608348608934092322395278926a",
Size: 6323,
MediaType: schema2.MediaTypeForeignLayer,
}
template := schema2.Manifest{
Versioned: manifest.Versioned{
SchemaVersion: 2,
MediaType: schema2.MediaTypeManifest,
},
Config: config,
}
type testcase struct {
BaseLayer distribution.Descriptor
URLs []string
Err error
}
cases := []testcase{
{
foreignLayer,
nil,
errMissingURL,
},
{
layer,
[]string{"http://foo/bar"},
errUnexpectedURL,
},
{
foreignLayer,
[]string{"file:///local/file"},
errInvalidURL,
},
{
foreignLayer,
[]string{"http://foo/bar#baz"},
errInvalidURL,
},
{
foreignLayer,
[]string{""},
errInvalidURL,
},
{
foreignLayer,
[]string{"https://foo/bar", ""},
errInvalidURL,
},
{
foreignLayer,
[]string{"http://foo/bar"},
nil,
},
{
foreignLayer,
[]string{"https://foo/bar"},
nil,
},
}
for _, c := range cases {
m := template
l := c.BaseLayer
l.URLs = c.URLs
m.Layers = []distribution.Descriptor{l}
dm, err := schema2.FromStruct(m)
if err != nil {
t.Error(err)
continue
}
_, err = manifestService.Put(ctx, dm)
if verr, ok := err.(distribution.ErrManifestVerification); ok {
// Extract the first error
if len(verr) == 2 {
if _, ok = verr[1].(distribution.ErrManifestBlobUnknown); ok {
err = verr[0]
}
}
}
if err != c.Err {
t.Errorf("%#v: expected %v, got %v", l, c.Err, err)
}
}
}

View file

@ -1,131 +0,0 @@
package storage
import (
"path"
"sync"
"github.com/docker/distribution/context"
"github.com/docker/distribution/digest"
)
type signatureStore struct {
repository *repository
blobStore *blobStore
ctx context.Context
}
func (s *signatureStore) Get(dgst digest.Digest) ([][]byte, error) {
signaturesPath, err := pathFor(manifestSignaturesPathSpec{
name: s.repository.Named().Name(),
revision: dgst,
})
if err != nil {
return nil, err
}
// Need to append signature digest algorithm to path to get all items.
// Perhaps, this should be in the pathMapper but it feels awkward. This
// can be eliminated by implementing listAll on drivers.
signaturesPath = path.Join(signaturesPath, "sha256")
signaturePaths, err := s.blobStore.driver.List(s.ctx, signaturesPath)
if err != nil {
return nil, err
}
var wg sync.WaitGroup
type result struct {
index int
signature []byte
err error
}
ch := make(chan result)
bs := s.linkedBlobStore(s.ctx, dgst)
for i, sigPath := range signaturePaths {
sigdgst, err := digest.ParseDigest("sha256:" + path.Base(sigPath))
if err != nil {
context.GetLogger(s.ctx).Errorf("could not get digest from path: %q, skipping", sigPath)
continue
}
wg.Add(1)
go func(idx int, sigdgst digest.Digest) {
defer wg.Done()
context.GetLogger(s.ctx).
Debugf("fetching signature %q", sigdgst)
r := result{index: idx}
if p, err := bs.Get(s.ctx, sigdgst); err != nil {
context.GetLogger(s.ctx).
Errorf("error fetching signature %q: %v", sigdgst, err)
r.err = err
} else {
r.signature = p
}
ch <- r
}(i, sigdgst)
}
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
// aggregrate the results
signatures := make([][]byte, len(signaturePaths))
loop:
for {
select {
case result := <-ch:
signatures[result.index] = result.signature
if result.err != nil && err == nil {
// only set the first one.
err = result.err
}
case <-done:
break loop
}
}
return signatures, err
}
func (s *signatureStore) Put(dgst digest.Digest, signatures ...[]byte) error {
bs := s.linkedBlobStore(s.ctx, dgst)
for _, signature := range signatures {
if _, err := bs.Put(s.ctx, "application/json", signature); err != nil {
return err
}
}
return nil
}
// linkedBlobStore returns the namedBlobStore of the signatures for the
// manifest with the given digest. Effectively, each signature link path
// layout is a unique linked blob store.
func (s *signatureStore) linkedBlobStore(ctx context.Context, revision digest.Digest) *linkedBlobStore {
linkpath := func(name string, dgst digest.Digest) (string, error) {
return pathFor(manifestSignatureLinkPathSpec{
name: name,
revision: revision,
signature: dgst,
})
}
return &linkedBlobStore{
ctx: ctx,
repository: s.repository,
blobStore: s.blobStore,
blobAccessController: &linkedBlobStatter{
blobStore: s.blobStore,
repository: s.repository,
linkPathFns: []linkPathFunc{linkpath},
},
linkPathFns: []linkPathFunc{linkpath},
}
}

View file

@ -18,7 +18,6 @@ type signedManifestHandler struct {
repository *repository repository *repository
blobStore *linkedBlobStore blobStore *linkedBlobStore
ctx context.Context ctx context.Context
signatures *signatureStore
} }
var _ ManifestHandler = &signedManifestHandler{} var _ ManifestHandler = &signedManifestHandler{}
@ -30,13 +29,6 @@ func (ms *signedManifestHandler) Unmarshal(ctx context.Context, dgst digest.Dige
signatures [][]byte signatures [][]byte
err error err error
) )
if ms.repository.schema1SignaturesEnabled {
// Fetch the signatures for the manifest
signatures, err = ms.signatures.Get(dgst)
if err != nil {
return nil, err
}
}
jsig, err := libtrust.NewJSONSignature(content, signatures...) jsig, err := libtrust.NewJSONSignature(content, signatures...)
if err != nil { if err != nil {
@ -47,8 +39,6 @@ func (ms *signedManifestHandler) Unmarshal(ctx context.Context, dgst digest.Dige
if err := jsig.Sign(ms.repository.schema1SigningKey); err != nil { if err := jsig.Sign(ms.repository.schema1SigningKey); err != nil {
return nil, err return nil, err
} }
} else if !ms.repository.schema1SignaturesEnabled {
return nil, fmt.Errorf("missing signing key with signature store disabled")
} }
// Extract the pretty JWS // Extract the pretty JWS
@ -90,18 +80,6 @@ func (ms *signedManifestHandler) Put(ctx context.Context, manifest distribution.
return "", err return "", err
} }
if ms.repository.schema1SignaturesEnabled {
// Grab each json signature and store them.
signatures, err := sm.Signatures()
if err != nil {
return "", err
}
if err := ms.signatures.Put(revision.Digest, signatures...); err != nil {
return "", err
}
}
return revision.Digest, nil return revision.Digest, nil
} }