Merge pull request #1 from docker/master

update
This commit is contained in:
Eric Yang 2016-06-07 10:58:38 +08:00
commit a4f7559f6c
80 changed files with 3871 additions and 4693 deletions

View file

@ -63,6 +63,19 @@ var (
Description: "Returned when a service is not available", Description: "Returned when a service is not available",
HTTPStatusCode: http.StatusServiceUnavailable, HTTPStatusCode: http.StatusServiceUnavailable,
}) })
// ErrorCodeTooManyRequests is returned if a client attempts too many
// times to contact a service endpoint.
ErrorCodeTooManyRequests = Register("errcode", ErrorDescriptor{
Value: "TOOMANYREQUESTS",
Message: "too many requests",
Description: `Returned when a client attempts to contact a
service too many times`,
// FIXME: go1.5 doesn't export http.StatusTooManyRequests while
// go1.6 does. Update the hardcoded value to the constant once
// Docker updates golang version to 1.6.
HTTPStatusCode: 429,
})
) )
var nextCode = 1000 var nextCode = 1000

View file

@ -19,31 +19,33 @@ import (
type URLBuilder struct { type URLBuilder struct {
root *url.URL // url root (ie http://localhost/) root *url.URL // url root (ie http://localhost/)
router *mux.Router router *mux.Router
relative bool
} }
// NewURLBuilder creates a URLBuilder with provided root url object. // NewURLBuilder creates a URLBuilder with provided root url object.
func NewURLBuilder(root *url.URL) *URLBuilder { func NewURLBuilder(root *url.URL, relative bool) *URLBuilder {
return &URLBuilder{ return &URLBuilder{
root: root, root: root,
router: Router(), router: Router(),
relative: relative,
} }
} }
// NewURLBuilderFromString workes identically to NewURLBuilder except it takes // NewURLBuilderFromString workes identically to NewURLBuilder except it takes
// a string argument for the root, returning an error if it is not a valid // a string argument for the root, returning an error if it is not a valid
// url. // url.
func NewURLBuilderFromString(root string) (*URLBuilder, error) { func NewURLBuilderFromString(root string, relative bool) (*URLBuilder, error) {
u, err := url.Parse(root) u, err := url.Parse(root)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return NewURLBuilder(u), nil return NewURLBuilder(u, relative), nil
} }
// NewURLBuilderFromRequest uses information from an *http.Request to // NewURLBuilderFromRequest uses information from an *http.Request to
// construct the root url. // construct the root url.
func NewURLBuilderFromRequest(r *http.Request) *URLBuilder { func NewURLBuilderFromRequest(r *http.Request, relative bool) *URLBuilder {
var scheme string var scheme string
forwardedProto := r.Header.Get("X-Forwarded-Proto") forwardedProto := r.Header.Get("X-Forwarded-Proto")
@ -85,7 +87,7 @@ func NewURLBuilderFromRequest(r *http.Request) *URLBuilder {
u.Path = requestPath[0 : index+1] u.Path = requestPath[0 : index+1]
} }
return NewURLBuilder(u) return NewURLBuilder(u, relative)
} }
// BuildBaseURL constructs a base url for the API, typically just "/v2/". // BuildBaseURL constructs a base url for the API, typically just "/v2/".
@ -194,12 +196,13 @@ func (ub *URLBuilder) cloneRoute(name string) clonedRoute {
*route = *ub.router.GetRoute(name) // clone the route *route = *ub.router.GetRoute(name) // clone the route
*root = *ub.root *root = *ub.root
return clonedRoute{Route: route, root: root} return clonedRoute{Route: route, root: root, relative: ub.relative}
} }
type clonedRoute struct { type clonedRoute struct {
*mux.Route *mux.Route
root *url.URL root *url.URL
relative bool
} }
func (cr clonedRoute) URL(pairs ...string) (*url.URL, error) { func (cr clonedRoute) URL(pairs ...string) (*url.URL, error) {
@ -208,6 +211,10 @@ func (cr clonedRoute) URL(pairs ...string) (*url.URL, error) {
return nil, err return nil, err
} }
if cr.relative {
return routeURL, nil
}
if routeURL.Scheme == "" && routeURL.User == nil && routeURL.Host == "" { if routeURL.Scheme == "" && routeURL.User == nil && routeURL.Host == "" {
routeURL.Path = routeURL.Path[1:] routeURL.Path = routeURL.Path[1:]
} }

View file

@ -92,8 +92,9 @@ func TestURLBuilder(t *testing.T) {
"https://localhost:5443", "https://localhost:5443",
} }
doTest := func(relative bool) {
for _, root := range roots { for _, root := range roots {
urlBuilder, err := NewURLBuilderFromString(root) urlBuilder, err := NewURLBuilderFromString(root, relative)
if err != nil { if err != nil {
t.Fatalf("unexpected error creating urlbuilder: %v", err) t.Fatalf("unexpected error creating urlbuilder: %v", err)
} }
@ -103,14 +104,19 @@ func TestURLBuilder(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("%s: error building url: %v", testCase.description, err) t.Fatalf("%s: error building url: %v", testCase.description, err)
} }
expectedURL := testCase.expectedPath
expectedURL := root + testCase.expectedPath if !relative {
expectedURL = root + expectedURL
}
if url != expectedURL { if url != expectedURL {
t.Fatalf("%s: %q != %q", testCase.description, url, expectedURL) t.Fatalf("%s: %q != %q", testCase.description, url, expectedURL)
} }
} }
} }
}
doTest(true)
doTest(false)
} }
func TestURLBuilderWithPrefix(t *testing.T) { func TestURLBuilderWithPrefix(t *testing.T) {
@ -121,8 +127,9 @@ func TestURLBuilderWithPrefix(t *testing.T) {
"https://localhost:5443/prefix/", "https://localhost:5443/prefix/",
} }
doTest := func(relative bool) {
for _, root := range roots { for _, root := range roots {
urlBuilder, err := NewURLBuilderFromString(root) urlBuilder, err := NewURLBuilderFromString(root, relative)
if err != nil { if err != nil {
t.Fatalf("unexpected error creating urlbuilder: %v", err) t.Fatalf("unexpected error creating urlbuilder: %v", err)
} }
@ -133,13 +140,18 @@ func TestURLBuilderWithPrefix(t *testing.T) {
t.Fatalf("%s: error building url: %v", testCase.description, err) t.Fatalf("%s: error building url: %v", testCase.description, err)
} }
expectedURL := root[0:len(root)-1] + testCase.expectedPath expectedURL := testCase.expectedPath
if !relative {
expectedURL = root[0:len(root)-1] + expectedURL
}
if url != expectedURL { if url != expectedURL {
t.Fatalf("%s: %q != %q", testCase.description, url, expectedURL) t.Fatalf("%s: %q != %q", testCase.description, url, expectedURL)
} }
} }
} }
}
doTest(true)
doTest(false)
} }
type builderFromRequestTestCase struct { type builderFromRequestTestCase struct {
@ -197,13 +209,13 @@ func TestBuilderFromRequest(t *testing.T) {
}, },
}, },
} }
doTest := func(relative bool) {
for _, tr := range testRequests { for _, tr := range testRequests {
var builder *URLBuilder var builder *URLBuilder
if tr.configHost.Scheme != "" && tr.configHost.Host != "" { if tr.configHost.Scheme != "" && tr.configHost.Host != "" {
builder = NewURLBuilder(&tr.configHost) builder = NewURLBuilder(&tr.configHost, relative)
} else { } else {
builder = NewURLBuilderFromRequest(tr.request) builder = NewURLBuilderFromRequest(tr.request, relative)
} }
for _, testCase := range makeURLBuilderTestCases(builder) { for _, testCase := range makeURLBuilderTestCases(builder) {
@ -215,14 +227,20 @@ func TestBuilderFromRequest(t *testing.T) {
var expectedURL string var expectedURL string
proto, ok := tr.request.Header["X-Forwarded-Proto"] proto, ok := tr.request.Header["X-Forwarded-Proto"]
if !ok { if !ok {
expectedURL = tr.base + testCase.expectedPath expectedURL = testCase.expectedPath
if !relative {
expectedURL = tr.base + expectedURL
}
} else { } else {
urlBase, err := url.Parse(tr.base) urlBase, err := url.Parse(tr.base)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
urlBase.Scheme = proto[0] urlBase.Scheme = proto[0]
expectedURL = urlBase.String() + testCase.expectedPath expectedURL = testCase.expectedPath
if !relative {
expectedURL = urlBase.String() + expectedURL
}
} }
if buildURL != expectedURL { if buildURL != expectedURL {
@ -230,6 +248,9 @@ func TestBuilderFromRequest(t *testing.T) {
} }
} }
} }
}
doTest(true)
doTest(false)
} }
func TestBuilderFromRequestWithPrefix(t *testing.T) { func TestBuilderFromRequestWithPrefix(t *testing.T) {
@ -270,12 +291,13 @@ func TestBuilderFromRequestWithPrefix(t *testing.T) {
}, },
} }
var relative bool
for _, tr := range testRequests { for _, tr := range testRequests {
var builder *URLBuilder var builder *URLBuilder
if tr.configHost.Scheme != "" && tr.configHost.Host != "" { if tr.configHost.Scheme != "" && tr.configHost.Host != "" {
builder = NewURLBuilder(&tr.configHost) builder = NewURLBuilder(&tr.configHost, false)
} else { } else {
builder = NewURLBuilderFromRequest(tr.request) builder = NewURLBuilderFromRequest(tr.request, false)
} }
for _, testCase := range makeURLBuilderTestCases(builder) { for _, testCase := range makeURLBuilderTestCases(builder) {
@ -283,17 +305,25 @@ func TestBuilderFromRequestWithPrefix(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("%s: error building url: %v", testCase.description, err) t.Fatalf("%s: error building url: %v", testCase.description, err)
} }
var expectedURL string var expectedURL string
proto, ok := tr.request.Header["X-Forwarded-Proto"] proto, ok := tr.request.Header["X-Forwarded-Proto"]
if !ok { if !ok {
expectedURL = tr.base[0:len(tr.base)-1] + testCase.expectedPath expectedURL = testCase.expectedPath
if !relative {
expectedURL = tr.base[0:len(tr.base)-1] + expectedURL
}
} else { } else {
urlBase, err := url.Parse(tr.base) urlBase, err := url.Parse(tr.base)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
urlBase.Scheme = proto[0] urlBase.Scheme = proto[0]
expectedURL = urlBase.String()[0:len(urlBase.String())-1] + testCase.expectedPath expectedURL = testCase.expectedPath
if !relative {
expectedURL = urlBase.String()[0:len(urlBase.String())-1] + expectedURL
}
} }
if buildURL != expectedURL { if buildURL != expectedURL {

View file

@ -46,7 +46,7 @@ func (htpasswd *htpasswd) authenticateUser(username string, password string) err
// parseHTPasswd parses the contents of htpasswd. This will read all the // parseHTPasswd parses the contents of htpasswd. This will read all the
// entries in the file, whether or not they are needed. An error is returned // entries in the file, whether or not they are needed. An error is returned
// if an syntax errors are encountered or if the reader fails. // if a syntax errors are encountered or if the reader fails.
func parseHTPasswd(rd io.Reader) (map[string][]byte, error) { func parseHTPasswd(rd io.Reader) (map[string][]byte, error) {
entries := map[string][]byte{} entries := map[string][]byte{}
scanner := bufio.NewScanner(rd) scanner := bufio.NewScanner(rd)

View file

@ -25,7 +25,7 @@ type Challenge struct {
type ChallengeManager interface { type ChallengeManager interface {
// GetChallenges returns the challenges for the given // GetChallenges returns the challenges for the given
// endpoint URL. // endpoint URL.
GetChallenges(endpoint string) ([]Challenge, error) GetChallenges(endpoint url.URL) ([]Challenge, error)
// AddResponse adds the response to the challenge // AddResponse adds the response to the challenge
// manager. The challenges will be parsed out of // manager. The challenges will be parsed out of
@ -48,8 +48,10 @@ func NewSimpleChallengeManager() ChallengeManager {
type simpleChallengeManager map[string][]Challenge type simpleChallengeManager map[string][]Challenge
func (m simpleChallengeManager) GetChallenges(endpoint string) ([]Challenge, error) { func (m simpleChallengeManager) GetChallenges(endpoint url.URL) ([]Challenge, error) {
challenges := m[endpoint] endpoint.Host = strings.ToLower(endpoint.Host)
challenges := m[endpoint.String()]
return challenges, nil return challenges, nil
} }
@ -60,11 +62,10 @@ func (m simpleChallengeManager) AddResponse(resp *http.Response) error {
} }
urlCopy := url.URL{ urlCopy := url.URL{
Path: resp.Request.URL.Path, Path: resp.Request.URL.Path,
Host: resp.Request.URL.Host, Host: strings.ToLower(resp.Request.URL.Host),
Scheme: resp.Request.URL.Scheme, Scheme: resp.Request.URL.Scheme,
} }
m[urlCopy.String()] = challenges m[urlCopy.String()] = challenges
return nil return nil
} }

View file

@ -1,7 +1,10 @@
package auth package auth
import ( import (
"fmt"
"net/http" "net/http"
"net/url"
"strings"
"testing" "testing"
) )
@ -36,3 +39,43 @@ func TestAuthChallengeParse(t *testing.T) {
} }
} }
func TestAuthChallengeNormalization(t *testing.T) {
testAuthChallengeNormalization(t, "reg.EXAMPLE.com")
testAuthChallengeNormalization(t, "bɿɒʜɔiɿ-ɿɘƚƨim-ƚol-ɒ-ƨʞnɒʜƚ.com")
}
func testAuthChallengeNormalization(t *testing.T, host string) {
scm := NewSimpleChallengeManager()
url, err := url.Parse(fmt.Sprintf("http://%s/v2/", host))
if err != nil {
t.Fatal(err)
}
resp := &http.Response{
Request: &http.Request{
URL: url,
},
Header: make(http.Header),
StatusCode: http.StatusUnauthorized,
}
resp.Header.Add("WWW-Authenticate", fmt.Sprintf("Bearer realm=\"https://%s/token\",service=\"registry.example.com\"", host))
err = scm.AddResponse(resp)
if err != nil {
t.Fatal(err)
}
lowered := *url
lowered.Host = strings.ToLower(lowered.Host)
c, err := scm.GetChallenges(lowered)
if err != nil {
t.Fatal(err)
}
if len(c) == 0 {
t.Fatal("Expected challenge for lower-cased-host URL")
}
}

View file

@ -15,9 +15,17 @@ import (
"github.com/docker/distribution/registry/client/transport" "github.com/docker/distribution/registry/client/transport"
) )
// ErrNoBasicAuthCredentials is returned if a request can't be authorized with var (
// basic auth due to lack of credentials. // ErrNoBasicAuthCredentials is returned if a request can't be authorized with
var ErrNoBasicAuthCredentials = errors.New("no basic auth credentials") // basic auth due to lack of credentials.
ErrNoBasicAuthCredentials = errors.New("no basic auth credentials")
// ErrNoToken is returned if a request is successful but the body does not
// contain an authorization token.
ErrNoToken = errors.New("authorization server did not include a token in the response")
)
const defaultClientID = "registry-client"
// AuthenticationHandler is an interface for authorizing a request from // AuthenticationHandler is an interface for authorizing a request from
// params from a "WWW-Authenicate" header for a single scheme. // params from a "WWW-Authenicate" header for a single scheme.
@ -36,6 +44,14 @@ type AuthenticationHandler interface {
type CredentialStore interface { type CredentialStore interface {
// Basic returns basic auth for the given URL // Basic returns basic auth for the given URL
Basic(*url.URL) (string, string) Basic(*url.URL) (string, string)
// RefreshToken returns a refresh token for the
// given URL and service
RefreshToken(*url.URL, string) string
// SetRefreshToken sets the refresh token if none
// is provided for the given url and service
SetRefreshToken(realm *url.URL, service, token string)
} }
// NewAuthorizer creates an authorizer which can handle multiple authentication // NewAuthorizer creates an authorizer which can handle multiple authentication
@ -67,9 +83,7 @@ func (ea *endpointAuthorizer) ModifyRequest(req *http.Request) error {
Path: req.URL.Path[:v2Root+4], Path: req.URL.Path[:v2Root+4],
} }
pingEndpoint := ping.String() challenges, err := ea.challenges.GetChallenges(ping)
challenges, err := ea.challenges.GetChallenges(pingEndpoint)
if err != nil { if err != nil {
return err return err
} }
@ -105,27 +119,47 @@ type clock interface {
type tokenHandler struct { type tokenHandler struct {
header http.Header header http.Header
creds CredentialStore creds CredentialStore
scope tokenScope
transport http.RoundTripper transport http.RoundTripper
clock clock clock clock
offlineAccess bool
forceOAuth bool
clientID string
scopes []Scope
tokenLock sync.Mutex tokenLock sync.Mutex
tokenCache string tokenCache string
tokenExpiration time.Time tokenExpiration time.Time
additionalScopes map[string]struct{}
} }
// tokenScope represents the scope at which a token will be requested. // Scope is a type which is serializable to a string
// This represents a specific action on a registry resource. // using the allow scope grammar.
type tokenScope struct { type Scope interface {
Resource string String() string
Scope string }
// RepositoryScope represents a token scope for access
// to a repository.
type RepositoryScope struct {
Repository string
Actions []string Actions []string
} }
func (ts tokenScope) String() string { // String returns the string representation of the repository
return fmt.Sprintf("%s:%s:%s", ts.Resource, ts.Scope, strings.Join(ts.Actions, ",")) // using the scope grammar
func (rs RepositoryScope) String() string {
return fmt.Sprintf("repository:%s:%s", rs.Repository, strings.Join(rs.Actions, ","))
}
// TokenHandlerOptions is used to configure a new token handler
type TokenHandlerOptions struct {
Transport http.RoundTripper
Credentials CredentialStore
OfflineAccess bool
ForceOAuth bool
ClientID string
Scopes []Scope
} }
// An implementation of clock for providing real time data. // An implementation of clock for providing real time data.
@ -137,22 +171,33 @@ func (realClock) Now() time.Time { return time.Now() }
// NewTokenHandler creates a new AuthenicationHandler which supports // NewTokenHandler creates a new AuthenicationHandler which supports
// fetching tokens from a remote token server. // fetching tokens from a remote token server.
func NewTokenHandler(transport http.RoundTripper, creds CredentialStore, scope string, actions ...string) AuthenticationHandler { func NewTokenHandler(transport http.RoundTripper, creds CredentialStore, scope string, actions ...string) AuthenticationHandler {
return newTokenHandler(transport, creds, realClock{}, scope, actions...) // Create options...
} return NewTokenHandlerWithOptions(TokenHandlerOptions{
Transport: transport,
// newTokenHandler exposes the option to provide a clock to manipulate time in unit testing. Credentials: creds,
func newTokenHandler(transport http.RoundTripper, creds CredentialStore, c clock, scope string, actions ...string) AuthenticationHandler { Scopes: []Scope{
return &tokenHandler{ RepositoryScope{
transport: transport, Repository: scope,
creds: creds,
clock: c,
scope: tokenScope{
Resource: "repository",
Scope: scope,
Actions: actions, Actions: actions,
}, },
additionalScopes: map[string]struct{}{}, },
})
}
// NewTokenHandlerWithOptions creates a new token handler using the provided
// options structure.
func NewTokenHandlerWithOptions(options TokenHandlerOptions) AuthenticationHandler {
handler := &tokenHandler{
transport: options.Transport,
creds: options.Credentials,
offlineAccess: options.OfflineAccess,
forceOAuth: options.ForceOAuth,
clientID: options.ClientID,
scopes: options.Scopes,
clock: realClock{},
} }
return handler
} }
func (th *tokenHandler) client() *http.Client { func (th *tokenHandler) client() *http.Client {
@ -169,122 +214,110 @@ func (th *tokenHandler) Scheme() string {
func (th *tokenHandler) AuthorizeRequest(req *http.Request, params map[string]string) error { func (th *tokenHandler) AuthorizeRequest(req *http.Request, params map[string]string) error {
var additionalScopes []string var additionalScopes []string
if fromParam := req.URL.Query().Get("from"); fromParam != "" { if fromParam := req.URL.Query().Get("from"); fromParam != "" {
additionalScopes = append(additionalScopes, tokenScope{ additionalScopes = append(additionalScopes, RepositoryScope{
Resource: "repository", Repository: fromParam,
Scope: fromParam,
Actions: []string{"pull"}, Actions: []string{"pull"},
}.String()) }.String())
} }
if err := th.refreshToken(params, additionalScopes...); err != nil {
token, err := th.getToken(params, additionalScopes...)
if err != nil {
return err return err
} }
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", th.tokenCache)) req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
return nil return nil
} }
func (th *tokenHandler) refreshToken(params map[string]string, additionalScopes ...string) error { func (th *tokenHandler) getToken(params map[string]string, additionalScopes ...string) (string, error) {
th.tokenLock.Lock() th.tokenLock.Lock()
defer th.tokenLock.Unlock() defer th.tokenLock.Unlock()
scopes := make([]string, 0, len(th.scopes)+len(additionalScopes))
for _, scope := range th.scopes {
scopes = append(scopes, scope.String())
}
var addedScopes bool var addedScopes bool
for _, scope := range additionalScopes { for _, scope := range additionalScopes {
if _, ok := th.additionalScopes[scope]; !ok { scopes = append(scopes, scope)
th.additionalScopes[scope] = struct{}{}
addedScopes = true addedScopes = true
} }
}
now := th.clock.Now() now := th.clock.Now()
if now.After(th.tokenExpiration) || addedScopes { if now.After(th.tokenExpiration) || addedScopes {
tr, err := th.fetchToken(params) token, expiration, err := th.fetchToken(params, scopes)
if err != nil { if err != nil {
return err return "", err
}
th.tokenCache = tr.Token
th.tokenExpiration = tr.IssuedAt.Add(time.Duration(tr.ExpiresIn) * time.Second)
} }
return nil // do not update cache for added scope tokens
if !addedScopes {
th.tokenCache = token
th.tokenExpiration = expiration
}
return token, nil
}
return th.tokenCache, nil
} }
type tokenResponse struct { type postTokenResponse struct {
Token string `json:"token"`
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"` ExpiresIn int `json:"expires_in"`
IssuedAt time.Time `json:"issued_at"` IssuedAt time.Time `json:"issued_at"`
Scope string `json:"scope"`
} }
func (th *tokenHandler) fetchToken(params map[string]string) (token *tokenResponse, err error) { func (th *tokenHandler) fetchTokenWithOAuth(realm *url.URL, refreshToken, service string, scopes []string) (token string, expiration time.Time, err error) {
realm, ok := params["realm"] form := url.Values{}
if !ok { form.Set("scope", strings.Join(scopes, " "))
return nil, errors.New("no realm specified for token auth challenge") form.Set("service", service)
clientID := th.clientID
if clientID == "" {
// Use default client, this is a required field
clientID = defaultClientID
}
form.Set("client_id", clientID)
if refreshToken != "" {
form.Set("grant_type", "refresh_token")
form.Set("refresh_token", refreshToken)
} else if th.creds != nil {
form.Set("grant_type", "password")
username, password := th.creds.Basic(realm)
form.Set("username", username)
form.Set("password", password)
// attempt to get a refresh token
form.Set("access_type", "offline")
} else {
// refuse to do oauth without a grant type
return "", time.Time{}, fmt.Errorf("no supported grant type")
} }
// TODO(dmcgowan): Handle empty scheme resp, err := th.client().PostForm(realm.String(), form)
realmURL, err := url.Parse(realm)
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid token auth challenge realm: %s", err) return "", time.Time{}, err
}
req, err := http.NewRequest("GET", realmURL.String(), nil)
if err != nil {
return nil, err
}
reqParams := req.URL.Query()
service := params["service"]
scope := th.scope.String()
if service != "" {
reqParams.Add("service", service)
}
for _, scopeField := range strings.Fields(scope) {
reqParams.Add("scope", scopeField)
}
for scope := range th.additionalScopes {
reqParams.Add("scope", scope)
}
if th.creds != nil {
username, password := th.creds.Basic(realmURL)
if username != "" && password != "" {
reqParams.Add("account", username)
req.SetBasicAuth(username, password)
}
}
req.URL.RawQuery = reqParams.Encode()
resp, err := th.client().Do(req)
if err != nil {
return nil, err
} }
defer resp.Body.Close() defer resp.Body.Close()
if !client.SuccessStatus(resp.StatusCode) { if !client.SuccessStatus(resp.StatusCode) {
err := client.HandleErrorResponse(resp) err := client.HandleErrorResponse(resp)
return nil, err return "", time.Time{}, err
} }
decoder := json.NewDecoder(resp.Body) decoder := json.NewDecoder(resp.Body)
tr := new(tokenResponse) var tr postTokenResponse
if err = decoder.Decode(tr); err != nil { if err = decoder.Decode(&tr); err != nil {
return nil, fmt.Errorf("unable to decode token response: %s", err) return "", time.Time{}, fmt.Errorf("unable to decode token response: %s", err)
} }
// `access_token` is equivalent to `token` and if both are specified if tr.RefreshToken != "" && tr.RefreshToken != refreshToken {
// the choice is undefined. Canonicalize `access_token` by sticking th.creds.SetRefreshToken(realm, service, tr.RefreshToken)
// things in `token`.
if tr.AccessToken != "" {
tr.Token = tr.AccessToken
}
if tr.Token == "" {
return nil, errors.New("authorization server did not include a token in the response")
} }
if tr.ExpiresIn < minimumTokenLifetimeSeconds { if tr.ExpiresIn < minimumTokenLifetimeSeconds {
@ -295,10 +328,128 @@ func (th *tokenHandler) fetchToken(params map[string]string) (token *tokenRespon
if tr.IssuedAt.IsZero() { if tr.IssuedAt.IsZero() {
// issued_at is optional in the token response. // issued_at is optional in the token response.
tr.IssuedAt = th.clock.Now() tr.IssuedAt = th.clock.Now().UTC()
} }
return tr, nil return tr.AccessToken, tr.IssuedAt.Add(time.Duration(tr.ExpiresIn) * time.Second), nil
}
type getTokenResponse struct {
Token string `json:"token"`
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
IssuedAt time.Time `json:"issued_at"`
RefreshToken string `json:"refresh_token"`
}
func (th *tokenHandler) fetchTokenWithBasicAuth(realm *url.URL, service string, scopes []string) (token string, expiration time.Time, err error) {
req, err := http.NewRequest("GET", realm.String(), nil)
if err != nil {
return "", time.Time{}, err
}
reqParams := req.URL.Query()
if service != "" {
reqParams.Add("service", service)
}
for _, scope := range scopes {
reqParams.Add("scope", scope)
}
if th.offlineAccess {
reqParams.Add("offline_token", "true")
clientID := th.clientID
if clientID == "" {
clientID = defaultClientID
}
reqParams.Add("client_id", clientID)
}
if th.creds != nil {
username, password := th.creds.Basic(realm)
if username != "" && password != "" {
reqParams.Add("account", username)
req.SetBasicAuth(username, password)
}
}
req.URL.RawQuery = reqParams.Encode()
resp, err := th.client().Do(req)
if err != nil {
return "", time.Time{}, err
}
defer resp.Body.Close()
if !client.SuccessStatus(resp.StatusCode) {
err := client.HandleErrorResponse(resp)
return "", time.Time{}, err
}
decoder := json.NewDecoder(resp.Body)
var tr getTokenResponse
if err = decoder.Decode(&tr); err != nil {
return "", time.Time{}, fmt.Errorf("unable to decode token response: %s", err)
}
if tr.RefreshToken != "" && th.creds != nil {
th.creds.SetRefreshToken(realm, service, tr.RefreshToken)
}
// `access_token` is equivalent to `token` and if both are specified
// the choice is undefined. Canonicalize `access_token` by sticking
// things in `token`.
if tr.AccessToken != "" {
tr.Token = tr.AccessToken
}
if tr.Token == "" {
return "", time.Time{}, ErrNoToken
}
if tr.ExpiresIn < minimumTokenLifetimeSeconds {
// The default/minimum lifetime.
tr.ExpiresIn = minimumTokenLifetimeSeconds
logrus.Debugf("Increasing token expiration to: %d seconds", tr.ExpiresIn)
}
if tr.IssuedAt.IsZero() {
// issued_at is optional in the token response.
tr.IssuedAt = th.clock.Now().UTC()
}
return tr.Token, tr.IssuedAt.Add(time.Duration(tr.ExpiresIn) * time.Second), nil
}
func (th *tokenHandler) fetchToken(params map[string]string, scopes []string) (token string, expiration time.Time, err error) {
realm, ok := params["realm"]
if !ok {
return "", time.Time{}, errors.New("no realm specified for token auth challenge")
}
// TODO(dmcgowan): Handle empty scheme and relative realm
realmURL, err := url.Parse(realm)
if err != nil {
return "", time.Time{}, fmt.Errorf("invalid token auth challenge realm: %s", err)
}
service := params["service"]
var refreshToken string
if th.creds != nil {
refreshToken = th.creds.RefreshToken(realmURL, service)
}
if refreshToken != "" || th.forceOAuth {
return th.fetchTokenWithOAuth(realmURL, refreshToken, service, scopes)
}
return th.fetchTokenWithBasicAuth(realmURL, service, scopes)
} }
type basicHandler struct { type basicHandler struct {

View file

@ -82,12 +82,23 @@ func ping(manager ChallengeManager, endpoint, versionHeader string) ([]APIVersio
type testCredentialStore struct { type testCredentialStore struct {
username string username string
password string password string
refreshTokens map[string]string
} }
func (tcs *testCredentialStore) Basic(*url.URL) (string, string) { func (tcs *testCredentialStore) Basic(*url.URL) (string, string) {
return tcs.username, tcs.password return tcs.username, tcs.password
} }
func (tcs *testCredentialStore) RefreshToken(u *url.URL, service string) string {
return tcs.refreshTokens[service]
}
func (tcs *testCredentialStore) SetRefreshToken(u *url.URL, service string, token string) {
if tcs.refreshTokens != nil {
tcs.refreshTokens[service] = token
}
}
func TestEndpointAuthorizeToken(t *testing.T) { func TestEndpointAuthorizeToken(t *testing.T) {
service := "localhost.localdomain" service := "localhost.localdomain"
repo1 := "some/registry" repo1 := "some/registry"
@ -162,14 +173,11 @@ func TestEndpointAuthorizeToken(t *testing.T) {
t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusAccepted) t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusAccepted)
} }
badCheck := func(a string) bool { e2, c2 := testServerWithAuth(m, authenicate, validCheck)
return a == "Bearer statictoken"
}
e2, c2 := testServerWithAuth(m, authenicate, badCheck)
defer c2() defer c2()
challengeManager2 := NewSimpleChallengeManager() challengeManager2 := NewSimpleChallengeManager()
versions, err = ping(challengeManager2, e+"/v2/", "x-multi-api-version") versions, err = ping(challengeManager2, e2+"/v2/", "x-multi-api-version")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -199,6 +207,161 @@ func TestEndpointAuthorizeToken(t *testing.T) {
} }
} }
func TestEndpointAuthorizeRefreshToken(t *testing.T) {
service := "localhost.localdomain"
repo1 := "some/registry"
repo2 := "other/registry"
scope1 := fmt.Sprintf("repository:%s:pull,push", repo1)
scope2 := fmt.Sprintf("repository:%s:pull,push", repo2)
refreshToken1 := "0123456790abcdef"
refreshToken2 := "0123456790fedcba"
tokenMap := testutil.RequestResponseMap([]testutil.RequestResponseMapping{
{
Request: testutil.Request{
Method: "POST",
Route: "/token",
Body: []byte(fmt.Sprintf("client_id=registry-client&grant_type=refresh_token&refresh_token=%s&scope=%s&service=%s", refreshToken1, url.QueryEscape(scope1), service)),
},
Response: testutil.Response{
StatusCode: http.StatusOK,
Body: []byte(fmt.Sprintf(`{"access_token":"statictoken","refresh_token":"%s"}`, refreshToken1)),
},
},
{
// In the future this test may fail and require using basic auth to get a different refresh token
Request: testutil.Request{
Method: "POST",
Route: "/token",
Body: []byte(fmt.Sprintf("client_id=registry-client&grant_type=refresh_token&refresh_token=%s&scope=%s&service=%s", refreshToken1, url.QueryEscape(scope2), service)),
},
Response: testutil.Response{
StatusCode: http.StatusOK,
Body: []byte(fmt.Sprintf(`{"access_token":"statictoken","refresh_token":"%s"}`, refreshToken2)),
},
},
{
Request: testutil.Request{
Method: "POST",
Route: "/token",
Body: []byte(fmt.Sprintf("client_id=registry-client&grant_type=refresh_token&refresh_token=%s&scope=%s&service=%s", refreshToken2, url.QueryEscape(scope2), service)),
},
Response: testutil.Response{
StatusCode: http.StatusOK,
Body: []byte(`{"access_token":"badtoken","refresh_token":"%s"}`),
},
},
})
te, tc := testServer(tokenMap)
defer tc()
m := testutil.RequestResponseMap([]testutil.RequestResponseMapping{
{
Request: testutil.Request{
Method: "GET",
Route: "/v2/hello",
},
Response: testutil.Response{
StatusCode: http.StatusAccepted,
},
},
})
authenicate := fmt.Sprintf("Bearer realm=%q,service=%q", te+"/token", service)
validCheck := func(a string) bool {
return a == "Bearer statictoken"
}
e, c := testServerWithAuth(m, authenicate, validCheck)
defer c()
challengeManager1 := NewSimpleChallengeManager()
versions, err := ping(challengeManager1, e+"/v2/", "x-api-version")
if err != nil {
t.Fatal(err)
}
if len(versions) != 1 {
t.Fatalf("Unexpected version count: %d, expected 1", len(versions))
}
if check := (APIVersion{Type: "registry", Version: "2.0"}); versions[0] != check {
t.Fatalf("Unexpected api version: %#v, expected %#v", versions[0], check)
}
creds := &testCredentialStore{
refreshTokens: map[string]string{
service: refreshToken1,
},
}
transport1 := transport.NewTransport(nil, NewAuthorizer(challengeManager1, NewTokenHandler(nil, creds, repo1, "pull", "push")))
client := &http.Client{Transport: transport1}
req, _ := http.NewRequest("GET", e+"/v2/hello", nil)
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Error sending get request: %s", err)
}
if resp.StatusCode != http.StatusAccepted {
t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusAccepted)
}
// Try with refresh token setting
e2, c2 := testServerWithAuth(m, authenicate, validCheck)
defer c2()
challengeManager2 := NewSimpleChallengeManager()
versions, err = ping(challengeManager2, e2+"/v2/", "x-api-version")
if err != nil {
t.Fatal(err)
}
if len(versions) != 1 {
t.Fatalf("Unexpected version count: %d, expected 1", len(versions))
}
if check := (APIVersion{Type: "registry", Version: "2.0"}); versions[0] != check {
t.Fatalf("Unexpected api version: %#v, expected %#v", versions[0], check)
}
transport2 := transport.NewTransport(nil, NewAuthorizer(challengeManager2, NewTokenHandler(nil, creds, repo2, "pull", "push")))
client2 := &http.Client{Transport: transport2}
req, _ = http.NewRequest("GET", e2+"/v2/hello", nil)
resp, err = client2.Do(req)
if err != nil {
t.Fatalf("Error sending get request: %s", err)
}
if resp.StatusCode != http.StatusAccepted {
t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusUnauthorized)
}
if creds.refreshTokens[service] != refreshToken2 {
t.Fatalf("Refresh token not set after change")
}
// Try with bad token
e3, c3 := testServerWithAuth(m, authenicate, validCheck)
defer c3()
challengeManager3 := NewSimpleChallengeManager()
versions, err = ping(challengeManager3, e3+"/v2/", "x-api-version")
if err != nil {
t.Fatal(err)
}
if check := (APIVersion{Type: "registry", Version: "2.0"}); versions[0] != check {
t.Fatalf("Unexpected api version: %#v, expected %#v", versions[0], check)
}
transport3 := transport.NewTransport(nil, NewAuthorizer(challengeManager3, NewTokenHandler(nil, creds, repo2, "pull", "push")))
client3 := &http.Client{Transport: transport3}
req, _ = http.NewRequest("GET", e3+"/v2/hello", nil)
resp, err = client3.Do(req)
if err != nil {
t.Fatalf("Error sending get request: %s", err)
}
if resp.StatusCode != http.StatusUnauthorized {
t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusUnauthorized)
}
}
func basicAuth(username, password string) string { func basicAuth(username, password string) string {
auth := username + ":" + password auth := username + ":" + password
return base64.StdEncoding.EncodeToString([]byte(auth)) return base64.StdEncoding.EncodeToString([]byte(auth))
@ -379,7 +542,19 @@ func TestEndpointAuthorizeTokenBasicWithExpiresIn(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
clock := &fakeClock{current: time.Now()} clock := &fakeClock{current: time.Now()}
transport1 := transport.NewTransport(nil, NewAuthorizer(challengeManager, newTokenHandler(nil, creds, clock, repo, "pull", "push"), NewBasicHandler(creds))) options := TokenHandlerOptions{
Transport: nil,
Credentials: creds,
Scopes: []Scope{
RepositoryScope{
Repository: repo,
Actions: []string{"pull", "push"},
},
},
}
tHandler := NewTokenHandlerWithOptions(options)
tHandler.(*tokenHandler).clock = clock
transport1 := transport.NewTransport(nil, NewAuthorizer(challengeManager, tHandler, NewBasicHandler(creds)))
client := &http.Client{Transport: transport1} client := &http.Client{Transport: transport1}
// First call should result in a token exchange // First call should result in a token exchange
@ -517,7 +692,20 @@ func TestEndpointAuthorizeTokenBasicWithExpiresInAndIssuedAt(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
transport1 := transport.NewTransport(nil, NewAuthorizer(challengeManager, newTokenHandler(nil, creds, clock, repo, "pull", "push"), NewBasicHandler(creds)))
options := TokenHandlerOptions{
Transport: nil,
Credentials: creds,
Scopes: []Scope{
RepositoryScope{
Repository: repo,
Actions: []string{"pull", "push"},
},
},
}
tHandler := NewTokenHandlerWithOptions(options)
tHandler.(*tokenHandler).clock = clock
transport1 := transport.NewTransport(nil, NewAuthorizer(challengeManager, tHandler, NewBasicHandler(creds)))
client := &http.Client{Transport: transport1} client := &http.Client{Transport: transport1}
// First call should result in a token exchange // First call should result in a token exchange

View file

@ -6,7 +6,6 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"os"
"time" "time"
"github.com/docker/distribution" "github.com/docker/distribution"
@ -104,21 +103,8 @@ func (hbu *httpBlobUpload) Write(p []byte) (n int, err error) {
} }
func (hbu *httpBlobUpload) Seek(offset int64, whence int) (int64, error) { func (hbu *httpBlobUpload) Size() int64 {
newOffset := hbu.offset return hbu.offset
switch whence {
case os.SEEK_CUR:
newOffset += int64(offset)
case os.SEEK_END:
newOffset += int64(offset)
case os.SEEK_SET:
newOffset = int64(offset)
}
hbu.offset = newOffset
return hbu.offset, nil
} }
func (hbu *httpBlobUpload) ID() string { func (hbu *httpBlobUpload) ID() string {

View file

@ -2,6 +2,7 @@ package client
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
@ -10,6 +11,10 @@ import (
"github.com/docker/distribution/registry/api/errcode" "github.com/docker/distribution/registry/api/errcode"
) )
// ErrNoErrorsInBody is returned when an HTTP response body parses to an empty
// errcode.Errors slice.
var ErrNoErrorsInBody = errors.New("no error details found in HTTP response body")
// UnexpectedHTTPStatusError is returned when an unexpected HTTP status is // UnexpectedHTTPStatusError is returned when an unexpected HTTP status is
// returned when making a registry api call. // returned when making a registry api call.
type UnexpectedHTTPStatusError struct { type UnexpectedHTTPStatusError struct {
@ -17,18 +22,19 @@ type UnexpectedHTTPStatusError struct {
} }
func (e *UnexpectedHTTPStatusError) Error() string { func (e *UnexpectedHTTPStatusError) Error() string {
return fmt.Sprintf("Received unexpected HTTP status: %s", e.Status) return fmt.Sprintf("received unexpected HTTP status: %s", e.Status)
} }
// UnexpectedHTTPResponseError is returned when an expected HTTP status code // UnexpectedHTTPResponseError is returned when an expected HTTP status code
// is returned, but the content was unexpected and failed to be parsed. // is returned, but the content was unexpected and failed to be parsed.
type UnexpectedHTTPResponseError struct { type UnexpectedHTTPResponseError struct {
ParseErr error ParseErr error
StatusCode int
Response []byte Response []byte
} }
func (e *UnexpectedHTTPResponseError) Error() string { func (e *UnexpectedHTTPResponseError) Error() string {
return fmt.Sprintf("Error parsing HTTP response: %s: %q", e.ParseErr.Error(), string(e.Response)) return fmt.Sprintf("error parsing HTTP %d response body: %s: %q", e.StatusCode, e.ParseErr.Error(), string(e.Response))
} }
func parseHTTPErrorResponse(statusCode int, r io.Reader) error { func parseHTTPErrorResponse(statusCode int, r io.Reader) error {
@ -45,18 +51,37 @@ func parseHTTPErrorResponse(statusCode int, r io.Reader) error {
} }
err = json.Unmarshal(body, &detailsErr) err = json.Unmarshal(body, &detailsErr)
if err == nil && detailsErr.Details != "" { if err == nil && detailsErr.Details != "" {
if statusCode == http.StatusUnauthorized { switch statusCode {
case http.StatusUnauthorized:
return errcode.ErrorCodeUnauthorized.WithMessage(detailsErr.Details) return errcode.ErrorCodeUnauthorized.WithMessage(detailsErr.Details)
} // FIXME: go1.5 doesn't export http.StatusTooManyRequests while
// go1.6 does. Update the hardcoded value to the constant once
// Docker updates golang version to 1.6.
case 429:
return errcode.ErrorCodeTooManyRequests.WithMessage(detailsErr.Details)
default:
return errcode.ErrorCodeUnknown.WithMessage(detailsErr.Details) return errcode.ErrorCodeUnknown.WithMessage(detailsErr.Details)
} }
}
if err := json.Unmarshal(body, &errors); err != nil { if err := json.Unmarshal(body, &errors); err != nil {
return &UnexpectedHTTPResponseError{ return &UnexpectedHTTPResponseError{
ParseErr: err, ParseErr: err,
StatusCode: statusCode,
Response: body, Response: body,
} }
} }
if len(errors) == 0 {
// If there was no error specified in the body, return
// UnexpectedHTTPResponseError.
return &UnexpectedHTTPResponseError{
ParseErr: ErrNoErrorsInBody,
StatusCode: statusCode,
Response: body,
}
}
return errors return errors
} }

View file

@ -59,6 +59,21 @@ func TestHandleErrorResponseExpectedStatusCode400ValidBody(t *testing.T) {
} }
} }
func TestHandleErrorResponseExpectedStatusCode404EmptyErrorSlice(t *testing.T) {
json := `{"randomkey": "randomvalue"}`
response := &http.Response{
Status: "404 Not Found",
StatusCode: 404,
Body: nopCloser{bytes.NewBufferString(json)},
}
err := HandleErrorResponse(response)
expectedMsg := `error parsing HTTP 404 response body: no error details found in HTTP response body: "{\"randomkey\": \"randomvalue\"}"`
if !strings.Contains(err.Error(), expectedMsg) {
t.Errorf("Expected \"%s\", got: \"%s\"", expectedMsg, err.Error())
}
}
func TestHandleErrorResponseExpectedStatusCode404InvalidBody(t *testing.T) { func TestHandleErrorResponseExpectedStatusCode404InvalidBody(t *testing.T) {
json := "{invalid json}" json := "{invalid json}"
response := &http.Response{ response := &http.Response{
@ -68,7 +83,7 @@ func TestHandleErrorResponseExpectedStatusCode404InvalidBody(t *testing.T) {
} }
err := HandleErrorResponse(response) err := HandleErrorResponse(response)
expectedMsg := "Error parsing HTTP response: invalid character 'i' looking for beginning of object key string: \"{invalid json}\"" expectedMsg := "error parsing HTTP 404 response body: invalid character 'i' looking for beginning of object key string: \"{invalid json}\""
if !strings.Contains(err.Error(), expectedMsg) { if !strings.Contains(err.Error(), expectedMsg) {
t.Errorf("Expected \"%s\", got: \"%s\"", expectedMsg, err.Error()) t.Errorf("Expected \"%s\", got: \"%s\"", expectedMsg, err.Error())
} }
@ -82,7 +97,7 @@ func TestHandleErrorResponseUnexpectedStatusCode501(t *testing.T) {
} }
err := HandleErrorResponse(response) err := HandleErrorResponse(response)
expectedMsg := "Received unexpected HTTP status: 501 Not Implemented" expectedMsg := "received unexpected HTTP status: 501 Not Implemented"
if !strings.Contains(err.Error(), expectedMsg) { if !strings.Contains(err.Error(), expectedMsg) {
t.Errorf("Expected \"%s\", got: \"%s\"", expectedMsg, err.Error()) t.Errorf("Expected \"%s\", got: \"%s\"", expectedMsg, err.Error())
} }

View file

@ -62,7 +62,7 @@ func checkHTTPRedirect(req *http.Request, via []*http.Request) error {
// NewRegistry creates a registry namespace which can be used to get a listing of repositories // NewRegistry creates a registry namespace which can be used to get a listing of repositories
func NewRegistry(ctx context.Context, baseURL string, transport http.RoundTripper) (Registry, error) { func NewRegistry(ctx context.Context, baseURL string, transport http.RoundTripper) (Registry, error) {
ub, err := v2.NewURLBuilderFromString(baseURL) ub, err := v2.NewURLBuilderFromString(baseURL, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -133,7 +133,7 @@ func (r *registry) Repositories(ctx context.Context, entries []string, last stri
// NewRepository creates a new Repository for the given repository name and base URL. // NewRepository creates a new Repository for the given repository name and base URL.
func NewRepository(ctx context.Context, name reference.Named, baseURL string, transport http.RoundTripper) (distribution.Repository, error) { func NewRepository(ctx context.Context, name reference.Named, baseURL string, transport http.RoundTripper) (distribution.Repository, error) {
ub, err := v2.NewURLBuilderFromString(baseURL) ub, err := v2.NewURLBuilderFromString(baseURL, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -308,6 +308,7 @@ check:
if err != nil { if err != nil {
return distribution.Descriptor{}, err return distribution.Descriptor{}, err
} }
defer resp.Body.Close()
switch { switch {
case resp.StatusCode >= 200 && resp.StatusCode < 400: case resp.StatusCode >= 200 && resp.StatusCode < 400:
@ -401,9 +402,9 @@ func (ms *manifests) Get(ctx context.Context, dgst digest.Digest, options ...dis
) )
for _, option := range options { for _, option := range options {
if opt, ok := option.(withTagOption); ok { if opt, ok := option.(distribution.WithTagOption); ok {
digestOrTag = opt.tag digestOrTag = opt.Tag
ref, err = reference.WithTag(ms.name, opt.tag) ref, err = reference.WithTag(ms.name, opt.Tag)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -464,21 +465,6 @@ func (ms *manifests) Get(ctx context.Context, dgst digest.Digest, options ...dis
return nil, HandleErrorResponse(resp) return nil, HandleErrorResponse(resp)
} }
// WithTag allows a tag to be passed into Put which enables the client
// to build a correct URL.
func WithTag(tag string) distribution.ManifestServiceOption {
return withTagOption{tag}
}
type withTagOption struct{ tag string }
func (o withTagOption) Apply(m distribution.ManifestService) error {
if _, ok := m.(*manifests); ok {
return nil
}
return fmt.Errorf("withTagOption is a client-only option")
}
// Put puts a manifest. A tag can be specified using an options parameter which uses some shared state to hold the // Put puts a manifest. A tag can be specified using an options parameter which uses some shared state to hold the
// tag name in order to build the correct upload URL. // tag name in order to build the correct upload URL.
func (ms *manifests) Put(ctx context.Context, m distribution.Manifest, options ...distribution.ManifestServiceOption) (digest.Digest, error) { func (ms *manifests) Put(ctx context.Context, m distribution.Manifest, options ...distribution.ManifestServiceOption) (digest.Digest, error) {
@ -486,9 +472,9 @@ func (ms *manifests) Put(ctx context.Context, m distribution.Manifest, options .
var tagged bool var tagged bool
for _, option := range options { for _, option := range options {
if opt, ok := option.(withTagOption); ok { if opt, ok := option.(distribution.WithTagOption); ok {
var err error var err error
ref, err = reference.WithTag(ref, opt.tag) ref, err = reference.WithTag(ref, opt.Tag)
if err != nil { if err != nil {
return "", err return "", err
} }

View file

@ -710,7 +710,7 @@ func TestV1ManifestFetch(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
manifest, err = ms.Get(ctx, dgst, WithTag("latest")) manifest, err = ms.Get(ctx, dgst, distribution.WithTag("latest"))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -723,7 +723,7 @@ func TestV1ManifestFetch(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
manifest, err = ms.Get(ctx, dgst, WithTag("badcontenttype")) manifest, err = ms.Get(ctx, dgst, distribution.WithTag("badcontenttype"))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -761,7 +761,7 @@ func TestManifestFetchWithEtag(t *testing.T) {
if !ok { if !ok {
panic("wrong type for client manifest service") panic("wrong type for client manifest service")
} }
_, err = clientManifestService.Get(ctx, d1, WithTag("latest"), AddEtagToTag("latest", d1.String())) _, err = clientManifestService.Get(ctx, d1, distribution.WithTag("latest"), AddEtagToTag("latest", d1.String()))
if err != distribution.ErrManifestNotModified { if err != distribution.ErrManifestNotModified {
t.Fatal(err) t.Fatal(err)
} }
@ -861,7 +861,7 @@ func TestManifestPut(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
if _, err := ms.Put(ctx, m1, WithTag(m1.Tag)); err != nil { if _, err := ms.Put(ctx, m1, distribution.WithTag(m1.Tag)); err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -43,7 +43,6 @@ var headerConfig = http.Header{
// 200 OK response. // 200 OK response.
func TestCheckAPI(t *testing.T) { func TestCheckAPI(t *testing.T) {
env := newTestEnv(t, false) env := newTestEnv(t, false)
baseURL, err := env.builder.BuildBaseURL() baseURL, err := env.builder.BuildBaseURL()
if err != nil { if err != nil {
t.Fatalf("unexpected error building base url: %v", err) t.Fatalf("unexpected error building base url: %v", err)
@ -294,6 +293,79 @@ func TestBlobDelete(t *testing.T) {
testBlobDelete(t, env, args) testBlobDelete(t, env, args)
} }
func TestRelativeURL(t *testing.T) {
config := configuration.Configuration{
Storage: configuration.Storage{
"inmemory": configuration.Parameters{},
},
}
config.HTTP.Headers = headerConfig
config.HTTP.RelativeURLs = false
env := newTestEnvWithConfig(t, &config)
ref, _ := reference.WithName("foo/bar")
uploadURLBaseAbs, _ := startPushLayer(t, env, ref)
u, err := url.Parse(uploadURLBaseAbs)
if err != nil {
t.Fatal(err)
}
if !u.IsAbs() {
t.Fatal("Relative URL returned from blob upload chunk with non-relative configuration")
}
args := makeBlobArgs(t)
resp, err := doPushLayer(t, env.builder, ref, args.layerDigest, uploadURLBaseAbs, args.layerFile)
if err != nil {
t.Fatalf("unexpected error doing layer push relative url: %v", err)
}
checkResponse(t, "relativeurl blob upload", resp, http.StatusCreated)
u, err = url.Parse(resp.Header.Get("Location"))
if err != nil {
t.Fatal(err)
}
if !u.IsAbs() {
t.Fatal("Relative URL returned from blob upload with non-relative configuration")
}
config.HTTP.RelativeURLs = true
args = makeBlobArgs(t)
uploadURLBaseRelative, _ := startPushLayer(t, env, ref)
u, err = url.Parse(uploadURLBaseRelative)
if err != nil {
t.Fatal(err)
}
if u.IsAbs() {
t.Fatal("Absolute URL returned from blob upload chunk with relative configuration")
}
// Start a new upload in absolute mode to get a valid base URL
config.HTTP.RelativeURLs = false
uploadURLBaseAbs, _ = startPushLayer(t, env, ref)
u, err = url.Parse(uploadURLBaseAbs)
if err != nil {
t.Fatal(err)
}
if !u.IsAbs() {
t.Fatal("Relative URL returned from blob upload chunk with non-relative configuration")
}
// Complete upload with relative URLs enabled to ensure the final location is relative
config.HTTP.RelativeURLs = true
resp, err = doPushLayer(t, env.builder, ref, args.layerDigest, uploadURLBaseAbs, args.layerFile)
if err != nil {
t.Fatalf("unexpected error doing layer push relative url: %v", err)
}
checkResponse(t, "relativeurl blob upload", resp, http.StatusCreated)
u, err = url.Parse(resp.Header.Get("Location"))
if err != nil {
t.Fatal(err)
}
if u.IsAbs() {
t.Fatal("Relative URL returned from blob upload with non-relative configuration")
}
}
func TestBlobDeleteDisabled(t *testing.T) { func TestBlobDeleteDisabled(t *testing.T) {
deleteEnabled := false deleteEnabled := false
env := newTestEnv(t, deleteEnabled) env := newTestEnv(t, deleteEnabled)
@ -349,7 +421,7 @@ func testBlobAPI(t *testing.T, env *testEnv, args blobArgs) *testEnv {
// ------------------------------------------ // ------------------------------------------
// Start an upload, check the status then cancel // Start an upload, check the status then cancel
uploadURLBase, uploadUUID := startPushLayer(t, env.builder, imageName) uploadURLBase, uploadUUID := startPushLayer(t, env, imageName)
// A status check should work // A status check should work
resp, err = http.Get(uploadURLBase) resp, err = http.Get(uploadURLBase)
@ -384,7 +456,7 @@ func testBlobAPI(t *testing.T, env *testEnv, args blobArgs) *testEnv {
// ----------------------------------------- // -----------------------------------------
// Do layer push with an empty body and different digest // Do layer push with an empty body and different digest
uploadURLBase, uploadUUID = startPushLayer(t, env.builder, imageName) uploadURLBase, uploadUUID = startPushLayer(t, env, imageName)
resp, err = doPushLayer(t, env.builder, imageName, layerDigest, uploadURLBase, bytes.NewReader([]byte{})) resp, err = doPushLayer(t, env.builder, imageName, layerDigest, uploadURLBase, bytes.NewReader([]byte{}))
if err != nil { if err != nil {
t.Fatalf("unexpected error doing bad layer push: %v", err) t.Fatalf("unexpected error doing bad layer push: %v", err)
@ -400,7 +472,7 @@ func testBlobAPI(t *testing.T, env *testEnv, args blobArgs) *testEnv {
t.Fatalf("unexpected error digesting empty buffer: %v", err) t.Fatalf("unexpected error digesting empty buffer: %v", err)
} }
uploadURLBase, uploadUUID = startPushLayer(t, env.builder, imageName) uploadURLBase, uploadUUID = startPushLayer(t, env, imageName)
pushLayer(t, env.builder, imageName, zeroDigest, uploadURLBase, bytes.NewReader([]byte{})) pushLayer(t, env.builder, imageName, zeroDigest, uploadURLBase, bytes.NewReader([]byte{}))
// ----------------------------------------- // -----------------------------------------
@ -413,7 +485,7 @@ func testBlobAPI(t *testing.T, env *testEnv, args blobArgs) *testEnv {
t.Fatalf("unexpected error digesting empty tar: %v", err) t.Fatalf("unexpected error digesting empty tar: %v", err)
} }
uploadURLBase, uploadUUID = startPushLayer(t, env.builder, imageName) uploadURLBase, uploadUUID = startPushLayer(t, env, imageName)
pushLayer(t, env.builder, imageName, emptyDigest, uploadURLBase, bytes.NewReader(emptyTar)) pushLayer(t, env.builder, imageName, emptyDigest, uploadURLBase, bytes.NewReader(emptyTar))
// ------------------------------------------ // ------------------------------------------
@ -421,7 +493,7 @@ func testBlobAPI(t *testing.T, env *testEnv, args blobArgs) *testEnv {
layerLength, _ := layerFile.Seek(0, os.SEEK_END) layerLength, _ := layerFile.Seek(0, os.SEEK_END)
layerFile.Seek(0, os.SEEK_SET) layerFile.Seek(0, os.SEEK_SET)
uploadURLBase, uploadUUID = startPushLayer(t, env.builder, imageName) uploadURLBase, uploadUUID = startPushLayer(t, env, imageName)
pushLayer(t, env.builder, imageName, layerDigest, uploadURLBase, layerFile) pushLayer(t, env.builder, imageName, layerDigest, uploadURLBase, layerFile)
// ------------------------------------------ // ------------------------------------------
@ -435,7 +507,7 @@ func testBlobAPI(t *testing.T, env *testEnv, args blobArgs) *testEnv {
canonicalDigest := canonicalDigester.Digest() canonicalDigest := canonicalDigester.Digest()
layerFile.Seek(0, 0) layerFile.Seek(0, 0)
uploadURLBase, uploadUUID = startPushLayer(t, env.builder, imageName) uploadURLBase, uploadUUID = startPushLayer(t, env, imageName)
uploadURLBase, dgst := pushChunk(t, env.builder, imageName, uploadURLBase, layerFile, layerLength) uploadURLBase, dgst := pushChunk(t, env.builder, imageName, uploadURLBase, layerFile, layerLength)
finishUpload(t, env.builder, imageName, uploadURLBase, dgst) finishUpload(t, env.builder, imageName, uploadURLBase, dgst)
@ -585,7 +657,7 @@ func testBlobDelete(t *testing.T, env *testEnv, args blobArgs) {
// Reupload previously deleted blob // Reupload previously deleted blob
layerFile.Seek(0, os.SEEK_SET) layerFile.Seek(0, os.SEEK_SET)
uploadURLBase, _ := startPushLayer(t, env.builder, imageName) uploadURLBase, _ := startPushLayer(t, env, imageName)
pushLayer(t, env.builder, imageName, layerDigest, uploadURLBase, layerFile) pushLayer(t, env.builder, imageName, layerDigest, uploadURLBase, layerFile)
layerFile.Seek(0, os.SEEK_SET) layerFile.Seek(0, os.SEEK_SET)
@ -625,7 +697,7 @@ func TestDeleteDisabled(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Error building blob URL") t.Fatalf("Error building blob URL")
} }
uploadURLBase, _ := startPushLayer(t, env.builder, imageName) uploadURLBase, _ := startPushLayer(t, env, imageName)
pushLayer(t, env.builder, imageName, layerDigest, uploadURLBase, layerFile) pushLayer(t, env.builder, imageName, layerDigest, uploadURLBase, layerFile)
resp, err := httpDelete(layerURL) resp, err := httpDelete(layerURL)
@ -651,7 +723,7 @@ func TestDeleteReadOnly(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Error building blob URL") t.Fatalf("Error building blob URL")
} }
uploadURLBase, _ := startPushLayer(t, env.builder, imageName) uploadURLBase, _ := startPushLayer(t, env, imageName)
pushLayer(t, env.builder, imageName, layerDigest, uploadURLBase, layerFile) pushLayer(t, env.builder, imageName, layerDigest, uploadURLBase, layerFile)
env.app.readOnly = true env.app.readOnly = true
@ -854,7 +926,7 @@ func testManifestAPISchema1(t *testing.T, env *testEnv, imageName reference.Name
} }
// TODO(stevvooe): Add a test case where we take a mostly valid registry, // TODO(stevvooe): Add a test case where we take a mostly valid registry,
// tamper with the content and ensure that we get a unverified manifest // tamper with the content and ensure that we get an unverified manifest
// error. // error.
// Push 2 random layers // Push 2 random layers
@ -871,7 +943,7 @@ func testManifestAPISchema1(t *testing.T, env *testEnv, imageName reference.Name
expectedLayers[dgst] = rs expectedLayers[dgst] = rs
unsignedManifest.FSLayers[i].BlobSum = dgst unsignedManifest.FSLayers[i].BlobSum = dgst
uploadURLBase, _ := startPushLayer(t, env.builder, imageName) uploadURLBase, _ := startPushLayer(t, env, imageName)
pushLayer(t, env.builder, imageName, dgst, uploadURLBase, rs) pushLayer(t, env.builder, imageName, dgst, uploadURLBase, rs)
} }
@ -995,13 +1067,13 @@ func testManifestAPISchema1(t *testing.T, env *testEnv, imageName reference.Name
t.Fatalf("error decoding fetched manifest: %v", err) t.Fatalf("error decoding fetched manifest: %v", err)
} }
// check two signatures were roundtripped // check only 1 signature is returned
signatures, err = fetchedManifestByDigest.Signatures() signatures, err = fetchedManifestByDigest.Signatures()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(signatures) != 2 { if len(signatures) != 1 {
t.Fatalf("expected 2 signature from manifest, got: %d", len(signatures)) t.Fatalf("expected 2 signature from manifest, got: %d", len(signatures))
} }
@ -1177,7 +1249,7 @@ func testManifestAPISchema2(t *testing.T, env *testEnv, imageName reference.Name
}`) }`)
sampleConfigDigest := digest.FromBytes(sampleConfig) sampleConfigDigest := digest.FromBytes(sampleConfig)
uploadURLBase, _ := startPushLayer(t, env.builder, imageName) uploadURLBase, _ := startPushLayer(t, env, imageName)
pushLayer(t, env.builder, imageName, sampleConfigDigest, uploadURLBase, bytes.NewReader(sampleConfig)) pushLayer(t, env.builder, imageName, sampleConfigDigest, uploadURLBase, bytes.NewReader(sampleConfig))
manifest.Config.Digest = sampleConfigDigest manifest.Config.Digest = sampleConfigDigest
manifest.Config.Size = int64(len(sampleConfig)) manifest.Config.Size = int64(len(sampleConfig))
@ -1210,7 +1282,7 @@ func testManifestAPISchema2(t *testing.T, env *testEnv, imageName reference.Name
expectedLayers[dgst] = rs expectedLayers[dgst] = rs
manifest.Layers[i].Digest = dgst manifest.Layers[i].Digest = dgst
uploadURLBase, _ := startPushLayer(t, env.builder, imageName) uploadURLBase, _ := startPushLayer(t, env, imageName)
pushLayer(t, env.builder, imageName, dgst, uploadURLBase, rs) pushLayer(t, env.builder, imageName, dgst, uploadURLBase, rs)
} }
@ -1842,7 +1914,7 @@ func newTestEnvWithConfig(t *testing.T, config *configuration.Configuration) *te
app := NewApp(ctx, config) app := NewApp(ctx, config)
server := httptest.NewServer(handlers.CombinedLoggingHandler(os.Stderr, app)) server := httptest.NewServer(handlers.CombinedLoggingHandler(os.Stderr, app))
builder, err := v2.NewURLBuilderFromString(server.URL + config.HTTP.Prefix) builder, err := v2.NewURLBuilderFromString(server.URL+config.HTTP.Prefix, false)
if err != nil { if err != nil {
t.Fatalf("error creating url builder: %v", err) t.Fatalf("error creating url builder: %v", err)
@ -1904,21 +1976,33 @@ func putManifest(t *testing.T, msg, url, contentType string, v interface{}) *htt
return resp return resp
} }
func startPushLayer(t *testing.T, ub *v2.URLBuilder, name reference.Named) (location string, uuid string) { func startPushLayer(t *testing.T, env *testEnv, name reference.Named) (location string, uuid string) {
layerUploadURL, err := ub.BuildBlobUploadURL(name) layerUploadURL, err := env.builder.BuildBlobUploadURL(name)
if err != nil { if err != nil {
t.Fatalf("unexpected error building layer upload url: %v", err) t.Fatalf("unexpected error building layer upload url: %v", err)
} }
u, err := url.Parse(layerUploadURL)
if err != nil {
t.Fatalf("error parsing layer upload URL: %v", err)
}
base, err := url.Parse(env.server.URL)
if err != nil {
t.Fatalf("error parsing server URL: %v", err)
}
layerUploadURL = base.ResolveReference(u).String()
resp, err := http.Post(layerUploadURL, "", nil) resp, err := http.Post(layerUploadURL, "", nil)
if err != nil { if err != nil {
t.Fatalf("unexpected error starting layer push: %v", err) t.Fatalf("unexpected error starting layer push: %v", err)
} }
defer resp.Body.Close() defer resp.Body.Close()
checkResponse(t, fmt.Sprintf("pushing starting layer push %v", name.String()), resp, http.StatusAccepted) checkResponse(t, fmt.Sprintf("pushing starting layer push %v", name.String()), resp, http.StatusAccepted)
u, err := url.Parse(resp.Header.Get("Location")) u, err = url.Parse(resp.Header.Get("Location"))
if err != nil { if err != nil {
t.Fatalf("error parsing location header: %v", err) t.Fatalf("error parsing location header: %v", err)
} }
@ -1943,7 +2027,6 @@ func doPushLayer(t *testing.T, ub *v2.URLBuilder, name reference.Named, dgst dig
u.RawQuery = url.Values{ u.RawQuery = url.Values{
"_state": u.Query()["_state"], "_state": u.Query()["_state"],
"digest": []string{dgst.String()}, "digest": []string{dgst.String()},
}.Encode() }.Encode()
@ -2211,8 +2294,7 @@ func createRepository(env *testEnv, t *testing.T, imageName string, tag string)
expectedLayers[dgst] = rs expectedLayers[dgst] = rs
unsignedManifest.FSLayers[i].BlobSum = dgst unsignedManifest.FSLayers[i].BlobSum = dgst
uploadURLBase, _ := startPushLayer(t, env, imageNameRef)
uploadURLBase, _ := startPushLayer(t, env.builder, imageNameRef)
pushLayer(t, env.builder, imageNameRef, dgst, uploadURLBase, rs) pushLayer(t, env.builder, imageNameRef, dgst, uploadURLBase, rs)
} }

View file

@ -155,6 +155,7 @@ func NewApp(ctx context.Context, config *configuration.Configuration) *App {
app.configureRedis(config) app.configureRedis(config)
app.configureLogHook(config) app.configureLogHook(config)
options := registrymiddleware.GetRegistryOptions()
if config.Compatibility.Schema1.TrustKey != "" { if config.Compatibility.Schema1.TrustKey != "" {
app.trustKey, err = libtrust.LoadKeyFile(config.Compatibility.Schema1.TrustKey) app.trustKey, err = libtrust.LoadKeyFile(config.Compatibility.Schema1.TrustKey)
if err != nil { if err != nil {
@ -169,6 +170,8 @@ func NewApp(ctx context.Context, config *configuration.Configuration) *App {
} }
} }
options = append(options, storage.Schema1SigningKey(app.trustKey))
if config.HTTP.Host != "" { if config.HTTP.Host != "" {
u, err := url.Parse(config.HTTP.Host) u, err := url.Parse(config.HTTP.Host)
if err != nil { if err != nil {
@ -177,17 +180,10 @@ func NewApp(ctx context.Context, config *configuration.Configuration) *App {
app.httpHost = *u app.httpHost = *u
} }
options := []storage.RegistryOption{}
if app.isCache { if app.isCache {
options = append(options, storage.DisableDigestResumption) options = append(options, storage.DisableDigestResumption)
} }
if config.Compatibility.Schema1.DisableSignatureStore {
options = append(options, storage.DisableSchema1Signatures)
options = append(options, storage.Schema1SigningKey(app.trustKey))
}
// configure deletion // configure deletion
if d, ok := config.Storage["delete"]; ok { if d, ok := config.Storage["delete"]; ok {
e, ok := d["enabled"] e, ok := d["enabled"]
@ -258,7 +254,7 @@ func NewApp(ctx context.Context, config *configuration.Configuration) *App {
} }
} }
app.registry, err = applyRegistryMiddleware(app.Context, app.registry, config.Middleware["registry"]) app.registry, err = applyRegistryMiddleware(app, app.registry, config.Middleware["registry"])
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -634,6 +630,8 @@ func (app *App) dispatcher(dispatch dispatchFunc) http.Handler {
context.Errors = append(context.Errors, v2.ErrorCodeNameUnknown.WithDetail(err)) context.Errors = append(context.Errors, v2.ErrorCodeNameUnknown.WithDetail(err))
case distribution.ErrRepositoryNameInvalid: case distribution.ErrRepositoryNameInvalid:
context.Errors = append(context.Errors, v2.ErrorCodeNameInvalid.WithDetail(err)) context.Errors = append(context.Errors, v2.ErrorCodeNameInvalid.WithDetail(err))
case errcode.Error:
context.Errors = append(context.Errors, err)
} }
if err := errcode.ServeJSON(w, context.Errors); err != nil { if err := errcode.ServeJSON(w, context.Errors); err != nil {
@ -647,7 +645,7 @@ func (app *App) dispatcher(dispatch dispatchFunc) http.Handler {
repository, repository,
app.eventBridge(context, r)) app.eventBridge(context, r))
context.Repository, err = applyRepoMiddleware(context.Context, context.Repository, app.Config.Middleware["repository"]) context.Repository, err = applyRepoMiddleware(app, context.Repository, app.Config.Middleware["repository"])
if err != nil { if err != nil {
ctxu.GetLogger(context).Errorf("error initializing repository middleware: %v", err) ctxu.GetLogger(context).Errorf("error initializing repository middleware: %v", err)
context.Errors = append(context.Errors, errcode.ErrorCodeUnknown.WithDetail(err)) context.Errors = append(context.Errors, errcode.ErrorCodeUnknown.WithDetail(err))
@ -721,9 +719,9 @@ func (app *App) context(w http.ResponseWriter, r *http.Request) *Context {
// A "host" item in the configuration takes precedence over // A "host" item in the configuration takes precedence over
// X-Forwarded-Proto and X-Forwarded-Host headers, and the // X-Forwarded-Proto and X-Forwarded-Host headers, and the
// hostname in the request. // hostname in the request.
context.urlBuilder = v2.NewURLBuilder(&app.httpHost) context.urlBuilder = v2.NewURLBuilder(&app.httpHost, false)
} else { } else {
context.urlBuilder = v2.NewURLBuilderFromRequest(r) context.urlBuilder = v2.NewURLBuilderFromRequest(r, app.Config.HTTP.RelativeURLs)
} }
return context return context

View file

@ -160,7 +160,7 @@ func TestNewApp(t *testing.T) {
app := NewApp(ctx, &config) app := NewApp(ctx, &config)
server := httptest.NewServer(app) server := httptest.NewServer(app)
builder, err := v2.NewURLBuilderFromString(server.URL) builder, err := v2.NewURLBuilderFromString(server.URL, false)
if err != nil { if err != nil {
t.Fatalf("error creating urlbuilder: %v", err) t.Fatalf("error creating urlbuilder: %v", err)
} }

View file

@ -4,7 +4,6 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"net/url" "net/url"
"os"
"github.com/docker/distribution" "github.com/docker/distribution"
ctxu "github.com/docker/distribution/context" ctxu "github.com/docker/distribution/context"
@ -76,28 +75,14 @@ func blobUploadDispatcher(ctx *Context, r *http.Request) http.Handler {
} }
buh.Upload = upload buh.Upload = upload
if state.Offset > 0 { if size := upload.Size(); size != buh.State.Offset {
// Seek the blob upload to the correct spot if it's non-zero.
// These error conditions should be rare and demonstrate really
// problems. We basically cancel the upload and tell the client to
// start over.
if nn, err := upload.Seek(buh.State.Offset, os.SEEK_SET); err != nil {
defer upload.Close() defer upload.Close()
ctxu.GetLogger(ctx).Infof("error seeking blob upload: %v", err) ctxu.GetLogger(ctx).Infof("upload resumed at wrong offest: %d != %d", size, buh.State.Offset)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
buh.Errors = append(buh.Errors, v2.ErrorCodeBlobUploadInvalid.WithDetail(err))
upload.Cancel(buh)
})
} else if nn != buh.State.Offset {
defer upload.Close()
ctxu.GetLogger(ctx).Infof("seek to wrong offest: %d != %d", nn, buh.State.Offset)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
buh.Errors = append(buh.Errors, v2.ErrorCodeBlobUploadInvalid.WithDetail(err)) buh.Errors = append(buh.Errors, v2.ErrorCodeBlobUploadInvalid.WithDetail(err))
upload.Cancel(buh) upload.Cancel(buh)
}) })
} }
}
return closeResources(handler, buh.Upload) return closeResources(handler, buh.Upload)
} }
@ -239,10 +224,7 @@ func (buh *blobUploadHandler) PutBlobUploadComplete(w http.ResponseWriter, r *ht
return return
} }
size := buh.State.Offset size := buh.Upload.Size()
if offset, err := buh.Upload.Seek(0, os.SEEK_CUR); err == nil {
size = offset
}
desc, err := buh.Upload.Commit(buh, distribution.Descriptor{ desc, err := buh.Upload.Commit(buh, distribution.Descriptor{
Digest: dgst, Digest: dgst,
@ -257,6 +239,8 @@ func (buh *blobUploadHandler) PutBlobUploadComplete(w http.ResponseWriter, r *ht
switch err := err.(type) { switch err := err.(type) {
case distribution.ErrBlobInvalidDigest: case distribution.ErrBlobInvalidDigest:
buh.Errors = append(buh.Errors, v2.ErrorCodeDigestInvalid.WithDetail(err)) buh.Errors = append(buh.Errors, v2.ErrorCodeDigestInvalid.WithDetail(err))
case errcode.Error:
buh.Errors = append(buh.Errors, err)
default: default:
switch err { switch err {
case distribution.ErrAccessDenied: case distribution.ErrAccessDenied:
@ -308,21 +292,10 @@ func (buh *blobUploadHandler) CancelBlobUpload(w http.ResponseWriter, r *http.Re
// uploads always start at a 0 offset. This allows disabling resumable push by // uploads always start at a 0 offset. This allows disabling resumable push by
// always returning a 0 offset on check status. // always returning a 0 offset on check status.
func (buh *blobUploadHandler) blobUploadResponse(w http.ResponseWriter, r *http.Request, fresh bool) error { func (buh *blobUploadHandler) blobUploadResponse(w http.ResponseWriter, r *http.Request, fresh bool) error {
var offset int64
if !fresh {
var err error
offset, err = buh.Upload.Seek(0, os.SEEK_CUR)
if err != nil {
ctxu.GetLogger(buh).Errorf("unable get current offset of blob upload: %v", err)
return err
}
}
// TODO(stevvooe): Need a better way to manage the upload state automatically. // TODO(stevvooe): Need a better way to manage the upload state automatically.
buh.State.Name = buh.Repository.Named().Name() buh.State.Name = buh.Repository.Named().Name()
buh.State.UUID = buh.Upload.ID() buh.State.UUID = buh.Upload.ID()
buh.State.Offset = offset buh.State.Offset = buh.Upload.Size()
buh.State.StartedAt = buh.Upload.StartedAt() buh.State.StartedAt = buh.Upload.StartedAt()
token, err := hmacKey(buh.Config.HTTP.Secret).packUploadState(buh.State) token, err := hmacKey(buh.Config.HTTP.Secret).packUploadState(buh.State)
@ -341,13 +314,14 @@ func (buh *blobUploadHandler) blobUploadResponse(w http.ResponseWriter, r *http.
return err return err
} }
endRange := offset endRange := buh.Upload.Size()
if endRange > 0 { if endRange > 0 {
endRange = endRange - 1 endRange = endRange - 1
} }
w.Header().Set("Docker-Upload-UUID", buh.UUID) w.Header().Set("Docker-Upload-UUID", buh.UUID)
w.Header().Set("Location", uploadURL) w.Header().Set("Location", uploadURL)
w.Header().Set("Content-Length", "0") w.Header().Set("Content-Length", "0")
w.Header().Set("Range", fmt.Sprintf("0-%d", endRange)) w.Header().Set("Range", fmt.Sprintf("0-%d", endRange))

View file

@ -20,7 +20,7 @@ func closeResources(handler http.Handler, closers ...io.Closer) http.Handler {
}) })
} }
// copyFullPayload copies the payload of a HTTP request to destWriter. If it // copyFullPayload copies the payload of an HTTP request to destWriter. If it
// receives less content than expected, and the client disconnected during the // receives less content than expected, and the client disconnected during the
// upload, it avoids sending a 400 error to keep the logs cleaner. // upload, it avoids sending a 400 error to keep the logs cleaner.
func copyFullPayload(responseWriter http.ResponseWriter, r *http.Request, destWriter io.Writer, context ctxu.Context, action string, errSlice *errcode.Errors) error { func copyFullPayload(responseWriter http.ResponseWriter, r *http.Request, destWriter io.Writer, context ctxu.Context, action string, errSlice *errcode.Errors) error {
@ -46,7 +46,11 @@ func copyFullPayload(responseWriter http.ResponseWriter, r *http.Request, destWr
// instead of showing 0 for the HTTP status. // instead of showing 0 for the HTTP status.
responseWriter.WriteHeader(499) responseWriter.WriteHeader(499)
ctxu.GetLogger(context).Error("client disconnected during " + action) ctxu.GetLoggerWithFields(context, map[interface{}]interface{}{
"error": err,
"copied": copied,
"contentLength": r.ContentLength,
}, "error", "copied", "contentLength").Error("client disconnected during " + action)
return errors.New("client disconnected") return errors.New("client disconnected")
default: default:
} }

View file

@ -86,7 +86,11 @@ func (imh *imageManifestHandler) GetImageManifest(w http.ResponseWriter, r *http
return return
} }
manifest, err = manifests.Get(imh, imh.Digest) var options []distribution.ManifestServiceOption
if imh.Tag != "" {
options = append(options, distribution.WithTag(imh.Tag))
}
manifest, err = manifests.Get(imh, imh.Digest, options...)
if err != nil { if err != nil {
imh.Errors = append(imh.Errors, v2.ErrorCodeManifestUnknown.WithDetail(err)) imh.Errors = append(imh.Errors, v2.ErrorCodeManifestUnknown.WithDetail(err))
return return
@ -245,7 +249,11 @@ func (imh *imageManifestHandler) PutImageManifest(w http.ResponseWriter, r *http
return return
} }
_, err = manifests.Put(imh, manifest) var options []distribution.ManifestServiceOption
if imh.Tag != "" {
options = append(options, distribution.WithTag(imh.Tag))
}
_, err = manifests.Put(imh, manifest, options...)
if err != nil { if err != nil {
// TODO(stevvooe): These error handling switches really need to be // TODO(stevvooe): These error handling switches really need to be
// handled by an app global mapper. // handled by an app global mapper.
@ -275,6 +283,8 @@ func (imh *imageManifestHandler) PutImageManifest(w http.ResponseWriter, r *http
} }
} }
} }
case errcode.Error:
imh.Errors = append(imh.Errors, err)
default: default:
imh.Errors = append(imh.Errors, errcode.ErrorCodeUnknown.WithDetail(err)) imh.Errors = append(imh.Errors, errcode.ErrorCodeUnknown.WithDetail(err))
} }

View file

@ -41,6 +41,8 @@ func (th *tagsHandler) GetTags(w http.ResponseWriter, r *http.Request) {
switch err := err.(type) { switch err := err.(type) {
case distribution.ErrRepositoryUnknown: case distribution.ErrRepositoryUnknown:
th.Errors = append(th.Errors, v2.ErrorCodeNameUnknown.WithDetail(map[string]string{"name": th.Repository.Named().Name()})) th.Errors = append(th.Errors, v2.ErrorCodeNameUnknown.WithDetail(map[string]string{"name": th.Repository.Named().Name()}))
case errcode.Error:
th.Errors = append(th.Errors, err)
default: default:
th.Errors = append(th.Errors, errcode.ErrorCodeUnknown.WithDetail(err)) th.Errors = append(th.Errors, errcode.ErrorCodeUnknown.WithDetail(err))
} }

View file

@ -5,6 +5,7 @@ import (
"github.com/docker/distribution" "github.com/docker/distribution"
"github.com/docker/distribution/context" "github.com/docker/distribution/context"
"github.com/docker/distribution/registry/storage"
) )
// InitFunc is the type of a RegistryMiddleware factory function and is // InitFunc is the type of a RegistryMiddleware factory function and is
@ -12,6 +13,7 @@ import (
type InitFunc func(ctx context.Context, registry distribution.Namespace, options map[string]interface{}) (distribution.Namespace, error) type InitFunc func(ctx context.Context, registry distribution.Namespace, options map[string]interface{}) (distribution.Namespace, error)
var middlewares map[string]InitFunc var middlewares map[string]InitFunc
var registryoptions []storage.RegistryOption
// Register is used to register an InitFunc for // Register is used to register an InitFunc for
// a RegistryMiddleware backend with the given name. // a RegistryMiddleware backend with the given name.
@ -38,3 +40,15 @@ func Get(ctx context.Context, name string, options map[string]interface{}, regis
return nil, fmt.Errorf("no registry middleware registered with name: %s", name) return nil, fmt.Errorf("no registry middleware registered with name: %s", name)
} }
// RegisterOptions adds more options to RegistryOption list. Options get applied before
// any other configuration-based options.
func RegisterOptions(options ...storage.RegistryOption) error {
registryoptions = append(registryoptions, options...)
return nil
}
// GetRegistryOptions returns list of RegistryOption.
func GetRegistryOptions() []storage.RegistryOption {
return registryoptions
}

View file

@ -25,6 +25,13 @@ func (c credentials) Basic(u *url.URL) (string, string) {
return up.username, up.password return up.username, up.password
} }
func (c credentials) RefreshToken(u *url.URL, service string) string {
return ""
}
func (c credentials) SetRefreshToken(u *url.URL, service, token string) {
}
// configureAuth stores credentials for challenge responses // configureAuth stores credentials for challenge responses
func configureAuth(username, password string) (auth.CredentialStore, error) { func configureAuth(username, password string) (auth.CredentialStore, error) {
creds := map[string]userpass{ creds := map[string]userpass{

View file

@ -132,8 +132,15 @@ func makeTestEnv(t *testing.T, name string) *testEnv {
t.Fatalf("unable to create tempdir: %s", err) t.Fatalf("unable to create tempdir: %s", err)
} }
localDriver, err := filesystem.FromParameters(map[string]interface{}{
"rootdirectory": truthDir,
})
if err != nil {
t.Fatalf("unable to create filesystem driver: %s", err)
}
// todo: create a tempfile area here // todo: create a tempfile area here
localRegistry, err := storage.NewRegistry(ctx, filesystem.New(truthDir), storage.BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider()), storage.EnableRedirect, storage.DisableDigestResumption) localRegistry, err := storage.NewRegistry(ctx, localDriver, storage.BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider()), storage.EnableRedirect, storage.DisableDigestResumption)
if err != nil { if err != nil {
t.Fatalf("error creating registry: %v", err) t.Fatalf("error creating registry: %v", err)
} }
@ -142,7 +149,14 @@ func makeTestEnv(t *testing.T, name string) *testEnv {
t.Fatalf("unexpected error getting repo: %v", err) t.Fatalf("unexpected error getting repo: %v", err)
} }
truthRegistry, err := storage.NewRegistry(ctx, filesystem.New(cacheDir), storage.BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider())) cacheDriver, err := filesystem.FromParameters(map[string]interface{}{
"rootdirectory": cacheDir,
})
if err != nil {
t.Fatalf("unable to create filesystem driver: %s", err)
}
truthRegistry, err := storage.NewRegistry(ctx, cacheDriver, storage.BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider()))
if err != nil { if err != nil {
t.Fatalf("error creating registry: %v", err) t.Fatalf("error creating registry: %v", err)
} }

View file

@ -60,12 +60,6 @@ func (sm statsManifest) Put(ctx context.Context, manifest distribution.Manifest,
return sm.manifests.Put(ctx, manifest) return sm.manifests.Put(ctx, manifest)
} }
/*func (sm statsManifest) Enumerate(ctx context.Context, manifests []distribution.Manifest, last distribution.Manifest) (n int, err error) {
sm.stats["enumerate"]++
return sm.manifests.Enumerate(ctx, manifests, last)
}
*/
type mockChallenger struct { type mockChallenger struct {
sync.Mutex sync.Mutex
count int count int
@ -75,7 +69,6 @@ type mockChallenger struct {
func (m *mockChallenger) tryEstablishChallenges(context.Context) error { func (m *mockChallenger) tryEstablishChallenges(context.Context) error {
m.Lock() m.Lock()
defer m.Unlock() defer m.Unlock()
m.count++ m.count++
return nil return nil
} }
@ -93,9 +86,15 @@ func newManifestStoreTestEnv(t *testing.T, name, tag string) *manifestStoreTestE
if err != nil { if err != nil {
t.Fatalf("unable to parse reference: %s", err) t.Fatalf("unable to parse reference: %s", err)
} }
k, err := libtrust.GenerateECP256PrivateKey()
if err != nil {
t.Fatal(err)
}
ctx := context.Background() ctx := context.Background()
truthRegistry, err := storage.NewRegistry(ctx, inmemory.New(), storage.BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider())) truthRegistry, err := storage.NewRegistry(ctx, inmemory.New(),
storage.BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider()),
storage.Schema1SigningKey(k))
if err != nil { if err != nil {
t.Fatalf("error creating registry: %v", err) t.Fatalf("error creating registry: %v", err)
} }
@ -117,7 +116,7 @@ func newManifestStoreTestEnv(t *testing.T, name, tag string) *manifestStoreTestE
t.Fatalf(err.Error()) t.Fatalf(err.Error())
} }
localRegistry, err := storage.NewRegistry(ctx, inmemory.New(), storage.BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider()), storage.EnableRedirect, storage.DisableDigestResumption) localRegistry, err := storage.NewRegistry(ctx, inmemory.New(), storage.BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider()), storage.EnableRedirect, storage.DisableDigestResumption, storage.Schema1SigningKey(k))
if err != nil { if err != nil {
t.Fatalf("error creating registry: %v", err) t.Fatalf("error creating registry: %v", err)
} }

View file

@ -22,13 +22,13 @@ import (
type proxyingRegistry struct { type proxyingRegistry struct {
embedded distribution.Namespace // provides local registry functionality embedded distribution.Namespace // provides local registry functionality
scheduler *scheduler.TTLExpirationScheduler scheduler *scheduler.TTLExpirationScheduler
remoteURL string remoteURL url.URL
authChallenger authChallenger authChallenger authChallenger
} }
// NewRegistryPullThroughCache creates a registry acting as a pull through cache // NewRegistryPullThroughCache creates a registry acting as a pull through cache
func NewRegistryPullThroughCache(ctx context.Context, registry distribution.Namespace, driver driver.StorageDriver, config configuration.Proxy) (distribution.Namespace, error) { func NewRegistryPullThroughCache(ctx context.Context, registry distribution.Namespace, driver driver.StorageDriver, config configuration.Proxy) (distribution.Namespace, error) {
_, err := url.Parse(config.RemoteURL) remoteURL, err := url.Parse(config.RemoteURL)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -99,9 +99,9 @@ func NewRegistryPullThroughCache(ctx context.Context, registry distribution.Name
return &proxyingRegistry{ return &proxyingRegistry{
embedded: registry, embedded: registry,
scheduler: s, scheduler: s,
remoteURL: config.RemoteURL, remoteURL: *remoteURL,
authChallenger: &remoteAuthChallenger{ authChallenger: &remoteAuthChallenger{
remoteURL: config.RemoteURL, remoteURL: *remoteURL,
cm: auth.NewSimpleChallengeManager(), cm: auth.NewSimpleChallengeManager(),
cs: cs, cs: cs,
}, },
@ -131,7 +131,7 @@ func (pr *proxyingRegistry) Repository(ctx context.Context, name reference.Named
return nil, err return nil, err
} }
remoteRepo, err := client.NewRepository(ctx, name, pr.remoteURL, tr) remoteRepo, err := client.NewRepository(ctx, name, pr.remoteURL.String(), tr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -182,7 +182,7 @@ type authChallenger interface {
} }
type remoteAuthChallenger struct { type remoteAuthChallenger struct {
remoteURL string remoteURL url.URL
sync.Mutex sync.Mutex
cm auth.ChallengeManager cm auth.ChallengeManager
cs auth.CredentialStore cs auth.CredentialStore
@ -201,8 +201,9 @@ func (r *remoteAuthChallenger) tryEstablishChallenges(ctx context.Context) error
r.Lock() r.Lock()
defer r.Unlock() defer r.Unlock()
remoteURL := r.remoteURL + "/v2/" remoteURL := r.remoteURL
challenges, err := r.cm.GetChallenges(remoteURL) remoteURL.Path = "/v2/"
challenges, err := r.cm.GetChallenges(r.remoteURL)
if err != nil { if err != nil {
return err return err
} }
@ -212,7 +213,7 @@ func (r *remoteAuthChallenger) tryEstablishChallenges(ctx context.Context) error
} }
// establish challenge type with upstream // establish challenge type with upstream
if err := ping(r.cm, remoteURL, challengeHeader); err != nil { if err := ping(r.cm, remoteURL.String(), challengeHeader); err != nil {
return err return err
} }

View file

@ -1,6 +1,7 @@
package proxy package proxy
import ( import (
"reflect"
"sort" "sort"
"sync" "sync"
"testing" "testing"
@ -92,7 +93,7 @@ func TestGet(t *testing.T) {
t.Fatalf("Expected 1 auth challenge call, got %#v", proxyTags.authChallenger) t.Fatalf("Expected 1 auth challenge call, got %#v", proxyTags.authChallenger)
} }
if d != remoteDesc { if !reflect.DeepEqual(d, remoteDesc) {
t.Fatal("unable to get put tag") t.Fatal("unable to get put tag")
} }
@ -101,7 +102,7 @@ func TestGet(t *testing.T) {
t.Fatal("remote tag not pulled into store") t.Fatal("remote tag not pulled into store")
} }
if local != remoteDesc { if !reflect.DeepEqual(local, remoteDesc) {
t.Fatalf("unexpected descriptor pulled through") t.Fatalf("unexpected descriptor pulled through")
} }
@ -121,7 +122,7 @@ func TestGet(t *testing.T) {
t.Fatalf("Expected 2 auth challenge calls, got %#v", proxyTags.authChallenger) t.Fatalf("Expected 2 auth challenge calls, got %#v", proxyTags.authChallenger)
} }
if d != newRemoteDesc { if !reflect.DeepEqual(d, newRemoteDesc) {
t.Fatal("unable to get put tag") t.Fatal("unable to get put tag")
} }

View file

@ -267,7 +267,7 @@ func logLevel(level configuration.Loglevel) log.Level {
return l return l
} }
// panicHandler add a HTTP handler to web app. The handler recover the happening // panicHandler add an HTTP handler to web app. The handler recover the happening
// panic. logrus.Panic transmits panic message to pre-config log hooks, which is // panic. logrus.Panic transmits panic message to pre-config log hooks, which is
// defined in config.yml. // defined in config.yml.
func panicHandler(handler http.Handler) http.Handler { func panicHandler(handler http.Handler) http.Handler {

View file

@ -1,7 +1,14 @@
package registry package registry
import ( import (
"fmt"
"os"
"github.com/docker/distribution/context"
"github.com/docker/distribution/registry/storage"
"github.com/docker/distribution/registry/storage/driver/factory"
"github.com/docker/distribution/version" "github.com/docker/distribution/version"
"github.com/docker/libtrust"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -10,6 +17,7 @@ var showVersion bool
func init() { func init() {
RootCmd.AddCommand(ServeCmd) RootCmd.AddCommand(ServeCmd)
RootCmd.AddCommand(GCCmd) RootCmd.AddCommand(GCCmd)
GCCmd.Flags().BoolVarP(&dryRun, "dry-run", "d", false, "do everything except remove the blobs")
RootCmd.Flags().BoolVarP(&showVersion, "version", "v", false, "show the version and exit") RootCmd.Flags().BoolVarP(&showVersion, "version", "v", false, "show the version and exit")
} }
@ -26,3 +34,51 @@ var RootCmd = &cobra.Command{
cmd.Usage() cmd.Usage()
}, },
} }
var dryRun bool
// GCCmd is the cobra command that corresponds to the garbage-collect subcommand
var GCCmd = &cobra.Command{
Use: "garbage-collect <config>",
Short: "`garbage-collect` deletes layers not referenced by any manifests",
Long: "`garbage-collect` deletes layers not referenced by any manifests",
Run: func(cmd *cobra.Command, args []string) {
config, err := resolveConfiguration(args)
if err != nil {
fmt.Fprintf(os.Stderr, "configuration error: %v\n", err)
cmd.Usage()
os.Exit(1)
}
driver, err := factory.Create(config.Storage.Type(), config.Storage.Parameters())
if err != nil {
fmt.Fprintf(os.Stderr, "failed to construct %s driver: %v", config.Storage.Type(), err)
os.Exit(1)
}
ctx := context.Background()
ctx, err = configureLogging(ctx, config)
if err != nil {
fmt.Fprintf(os.Stderr, "unable to configure logging with config: %s", err)
os.Exit(1)
}
k, err := libtrust.GenerateECP256PrivateKey()
if err != nil {
fmt.Fprint(os.Stderr, err)
os.Exit(1)
}
registry, err := storage.NewRegistry(ctx, driver, storage.Schema1SigningKey(k))
if err != nil {
fmt.Fprintf(os.Stderr, "failed to construct registry: %v", err)
os.Exit(1)
}
err = storage.MarkAndSweep(ctx, driver, registry, dryRun)
if err != nil {
fmt.Fprintf(os.Stderr, "failed to garbage collect: %v", err)
os.Exit(1)
}
},
}

View file

@ -7,6 +7,8 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"os" "os"
"path"
"reflect"
"testing" "testing"
"github.com/docker/distribution" "github.com/docker/distribution"
@ -41,10 +43,7 @@ func TestWriteSeek(t *testing.T) {
} }
contents := []byte{1, 2, 3} contents := []byte{1, 2, 3}
blobUpload.Write(contents) blobUpload.Write(contents)
offset, err := blobUpload.Seek(0, os.SEEK_CUR) offset := blobUpload.Size()
if err != nil {
t.Fatalf("unexpected error in blobUpload.Seek: %s", err)
}
if offset != int64(len(contents)) { if offset != int64(len(contents)) {
t.Fatalf("unexpected value for blobUpload offset: %v != %v", offset, len(contents)) t.Fatalf("unexpected value for blobUpload offset: %v != %v", offset, len(contents))
} }
@ -86,6 +85,15 @@ func TestSimpleBlobUpload(t *testing.T) {
t.Fatalf("unexpected error during upload cancellation: %v", err) t.Fatalf("unexpected error during upload cancellation: %v", err)
} }
// get the enclosing directory
uploadPath := path.Dir(blobUpload.(*blobWriter).path)
// ensure state was cleaned up
_, err = driver.List(ctx, uploadPath)
if err == nil {
t.Fatal("files in upload path after cleanup")
}
// Do a resume, get unknown upload // Do a resume, get unknown upload
blobUpload, err = bs.Resume(ctx, blobUpload.ID()) blobUpload, err = bs.Resume(ctx, blobUpload.ID())
if err != distribution.ErrBlobUploadUnknown { if err != distribution.ErrBlobUploadUnknown {
@ -113,11 +121,7 @@ func TestSimpleBlobUpload(t *testing.T) {
t.Fatalf("layer data write incomplete") t.Fatalf("layer data write incomplete")
} }
offset, err := blobUpload.Seek(0, os.SEEK_CUR) offset := blobUpload.Size()
if err != nil {
t.Fatalf("unexpected error seeking layer upload: %v", err)
}
if offset != nn { if offset != nn {
t.Fatalf("blobUpload not updated with correct offset: %v != %v", offset, nn) t.Fatalf("blobUpload not updated with correct offset: %v != %v", offset, nn)
} }
@ -135,6 +139,13 @@ func TestSimpleBlobUpload(t *testing.T) {
t.Fatalf("unexpected error finishing layer upload: %v", err) t.Fatalf("unexpected error finishing layer upload: %v", err)
} }
// ensure state was cleaned up
uploadPath = path.Dir(blobUpload.(*blobWriter).path)
_, err = driver.List(ctx, uploadPath)
if err == nil {
t.Fatal("files in upload path after commit")
}
// After finishing an upload, it should no longer exist. // After finishing an upload, it should no longer exist.
if _, err := bs.Resume(ctx, blobUpload.ID()); err != distribution.ErrBlobUploadUnknown { if _, err := bs.Resume(ctx, blobUpload.ID()); err != distribution.ErrBlobUploadUnknown {
t.Fatalf("expected layer upload to be unknown, got %v", err) t.Fatalf("expected layer upload to be unknown, got %v", err)
@ -146,7 +157,7 @@ func TestSimpleBlobUpload(t *testing.T) {
t.Fatalf("unexpected error checking for existence: %v, %#v", err, bs) t.Fatalf("unexpected error checking for existence: %v, %#v", err, bs)
} }
if statDesc != desc { if !reflect.DeepEqual(statDesc, desc) {
t.Fatalf("descriptors not equal: %v != %v", statDesc, desc) t.Fatalf("descriptors not equal: %v != %v", statDesc, desc)
} }
@ -400,7 +411,7 @@ func TestBlobMount(t *testing.T) {
t.Fatalf("unexpected error checking for existence: %v, %#v", err, sbs) t.Fatalf("unexpected error checking for existence: %v, %#v", err, sbs)
} }
if statDesc != desc { if !reflect.DeepEqual(statDesc, desc) {
t.Fatalf("descriptors not equal: %v != %v", statDesc, desc) t.Fatalf("descriptors not equal: %v != %v", statDesc, desc)
} }
@ -426,7 +437,7 @@ func TestBlobMount(t *testing.T) {
t.Fatalf("unexpected error mounting layer: %v", err) t.Fatalf("unexpected error mounting layer: %v", err)
} }
if ebm.Descriptor != desc { if !reflect.DeepEqual(ebm.Descriptor, desc) {
t.Fatalf("descriptors not equal: %v != %v", ebm.Descriptor, desc) t.Fatalf("descriptors not equal: %v != %v", ebm.Descriptor, desc)
} }
@ -436,7 +447,7 @@ func TestBlobMount(t *testing.T) {
t.Fatalf("unexpected error checking for existence: %v, %#v", err, bs) t.Fatalf("unexpected error checking for existence: %v, %#v", err, bs)
} }
if statDesc != desc { if !reflect.DeepEqual(statDesc, desc) {
t.Fatalf("descriptors not equal: %v != %v", statDesc, desc) t.Fatalf("descriptors not equal: %v != %v", statDesc, desc)
} }

View file

@ -75,7 +75,6 @@ func (bs *blobStore) Put(ctx context.Context, mediaType string, p []byte) (distr
} }
// TODO(stevvooe): Write out mediatype here, as well. // TODO(stevvooe): Write out mediatype here, as well.
return distribution.Descriptor{ return distribution.Descriptor{
Size: int64(len(p)), Size: int64(len(p)),

View file

@ -18,9 +18,10 @@ var (
errResumableDigestNotAvailable = errors.New("resumable digest not available") errResumableDigestNotAvailable = errors.New("resumable digest not available")
) )
// layerWriter is used to control the various aspects of resumable // blobWriter is used to control the various aspects of resumable
// layer upload. It implements the LayerUpload interface. // blob upload.
type blobWriter struct { type blobWriter struct {
ctx context.Context
blobStore *linkedBlobStore blobStore *linkedBlobStore
id string id string
@ -28,11 +29,12 @@ type blobWriter struct {
digester digest.Digester digester digest.Digester
written int64 // track the contiguous write written int64 // track the contiguous write
// implementes io.WriteSeeker, io.ReaderFrom and io.Closer to satisfy fileWriter storagedriver.FileWriter
// LayerUpload Interface driver storagedriver.StorageDriver
fileWriter path string
resumableDigestEnabled bool resumableDigestEnabled bool
committed bool
} }
var _ distribution.BlobWriter = &blobWriter{} var _ distribution.BlobWriter = &blobWriter{}
@ -51,10 +53,12 @@ func (bw *blobWriter) StartedAt() time.Time {
func (bw *blobWriter) Commit(ctx context.Context, desc distribution.Descriptor) (distribution.Descriptor, error) { func (bw *blobWriter) Commit(ctx context.Context, desc distribution.Descriptor) (distribution.Descriptor, error) {
context.GetLogger(ctx).Debug("(*blobWriter).Commit") context.GetLogger(ctx).Debug("(*blobWriter).Commit")
if err := bw.fileWriter.Close(); err != nil { if err := bw.fileWriter.Commit(); err != nil {
return distribution.Descriptor{}, err return distribution.Descriptor{}, err
} }
bw.Close()
canonical, err := bw.validateBlob(ctx, desc) canonical, err := bw.validateBlob(ctx, desc)
if err != nil { if err != nil {
return distribution.Descriptor{}, err return distribution.Descriptor{}, err
@ -77,6 +81,7 @@ func (bw *blobWriter) Commit(ctx context.Context, desc distribution.Descriptor)
return distribution.Descriptor{}, err return distribution.Descriptor{}, err
} }
bw.committed = true
return canonical, nil return canonical, nil
} }
@ -84,23 +89,34 @@ func (bw *blobWriter) Commit(ctx context.Context, desc distribution.Descriptor)
// the writer and canceling the operation. // the writer and canceling the operation.
func (bw *blobWriter) Cancel(ctx context.Context) error { func (bw *blobWriter) Cancel(ctx context.Context) error {
context.GetLogger(ctx).Debug("(*blobWriter).Rollback") context.GetLogger(ctx).Debug("(*blobWriter).Rollback")
if err := bw.fileWriter.Cancel(); err != nil {
return err
}
if err := bw.Close(); err != nil {
context.GetLogger(ctx).Errorf("error closing blobwriter: %s", err)
}
if err := bw.removeResources(ctx); err != nil { if err := bw.removeResources(ctx); err != nil {
return err return err
} }
bw.Close()
return nil return nil
} }
func (bw *blobWriter) Size() int64 {
return bw.fileWriter.Size()
}
func (bw *blobWriter) Write(p []byte) (int, error) { func (bw *blobWriter) Write(p []byte) (int, error) {
// Ensure that the current write offset matches how many bytes have been // Ensure that the current write offset matches how many bytes have been
// written to the digester. If not, we need to update the digest state to // written to the digester. If not, we need to update the digest state to
// match the current write position. // match the current write position.
if err := bw.resumeDigestAt(bw.blobStore.ctx, bw.offset); err != nil && err != errResumableDigestNotAvailable { if err := bw.resumeDigest(bw.blobStore.ctx); err != nil && err != errResumableDigestNotAvailable {
return 0, err return 0, err
} }
n, err := io.MultiWriter(&bw.fileWriter, bw.digester.Hash()).Write(p) n, err := io.MultiWriter(bw.fileWriter, bw.digester.Hash()).Write(p)
bw.written += int64(n) bw.written += int64(n)
return n, err return n, err
@ -110,19 +126,19 @@ func (bw *blobWriter) ReadFrom(r io.Reader) (n int64, err error) {
// Ensure that the current write offset matches how many bytes have been // Ensure that the current write offset matches how many bytes have been
// written to the digester. If not, we need to update the digest state to // written to the digester. If not, we need to update the digest state to
// match the current write position. // match the current write position.
if err := bw.resumeDigestAt(bw.blobStore.ctx, bw.offset); err != nil && err != errResumableDigestNotAvailable { if err := bw.resumeDigest(bw.blobStore.ctx); err != nil && err != errResumableDigestNotAvailable {
return 0, err return 0, err
} }
nn, err := bw.fileWriter.ReadFrom(io.TeeReader(r, bw.digester.Hash())) nn, err := io.Copy(io.MultiWriter(bw.fileWriter, bw.digester.Hash()), r)
bw.written += nn bw.written += nn
return nn, err return nn, err
} }
func (bw *blobWriter) Close() error { func (bw *blobWriter) Close() error {
if bw.err != nil { if bw.committed {
return bw.err return errors.New("blobwriter close after commit")
} }
if err := bw.storeHashState(bw.blobStore.ctx); err != nil { if err := bw.storeHashState(bw.blobStore.ctx); err != nil {
@ -148,8 +164,10 @@ func (bw *blobWriter) validateBlob(ctx context.Context, desc distribution.Descri
} }
} }
var size int64
// Stat the on disk file // Stat the on disk file
if fi, err := bw.fileWriter.driver.Stat(ctx, bw.path); err != nil { if fi, err := bw.driver.Stat(ctx, bw.path); err != nil {
switch err := err.(type) { switch err := err.(type) {
case storagedriver.PathNotFoundError: case storagedriver.PathNotFoundError:
// NOTE(stevvooe): We really don't care if the file is // NOTE(stevvooe): We really don't care if the file is
@ -165,23 +183,23 @@ func (bw *blobWriter) validateBlob(ctx context.Context, desc distribution.Descri
return distribution.Descriptor{}, fmt.Errorf("unexpected directory at upload location %q", bw.path) return distribution.Descriptor{}, fmt.Errorf("unexpected directory at upload location %q", bw.path)
} }
bw.size = fi.Size() size = fi.Size()
} }
if desc.Size > 0 { if desc.Size > 0 {
if desc.Size != bw.size { if desc.Size != size {
return distribution.Descriptor{}, distribution.ErrBlobInvalidLength return distribution.Descriptor{}, distribution.ErrBlobInvalidLength
} }
} else { } else {
// if provided 0 or negative length, we can assume caller doesn't know or // if provided 0 or negative length, we can assume caller doesn't know or
// care about length. // care about length.
desc.Size = bw.size desc.Size = size
} }
// TODO(stevvooe): This section is very meandering. Need to be broken down // TODO(stevvooe): This section is very meandering. Need to be broken down
// to be a lot more clear. // to be a lot more clear.
if err := bw.resumeDigestAt(ctx, bw.size); err == nil { if err := bw.resumeDigest(ctx); err == nil {
canonical = bw.digester.Digest() canonical = bw.digester.Digest()
if canonical.Algorithm() == desc.Digest.Algorithm() { if canonical.Algorithm() == desc.Digest.Algorithm() {
@ -206,7 +224,7 @@ func (bw *blobWriter) validateBlob(ctx context.Context, desc distribution.Descri
// the same, we don't need to read the data from the backend. This is // the same, we don't need to read the data from the backend. This is
// because we've written the entire file in the lifecycle of the // because we've written the entire file in the lifecycle of the
// current instance. // current instance.
if bw.written == bw.size && digest.Canonical == desc.Digest.Algorithm() { if bw.written == size && digest.Canonical == desc.Digest.Algorithm() {
canonical = bw.digester.Digest() canonical = bw.digester.Digest()
verified = desc.Digest == canonical verified = desc.Digest == canonical
} }
@ -223,7 +241,7 @@ func (bw *blobWriter) validateBlob(ctx context.Context, desc distribution.Descri
} }
// Read the file from the backend driver and validate it. // Read the file from the backend driver and validate it.
fr, err := newFileReader(ctx, bw.fileWriter.driver, bw.path, desc.Size) fr, err := newFileReader(ctx, bw.driver, bw.path, desc.Size)
if err != nil { if err != nil {
return distribution.Descriptor{}, err return distribution.Descriptor{}, err
} }
@ -357,7 +375,7 @@ func (bw *blobWriter) Reader() (io.ReadCloser, error) {
// todo(richardscothern): Change to exponential backoff, i=0.5, e=2, n=4 // todo(richardscothern): Change to exponential backoff, i=0.5, e=2, n=4
try := 1 try := 1
for try <= 5 { for try <= 5 {
_, err := bw.fileWriter.driver.Stat(bw.ctx, bw.path) _, err := bw.driver.Stat(bw.ctx, bw.path)
if err == nil { if err == nil {
break break
} }
@ -371,7 +389,7 @@ func (bw *blobWriter) Reader() (io.ReadCloser, error) {
} }
} }
readCloser, err := bw.fileWriter.driver.ReadStream(bw.ctx, bw.path, 0) readCloser, err := bw.driver.Reader(bw.ctx, bw.path, 0)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -4,8 +4,6 @@ package storage
import ( import (
"fmt" "fmt"
"io"
"os"
"path" "path"
"strconv" "strconv"
@ -19,24 +17,18 @@ import (
_ "github.com/stevvooe/resumable/sha512" _ "github.com/stevvooe/resumable/sha512"
) )
// resumeDigestAt attempts to restore the state of the internal hash function // resumeDigest attempts to restore the state of the internal hash function
// by loading the most recent saved hash state less than or equal to the given // by loading the most recent saved hash state equal to the current size of the blob.
// offset. Any unhashed bytes remaining less than the given offset are hashed func (bw *blobWriter) resumeDigest(ctx context.Context) error {
// from the content uploaded so far.
func (bw *blobWriter) resumeDigestAt(ctx context.Context, offset int64) error {
if !bw.resumableDigestEnabled { if !bw.resumableDigestEnabled {
return errResumableDigestNotAvailable return errResumableDigestNotAvailable
} }
if offset < 0 {
return fmt.Errorf("cannot resume hash at negative offset: %d", offset)
}
h, ok := bw.digester.Hash().(resumable.Hash) h, ok := bw.digester.Hash().(resumable.Hash)
if !ok { if !ok {
return errResumableDigestNotAvailable return errResumableDigestNotAvailable
} }
offset := bw.fileWriter.Size()
if offset == int64(h.Len()) { if offset == int64(h.Len()) {
// State of digester is already at the requested offset. // State of digester is already at the requested offset.
return nil return nil
@ -49,24 +41,12 @@ func (bw *blobWriter) resumeDigestAt(ctx context.Context, offset int64) error {
return fmt.Errorf("unable to get stored hash states with offset %d: %s", offset, err) return fmt.Errorf("unable to get stored hash states with offset %d: %s", offset, err)
} }
// Find the highest stored hashState with offset less than or equal to // Find the highest stored hashState with offset equal to
// the requested offset. // the requested offset.
for _, hashState := range hashStates { for _, hashState := range hashStates {
if hashState.offset == offset { if hashState.offset == offset {
hashStateMatch = hashState hashStateMatch = hashState
break // Found an exact offset match. break // Found an exact offset match.
} else if hashState.offset < offset && hashState.offset > hashStateMatch.offset {
// This offset is closer to the requested offset.
hashStateMatch = hashState
} else if hashState.offset > offset {
// Remove any stored hash state with offsets higher than this one
// as writes to this resumed hasher will make those invalid. This
// is probably okay to skip for now since we don't expect anyone to
// use the API in this way. For that reason, we don't treat an
// an error here as a fatal error, but only log it.
if err := bw.driver.Delete(ctx, hashState.path); err != nil {
logrus.Errorf("unable to delete stale hash state %q: %s", hashState.path, err)
}
} }
} }
@ -86,20 +66,7 @@ func (bw *blobWriter) resumeDigestAt(ctx context.Context, offset int64) error {
// Mind the gap. // Mind the gap.
if gapLen := offset - int64(h.Len()); gapLen > 0 { if gapLen := offset - int64(h.Len()); gapLen > 0 {
// Need to read content from the upload to catch up to the desired offset. return errResumableDigestNotAvailable
fr, err := newFileReader(ctx, bw.driver, bw.path, bw.size)
if err != nil {
return err
}
defer fr.Close()
if _, err = fr.Seek(int64(h.Len()), os.SEEK_SET); err != nil {
return fmt.Errorf("unable to seek to layer reader offset %d: %s", h.Len(), err)
}
if _, err := io.CopyN(h, fr, gapLen); err != nil {
return err
}
} }
return nil return nil

View file

@ -1,6 +1,7 @@
package cachecheck package cachecheck
import ( import (
"reflect"
"testing" "testing"
"github.com/docker/distribution" "github.com/docker/distribution"
@ -79,7 +80,7 @@ func checkBlobDescriptorCacheSetAndRead(t *testing.T, ctx context.Context, provi
t.Fatalf("unexpected error statting fake2:abc: %v", err) t.Fatalf("unexpected error statting fake2:abc: %v", err)
} }
if expected != desc { if !reflect.DeepEqual(expected, desc) {
t.Fatalf("unexpected descriptor: %#v != %#v", expected, desc) t.Fatalf("unexpected descriptor: %#v != %#v", expected, desc)
} }
@ -89,7 +90,7 @@ func checkBlobDescriptorCacheSetAndRead(t *testing.T, ctx context.Context, provi
t.Fatalf("descriptor not returned for canonical key: %v", err) t.Fatalf("descriptor not returned for canonical key: %v", err)
} }
if expected != desc { if !reflect.DeepEqual(expected, desc) {
t.Fatalf("unexpected descriptor: %#v != %#v", expected, desc) t.Fatalf("unexpected descriptor: %#v != %#v", expected, desc)
} }
@ -99,7 +100,7 @@ func checkBlobDescriptorCacheSetAndRead(t *testing.T, ctx context.Context, provi
t.Fatalf("expected blob unknown in global cache: %v, %v", err, desc) t.Fatalf("expected blob unknown in global cache: %v, %v", err, desc)
} }
if desc != expected { if !reflect.DeepEqual(desc, expected) {
t.Fatalf("unexpected descriptor: %#v != %#v", expected, desc) t.Fatalf("unexpected descriptor: %#v != %#v", expected, desc)
} }
@ -109,7 +110,7 @@ func checkBlobDescriptorCacheSetAndRead(t *testing.T, ctx context.Context, provi
t.Fatalf("unexpected error checking glboal descriptor: %v", err) t.Fatalf("unexpected error checking glboal descriptor: %v", err)
} }
if desc != expected { if !reflect.DeepEqual(desc, expected) {
t.Fatalf("unexpected descriptor: %#v != %#v", expected, desc) t.Fatalf("unexpected descriptor: %#v != %#v", expected, desc)
} }
@ -126,7 +127,7 @@ func checkBlobDescriptorCacheSetAndRead(t *testing.T, ctx context.Context, provi
t.Fatalf("unexpected error getting descriptor: %v", err) t.Fatalf("unexpected error getting descriptor: %v", err)
} }
if desc != expected { if !reflect.DeepEqual(desc, expected) {
t.Fatalf("unexpected descriptor: %#v != %#v", desc, expected) t.Fatalf("unexpected descriptor: %#v != %#v", desc, expected)
} }
@ -137,7 +138,7 @@ func checkBlobDescriptorCacheSetAndRead(t *testing.T, ctx context.Context, provi
expected.MediaType = "application/octet-stream" // expect original mediatype in global expected.MediaType = "application/octet-stream" // expect original mediatype in global
if desc != expected { if !reflect.DeepEqual(desc, expected) {
t.Fatalf("unexpected descriptor: %#v != %#v", desc, expected) t.Fatalf("unexpected descriptor: %#v != %#v", desc, expected)
} }
} }
@ -163,7 +164,7 @@ func checkBlobDescriptorCacheClear(t *testing.T, ctx context.Context, provider c
t.Fatalf("unexpected error statting fake2:abc: %v", err) t.Fatalf("unexpected error statting fake2:abc: %v", err)
} }
if expected != desc { if !reflect.DeepEqual(expected, desc) {
t.Fatalf("unexpected descriptor: %#v != %#v", expected, desc) t.Fatalf("unexpected descriptor: %#v != %#v", expected, desc)
} }

View file

@ -3,6 +3,7 @@
package azure package azure
import ( import (
"bufio"
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
@ -26,6 +27,7 @@ const (
paramAccountKey = "accountkey" paramAccountKey = "accountkey"
paramContainer = "container" paramContainer = "container"
paramRealm = "realm" paramRealm = "realm"
maxChunkSize = 4 * 1024 * 1024
) )
type driver struct { type driver struct {
@ -117,18 +119,21 @@ func (d *driver) PutContent(ctx context.Context, path string, contents []byte) e
if _, err := d.client.DeleteBlobIfExists(d.container, path); err != nil { if _, err := d.client.DeleteBlobIfExists(d.container, path); err != nil {
return err return err
} }
if err := d.client.CreateBlockBlob(d.container, path); err != nil { writer, err := d.Writer(ctx, path, false)
if err != nil {
return err return err
} }
bs := newAzureBlockStorage(d.client) defer writer.Close()
bw := newRandomBlobWriter(&bs, azure.MaxBlobBlockSize) _, err = writer.Write(contents)
_, err := bw.WriteBlobAt(d.container, path, 0, bytes.NewReader(contents)) if err != nil {
return err return err
}
return writer.Commit()
} }
// ReadStream retrieves an io.ReadCloser for the content stored at "path" with a // Reader retrieves an io.ReadCloser for the content stored at "path" with a
// given byte offset. // given byte offset.
func (d *driver) ReadStream(ctx context.Context, path string, offset int64) (io.ReadCloser, error) { func (d *driver) Reader(ctx context.Context, path string, offset int64) (io.ReadCloser, error) {
if ok, err := d.client.BlobExists(d.container, path); err != nil { if ok, err := d.client.BlobExists(d.container, path); err != nil {
return nil, err return nil, err
} else if !ok { } else if !ok {
@ -153,25 +158,38 @@ func (d *driver) ReadStream(ctx context.Context, path string, offset int64) (io.
return resp, nil return resp, nil
} }
// WriteStream stores the contents of the provided io.ReadCloser at a location // Writer returns a FileWriter which will store the content written to it
// designated by the given path. // at the location designated by "path" after the call to Commit.
func (d *driver) WriteStream(ctx context.Context, path string, offset int64, reader io.Reader) (int64, error) { func (d *driver) Writer(ctx context.Context, path string, append bool) (storagedriver.FileWriter, error) {
if blobExists, err := d.client.BlobExists(d.container, path); err != nil { blobExists, err := d.client.BlobExists(d.container, path)
return 0, err
} else if !blobExists {
err := d.client.CreateBlockBlob(d.container, path)
if err != nil { if err != nil {
return 0, err return nil, err
}
var size int64
if blobExists {
if append {
blobProperties, err := d.client.GetBlobProperties(d.container, path)
if err != nil {
return nil, err
}
size = blobProperties.ContentLength
} else {
err := d.client.DeleteBlob(d.container, path)
if err != nil {
return nil, err
} }
} }
if offset < 0 { } else {
return 0, storagedriver.InvalidOffsetError{Path: path, Offset: offset} if append {
return nil, storagedriver.PathNotFoundError{Path: path}
}
err := d.client.PutAppendBlob(d.container, path, nil)
if err != nil {
return nil, err
}
} }
bs := newAzureBlockStorage(d.client) return d.newWriter(path, size), nil
bw := newRandomBlobWriter(&bs, azure.MaxBlobBlockSize)
zw := newZeroFillWriter(&bw)
return zw.Write(d.container, path, offset, reader)
} }
// Stat retrieves the FileInfo for the given path, including the current size // Stat retrieves the FileInfo for the given path, including the current size
@ -236,6 +254,9 @@ func (d *driver) List(ctx context.Context, path string) ([]string, error) {
} }
list := directDescendants(blobs, path) list := directDescendants(blobs, path)
if path != "" && len(list) == 0 {
return nil, storagedriver.PathNotFoundError{Path: path}
}
return list, nil return list, nil
} }
@ -361,6 +382,101 @@ func (d *driver) listBlobs(container, virtPath string) ([]string, error) {
} }
func is404(err error) bool { func is404(err error) bool {
e, ok := err.(azure.AzureStorageServiceError) statusCodeErr, ok := err.(azure.AzureStorageServiceError)
return ok && e.StatusCode == http.StatusNotFound return ok && statusCodeErr.StatusCode == http.StatusNotFound
}
type writer struct {
driver *driver
path string
size int64
bw *bufio.Writer
closed bool
committed bool
cancelled bool
}
func (d *driver) newWriter(path string, size int64) storagedriver.FileWriter {
return &writer{
driver: d,
path: path,
size: size,
bw: bufio.NewWriterSize(&blockWriter{
client: d.client,
container: d.container,
path: path,
}, maxChunkSize),
}
}
func (w *writer) Write(p []byte) (int, error) {
if w.closed {
return 0, fmt.Errorf("already closed")
} else if w.committed {
return 0, fmt.Errorf("already committed")
} else if w.cancelled {
return 0, fmt.Errorf("already cancelled")
}
n, err := w.bw.Write(p)
w.size += int64(n)
return n, err
}
func (w *writer) Size() int64 {
return w.size
}
func (w *writer) Close() error {
if w.closed {
return fmt.Errorf("already closed")
}
w.closed = true
return w.bw.Flush()
}
func (w *writer) Cancel() error {
if w.closed {
return fmt.Errorf("already closed")
} else if w.committed {
return fmt.Errorf("already committed")
}
w.cancelled = true
return w.driver.client.DeleteBlob(w.driver.container, w.path)
}
func (w *writer) Commit() error {
if w.closed {
return fmt.Errorf("already closed")
} else if w.committed {
return fmt.Errorf("already committed")
} else if w.cancelled {
return fmt.Errorf("already cancelled")
}
w.committed = true
return w.bw.Flush()
}
type blockWriter struct {
client azure.BlobStorageClient
container string
path string
}
func (bw *blockWriter) Write(p []byte) (int, error) {
n := 0
for offset := 0; offset < len(p); offset += maxChunkSize {
chunkSize := maxChunkSize
if offset+chunkSize > len(p) {
chunkSize = len(p) - offset
}
err := bw.client.AppendBlock(bw.container, bw.path, p[offset:offset+chunkSize])
if err != nil {
return n, err
}
n += chunkSize
}
return n, nil
} }

View file

@ -1,24 +0,0 @@
package azure
import (
"fmt"
"io"
azure "github.com/Azure/azure-sdk-for-go/storage"
)
// azureBlockStorage is adaptor between azure.BlobStorageClient and
// blockStorage interface.
type azureBlockStorage struct {
azure.BlobStorageClient
}
func (b *azureBlockStorage) GetSectionReader(container, blob string, start, length int64) (io.ReadCloser, error) {
return b.BlobStorageClient.GetBlobRange(container, blob, fmt.Sprintf("%v-%v", start, start+length-1))
}
func newAzureBlockStorage(b azure.BlobStorageClient) azureBlockStorage {
a := azureBlockStorage{}
a.BlobStorageClient = b
return a
}

View file

@ -1,155 +0,0 @@
package azure
import (
"bytes"
"fmt"
"io"
"io/ioutil"
azure "github.com/Azure/azure-sdk-for-go/storage"
)
type StorageSimulator struct {
blobs map[string]*BlockBlob
}
type BlockBlob struct {
blocks map[string]*DataBlock
blockList []string
}
type DataBlock struct {
data []byte
committed bool
}
func (s *StorageSimulator) path(container, blob string) string {
return fmt.Sprintf("%s/%s", container, blob)
}
func (s *StorageSimulator) BlobExists(container, blob string) (bool, error) {
_, ok := s.blobs[s.path(container, blob)]
return ok, nil
}
func (s *StorageSimulator) GetBlob(container, blob string) (io.ReadCloser, error) {
bb, ok := s.blobs[s.path(container, blob)]
if !ok {
return nil, fmt.Errorf("blob not found")
}
var readers []io.Reader
for _, bID := range bb.blockList {
readers = append(readers, bytes.NewReader(bb.blocks[bID].data))
}
return ioutil.NopCloser(io.MultiReader(readers...)), nil
}
func (s *StorageSimulator) GetSectionReader(container, blob string, start, length int64) (io.ReadCloser, error) {
r, err := s.GetBlob(container, blob)
if err != nil {
return nil, err
}
b, err := ioutil.ReadAll(r)
if err != nil {
return nil, err
}
return ioutil.NopCloser(bytes.NewReader(b[start : start+length])), nil
}
func (s *StorageSimulator) CreateBlockBlob(container, blob string) error {
path := s.path(container, blob)
bb := &BlockBlob{
blocks: make(map[string]*DataBlock),
blockList: []string{},
}
s.blobs[path] = bb
return nil
}
func (s *StorageSimulator) PutBlock(container, blob, blockID string, chunk []byte) error {
path := s.path(container, blob)
bb, ok := s.blobs[path]
if !ok {
return fmt.Errorf("blob not found")
}
data := make([]byte, len(chunk))
copy(data, chunk)
bb.blocks[blockID] = &DataBlock{data: data, committed: false} // add block to blob
return nil
}
func (s *StorageSimulator) GetBlockList(container, blob string, blockType azure.BlockListType) (azure.BlockListResponse, error) {
resp := azure.BlockListResponse{}
bb, ok := s.blobs[s.path(container, blob)]
if !ok {
return resp, fmt.Errorf("blob not found")
}
// Iterate committed blocks (in order)
if blockType == azure.BlockListTypeAll || blockType == azure.BlockListTypeCommitted {
for _, blockID := range bb.blockList {
b := bb.blocks[blockID]
block := azure.BlockResponse{
Name: blockID,
Size: int64(len(b.data)),
}
resp.CommittedBlocks = append(resp.CommittedBlocks, block)
}
}
// Iterate uncommitted blocks (in no order)
if blockType == azure.BlockListTypeAll || blockType == azure.BlockListTypeCommitted {
for blockID, b := range bb.blocks {
block := azure.BlockResponse{
Name: blockID,
Size: int64(len(b.data)),
}
if !b.committed {
resp.UncommittedBlocks = append(resp.UncommittedBlocks, block)
}
}
}
return resp, nil
}
func (s *StorageSimulator) PutBlockList(container, blob string, blocks []azure.Block) error {
bb, ok := s.blobs[s.path(container, blob)]
if !ok {
return fmt.Errorf("blob not found")
}
var blockIDs []string
for _, v := range blocks {
bl, ok := bb.blocks[v.ID]
if !ok { // check if block ID exists
return fmt.Errorf("Block id '%s' not found", v.ID)
}
bl.committed = true
blockIDs = append(blockIDs, v.ID)
}
// Mark all other blocks uncommitted
for k, b := range bb.blocks {
inList := false
for _, v := range blockIDs {
if k == v {
inList = true
break
}
}
if !inList {
b.committed = false
}
}
bb.blockList = blockIDs
return nil
}
func NewStorageSimulator() StorageSimulator {
return StorageSimulator{
blobs: make(map[string]*BlockBlob),
}
}

View file

@ -1,60 +0,0 @@
package azure
import (
"encoding/base64"
"fmt"
"math/rand"
"sync"
"time"
azure "github.com/Azure/azure-sdk-for-go/storage"
)
type blockIDGenerator struct {
pool map[string]bool
r *rand.Rand
m sync.Mutex
}
// Generate returns an unused random block id and adds the generated ID
// to list of used IDs so that the same block name is not used again.
func (b *blockIDGenerator) Generate() string {
b.m.Lock()
defer b.m.Unlock()
var id string
for {
id = toBlockID(int(b.r.Int()))
if !b.exists(id) {
break
}
}
b.pool[id] = true
return id
}
func (b *blockIDGenerator) exists(id string) bool {
_, used := b.pool[id]
return used
}
func (b *blockIDGenerator) Feed(blocks azure.BlockListResponse) {
b.m.Lock()
defer b.m.Unlock()
for _, bl := range append(blocks.CommittedBlocks, blocks.UncommittedBlocks...) {
b.pool[bl.Name] = true
}
}
func newBlockIDGenerator() *blockIDGenerator {
return &blockIDGenerator{
pool: make(map[string]bool),
r: rand.New(rand.NewSource(time.Now().UnixNano()))}
}
// toBlockId converts given integer to base64-encoded block ID of a fixed length.
func toBlockID(i int) string {
s := fmt.Sprintf("%029d", i) // add zero padding for same length-blobs
return base64.StdEncoding.EncodeToString([]byte(s))
}

View file

@ -1,74 +0,0 @@
package azure
import (
"math"
"testing"
azure "github.com/Azure/azure-sdk-for-go/storage"
)
func Test_blockIdGenerator(t *testing.T) {
r := newBlockIDGenerator()
for i := 1; i <= 10; i++ {
if expected := i - 1; len(r.pool) != expected {
t.Fatalf("rand pool had wrong number of items: %d, expected:%d", len(r.pool), expected)
}
if id := r.Generate(); id == "" {
t.Fatal("returned empty id")
}
if expected := i; len(r.pool) != expected {
t.Fatalf("rand pool has wrong number of items: %d, expected:%d", len(r.pool), expected)
}
}
}
func Test_blockIdGenerator_Feed(t *testing.T) {
r := newBlockIDGenerator()
if expected := 0; len(r.pool) != expected {
t.Fatalf("rand pool had wrong number of items: %d, expected:%d", len(r.pool), expected)
}
// feed empty list
blocks := azure.BlockListResponse{}
r.Feed(blocks)
if expected := 0; len(r.pool) != expected {
t.Fatalf("rand pool had wrong number of items: %d, expected:%d", len(r.pool), expected)
}
// feed blocks
blocks = azure.BlockListResponse{
CommittedBlocks: []azure.BlockResponse{
{"1", 1},
{"2", 2},
},
UncommittedBlocks: []azure.BlockResponse{
{"3", 3},
}}
r.Feed(blocks)
if expected := 3; len(r.pool) != expected {
t.Fatalf("rand pool had wrong number of items: %d, expected:%d", len(r.pool), expected)
}
// feed same block IDs with committed/uncommitted place changed
blocks = azure.BlockListResponse{
CommittedBlocks: []azure.BlockResponse{
{"3", 3},
},
UncommittedBlocks: []azure.BlockResponse{
{"1", 1},
}}
r.Feed(blocks)
if expected := 3; len(r.pool) != expected {
t.Fatalf("rand pool had wrong number of items: %d, expected:%d", len(r.pool), expected)
}
}
func Test_toBlockId(t *testing.T) {
min := 0
max := math.MaxInt64
if len(toBlockID(min)) != len(toBlockID(max)) {
t.Fatalf("different-sized blockIDs are returned")
}
}

View file

@ -1,208 +0,0 @@
package azure
import (
"fmt"
"io"
"io/ioutil"
azure "github.com/Azure/azure-sdk-for-go/storage"
)
// blockStorage is the interface required from a block storage service
// client implementation
type blockStorage interface {
CreateBlockBlob(container, blob string) error
GetBlob(container, blob string) (io.ReadCloser, error)
GetSectionReader(container, blob string, start, length int64) (io.ReadCloser, error)
PutBlock(container, blob, blockID string, chunk []byte) error
GetBlockList(container, blob string, blockType azure.BlockListType) (azure.BlockListResponse, error)
PutBlockList(container, blob string, blocks []azure.Block) error
}
// randomBlobWriter enables random access semantics on Azure block blobs
// by enabling writing arbitrary length of chunks to arbitrary write offsets
// within the blob. Normally, Azure Blob Storage does not support random
// access semantics on block blobs; however, this writer can download, split and
// reupload the overlapping blocks and discards those being overwritten entirely.
type randomBlobWriter struct {
bs blockStorage
blockSize int
}
func newRandomBlobWriter(bs blockStorage, blockSize int) randomBlobWriter {
return randomBlobWriter{bs: bs, blockSize: blockSize}
}
// WriteBlobAt writes the given chunk to the specified position of an existing blob.
// The offset must be equals to size of the blob or smaller than it.
func (r *randomBlobWriter) WriteBlobAt(container, blob string, offset int64, chunk io.Reader) (int64, error) {
rand := newBlockIDGenerator()
blocks, err := r.bs.GetBlockList(container, blob, azure.BlockListTypeCommitted)
if err != nil {
return 0, err
}
rand.Feed(blocks) // load existing block IDs
// Check for write offset for existing blob
size := getBlobSize(blocks)
if offset < 0 || offset > size {
return 0, fmt.Errorf("wrong offset for Write: %v", offset)
}
// Upload the new chunk as blocks
blockList, nn, err := r.writeChunkToBlocks(container, blob, chunk, rand)
if err != nil {
return 0, err
}
// For non-append operations, existing blocks may need to be splitted
if offset != size {
// Split the block on the left end (if any)
leftBlocks, err := r.blocksLeftSide(container, blob, offset, rand)
if err != nil {
return 0, err
}
blockList = append(leftBlocks, blockList...)
// Split the block on the right end (if any)
rightBlocks, err := r.blocksRightSide(container, blob, offset, nn, rand)
if err != nil {
return 0, err
}
blockList = append(blockList, rightBlocks...)
} else {
// Use existing block list
var existingBlocks []azure.Block
for _, v := range blocks.CommittedBlocks {
existingBlocks = append(existingBlocks, azure.Block{ID: v.Name, Status: azure.BlockStatusCommitted})
}
blockList = append(existingBlocks, blockList...)
}
// Put block list
return nn, r.bs.PutBlockList(container, blob, blockList)
}
func (r *randomBlobWriter) GetSize(container, blob string) (int64, error) {
blocks, err := r.bs.GetBlockList(container, blob, azure.BlockListTypeCommitted)
if err != nil {
return 0, err
}
return getBlobSize(blocks), nil
}
// writeChunkToBlocks writes given chunk to one or multiple blocks within specified
// blob and returns their block representations. Those blocks are not committed, yet
func (r *randomBlobWriter) writeChunkToBlocks(container, blob string, chunk io.Reader, rand *blockIDGenerator) ([]azure.Block, int64, error) {
var newBlocks []azure.Block
var nn int64
// Read chunks of at most size N except the last chunk to
// maximize block size and minimize block count.
buf := make([]byte, r.blockSize)
for {
n, err := io.ReadFull(chunk, buf)
if err == io.EOF {
break
}
nn += int64(n)
data := buf[:n]
blockID := rand.Generate()
if err := r.bs.PutBlock(container, blob, blockID, data); err != nil {
return newBlocks, nn, err
}
newBlocks = append(newBlocks, azure.Block{ID: blockID, Status: azure.BlockStatusUncommitted})
}
return newBlocks, nn, nil
}
// blocksLeftSide returns the blocks that are going to be at the left side of
// the writeOffset: [0, writeOffset) by identifying blocks that will remain
// the same and splitting blocks and reuploading them as needed.
func (r *randomBlobWriter) blocksLeftSide(container, blob string, writeOffset int64, rand *blockIDGenerator) ([]azure.Block, error) {
var left []azure.Block
bx, err := r.bs.GetBlockList(container, blob, azure.BlockListTypeAll)
if err != nil {
return left, err
}
o := writeOffset
elapsed := int64(0)
for _, v := range bx.CommittedBlocks {
blkSize := int64(v.Size)
if o >= blkSize { // use existing block
left = append(left, azure.Block{ID: v.Name, Status: azure.BlockStatusCommitted})
o -= blkSize
elapsed += blkSize
} else if o > 0 { // current block needs to be splitted
start := elapsed
size := o
part, err := r.bs.GetSectionReader(container, blob, start, size)
if err != nil {
return left, err
}
newBlockID := rand.Generate()
data, err := ioutil.ReadAll(part)
if err != nil {
return left, err
}
if err = r.bs.PutBlock(container, blob, newBlockID, data); err != nil {
return left, err
}
left = append(left, azure.Block{ID: newBlockID, Status: azure.BlockStatusUncommitted})
break
}
}
return left, nil
}
// blocksRightSide returns the blocks that are going to be at the right side of
// the written chunk: [writeOffset+size, +inf) by identifying blocks that will remain
// the same and splitting blocks and reuploading them as needed.
func (r *randomBlobWriter) blocksRightSide(container, blob string, writeOffset int64, chunkSize int64, rand *blockIDGenerator) ([]azure.Block, error) {
var right []azure.Block
bx, err := r.bs.GetBlockList(container, blob, azure.BlockListTypeAll)
if err != nil {
return nil, err
}
re := writeOffset + chunkSize - 1 // right end of written chunk
var elapsed int64
for _, v := range bx.CommittedBlocks {
var (
bs = elapsed // left end of current block
be = elapsed + int64(v.Size) - 1 // right end of current block
)
if bs > re { // take the block as is
right = append(right, azure.Block{ID: v.Name, Status: azure.BlockStatusCommitted})
} else if be > re { // current block needs to be splitted
part, err := r.bs.GetSectionReader(container, blob, re+1, be-(re+1)+1)
if err != nil {
return right, err
}
newBlockID := rand.Generate()
data, err := ioutil.ReadAll(part)
if err != nil {
return right, err
}
if err = r.bs.PutBlock(container, blob, newBlockID, data); err != nil {
return right, err
}
right = append(right, azure.Block{ID: newBlockID, Status: azure.BlockStatusUncommitted})
}
elapsed += int64(v.Size)
}
return right, nil
}
func getBlobSize(blocks azure.BlockListResponse) int64 {
var n int64
for _, v := range blocks.CommittedBlocks {
n += int64(v.Size)
}
return n
}

View file

@ -1,339 +0,0 @@
package azure
import (
"bytes"
"io"
"io/ioutil"
"math/rand"
"reflect"
"strings"
"testing"
azure "github.com/Azure/azure-sdk-for-go/storage"
)
func TestRandomWriter_writeChunkToBlocks(t *testing.T) {
s := NewStorageSimulator()
rw := newRandomBlobWriter(&s, 3)
rand := newBlockIDGenerator()
c := []byte("AAABBBCCCD")
if err := rw.bs.CreateBlockBlob("a", "b"); err != nil {
t.Fatal(err)
}
bw, nn, err := rw.writeChunkToBlocks("a", "b", bytes.NewReader(c), rand)
if err != nil {
t.Fatal(err)
}
if expected := int64(len(c)); nn != expected {
t.Fatalf("wrong nn:%v, expected:%v", nn, expected)
}
if expected := 4; len(bw) != expected {
t.Fatal("unexpected written block count")
}
bx, err := s.GetBlockList("a", "b", azure.BlockListTypeAll)
if err != nil {
t.Fatal(err)
}
if expected := 0; len(bx.CommittedBlocks) != expected {
t.Fatal("unexpected committed block count")
}
if expected := 4; len(bx.UncommittedBlocks) != expected {
t.Fatalf("unexpected uncommitted block count: %d -- %#v", len(bx.UncommittedBlocks), bx)
}
if err := rw.bs.PutBlockList("a", "b", bw); err != nil {
t.Fatal(err)
}
r, err := rw.bs.GetBlob("a", "b")
if err != nil {
t.Fatal(err)
}
assertBlobContents(t, r, c)
}
func TestRandomWriter_blocksLeftSide(t *testing.T) {
blob := "AAAAABBBBBCCC"
cases := []struct {
offset int64
expectedBlob string
expectedPattern []azure.BlockStatus
}{
{0, "", []azure.BlockStatus{}}, // write to beginning, discard all
{13, blob, []azure.BlockStatus{azure.BlockStatusCommitted, azure.BlockStatusCommitted, azure.BlockStatusCommitted}}, // write to end, no change
{1, "A", []azure.BlockStatus{azure.BlockStatusUncommitted}}, // write at 1
{5, "AAAAA", []azure.BlockStatus{azure.BlockStatusCommitted}}, // write just after first block
{6, "AAAAAB", []azure.BlockStatus{azure.BlockStatusCommitted, azure.BlockStatusUncommitted}}, // split the second block
{9, "AAAAABBBB", []azure.BlockStatus{azure.BlockStatusCommitted, azure.BlockStatusUncommitted}}, // write just after first block
}
for _, c := range cases {
s := NewStorageSimulator()
rw := newRandomBlobWriter(&s, 5)
rand := newBlockIDGenerator()
if err := rw.bs.CreateBlockBlob("a", "b"); err != nil {
t.Fatal(err)
}
bw, _, err := rw.writeChunkToBlocks("a", "b", strings.NewReader(blob), rand)
if err != nil {
t.Fatal(err)
}
if err := rw.bs.PutBlockList("a", "b", bw); err != nil {
t.Fatal(err)
}
bx, err := rw.blocksLeftSide("a", "b", c.offset, rand)
if err != nil {
t.Fatal(err)
}
bs := []azure.BlockStatus{}
for _, v := range bx {
bs = append(bs, v.Status)
}
if !reflect.DeepEqual(bs, c.expectedPattern) {
t.Logf("Committed blocks %v", bw)
t.Fatalf("For offset %v: Expected pattern: %v, Got: %v\n(Returned: %v)", c.offset, c.expectedPattern, bs, bx)
}
if rw.bs.PutBlockList("a", "b", bx); err != nil {
t.Fatal(err)
}
r, err := rw.bs.GetBlob("a", "b")
if err != nil {
t.Fatal(err)
}
cout, err := ioutil.ReadAll(r)
if err != nil {
t.Fatal(err)
}
outBlob := string(cout)
if outBlob != c.expectedBlob {
t.Fatalf("wrong blob contents: %v, expected: %v", outBlob, c.expectedBlob)
}
}
}
func TestRandomWriter_blocksRightSide(t *testing.T) {
blob := "AAAAABBBBBCCC"
cases := []struct {
offset int64
size int64
expectedBlob string
expectedPattern []azure.BlockStatus
}{
{0, 100, "", []azure.BlockStatus{}}, // overwrite the entire blob
{0, 3, "AABBBBBCCC", []azure.BlockStatus{azure.BlockStatusUncommitted, azure.BlockStatusCommitted, azure.BlockStatusCommitted}}, // split first block
{4, 1, "BBBBBCCC", []azure.BlockStatus{azure.BlockStatusCommitted, azure.BlockStatusCommitted}}, // write to last char of first block
{1, 6, "BBBCCC", []azure.BlockStatus{azure.BlockStatusUncommitted, azure.BlockStatusCommitted}}, // overwrite splits first and second block, last block remains
{3, 8, "CC", []azure.BlockStatus{azure.BlockStatusUncommitted}}, // overwrite a block in middle block, split end block
{10, 1, "CC", []azure.BlockStatus{azure.BlockStatusUncommitted}}, // overwrite first byte of rightmost block
{11, 2, "", []azure.BlockStatus{}}, // overwrite the rightmost index
{13, 20, "", []azure.BlockStatus{}}, // append to the end
}
for _, c := range cases {
s := NewStorageSimulator()
rw := newRandomBlobWriter(&s, 5)
rand := newBlockIDGenerator()
if err := rw.bs.CreateBlockBlob("a", "b"); err != nil {
t.Fatal(err)
}
bw, _, err := rw.writeChunkToBlocks("a", "b", strings.NewReader(blob), rand)
if err != nil {
t.Fatal(err)
}
if err := rw.bs.PutBlockList("a", "b", bw); err != nil {
t.Fatal(err)
}
bx, err := rw.blocksRightSide("a", "b", c.offset, c.size, rand)
if err != nil {
t.Fatal(err)
}
bs := []azure.BlockStatus{}
for _, v := range bx {
bs = append(bs, v.Status)
}
if !reflect.DeepEqual(bs, c.expectedPattern) {
t.Logf("Committed blocks %v", bw)
t.Fatalf("For offset %v-size:%v: Expected pattern: %v, Got: %v\n(Returned: %v)", c.offset, c.size, c.expectedPattern, bs, bx)
}
if rw.bs.PutBlockList("a", "b", bx); err != nil {
t.Fatal(err)
}
r, err := rw.bs.GetBlob("a", "b")
if err != nil {
t.Fatal(err)
}
cout, err := ioutil.ReadAll(r)
if err != nil {
t.Fatal(err)
}
outBlob := string(cout)
if outBlob != c.expectedBlob {
t.Fatalf("For offset %v-size:%v: wrong blob contents: %v, expected: %v", c.offset, c.size, outBlob, c.expectedBlob)
}
}
}
func TestRandomWriter_Write_NewBlob(t *testing.T) {
var (
s = NewStorageSimulator()
rw = newRandomBlobWriter(&s, 1024*3) // 3 KB blocks
blob = randomContents(1024 * 7) // 7 KB blob
)
if err := rw.bs.CreateBlockBlob("a", "b"); err != nil {
t.Fatal(err)
}
if _, err := rw.WriteBlobAt("a", "b", 10, bytes.NewReader(blob)); err == nil {
t.Fatal("expected error, got nil")
}
if _, err := rw.WriteBlobAt("a", "b", 100000, bytes.NewReader(blob)); err == nil {
t.Fatal("expected error, got nil")
}
if nn, err := rw.WriteBlobAt("a", "b", 0, bytes.NewReader(blob)); err != nil {
t.Fatal(err)
} else if expected := int64(len(blob)); expected != nn {
t.Fatalf("wrong written bytes count: %v, expected: %v", nn, expected)
}
if out, err := rw.bs.GetBlob("a", "b"); err != nil {
t.Fatal(err)
} else {
assertBlobContents(t, out, blob)
}
if bx, err := rw.bs.GetBlockList("a", "b", azure.BlockListTypeCommitted); err != nil {
t.Fatal(err)
} else if len(bx.CommittedBlocks) != 3 {
t.Fatalf("got wrong number of committed blocks: %v", len(bx.CommittedBlocks))
}
// Replace first 512 bytes
leftChunk := randomContents(512)
blob = append(leftChunk, blob[512:]...)
if nn, err := rw.WriteBlobAt("a", "b", 0, bytes.NewReader(leftChunk)); err != nil {
t.Fatal(err)
} else if expected := int64(len(leftChunk)); expected != nn {
t.Fatalf("wrong written bytes count: %v, expected: %v", nn, expected)
}
if out, err := rw.bs.GetBlob("a", "b"); err != nil {
t.Fatal(err)
} else {
assertBlobContents(t, out, blob)
}
if bx, err := rw.bs.GetBlockList("a", "b", azure.BlockListTypeCommitted); err != nil {
t.Fatal(err)
} else if expected := 4; len(bx.CommittedBlocks) != expected {
t.Fatalf("got wrong number of committed blocks: %v, expected: %v", len(bx.CommittedBlocks), expected)
}
// Replace last 512 bytes with 1024 bytes
rightChunk := randomContents(1024)
offset := int64(len(blob) - 512)
blob = append(blob[:offset], rightChunk...)
if nn, err := rw.WriteBlobAt("a", "b", offset, bytes.NewReader(rightChunk)); err != nil {
t.Fatal(err)
} else if expected := int64(len(rightChunk)); expected != nn {
t.Fatalf("wrong written bytes count: %v, expected: %v", nn, expected)
}
if out, err := rw.bs.GetBlob("a", "b"); err != nil {
t.Fatal(err)
} else {
assertBlobContents(t, out, blob)
}
if bx, err := rw.bs.GetBlockList("a", "b", azure.BlockListTypeCommitted); err != nil {
t.Fatal(err)
} else if expected := 5; len(bx.CommittedBlocks) != expected {
t.Fatalf("got wrong number of committed blocks: %v, expected: %v", len(bx.CommittedBlocks), expected)
}
// Replace 2K-4K (overlaps 2 blocks from L/R)
newChunk := randomContents(1024 * 2)
offset = 1024 * 2
blob = append(append(blob[:offset], newChunk...), blob[offset+int64(len(newChunk)):]...)
if nn, err := rw.WriteBlobAt("a", "b", offset, bytes.NewReader(newChunk)); err != nil {
t.Fatal(err)
} else if expected := int64(len(newChunk)); expected != nn {
t.Fatalf("wrong written bytes count: %v, expected: %v", nn, expected)
}
if out, err := rw.bs.GetBlob("a", "b"); err != nil {
t.Fatal(err)
} else {
assertBlobContents(t, out, blob)
}
if bx, err := rw.bs.GetBlockList("a", "b", azure.BlockListTypeCommitted); err != nil {
t.Fatal(err)
} else if expected := 6; len(bx.CommittedBlocks) != expected {
t.Fatalf("got wrong number of committed blocks: %v, expected: %v\n%v", len(bx.CommittedBlocks), expected, bx.CommittedBlocks)
}
// Replace the entire blob
newBlob := randomContents(1024 * 30)
if nn, err := rw.WriteBlobAt("a", "b", 0, bytes.NewReader(newBlob)); err != nil {
t.Fatal(err)
} else if expected := int64(len(newBlob)); expected != nn {
t.Fatalf("wrong written bytes count: %v, expected: %v", nn, expected)
}
if out, err := rw.bs.GetBlob("a", "b"); err != nil {
t.Fatal(err)
} else {
assertBlobContents(t, out, newBlob)
}
if bx, err := rw.bs.GetBlockList("a", "b", azure.BlockListTypeCommitted); err != nil {
t.Fatal(err)
} else if expected := 10; len(bx.CommittedBlocks) != expected {
t.Fatalf("got wrong number of committed blocks: %v, expected: %v\n%v", len(bx.CommittedBlocks), expected, bx.CommittedBlocks)
} else if expected, size := int64(1024*30), getBlobSize(bx); size != expected {
t.Fatalf("committed block size does not indicate blob size")
}
}
func Test_getBlobSize(t *testing.T) {
// with some committed blocks
if expected, size := int64(151), getBlobSize(azure.BlockListResponse{
CommittedBlocks: []azure.BlockResponse{
{"A", 100},
{"B", 50},
{"C", 1},
},
UncommittedBlocks: []azure.BlockResponse{
{"D", 200},
}}); expected != size {
t.Fatalf("wrong blob size: %v, expected: %v", size, expected)
}
// with no committed blocks
if expected, size := int64(0), getBlobSize(azure.BlockListResponse{
UncommittedBlocks: []azure.BlockResponse{
{"A", 100},
{"B", 50},
{"C", 1},
{"D", 200},
}}); expected != size {
t.Fatalf("wrong blob size: %v, expected: %v", size, expected)
}
}
func assertBlobContents(t *testing.T, r io.Reader, expected []byte) {
out, err := ioutil.ReadAll(r)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(out, expected) {
t.Fatalf("wrong blob contents. size: %v, expected: %v", len(out), len(expected))
}
}
func randomContents(length int64) []byte {
b := make([]byte, length)
for i := range b {
b[i] = byte(rand.Intn(2 << 8))
}
return b
}

View file

@ -1,49 +0,0 @@
package azure
import (
"bytes"
"io"
)
type blockBlobWriter interface {
GetSize(container, blob string) (int64, error)
WriteBlobAt(container, blob string, offset int64, chunk io.Reader) (int64, error)
}
// zeroFillWriter enables writing to an offset outside a block blob's size
// by offering the chunk to the underlying writer as a contiguous data with
// the gap in between filled with NUL (zero) bytes.
type zeroFillWriter struct {
blockBlobWriter
}
func newZeroFillWriter(b blockBlobWriter) zeroFillWriter {
w := zeroFillWriter{}
w.blockBlobWriter = b
return w
}
// Write writes the given chunk to the specified existing blob even though
// offset is out of blob's size. The gaps are filled with zeros. Returned
// written number count does not include zeros written.
func (z *zeroFillWriter) Write(container, blob string, offset int64, chunk io.Reader) (int64, error) {
size, err := z.blockBlobWriter.GetSize(container, blob)
if err != nil {
return 0, err
}
var reader io.Reader
var zeroPadding int64
if offset <= size {
reader = chunk
} else {
zeroPadding = offset - size
offset = size // adjust offset to be the append index
zeros := bytes.NewReader(make([]byte, zeroPadding))
reader = io.MultiReader(zeros, chunk)
}
nn, err := z.blockBlobWriter.WriteBlobAt(container, blob, offset, reader)
nn -= zeroPadding
return nn, err
}

View file

@ -1,126 +0,0 @@
package azure
import (
"bytes"
"testing"
)
func Test_zeroFillWrite_AppendNoGap(t *testing.T) {
s := NewStorageSimulator()
bw := newRandomBlobWriter(&s, 1024*1)
zw := newZeroFillWriter(&bw)
if err := s.CreateBlockBlob("a", "b"); err != nil {
t.Fatal(err)
}
firstChunk := randomContents(1024*3 + 512)
if nn, err := zw.Write("a", "b", 0, bytes.NewReader(firstChunk)); err != nil {
t.Fatal(err)
} else if expected := int64(len(firstChunk)); expected != nn {
t.Fatalf("wrong written bytes count: %v, expected: %v", nn, expected)
}
if out, err := s.GetBlob("a", "b"); err != nil {
t.Fatal(err)
} else {
assertBlobContents(t, out, firstChunk)
}
secondChunk := randomContents(256)
if nn, err := zw.Write("a", "b", int64(len(firstChunk)), bytes.NewReader(secondChunk)); err != nil {
t.Fatal(err)
} else if expected := int64(len(secondChunk)); expected != nn {
t.Fatalf("wrong written bytes count: %v, expected: %v", nn, expected)
}
if out, err := s.GetBlob("a", "b"); err != nil {
t.Fatal(err)
} else {
assertBlobContents(t, out, append(firstChunk, secondChunk...))
}
}
func Test_zeroFillWrite_StartWithGap(t *testing.T) {
s := NewStorageSimulator()
bw := newRandomBlobWriter(&s, 1024*2)
zw := newZeroFillWriter(&bw)
if err := s.CreateBlockBlob("a", "b"); err != nil {
t.Fatal(err)
}
chunk := randomContents(1024 * 5)
padding := int64(1024*2 + 256)
if nn, err := zw.Write("a", "b", padding, bytes.NewReader(chunk)); err != nil {
t.Fatal(err)
} else if expected := int64(len(chunk)); expected != nn {
t.Fatalf("wrong written bytes count: %v, expected: %v", nn, expected)
}
if out, err := s.GetBlob("a", "b"); err != nil {
t.Fatal(err)
} else {
assertBlobContents(t, out, append(make([]byte, padding), chunk...))
}
}
func Test_zeroFillWrite_AppendWithGap(t *testing.T) {
s := NewStorageSimulator()
bw := newRandomBlobWriter(&s, 1024*2)
zw := newZeroFillWriter(&bw)
if err := s.CreateBlockBlob("a", "b"); err != nil {
t.Fatal(err)
}
firstChunk := randomContents(1024*3 + 512)
if _, err := zw.Write("a", "b", 0, bytes.NewReader(firstChunk)); err != nil {
t.Fatal(err)
}
if out, err := s.GetBlob("a", "b"); err != nil {
t.Fatal(err)
} else {
assertBlobContents(t, out, firstChunk)
}
secondChunk := randomContents(256)
padding := int64(1024 * 4)
if nn, err := zw.Write("a", "b", int64(len(firstChunk))+padding, bytes.NewReader(secondChunk)); err != nil {
t.Fatal(err)
} else if expected := int64(len(secondChunk)); expected != nn {
t.Fatalf("wrong written bytes count: %v, expected: %v", nn, expected)
}
if out, err := s.GetBlob("a", "b"); err != nil {
t.Fatal(err)
} else {
assertBlobContents(t, out, append(firstChunk, append(make([]byte, padding), secondChunk...)...))
}
}
func Test_zeroFillWrite_LiesWithinSize(t *testing.T) {
s := NewStorageSimulator()
bw := newRandomBlobWriter(&s, 1024*2)
zw := newZeroFillWriter(&bw)
if err := s.CreateBlockBlob("a", "b"); err != nil {
t.Fatal(err)
}
firstChunk := randomContents(1024 * 3)
if _, err := zw.Write("a", "b", 0, bytes.NewReader(firstChunk)); err != nil {
t.Fatal(err)
}
if out, err := s.GetBlob("a", "b"); err != nil {
t.Fatal(err)
} else {
assertBlobContents(t, out, firstChunk)
}
// in this case, zerofill won't be used
secondChunk := randomContents(256)
if nn, err := zw.Write("a", "b", 0, bytes.NewReader(secondChunk)); err != nil {
t.Fatal(err)
} else if expected := int64(len(secondChunk)); expected != nn {
t.Fatalf("wrong written bytes count: %v, expected: %v", nn, expected)
}
if out, err := s.GetBlob("a", "b"); err != nil {
t.Fatal(err)
} else {
assertBlobContents(t, out, append(secondChunk, firstChunk[len(secondChunk):]...))
}
}

View file

@ -102,10 +102,10 @@ func (base *Base) PutContent(ctx context.Context, path string, content []byte) e
return base.setDriverName(base.StorageDriver.PutContent(ctx, path, content)) return base.setDriverName(base.StorageDriver.PutContent(ctx, path, content))
} }
// ReadStream wraps ReadStream of underlying storage driver. // Reader wraps Reader of underlying storage driver.
func (base *Base) ReadStream(ctx context.Context, path string, offset int64) (io.ReadCloser, error) { func (base *Base) Reader(ctx context.Context, path string, offset int64) (io.ReadCloser, error) {
ctx, done := context.WithTrace(ctx) ctx, done := context.WithTrace(ctx)
defer done("%s.ReadStream(%q, %d)", base.Name(), path, offset) defer done("%s.Reader(%q, %d)", base.Name(), path, offset)
if offset < 0 { if offset < 0 {
return nil, storagedriver.InvalidOffsetError{Path: path, Offset: offset, DriverName: base.StorageDriver.Name()} return nil, storagedriver.InvalidOffsetError{Path: path, Offset: offset, DriverName: base.StorageDriver.Name()}
@ -115,25 +115,21 @@ func (base *Base) ReadStream(ctx context.Context, path string, offset int64) (io
return nil, storagedriver.InvalidPathError{Path: path, DriverName: base.StorageDriver.Name()} return nil, storagedriver.InvalidPathError{Path: path, DriverName: base.StorageDriver.Name()}
} }
rc, e := base.StorageDriver.ReadStream(ctx, path, offset) rc, e := base.StorageDriver.Reader(ctx, path, offset)
return rc, base.setDriverName(e) return rc, base.setDriverName(e)
} }
// WriteStream wraps WriteStream of underlying storage driver. // Writer wraps Writer of underlying storage driver.
func (base *Base) WriteStream(ctx context.Context, path string, offset int64, reader io.Reader) (nn int64, err error) { func (base *Base) Writer(ctx context.Context, path string, append bool) (storagedriver.FileWriter, error) {
ctx, done := context.WithTrace(ctx) ctx, done := context.WithTrace(ctx)
defer done("%s.WriteStream(%q, %d)", base.Name(), path, offset) defer done("%s.Writer(%q, %v)", base.Name(), path, append)
if offset < 0 {
return 0, storagedriver.InvalidOffsetError{Path: path, Offset: offset, DriverName: base.StorageDriver.Name()}
}
if !storagedriver.PathRegexp.MatchString(path) { if !storagedriver.PathRegexp.MatchString(path) {
return 0, storagedriver.InvalidPathError{Path: path, DriverName: base.StorageDriver.Name()} return nil, storagedriver.InvalidPathError{Path: path, DriverName: base.StorageDriver.Name()}
} }
i64, e := base.StorageDriver.WriteStream(ctx, path, offset, reader) writer, e := base.StorageDriver.Writer(ctx, path, append)
return i64, base.setDriverName(e) return writer, base.setDriverName(e)
} }
// Stat wraps Stat of underlying storage driver. // Stat wraps Stat of underlying storage driver.

View file

@ -0,0 +1,145 @@
package base
import (
"io"
"sync"
"github.com/docker/distribution/context"
storagedriver "github.com/docker/distribution/registry/storage/driver"
)
type regulator struct {
storagedriver.StorageDriver
*sync.Cond
available uint64
}
// NewRegulator wraps the given driver and is used to regulate concurrent calls
// to the given storage driver to a maximum of the given limit. This is useful
// for storage drivers that would otherwise create an unbounded number of OS
// threads if allowed to be called unregulated.
func NewRegulator(driver storagedriver.StorageDriver, limit uint64) storagedriver.StorageDriver {
return &regulator{
StorageDriver: driver,
Cond: sync.NewCond(&sync.Mutex{}),
available: limit,
}
}
func (r *regulator) enter() {
r.L.Lock()
for r.available == 0 {
r.Wait()
}
r.available--
r.L.Unlock()
}
func (r *regulator) exit() {
r.L.Lock()
// We only need to signal to a waiting FS operation if we're already at the
// limit of threads used
if r.available == 0 {
r.Signal()
}
r.available++
r.L.Unlock()
}
// Name returns the human-readable "name" of the driver, useful in error
// messages and logging. By convention, this will just be the registration
// name, but drivers may provide other information here.
func (r *regulator) Name() string {
r.enter()
defer r.exit()
return r.StorageDriver.Name()
}
// GetContent retrieves the content stored at "path" as a []byte.
// This should primarily be used for small objects.
func (r *regulator) GetContent(ctx context.Context, path string) ([]byte, error) {
r.enter()
defer r.exit()
return r.StorageDriver.GetContent(ctx, path)
}
// PutContent stores the []byte content at a location designated by "path".
// This should primarily be used for small objects.
func (r *regulator) PutContent(ctx context.Context, path string, content []byte) error {
r.enter()
defer r.exit()
return r.StorageDriver.PutContent(ctx, path, content)
}
// Reader retrieves an io.ReadCloser for the content stored at "path"
// with a given byte offset.
// May be used to resume reading a stream by providing a nonzero offset.
func (r *regulator) Reader(ctx context.Context, path string, offset int64) (io.ReadCloser, error) {
r.enter()
defer r.exit()
return r.StorageDriver.Reader(ctx, path, offset)
}
// Writer stores the contents of the provided io.ReadCloser at a
// location designated by the given path.
// May be used to resume writing a stream by providing a nonzero offset.
// The offset must be no larger than the CurrentSize for this path.
func (r *regulator) Writer(ctx context.Context, path string, append bool) (storagedriver.FileWriter, error) {
r.enter()
defer r.exit()
return r.StorageDriver.Writer(ctx, path, append)
}
// Stat retrieves the FileInfo for the given path, including the current
// size in bytes and the creation time.
func (r *regulator) Stat(ctx context.Context, path string) (storagedriver.FileInfo, error) {
r.enter()
defer r.exit()
return r.StorageDriver.Stat(ctx, path)
}
// List returns a list of the objects that are direct descendants of the
//given path.
func (r *regulator) List(ctx context.Context, path string) ([]string, error) {
r.enter()
defer r.exit()
return r.StorageDriver.List(ctx, path)
}
// Move moves an object stored at sourcePath to destPath, removing the
// original object.
// Note: This may be no more efficient than a copy followed by a delete for
// many implementations.
func (r *regulator) Move(ctx context.Context, sourcePath string, destPath string) error {
r.enter()
defer r.exit()
return r.StorageDriver.Move(ctx, sourcePath, destPath)
}
// Delete recursively deletes all objects stored at "path" and its subpaths.
func (r *regulator) Delete(ctx context.Context, path string) error {
r.enter()
defer r.exit()
return r.StorageDriver.Delete(ctx, path)
}
// URLFor returns a URL which may be used to retrieve the content stored at
// the given path, possibly using the given options.
// May return an ErrUnsupportedMethod in certain StorageDriver
// implementations.
func (r *regulator) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) {
r.enter()
defer r.exit()
return r.StorageDriver.URLFor(ctx, path, options)
}

View file

@ -11,7 +11,14 @@ import (
var driverFactories = make(map[string]StorageDriverFactory) var driverFactories = make(map[string]StorageDriverFactory)
// StorageDriverFactory is a factory interface for creating storagedriver.StorageDriver interfaces // StorageDriverFactory is a factory interface for creating storagedriver.StorageDriver interfaces
// Storage drivers should call Register() with a factory to make the driver available by name // Storage drivers should call Register() with a factory to make the driver available by name.
// Individual StorageDriver implementations generally register with the factory via the Register
// func (below) in their init() funcs, and as such they should be imported anonymously before use.
// See below for an example of how to register and get a StorageDriver for S3
//
// import _ "github.com/docker/distribution/registry/storage/driver/s3-aws"
// s3Driver, err = factory.Create("s3", storageParams)
// // assuming no error, s3Driver is the StorageDriver that communicates with S3 according to storageParams
type StorageDriverFactory interface { type StorageDriverFactory interface {
// Create returns a new storagedriver.StorageDriver with the given parameters // Create returns a new storagedriver.StorageDriver with the given parameters
// Parameters will vary by driver and may be ignored // Parameters will vary by driver and may be ignored
@ -21,6 +28,8 @@ type StorageDriverFactory interface {
// Register makes a storage driver available by the provided name. // Register makes a storage driver available by the provided name.
// If Register is called twice with the same name or if driver factory is nil, it panics. // If Register is called twice with the same name or if driver factory is nil, it panics.
// Additionally, it is not concurrency safe. Most Storage Drivers call this function
// in their init() functions. See the documentation for StorageDriverFactory for more.
func Register(name string, factory StorageDriverFactory) { func Register(name string, factory StorageDriverFactory) {
if factory == nil { if factory == nil {
panic("Must not provide nil StorageDriverFactory") panic("Must not provide nil StorageDriverFactory")

View file

@ -1,12 +1,15 @@
package filesystem package filesystem
import ( import (
"bufio"
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"os" "os"
"path" "path"
"reflect"
"strconv"
"time" "time"
"github.com/docker/distribution/context" "github.com/docker/distribution/context"
@ -15,8 +18,23 @@ import (
"github.com/docker/distribution/registry/storage/driver/factory" "github.com/docker/distribution/registry/storage/driver/factory"
) )
const driverName = "filesystem" const (
const defaultRootDirectory = "/var/lib/registry" driverName = "filesystem"
defaultRootDirectory = "/var/lib/registry"
defaultMaxThreads = uint64(100)
// minThreads is the minimum value for the maxthreads configuration
// parameter. If the driver's parameters are less than this we set
// the parameters to minThreads
minThreads = uint64(25)
)
// DriverParameters represents all configuration options available for the
// filesystem driver
type DriverParameters struct {
RootDirectory string
MaxThreads uint64
}
func init() { func init() {
factory.Register(driverName, &filesystemDriverFactory{}) factory.Register(driverName, &filesystemDriverFactory{})
@ -26,7 +44,7 @@ func init() {
type filesystemDriverFactory struct{} type filesystemDriverFactory struct{}
func (factory *filesystemDriverFactory) Create(parameters map[string]interface{}) (storagedriver.StorageDriver, error) { func (factory *filesystemDriverFactory) Create(parameters map[string]interface{}) (storagedriver.StorageDriver, error) {
return FromParameters(parameters), nil return FromParameters(parameters)
} }
type driver struct { type driver struct {
@ -46,25 +64,72 @@ type Driver struct {
// FromParameters constructs a new Driver with a given parameters map // FromParameters constructs a new Driver with a given parameters map
// Optional Parameters: // Optional Parameters:
// - rootdirectory // - rootdirectory
func FromParameters(parameters map[string]interface{}) *Driver { // - maxthreads
var rootDirectory = defaultRootDirectory func FromParameters(parameters map[string]interface{}) (*Driver, error) {
params, err := fromParametersImpl(parameters)
if err != nil || params == nil {
return nil, err
}
return New(*params), nil
}
func fromParametersImpl(parameters map[string]interface{}) (*DriverParameters, error) {
var (
err error
maxThreads = defaultMaxThreads
rootDirectory = defaultRootDirectory
)
if parameters != nil { if parameters != nil {
rootDir, ok := parameters["rootdirectory"] if rootDir, ok := parameters["rootdirectory"]; ok {
if ok {
rootDirectory = fmt.Sprint(rootDir) rootDirectory = fmt.Sprint(rootDir)
} }
// Get maximum number of threads for blocking filesystem operations,
// if specified
threads := parameters["maxthreads"]
switch v := threads.(type) {
case string:
if maxThreads, err = strconv.ParseUint(v, 0, 64); err != nil {
return nil, fmt.Errorf("maxthreads parameter must be an integer, %v invalid", threads)
} }
return New(rootDirectory) case uint64:
maxThreads = v
case int, int32, int64:
val := reflect.ValueOf(v).Convert(reflect.TypeOf(threads)).Int()
// If threads is negative casting to uint64 will wrap around and
// give you the hugest thread limit ever. Let's be sensible, here
if val > 0 {
maxThreads = uint64(val)
}
case uint, uint32:
maxThreads = reflect.ValueOf(v).Convert(reflect.TypeOf(threads)).Uint()
case nil:
// do nothing
default:
return nil, fmt.Errorf("invalid value for maxthreads: %#v", threads)
}
if maxThreads < minThreads {
maxThreads = minThreads
}
}
params := &DriverParameters{
RootDirectory: rootDirectory,
MaxThreads: maxThreads,
}
return params, nil
} }
// New constructs a new Driver with a given rootDirectory // New constructs a new Driver with a given rootDirectory
func New(rootDirectory string) *Driver { func New(params DriverParameters) *Driver {
fsDriver := &driver{rootDirectory: params.RootDirectory}
return &Driver{ return &Driver{
baseEmbed: baseEmbed{ baseEmbed: baseEmbed{
Base: base.Base{ Base: base.Base{
StorageDriver: &driver{ StorageDriver: base.NewRegulator(fsDriver, params.MaxThreads),
rootDirectory: rootDirectory,
},
}, },
}, },
} }
@ -78,7 +143,7 @@ func (d *driver) Name() string {
// GetContent retrieves the content stored at "path" as a []byte. // GetContent retrieves the content stored at "path" as a []byte.
func (d *driver) GetContent(ctx context.Context, path string) ([]byte, error) { func (d *driver) GetContent(ctx context.Context, path string) ([]byte, error) {
rc, err := d.ReadStream(ctx, path, 0) rc, err := d.Reader(ctx, path, 0)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -94,16 +159,22 @@ func (d *driver) GetContent(ctx context.Context, path string) ([]byte, error) {
// PutContent stores the []byte content at a location designated by "path". // PutContent stores the []byte content at a location designated by "path".
func (d *driver) PutContent(ctx context.Context, subPath string, contents []byte) error { func (d *driver) PutContent(ctx context.Context, subPath string, contents []byte) error {
if _, err := d.WriteStream(ctx, subPath, 0, bytes.NewReader(contents)); err != nil { writer, err := d.Writer(ctx, subPath, false)
if err != nil {
return err return err
} }
defer writer.Close()
return os.Truncate(d.fullPath(subPath), int64(len(contents))) _, err = io.Copy(writer, bytes.NewReader(contents))
if err != nil {
writer.Cancel()
return err
}
return writer.Commit()
} }
// ReadStream retrieves an io.ReadCloser for the content stored at "path" with a // Reader retrieves an io.ReadCloser for the content stored at "path" with a
// given byte offset. // given byte offset.
func (d *driver) ReadStream(ctx context.Context, path string, offset int64) (io.ReadCloser, error) { func (d *driver) Reader(ctx context.Context, path string, offset int64) (io.ReadCloser, error) {
file, err := os.OpenFile(d.fullPath(path), os.O_RDONLY, 0644) file, err := os.OpenFile(d.fullPath(path), os.O_RDONLY, 0644)
if err != nil { if err != nil {
if os.IsNotExist(err) { if os.IsNotExist(err) {
@ -125,40 +196,36 @@ func (d *driver) ReadStream(ctx context.Context, path string, offset int64) (io.
return file, nil return file, nil
} }
// WriteStream stores the contents of the provided io.Reader at a location func (d *driver) Writer(ctx context.Context, subPath string, append bool) (storagedriver.FileWriter, error) {
// designated by the given path.
func (d *driver) WriteStream(ctx context.Context, subPath string, offset int64, reader io.Reader) (nn int64, err error) {
// TODO(stevvooe): This needs to be a requirement.
// if !path.IsAbs(subPath) {
// return fmt.Errorf("absolute path required: %q", subPath)
// }
fullPath := d.fullPath(subPath) fullPath := d.fullPath(subPath)
parentDir := path.Dir(fullPath) parentDir := path.Dir(fullPath)
if err := os.MkdirAll(parentDir, 0777); err != nil { if err := os.MkdirAll(parentDir, 0777); err != nil {
return 0, err return nil, err
} }
fp, err := os.OpenFile(fullPath, os.O_WRONLY|os.O_CREATE, 0666) fp, err := os.OpenFile(fullPath, os.O_WRONLY|os.O_CREATE, 0666)
if err != nil { if err != nil {
// TODO(stevvooe): A few missing conditions in storage driver: return nil, err
// 1. What if the path is already a directory?
// 2. Should number 1 be exposed explicitly in storagedriver?
// 2. Can this path not exist, even if we create above?
return 0, err
} }
defer fp.Close()
nn, err = fp.Seek(offset, os.SEEK_SET) var offset int64
if !append {
err := fp.Truncate(0)
if err != nil { if err != nil {
return 0, err fp.Close()
return nil, err
}
} else {
n, err := fp.Seek(0, os.SEEK_END)
if err != nil {
fp.Close()
return nil, err
}
offset = int64(n)
} }
if nn != offset { return newFileWriter(fp, offset), nil
return 0, fmt.Errorf("bad seek to %v, expected %v in fp=%v", offset, nn, fp)
}
return io.Copy(fp, reader)
} }
// Stat retrieves the FileInfo for the given path, including the current size // Stat retrieves the FileInfo for the given path, including the current size
@ -286,3 +353,88 @@ func (fi fileInfo) ModTime() time.Time {
func (fi fileInfo) IsDir() bool { func (fi fileInfo) IsDir() bool {
return fi.FileInfo.IsDir() return fi.FileInfo.IsDir()
} }
type fileWriter struct {
file *os.File
size int64
bw *bufio.Writer
closed bool
committed bool
cancelled bool
}
func newFileWriter(file *os.File, size int64) *fileWriter {
return &fileWriter{
file: file,
size: size,
bw: bufio.NewWriter(file),
}
}
func (fw *fileWriter) Write(p []byte) (int, error) {
if fw.closed {
return 0, fmt.Errorf("already closed")
} else if fw.committed {
return 0, fmt.Errorf("already committed")
} else if fw.cancelled {
return 0, fmt.Errorf("already cancelled")
}
n, err := fw.bw.Write(p)
fw.size += int64(n)
return n, err
}
func (fw *fileWriter) Size() int64 {
return fw.size
}
func (fw *fileWriter) Close() error {
if fw.closed {
return fmt.Errorf("already closed")
}
if err := fw.bw.Flush(); err != nil {
return err
}
if err := fw.file.Sync(); err != nil {
return err
}
if err := fw.file.Close(); err != nil {
return err
}
fw.closed = true
return nil
}
func (fw *fileWriter) Cancel() error {
if fw.closed {
return fmt.Errorf("already closed")
}
fw.cancelled = true
fw.file.Close()
return os.Remove(fw.file.Name())
}
func (fw *fileWriter) Commit() error {
if fw.closed {
return fmt.Errorf("already closed")
} else if fw.committed {
return fmt.Errorf("already committed")
} else if fw.cancelled {
return fmt.Errorf("already cancelled")
}
if err := fw.bw.Flush(); err != nil {
return err
}
if err := fw.file.Sync(); err != nil {
return err
}
fw.committed = true
return nil
}

View file

@ -3,6 +3,7 @@ package filesystem
import ( import (
"io/ioutil" "io/ioutil"
"os" "os"
"reflect"
"testing" "testing"
storagedriver "github.com/docker/distribution/registry/storage/driver" storagedriver "github.com/docker/distribution/registry/storage/driver"
@ -20,7 +21,93 @@ func init() {
} }
defer os.Remove(root) defer os.Remove(root)
driver, err := FromParameters(map[string]interface{}{
"rootdirectory": root,
})
if err != nil {
panic(err)
}
testsuites.RegisterSuite(func() (storagedriver.StorageDriver, error) { testsuites.RegisterSuite(func() (storagedriver.StorageDriver, error) {
return New(root), nil return driver, nil
}, testsuites.NeverSkip) }, testsuites.NeverSkip)
} }
func TestFromParametersImpl(t *testing.T) {
tests := []struct {
params map[string]interface{} // techincally the yaml can contain anything
expected DriverParameters
pass bool
}{
// check we use default threads and root dirs
{
params: map[string]interface{}{},
expected: DriverParameters{
RootDirectory: defaultRootDirectory,
MaxThreads: defaultMaxThreads,
},
pass: true,
},
// Testing initiation with a string maxThreads which can't be parsed
{
params: map[string]interface{}{
"maxthreads": "fail",
},
expected: DriverParameters{},
pass: false,
},
{
params: map[string]interface{}{
"maxthreads": "100",
},
expected: DriverParameters{
RootDirectory: defaultRootDirectory,
MaxThreads: uint64(100),
},
pass: true,
},
{
params: map[string]interface{}{
"maxthreads": 100,
},
expected: DriverParameters{
RootDirectory: defaultRootDirectory,
MaxThreads: uint64(100),
},
pass: true,
},
// check that we use minimum thread counts
{
params: map[string]interface{}{
"maxthreads": 1,
},
expected: DriverParameters{
RootDirectory: defaultRootDirectory,
MaxThreads: minThreads,
},
pass: true,
},
}
for _, item := range tests {
params, err := fromParametersImpl(item.params)
if !item.pass {
// We only need to assert that expected failures have an error
if err == nil {
t.Fatalf("expected error configuring filesystem driver with invalid param: %+v", item.params)
}
continue
}
if err != nil {
t.Fatalf("unexpected error creating filesystem driver: %s", err)
}
// Note that we get a pointer to params back
if !reflect.DeepEqual(*params, item.expected) {
t.Fatalf("unexpected params from filesystem driver. expected %+v, got %+v", item.expected, params)
}
}
}

View file

@ -7,11 +7,8 @@
// Because gcs is a key, value store the Stat call does not support last modification // Because gcs is a key, value store the Stat call does not support last modification
// time for directories (directories are an abstraction for key, value stores) // time for directories (directories are an abstraction for key, value stores)
// //
// Keep in mind that gcs guarantees only eventual consistency, so do not assume // Note that the contents of incomplete uploads are not accessible even though
// that a successful write will mean immediate access to the data written (although // Stat returns their length
// in most regions a new object put has guaranteed read after write). The only true
// guarantee is that once you call Stat and receive a certain file size, that much of
// the file is already accessible.
// //
// +build include_gcs // +build include_gcs
@ -25,7 +22,10 @@ import (
"math/rand" "math/rand"
"net/http" "net/http"
"net/url" "net/url"
"reflect"
"regexp"
"sort" "sort"
"strconv"
"strings" "strings"
"time" "time"
@ -34,7 +34,6 @@ import (
"golang.org/x/oauth2/google" "golang.org/x/oauth2/google"
"golang.org/x/oauth2/jwt" "golang.org/x/oauth2/jwt"
"google.golang.org/api/googleapi" "google.golang.org/api/googleapi"
storageapi "google.golang.org/api/storage/v1"
"google.golang.org/cloud" "google.golang.org/cloud"
"google.golang.org/cloud/storage" "google.golang.org/cloud/storage"
@ -46,8 +45,18 @@ import (
"github.com/docker/distribution/registry/storage/driver/factory" "github.com/docker/distribution/registry/storage/driver/factory"
) )
const driverName = "gcs" const (
const dummyProjectID = "<unknown>" driverName = "gcs"
dummyProjectID = "<unknown>"
uploadSessionContentType = "application/x-docker-upload-session"
minChunkSize = 256 * 1024
defaultChunkSize = 20 * minChunkSize
maxTries = 5
)
var rangeHeader = regexp.MustCompile(`^bytes=([0-9])+-([0-9]+)$`)
// driverParameters is a struct that encapsulates all of the driver parameters after all values have been set // driverParameters is a struct that encapsulates all of the driver parameters after all values have been set
type driverParameters struct { type driverParameters struct {
@ -57,6 +66,7 @@ type driverParameters struct {
privateKey []byte privateKey []byte
client *http.Client client *http.Client
rootDirectory string rootDirectory string
chunkSize int
} }
func init() { func init() {
@ -79,6 +89,7 @@ type driver struct {
email string email string
privateKey []byte privateKey []byte
rootDirectory string rootDirectory string
chunkSize int
} }
// FromParameters constructs a new Driver with a given parameters map // FromParameters constructs a new Driver with a given parameters map
@ -95,6 +106,31 @@ func FromParameters(parameters map[string]interface{}) (storagedriver.StorageDri
rootDirectory = "" rootDirectory = ""
} }
chunkSize := defaultChunkSize
chunkSizeParam, ok := parameters["chunksize"]
if ok {
switch v := chunkSizeParam.(type) {
case string:
vv, err := strconv.Atoi(v)
if err != nil {
return nil, fmt.Errorf("chunksize parameter must be an integer, %v invalid", chunkSizeParam)
}
chunkSize = vv
case int, uint, int32, uint32, uint64, int64:
chunkSize = int(reflect.ValueOf(v).Convert(reflect.TypeOf(chunkSize)).Int())
default:
return nil, fmt.Errorf("invalid valud for chunksize: %#v", chunkSizeParam)
}
if chunkSize < minChunkSize {
return nil, fmt.Errorf("The chunksize %#v parameter should be a number that is larger than or equal to %d", chunkSize, minChunkSize)
}
if chunkSize%minChunkSize != 0 {
return nil, fmt.Errorf("chunksize should be a multiple of %d", minChunkSize)
}
}
var ts oauth2.TokenSource var ts oauth2.TokenSource
jwtConf := new(jwt.Config) jwtConf := new(jwt.Config)
if keyfile, ok := parameters["keyfile"]; ok { if keyfile, ok := parameters["keyfile"]; ok {
@ -113,7 +149,6 @@ func FromParameters(parameters map[string]interface{}) (storagedriver.StorageDri
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
params := driverParameters{ params := driverParameters{
@ -122,6 +157,7 @@ func FromParameters(parameters map[string]interface{}) (storagedriver.StorageDri
email: jwtConf.Email, email: jwtConf.Email,
privateKey: jwtConf.PrivateKey, privateKey: jwtConf.PrivateKey,
client: oauth2.NewClient(context.Background(), ts), client: oauth2.NewClient(context.Background(), ts),
chunkSize: chunkSize,
} }
return New(params) return New(params)
@ -133,12 +169,16 @@ func New(params driverParameters) (storagedriver.StorageDriver, error) {
if rootDirectory != "" { if rootDirectory != "" {
rootDirectory += "/" rootDirectory += "/"
} }
if params.chunkSize <= 0 || params.chunkSize%minChunkSize != 0 {
return nil, fmt.Errorf("Invalid chunksize: %d is not a positive multiple of %d", params.chunkSize, minChunkSize)
}
d := &driver{ d := &driver{
bucket: params.bucket, bucket: params.bucket,
rootDirectory: rootDirectory, rootDirectory: rootDirectory,
email: params.email, email: params.email,
privateKey: params.privateKey, privateKey: params.privateKey,
client: params.client, client: params.client,
chunkSize: params.chunkSize,
} }
return &base.Base{ return &base.Base{
@ -155,7 +195,17 @@ func (d *driver) Name() string {
// GetContent retrieves the content stored at "path" as a []byte. // GetContent retrieves the content stored at "path" as a []byte.
// This should primarily be used for small objects. // This should primarily be used for small objects.
func (d *driver) GetContent(context ctx.Context, path string) ([]byte, error) { func (d *driver) GetContent(context ctx.Context, path string) ([]byte, error) {
rc, err := d.ReadStream(context, path, 0) gcsContext := d.context(context)
name := d.pathToKey(path)
var rc io.ReadCloser
err := retry(func() error {
var err error
rc, err = storage.NewReader(gcsContext, d.bucket, name)
return err
})
if err == storage.ErrObjectNotExist {
return nil, storagedriver.PathNotFoundError{Path: path}
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -171,44 +221,28 @@ func (d *driver) GetContent(context ctx.Context, path string) ([]byte, error) {
// PutContent stores the []byte content at a location designated by "path". // PutContent stores the []byte content at a location designated by "path".
// This should primarily be used for small objects. // This should primarily be used for small objects.
func (d *driver) PutContent(context ctx.Context, path string, contents []byte) error { func (d *driver) PutContent(context ctx.Context, path string, contents []byte) error {
return retry(func() error {
wc := storage.NewWriter(d.context(context), d.bucket, d.pathToKey(path)) wc := storage.NewWriter(d.context(context), d.bucket, d.pathToKey(path))
wc.ContentType = "application/octet-stream" wc.ContentType = "application/octet-stream"
defer wc.Close() return putContentsClose(wc, contents)
_, err := wc.Write(contents) })
return err
} }
// ReadStream retrieves an io.ReadCloser for the content stored at "path" // Reader retrieves an io.ReadCloser for the content stored at "path"
// with a given byte offset. // with a given byte offset.
// May be used to resume reading a stream by providing a nonzero offset. // May be used to resume reading a stream by providing a nonzero offset.
func (d *driver) ReadStream(context ctx.Context, path string, offset int64) (io.ReadCloser, error) { func (d *driver) Reader(context ctx.Context, path string, offset int64) (io.ReadCloser, error) {
name := d.pathToKey(path) res, err := getObject(d.client, d.bucket, d.pathToKey(path), offset)
// copied from google.golang.org/cloud/storage#NewReader :
// to set the additional "Range" header
u := &url.URL{
Scheme: "https",
Host: "storage.googleapis.com",
Path: fmt.Sprintf("/%s/%s", d.bucket, name),
}
req, err := http.NewRequest("GET", u.String(), nil)
if err != nil { if err != nil {
return nil, err if res != nil {
}
if offset > 0 {
req.Header.Set("Range", fmt.Sprintf("bytes=%v-", offset))
}
res, err := d.client.Do(req)
if err != nil {
return nil, err
}
if res.StatusCode == http.StatusNotFound { if res.StatusCode == http.StatusNotFound {
res.Body.Close() res.Body.Close()
return nil, storagedriver.PathNotFoundError{Path: path} return nil, storagedriver.PathNotFoundError{Path: path}
} }
if res.StatusCode == http.StatusRequestedRangeNotSatisfiable { if res.StatusCode == http.StatusRequestedRangeNotSatisfiable {
res.Body.Close() res.Body.Close()
obj, err := storageStatObject(d.context(context), d.bucket, name) obj, err := storageStatObject(d.context(context), d.bucket, d.pathToKey(path))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -217,103 +251,278 @@ func (d *driver) ReadStream(context ctx.Context, path string, offset int64) (io.
} }
return nil, storagedriver.InvalidOffsetError{Path: path, Offset: offset} return nil, storagedriver.InvalidOffsetError{Path: path, Offset: offset}
} }
if res.StatusCode < 200 || res.StatusCode > 299 { }
res.Body.Close() return nil, err
return nil, fmt.Errorf("storage: can't read object %v/%v, status code: %v", d.bucket, name, res.Status) }
if res.Header.Get("Content-Type") == uploadSessionContentType {
defer res.Body.Close()
return nil, storagedriver.PathNotFoundError{Path: path}
} }
return res.Body, nil return res.Body, nil
} }
// WriteStream stores the contents of the provided io.ReadCloser at a func getObject(client *http.Client, bucket string, name string, offset int64) (*http.Response, error) {
// location designated by the given path. // copied from google.golang.org/cloud/storage#NewReader :
// May be used to resume writing a stream by providing a nonzero offset. // to set the additional "Range" header
// The offset must be no larger than the CurrentSize for this path. u := &url.URL{
func (d *driver) WriteStream(context ctx.Context, path string, offset int64, reader io.Reader) (totalRead int64, err error) { Scheme: "https",
if offset < 0 { Host: "storage.googleapis.com",
return 0, storagedriver.InvalidOffsetError{Path: path, Offset: offset} Path: fmt.Sprintf("/%s/%s", bucket, name),
} }
req, err := http.NewRequest("GET", u.String(), nil)
if offset == 0 {
return d.writeCompletely(context, path, 0, reader)
}
service, err := storageapi.New(d.client)
if err != nil { if err != nil {
return 0, err return nil, err
} }
objService := storageapi.NewObjectsService(service) if offset > 0 {
var obj *storageapi.Object req.Header.Set("Range", fmt.Sprintf("bytes=%v-", offset))
err = retry(5, func() error { }
o, err := objService.Get(d.bucket, d.pathToKey(path)).Do() var res *http.Response
obj = o err = retry(func() error {
var err error
res, err = client.Do(req)
return err return err
}) })
// obj, err := retry(5, objService.Get(d.bucket, d.pathToKey(path)).Do) if err != nil {
return nil, err
}
return res, googleapi.CheckMediaResponse(res)
}
// Writer returns a FileWriter which will store the content written to it
// at the location designated by "path" after the call to Commit.
func (d *driver) Writer(context ctx.Context, path string, append bool) (storagedriver.FileWriter, error) {
writer := &writer{
client: d.client,
bucket: d.bucket,
name: d.pathToKey(path),
buffer: make([]byte, d.chunkSize),
}
if append {
err := writer.init(path)
if err != nil {
return nil, err
}
}
return writer, nil
}
type writer struct {
client *http.Client
bucket string
name string
size int64
offset int64
closed bool
sessionURI string
buffer []byte
buffSize int
}
// Cancel removes any written content from this FileWriter.
func (w *writer) Cancel() error {
err := w.checkClosed()
if err != nil {
return err
}
w.closed = true
err = storageDeleteObject(cloud.NewContext(dummyProjectID, w.client), w.bucket, w.name)
if err != nil {
if status, ok := err.(*googleapi.Error); ok {
if status.Code == http.StatusNotFound {
err = nil
}
}
}
return err
}
func (w *writer) Close() error {
if w.closed {
return nil
}
w.closed = true
err := w.writeChunk()
if err != nil {
return err
}
// Copy the remaining bytes from the buffer to the upload session
// Normally buffSize will be smaller than minChunkSize. However, in the
// unlikely event that the upload session failed to start, this number could be higher.
// In this case we can safely clip the remaining bytes to the minChunkSize
if w.buffSize > minChunkSize {
w.buffSize = minChunkSize
}
// commit the writes by updating the upload session
err = retry(func() error {
wc := storage.NewWriter(cloud.NewContext(dummyProjectID, w.client), w.bucket, w.name)
wc.ContentType = uploadSessionContentType
wc.Metadata = map[string]string{
"Session-URI": w.sessionURI,
"Offset": strconv.FormatInt(w.offset, 10),
}
return putContentsClose(wc, w.buffer[0:w.buffSize])
})
if err != nil {
return err
}
w.size = w.offset + int64(w.buffSize)
w.buffSize = 0
return nil
}
func putContentsClose(wc *storage.Writer, contents []byte) error {
size := len(contents)
var nn int
var err error
for nn < size {
n, err := wc.Write(contents[nn:size])
nn += n
if err != nil {
break
}
}
if err != nil {
wc.CloseWithError(err)
return err
}
return wc.Close()
}
// Commit flushes all content written to this FileWriter and makes it
// available for future calls to StorageDriver.GetContent and
// StorageDriver.Reader.
func (w *writer) Commit() error {
if err := w.checkClosed(); err != nil {
return err
}
w.closed = true
// no session started yet just perform a simple upload
if w.sessionURI == "" {
err := retry(func() error {
wc := storage.NewWriter(cloud.NewContext(dummyProjectID, w.client), w.bucket, w.name)
wc.ContentType = "application/octet-stream"
return putContentsClose(wc, w.buffer[0:w.buffSize])
})
if err != nil {
return err
}
w.size = w.offset + int64(w.buffSize)
w.buffSize = 0
return nil
}
size := w.offset + int64(w.buffSize)
var nn int
// loop must be performed at least once to ensure the file is committed even when
// the buffer is empty
for {
n, err := putChunk(w.client, w.sessionURI, w.buffer[nn:w.buffSize], w.offset, size)
nn += int(n)
w.offset += n
w.size = w.offset
if err != nil {
w.buffSize = copy(w.buffer, w.buffer[nn:w.buffSize])
return err
}
if nn == w.buffSize {
break
}
}
w.buffSize = 0
return nil
}
func (w *writer) checkClosed() error {
if w.closed {
return fmt.Errorf("Writer already closed")
}
return nil
}
func (w *writer) writeChunk() error {
var err error
// chunks can be uploaded only in multiples of minChunkSize
// chunkSize is a multiple of minChunkSize less than or equal to buffSize
chunkSize := w.buffSize - (w.buffSize % minChunkSize)
if chunkSize == 0 {
return nil
}
// if their is no sessionURI yet, obtain one by starting the session
if w.sessionURI == "" {
w.sessionURI, err = startSession(w.client, w.bucket, w.name)
}
if err != nil {
return err
}
nn, err := putChunk(w.client, w.sessionURI, w.buffer[0:chunkSize], w.offset, -1)
w.offset += nn
if w.offset > w.size {
w.size = w.offset
}
// shift the remaining bytes to the start of the buffer
w.buffSize = copy(w.buffer, w.buffer[int(nn):w.buffSize])
return err
}
func (w *writer) Write(p []byte) (int, error) {
err := w.checkClosed()
if err != nil { if err != nil {
return 0, err return 0, err
} }
// cannot append more chunks, so redo from scratch var nn int
if obj.ComponentCount >= 1023 { for nn < len(p) {
return d.writeCompletely(context, path, offset, reader) n := copy(w.buffer[w.buffSize:], p[nn:])
} w.buffSize += n
if w.buffSize == cap(w.buffer) {
// skip from reader err = w.writeChunk()
objSize := int64(obj.Size)
nn, err := skip(reader, objSize-offset)
if err != nil { if err != nil {
return nn, err break
} }
}
nn += n
}
return nn, err
}
// Size <= offset // Size returns the number of bytes written to this FileWriter.
partName := fmt.Sprintf("%v#part-%d#", d.pathToKey(path), obj.ComponentCount) func (w *writer) Size() int64 {
gcsContext := d.context(context) return w.size
wc := storage.NewWriter(gcsContext, d.bucket, partName) }
wc.ContentType = "application/octet-stream"
if objSize < offset { func (w *writer) init(path string) error {
err = writeZeros(wc, offset-objSize) res, err := getObject(w.client, w.bucket, w.name, 0)
if err != nil { if err != nil {
wc.CloseWithError(err) return err
return nn, err
} }
defer res.Body.Close()
if res.Header.Get("Content-Type") != uploadSessionContentType {
return storagedriver.PathNotFoundError{Path: path}
} }
n, err := io.Copy(wc, reader) offset, err := strconv.ParseInt(res.Header.Get("X-Goog-Meta-Offset"), 10, 64)
if err != nil { if err != nil {
wc.CloseWithError(err) return err
return nn, err
} }
err = wc.Close() buffer, err := ioutil.ReadAll(res.Body)
if err != nil { if err != nil {
return nn, err return err
} }
// wc was closed successfully, so the temporary part exists, schedule it for deletion at the end w.sessionURI = res.Header.Get("X-Goog-Meta-Session-URI")
// of the function w.buffSize = copy(w.buffer, buffer)
defer storageDeleteObject(gcsContext, d.bucket, partName) w.offset = offset
w.size = offset + int64(w.buffSize)
req := &storageapi.ComposeRequest{ return nil
Destination: &storageapi.Object{Bucket: obj.Bucket, Name: obj.Name, ContentType: obj.ContentType},
SourceObjects: []*storageapi.ComposeRequestSourceObjects{
{
Name: obj.Name,
Generation: obj.Generation,
}, {
Name: partName,
Generation: wc.Object().Generation,
}},
}
err = retry(5, func() error { _, err := objService.Compose(d.bucket, obj.Name, req).Do(); return err })
if err == nil {
nn = nn + n
}
return nn, err
} }
type request func() error type request func() error
func retry(maxTries int, req request) error { func retry(req request) error {
backoff := time.Second backoff := time.Second
var err error var err error
for i := 0; i < maxTries; i++ { for i := 0; i < maxTries; i++ {
@ -335,53 +544,6 @@ func retry(maxTries int, req request) error {
return err return err
} }
func (d *driver) writeCompletely(context ctx.Context, path string, offset int64, reader io.Reader) (totalRead int64, err error) {
wc := storage.NewWriter(d.context(context), d.bucket, d.pathToKey(path))
wc.ContentType = "application/octet-stream"
defer wc.Close()
// Copy the first offset bytes of the existing contents
// (padded with zeros if needed) into the writer
if offset > 0 {
existing, err := d.ReadStream(context, path, 0)
if err != nil {
return 0, err
}
defer existing.Close()
n, err := io.CopyN(wc, existing, offset)
if err == io.EOF {
err = writeZeros(wc, offset-n)
}
if err != nil {
return 0, err
}
}
return io.Copy(wc, reader)
}
func skip(reader io.Reader, count int64) (int64, error) {
if count <= 0 {
return 0, nil
}
return io.CopyN(ioutil.Discard, reader, count)
}
func writeZeros(wc io.Writer, count int64) error {
buf := make([]byte, 32*1024)
for count > 0 {
size := cap(buf)
if int64(size) > count {
size = int(count)
}
n, err := wc.Write(buf[0:size])
if err != nil {
return err
}
count = count - int64(n)
}
return nil
}
// Stat retrieves the FileInfo for the given path, including the current // Stat retrieves the FileInfo for the given path, including the current
// size in bytes and the creation time. // size in bytes and the creation time.
func (d *driver) Stat(context ctx.Context, path string) (storagedriver.FileInfo, error) { func (d *driver) Stat(context ctx.Context, path string) (storagedriver.FileInfo, error) {
@ -390,6 +552,9 @@ func (d *driver) Stat(context ctx.Context, path string) (storagedriver.FileInfo,
gcsContext := d.context(context) gcsContext := d.context(context)
obj, err := storageStatObject(gcsContext, d.bucket, d.pathToKey(path)) obj, err := storageStatObject(gcsContext, d.bucket, d.pathToKey(path))
if err == nil { if err == nil {
if obj.ContentType == uploadSessionContentType {
return nil, storagedriver.PathNotFoundError{Path: path}
}
fi = storagedriver.FileInfoFields{ fi = storagedriver.FileInfoFields{
Path: path, Path: path,
Size: obj.Size, Size: obj.Size,
@ -440,15 +605,10 @@ func (d *driver) List(context ctx.Context, path string) ([]string, error) {
} }
for _, object := range objects.Results { for _, object := range objects.Results {
// GCS does not guarantee strong consistency between // GCS does not guarantee strong consistency between
// DELETE and LIST operationsCheck that the object is not deleted, // DELETE and LIST operations. Check that the object is not deleted,
// so filter out any objects with a non-zero time-deleted // and filter out any objects with a non-zero time-deleted
if object.Deleted.IsZero() { if object.Deleted.IsZero() && object.ContentType != uploadSessionContentType {
name := object.Name list = append(list, d.keyToPath(object.Name))
// Ignore objects with names that end with '#' (these are uploaded parts)
if name[len(name)-1] != '#' {
name = d.keyToPath(name)
list = append(list, name)
}
} }
} }
for _, subpath := range objects.Prefixes { for _, subpath := range objects.Prefixes {
@ -474,7 +634,7 @@ func (d *driver) Move(context ctx.Context, sourcePath string, destPath string) e
gcsContext := d.context(context) gcsContext := d.context(context)
_, err := storageCopyObject(gcsContext, d.bucket, d.pathToKey(sourcePath), d.bucket, d.pathToKey(destPath), nil) _, err := storageCopyObject(gcsContext, d.bucket, d.pathToKey(sourcePath), d.bucket, d.pathToKey(destPath), nil)
if err != nil { if err != nil {
if status := err.(*googleapi.Error); status != nil { if status, ok := err.(*googleapi.Error); ok {
if status.Code == http.StatusNotFound { if status.Code == http.StatusNotFound {
return storagedriver.PathNotFoundError{Path: sourcePath} return storagedriver.PathNotFoundError{Path: sourcePath}
} }
@ -482,7 +642,7 @@ func (d *driver) Move(context ctx.Context, sourcePath string, destPath string) e
return err return err
} }
err = storageDeleteObject(gcsContext, d.bucket, d.pathToKey(sourcePath)) err = storageDeleteObject(gcsContext, d.bucket, d.pathToKey(sourcePath))
// if deleting the file fails, log the error, but do not fail; the file was succesfully copied, // if deleting the file fails, log the error, but do not fail; the file was successfully copied,
// and the original should eventually be cleaned when purging the uploads folder. // and the original should eventually be cleaned when purging the uploads folder.
if err != nil { if err != nil {
logrus.Infof("error deleting file: %v due to %v", sourcePath, err) logrus.Infof("error deleting file: %v due to %v", sourcePath, err)
@ -545,7 +705,7 @@ func (d *driver) Delete(context ctx.Context, path string) error {
} }
err = storageDeleteObject(gcsContext, d.bucket, d.pathToKey(path)) err = storageDeleteObject(gcsContext, d.bucket, d.pathToKey(path))
if err != nil { if err != nil {
if status := err.(*googleapi.Error); status != nil { if status, ok := err.(*googleapi.Error); ok {
if status.Code == http.StatusNotFound { if status.Code == http.StatusNotFound {
return storagedriver.PathNotFoundError{Path: path} return storagedriver.PathNotFoundError{Path: path}
} }
@ -555,14 +715,14 @@ func (d *driver) Delete(context ctx.Context, path string) error {
} }
func storageDeleteObject(context context.Context, bucket string, name string) error { func storageDeleteObject(context context.Context, bucket string, name string) error {
return retry(5, func() error { return retry(func() error {
return storage.DeleteObject(context, bucket, name) return storage.DeleteObject(context, bucket, name)
}) })
} }
func storageStatObject(context context.Context, bucket string, name string) (*storage.Object, error) { func storageStatObject(context context.Context, bucket string, name string) (*storage.Object, error) {
var obj *storage.Object var obj *storage.Object
err := retry(5, func() error { err := retry(func() error {
var err error var err error
obj, err = storage.StatObject(context, bucket, name) obj, err = storage.StatObject(context, bucket, name)
return err return err
@ -572,7 +732,7 @@ func storageStatObject(context context.Context, bucket string, name string) (*st
func storageListObjects(context context.Context, bucket string, q *storage.Query) (*storage.Objects, error) { func storageListObjects(context context.Context, bucket string, q *storage.Query) (*storage.Objects, error) {
var objs *storage.Objects var objs *storage.Objects
err := retry(5, func() error { err := retry(func() error {
var err error var err error
objs, err = storage.ListObjects(context, bucket, q) objs, err = storage.ListObjects(context, bucket, q)
return err return err
@ -582,7 +742,7 @@ func storageListObjects(context context.Context, bucket string, q *storage.Query
func storageCopyObject(context context.Context, srcBucket, srcName string, destBucket, destName string, attrs *storage.ObjectAttrs) (*storage.Object, error) { func storageCopyObject(context context.Context, srcBucket, srcName string, destBucket, destName string, attrs *storage.ObjectAttrs) (*storage.Object, error) {
var obj *storage.Object var obj *storage.Object
err := retry(5, func() error { err := retry(func() error {
var err error var err error
obj, err = storage.CopyObject(context, srcBucket, srcName, destBucket, destName, attrs) obj, err = storage.CopyObject(context, srcBucket, srcName, destBucket, destName, attrs)
return err return err
@ -626,6 +786,80 @@ func (d *driver) URLFor(context ctx.Context, path string, options map[string]int
return storage.SignedURL(d.bucket, name, opts) return storage.SignedURL(d.bucket, name, opts)
} }
func startSession(client *http.Client, bucket string, name string) (uri string, err error) {
u := &url.URL{
Scheme: "https",
Host: "www.googleapis.com",
Path: fmt.Sprintf("/upload/storage/v1/b/%v/o", bucket),
RawQuery: fmt.Sprintf("uploadType=resumable&name=%v", name),
}
err = retry(func() error {
req, err := http.NewRequest("POST", u.String(), nil)
if err != nil {
return err
}
req.Header.Set("X-Upload-Content-Type", "application/octet-stream")
req.Header.Set("Content-Length", "0")
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
err = googleapi.CheckMediaResponse(resp)
if err != nil {
return err
}
uri = resp.Header.Get("Location")
return nil
})
return uri, err
}
func putChunk(client *http.Client, sessionURI string, chunk []byte, from int64, totalSize int64) (int64, error) {
bytesPut := int64(0)
err := retry(func() error {
req, err := http.NewRequest("PUT", sessionURI, bytes.NewReader(chunk))
if err != nil {
return err
}
length := int64(len(chunk))
to := from + length - 1
size := "*"
if totalSize >= 0 {
size = strconv.FormatInt(totalSize, 10)
}
req.Header.Set("Content-Type", "application/octet-stream")
if from == to+1 {
req.Header.Set("Content-Range", fmt.Sprintf("bytes */%v", size))
} else {
req.Header.Set("Content-Range", fmt.Sprintf("bytes %v-%v/%v", from, to, size))
}
req.Header.Set("Content-Length", strconv.FormatInt(length, 10))
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if totalSize < 0 && resp.StatusCode == 308 {
groups := rangeHeader.FindStringSubmatch(resp.Header.Get("Range"))
end, err := strconv.ParseInt(groups[2], 10, 64)
if err != nil {
return err
}
bytesPut = end - from + 1
return nil
}
err = googleapi.CheckMediaResponse(resp)
if err != nil {
return err
}
bytesPut = to - from + 1
return nil
})
return bytesPut, err
}
func (d *driver) context(context ctx.Context) context.Context { func (d *driver) context(context ctx.Context) context.Context {
return cloud.WithContext(context, dummyProjectID, d.client) return cloud.WithContext(context, dummyProjectID, d.client)
} }

View file

@ -75,6 +75,7 @@ func init() {
email: email, email: email,
privateKey: privateKey, privateKey: privateKey,
client: oauth2.NewClient(ctx.Background(), ts), client: oauth2.NewClient(ctx.Background(), ts),
chunkSize: defaultChunkSize,
} }
return New(parameters) return New(parameters)
@ -85,6 +86,102 @@ func init() {
}, skipGCS) }, skipGCS)
} }
// Test Committing a FileWriter without having called Write
func TestCommitEmpty(t *testing.T) {
if skipGCS() != "" {
t.Skip(skipGCS())
}
validRoot, err := ioutil.TempDir("", "driver-")
if err != nil {
t.Fatalf("unexpected error creating temporary directory: %v", err)
}
defer os.Remove(validRoot)
driver, err := gcsDriverConstructor(validRoot)
if err != nil {
t.Fatalf("unexpected error creating rooted driver: %v", err)
}
filename := "/test"
ctx := ctx.Background()
writer, err := driver.Writer(ctx, filename, false)
defer driver.Delete(ctx, filename)
if err != nil {
t.Fatalf("driver.Writer: unexpected error: %v", err)
}
err = writer.Commit()
if err != nil {
t.Fatalf("writer.Commit: unexpected error: %v", err)
}
err = writer.Close()
if err != nil {
t.Fatalf("writer.Close: unexpected error: %v", err)
}
if writer.Size() != 0 {
t.Fatalf("writer.Size: %d != 0", writer.Size())
}
readContents, err := driver.GetContent(ctx, filename)
if err != nil {
t.Fatalf("driver.GetContent: unexpected error: %v", err)
}
if len(readContents) != 0 {
t.Fatalf("len(driver.GetContent(..)): %d != 0", len(readContents))
}
}
// Test Committing a FileWriter after having written exactly
// defaultChunksize bytes.
func TestCommit(t *testing.T) {
if skipGCS() != "" {
t.Skip(skipGCS())
}
validRoot, err := ioutil.TempDir("", "driver-")
if err != nil {
t.Fatalf("unexpected error creating temporary directory: %v", err)
}
defer os.Remove(validRoot)
driver, err := gcsDriverConstructor(validRoot)
if err != nil {
t.Fatalf("unexpected error creating rooted driver: %v", err)
}
filename := "/test"
ctx := ctx.Background()
contents := make([]byte, defaultChunkSize)
writer, err := driver.Writer(ctx, filename, false)
defer driver.Delete(ctx, filename)
if err != nil {
t.Fatalf("driver.Writer: unexpected error: %v", err)
}
_, err = writer.Write(contents)
if err != nil {
t.Fatalf("writer.Write: unexpected error: %v", err)
}
err = writer.Commit()
if err != nil {
t.Fatalf("writer.Commit: unexpected error: %v", err)
}
err = writer.Close()
if err != nil {
t.Fatalf("writer.Close: unexpected error: %v", err)
}
if writer.Size() != int64(len(contents)) {
t.Fatalf("writer.Size: %d != %d", writer.Size(), len(contents))
}
readContents, err := driver.GetContent(ctx, filename)
if err != nil {
t.Fatalf("driver.GetContent: unexpected error: %v", err)
}
if len(readContents) != len(contents) {
t.Fatalf("len(driver.GetContent(..)): %d != %d", len(readContents), len(contents))
}
}
func TestRetry(t *testing.T) { func TestRetry(t *testing.T) {
if skipGCS() != "" { if skipGCS() != "" {
t.Skip(skipGCS()) t.Skip(skipGCS())
@ -100,7 +197,7 @@ func TestRetry(t *testing.T) {
} }
} }
err := retry(2, func() error { err := retry(func() error {
return &googleapi.Error{ return &googleapi.Error{
Code: 503, Code: 503,
Message: "google api error", Message: "google api error",
@ -108,7 +205,7 @@ func TestRetry(t *testing.T) {
}) })
assertError("googleapi: Error 503: google api error", err) assertError("googleapi: Error 503: google api error", err)
err = retry(2, func() error { err = retry(func() error {
return &googleapi.Error{ return &googleapi.Error{
Code: 404, Code: 404,
Message: "google api error", Message: "google api error",
@ -116,7 +213,7 @@ func TestRetry(t *testing.T) {
}) })
assertError("googleapi: Error 404: google api error", err) assertError("googleapi: Error 404: google api error", err)
err = retry(2, func() error { err = retry(func() error {
return fmt.Errorf("error") return fmt.Errorf("error")
}) })
assertError("error", err) assertError("error", err)

View file

@ -1,7 +1,6 @@
package inmemory package inmemory
import ( import (
"bytes"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
@ -74,7 +73,7 @@ func (d *driver) GetContent(ctx context.Context, path string) ([]byte, error) {
d.mutex.RLock() d.mutex.RLock()
defer d.mutex.RUnlock() defer d.mutex.RUnlock()
rc, err := d.ReadStream(ctx, path, 0) rc, err := d.Reader(ctx, path, 0)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -88,7 +87,9 @@ func (d *driver) PutContent(ctx context.Context, p string, contents []byte) erro
d.mutex.Lock() d.mutex.Lock()
defer d.mutex.Unlock() defer d.mutex.Unlock()
f, err := d.root.mkfile(p) normalized := normalize(p)
f, err := d.root.mkfile(normalized)
if err != nil { if err != nil {
// TODO(stevvooe): Again, we need to clarify when this is not a // TODO(stevvooe): Again, we need to clarify when this is not a
// directory in StorageDriver API. // directory in StorageDriver API.
@ -101,9 +102,9 @@ func (d *driver) PutContent(ctx context.Context, p string, contents []byte) erro
return nil return nil
} }
// ReadStream retrieves an io.ReadCloser for the content stored at "path" with a // Reader retrieves an io.ReadCloser for the content stored at "path" with a
// given byte offset. // given byte offset.
func (d *driver) ReadStream(ctx context.Context, path string, offset int64) (io.ReadCloser, error) { func (d *driver) Reader(ctx context.Context, path string, offset int64) (io.ReadCloser, error) {
d.mutex.RLock() d.mutex.RLock()
defer d.mutex.RUnlock() defer d.mutex.RUnlock()
@ -111,10 +112,10 @@ func (d *driver) ReadStream(ctx context.Context, path string, offset int64) (io.
return nil, storagedriver.InvalidOffsetError{Path: path, Offset: offset} return nil, storagedriver.InvalidOffsetError{Path: path, Offset: offset}
} }
path = normalize(path) normalized := normalize(path)
found := d.root.find(path) found := d.root.find(normalized)
if found.path() != path { if found.path() != normalized {
return nil, storagedriver.PathNotFoundError{Path: path} return nil, storagedriver.PathNotFoundError{Path: path}
} }
@ -125,46 +126,24 @@ func (d *driver) ReadStream(ctx context.Context, path string, offset int64) (io.
return ioutil.NopCloser(found.(*file).sectionReader(offset)), nil return ioutil.NopCloser(found.(*file).sectionReader(offset)), nil
} }
// WriteStream stores the contents of the provided io.ReadCloser at a location // Writer returns a FileWriter which will store the content written to it
// designated by the given path. // at the location designated by "path" after the call to Commit.
func (d *driver) WriteStream(ctx context.Context, path string, offset int64, reader io.Reader) (nn int64, err error) { func (d *driver) Writer(ctx context.Context, path string, append bool) (storagedriver.FileWriter, error) {
d.mutex.Lock() d.mutex.Lock()
defer d.mutex.Unlock() defer d.mutex.Unlock()
if offset < 0 {
return 0, storagedriver.InvalidOffsetError{Path: path, Offset: offset}
}
normalized := normalize(path) normalized := normalize(path)
f, err := d.root.mkfile(normalized) f, err := d.root.mkfile(normalized)
if err != nil { if err != nil {
return 0, fmt.Errorf("not a file") return nil, fmt.Errorf("not a file")
} }
// Unlock while we are reading from the source, in case we are reading if !append {
// from the same mfs instance. This can be fixed by a more granular f.truncate()
// locking model.
d.mutex.Unlock()
d.mutex.RLock() // Take the readlock to block other writers.
var buf bytes.Buffer
nn, err = buf.ReadFrom(reader)
if err != nil {
// TODO(stevvooe): This condition is odd and we may need to clarify:
// we've read nn bytes from reader but have written nothing to the
// backend. What is the correct return value? Really, the caller needs
// to know that the reader has been advanced and reattempting the
// operation is incorrect.
d.mutex.RUnlock()
d.mutex.Lock()
return nn, err
} }
d.mutex.RUnlock() return d.newWriter(f), nil
d.mutex.Lock()
f.WriteAt(buf.Bytes(), offset)
return nn, err
} }
// Stat returns info about the provided path. // Stat returns info about the provided path.
@ -173,7 +152,7 @@ func (d *driver) Stat(ctx context.Context, path string) (storagedriver.FileInfo,
defer d.mutex.RUnlock() defer d.mutex.RUnlock()
normalized := normalize(path) normalized := normalize(path)
found := d.root.find(path) found := d.root.find(normalized)
if found.path() != normalized { if found.path() != normalized {
return nil, storagedriver.PathNotFoundError{Path: path} return nil, storagedriver.PathNotFoundError{Path: path}
@ -260,3 +239,74 @@ func (d *driver) Delete(ctx context.Context, path string) error {
func (d *driver) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) { func (d *driver) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) {
return "", storagedriver.ErrUnsupportedMethod{} return "", storagedriver.ErrUnsupportedMethod{}
} }
type writer struct {
d *driver
f *file
closed bool
committed bool
cancelled bool
}
func (d *driver) newWriter(f *file) storagedriver.FileWriter {
return &writer{
d: d,
f: f,
}
}
func (w *writer) Write(p []byte) (int, error) {
if w.closed {
return 0, fmt.Errorf("already closed")
} else if w.committed {
return 0, fmt.Errorf("already committed")
} else if w.cancelled {
return 0, fmt.Errorf("already cancelled")
}
w.d.mutex.Lock()
defer w.d.mutex.Unlock()
return w.f.WriteAt(p, int64(len(w.f.data)))
}
func (w *writer) Size() int64 {
w.d.mutex.RLock()
defer w.d.mutex.RUnlock()
return int64(len(w.f.data))
}
func (w *writer) Close() error {
if w.closed {
return fmt.Errorf("already closed")
}
w.closed = true
return nil
}
func (w *writer) Cancel() error {
if w.closed {
return fmt.Errorf("already closed")
} else if w.committed {
return fmt.Errorf("already committed")
}
w.cancelled = true
w.d.mutex.Lock()
defer w.d.mutex.Unlock()
return w.d.root.delete(w.f.path())
}
func (w *writer) Commit() error {
if w.closed {
return fmt.Errorf("already closed")
} else if w.committed {
return fmt.Errorf("already committed")
} else if w.cancelled {
return fmt.Errorf("already cancelled")
}
w.committed = true
return nil
}

View file

@ -18,7 +18,7 @@ import (
storagemiddleware "github.com/docker/distribution/registry/storage/driver/middleware" storagemiddleware "github.com/docker/distribution/registry/storage/driver/middleware"
) )
// cloudFrontStorageMiddleware provides an simple implementation of layerHandler that // cloudFrontStorageMiddleware provides a simple implementation of layerHandler that
// constructs temporary signed CloudFront URLs from the storagedriver layer URL, // constructs temporary signed CloudFront URLs from the storagedriver layer URL,
// then issues HTTP Temporary Redirects to this CloudFront content URL. // then issues HTTP Temporary Redirects to this CloudFront content URL.
type cloudFrontStorageMiddleware struct { type cloudFrontStorageMiddleware struct {

View file

@ -0,0 +1,50 @@
package middleware
import (
"fmt"
"net/url"
"github.com/docker/distribution/context"
storagedriver "github.com/docker/distribution/registry/storage/driver"
storagemiddleware "github.com/docker/distribution/registry/storage/driver/middleware"
)
type redirectStorageMiddleware struct {
storagedriver.StorageDriver
scheme string
host string
}
var _ storagedriver.StorageDriver = &redirectStorageMiddleware{}
func newRedirectStorageMiddleware(sd storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error) {
o, ok := options["baseurl"]
if !ok {
return nil, fmt.Errorf("no baseurl provided")
}
b, ok := o.(string)
if !ok {
return nil, fmt.Errorf("baseurl must be a string")
}
u, err := url.Parse(b)
if err != nil {
return nil, fmt.Errorf("unable to parse redirect baseurl: %s", b)
}
if u.Scheme == "" {
return nil, fmt.Errorf("no scheme specified for redirect baseurl")
}
if u.Host == "" {
return nil, fmt.Errorf("no host specified for redirect baseurl")
}
return &redirectStorageMiddleware{StorageDriver: sd, scheme: u.Scheme, host: u.Host}, nil
}
func (r *redirectStorageMiddleware) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) {
u := &url.URL{Scheme: r.scheme, Host: r.host, Path: path}
return u.String(), nil
}
func init() {
storagemiddleware.Register("redirect", storagemiddleware.InitFunc(newRedirectStorageMiddleware))
}

View file

@ -0,0 +1,58 @@
package middleware
import (
"testing"
check "gopkg.in/check.v1"
)
func Test(t *testing.T) { check.TestingT(t) }
type MiddlewareSuite struct{}
var _ = check.Suite(&MiddlewareSuite{})
func (s *MiddlewareSuite) TestNoConfig(c *check.C) {
options := make(map[string]interface{})
_, err := newRedirectStorageMiddleware(nil, options)
c.Assert(err, check.ErrorMatches, "no baseurl provided")
}
func (s *MiddlewareSuite) TestMissingScheme(c *check.C) {
options := make(map[string]interface{})
options["baseurl"] = "example.com"
_, err := newRedirectStorageMiddleware(nil, options)
c.Assert(err, check.ErrorMatches, "no scheme specified for redirect baseurl")
}
func (s *MiddlewareSuite) TestHttpsPort(c *check.C) {
options := make(map[string]interface{})
options["baseurl"] = "https://example.com:5443"
middleware, err := newRedirectStorageMiddleware(nil, options)
c.Assert(err, check.Equals, nil)
m, ok := middleware.(*redirectStorageMiddleware)
c.Assert(ok, check.Equals, true)
c.Assert(m.scheme, check.Equals, "https")
c.Assert(m.host, check.Equals, "example.com:5443")
url, err := middleware.URLFor(nil, "/rick/data", nil)
c.Assert(err, check.Equals, nil)
c.Assert(url, check.Equals, "https://example.com:5443/rick/data")
}
func (s *MiddlewareSuite) TestHTTP(c *check.C) {
options := make(map[string]interface{})
options["baseurl"] = "http://example.com"
middleware, err := newRedirectStorageMiddleware(nil, options)
c.Assert(err, check.Equals, nil)
m, ok := middleware.(*redirectStorageMiddleware)
c.Assert(ok, check.Equals, true)
c.Assert(m.scheme, check.Equals, "http")
c.Assert(m.host, check.Equals, "example.com")
url, err := middleware.URLFor(nil, "morty/data", nil)
c.Assert(err, check.Equals, nil)
c.Assert(url, check.Equals, "http://example.com/morty/data")
}

View file

@ -20,7 +20,6 @@ import (
"reflect" "reflect"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"github.com/docker/distribution/context" "github.com/docker/distribution/context"
@ -75,9 +74,6 @@ type driver struct {
ChunkSize int64 ChunkSize int64
Encrypt bool Encrypt bool
RootDirectory string RootDirectory string
pool sync.Pool // pool []byte buffers used for WriteStream
zeros []byte // shared, zero-valued buffer used for WriteStream
} }
type baseEmbed struct { type baseEmbed struct {
@ -99,8 +95,7 @@ type Driver struct {
// - encrypt // - encrypt
func FromParameters(parameters map[string]interface{}) (*Driver, error) { func FromParameters(parameters map[string]interface{}) (*Driver, error) {
// Providing no values for these is valid in case the user is authenticating // Providing no values for these is valid in case the user is authenticating
// with an IAM on an ec2 instance (in which case the instance credentials will
// be summoned when GetAuth is called)
accessKey, ok := parameters["accesskeyid"] accessKey, ok := parameters["accesskeyid"]
if !ok { if !ok {
return nil, fmt.Errorf("No accesskeyid parameter provided") return nil, fmt.Errorf("No accesskeyid parameter provided")
@ -220,11 +215,6 @@ func New(params DriverParameters) (*Driver, error) {
ChunkSize: params.ChunkSize, ChunkSize: params.ChunkSize,
Encrypt: params.Encrypt, Encrypt: params.Encrypt,
RootDirectory: params.RootDirectory, RootDirectory: params.RootDirectory,
zeros: make([]byte, params.ChunkSize),
}
d.pool.New = func() interface{} {
return make([]byte, d.ChunkSize)
} }
return &Driver{ return &Driver{
@ -256,9 +246,9 @@ func (d *driver) PutContent(ctx context.Context, path string, contents []byte) e
return parseError(path, d.Bucket.Put(d.ossPath(path), contents, d.getContentType(), getPermissions(), d.getOptions())) return parseError(path, d.Bucket.Put(d.ossPath(path), contents, d.getContentType(), getPermissions(), d.getOptions()))
} }
// ReadStream retrieves an io.ReadCloser for the content stored at "path" with a // Reader retrieves an io.ReadCloser for the content stored at "path" with a
// given byte offset. // given byte offset.
func (d *driver) ReadStream(ctx context.Context, path string, offset int64) (io.ReadCloser, error) { func (d *driver) Reader(ctx context.Context, path string, offset int64) (io.ReadCloser, error) {
headers := make(http.Header) headers := make(http.Header)
headers.Add("Range", "bytes="+strconv.FormatInt(offset, 10)+"-") headers.Add("Range", "bytes="+strconv.FormatInt(offset, 10)+"-")
@ -279,315 +269,37 @@ func (d *driver) ReadStream(ctx context.Context, path string, offset int64) (io.
return resp.Body, nil return resp.Body, nil
} }
// WriteStream stores the contents of the provided io.Reader at a // Writer returns a FileWriter which will store the content written to it
// location designated by the given path. The driver will know it has // at the location designated by "path" after the call to Commit.
// received the full contents when the reader returns io.EOF. The number func (d *driver) Writer(ctx context.Context, path string, append bool) (storagedriver.FileWriter, error) {
// of successfully READ bytes will be returned, even if an error is key := d.ossPath(path)
// returned. May be used to resume writing a stream by providing a nonzero if !append {
// offset. Offsets past the current size will write from the position // TODO (brianbland): cancel other uploads at this path
// beyond the end of the file. multi, err := d.Bucket.InitMulti(key, d.getContentType(), getPermissions(), d.getOptions())
func (d *driver) WriteStream(ctx context.Context, path string, offset int64, reader io.Reader) (totalRead int64, err error) {
partNumber := 1
bytesRead := 0
var putErrChan chan error
parts := []oss.Part{}
var part oss.Part
done := make(chan struct{}) // stopgap to free up waiting goroutines
multi, err := d.Bucket.InitMulti(d.ossPath(path), d.getContentType(), getPermissions(), d.getOptions())
if err != nil { if err != nil {
return 0, err return nil, err
} }
return d.newWriter(key, multi, nil), nil
buf := d.getbuf()
// We never want to leave a dangling multipart upload, our only consistent state is
// when there is a whole object at path. This is in order to remain consistent with
// the stat call.
//
// Note that if the machine dies before executing the defer, we will be left with a dangling
// multipart upload, which will eventually be cleaned up, but we will lose all of the progress
// made prior to the machine crashing.
defer func() {
if putErrChan != nil {
if putErr := <-putErrChan; putErr != nil {
err = putErr
} }
} multis, _, err := d.Bucket.ListMulti(key, "")
if len(parts) > 0 {
if multi == nil {
// Parts should be empty if the multi is not initialized
panic("Unreachable")
} else {
if multi.Complete(parts) != nil {
multi.Abort()
}
}
}
d.putbuf(buf) // needs to be here to pick up new buf value
close(done) // free up any waiting goroutines
}()
// Fills from 0 to total from current
fromSmallCurrent := func(total int64) error {
current, err := d.ReadStream(ctx, path, 0)
if err != nil { if err != nil {
return err return nil, parseError(path, err)
} }
for _, multi := range multis {
bytesRead = 0 if key != multi.Key {
for int64(bytesRead) < total { continue
//The loop should very rarely enter a second iteration }
nn, err := current.Read(buf[bytesRead:total]) parts, err := multi.ListParts()
bytesRead += nn
if err != nil { if err != nil {
if err != io.EOF { return nil, parseError(path, err)
return err
} }
var multiSize int64
break for _, part := range parts {
multiSize += part.Size
} }
return d.newWriter(key, multi, parts), nil
} }
return nil return nil, storagedriver.PathNotFoundError{Path: path}
}
// Fills from parameter to chunkSize from reader
fromReader := func(from int64) error {
bytesRead = 0
for from+int64(bytesRead) < d.ChunkSize {
nn, err := reader.Read(buf[from+int64(bytesRead):])
totalRead += int64(nn)
bytesRead += nn
if err != nil {
if err != io.EOF {
return err
}
break
}
}
if putErrChan == nil {
putErrChan = make(chan error)
} else {
if putErr := <-putErrChan; putErr != nil {
putErrChan = nil
return putErr
}
}
go func(bytesRead int, from int64, buf []byte) {
defer d.putbuf(buf) // this buffer gets dropped after this call
// DRAGONS(stevvooe): There are few things one might want to know
// about this section. First, the putErrChan is expecting an error
// and a nil or just a nil to come through the channel. This is
// covered by the silly defer below. The other aspect is the OSS
// retry backoff to deal with RequestTimeout errors. Even though
// the underlying OSS library should handle it, it doesn't seem to
// be part of the shouldRetry function (see denverdino/aliyungo/oss).
defer func() {
select {
case putErrChan <- nil: // for some reason, we do this no matter what.
case <-done:
return // ensure we don't leak the goroutine
}
}()
if bytesRead <= 0 {
return
}
var err error
var part oss.Part
part, err = multi.PutPartWithTimeout(int(partNumber), bytes.NewReader(buf[0:int64(bytesRead)+from]), defaultTimeout)
if err != nil {
logrus.Errorf("error putting part, aborting: %v", err)
select {
case putErrChan <- err:
case <-done:
return // don't leak the goroutine
}
}
// parts and partNumber are safe, because this function is the
// only one modifying them and we force it to be executed
// serially.
parts = append(parts, part)
partNumber++
}(bytesRead, from, buf)
buf = d.getbuf() // use a new buffer for the next call
return nil
}
if offset > 0 {
resp, err := d.Bucket.Head(d.ossPath(path), nil)
if err != nil {
if ossErr, ok := err.(*oss.Error); !ok || ossErr.StatusCode != http.StatusNotFound {
return 0, err
}
}
currentLength := int64(0)
if err == nil {
currentLength = resp.ContentLength
}
if currentLength >= offset {
if offset < d.ChunkSize {
// chunkSize > currentLength >= offset
if err = fromSmallCurrent(offset); err != nil {
return totalRead, err
}
if err = fromReader(offset); err != nil {
return totalRead, err
}
if totalRead+offset < d.ChunkSize {
return totalRead, nil
}
} else {
// currentLength >= offset >= chunkSize
_, part, err = multi.PutPartCopy(partNumber,
oss.CopyOptions{CopySourceOptions: "bytes=0-" + strconv.FormatInt(offset-1, 10)},
d.Bucket.Path(d.ossPath(path)))
if err != nil {
return 0, err
}
parts = append(parts, part)
partNumber++
}
} else {
// Fills between parameters with 0s but only when to - from <= chunkSize
fromZeroFillSmall := func(from, to int64) error {
bytesRead = 0
for from+int64(bytesRead) < to {
nn, err := bytes.NewReader(d.zeros).Read(buf[from+int64(bytesRead) : to])
bytesRead += nn
if err != nil {
return err
}
}
return nil
}
// Fills between parameters with 0s, making new parts
fromZeroFillLarge := func(from, to int64) error {
bytesRead64 := int64(0)
for to-(from+bytesRead64) >= d.ChunkSize {
part, err := multi.PutPartWithTimeout(int(partNumber), bytes.NewReader(d.zeros), defaultTimeout)
if err != nil {
return err
}
bytesRead64 += d.ChunkSize
parts = append(parts, part)
partNumber++
}
return fromZeroFillSmall(0, (to-from)%d.ChunkSize)
}
// currentLength < offset
if currentLength < d.ChunkSize {
if offset < d.ChunkSize {
// chunkSize > offset > currentLength
if err = fromSmallCurrent(currentLength); err != nil {
return totalRead, err
}
if err = fromZeroFillSmall(currentLength, offset); err != nil {
return totalRead, err
}
if err = fromReader(offset); err != nil {
return totalRead, err
}
if totalRead+offset < d.ChunkSize {
return totalRead, nil
}
} else {
// offset >= chunkSize > currentLength
if err = fromSmallCurrent(currentLength); err != nil {
return totalRead, err
}
if err = fromZeroFillSmall(currentLength, d.ChunkSize); err != nil {
return totalRead, err
}
part, err = multi.PutPartWithTimeout(int(partNumber), bytes.NewReader(buf), defaultTimeout)
if err != nil {
return totalRead, err
}
parts = append(parts, part)
partNumber++
//Zero fill from chunkSize up to offset, then some reader
if err = fromZeroFillLarge(d.ChunkSize, offset); err != nil {
return totalRead, err
}
if err = fromReader(offset % d.ChunkSize); err != nil {
return totalRead, err
}
if totalRead+(offset%d.ChunkSize) < d.ChunkSize {
return totalRead, nil
}
}
} else {
// offset > currentLength >= chunkSize
_, part, err = multi.PutPartCopy(partNumber,
oss.CopyOptions{},
d.Bucket.Path(d.ossPath(path)))
if err != nil {
return 0, err
}
parts = append(parts, part)
partNumber++
//Zero fill from currentLength up to offset, then some reader
if err = fromZeroFillLarge(currentLength, offset); err != nil {
return totalRead, err
}
if err = fromReader((offset - currentLength) % d.ChunkSize); err != nil {
return totalRead, err
}
if totalRead+((offset-currentLength)%d.ChunkSize) < d.ChunkSize {
return totalRead, nil
}
}
}
}
for {
if err = fromReader(0); err != nil {
return totalRead, err
}
if int64(bytesRead) < d.ChunkSize {
break
}
}
return totalRead, nil
} }
// Stat retrieves the FileInfo for the given path, including the current size // Stat retrieves the FileInfo for the given path, including the current size
@ -778,12 +490,181 @@ func (d *driver) getContentType() string {
return "application/octet-stream" return "application/octet-stream"
} }
// getbuf returns a buffer from the driver's pool with length d.ChunkSize. // writer attempts to upload parts to S3 in a buffered fashion where the last
func (d *driver) getbuf() []byte { // part is at least as large as the chunksize, so the multipart upload could be
return d.pool.Get().([]byte) // cleanly resumed in the future. This is violated if Close is called after less
// than a full chunk is written.
type writer struct {
driver *driver
key string
multi *oss.Multi
parts []oss.Part
size int64
readyPart []byte
pendingPart []byte
closed bool
committed bool
cancelled bool
} }
func (d *driver) putbuf(p []byte) { func (d *driver) newWriter(key string, multi *oss.Multi, parts []oss.Part) storagedriver.FileWriter {
copy(p, d.zeros) var size int64
d.pool.Put(p) for _, part := range parts {
size += part.Size
}
return &writer{
driver: d,
key: key,
multi: multi,
parts: parts,
size: size,
}
}
func (w *writer) Write(p []byte) (int, error) {
if w.closed {
return 0, fmt.Errorf("already closed")
} else if w.committed {
return 0, fmt.Errorf("already committed")
} else if w.cancelled {
return 0, fmt.Errorf("already cancelled")
}
// If the last written part is smaller than minChunkSize, we need to make a
// new multipart upload :sadface:
if len(w.parts) > 0 && int(w.parts[len(w.parts)-1].Size) < minChunkSize {
err := w.multi.Complete(w.parts)
if err != nil {
w.multi.Abort()
return 0, err
}
multi, err := w.driver.Bucket.InitMulti(w.key, w.driver.getContentType(), getPermissions(), w.driver.getOptions())
if err != nil {
return 0, err
}
w.multi = multi
// If the entire written file is smaller than minChunkSize, we need to make
// a new part from scratch :double sad face:
if w.size < minChunkSize {
contents, err := w.driver.Bucket.Get(w.key)
if err != nil {
return 0, err
}
w.parts = nil
w.readyPart = contents
} else {
// Otherwise we can use the old file as the new first part
_, part, err := multi.PutPartCopy(1, oss.CopyOptions{}, w.driver.Bucket.Name+"/"+w.key)
if err != nil {
return 0, err
}
w.parts = []oss.Part{part}
}
}
var n int
for len(p) > 0 {
// If no parts are ready to write, fill up the first part
if neededBytes := int(w.driver.ChunkSize) - len(w.readyPart); neededBytes > 0 {
if len(p) >= neededBytes {
w.readyPart = append(w.readyPart, p[:neededBytes]...)
n += neededBytes
p = p[neededBytes:]
} else {
w.readyPart = append(w.readyPart, p...)
n += len(p)
p = nil
}
}
if neededBytes := int(w.driver.ChunkSize) - len(w.pendingPart); neededBytes > 0 {
if len(p) >= neededBytes {
w.pendingPart = append(w.pendingPart, p[:neededBytes]...)
n += neededBytes
p = p[neededBytes:]
err := w.flushPart()
if err != nil {
w.size += int64(n)
return n, err
}
} else {
w.pendingPart = append(w.pendingPart, p...)
n += len(p)
p = nil
}
}
}
w.size += int64(n)
return n, nil
}
func (w *writer) Size() int64 {
return w.size
}
func (w *writer) Close() error {
if w.closed {
return fmt.Errorf("already closed")
}
w.closed = true
return w.flushPart()
}
func (w *writer) Cancel() error {
if w.closed {
return fmt.Errorf("already closed")
} else if w.committed {
return fmt.Errorf("already committed")
}
w.cancelled = true
err := w.multi.Abort()
return err
}
func (w *writer) Commit() error {
if w.closed {
return fmt.Errorf("already closed")
} else if w.committed {
return fmt.Errorf("already committed")
} else if w.cancelled {
return fmt.Errorf("already cancelled")
}
err := w.flushPart()
if err != nil {
return err
}
w.committed = true
err = w.multi.Complete(w.parts)
if err != nil {
w.multi.Abort()
return err
}
return nil
}
// flushPart flushes buffers to write a part to S3.
// Only called by Write (with both buffers full) and Close/Commit (always)
func (w *writer) flushPart() error {
if len(w.readyPart) == 0 && len(w.pendingPart) == 0 {
// nothing to write
return nil
}
if len(w.pendingPart) < int(w.driver.ChunkSize) {
// closing with a small pending part
// combine ready and pending to avoid writing a small part
w.readyPart = append(w.readyPart, w.pendingPart...)
w.pendingPart = nil
}
part, err := w.multi.PutPart(len(w.parts)+1, bytes.NewReader(w.readyPart))
if err != nil {
return err
}
w.parts = append(w.parts, part)
w.readyPart = w.pendingPart
w.pendingPart = nil
return nil
} }

View file

@ -1,3 +0,0 @@
// Package rados implements the rados storage driver backend. Support can be
// enabled by including the "include_rados" build tag.
package rados

View file

@ -1,632 +0,0 @@
// +build include_rados
package rados
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"io/ioutil"
"path"
"strconv"
log "github.com/Sirupsen/logrus"
"github.com/docker/distribution/context"
storagedriver "github.com/docker/distribution/registry/storage/driver"
"github.com/docker/distribution/registry/storage/driver/base"
"github.com/docker/distribution/registry/storage/driver/factory"
"github.com/docker/distribution/uuid"
"github.com/noahdesu/go-ceph/rados"
)
const driverName = "rados"
// Prefix all the stored blob
const objectBlobPrefix = "blob:"
// Stripes objects size to 4M
const defaultChunkSize = 4 << 20
const defaultXattrTotalSizeName = "total-size"
// Max number of keys fetched from omap at each read operation
const defaultKeysFetched = 1
//DriverParameters A struct that encapsulates all of the driver parameters after all values have been set
type DriverParameters struct {
poolname string
username string
chunksize uint64
}
func init() {
factory.Register(driverName, &radosDriverFactory{})
}
// radosDriverFactory implements the factory.StorageDriverFactory interface
type radosDriverFactory struct{}
func (factory *radosDriverFactory) Create(parameters map[string]interface{}) (storagedriver.StorageDriver, error) {
return FromParameters(parameters)
}
type driver struct {
Conn *rados.Conn
Ioctx *rados.IOContext
chunksize uint64
}
type baseEmbed struct {
base.Base
}
// Driver is a storagedriver.StorageDriver implementation backed by Ceph RADOS
// Objects are stored at absolute keys in the provided bucket.
type Driver struct {
baseEmbed
}
// FromParameters constructs a new Driver with a given parameters map
// Required parameters:
// - poolname: the ceph pool name
func FromParameters(parameters map[string]interface{}) (*Driver, error) {
pool, ok := parameters["poolname"]
if !ok {
return nil, fmt.Errorf("No poolname parameter provided")
}
username, ok := parameters["username"]
if !ok {
username = ""
}
chunksize := uint64(defaultChunkSize)
chunksizeParam, ok := parameters["chunksize"]
if ok {
chunksize, ok = chunksizeParam.(uint64)
if !ok {
return nil, fmt.Errorf("The chunksize parameter should be a number")
}
}
params := DriverParameters{
fmt.Sprint(pool),
fmt.Sprint(username),
chunksize,
}
return New(params)
}
// New constructs a new Driver
func New(params DriverParameters) (*Driver, error) {
var conn *rados.Conn
var err error
if params.username != "" {
log.Infof("Opening connection to pool %s using user %s", params.poolname, params.username)
conn, err = rados.NewConnWithUser(params.username)
} else {
log.Infof("Opening connection to pool %s", params.poolname)
conn, err = rados.NewConn()
}
if err != nil {
return nil, err
}
err = conn.ReadDefaultConfigFile()
if err != nil {
return nil, err
}
err = conn.Connect()
if err != nil {
return nil, err
}
log.Infof("Connected")
ioctx, err := conn.OpenIOContext(params.poolname)
log.Infof("Connected to pool %s", params.poolname)
if err != nil {
return nil, err
}
d := &driver{
Ioctx: ioctx,
Conn: conn,
chunksize: params.chunksize,
}
return &Driver{
baseEmbed: baseEmbed{
Base: base.Base{
StorageDriver: d,
},
},
}, nil
}
// Implement the storagedriver.StorageDriver interface
func (d *driver) Name() string {
return driverName
}
// GetContent retrieves the content stored at "path" as a []byte.
func (d *driver) GetContent(ctx context.Context, path string) ([]byte, error) {
rc, err := d.ReadStream(ctx, path, 0)
if err != nil {
return nil, err
}
defer rc.Close()
p, err := ioutil.ReadAll(rc)
if err != nil {
return nil, err
}
return p, nil
}
// PutContent stores the []byte content at a location designated by "path".
func (d *driver) PutContent(ctx context.Context, path string, contents []byte) error {
if _, err := d.WriteStream(ctx, path, 0, bytes.NewReader(contents)); err != nil {
return err
}
return nil
}
// ReadStream retrieves an io.ReadCloser for the content stored at "path" with a
// given byte offset.
type readStreamReader struct {
driver *driver
oid string
size uint64
offset uint64
}
func (r *readStreamReader) Read(b []byte) (n int, err error) {
// Determine the part available to read
bufferOffset := uint64(0)
bufferSize := uint64(len(b))
// End of the object, read less than the buffer size
if bufferSize > r.size-r.offset {
bufferSize = r.size - r.offset
}
// Fill `b`
for bufferOffset < bufferSize {
// Get the offset in the object chunk
chunkedOid, chunkedOffset := r.driver.getChunkNameFromOffset(r.oid, r.offset)
// Determine the best size to read
bufferEndOffset := bufferSize
if bufferEndOffset-bufferOffset > r.driver.chunksize-chunkedOffset {
bufferEndOffset = bufferOffset + (r.driver.chunksize - chunkedOffset)
}
// Read the chunk
n, err = r.driver.Ioctx.Read(chunkedOid, b[bufferOffset:bufferEndOffset], chunkedOffset)
if err != nil {
return int(bufferOffset), err
}
bufferOffset += uint64(n)
r.offset += uint64(n)
}
// EOF if the offset is at the end of the object
if r.offset == r.size {
return int(bufferOffset), io.EOF
}
return int(bufferOffset), nil
}
func (r *readStreamReader) Close() error {
return nil
}
func (d *driver) ReadStream(ctx context.Context, path string, offset int64) (io.ReadCloser, error) {
// get oid from filename
oid, err := d.getOid(path)
if err != nil {
return nil, err
}
// get object stat
stat, err := d.Stat(ctx, path)
if err != nil {
return nil, err
}
if offset > stat.Size() {
return nil, storagedriver.InvalidOffsetError{Path: path, Offset: offset}
}
return &readStreamReader{
driver: d,
oid: oid,
size: uint64(stat.Size()),
offset: uint64(offset),
}, nil
}
func (d *driver) WriteStream(ctx context.Context, path string, offset int64, reader io.Reader) (totalRead int64, err error) {
buf := make([]byte, d.chunksize)
totalRead = 0
oid, err := d.getOid(path)
if err != nil {
switch err.(type) {
// Trying to write new object, generate new blob identifier for it
case storagedriver.PathNotFoundError:
oid = d.generateOid()
err = d.putOid(path, oid)
if err != nil {
return 0, err
}
default:
return 0, err
}
} else {
// Check total object size only for existing ones
totalSize, err := d.getXattrTotalSize(ctx, oid)
if err != nil {
return 0, err
}
// If offset if after the current object size, fill the gap with zeros
for totalSize < uint64(offset) {
sizeToWrite := d.chunksize
if totalSize-uint64(offset) < sizeToWrite {
sizeToWrite = totalSize - uint64(offset)
}
chunkName, chunkOffset := d.getChunkNameFromOffset(oid, uint64(totalSize))
err = d.Ioctx.Write(chunkName, buf[:sizeToWrite], uint64(chunkOffset))
if err != nil {
return totalRead, err
}
totalSize += sizeToWrite
}
}
// Writer
for {
// Align to chunk size
sizeRead := uint64(0)
sizeToRead := uint64(offset+totalRead) % d.chunksize
if sizeToRead == 0 {
sizeToRead = d.chunksize
}
// Read from `reader`
for sizeRead < sizeToRead {
nn, err := reader.Read(buf[sizeRead:sizeToRead])
sizeRead += uint64(nn)
if err != nil {
if err != io.EOF {
return totalRead, err
}
break
}
}
// End of file and nothing was read
if sizeRead == 0 {
break
}
// Write chunk object
chunkName, chunkOffset := d.getChunkNameFromOffset(oid, uint64(offset+totalRead))
err = d.Ioctx.Write(chunkName, buf[:sizeRead], uint64(chunkOffset))
if err != nil {
return totalRead, err
}
// Update total object size as xattr in the first chunk of the object
err = d.setXattrTotalSize(oid, uint64(offset+totalRead)+sizeRead)
if err != nil {
return totalRead, err
}
totalRead += int64(sizeRead)
// End of file
if sizeRead < sizeToRead {
break
}
}
return totalRead, nil
}
// Stat retrieves the FileInfo for the given path, including the current size
func (d *driver) Stat(ctx context.Context, path string) (storagedriver.FileInfo, error) {
// get oid from filename
oid, err := d.getOid(path)
if err != nil {
return nil, err
}
// the path is a virtual directory?
if oid == "" {
return storagedriver.FileInfoInternal{
FileInfoFields: storagedriver.FileInfoFields{
Path: path,
Size: 0,
IsDir: true,
},
}, nil
}
// stat first chunk
stat, err := d.Ioctx.Stat(oid + "-0")
if err != nil {
return nil, err
}
// get total size of chunked object
totalSize, err := d.getXattrTotalSize(ctx, oid)
if err != nil {
return nil, err
}
return storagedriver.FileInfoInternal{
FileInfoFields: storagedriver.FileInfoFields{
Path: path,
Size: int64(totalSize),
ModTime: stat.ModTime,
},
}, nil
}
// List returns a list of the objects that are direct descendants of the given path.
func (d *driver) List(ctx context.Context, dirPath string) ([]string, error) {
files, err := d.listDirectoryOid(dirPath)
if err != nil {
return nil, storagedriver.PathNotFoundError{Path: dirPath}
}
keys := make([]string, 0, len(files))
for k := range files {
if k != dirPath {
keys = append(keys, path.Join(dirPath, k))
}
}
return keys, nil
}
// Move moves an object stored at sourcePath to destPath, removing the original
// object.
func (d *driver) Move(ctx context.Context, sourcePath string, destPath string) error {
// Get oid
oid, err := d.getOid(sourcePath)
if err != nil {
return err
}
// Move reference
err = d.putOid(destPath, oid)
if err != nil {
return err
}
// Delete old reference
err = d.deleteOid(sourcePath)
if err != nil {
return err
}
return nil
}
// Delete recursively deletes all objects stored at "path" and its subpaths.
func (d *driver) Delete(ctx context.Context, objectPath string) error {
// Get oid
oid, err := d.getOid(objectPath)
if err != nil {
return err
}
// Deleting virtual directory
if oid == "" {
objects, err := d.listDirectoryOid(objectPath)
if err != nil {
return err
}
for object := range objects {
err = d.Delete(ctx, path.Join(objectPath, object))
if err != nil {
return err
}
}
} else {
// Delete object chunks
totalSize, err := d.getXattrTotalSize(ctx, oid)
if err != nil {
return err
}
for offset := uint64(0); offset < totalSize; offset += d.chunksize {
chunkName, _ := d.getChunkNameFromOffset(oid, offset)
err = d.Ioctx.Delete(chunkName)
if err != nil {
return err
}
}
// Delete reference
err = d.deleteOid(objectPath)
if err != nil {
return err
}
}
return nil
}
// URLFor returns a URL which may be used to retrieve the content stored at the given path.
// May return an UnsupportedMethodErr in certain StorageDriver implementations.
func (d *driver) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) {
return "", storagedriver.ErrUnsupportedMethod{}
}
// Generate a blob identifier
func (d *driver) generateOid() string {
return objectBlobPrefix + uuid.Generate().String()
}
// Reference a object and its hierarchy
func (d *driver) putOid(objectPath string, oid string) error {
directory := path.Dir(objectPath)
base := path.Base(objectPath)
createParentReference := true
// After creating this reference, skip the parents referencing since the
// hierarchy already exists
if oid == "" {
firstReference, err := d.Ioctx.GetOmapValues(directory, "", "", 1)
if (err == nil) && (len(firstReference) > 0) {
createParentReference = false
}
}
oids := map[string][]byte{
base: []byte(oid),
}
// Reference object
err := d.Ioctx.SetOmap(directory, oids)
if err != nil {
return err
}
// Esure parent virtual directories
if createParentReference {
return d.putOid(directory, "")
}
return nil
}
// Get the object identifier from an object name
func (d *driver) getOid(objectPath string) (string, error) {
directory := path.Dir(objectPath)
base := path.Base(objectPath)
files, err := d.Ioctx.GetOmapValues(directory, "", base, 1)
if (err != nil) || (files[base] == nil) {
return "", storagedriver.PathNotFoundError{Path: objectPath}
}
return string(files[base]), nil
}
// List the objects of a virtual directory
func (d *driver) listDirectoryOid(path string) (list map[string][]byte, err error) {
return d.Ioctx.GetAllOmapValues(path, "", "", defaultKeysFetched)
}
// Remove a file from the files hierarchy
func (d *driver) deleteOid(objectPath string) error {
// Remove object reference
directory := path.Dir(objectPath)
base := path.Base(objectPath)
err := d.Ioctx.RmOmapKeys(directory, []string{base})
if err != nil {
return err
}
// Remove virtual directory if empty (no more references)
firstReference, err := d.Ioctx.GetOmapValues(directory, "", "", 1)
if err != nil {
return err
}
if len(firstReference) == 0 {
// Delete omap
err := d.Ioctx.Delete(directory)
if err != nil {
return err
}
// Remove reference on parent omaps
if directory != "" {
return d.deleteOid(directory)
}
}
return nil
}
// Takes an offset in an chunked object and return the chunk name and a new
// offset in this chunk object
func (d *driver) getChunkNameFromOffset(oid string, offset uint64) (string, uint64) {
chunkID := offset / d.chunksize
chunkedOid := oid + "-" + strconv.FormatInt(int64(chunkID), 10)
chunkedOffset := offset % d.chunksize
return chunkedOid, chunkedOffset
}
// Set the total size of a chunked object `oid`
func (d *driver) setXattrTotalSize(oid string, size uint64) error {
// Convert uint64 `size` to []byte
xattr := make([]byte, binary.MaxVarintLen64)
binary.LittleEndian.PutUint64(xattr, size)
// Save the total size as a xattr in the first chunk
return d.Ioctx.SetXattr(oid+"-0", defaultXattrTotalSizeName, xattr)
}
// Get the total size of the chunked object `oid` stored as xattr
func (d *driver) getXattrTotalSize(ctx context.Context, oid string) (uint64, error) {
// Fetch xattr as []byte
xattr := make([]byte, binary.MaxVarintLen64)
xattrLength, err := d.Ioctx.GetXattr(oid+"-0", defaultXattrTotalSizeName, xattr)
if err != nil {
return 0, err
}
if xattrLength != len(xattr) {
context.GetLogger(ctx).Errorf("object %s xattr length mismatch: %d != %d", oid, xattrLength, len(xattr))
return 0, storagedriver.PathNotFoundError{Path: oid}
}
// Convert []byte as uint64
totalSize := binary.LittleEndian.Uint64(xattr)
return totalSize, nil
}

View file

@ -1,40 +0,0 @@
// +build include_rados
package rados
import (
"os"
"testing"
storagedriver "github.com/docker/distribution/registry/storage/driver"
"github.com/docker/distribution/registry/storage/driver/testsuites"
"gopkg.in/check.v1"
)
// Hook up gocheck into the "go test" runner.
func Test(t *testing.T) { check.TestingT(t) }
func init() {
poolname := os.Getenv("RADOS_POOL")
username := os.Getenv("RADOS_USER")
driverConstructor := func() (storagedriver.StorageDriver, error) {
parameters := DriverParameters{
poolname,
username,
defaultChunkSize,
}
return New(parameters)
}
skipCheck := func() string {
if poolname == "" {
return "RADOS_POOL must be set to run Rado tests"
}
return ""
}
testsuites.RegisterSuite(driverConstructor, skipCheck)
}

View file

@ -18,12 +18,11 @@ import (
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"reflect" "reflect"
"sort"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"github.com/Sirupsen/logrus"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/credentials"
@ -60,7 +59,9 @@ type DriverParameters struct {
SecretKey string SecretKey string
Bucket string Bucket string
Region string Region string
RegionEndpoint string
Encrypt bool Encrypt bool
KeyID string
Secure bool Secure bool
ChunkSize int64 ChunkSize int64
RootDirectory string RootDirectory string
@ -80,6 +81,8 @@ func init() {
"ap-northeast-1", "ap-northeast-1",
"ap-northeast-2", "ap-northeast-2",
"sa-east-1", "sa-east-1",
"cn-north-1",
"us-gov-west-1",
} { } {
validRegions[region] = struct{}{} validRegions[region] = struct{}{}
} }
@ -101,11 +104,9 @@ type driver struct {
Bucket string Bucket string
ChunkSize int64 ChunkSize int64
Encrypt bool Encrypt bool
KeyID string
RootDirectory string RootDirectory string
StorageClass string StorageClass string
pool sync.Pool // pool []byte buffers used for WriteStream
zeros []byte // shared, zero-valued buffer used for WriteStream
} }
type baseEmbed struct { type baseEmbed struct {
@ -129,51 +130,78 @@ func FromParameters(parameters map[string]interface{}) (*Driver, error) {
// Providing no values for these is valid in case the user is authenticating // Providing no values for these is valid in case the user is authenticating
// with an IAM on an ec2 instance (in which case the instance credentials will // with an IAM on an ec2 instance (in which case the instance credentials will
// be summoned when GetAuth is called) // be summoned when GetAuth is called)
accessKey, ok := parameters["accesskey"] accessKey := parameters["accesskey"]
if !ok { if accessKey == nil {
accessKey = "" accessKey = ""
} }
secretKey, ok := parameters["secretkey"] secretKey := parameters["secretkey"]
if !ok { if secretKey == nil {
secretKey = "" secretKey = ""
} }
regionEndpoint := parameters["regionendpoint"]
if regionEndpoint == nil {
regionEndpoint = ""
}
regionName, ok := parameters["region"] regionName, ok := parameters["region"]
if !ok || fmt.Sprint(regionName) == "" { if regionName == nil || fmt.Sprint(regionName) == "" {
return nil, fmt.Errorf("No region parameter provided") return nil, fmt.Errorf("No region parameter provided")
} }
region := fmt.Sprint(regionName) region := fmt.Sprint(regionName)
_, ok = validRegions[region] // Don't check the region value if a custom endpoint is provided.
if !ok { if regionEndpoint == "" {
if _, ok = validRegions[region]; !ok {
return nil, fmt.Errorf("Invalid region provided: %v", region) return nil, fmt.Errorf("Invalid region provided: %v", region)
} }
}
bucket, ok := parameters["bucket"] bucket := parameters["bucket"]
if !ok || fmt.Sprint(bucket) == "" { if bucket == nil || fmt.Sprint(bucket) == "" {
return nil, fmt.Errorf("No bucket parameter provided") return nil, fmt.Errorf("No bucket parameter provided")
} }
encryptBool := false encryptBool := false
encrypt, ok := parameters["encrypt"] encrypt := parameters["encrypt"]
if ok { switch encrypt := encrypt.(type) {
encryptBool, ok = encrypt.(bool) case string:
if !ok { b, err := strconv.ParseBool(encrypt)
if err != nil {
return nil, fmt.Errorf("The encrypt parameter should be a boolean") return nil, fmt.Errorf("The encrypt parameter should be a boolean")
} }
encryptBool = b
case bool:
encryptBool = encrypt
case nil:
// do nothing
default:
return nil, fmt.Errorf("The encrypt parameter should be a boolean")
} }
secureBool := true secureBool := true
secure, ok := parameters["secure"] secure := parameters["secure"]
if ok { switch secure := secure.(type) {
secureBool, ok = secure.(bool) case string:
if !ok { b, err := strconv.ParseBool(secure)
if err != nil {
return nil, fmt.Errorf("The secure parameter should be a boolean") return nil, fmt.Errorf("The secure parameter should be a boolean")
} }
secureBool = b
case bool:
secureBool = secure
case nil:
// do nothing
default:
return nil, fmt.Errorf("The secure parameter should be a boolean")
}
keyID := parameters["keyid"]
if keyID == nil {
keyID = ""
} }
chunkSize := int64(defaultChunkSize) chunkSize := int64(defaultChunkSize)
chunkSizeParam, ok := parameters["chunksize"] chunkSizeParam := parameters["chunksize"]
if ok {
switch v := chunkSizeParam.(type) { switch v := chunkSizeParam.(type) {
case string: case string:
vv, err := strconv.ParseInt(v, 0, 64) vv, err := strconv.ParseInt(v, 0, 64)
@ -185,23 +213,24 @@ func FromParameters(parameters map[string]interface{}) (*Driver, error) {
chunkSize = v chunkSize = v
case int, uint, int32, uint32, uint64: case int, uint, int32, uint32, uint64:
chunkSize = reflect.ValueOf(v).Convert(reflect.TypeOf(chunkSize)).Int() chunkSize = reflect.ValueOf(v).Convert(reflect.TypeOf(chunkSize)).Int()
case nil:
// do nothing
default: default:
return nil, fmt.Errorf("invalid valud for chunksize: %#v", chunkSizeParam) return nil, fmt.Errorf("invalid value for chunksize: %#v", chunkSizeParam)
} }
if chunkSize < minChunkSize { if chunkSize < minChunkSize {
return nil, fmt.Errorf("The chunksize %#v parameter should be a number that is larger than or equal to %d", chunkSize, minChunkSize) return nil, fmt.Errorf("The chunksize %#v parameter should be a number that is larger than or equal to %d", chunkSize, minChunkSize)
} }
}
rootDirectory, ok := parameters["rootdirectory"] rootDirectory := parameters["rootdirectory"]
if !ok { if rootDirectory == nil {
rootDirectory = "" rootDirectory = ""
} }
storageClass := s3.StorageClassStandard storageClass := s3.StorageClassStandard
storageClassParam, ok := parameters["storageclass"] storageClassParam := parameters["storageclass"]
if ok { if storageClassParam != nil {
storageClassString, ok := storageClassParam.(string) storageClassString, ok := storageClassParam.(string)
if !ok { if !ok {
return nil, fmt.Errorf("The storageclass parameter must be one of %v, %v invalid", []string{s3.StorageClassStandard, s3.StorageClassReducedRedundancy}, storageClassParam) return nil, fmt.Errorf("The storageclass parameter must be one of %v, %v invalid", []string{s3.StorageClassStandard, s3.StorageClassReducedRedundancy}, storageClassParam)
@ -214,8 +243,8 @@ func FromParameters(parameters map[string]interface{}) (*Driver, error) {
storageClass = storageClassString storageClass = storageClassString
} }
userAgent, ok := parameters["useragent"] userAgent := parameters["useragent"]
if !ok { if userAgent == nil {
userAgent = "" userAgent = ""
} }
@ -224,7 +253,9 @@ func FromParameters(parameters map[string]interface{}) (*Driver, error) {
fmt.Sprint(secretKey), fmt.Sprint(secretKey),
fmt.Sprint(bucket), fmt.Sprint(bucket),
region, region,
fmt.Sprint(regionEndpoint),
encryptBool, encryptBool,
fmt.Sprint(keyID),
secureBool, secureBool,
chunkSize, chunkSize,
fmt.Sprint(rootDirectory), fmt.Sprint(rootDirectory),
@ -239,7 +270,9 @@ func FromParameters(parameters map[string]interface{}) (*Driver, error) {
// bucketName // bucketName
func New(params DriverParameters) (*Driver, error) { func New(params DriverParameters) (*Driver, error) {
awsConfig := aws.NewConfig() awsConfig := aws.NewConfig()
creds := credentials.NewChainCredentials([]credentials.Provider{ var creds *credentials.Credentials
if params.RegionEndpoint == "" {
creds = credentials.NewChainCredentials([]credentials.Provider{
&credentials.StaticProvider{ &credentials.StaticProvider{
Value: credentials.Value{ Value: credentials.Value{
AccessKeyID: params.AccessKey, AccessKeyID: params.AccessKey,
@ -251,10 +284,23 @@ func New(params DriverParameters) (*Driver, error) {
&ec2rolecreds.EC2RoleProvider{Client: ec2metadata.New(session.New())}, &ec2rolecreds.EC2RoleProvider{Client: ec2metadata.New(session.New())},
}) })
} else {
creds = credentials.NewChainCredentials([]credentials.Provider{
&credentials.StaticProvider{
Value: credentials.Value{
AccessKeyID: params.AccessKey,
SecretAccessKey: params.SecretKey,
},
},
&credentials.EnvProvider{},
})
awsConfig.WithS3ForcePathStyle(true)
awsConfig.WithEndpoint(params.RegionEndpoint)
}
awsConfig.WithCredentials(creds) awsConfig.WithCredentials(creds)
awsConfig.WithRegion(params.Region) awsConfig.WithRegion(params.Region)
awsConfig.WithDisableSSL(!params.Secure) awsConfig.WithDisableSSL(!params.Secure)
// awsConfig.WithMaxRetries(10)
if params.UserAgent != "" { if params.UserAgent != "" {
awsConfig.WithHTTPClient(&http.Client{ awsConfig.WithHTTPClient(&http.Client{
@ -284,13 +330,9 @@ func New(params DriverParameters) (*Driver, error) {
Bucket: params.Bucket, Bucket: params.Bucket,
ChunkSize: params.ChunkSize, ChunkSize: params.ChunkSize,
Encrypt: params.Encrypt, Encrypt: params.Encrypt,
KeyID: params.KeyID,
RootDirectory: params.RootDirectory, RootDirectory: params.RootDirectory,
StorageClass: params.StorageClass, StorageClass: params.StorageClass,
zeros: make([]byte, params.ChunkSize),
}
d.pool.New = func() interface{} {
return make([]byte, d.ChunkSize)
} }
return &Driver{ return &Driver{
@ -310,7 +352,7 @@ func (d *driver) Name() string {
// GetContent retrieves the content stored at "path" as a []byte. // GetContent retrieves the content stored at "path" as a []byte.
func (d *driver) GetContent(ctx context.Context, path string) ([]byte, error) { func (d *driver) GetContent(ctx context.Context, path string) ([]byte, error) {
reader, err := d.ReadStream(ctx, path, 0) reader, err := d.Reader(ctx, path, 0)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -325,15 +367,16 @@ func (d *driver) PutContent(ctx context.Context, path string, contents []byte) e
ContentType: d.getContentType(), ContentType: d.getContentType(),
ACL: d.getACL(), ACL: d.getACL(),
ServerSideEncryption: d.getEncryptionMode(), ServerSideEncryption: d.getEncryptionMode(),
SSEKMSKeyId: d.getSSEKMSKeyID(),
StorageClass: d.getStorageClass(), StorageClass: d.getStorageClass(),
Body: bytes.NewReader(contents), Body: bytes.NewReader(contents),
}) })
return parseError(path, err) return parseError(path, err)
} }
// ReadStream retrieves an io.ReadCloser for the content stored at "path" with a // Reader retrieves an io.ReadCloser for the content stored at "path" with a
// given byte offset. // given byte offset.
func (d *driver) ReadStream(ctx context.Context, path string, offset int64) (io.ReadCloser, error) { func (d *driver) Reader(ctx context.Context, path string, offset int64) (io.ReadCloser, error) {
resp, err := d.S3.GetObject(&s3.GetObjectInput{ resp, err := d.S3.GetObject(&s3.GetObjectInput{
Bucket: aws.String(d.Bucket), Bucket: aws.String(d.Bucket),
Key: aws.String(d.s3Path(path)), Key: aws.String(d.s3Path(path)),
@ -350,372 +393,53 @@ func (d *driver) ReadStream(ctx context.Context, path string, offset int64) (io.
return resp.Body, nil return resp.Body, nil
} }
// WriteStream stores the contents of the provided io.Reader at a // Writer returns a FileWriter which will store the content written to it
// location designated by the given path. The driver will know it has // at the location designated by "path" after the call to Commit.
// received the full contents when the reader returns io.EOF. The number func (d *driver) Writer(ctx context.Context, path string, append bool) (storagedriver.FileWriter, error) {
// of successfully READ bytes will be returned, even if an error is key := d.s3Path(path)
// returned. May be used to resume writing a stream by providing a nonzero if !append {
// offset. Offsets past the current size will write from the position // TODO (brianbland): cancel other uploads at this path
// beyond the end of the file.
func (d *driver) WriteStream(ctx context.Context, path string, offset int64, reader io.Reader) (totalRead int64, err error) {
var partNumber int64 = 1
bytesRead := 0
var putErrChan chan error
parts := []*s3.CompletedPart{}
done := make(chan struct{}) // stopgap to free up waiting goroutines
resp, err := d.S3.CreateMultipartUpload(&s3.CreateMultipartUploadInput{ resp, err := d.S3.CreateMultipartUpload(&s3.CreateMultipartUploadInput{
Bucket: aws.String(d.Bucket), Bucket: aws.String(d.Bucket),
Key: aws.String(d.s3Path(path)), Key: aws.String(key),
ContentType: d.getContentType(), ContentType: d.getContentType(),
ACL: d.getACL(), ACL: d.getACL(),
ServerSideEncryption: d.getEncryptionMode(), ServerSideEncryption: d.getEncryptionMode(),
SSEKMSKeyId: d.getSSEKMSKeyID(),
StorageClass: d.getStorageClass(), StorageClass: d.getStorageClass(),
}) })
if err != nil { if err != nil {
return 0, err return nil, err
} }
return d.newWriter(key, *resp.UploadId, nil), nil
uploadID := resp.UploadId
buf := d.getbuf()
// We never want to leave a dangling multipart upload, our only consistent state is
// when there is a whole object at path. This is in order to remain consistent with
// the stat call.
//
// Note that if the machine dies before executing the defer, we will be left with a dangling
// multipart upload, which will eventually be cleaned up, but we will lose all of the progress
// made prior to the machine crashing.
defer func() {
if putErrChan != nil {
if putErr := <-putErrChan; putErr != nil {
err = putErr
} }
} resp, err := d.S3.ListMultipartUploads(&s3.ListMultipartUploadsInput{
if len(parts) > 0 {
_, err := d.S3.CompleteMultipartUpload(&s3.CompleteMultipartUploadInput{
Bucket: aws.String(d.Bucket), Bucket: aws.String(d.Bucket),
Key: aws.String(d.s3Path(path)), Prefix: aws.String(key),
UploadId: uploadID,
MultipartUpload: &s3.CompletedMultipartUpload{
Parts: parts,
},
}) })
if err != nil { if err != nil {
// TODO (brianbland): log errors here return nil, parseError(path, err)
d.S3.AbortMultipartUpload(&s3.AbortMultipartUploadInput{ }
for _, multi := range resp.Uploads {
if key != *multi.Key {
continue
}
resp, err := d.S3.ListParts(&s3.ListPartsInput{
Bucket: aws.String(d.Bucket), Bucket: aws.String(d.Bucket),
Key: aws.String(d.s3Path(path)), Key: aws.String(key),
UploadId: uploadID, UploadId: multi.UploadId,
})
}
}
d.putbuf(buf) // needs to be here to pick up new buf value
close(done) // free up any waiting goroutines
}()
// Fills from 0 to total from current
fromSmallCurrent := func(total int64) error {
current, err := d.ReadStream(ctx, path, 0)
if err != nil {
return err
}
bytesRead = 0
for int64(bytesRead) < total {
//The loop should very rarely enter a second iteration
nn, err := current.Read(buf[bytesRead:total])
bytesRead += nn
if err != nil {
if err != io.EOF {
return err
}
break
}
}
return nil
}
// Fills from parameter to chunkSize from reader
fromReader := func(from int64) error {
bytesRead = 0
for from+int64(bytesRead) < d.ChunkSize {
nn, err := reader.Read(buf[from+int64(bytesRead):])
totalRead += int64(nn)
bytesRead += nn
if err != nil {
if err != io.EOF {
return err
}
break
}
}
if putErrChan == nil {
putErrChan = make(chan error)
} else {
if putErr := <-putErrChan; putErr != nil {
putErrChan = nil
return putErr
}
}
go func(bytesRead int, from int64, buf []byte) {
defer d.putbuf(buf) // this buffer gets dropped after this call
// DRAGONS(stevvooe): There are few things one might want to know
// about this section. First, the putErrChan is expecting an error
// and a nil or just a nil to come through the channel. This is
// covered by the silly defer below. The other aspect is the s3
// retry backoff to deal with RequestTimeout errors. Even though
// the underlying s3 library should handle it, it doesn't seem to
// be part of the shouldRetry function (see AdRoll/goamz/s3).
defer func() {
select {
case putErrChan <- nil: // for some reason, we do this no matter what.
case <-done:
return // ensure we don't leak the goroutine
}
}()
if bytesRead <= 0 {
return
}
resp, err := d.S3.UploadPart(&s3.UploadPartInput{
Bucket: aws.String(d.Bucket),
Key: aws.String(d.s3Path(path)),
PartNumber: aws.Int64(partNumber),
UploadId: uploadID,
Body: bytes.NewReader(buf[0 : int64(bytesRead)+from]),
}) })
if err != nil { if err != nil {
logrus.Errorf("error putting part, aborting: %v", err) return nil, parseError(path, err)
select {
case putErrChan <- err:
case <-done:
return // don't leak the goroutine
} }
var multiSize int64
for _, part := range resp.Parts {
multiSize += *part.Size
} }
return d.newWriter(key, *multi.UploadId, resp.Parts), nil
// parts and partNumber are safe, because this function is the
// only one modifying them and we force it to be executed
// serially.
parts = append(parts, &s3.CompletedPart{
ETag: resp.ETag,
PartNumber: aws.Int64(partNumber),
})
partNumber++
}(bytesRead, from, buf)
buf = d.getbuf() // use a new buffer for the next call
return nil
} }
return nil, storagedriver.PathNotFoundError{Path: path}
if offset > 0 {
resp, err := d.S3.HeadObject(&s3.HeadObjectInput{
Bucket: aws.String(d.Bucket),
Key: aws.String(d.s3Path(path)),
})
if err != nil {
if s3Err, ok := err.(awserr.Error); !ok || s3Err.Code() != "NoSuchKey" {
return 0, err
}
}
currentLength := int64(0)
if err == nil && resp.ContentLength != nil {
currentLength = *resp.ContentLength
}
if currentLength >= offset {
if offset < d.ChunkSize {
// chunkSize > currentLength >= offset
if err = fromSmallCurrent(offset); err != nil {
return totalRead, err
}
if err = fromReader(offset); err != nil {
return totalRead, err
}
if totalRead+offset < d.ChunkSize {
return totalRead, nil
}
} else {
// currentLength >= offset >= chunkSize
resp, err := d.S3.UploadPartCopy(&s3.UploadPartCopyInput{
Bucket: aws.String(d.Bucket),
Key: aws.String(d.s3Path(path)),
PartNumber: aws.Int64(partNumber),
UploadId: uploadID,
CopySource: aws.String(d.Bucket + "/" + d.s3Path(path)),
CopySourceRange: aws.String("bytes=0-" + strconv.FormatInt(offset-1, 10)),
})
if err != nil {
return 0, err
}
parts = append(parts, &s3.CompletedPart{
ETag: resp.CopyPartResult.ETag,
PartNumber: aws.Int64(partNumber),
})
partNumber++
}
} else {
// Fills between parameters with 0s but only when to - from <= chunkSize
fromZeroFillSmall := func(from, to int64) error {
bytesRead = 0
for from+int64(bytesRead) < to {
nn, err := bytes.NewReader(d.zeros).Read(buf[from+int64(bytesRead) : to])
bytesRead += nn
if err != nil {
return err
}
}
return nil
}
// Fills between parameters with 0s, making new parts
fromZeroFillLarge := func(from, to int64) error {
bytesRead64 := int64(0)
for to-(from+bytesRead64) >= d.ChunkSize {
resp, err := d.S3.UploadPart(&s3.UploadPartInput{
Bucket: aws.String(d.Bucket),
Key: aws.String(d.s3Path(path)),
PartNumber: aws.Int64(partNumber),
UploadId: uploadID,
Body: bytes.NewReader(d.zeros),
})
if err != nil {
return err
}
bytesRead64 += d.ChunkSize
parts = append(parts, &s3.CompletedPart{
ETag: resp.ETag,
PartNumber: aws.Int64(partNumber),
})
partNumber++
}
return fromZeroFillSmall(0, (to-from)%d.ChunkSize)
}
// currentLength < offset
if currentLength < d.ChunkSize {
if offset < d.ChunkSize {
// chunkSize > offset > currentLength
if err = fromSmallCurrent(currentLength); err != nil {
return totalRead, err
}
if err = fromZeroFillSmall(currentLength, offset); err != nil {
return totalRead, err
}
if err = fromReader(offset); err != nil {
return totalRead, err
}
if totalRead+offset < d.ChunkSize {
return totalRead, nil
}
} else {
// offset >= chunkSize > currentLength
if err = fromSmallCurrent(currentLength); err != nil {
return totalRead, err
}
if err = fromZeroFillSmall(currentLength, d.ChunkSize); err != nil {
return totalRead, err
}
resp, err := d.S3.UploadPart(&s3.UploadPartInput{
Bucket: aws.String(d.Bucket),
Key: aws.String(d.s3Path(path)),
PartNumber: aws.Int64(partNumber),
UploadId: uploadID,
Body: bytes.NewReader(buf),
})
if err != nil {
return totalRead, err
}
parts = append(parts, &s3.CompletedPart{
ETag: resp.ETag,
PartNumber: aws.Int64(partNumber),
})
partNumber++
//Zero fill from chunkSize up to offset, then some reader
if err = fromZeroFillLarge(d.ChunkSize, offset); err != nil {
return totalRead, err
}
if err = fromReader(offset % d.ChunkSize); err != nil {
return totalRead, err
}
if totalRead+(offset%d.ChunkSize) < d.ChunkSize {
return totalRead, nil
}
}
} else {
// offset > currentLength >= chunkSize
resp, err := d.S3.UploadPartCopy(&s3.UploadPartCopyInput{
Bucket: aws.String(d.Bucket),
Key: aws.String(d.s3Path(path)),
PartNumber: aws.Int64(partNumber),
UploadId: uploadID,
CopySource: aws.String(d.Bucket + "/" + d.s3Path(path)),
})
if err != nil {
return 0, err
}
parts = append(parts, &s3.CompletedPart{
ETag: resp.CopyPartResult.ETag,
PartNumber: aws.Int64(partNumber),
})
partNumber++
//Zero fill from currentLength up to offset, then some reader
if err = fromZeroFillLarge(currentLength, offset); err != nil {
return totalRead, err
}
if err = fromReader((offset - currentLength) % d.ChunkSize); err != nil {
return totalRead, err
}
if totalRead+((offset-currentLength)%d.ChunkSize) < d.ChunkSize {
return totalRead, nil
}
}
}
}
for {
if err = fromReader(0); err != nil {
return totalRead, err
}
if int64(bytesRead) < d.ChunkSize {
break
}
}
return totalRead, nil
} }
// Stat retrieves the FileInfo for the given path, including the current size // Stat retrieves the FileInfo for the given path, including the current size
@ -826,6 +550,7 @@ func (d *driver) Move(ctx context.Context, sourcePath string, destPath string) e
ContentType: d.getContentType(), ContentType: d.getContentType(),
ACL: d.getACL(), ACL: d.getACL(),
ServerSideEncryption: d.getEncryptionMode(), ServerSideEncryption: d.getEncryptionMode(),
SSEKMSKeyId: d.getSSEKMSKeyID(),
StorageClass: d.getStorageClass(), StorageClass: d.getStorageClass(),
CopySource: aws.String(d.Bucket + "/" + d.s3Path(sourcePath)), CopySource: aws.String(d.Bucket + "/" + d.s3Path(sourcePath)),
}) })
@ -937,9 +662,19 @@ func parseError(path string, err error) error {
} }
func (d *driver) getEncryptionMode() *string { func (d *driver) getEncryptionMode() *string {
if d.Encrypt { if !d.Encrypt {
return nil
}
if d.KeyID == "" {
return aws.String("AES256") return aws.String("AES256")
} }
return aws.String("aws:kms")
}
func (d *driver) getSSEKMSKeyID() *string {
if d.KeyID != "" {
return aws.String(d.KeyID)
}
return nil return nil
} }
@ -955,12 +690,271 @@ func (d *driver) getStorageClass() *string {
return aws.String(d.StorageClass) return aws.String(d.StorageClass)
} }
// getbuf returns a buffer from the driver's pool with length d.ChunkSize. // writer attempts to upload parts to S3 in a buffered fashion where the last
func (d *driver) getbuf() []byte { // part is at least as large as the chunksize, so the multipart upload could be
return d.pool.Get().([]byte) // cleanly resumed in the future. This is violated if Close is called after less
// than a full chunk is written.
type writer struct {
driver *driver
key string
uploadID string
parts []*s3.Part
size int64
readyPart []byte
pendingPart []byte
closed bool
committed bool
cancelled bool
} }
func (d *driver) putbuf(p []byte) { func (d *driver) newWriter(key, uploadID string, parts []*s3.Part) storagedriver.FileWriter {
copy(p, d.zeros) var size int64
d.pool.Put(p) for _, part := range parts {
size += *part.Size
}
return &writer{
driver: d,
key: key,
uploadID: uploadID,
parts: parts,
size: size,
}
}
type completedParts []*s3.CompletedPart
func (a completedParts) Len() int { return len(a) }
func (a completedParts) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a completedParts) Less(i, j int) bool { return *a[i].PartNumber < *a[j].PartNumber }
func (w *writer) Write(p []byte) (int, error) {
if w.closed {
return 0, fmt.Errorf("already closed")
} else if w.committed {
return 0, fmt.Errorf("already committed")
} else if w.cancelled {
return 0, fmt.Errorf("already cancelled")
}
// If the last written part is smaller than minChunkSize, we need to make a
// new multipart upload :sadface:
if len(w.parts) > 0 && int(*w.parts[len(w.parts)-1].Size) < minChunkSize {
var completedUploadedParts completedParts
for _, part := range w.parts {
completedUploadedParts = append(completedUploadedParts, &s3.CompletedPart{
ETag: part.ETag,
PartNumber: part.PartNumber,
})
}
sort.Sort(completedUploadedParts)
_, err := w.driver.S3.CompleteMultipartUpload(&s3.CompleteMultipartUploadInput{
Bucket: aws.String(w.driver.Bucket),
Key: aws.String(w.key),
UploadId: aws.String(w.uploadID),
MultipartUpload: &s3.CompletedMultipartUpload{
Parts: completedUploadedParts,
},
})
if err != nil {
w.driver.S3.AbortMultipartUpload(&s3.AbortMultipartUploadInput{
Bucket: aws.String(w.driver.Bucket),
Key: aws.String(w.key),
UploadId: aws.String(w.uploadID),
})
return 0, err
}
resp, err := w.driver.S3.CreateMultipartUpload(&s3.CreateMultipartUploadInput{
Bucket: aws.String(w.driver.Bucket),
Key: aws.String(w.key),
ContentType: w.driver.getContentType(),
ACL: w.driver.getACL(),
ServerSideEncryption: w.driver.getEncryptionMode(),
StorageClass: w.driver.getStorageClass(),
})
if err != nil {
return 0, err
}
w.uploadID = *resp.UploadId
// If the entire written file is smaller than minChunkSize, we need to make
// a new part from scratch :double sad face:
if w.size < minChunkSize {
resp, err := w.driver.S3.GetObject(&s3.GetObjectInput{
Bucket: aws.String(w.driver.Bucket),
Key: aws.String(w.key),
})
defer resp.Body.Close()
if err != nil {
return 0, err
}
w.parts = nil
w.readyPart, err = ioutil.ReadAll(resp.Body)
if err != nil {
return 0, err
}
} else {
// Otherwise we can use the old file as the new first part
copyPartResp, err := w.driver.S3.UploadPartCopy(&s3.UploadPartCopyInput{
Bucket: aws.String(w.driver.Bucket),
CopySource: aws.String(w.driver.Bucket + "/" + w.key),
Key: aws.String(w.key),
PartNumber: aws.Int64(1),
UploadId: resp.UploadId,
})
if err != nil {
return 0, err
}
w.parts = []*s3.Part{
{
ETag: copyPartResp.CopyPartResult.ETag,
PartNumber: aws.Int64(1),
Size: aws.Int64(w.size),
},
}
}
}
var n int
for len(p) > 0 {
// If no parts are ready to write, fill up the first part
if neededBytes := int(w.driver.ChunkSize) - len(w.readyPart); neededBytes > 0 {
if len(p) >= neededBytes {
w.readyPart = append(w.readyPart, p[:neededBytes]...)
n += neededBytes
p = p[neededBytes:]
} else {
w.readyPart = append(w.readyPart, p...)
n += len(p)
p = nil
}
}
if neededBytes := int(w.driver.ChunkSize) - len(w.pendingPart); neededBytes > 0 {
if len(p) >= neededBytes {
w.pendingPart = append(w.pendingPart, p[:neededBytes]...)
n += neededBytes
p = p[neededBytes:]
err := w.flushPart()
if err != nil {
w.size += int64(n)
return n, err
}
} else {
w.pendingPart = append(w.pendingPart, p...)
n += len(p)
p = nil
}
}
}
w.size += int64(n)
return n, nil
}
func (w *writer) Size() int64 {
return w.size
}
func (w *writer) Close() error {
if w.closed {
return fmt.Errorf("already closed")
}
w.closed = true
return w.flushPart()
}
func (w *writer) Cancel() error {
if w.closed {
return fmt.Errorf("already closed")
} else if w.committed {
return fmt.Errorf("already committed")
}
w.cancelled = true
_, err := w.driver.S3.AbortMultipartUpload(&s3.AbortMultipartUploadInput{
Bucket: aws.String(w.driver.Bucket),
Key: aws.String(w.key),
UploadId: aws.String(w.uploadID),
})
return err
}
func (w *writer) Commit() error {
if w.closed {
return fmt.Errorf("already closed")
} else if w.committed {
return fmt.Errorf("already committed")
} else if w.cancelled {
return fmt.Errorf("already cancelled")
}
err := w.flushPart()
if err != nil {
return err
}
w.committed = true
var completedUploadedParts completedParts
for _, part := range w.parts {
completedUploadedParts = append(completedUploadedParts, &s3.CompletedPart{
ETag: part.ETag,
PartNumber: part.PartNumber,
})
}
sort.Sort(completedUploadedParts)
_, err = w.driver.S3.CompleteMultipartUpload(&s3.CompleteMultipartUploadInput{
Bucket: aws.String(w.driver.Bucket),
Key: aws.String(w.key),
UploadId: aws.String(w.uploadID),
MultipartUpload: &s3.CompletedMultipartUpload{
Parts: completedUploadedParts,
},
})
if err != nil {
w.driver.S3.AbortMultipartUpload(&s3.AbortMultipartUploadInput{
Bucket: aws.String(w.driver.Bucket),
Key: aws.String(w.key),
UploadId: aws.String(w.uploadID),
})
return err
}
return nil
}
// flushPart flushes buffers to write a part to S3.
// Only called by Write (with both buffers full) and Close/Commit (always)
func (w *writer) flushPart() error {
if len(w.readyPart) == 0 && len(w.pendingPart) == 0 {
// nothing to write
return nil
}
if len(w.pendingPart) < int(w.driver.ChunkSize) {
// closing with a small pending part
// combine ready and pending to avoid writing a small part
w.readyPart = append(w.readyPart, w.pendingPart...)
w.pendingPart = nil
}
partNumber := aws.Int64(int64(len(w.parts) + 1))
resp, err := w.driver.S3.UploadPart(&s3.UploadPartInput{
Bucket: aws.String(w.driver.Bucket),
Key: aws.String(w.key),
PartNumber: partNumber,
UploadId: aws.String(w.uploadID),
Body: bytes.NewReader(w.readyPart),
})
if err != nil {
return err
}
w.parts = append(w.parts, &s3.Part{
ETag: resp.ETag,
PartNumber: partNumber,
Size: aws.Int64(int64(len(w.readyPart))),
})
w.readyPart = w.pendingPart
w.pendingPart = nil
return nil
} }

View file

@ -27,9 +27,11 @@ func init() {
secretKey := os.Getenv("AWS_SECRET_KEY") secretKey := os.Getenv("AWS_SECRET_KEY")
bucket := os.Getenv("S3_BUCKET") bucket := os.Getenv("S3_BUCKET")
encrypt := os.Getenv("S3_ENCRYPT") encrypt := os.Getenv("S3_ENCRYPT")
keyID := os.Getenv("S3_KEY_ID")
secure := os.Getenv("S3_SECURE") secure := os.Getenv("S3_SECURE")
region := os.Getenv("AWS_REGION") region := os.Getenv("AWS_REGION")
root, err := ioutil.TempDir("", "driver-") root, err := ioutil.TempDir("", "driver-")
regionEndpoint := os.Getenv("REGION_ENDPOINT")
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -57,7 +59,9 @@ func init() {
secretKey, secretKey,
bucket, bucket,
region, region,
regionEndpoint,
encryptBool, encryptBool,
keyID,
secureBool, secureBool,
minChunkSize, minChunkSize,
rootDirectory, rootDirectory,

View file

@ -21,10 +21,8 @@ import (
"reflect" "reflect"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"github.com/Sirupsen/logrus"
"github.com/docker/goamz/aws" "github.com/docker/goamz/aws"
"github.com/docker/goamz/s3" "github.com/docker/goamz/s3"
@ -79,9 +77,6 @@ type driver struct {
Encrypt bool Encrypt bool
RootDirectory string RootDirectory string
StorageClass s3.StorageClass StorageClass s3.StorageClass
pool sync.Pool // pool []byte buffers used for WriteStream
zeros []byte // shared, zero-valued buffer used for WriteStream
} }
type baseEmbed struct { type baseEmbed struct {
@ -301,11 +296,6 @@ func New(params DriverParameters) (*Driver, error) {
Encrypt: params.Encrypt, Encrypt: params.Encrypt,
RootDirectory: params.RootDirectory, RootDirectory: params.RootDirectory,
StorageClass: params.StorageClass, StorageClass: params.StorageClass,
zeros: make([]byte, params.ChunkSize),
}
d.pool.New = func() interface{} {
return make([]byte, d.ChunkSize)
} }
return &Driver{ return &Driver{
@ -337,9 +327,9 @@ func (d *driver) PutContent(ctx context.Context, path string, contents []byte) e
return parseError(path, d.Bucket.Put(d.s3Path(path), contents, d.getContentType(), getPermissions(), d.getOptions())) return parseError(path, d.Bucket.Put(d.s3Path(path), contents, d.getContentType(), getPermissions(), d.getOptions()))
} }
// ReadStream retrieves an io.ReadCloser for the content stored at "path" with a // Reader retrieves an io.ReadCloser for the content stored at "path" with a
// given byte offset. // given byte offset.
func (d *driver) ReadStream(ctx context.Context, path string, offset int64) (io.ReadCloser, error) { func (d *driver) Reader(ctx context.Context, path string, offset int64) (io.ReadCloser, error) {
headers := make(http.Header) headers := make(http.Header)
headers.Add("Range", "bytes="+strconv.FormatInt(offset, 10)+"-") headers.Add("Range", "bytes="+strconv.FormatInt(offset, 10)+"-")
@ -354,343 +344,37 @@ func (d *driver) ReadStream(ctx context.Context, path string, offset int64) (io.
return resp.Body, nil return resp.Body, nil
} }
// WriteStream stores the contents of the provided io.Reader at a // Writer returns a FileWriter which will store the content written to it
// location designated by the given path. The driver will know it has // at the location designated by "path" after the call to Commit.
// received the full contents when the reader returns io.EOF. The number func (d *driver) Writer(ctx context.Context, path string, append bool) (storagedriver.FileWriter, error) {
// of successfully READ bytes will be returned, even if an error is key := d.s3Path(path)
// returned. May be used to resume writing a stream by providing a nonzero if !append {
// offset. Offsets past the current size will write from the position // TODO (brianbland): cancel other uploads at this path
// beyond the end of the file. multi, err := d.Bucket.InitMulti(key, d.getContentType(), getPermissions(), d.getOptions())
func (d *driver) WriteStream(ctx context.Context, path string, offset int64, reader io.Reader) (totalRead int64, err error) {
partNumber := 1
bytesRead := 0
var putErrChan chan error
parts := []s3.Part{}
var part s3.Part
done := make(chan struct{}) // stopgap to free up waiting goroutines
multi, err := d.Bucket.InitMulti(d.s3Path(path), d.getContentType(), getPermissions(), d.getOptions())
if err != nil { if err != nil {
return 0, err return nil, err
} }
return d.newWriter(key, multi, nil), nil
buf := d.getbuf()
// We never want to leave a dangling multipart upload, our only consistent state is
// when there is a whole object at path. This is in order to remain consistent with
// the stat call.
//
// Note that if the machine dies before executing the defer, we will be left with a dangling
// multipart upload, which will eventually be cleaned up, but we will lose all of the progress
// made prior to the machine crashing.
defer func() {
if putErrChan != nil {
if putErr := <-putErrChan; putErr != nil {
err = putErr
} }
} multis, _, err := d.Bucket.ListMulti(key, "")
if len(parts) > 0 {
if multi == nil {
// Parts should be empty if the multi is not initialized
panic("Unreachable")
} else {
if multi.Complete(parts) != nil {
multi.Abort()
}
}
}
d.putbuf(buf) // needs to be here to pick up new buf value
close(done) // free up any waiting goroutines
}()
// Fills from 0 to total from current
fromSmallCurrent := func(total int64) error {
current, err := d.ReadStream(ctx, path, 0)
if err != nil { if err != nil {
return err return nil, parseError(path, err)
} }
for _, multi := range multis {
bytesRead = 0 if key != multi.Key {
for int64(bytesRead) < total { continue
//The loop should very rarely enter a second iteration }
nn, err := current.Read(buf[bytesRead:total]) parts, err := multi.ListParts()
bytesRead += nn
if err != nil { if err != nil {
if err != io.EOF { return nil, parseError(path, err)
return err
} }
var multiSize int64
break for _, part := range parts {
multiSize += part.Size
} }
return d.newWriter(key, multi, parts), nil
} }
return nil return nil, storagedriver.PathNotFoundError{Path: path}
}
// Fills from parameter to chunkSize from reader
fromReader := func(from int64) error {
bytesRead = 0
for from+int64(bytesRead) < d.ChunkSize {
nn, err := reader.Read(buf[from+int64(bytesRead):])
totalRead += int64(nn)
bytesRead += nn
if err != nil {
if err != io.EOF {
return err
}
break
}
}
if putErrChan == nil {
putErrChan = make(chan error)
} else {
if putErr := <-putErrChan; putErr != nil {
putErrChan = nil
return putErr
}
}
go func(bytesRead int, from int64, buf []byte) {
defer d.putbuf(buf) // this buffer gets dropped after this call
// DRAGONS(stevvooe): There are few things one might want to know
// about this section. First, the putErrChan is expecting an error
// and a nil or just a nil to come through the channel. This is
// covered by the silly defer below. The other aspect is the s3
// retry backoff to deal with RequestTimeout errors. Even though
// the underlying s3 library should handle it, it doesn't seem to
// be part of the shouldRetry function (see AdRoll/goamz/s3).
defer func() {
select {
case putErrChan <- nil: // for some reason, we do this no matter what.
case <-done:
return // ensure we don't leak the goroutine
}
}()
if bytesRead <= 0 {
return
}
var err error
var part s3.Part
loop:
for retries := 0; retries < 5; retries++ {
part, err = multi.PutPart(int(partNumber), bytes.NewReader(buf[0:int64(bytesRead)+from]))
if err == nil {
break // success!
}
// NOTE(stevvooe): This retry code tries to only retry under
// conditions where the s3 package does not. We may add s3
// error codes to the below if we see others bubble up in the
// application. Right now, the most troubling is
// RequestTimeout, which seems to only triggered when a tcp
// connection to s3 slows to a crawl. If the RequestTimeout
// ends up getting added to the s3 library and we don't see
// other errors, this retry loop can be removed.
switch err := err.(type) {
case *s3.Error:
switch err.Code {
case "RequestTimeout":
// allow retries on only this error.
default:
break loop
}
}
backoff := 100 * time.Millisecond * time.Duration(retries+1)
logrus.Errorf("error putting part, retrying after %v: %v", err, backoff.String())
time.Sleep(backoff)
}
if err != nil {
logrus.Errorf("error putting part, aborting: %v", err)
select {
case putErrChan <- err:
case <-done:
return // don't leak the goroutine
}
}
// parts and partNumber are safe, because this function is the
// only one modifying them and we force it to be executed
// serially.
parts = append(parts, part)
partNumber++
}(bytesRead, from, buf)
buf = d.getbuf() // use a new buffer for the next call
return nil
}
if offset > 0 {
resp, err := d.Bucket.Head(d.s3Path(path), nil)
if err != nil {
if s3Err, ok := err.(*s3.Error); !ok || s3Err.Code != "NoSuchKey" {
return 0, err
}
}
currentLength := int64(0)
if err == nil {
currentLength = resp.ContentLength
}
if currentLength >= offset {
if offset < d.ChunkSize {
// chunkSize > currentLength >= offset
if err = fromSmallCurrent(offset); err != nil {
return totalRead, err
}
if err = fromReader(offset); err != nil {
return totalRead, err
}
if totalRead+offset < d.ChunkSize {
return totalRead, nil
}
} else {
// currentLength >= offset >= chunkSize
_, part, err = multi.PutPartCopy(partNumber,
s3.CopyOptions{CopySourceOptions: "bytes=0-" + strconv.FormatInt(offset-1, 10)},
d.Bucket.Name+"/"+d.s3Path(path))
if err != nil {
return 0, err
}
parts = append(parts, part)
partNumber++
}
} else {
// Fills between parameters with 0s but only when to - from <= chunkSize
fromZeroFillSmall := func(from, to int64) error {
bytesRead = 0
for from+int64(bytesRead) < to {
nn, err := bytes.NewReader(d.zeros).Read(buf[from+int64(bytesRead) : to])
bytesRead += nn
if err != nil {
return err
}
}
return nil
}
// Fills between parameters with 0s, making new parts
fromZeroFillLarge := func(from, to int64) error {
bytesRead64 := int64(0)
for to-(from+bytesRead64) >= d.ChunkSize {
part, err := multi.PutPart(int(partNumber), bytes.NewReader(d.zeros))
if err != nil {
return err
}
bytesRead64 += d.ChunkSize
parts = append(parts, part)
partNumber++
}
return fromZeroFillSmall(0, (to-from)%d.ChunkSize)
}
// currentLength < offset
if currentLength < d.ChunkSize {
if offset < d.ChunkSize {
// chunkSize > offset > currentLength
if err = fromSmallCurrent(currentLength); err != nil {
return totalRead, err
}
if err = fromZeroFillSmall(currentLength, offset); err != nil {
return totalRead, err
}
if err = fromReader(offset); err != nil {
return totalRead, err
}
if totalRead+offset < d.ChunkSize {
return totalRead, nil
}
} else {
// offset >= chunkSize > currentLength
if err = fromSmallCurrent(currentLength); err != nil {
return totalRead, err
}
if err = fromZeroFillSmall(currentLength, d.ChunkSize); err != nil {
return totalRead, err
}
part, err = multi.PutPart(int(partNumber), bytes.NewReader(buf))
if err != nil {
return totalRead, err
}
parts = append(parts, part)
partNumber++
//Zero fill from chunkSize up to offset, then some reader
if err = fromZeroFillLarge(d.ChunkSize, offset); err != nil {
return totalRead, err
}
if err = fromReader(offset % d.ChunkSize); err != nil {
return totalRead, err
}
if totalRead+(offset%d.ChunkSize) < d.ChunkSize {
return totalRead, nil
}
}
} else {
// offset > currentLength >= chunkSize
_, part, err = multi.PutPartCopy(partNumber,
s3.CopyOptions{},
d.Bucket.Name+"/"+d.s3Path(path))
if err != nil {
return 0, err
}
parts = append(parts, part)
partNumber++
//Zero fill from currentLength up to offset, then some reader
if err = fromZeroFillLarge(currentLength, offset); err != nil {
return totalRead, err
}
if err = fromReader((offset - currentLength) % d.ChunkSize); err != nil {
return totalRead, err
}
if totalRead+((offset-currentLength)%d.ChunkSize) < d.ChunkSize {
return totalRead, nil
}
}
}
}
for {
if err = fromReader(0); err != nil {
return totalRead, err
}
if int64(bytesRead) < d.ChunkSize {
break
}
}
return totalRead, nil
} }
// Stat retrieves the FileInfo for the given path, including the current size // Stat retrieves the FileInfo for the given path, including the current size
@ -882,12 +566,181 @@ func (d *driver) getContentType() string {
return "application/octet-stream" return "application/octet-stream"
} }
// getbuf returns a buffer from the driver's pool with length d.ChunkSize. // writer attempts to upload parts to S3 in a buffered fashion where the last
func (d *driver) getbuf() []byte { // part is at least as large as the chunksize, so the multipart upload could be
return d.pool.Get().([]byte) // cleanly resumed in the future. This is violated if Close is called after less
// than a full chunk is written.
type writer struct {
driver *driver
key string
multi *s3.Multi
parts []s3.Part
size int64
readyPart []byte
pendingPart []byte
closed bool
committed bool
cancelled bool
} }
func (d *driver) putbuf(p []byte) { func (d *driver) newWriter(key string, multi *s3.Multi, parts []s3.Part) storagedriver.FileWriter {
copy(p, d.zeros) var size int64
d.pool.Put(p) for _, part := range parts {
size += part.Size
}
return &writer{
driver: d,
key: key,
multi: multi,
parts: parts,
size: size,
}
}
func (w *writer) Write(p []byte) (int, error) {
if w.closed {
return 0, fmt.Errorf("already closed")
} else if w.committed {
return 0, fmt.Errorf("already committed")
} else if w.cancelled {
return 0, fmt.Errorf("already cancelled")
}
// If the last written part is smaller than minChunkSize, we need to make a
// new multipart upload :sadface:
if len(w.parts) > 0 && int(w.parts[len(w.parts)-1].Size) < minChunkSize {
err := w.multi.Complete(w.parts)
if err != nil {
w.multi.Abort()
return 0, err
}
multi, err := w.driver.Bucket.InitMulti(w.key, w.driver.getContentType(), getPermissions(), w.driver.getOptions())
if err != nil {
return 0, err
}
w.multi = multi
// If the entire written file is smaller than minChunkSize, we need to make
// a new part from scratch :double sad face:
if w.size < minChunkSize {
contents, err := w.driver.Bucket.Get(w.key)
if err != nil {
return 0, err
}
w.parts = nil
w.readyPart = contents
} else {
// Otherwise we can use the old file as the new first part
_, part, err := multi.PutPartCopy(1, s3.CopyOptions{}, w.driver.Bucket.Name+"/"+w.key)
if err != nil {
return 0, err
}
w.parts = []s3.Part{part}
}
}
var n int
for len(p) > 0 {
// If no parts are ready to write, fill up the first part
if neededBytes := int(w.driver.ChunkSize) - len(w.readyPart); neededBytes > 0 {
if len(p) >= neededBytes {
w.readyPart = append(w.readyPart, p[:neededBytes]...)
n += neededBytes
p = p[neededBytes:]
} else {
w.readyPart = append(w.readyPart, p...)
n += len(p)
p = nil
}
}
if neededBytes := int(w.driver.ChunkSize) - len(w.pendingPart); neededBytes > 0 {
if len(p) >= neededBytes {
w.pendingPart = append(w.pendingPart, p[:neededBytes]...)
n += neededBytes
p = p[neededBytes:]
err := w.flushPart()
if err != nil {
w.size += int64(n)
return n, err
}
} else {
w.pendingPart = append(w.pendingPart, p...)
n += len(p)
p = nil
}
}
}
w.size += int64(n)
return n, nil
}
func (w *writer) Size() int64 {
return w.size
}
func (w *writer) Close() error {
if w.closed {
return fmt.Errorf("already closed")
}
w.closed = true
return w.flushPart()
}
func (w *writer) Cancel() error {
if w.closed {
return fmt.Errorf("already closed")
} else if w.committed {
return fmt.Errorf("already committed")
}
w.cancelled = true
err := w.multi.Abort()
return err
}
func (w *writer) Commit() error {
if w.closed {
return fmt.Errorf("already closed")
} else if w.committed {
return fmt.Errorf("already committed")
} else if w.cancelled {
return fmt.Errorf("already cancelled")
}
err := w.flushPart()
if err != nil {
return err
}
w.committed = true
err = w.multi.Complete(w.parts)
if err != nil {
w.multi.Abort()
return err
}
return nil
}
// flushPart flushes buffers to write a part to S3.
// Only called by Write (with both buffers full) and Close/Commit (always)
func (w *writer) flushPart() error {
if len(w.readyPart) == 0 && len(w.pendingPart) == 0 {
// nothing to write
return nil
}
if len(w.pendingPart) < int(w.driver.ChunkSize) {
// closing with a small pending part
// combine ready and pending to avoid writing a small part
w.readyPart = append(w.readyPart, w.pendingPart...)
w.pendingPart = nil
}
part, err := w.multi.PutPart(len(w.parts)+1, bytes.NewReader(w.readyPart))
if err != nil {
return err
}
w.parts = append(w.parts, part)
w.readyPart = w.pendingPart
w.pendingPart = nil
return nil
} }

View file

@ -34,7 +34,11 @@ func (version Version) Minor() uint {
const CurrentVersion Version = "0.1" const CurrentVersion Version = "0.1"
// StorageDriver defines methods that a Storage Driver must implement for a // StorageDriver defines methods that a Storage Driver must implement for a
// filesystem-like key/value object storage. // filesystem-like key/value object storage. Storage Drivers are automatically
// registered via an internal registration mechanism, and generally created
// via the StorageDriverFactory interface (https://godoc.org/github.com/docker/distribution/registry/storage/driver/factory).
// Please see the aforementioned factory package for example code showing how to get an instance
// of a StorageDriver
type StorageDriver interface { type StorageDriver interface {
// Name returns the human-readable "name" of the driver, useful in error // Name returns the human-readable "name" of the driver, useful in error
// messages and logging. By convention, this will just be the registration // messages and logging. By convention, this will just be the registration
@ -49,15 +53,14 @@ type StorageDriver interface {
// This should primarily be used for small objects. // This should primarily be used for small objects.
PutContent(ctx context.Context, path string, content []byte) error PutContent(ctx context.Context, path string, content []byte) error
// ReadStream retrieves an io.ReadCloser for the content stored at "path" // Reader retrieves an io.ReadCloser for the content stored at "path"
// with a given byte offset. // with a given byte offset.
// May be used to resume reading a stream by providing a nonzero offset. // May be used to resume reading a stream by providing a nonzero offset.
ReadStream(ctx context.Context, path string, offset int64) (io.ReadCloser, error) Reader(ctx context.Context, path string, offset int64) (io.ReadCloser, error)
// WriteStream stores the contents of the provided io.ReadCloser at a // Writer returns a FileWriter which will store the content written to it
// location designated by the given path. // at the location designated by "path" after the call to Commit.
// May be used to resume writing a stream by providing a nonzero offset. Writer(ctx context.Context, path string, append bool) (FileWriter, error)
WriteStream(ctx context.Context, path string, offset int64, reader io.Reader) (nn int64, err error)
// Stat retrieves the FileInfo for the given path, including the current // Stat retrieves the FileInfo for the given path, including the current
// size in bytes and the creation time. // size in bytes and the creation time.
@ -83,6 +86,25 @@ type StorageDriver interface {
URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error)
} }
// FileWriter provides an abstraction for an opened writable file-like object in
// the storage backend. The FileWriter must flush all content written to it on
// the call to Close, but is only required to make its content readable on a
// call to Commit.
type FileWriter interface {
io.WriteCloser
// Size returns the number of bytes written to this FileWriter.
Size() int64
// Cancel removes any written content from this FileWriter.
Cancel() error
// Commit flushes all content written to this FileWriter and makes it
// available for future calls to StorageDriver.GetContent and
// StorageDriver.Reader.
Commit() error
}
// PathRegexp is the regular expression which each file path must match. A // PathRegexp is the regular expression which each file path must match. A
// file path is absolute, beginning with a slash and containing a positive // file path is absolute, beginning with a slash and containing a positive
// number of path components separated by slashes, where each component is // number of path components separated by slashes, where each component is

View file

@ -16,8 +16,8 @@
package swift package swift
import ( import (
"bufio"
"bytes" "bytes"
"crypto/md5"
"crypto/rand" "crypto/rand"
"crypto/sha1" "crypto/sha1"
"crypto/tls" "crypto/tls"
@ -49,6 +49,9 @@ const defaultChunkSize = 20 * 1024 * 1024
// minChunkSize defines the minimum size of a segment // minChunkSize defines the minimum size of a segment
const minChunkSize = 1 << 20 const minChunkSize = 1 << 20
// contentType defines the Content-Type header associated with stored segments
const contentType = "application/octet-stream"
// readAfterWriteTimeout defines the time we wait before an object appears after having been uploaded // readAfterWriteTimeout defines the time we wait before an object appears after having been uploaded
var readAfterWriteTimeout = 15 * time.Second var readAfterWriteTimeout = 15 * time.Second
@ -66,6 +69,7 @@ type Parameters struct {
DomainID string DomainID string
TrustID string TrustID string
Region string Region string
AuthVersion int
Container string Container string
Prefix string Prefix string
InsecureSkipVerify bool InsecureSkipVerify bool
@ -171,6 +175,7 @@ func New(params Parameters) (*Driver, error) {
ApiKey: params.Password, ApiKey: params.Password,
AuthUrl: params.AuthURL, AuthUrl: params.AuthURL,
Region: params.Region, Region: params.Region,
AuthVersion: params.AuthVersion,
UserAgent: "distribution/" + version.Version, UserAgent: "distribution/" + version.Version,
Tenant: params.Tenant, Tenant: params.Tenant,
TenantId: params.TenantID, TenantId: params.TenantID,
@ -277,21 +282,21 @@ func (d *driver) GetContent(ctx context.Context, path string) ([]byte, error) {
if err == swift.ObjectNotFound { if err == swift.ObjectNotFound {
return nil, storagedriver.PathNotFoundError{Path: path} return nil, storagedriver.PathNotFoundError{Path: path}
} }
return content, nil return content, err
} }
// PutContent stores the []byte content at a location designated by "path". // PutContent stores the []byte content at a location designated by "path".
func (d *driver) PutContent(ctx context.Context, path string, contents []byte) error { func (d *driver) PutContent(ctx context.Context, path string, contents []byte) error {
err := d.Conn.ObjectPutBytes(d.Container, d.swiftPath(path), contents, d.getContentType()) err := d.Conn.ObjectPutBytes(d.Container, d.swiftPath(path), contents, contentType)
if err == swift.ObjectNotFound { if err == swift.ObjectNotFound {
return storagedriver.PathNotFoundError{Path: path} return storagedriver.PathNotFoundError{Path: path}
} }
return err return err
} }
// ReadStream retrieves an io.ReadCloser for the content stored at "path" with a // Reader retrieves an io.ReadCloser for the content stored at "path" with a
// given byte offset. // given byte offset.
func (d *driver) ReadStream(ctx context.Context, path string, offset int64) (io.ReadCloser, error) { func (d *driver) Reader(ctx context.Context, path string, offset int64) (io.ReadCloser, error) {
headers := make(swift.Headers) headers := make(swift.Headers)
headers["Range"] = "bytes=" + strconv.FormatInt(offset, 10) + "-" headers["Range"] = "bytes=" + strconv.FormatInt(offset, 10) + "-"
@ -305,224 +310,46 @@ func (d *driver) ReadStream(ctx context.Context, path string, offset int64) (io.
return file, err return file, err
} }
// WriteStream stores the contents of the provided io.Reader at a // Writer returns a FileWriter which will store the content written to it
// location designated by the given path. The driver will know it has // at the location designated by "path" after the call to Commit.
// received the full contents when the reader returns io.EOF. The number func (d *driver) Writer(ctx context.Context, path string, append bool) (storagedriver.FileWriter, error) {
// of successfully READ bytes will be returned, even if an error is
// returned. May be used to resume writing a stream by providing a nonzero
// offset. Offsets past the current size will write from the position
// beyond the end of the file.
func (d *driver) WriteStream(ctx context.Context, path string, offset int64, reader io.Reader) (int64, error) {
var ( var (
segments []swift.Object segments []swift.Object
multi io.Reader segmentsPath string
paddingReader io.Reader err error
currentLength int64
cursor int64
segmentPath string
) )
partNumber := 1 if !append {
chunkSize := int64(d.ChunkSize) segmentsPath, err = d.swiftSegmentPath(path)
zeroBuf := make([]byte, d.ChunkSize) if err != nil {
hash := md5.New() return nil, err
getSegment := func() string {
return fmt.Sprintf("%s/%016d", segmentPath, partNumber)
} }
} else {
max := func(a int64, b int64) int64 {
if a > b {
return a
}
return b
}
createManifest := true
info, headers, err := d.Conn.Object(d.Container, d.swiftPath(path)) info, headers, err := d.Conn.Object(d.Container, d.swiftPath(path))
if err == nil { if err == swift.ObjectNotFound {
return nil, storagedriver.PathNotFoundError{Path: path}
} else if err != nil {
return nil, err
}
manifest, ok := headers["X-Object-Manifest"] manifest, ok := headers["X-Object-Manifest"]
if !ok { if !ok {
if segmentPath, err = d.swiftSegmentPath(path); err != nil { segmentsPath, err = d.swiftSegmentPath(path)
return 0, err if err != nil {
return nil, err
} }
if err := d.Conn.ObjectMove(d.Container, d.swiftPath(path), d.Container, getSegment()); err != nil { if err := d.Conn.ObjectMove(d.Container, d.swiftPath(path), d.Container, getSegmentPath(segmentsPath, len(segments))); err != nil {
return 0, err return nil, err
} }
segments = append(segments, info) segments = []swift.Object{info}
} else { } else {
_, segmentPath = parseManifest(manifest) _, segmentsPath = parseManifest(manifest)
if segments, err = d.getAllSegments(segmentPath); err != nil { if segments, err = d.getAllSegments(segmentsPath); err != nil {
return 0, err return nil, err
} }
createManifest = false
}
currentLength = info.Bytes
} else if err == swift.ObjectNotFound {
if segmentPath, err = d.swiftSegmentPath(path); err != nil {
return 0, err
}
} else {
return 0, err
}
// First, we skip the existing segments that are not modified by this call
for i := range segments {
if offset < cursor+segments[i].Bytes {
break
}
cursor += segments[i].Bytes
hash.Write([]byte(segments[i].Hash))
partNumber++
}
// We reached the end of the file but we haven't reached 'offset' yet
// Therefore we add blocks of zeros
if offset >= currentLength {
for offset-currentLength >= chunkSize {
// Insert a block a zero
headers, err := d.Conn.ObjectPut(d.Container, getSegment(), bytes.NewReader(zeroBuf), false, "", d.getContentType(), nil)
if err != nil {
if err == swift.ObjectNotFound {
return 0, storagedriver.PathNotFoundError{Path: getSegment()}
}
return 0, err
}
currentLength += chunkSize
partNumber++
hash.Write([]byte(headers["Etag"]))
}
cursor = currentLength
paddingReader = bytes.NewReader(zeroBuf)
} else if offset-cursor > 0 {
// Offset is inside the current segment : we need to read the
// data from the beginning of the segment to offset
file, _, err := d.Conn.ObjectOpen(d.Container, getSegment(), false, nil)
if err != nil {
if err == swift.ObjectNotFound {
return 0, storagedriver.PathNotFoundError{Path: getSegment()}
}
return 0, err
}
defer file.Close()
paddingReader = file
}
readers := []io.Reader{}
if paddingReader != nil {
readers = append(readers, io.LimitReader(paddingReader, offset-cursor))
}
readers = append(readers, io.LimitReader(reader, chunkSize-(offset-cursor)))
multi = io.MultiReader(readers...)
writeSegment := func(segment string) (finished bool, bytesRead int64, err error) {
currentSegment, err := d.Conn.ObjectCreate(d.Container, segment, false, "", d.getContentType(), nil)
if err != nil {
if err == swift.ObjectNotFound {
return false, bytesRead, storagedriver.PathNotFoundError{Path: segment}
}
return false, bytesRead, err
}
segmentHash := md5.New()
writer := io.MultiWriter(currentSegment, segmentHash)
n, err := io.Copy(writer, multi)
if err != nil {
return false, bytesRead, err
}
if n > 0 {
defer func() {
closeError := currentSegment.Close()
if err != nil {
err = closeError
}
hexHash := hex.EncodeToString(segmentHash.Sum(nil))
hash.Write([]byte(hexHash))
}()
bytesRead += n - max(0, offset-cursor)
}
if n < chunkSize {
// We wrote all the data
if cursor+n < currentLength {
// Copy the end of the chunk
headers := make(swift.Headers)
headers["Range"] = "bytes=" + strconv.FormatInt(cursor+n, 10) + "-" + strconv.FormatInt(cursor+chunkSize, 10)
file, _, err := d.Conn.ObjectOpen(d.Container, d.swiftPath(path), false, headers)
if err != nil {
if err == swift.ObjectNotFound {
return false, bytesRead, storagedriver.PathNotFoundError{Path: path}
}
return false, bytesRead, err
}
_, copyErr := io.Copy(writer, file)
if err := file.Close(); err != nil {
if err == swift.ObjectNotFound {
return false, bytesRead, storagedriver.PathNotFoundError{Path: path}
}
return false, bytesRead, err
}
if copyErr != nil {
return false, bytesRead, copyErr
} }
} }
return true, bytesRead, nil return d.newWriter(path, segmentsPath, segments), nil
}
multi = io.LimitReader(reader, chunkSize)
cursor += chunkSize
partNumber++
return false, bytesRead, nil
}
finished := false
read := int64(0)
bytesRead := int64(0)
for finished == false {
finished, read, err = writeSegment(getSegment())
bytesRead += read
if err != nil {
return bytesRead, err
}
}
for ; partNumber < len(segments); partNumber++ {
hash.Write([]byte(segments[partNumber].Hash))
}
if createManifest {
if err := d.createManifest(path, d.Container+"/"+segmentPath); err != nil {
return 0, err
}
}
expectedHash := hex.EncodeToString(hash.Sum(nil))
waitingTime := readAfterWriteWait
endTime := time.Now().Add(readAfterWriteTimeout)
for {
var infos swift.Object
if infos, _, err = d.Conn.Object(d.Container, d.swiftPath(path)); err == nil {
if strings.Trim(infos.Hash, "\"") == expectedHash {
return bytesRead, nil
}
err = fmt.Errorf("Timeout expired while waiting for segments of %s to show up", path)
}
if time.Now().Add(waitingTime).After(endTime) {
break
}
time.Sleep(waitingTime)
waitingTime *= 2
}
return bytesRead, err
} }
// Stat retrieves the FileInfo for the given path, including the current size // Stat retrieves the FileInfo for the given path, including the current size
@ -551,8 +378,15 @@ func (d *driver) Stat(ctx context.Context, path string) (storagedriver.FileInfo,
fi.IsDir = true fi.IsDir = true
return storagedriver.FileInfoInternal{FileInfoFields: fi}, nil return storagedriver.FileInfoInternal{FileInfoFields: fi}, nil
} else if obj.Name == swiftPath { } else if obj.Name == swiftPath {
// On Swift 1.12, the 'bytes' field is always 0 // The file exists. But on Swift 1.12, the 'bytes' field is always 0 so
// so we need to do a second HEAD request // we need to do a separate HEAD request.
break
}
}
//Don't trust an empty `objects` slice. A container listing can be
//outdated. For files, we can make a HEAD request on the object which
//reports existence (at least) much more reliably.
info, _, err := d.Conn.Object(d.Container, swiftPath) info, _, err := d.Conn.Object(d.Container, swiftPath)
if err != nil { if err != nil {
if err == swift.ObjectNotFound { if err == swift.ObjectNotFound {
@ -564,10 +398,6 @@ func (d *driver) Stat(ctx context.Context, path string) (storagedriver.FileInfo,
fi.Size = info.Bytes fi.Size = info.Bytes
fi.ModTime = info.LastModified fi.ModTime = info.LastModified
return storagedriver.FileInfoInternal{FileInfoFields: fi}, nil return storagedriver.FileInfoInternal{FileInfoFields: fi}, nil
}
}
return nil, storagedriver.PathNotFoundError{Path: path}
} }
// List returns a list of the objects that are direct descendants of the given path. // List returns a list of the objects that are direct descendants of the given path.
@ -763,22 +593,59 @@ func (d *driver) swiftSegmentPath(path string) (string, error) {
return strings.TrimLeft(strings.TrimRight(d.Prefix+"/segments/"+path[0:3]+"/"+path[3:], "/"), "/"), nil return strings.TrimLeft(strings.TrimRight(d.Prefix+"/segments/"+path[0:3]+"/"+path[3:], "/"), "/"), nil
} }
func (d *driver) getContentType() string {
return "application/octet-stream"
}
func (d *driver) getAllSegments(path string) ([]swift.Object, error) { func (d *driver) getAllSegments(path string) ([]swift.Object, error) {
//a simple container listing works 99.9% of the time
segments, err := d.Conn.ObjectsAll(d.Container, &swift.ObjectsOpts{Prefix: path}) segments, err := d.Conn.ObjectsAll(d.Container, &swift.ObjectsOpts{Prefix: path})
if err != nil {
if err == swift.ContainerNotFound { if err == swift.ContainerNotFound {
return nil, storagedriver.PathNotFoundError{Path: path} return nil, storagedriver.PathNotFoundError{Path: path}
} }
return segments, err return nil, err
}
//build a lookup table by object name
hasObjectName := make(map[string]struct{})
for _, segment := range segments {
hasObjectName[segment.Name] = struct{}{}
}
//The container listing might be outdated (i.e. not contain all existing
//segment objects yet) because of temporary inconsistency (Swift is only
//eventually consistent!). Check its completeness.
segmentNumber := 0
for {
segmentNumber++
segmentPath := getSegmentPath(path, segmentNumber)
if _, seen := hasObjectName[segmentPath]; seen {
continue
}
//This segment is missing in the container listing. Use a more reliable
//request to check its existence. (HEAD requests on segments are
//guaranteed to return the correct metadata, except for the pathological
//case of an outage of large parts of the Swift cluster or its network,
//since every segment is only written once.)
segment, _, err := d.Conn.Object(d.Container, segmentPath)
switch err {
case nil:
//found new segment -> keep going, more might be missing
segments = append(segments, segment)
continue
case swift.ObjectNotFound:
//This segment is missing. Since we upload segments sequentially,
//there won't be any more segments after it.
return segments, nil
default:
return nil, err //unexpected error
}
}
} }
func (d *driver) createManifest(path string, segments string) error { func (d *driver) createManifest(path string, segments string) error {
headers := make(swift.Headers) headers := make(swift.Headers)
headers["X-Object-Manifest"] = segments headers["X-Object-Manifest"] = segments
manifest, err := d.Conn.ObjectCreate(d.Container, d.swiftPath(path), false, "", d.getContentType(), headers) manifest, err := d.Conn.ObjectCreate(d.Container, d.swiftPath(path), false, "", contentType, headers)
if err != nil { if err != nil {
if err == swift.ObjectNotFound { if err == swift.ObjectNotFound {
return storagedriver.PathNotFoundError{Path: path} return storagedriver.PathNotFoundError{Path: path}
@ -810,3 +677,159 @@ func generateSecret() (string, error) {
} }
return hex.EncodeToString(secretBytes[:]), nil return hex.EncodeToString(secretBytes[:]), nil
} }
func getSegmentPath(segmentsPath string, partNumber int) string {
return fmt.Sprintf("%s/%016d", segmentsPath, partNumber)
}
type writer struct {
driver *driver
path string
segmentsPath string
size int64
bw *bufio.Writer
closed bool
committed bool
cancelled bool
}
func (d *driver) newWriter(path, segmentsPath string, segments []swift.Object) storagedriver.FileWriter {
var size int64
for _, segment := range segments {
size += segment.Bytes
}
return &writer{
driver: d,
path: path,
segmentsPath: segmentsPath,
size: size,
bw: bufio.NewWriterSize(&segmentWriter{
conn: d.Conn,
container: d.Container,
segmentsPath: segmentsPath,
segmentNumber: len(segments) + 1,
maxChunkSize: d.ChunkSize,
}, d.ChunkSize),
}
}
func (w *writer) Write(p []byte) (int, error) {
if w.closed {
return 0, fmt.Errorf("already closed")
} else if w.committed {
return 0, fmt.Errorf("already committed")
} else if w.cancelled {
return 0, fmt.Errorf("already cancelled")
}
n, err := w.bw.Write(p)
w.size += int64(n)
return n, err
}
func (w *writer) Size() int64 {
return w.size
}
func (w *writer) Close() error {
if w.closed {
return fmt.Errorf("already closed")
}
if err := w.bw.Flush(); err != nil {
return err
}
if !w.committed && !w.cancelled {
if err := w.driver.createManifest(w.path, w.driver.Container+"/"+w.segmentsPath); err != nil {
return err
}
if err := w.waitForSegmentsToShowUp(); err != nil {
return err
}
}
w.closed = true
return nil
}
func (w *writer) Cancel() error {
if w.closed {
return fmt.Errorf("already closed")
} else if w.committed {
return fmt.Errorf("already committed")
}
w.cancelled = true
return w.driver.Delete(context.Background(), w.path)
}
func (w *writer) Commit() error {
if w.closed {
return fmt.Errorf("already closed")
} else if w.committed {
return fmt.Errorf("already committed")
} else if w.cancelled {
return fmt.Errorf("already cancelled")
}
if err := w.bw.Flush(); err != nil {
return err
}
if err := w.driver.createManifest(w.path, w.driver.Container+"/"+w.segmentsPath); err != nil {
return err
}
w.committed = true
return w.waitForSegmentsToShowUp()
}
func (w *writer) waitForSegmentsToShowUp() error {
var err error
waitingTime := readAfterWriteWait
endTime := time.Now().Add(readAfterWriteTimeout)
for {
var info swift.Object
if info, _, err = w.driver.Conn.Object(w.driver.Container, w.driver.swiftPath(w.path)); err == nil {
if info.Bytes == w.size {
break
}
err = fmt.Errorf("Timeout expired while waiting for segments of %s to show up", w.path)
}
if time.Now().Add(waitingTime).After(endTime) {
break
}
time.Sleep(waitingTime)
waitingTime *= 2
}
return err
}
type segmentWriter struct {
conn swift.Connection
container string
segmentsPath string
segmentNumber int
maxChunkSize int
}
func (sw *segmentWriter) Write(p []byte) (int, error) {
n := 0
for offset := 0; offset < len(p); offset += sw.maxChunkSize {
chunkSize := sw.maxChunkSize
if offset+chunkSize > len(p) {
chunkSize = len(p) - offset
}
_, err := sw.conn.ObjectPut(sw.container, getSegmentPath(sw.segmentsPath, sw.segmentNumber), bytes.NewReader(p[offset:offset+chunkSize]), false, "", contentType, nil)
if err != nil {
return n, err
}
sw.segmentNumber++
n += chunkSize
}
return n, nil
}

View file

@ -33,6 +33,7 @@ func init() {
trustID string trustID string
container string container string
region string region string
AuthVersion int
insecureSkipVerify bool insecureSkipVerify bool
secretKey string secretKey string
accessKey string accessKey string
@ -52,6 +53,7 @@ func init() {
trustID = os.Getenv("SWIFT_TRUST_ID") trustID = os.Getenv("SWIFT_TRUST_ID")
container = os.Getenv("SWIFT_CONTAINER_NAME") container = os.Getenv("SWIFT_CONTAINER_NAME")
region = os.Getenv("SWIFT_REGION_NAME") region = os.Getenv("SWIFT_REGION_NAME")
AuthVersion, _ = strconv.Atoi(os.Getenv("SWIFT_AUTH_VERSION"))
insecureSkipVerify, _ = strconv.ParseBool(os.Getenv("SWIFT_INSECURESKIPVERIFY")) insecureSkipVerify, _ = strconv.ParseBool(os.Getenv("SWIFT_INSECURESKIPVERIFY"))
secretKey = os.Getenv("SWIFT_SECRET_KEY") secretKey = os.Getenv("SWIFT_SECRET_KEY")
accessKey = os.Getenv("SWIFT_ACCESS_KEY") accessKey = os.Getenv("SWIFT_ACCESS_KEY")
@ -85,6 +87,7 @@ func init() {
domainID, domainID,
trustID, trustID,
region, region,
AuthVersion,
container, container,
root, root,
insecureSkipVerify, insecureSkipVerify,

View file

@ -282,11 +282,19 @@ func (suite *DriverSuite) TestWriteReadLargeStreams(c *check.C) {
var fileSize int64 = 5 * 1024 * 1024 * 1024 var fileSize int64 = 5 * 1024 * 1024 * 1024
contents := newRandReader(fileSize) contents := newRandReader(fileSize)
written, err := suite.StorageDriver.WriteStream(suite.ctx, filename, 0, io.TeeReader(contents, checksum))
writer, err := suite.StorageDriver.Writer(suite.ctx, filename, false)
c.Assert(err, check.IsNil)
written, err := io.Copy(writer, io.TeeReader(contents, checksum))
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(written, check.Equals, fileSize) c.Assert(written, check.Equals, fileSize)
reader, err := suite.StorageDriver.ReadStream(suite.ctx, filename, 0) err = writer.Commit()
c.Assert(err, check.IsNil)
err = writer.Close()
c.Assert(err, check.IsNil)
reader, err := suite.StorageDriver.Reader(suite.ctx, filename, 0)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
defer reader.Close() defer reader.Close()
@ -296,9 +304,9 @@ func (suite *DriverSuite) TestWriteReadLargeStreams(c *check.C) {
c.Assert(writtenChecksum.Sum(nil), check.DeepEquals, checksum.Sum(nil)) c.Assert(writtenChecksum.Sum(nil), check.DeepEquals, checksum.Sum(nil))
} }
// TestReadStreamWithOffset tests that the appropriate data is streamed when // TestReaderWithOffset tests that the appropriate data is streamed when
// reading with a given offset. // reading with a given offset.
func (suite *DriverSuite) TestReadStreamWithOffset(c *check.C) { func (suite *DriverSuite) TestReaderWithOffset(c *check.C) {
filename := randomPath(32) filename := randomPath(32)
defer suite.deletePath(c, firstPart(filename)) defer suite.deletePath(c, firstPart(filename))
@ -311,7 +319,7 @@ func (suite *DriverSuite) TestReadStreamWithOffset(c *check.C) {
err := suite.StorageDriver.PutContent(suite.ctx, filename, append(append(contentsChunk1, contentsChunk2...), contentsChunk3...)) err := suite.StorageDriver.PutContent(suite.ctx, filename, append(append(contentsChunk1, contentsChunk2...), contentsChunk3...))
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
reader, err := suite.StorageDriver.ReadStream(suite.ctx, filename, 0) reader, err := suite.StorageDriver.Reader(suite.ctx, filename, 0)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
defer reader.Close() defer reader.Close()
@ -320,7 +328,7 @@ func (suite *DriverSuite) TestReadStreamWithOffset(c *check.C) {
c.Assert(readContents, check.DeepEquals, append(append(contentsChunk1, contentsChunk2...), contentsChunk3...)) c.Assert(readContents, check.DeepEquals, append(append(contentsChunk1, contentsChunk2...), contentsChunk3...))
reader, err = suite.StorageDriver.ReadStream(suite.ctx, filename, chunkSize) reader, err = suite.StorageDriver.Reader(suite.ctx, filename, chunkSize)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
defer reader.Close() defer reader.Close()
@ -329,7 +337,7 @@ func (suite *DriverSuite) TestReadStreamWithOffset(c *check.C) {
c.Assert(readContents, check.DeepEquals, append(contentsChunk2, contentsChunk3...)) c.Assert(readContents, check.DeepEquals, append(contentsChunk2, contentsChunk3...))
reader, err = suite.StorageDriver.ReadStream(suite.ctx, filename, chunkSize*2) reader, err = suite.StorageDriver.Reader(suite.ctx, filename, chunkSize*2)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
defer reader.Close() defer reader.Close()
@ -338,7 +346,7 @@ func (suite *DriverSuite) TestReadStreamWithOffset(c *check.C) {
c.Assert(readContents, check.DeepEquals, contentsChunk3) c.Assert(readContents, check.DeepEquals, contentsChunk3)
// Ensure we get invalid offest for negative offsets. // Ensure we get invalid offest for negative offsets.
reader, err = suite.StorageDriver.ReadStream(suite.ctx, filename, -1) reader, err = suite.StorageDriver.Reader(suite.ctx, filename, -1)
c.Assert(err, check.FitsTypeOf, storagedriver.InvalidOffsetError{}) c.Assert(err, check.FitsTypeOf, storagedriver.InvalidOffsetError{})
c.Assert(err.(storagedriver.InvalidOffsetError).Offset, check.Equals, int64(-1)) c.Assert(err.(storagedriver.InvalidOffsetError).Offset, check.Equals, int64(-1))
c.Assert(err.(storagedriver.InvalidOffsetError).Path, check.Equals, filename) c.Assert(err.(storagedriver.InvalidOffsetError).Path, check.Equals, filename)
@ -347,7 +355,7 @@ func (suite *DriverSuite) TestReadStreamWithOffset(c *check.C) {
// Read past the end of the content and make sure we get a reader that // Read past the end of the content and make sure we get a reader that
// returns 0 bytes and io.EOF // returns 0 bytes and io.EOF
reader, err = suite.StorageDriver.ReadStream(suite.ctx, filename, chunkSize*3) reader, err = suite.StorageDriver.Reader(suite.ctx, filename, chunkSize*3)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
defer reader.Close() defer reader.Close()
@ -357,7 +365,7 @@ func (suite *DriverSuite) TestReadStreamWithOffset(c *check.C) {
c.Assert(n, check.Equals, 0) c.Assert(n, check.Equals, 0)
// Check the N-1 boundary condition, ensuring we get 1 byte then io.EOF. // Check the N-1 boundary condition, ensuring we get 1 byte then io.EOF.
reader, err = suite.StorageDriver.ReadStream(suite.ctx, filename, chunkSize*3-1) reader, err = suite.StorageDriver.Reader(suite.ctx, filename, chunkSize*3-1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
defer reader.Close() defer reader.Close()
@ -395,78 +403,51 @@ func (suite *DriverSuite) testContinueStreamAppend(c *check.C, chunkSize int64)
contentsChunk1 := randomContents(chunkSize) contentsChunk1 := randomContents(chunkSize)
contentsChunk2 := randomContents(chunkSize) contentsChunk2 := randomContents(chunkSize)
contentsChunk3 := randomContents(chunkSize) contentsChunk3 := randomContents(chunkSize)
contentsChunk4 := randomContents(chunkSize)
zeroChunk := make([]byte, int64(chunkSize))
fullContents := append(append(contentsChunk1, contentsChunk2...), contentsChunk3...) fullContents := append(append(contentsChunk1, contentsChunk2...), contentsChunk3...)
nn, err := suite.StorageDriver.WriteStream(suite.ctx, filename, 0, bytes.NewReader(contentsChunk1)) writer, err := suite.StorageDriver.Writer(suite.ctx, filename, false)
c.Assert(err, check.IsNil)
nn, err := io.Copy(writer, bytes.NewReader(contentsChunk1))
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(nn, check.Equals, int64(len(contentsChunk1))) c.Assert(nn, check.Equals, int64(len(contentsChunk1)))
fi, err := suite.StorageDriver.Stat(suite.ctx, filename) err = writer.Close()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(fi, check.NotNil)
c.Assert(fi.Size(), check.Equals, int64(len(contentsChunk1)))
nn, err = suite.StorageDriver.WriteStream(suite.ctx, filename, fi.Size(), bytes.NewReader(contentsChunk2)) curSize := writer.Size()
c.Assert(curSize, check.Equals, int64(len(contentsChunk1)))
writer, err = suite.StorageDriver.Writer(suite.ctx, filename, true)
c.Assert(err, check.IsNil)
c.Assert(writer.Size(), check.Equals, curSize)
nn, err = io.Copy(writer, bytes.NewReader(contentsChunk2))
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(nn, check.Equals, int64(len(contentsChunk2))) c.Assert(nn, check.Equals, int64(len(contentsChunk2)))
fi, err = suite.StorageDriver.Stat(suite.ctx, filename) err = writer.Close()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(fi, check.NotNil)
c.Assert(fi.Size(), check.Equals, 2*chunkSize)
// Test re-writing the last chunk curSize = writer.Size()
nn, err = suite.StorageDriver.WriteStream(suite.ctx, filename, fi.Size()-chunkSize, bytes.NewReader(contentsChunk2)) c.Assert(curSize, check.Equals, 2*chunkSize)
c.Assert(err, check.IsNil)
c.Assert(nn, check.Equals, int64(len(contentsChunk2)))
fi, err = suite.StorageDriver.Stat(suite.ctx, filename) writer, err = suite.StorageDriver.Writer(suite.ctx, filename, true)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(fi, check.NotNil) c.Assert(writer.Size(), check.Equals, curSize)
c.Assert(fi.Size(), check.Equals, 2*chunkSize)
nn, err = suite.StorageDriver.WriteStream(suite.ctx, filename, fi.Size(), bytes.NewReader(fullContents[fi.Size():])) nn, err = io.Copy(writer, bytes.NewReader(fullContents[curSize:]))
c.Assert(err, check.IsNil)
c.Assert(nn, check.Equals, int64(len(fullContents[curSize:])))
err = writer.Commit()
c.Assert(err, check.IsNil)
err = writer.Close()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(nn, check.Equals, int64(len(fullContents[fi.Size():])))
received, err := suite.StorageDriver.GetContent(suite.ctx, filename) received, err := suite.StorageDriver.GetContent(suite.ctx, filename)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(received, check.DeepEquals, fullContents) c.Assert(received, check.DeepEquals, fullContents)
// Writing past size of file extends file (no offset error). We would like
// to write chunk 4 one chunk length past chunk 3. It should be successful
// and the resulting file will be 5 chunks long, with a chunk of all
// zeros.
fullContents = append(fullContents, zeroChunk...)
fullContents = append(fullContents, contentsChunk4...)
nn, err = suite.StorageDriver.WriteStream(suite.ctx, filename, int64(len(fullContents))-chunkSize, bytes.NewReader(contentsChunk4))
c.Assert(err, check.IsNil)
c.Assert(nn, check.Equals, chunkSize)
fi, err = suite.StorageDriver.Stat(suite.ctx, filename)
c.Assert(err, check.IsNil)
c.Assert(fi, check.NotNil)
c.Assert(fi.Size(), check.Equals, int64(len(fullContents)))
received, err = suite.StorageDriver.GetContent(suite.ctx, filename)
c.Assert(err, check.IsNil)
c.Assert(len(received), check.Equals, len(fullContents))
c.Assert(received[chunkSize*3:chunkSize*4], check.DeepEquals, zeroChunk)
c.Assert(received[chunkSize*4:chunkSize*5], check.DeepEquals, contentsChunk4)
c.Assert(received, check.DeepEquals, fullContents)
// Ensure that negative offsets return correct error.
nn, err = suite.StorageDriver.WriteStream(suite.ctx, filename, -1, bytes.NewReader(zeroChunk))
c.Assert(err, check.NotNil)
c.Assert(err, check.FitsTypeOf, storagedriver.InvalidOffsetError{})
c.Assert(err.(storagedriver.InvalidOffsetError).Path, check.Equals, filename)
c.Assert(err.(storagedriver.InvalidOffsetError).Offset, check.Equals, int64(-1))
c.Assert(strings.Contains(err.Error(), suite.Name()), check.Equals, true)
} }
// TestReadNonexistentStream tests that reading a stream for a nonexistent path // TestReadNonexistentStream tests that reading a stream for a nonexistent path
@ -474,12 +455,12 @@ func (suite *DriverSuite) testContinueStreamAppend(c *check.C, chunkSize int64)
func (suite *DriverSuite) TestReadNonexistentStream(c *check.C) { func (suite *DriverSuite) TestReadNonexistentStream(c *check.C) {
filename := randomPath(32) filename := randomPath(32)
_, err := suite.StorageDriver.ReadStream(suite.ctx, filename, 0) _, err := suite.StorageDriver.Reader(suite.ctx, filename, 0)
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
c.Assert(err, check.FitsTypeOf, storagedriver.PathNotFoundError{}) c.Assert(err, check.FitsTypeOf, storagedriver.PathNotFoundError{})
c.Assert(strings.Contains(err.Error(), suite.Name()), check.Equals, true) c.Assert(strings.Contains(err.Error(), suite.Name()), check.Equals, true)
_, err = suite.StorageDriver.ReadStream(suite.ctx, filename, 64) _, err = suite.StorageDriver.Reader(suite.ctx, filename, 64)
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
c.Assert(err, check.FitsTypeOf, storagedriver.PathNotFoundError{}) c.Assert(err, check.FitsTypeOf, storagedriver.PathNotFoundError{})
c.Assert(strings.Contains(err.Error(), suite.Name()), check.Equals, true) c.Assert(strings.Contains(err.Error(), suite.Name()), check.Equals, true)
@ -800,7 +781,7 @@ func (suite *DriverSuite) TestStatCall(c *check.C) {
// TestPutContentMultipleTimes checks that if storage driver can overwrite the content // TestPutContentMultipleTimes checks that if storage driver can overwrite the content
// in the subsequent puts. Validates that PutContent does not have to work // in the subsequent puts. Validates that PutContent does not have to work
// with an offset like WriteStream does and overwrites the file entirely // with an offset like Writer does and overwrites the file entirely
// rather than writing the data to the [0,len(data)) of the file. // rather than writing the data to the [0,len(data)) of the file.
func (suite *DriverSuite) TestPutContentMultipleTimes(c *check.C) { func (suite *DriverSuite) TestPutContentMultipleTimes(c *check.C) {
filename := randomPath(32) filename := randomPath(32)
@ -842,7 +823,7 @@ func (suite *DriverSuite) TestConcurrentStreamReads(c *check.C) {
readContents := func() { readContents := func() {
defer wg.Done() defer wg.Done()
offset := rand.Int63n(int64(len(contents))) offset := rand.Int63n(int64(len(contents)))
reader, err := suite.StorageDriver.ReadStream(suite.ctx, filename, offset) reader, err := suite.StorageDriver.Reader(suite.ctx, filename, offset)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
readContents, err := ioutil.ReadAll(reader) readContents, err := ioutil.ReadAll(reader)
@ -858,7 +839,7 @@ func (suite *DriverSuite) TestConcurrentStreamReads(c *check.C) {
} }
// TestConcurrentFileStreams checks that multiple *os.File objects can be passed // TestConcurrentFileStreams checks that multiple *os.File objects can be passed
// in to WriteStream concurrently without hanging. // in to Writer concurrently without hanging.
func (suite *DriverSuite) TestConcurrentFileStreams(c *check.C) { func (suite *DriverSuite) TestConcurrentFileStreams(c *check.C) {
numStreams := 32 numStreams := 32
@ -882,53 +863,54 @@ func (suite *DriverSuite) TestConcurrentFileStreams(c *check.C) {
wg.Wait() wg.Wait()
} }
// TODO (brianbland): evaluate the relevancy of this test
// TestEventualConsistency checks that if stat says that a file is a certain size, then // TestEventualConsistency checks that if stat says that a file is a certain size, then
// you can freely read from the file (this is the only guarantee that the driver needs to provide) // you can freely read from the file (this is the only guarantee that the driver needs to provide)
func (suite *DriverSuite) TestEventualConsistency(c *check.C) { // func (suite *DriverSuite) TestEventualConsistency(c *check.C) {
if testing.Short() { // if testing.Short() {
c.Skip("Skipping test in short mode") // c.Skip("Skipping test in short mode")
} // }
//
filename := randomPath(32) // filename := randomPath(32)
defer suite.deletePath(c, firstPart(filename)) // defer suite.deletePath(c, firstPart(filename))
//
var offset int64 // var offset int64
var misswrites int // var misswrites int
var chunkSize int64 = 32 // var chunkSize int64 = 32
//
for i := 0; i < 1024; i++ { // for i := 0; i < 1024; i++ {
contents := randomContents(chunkSize) // contents := randomContents(chunkSize)
read, err := suite.StorageDriver.WriteStream(suite.ctx, filename, offset, bytes.NewReader(contents)) // read, err := suite.StorageDriver.Writer(suite.ctx, filename, offset, bytes.NewReader(contents))
c.Assert(err, check.IsNil) // c.Assert(err, check.IsNil)
//
fi, err := suite.StorageDriver.Stat(suite.ctx, filename) // fi, err := suite.StorageDriver.Stat(suite.ctx, filename)
c.Assert(err, check.IsNil) // c.Assert(err, check.IsNil)
//
// We are most concerned with being able to read data as soon as Stat declares // // We are most concerned with being able to read data as soon as Stat declares
// it is uploaded. This is the strongest guarantee that some drivers (that guarantee // // it is uploaded. This is the strongest guarantee that some drivers (that guarantee
// at best eventual consistency) absolutely need to provide. // // at best eventual consistency) absolutely need to provide.
if fi.Size() == offset+chunkSize { // if fi.Size() == offset+chunkSize {
reader, err := suite.StorageDriver.ReadStream(suite.ctx, filename, offset) // reader, err := suite.StorageDriver.Reader(suite.ctx, filename, offset)
c.Assert(err, check.IsNil) // c.Assert(err, check.IsNil)
//
readContents, err := ioutil.ReadAll(reader) // readContents, err := ioutil.ReadAll(reader)
c.Assert(err, check.IsNil) // c.Assert(err, check.IsNil)
//
c.Assert(readContents, check.DeepEquals, contents) // c.Assert(readContents, check.DeepEquals, contents)
//
reader.Close() // reader.Close()
offset += read // offset += read
} else { // } else {
misswrites++ // misswrites++
} // }
} // }
//
if misswrites > 0 { // if misswrites > 0 {
c.Log("There were " + string(misswrites) + " occurrences of a write not being instantly available.") // c.Log("There were " + string(misswrites) + " occurrences of a write not being instantly available.")
} // }
//
c.Assert(misswrites, check.Not(check.Equals), 1024) // c.Assert(misswrites, check.Not(check.Equals), 1024)
} // }
// BenchmarkPutGetEmptyFiles benchmarks PutContent/GetContent for 0B files // BenchmarkPutGetEmptyFiles benchmarks PutContent/GetContent for 0B files
func (suite *DriverSuite) BenchmarkPutGetEmptyFiles(c *check.C) { func (suite *DriverSuite) BenchmarkPutGetEmptyFiles(c *check.C) {
@ -968,22 +950,22 @@ func (suite *DriverSuite) benchmarkPutGetFiles(c *check.C, size int64) {
} }
} }
// BenchmarkStreamEmptyFiles benchmarks WriteStream/ReadStream for 0B files // BenchmarkStreamEmptyFiles benchmarks Writer/Reader for 0B files
func (suite *DriverSuite) BenchmarkStreamEmptyFiles(c *check.C) { func (suite *DriverSuite) BenchmarkStreamEmptyFiles(c *check.C) {
suite.benchmarkStreamFiles(c, 0) suite.benchmarkStreamFiles(c, 0)
} }
// BenchmarkStream1KBFiles benchmarks WriteStream/ReadStream for 1KB files // BenchmarkStream1KBFiles benchmarks Writer/Reader for 1KB files
func (suite *DriverSuite) BenchmarkStream1KBFiles(c *check.C) { func (suite *DriverSuite) BenchmarkStream1KBFiles(c *check.C) {
suite.benchmarkStreamFiles(c, 1024) suite.benchmarkStreamFiles(c, 1024)
} }
// BenchmarkStream1MBFiles benchmarks WriteStream/ReadStream for 1MB files // BenchmarkStream1MBFiles benchmarks Writer/Reader for 1MB files
func (suite *DriverSuite) BenchmarkStream1MBFiles(c *check.C) { func (suite *DriverSuite) BenchmarkStream1MBFiles(c *check.C) {
suite.benchmarkStreamFiles(c, 1024*1024) suite.benchmarkStreamFiles(c, 1024*1024)
} }
// BenchmarkStream1GBFiles benchmarks WriteStream/ReadStream for 1GB files // BenchmarkStream1GBFiles benchmarks Writer/Reader for 1GB files
func (suite *DriverSuite) BenchmarkStream1GBFiles(c *check.C) { func (suite *DriverSuite) BenchmarkStream1GBFiles(c *check.C) {
suite.benchmarkStreamFiles(c, 1024*1024*1024) suite.benchmarkStreamFiles(c, 1024*1024*1024)
} }
@ -998,11 +980,18 @@ func (suite *DriverSuite) benchmarkStreamFiles(c *check.C, size int64) {
for i := 0; i < c.N; i++ { for i := 0; i < c.N; i++ {
filename := path.Join(parentDir, randomPath(32)) filename := path.Join(parentDir, randomPath(32))
written, err := suite.StorageDriver.WriteStream(suite.ctx, filename, 0, bytes.NewReader(randomContents(size))) writer, err := suite.StorageDriver.Writer(suite.ctx, filename, false)
c.Assert(err, check.IsNil)
written, err := io.Copy(writer, bytes.NewReader(randomContents(size)))
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(written, check.Equals, size) c.Assert(written, check.Equals, size)
rc, err := suite.StorageDriver.ReadStream(suite.ctx, filename, 0) err = writer.Commit()
c.Assert(err, check.IsNil)
err = writer.Close()
c.Assert(err, check.IsNil)
rc, err := suite.StorageDriver.Reader(suite.ctx, filename, 0)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
rc.Close() rc.Close()
} }
@ -1083,11 +1072,18 @@ func (suite *DriverSuite) testFileStreams(c *check.C, size int64) {
tf.Sync() tf.Sync()
tf.Seek(0, os.SEEK_SET) tf.Seek(0, os.SEEK_SET)
nn, err := suite.StorageDriver.WriteStream(suite.ctx, filename, 0, tf) writer, err := suite.StorageDriver.Writer(suite.ctx, filename, false)
c.Assert(err, check.IsNil)
nn, err := io.Copy(writer, tf)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(nn, check.Equals, size) c.Assert(nn, check.Equals, size)
reader, err := suite.StorageDriver.ReadStream(suite.ctx, filename, 0) err = writer.Commit()
c.Assert(err, check.IsNil)
err = writer.Close()
c.Assert(err, check.IsNil)
reader, err := suite.StorageDriver.Reader(suite.ctx, filename, 0)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
defer reader.Close() defer reader.Close()
@ -1112,11 +1108,18 @@ func (suite *DriverSuite) writeReadCompare(c *check.C, filename string, contents
func (suite *DriverSuite) writeReadCompareStreams(c *check.C, filename string, contents []byte) { func (suite *DriverSuite) writeReadCompareStreams(c *check.C, filename string, contents []byte) {
defer suite.deletePath(c, firstPart(filename)) defer suite.deletePath(c, firstPart(filename))
nn, err := suite.StorageDriver.WriteStream(suite.ctx, filename, 0, bytes.NewReader(contents)) writer, err := suite.StorageDriver.Writer(suite.ctx, filename, false)
c.Assert(err, check.IsNil)
nn, err := io.Copy(writer, bytes.NewReader(contents))
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(nn, check.Equals, int64(len(contents))) c.Assert(nn, check.Equals, int64(len(contents)))
reader, err := suite.StorageDriver.ReadStream(suite.ctx, filename, 0) err = writer.Commit()
c.Assert(err, check.IsNil)
err = writer.Close()
c.Assert(err, check.IsNil)
reader, err := suite.StorageDriver.Reader(suite.ctx, filename, 0)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
defer reader.Close() defer reader.Close()

View file

@ -119,7 +119,7 @@ func (fr *fileReader) reader() (io.Reader, error) {
} }
// If we don't have a reader, open one up. // If we don't have a reader, open one up.
rc, err := fr.driver.ReadStream(fr.ctx, fr.path, fr.offset) rc, err := fr.driver.Reader(fr.ctx, fr.path, fr.offset)
if err != nil { if err != nil {
switch err := err.(type) { switch err := err.(type) {
case storagedriver.PathNotFoundError: case storagedriver.PathNotFoundError:

View file

@ -183,7 +183,7 @@ func TestFileReaderNonExistentFile(t *testing.T) {
// conditions that can arise when reading a layer. // conditions that can arise when reading a layer.
func TestFileReaderErrors(t *testing.T) { func TestFileReaderErrors(t *testing.T) {
// TODO(stevvooe): We need to cover error return types, driven by the // TODO(stevvooe): We need to cover error return types, driven by the
// errors returned via the HTTP API. For now, here is a incomplete list: // errors returned via the HTTP API. For now, here is an incomplete list:
// //
// 1. Layer Not Found: returned when layer is not found or access is // 1. Layer Not Found: returned when layer is not found or access is
// denied. // denied.

View file

@ -1,135 +0,0 @@
package storage
import (
"bytes"
"fmt"
"io"
"os"
"github.com/docker/distribution/context"
storagedriver "github.com/docker/distribution/registry/storage/driver"
)
// fileWriter implements a remote file writer backed by a storage driver.
type fileWriter struct {
driver storagedriver.StorageDriver
ctx context.Context
// identifying fields
path string
// mutable fields
size int64 // size of the file, aka the current end
offset int64 // offset is the current write offset
err error // terminal error, if set, reader is closed
}
// fileWriterInterface makes the desired io compliant interface that the
// filewriter should implement.
type fileWriterInterface interface {
io.WriteSeeker
io.ReaderFrom
io.Closer
}
var _ fileWriterInterface = &fileWriter{}
// newFileWriter returns a prepared fileWriter for the driver and path. This
// could be considered similar to an "open" call on a regular filesystem.
func newFileWriter(ctx context.Context, driver storagedriver.StorageDriver, path string) (*fileWriter, error) {
fw := fileWriter{
driver: driver,
path: path,
ctx: ctx,
}
if fi, err := driver.Stat(ctx, path); err != nil {
switch err := err.(type) {
case storagedriver.PathNotFoundError:
// ignore, offset is zero
default:
return nil, err
}
} else {
if fi.IsDir() {
return nil, fmt.Errorf("cannot write to a directory")
}
fw.size = fi.Size()
}
return &fw, nil
}
// Write writes the buffer p at the current write offset.
func (fw *fileWriter) Write(p []byte) (n int, err error) {
nn, err := fw.ReadFrom(bytes.NewReader(p))
return int(nn), err
}
// ReadFrom reads reader r until io.EOF writing the contents at the current
// offset.
func (fw *fileWriter) ReadFrom(r io.Reader) (n int64, err error) {
if fw.err != nil {
return 0, fw.err
}
nn, err := fw.driver.WriteStream(fw.ctx, fw.path, fw.offset, r)
// We should forward the offset, whether or not there was an error.
// Basically, we keep the filewriter in sync with the reader's head. If an
// error is encountered, the whole thing should be retried but we proceed
// from an expected offset, even if the data didn't make it to the
// backend.
fw.offset += nn
if fw.offset > fw.size {
fw.size = fw.offset
}
return nn, err
}
// Seek moves the write position do the requested offest based on the whence
// argument, which can be os.SEEK_CUR, os.SEEK_END, or os.SEEK_SET.
func (fw *fileWriter) Seek(offset int64, whence int) (int64, error) {
if fw.err != nil {
return 0, fw.err
}
var err error
newOffset := fw.offset
switch whence {
case os.SEEK_CUR:
newOffset += int64(offset)
case os.SEEK_END:
newOffset = fw.size + int64(offset)
case os.SEEK_SET:
newOffset = int64(offset)
}
if newOffset < 0 {
err = fmt.Errorf("cannot seek to negative position")
} else {
// No problems, set the offset.
fw.offset = newOffset
}
return fw.offset, err
}
// Close closes the fileWriter for writing.
// Calling it once is valid and correct and it will
// return a nil error. Calling it subsequent times will
// detect that fw.err has been set and will return the error.
func (fw *fileWriter) Close() error {
if fw.err != nil {
return fw.err
}
fw.err = fmt.Errorf("filewriter@%v: closed", fw.path)
return nil
}

View file

@ -1,226 +0,0 @@
package storage
import (
"bytes"
"crypto/rand"
"io"
"os"
"testing"
"github.com/docker/distribution/context"
"github.com/docker/distribution/digest"
storagedriver "github.com/docker/distribution/registry/storage/driver"
"github.com/docker/distribution/registry/storage/driver/inmemory"
)
// TestSimpleWrite takes the fileWriter through common write operations
// ensuring data integrity.
func TestSimpleWrite(t *testing.T) {
content := make([]byte, 1<<20)
n, err := rand.Read(content)
if err != nil {
t.Fatalf("unexpected error building random data: %v", err)
}
if n != len(content) {
t.Fatalf("random read did't fill buffer")
}
dgst, err := digest.FromReader(bytes.NewReader(content))
if err != nil {
t.Fatalf("unexpected error digesting random content: %v", err)
}
driver := inmemory.New()
path := "/random"
ctx := context.Background()
fw, err := newFileWriter(ctx, driver, path)
if err != nil {
t.Fatalf("unexpected error creating fileWriter: %v", err)
}
defer fw.Close()
n, err = fw.Write(content)
if err != nil {
t.Fatalf("unexpected error writing content: %v", err)
}
if n != len(content) {
t.Fatalf("unexpected write length: %d != %d", n, len(content))
}
fr, err := newFileReader(ctx, driver, path, int64(len(content)))
if err != nil {
t.Fatalf("unexpected error creating fileReader: %v", err)
}
defer fr.Close()
verifier, err := digest.NewDigestVerifier(dgst)
if err != nil {
t.Fatalf("unexpected error getting digest verifier: %s", err)
}
io.Copy(verifier, fr)
if !verifier.Verified() {
t.Fatalf("unable to verify write data")
}
// Check the seek position is equal to the content length
end, err := fw.Seek(0, os.SEEK_END)
if err != nil {
t.Fatalf("unexpected error seeking: %v", err)
}
if end != int64(len(content)) {
t.Fatalf("write did not advance offset: %d != %d", end, len(content))
}
// Double the content
doubled := append(content, content...)
doubledgst, err := digest.FromReader(bytes.NewReader(doubled))
if err != nil {
t.Fatalf("unexpected error digesting doubled content: %v", err)
}
nn, err := fw.ReadFrom(bytes.NewReader(content))
if err != nil {
t.Fatalf("unexpected error doubling content: %v", err)
}
if nn != int64(len(content)) {
t.Fatalf("writeat was short: %d != %d", n, len(content))
}
fr, err = newFileReader(ctx, driver, path, int64(len(doubled)))
if err != nil {
t.Fatalf("unexpected error creating fileReader: %v", err)
}
defer fr.Close()
verifier, err = digest.NewDigestVerifier(doubledgst)
if err != nil {
t.Fatalf("unexpected error getting digest verifier: %s", err)
}
io.Copy(verifier, fr)
if !verifier.Verified() {
t.Fatalf("unable to verify write data")
}
// Check that Write updated the offset.
end, err = fw.Seek(0, os.SEEK_END)
if err != nil {
t.Fatalf("unexpected error seeking: %v", err)
}
if end != int64(len(doubled)) {
t.Fatalf("write did not advance offset: %d != %d", end, len(doubled))
}
// Now, we copy from one path to another, running the data through the
// fileReader to fileWriter, rather than the driver.Move command to ensure
// everything is working correctly.
fr, err = newFileReader(ctx, driver, path, int64(len(doubled)))
if err != nil {
t.Fatalf("unexpected error creating fileReader: %v", err)
}
defer fr.Close()
fw, err = newFileWriter(ctx, driver, "/copied")
if err != nil {
t.Fatalf("unexpected error creating fileWriter: %v", err)
}
defer fw.Close()
nn, err = io.Copy(fw, fr)
if err != nil {
t.Fatalf("unexpected error copying data: %v", err)
}
if nn != int64(len(doubled)) {
t.Fatalf("unexpected copy length: %d != %d", nn, len(doubled))
}
fr, err = newFileReader(ctx, driver, "/copied", int64(len(doubled)))
if err != nil {
t.Fatalf("unexpected error creating fileReader: %v", err)
}
defer fr.Close()
verifier, err = digest.NewDigestVerifier(doubledgst)
if err != nil {
t.Fatalf("unexpected error getting digest verifier: %s", err)
}
io.Copy(verifier, fr)
if !verifier.Verified() {
t.Fatalf("unable to verify write data")
}
}
func BenchmarkFileWriter(b *testing.B) {
b.StopTimer() // not sure how long setup above will take
for i := 0; i < b.N; i++ {
// Start basic fileWriter initialization
fw := fileWriter{
driver: inmemory.New(),
path: "/random",
}
ctx := context.Background()
if fi, err := fw.driver.Stat(ctx, fw.path); err != nil {
switch err := err.(type) {
case storagedriver.PathNotFoundError:
// ignore, offset is zero
default:
b.Fatalf("Failed to initialize fileWriter: %v", err.Error())
}
} else {
if fi.IsDir() {
b.Fatalf("Cannot write to a directory")
}
fw.size = fi.Size()
}
randomBytes := make([]byte, 1<<20)
_, err := rand.Read(randomBytes)
if err != nil {
b.Fatalf("unexpected error building random data: %v", err)
}
// End basic file writer initialization
b.StartTimer()
for j := 0; j < 100; j++ {
fw.Write(randomBytes)
}
b.StopTimer()
}
}
func BenchmarkfileWriter(b *testing.B) {
b.StopTimer() // not sure how long setup above will take
ctx := context.Background()
for i := 0; i < b.N; i++ {
bfw, err := newFileWriter(ctx, inmemory.New(), "/random")
if err != nil {
b.Fatalf("Failed to initialize fileWriter: %v", err.Error())
}
randomBytes := make([]byte, 1<<20)
_, err = rand.Read(randomBytes)
if err != nil {
b.Fatalf("unexpected error building random data: %v", err)
}
b.StartTimer()
for j := 0; j < 100; j++ {
bfw.Write(randomBytes)
}
b.StopTimer()
}
}

View file

@ -1,39 +1,34 @@
package registry package storage
import ( import (
"fmt" "fmt"
"os"
"github.com/docker/distribution" "github.com/docker/distribution"
"github.com/docker/distribution/context" "github.com/docker/distribution/context"
"github.com/docker/distribution/digest" "github.com/docker/distribution/digest"
"github.com/docker/distribution/manifest/schema1"
"github.com/docker/distribution/manifest/schema2" "github.com/docker/distribution/manifest/schema2"
"github.com/docker/distribution/reference" "github.com/docker/distribution/reference"
"github.com/docker/distribution/registry/storage"
"github.com/docker/distribution/registry/storage/driver" "github.com/docker/distribution/registry/storage/driver"
"github.com/docker/distribution/registry/storage/driver/factory"
"github.com/spf13/cobra"
) )
func markAndSweep(storageDriver driver.StorageDriver) error { func emit(format string, a ...interface{}) {
ctx := context.Background() fmt.Printf(format+"\n", a...)
}
// Construct a registry
registry, err := storage.NewRegistry(ctx, storageDriver)
if err != nil {
return fmt.Errorf("failed to construct registry: %v", err)
}
// MarkAndSweep performs a mark and sweep of registry data
func MarkAndSweep(ctx context.Context, storageDriver driver.StorageDriver, registry distribution.Namespace, dryRun bool) error {
repositoryEnumerator, ok := registry.(distribution.RepositoryEnumerator) repositoryEnumerator, ok := registry.(distribution.RepositoryEnumerator)
if !ok { if !ok {
return fmt.Errorf("coercion error: unable to convert Namespace to RepositoryEnumerator") return fmt.Errorf("unable to convert Namespace to RepositoryEnumerator")
} }
// mark // mark
markSet := make(map[digest.Digest]struct{}) markSet := make(map[digest.Digest]struct{})
err = repositoryEnumerator.Enumerate(ctx, func(repoName string) error { err := repositoryEnumerator.Enumerate(ctx, func(repoName string) error {
if dryRun {
emit(repoName)
}
var err error var err error
named, err := reference.ParseNamed(repoName) named, err := reference.ParseNamed(repoName)
if err != nil { if err != nil {
@ -51,11 +46,14 @@ func markAndSweep(storageDriver driver.StorageDriver) error {
manifestEnumerator, ok := manifestService.(distribution.ManifestEnumerator) manifestEnumerator, ok := manifestService.(distribution.ManifestEnumerator)
if !ok { if !ok {
return fmt.Errorf("coercion error: unable to convert ManifestService into ManifestEnumerator") return fmt.Errorf("unable to convert ManifestService into ManifestEnumerator")
} }
err = manifestEnumerator.Enumerate(ctx, func(dgst digest.Digest) error { err = manifestEnumerator.Enumerate(ctx, func(dgst digest.Digest) error {
// Mark the manifest's blob // Mark the manifest's blob
if dryRun {
emit("%s: marking manifest %s ", repoName, dgst)
}
markSet[dgst] = struct{}{} markSet[dgst] = struct{}{}
manifest, err := manifestService.Get(ctx, dgst) manifest, err := manifestService.Get(ctx, dgst)
@ -66,24 +64,17 @@ func markAndSweep(storageDriver driver.StorageDriver) error {
descriptors := manifest.References() descriptors := manifest.References()
for _, descriptor := range descriptors { for _, descriptor := range descriptors {
markSet[descriptor.Digest] = struct{}{} markSet[descriptor.Digest] = struct{}{}
if dryRun {
emit("%s: marking blob %s", repoName, descriptor.Digest)
}
} }
switch manifest.(type) { switch manifest.(type) {
case *schema1.SignedManifest:
signaturesGetter, ok := manifestService.(distribution.SignaturesGetter)
if !ok {
return fmt.Errorf("coercion error: unable to convert ManifestSErvice into SignaturesGetter")
}
signatures, err := signaturesGetter.GetSignatures(ctx, dgst)
if err != nil {
return fmt.Errorf("failed to get signatures for signed manifest: %v", err)
}
for _, signatureDigest := range signatures {
markSet[signatureDigest] = struct{}{}
}
break
case *schema2.DeserializedManifest: case *schema2.DeserializedManifest:
config := manifest.(*schema2.DeserializedManifest).Config config := manifest.(*schema2.DeserializedManifest).Config
if dryRun {
emit("%s: marking configuration %s", repoName, config.Digest)
}
markSet[config.Digest] = struct{}{} markSet[config.Digest] = struct{}{}
break break
} }
@ -91,6 +82,17 @@ func markAndSweep(storageDriver driver.StorageDriver) error {
return nil return nil
}) })
if err != nil {
// In certain situations such as unfinished uploads, deleting all
// tags in S3 or removing the _manifests folder manually, this
// error may be of type PathNotFound.
//
// In these cases we can continue marking other manifests safely.
if _, ok := err.(driver.PathNotFoundError); ok {
return nil
}
}
return err return err
}) })
@ -108,10 +110,19 @@ func markAndSweep(storageDriver driver.StorageDriver) error {
} }
return nil return nil
}) })
if err != nil {
return fmt.Errorf("error enumerating blobs: %v", err)
}
if dryRun {
emit("\n%d blobs marked, %d blobs eligible for deletion", len(markSet), len(deleteSet))
}
// Construct vacuum // Construct vacuum
vacuum := storage.NewVacuum(ctx, storageDriver) vacuum := NewVacuum(ctx, storageDriver)
for dgst := range deleteSet { for dgst := range deleteSet {
if dryRun {
emit("blob eligible for deletion: %s", dgst)
continue
}
err = vacuum.RemoveBlob(string(dgst)) err = vacuum.RemoveBlob(string(dgst))
if err != nil { if err != nil {
return fmt.Errorf("failed to delete blob %s: %v\n", dgst, err) return fmt.Errorf("failed to delete blob %s: %v\n", dgst, err)
@ -120,31 +131,3 @@ func markAndSweep(storageDriver driver.StorageDriver) error {
return err return err
} }
// GCCmd is the cobra command that corresponds to the garbage-collect subcommand
var GCCmd = &cobra.Command{
Use: "garbage-collect <config>",
Short: "`garbage-collects` deletes layers not referenced by any manifests",
Long: "`garbage-collects` deletes layers not referenced by any manifests",
Run: func(cmd *cobra.Command, args []string) {
config, err := resolveConfiguration(args)
if err != nil {
fmt.Fprintf(os.Stderr, "configuration error: %v\n", err)
cmd.Usage()
os.Exit(1)
}
driver, err := factory.Create(config.Storage.Type(), config.Storage.Parameters())
if err != nil {
fmt.Fprintf(os.Stderr, "failed to construct %s driver: %v", config.Storage.Type(), err)
os.Exit(1)
}
err = markAndSweep(driver)
if err != nil {
fmt.Fprintf(os.Stderr, "failed to garbage collect: %v", err)
os.Exit(1)
}
},
}

View file

@ -1,17 +1,18 @@
package registry package storage
import ( import (
"io" "io"
"path"
"testing" "testing"
"github.com/docker/distribution" "github.com/docker/distribution"
"github.com/docker/distribution/context" "github.com/docker/distribution/context"
"github.com/docker/distribution/digest" "github.com/docker/distribution/digest"
"github.com/docker/distribution/reference" "github.com/docker/distribution/reference"
"github.com/docker/distribution/registry/storage"
"github.com/docker/distribution/registry/storage/driver" "github.com/docker/distribution/registry/storage/driver"
"github.com/docker/distribution/registry/storage/driver/inmemory" "github.com/docker/distribution/registry/storage/driver/inmemory"
"github.com/docker/distribution/testutil" "github.com/docker/distribution/testutil"
"github.com/docker/libtrust"
) )
type image struct { type image struct {
@ -22,7 +23,11 @@ type image struct {
func createRegistry(t *testing.T, driver driver.StorageDriver) distribution.Namespace { func createRegistry(t *testing.T, driver driver.StorageDriver) distribution.Namespace {
ctx := context.Background() ctx := context.Background()
registry, err := storage.NewRegistry(ctx, driver, storage.EnableDelete) k, err := libtrust.GenerateECP256PrivateKey()
if err != nil {
t.Fatal(err)
}
registry, err := NewRegistry(ctx, driver, EnableDelete, Schema1SigningKey(k))
if err != nil { if err != nil {
t.Fatalf("Failed to construct namespace") t.Fatalf("Failed to construct namespace")
} }
@ -139,13 +144,13 @@ func TestNoDeletionNoEffect(t *testing.T) {
ctx := context.Background() ctx := context.Background()
inmemoryDriver := inmemory.New() inmemoryDriver := inmemory.New()
registry := createRegistry(t, inmemoryDriver) registry := createRegistry(t, inmemory.New())
repo := makeRepository(t, registry, "palailogos") repo := makeRepository(t, registry, "palailogos")
manifestService, err := repo.Manifests(ctx) manifestService, err := repo.Manifests(ctx)
image1 := uploadRandomSchema1Image(t, repo) image1 := uploadRandomSchema1Image(t, repo)
image2 := uploadRandomSchema1Image(t, repo) image2 := uploadRandomSchema1Image(t, repo)
image3 := uploadRandomSchema2Image(t, repo) uploadRandomSchema2Image(t, repo)
// construct manifestlist for fun. // construct manifestlist for fun.
blobstatter := registry.BlobStatter() blobstatter := registry.BlobStatter()
@ -160,20 +165,48 @@ func TestNoDeletionNoEffect(t *testing.T) {
t.Fatalf("Failed to add manifest list: %v", err) t.Fatalf("Failed to add manifest list: %v", err)
} }
before := allBlobs(t, registry)
// Run GC // Run GC
err = markAndSweep(inmemoryDriver) err = MarkAndSweep(context.Background(), inmemoryDriver, registry, false)
if err != nil {
t.Fatalf("Failed mark and sweep: %v", err)
}
after := allBlobs(t, registry)
if len(before) != len(after) {
t.Fatalf("Garbage collection affected storage: %d != %d", len(before), len(after))
}
}
func TestGCWithMissingManifests(t *testing.T) {
ctx := context.Background()
d := inmemory.New()
registry := createRegistry(t, d)
repo := makeRepository(t, registry, "testrepo")
uploadRandomSchema1Image(t, repo)
// Simulate a missing _manifests directory
revPath, err := pathFor(manifestRevisionsPathSpec{"testrepo"})
if err != nil {
t.Fatal(err)
}
_manifestsPath := path.Dir(revPath)
err = d.Delete(ctx, _manifestsPath)
if err != nil {
t.Fatal(err)
}
err = MarkAndSweep(context.Background(), d, registry, false)
if err != nil { if err != nil {
t.Fatalf("Failed mark and sweep: %v", err) t.Fatalf("Failed mark and sweep: %v", err)
} }
blobs := allBlobs(t, registry) blobs := allBlobs(t, registry)
if len(blobs) > 0 {
// the +1 at the end is for the manifestList t.Errorf("unexpected blobs after gc")
// the first +3 at the end for each manifest's blob
// the second +3 at the end for each manifest's signature/config layer
totalBlobCount := len(image1.layers) + len(image2.layers) + len(image3.layers) + 1 + 3 + 3
if len(blobs) != totalBlobCount {
t.Fatalf("Garbage collection affected storage")
} }
} }
@ -193,7 +226,7 @@ func TestDeletionHasEffect(t *testing.T) {
manifests.Delete(ctx, image3.manifestDigest) manifests.Delete(ctx, image3.manifestDigest)
// Run GC // Run GC
err = markAndSweep(inmemoryDriver) err = MarkAndSweep(context.Background(), inmemoryDriver, registry, false)
if err != nil { if err != nil {
t.Fatalf("Failed mark and sweep: %v", err) t.Fatalf("Failed mark and sweep: %v", err)
} }
@ -327,7 +360,7 @@ func TestOrphanBlobDeleted(t *testing.T) {
uploadRandomSchema2Image(t, repo) uploadRandomSchema2Image(t, repo)
// Run GC // Run GC
err = markAndSweep(inmemoryDriver) err = MarkAndSweep(context.Background(), inmemoryDriver, registry, false)
if err != nil { if err != nil {
t.Fatalf("Failed mark and sweep: %v", err) t.Fatalf("Failed mark and sweep: %v", err)
} }

View file

@ -35,7 +35,7 @@ type linkedBlobStore struct {
// control the repository blob link set to which the blob store // control the repository blob link set to which the blob store
// dispatches. This is required because manifest and layer blobs have not // dispatches. This is required because manifest and layer blobs have not
// yet been fully merged. At some point, this functionality should be // yet been fully merged. At some point, this functionality should be
// removed an the blob links folder should be merged. The first entry is // removed the blob links folder should be merged. The first entry is
// treated as the "canonical" link location and will be used for writes. // treated as the "canonical" link location and will be used for writes.
linkPathFns []linkPathFunc linkPathFns []linkPathFunc
@ -179,7 +179,7 @@ func (lbs *linkedBlobStore) Create(ctx context.Context, options ...distribution.
return nil, err return nil, err
} }
return lbs.newBlobUpload(ctx, uuid, path, startedAt) return lbs.newBlobUpload(ctx, uuid, path, startedAt, false)
} }
func (lbs *linkedBlobStore) Resume(ctx context.Context, id string) (distribution.BlobWriter, error) { func (lbs *linkedBlobStore) Resume(ctx context.Context, id string) (distribution.BlobWriter, error) {
@ -218,7 +218,7 @@ func (lbs *linkedBlobStore) Resume(ctx context.Context, id string) (distribution
return nil, err return nil, err
} }
return lbs.newBlobUpload(ctx, id, path, startedAt) return lbs.newBlobUpload(ctx, id, path, startedAt, true)
} }
func (lbs *linkedBlobStore) Delete(ctx context.Context, dgst digest.Digest) error { func (lbs *linkedBlobStore) Delete(ctx context.Context, dgst digest.Digest) error {
@ -312,18 +312,21 @@ func (lbs *linkedBlobStore) mount(ctx context.Context, sourceRepo reference.Name
} }
// newBlobUpload allocates a new upload controller with the given state. // newBlobUpload allocates a new upload controller with the given state.
func (lbs *linkedBlobStore) newBlobUpload(ctx context.Context, uuid, path string, startedAt time.Time) (distribution.BlobWriter, error) { func (lbs *linkedBlobStore) newBlobUpload(ctx context.Context, uuid, path string, startedAt time.Time, append bool) (distribution.BlobWriter, error) {
fw, err := newFileWriter(ctx, lbs.driver, path) fw, err := lbs.driver.Writer(ctx, path, append)
if err != nil { if err != nil {
return nil, err return nil, err
} }
bw := &blobWriter{ bw := &blobWriter{
ctx: ctx,
blobStore: lbs, blobStore: lbs,
id: uuid, id: uuid,
startedAt: startedAt, startedAt: startedAt,
digester: digest.Canonical.New(), digester: digest.Canonical.New(),
fileWriter: *fw, fileWriter: fw,
driver: lbs.driver,
path: path,
resumableDigestEnabled: lbs.resumableDigestEnabled, resumableDigestEnabled: lbs.resumableDigestEnabled,
} }
@ -381,7 +384,7 @@ var _ distribution.BlobDescriptorService = &linkedBlobStatter{}
func (lbs *linkedBlobStatter) Stat(ctx context.Context, dgst digest.Digest) (distribution.Descriptor, error) { func (lbs *linkedBlobStatter) Stat(ctx context.Context, dgst digest.Digest) (distribution.Descriptor, error) {
var ( var (
resolveErr error found bool
target digest.Digest target digest.Digest
) )
@ -392,19 +395,20 @@ func (lbs *linkedBlobStatter) Stat(ctx context.Context, dgst digest.Digest) (dis
target, err = lbs.resolveWithLinkFunc(ctx, dgst, linkPathFn) target, err = lbs.resolveWithLinkFunc(ctx, dgst, linkPathFn)
if err == nil { if err == nil {
found = true
break // success! break // success!
} }
switch err := err.(type) { switch err := err.(type) {
case driver.PathNotFoundError: case driver.PathNotFoundError:
resolveErr = distribution.ErrBlobUnknown // move to the next linkPathFn, saving the error // do nothing, just move to the next linkPathFn
default: default:
return distribution.Descriptor{}, err return distribution.Descriptor{}, err
} }
} }
if resolveErr != nil { if !found {
return distribution.Descriptor{}, resolveErr return distribution.Descriptor{}, distribution.ErrBlobUnknown
} }
if target != dgst { if target != dgst {

View file

@ -2,7 +2,6 @@ package storage
import ( import (
"fmt" "fmt"
"path"
"encoding/json" "encoding/json"
"github.com/docker/distribution" "github.com/docker/distribution"
@ -140,42 +139,3 @@ func (ms *manifestStore) Enumerate(ctx context.Context, ingester func(digest.Dig
}) })
return err return err
} }
// Only valid for schema1 signed manifests
func (ms *manifestStore) GetSignatures(ctx context.Context, manifestDigest digest.Digest) ([]digest.Digest, error) {
// sanity check that digest refers to a schema1 digest
manifest, err := ms.Get(ctx, manifestDigest)
if err != nil {
return nil, err
}
if _, ok := manifest.(*schema1.SignedManifest); !ok {
return nil, fmt.Errorf("digest %v is not for schema1 manifest", manifestDigest)
}
signaturesPath, err := pathFor(manifestSignaturesPathSpec{
name: ms.repository.Named().Name(),
revision: manifestDigest,
})
if err != nil {
return nil, err
}
signaturesPath = path.Join(signaturesPath, "sha256")
signaturePaths, err := ms.blobStore.driver.List(ctx, signaturesPath)
if err != nil {
return nil, err
}
var digests []digest.Digest
for _, sigPath := range signaturePaths {
sigdigest, err := digest.ParseDigest("sha256:" + path.Base(sigPath))
if err != nil {
// merely found not a digest
continue
}
digests = append(digests, sigdigest)
}
return digests, nil
}

View file

@ -52,15 +52,11 @@ func newManifestStoreTestEnv(t *testing.T, name reference.Named, tag string, opt
} }
func TestManifestStorage(t *testing.T) { func TestManifestStorage(t *testing.T) {
testManifestStorage(t, BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider()), EnableDelete, EnableRedirect)
}
func TestManifestStorageDisabledSignatures(t *testing.T) {
k, err := libtrust.GenerateECP256PrivateKey() k, err := libtrust.GenerateECP256PrivateKey()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
testManifestStorage(t, BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider()), EnableDelete, EnableRedirect, DisableSchema1Signatures, Schema1SigningKey(k)) testManifestStorage(t, BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider()), EnableDelete, EnableRedirect, Schema1SigningKey(k))
} }
func testManifestStorage(t *testing.T, options ...RegistryOption) { func testManifestStorage(t *testing.T, options ...RegistryOption) {
@ -71,7 +67,6 @@ func testManifestStorage(t *testing.T, options ...RegistryOption) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
equalSignatures := env.registry.(*registry).schema1SignaturesEnabled
m := schema1.Manifest{ m := schema1.Manifest{
Versioned: manifest.Versioned{ Versioned: manifest.Versioned{
@ -175,12 +170,6 @@ func testManifestStorage(t *testing.T, options ...RegistryOption) {
t.Fatalf("fetched payload does not match original payload: %q != %q", fetchedManifest.Canonical, sm.Canonical) t.Fatalf("fetched payload does not match original payload: %q != %q", fetchedManifest.Canonical, sm.Canonical)
} }
if equalSignatures {
if !reflect.DeepEqual(fetchedManifest, sm) {
t.Fatalf("fetched manifest not equal: %#v != %#v", fetchedManifest.Manifest, sm.Manifest)
}
}
_, pl, err := fetchedManifest.Payload() _, pl, err := fetchedManifest.Payload()
if err != nil { if err != nil {
t.Fatalf("error getting payload %#v", err) t.Fatalf("error getting payload %#v", err)
@ -223,12 +212,6 @@ func testManifestStorage(t *testing.T, options ...RegistryOption) {
t.Fatalf("fetched manifest not equal: %q != %q", byDigestManifest.Canonical, fetchedManifest.Canonical) t.Fatalf("fetched manifest not equal: %q != %q", byDigestManifest.Canonical, fetchedManifest.Canonical)
} }
if equalSignatures {
if !reflect.DeepEqual(fetchedByDigest, fetchedManifest) {
t.Fatalf("fetched manifest not equal: %#v != %#v", fetchedByDigest, fetchedManifest)
}
}
sigs, err := fetchedJWS.Signatures() sigs, err := fetchedJWS.Signatures()
if err != nil { if err != nil {
t.Fatalf("unable to extract signatures: %v", err) t.Fatalf("unable to extract signatures: %v", err)
@ -285,17 +268,6 @@ func testManifestStorage(t *testing.T, options ...RegistryOption) {
t.Fatalf("unexpected error verifying manifest: %v", err) t.Fatalf("unexpected error verifying manifest: %v", err)
} }
// Assemble our payload and two signatures to get what we expect!
expectedJWS, err := libtrust.NewJSONSignature(payload, sigs[0], sigs2[0])
if err != nil {
t.Fatalf("unexpected error merging jws: %v", err)
}
expectedSigs, err := expectedJWS.Signatures()
if err != nil {
t.Fatalf("unexpected error getting expected signatures: %v", err)
}
_, pl, err = fetched.Payload() _, pl, err = fetched.Payload()
if err != nil { if err != nil {
t.Fatalf("error getting payload %#v", err) t.Fatalf("error getting payload %#v", err)
@ -315,19 +287,6 @@ func testManifestStorage(t *testing.T, options ...RegistryOption) {
t.Fatalf("payloads are not equal") t.Fatalf("payloads are not equal")
} }
if equalSignatures {
receivedSigs, err := receivedJWS.Signatures()
if err != nil {
t.Fatalf("error getting signatures: %v", err)
}
for i, sig := range receivedSigs {
if !bytes.Equal(sig, expectedSigs[i]) {
t.Fatalf("mismatched signatures from remote: %v != %v", string(sig), string(expectedSigs[i]))
}
}
}
// Test deleting manifests // Test deleting manifests
err = ms.Delete(ctx, dgst) err = ms.Delete(ctx, dgst)
if err != nil { if err != nil {

View file

@ -30,8 +30,6 @@ const (
// revisions // revisions
// -> <manifest digest path> // -> <manifest digest path>
// -> link // -> link
// -> signatures
// <algorithm>/<digest>/link
// tags/<tag> // tags/<tag>
// -> current/link // -> current/link
// -> index // -> index
@ -62,8 +60,7 @@ const (
// //
// The third component of the repository directory is the manifests store, // The third component of the repository directory is the manifests store,
// which is made up of a revision store and tag store. Manifests are stored in // which is made up of a revision store and tag store. Manifests are stored in
// the blob store and linked into the revision store. Signatures are separated // the blob store and linked into the revision store.
// from the manifest payload data and linked into the blob store, as well.
// While the registry can save all revisions of a manifest, no relationship is // While the registry can save all revisions of a manifest, no relationship is
// implied as to the ordering of changes to a manifest. The tag store provides // implied as to the ordering of changes to a manifest. The tag store provides
// support for name, tag lookups of manifests, using "current/link" under a // support for name, tag lookups of manifests, using "current/link" under a
@ -77,8 +74,6 @@ const (
// manifestRevisionsPathSpec: <root>/v2/repositories/<name>/_manifests/revisions/ // manifestRevisionsPathSpec: <root>/v2/repositories/<name>/_manifests/revisions/
// manifestRevisionPathSpec: <root>/v2/repositories/<name>/_manifests/revisions/<algorithm>/<hex digest>/ // manifestRevisionPathSpec: <root>/v2/repositories/<name>/_manifests/revisions/<algorithm>/<hex digest>/
// manifestRevisionLinkPathSpec: <root>/v2/repositories/<name>/_manifests/revisions/<algorithm>/<hex digest>/link // manifestRevisionLinkPathSpec: <root>/v2/repositories/<name>/_manifests/revisions/<algorithm>/<hex digest>/link
// manifestSignaturesPathSpec: <root>/v2/repositories/<name>/_manifests/revisions/<algorithm>/<hex digest>/signatures/
// manifestSignatureLinkPathSpec: <root>/v2/repositories/<name>/_manifests/revisions/<algorithm>/<hex digest>/signatures/<algorithm>/<hex digest>/link
// //
// Tags: // Tags:
// //
@ -148,33 +143,6 @@ func pathFor(spec pathSpec) (string, error) {
} }
return path.Join(root, "link"), nil return path.Join(root, "link"), nil
case manifestSignaturesPathSpec:
root, err := pathFor(manifestRevisionPathSpec{
name: v.name,
revision: v.revision,
})
if err != nil {
return "", err
}
return path.Join(root, "signatures"), nil
case manifestSignatureLinkPathSpec:
root, err := pathFor(manifestSignaturesPathSpec{
name: v.name,
revision: v.revision,
})
if err != nil {
return "", err
}
signatureComponents, err := digestPathComponents(v.signature, false)
if err != nil {
return "", err
}
return path.Join(root, path.Join(append(signatureComponents, "link")...)), nil
case manifestTagsPathSpec: case manifestTagsPathSpec:
return path.Join(append(repoPrefix, v.name, "_manifests", "tags")...), nil return path.Join(append(repoPrefix, v.name, "_manifests", "tags")...), nil
case manifestTagPathSpec: case manifestTagPathSpec:
@ -325,26 +293,6 @@ type manifestRevisionLinkPathSpec struct {
func (manifestRevisionLinkPathSpec) pathSpec() {} func (manifestRevisionLinkPathSpec) pathSpec() {}
// manifestSignaturesPathSpec describes the path components for the directory
// containing all the signatures for the target blob. Entries are named with
// the underlying key id.
type manifestSignaturesPathSpec struct {
name string
revision digest.Digest
}
func (manifestSignaturesPathSpec) pathSpec() {}
// manifestSignatureLinkPathSpec describes the path components used to look up
// a signature file by the hash of its blob.
type manifestSignatureLinkPathSpec struct {
name string
revision digest.Digest
signature digest.Digest
}
func (manifestSignatureLinkPathSpec) pathSpec() {}
// manifestTagsPathSpec describes the path elements required to point to the // manifestTagsPathSpec describes the path elements required to point to the
// manifest tags directory. // manifest tags directory.
type manifestTagsPathSpec struct { type manifestTagsPathSpec struct {

View file

@ -26,21 +26,6 @@ func TestPathMapper(t *testing.T) {
}, },
expected: "/docker/registry/v2/repositories/foo/bar/_manifests/revisions/sha256/abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789/link", expected: "/docker/registry/v2/repositories/foo/bar/_manifests/revisions/sha256/abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789/link",
}, },
{
spec: manifestSignatureLinkPathSpec{
name: "foo/bar",
revision: "sha256:abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789",
signature: "sha256:abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789",
},
expected: "/docker/registry/v2/repositories/foo/bar/_manifests/revisions/sha256/abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789/signatures/sha256/abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789/link",
},
{
spec: manifestSignaturesPathSpec{
name: "foo/bar",
revision: "sha256:abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789",
},
expected: "/docker/registry/v2/repositories/foo/bar/_manifests/revisions/sha256/abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789/signatures",
},
{ {
spec: manifestTagsPathSpec{ spec: manifestTagsPathSpec{
name: "foo/bar", name: "foo/bar",
@ -113,7 +98,7 @@ func TestPathMapper(t *testing.T) {
// Add a few test cases to ensure we cover some errors // Add a few test cases to ensure we cover some errors
// Specify a path that requires a revision and get a digest validation error. // Specify a path that requires a revision and get a digest validation error.
badpath, err := pathFor(manifestSignaturesPathSpec{ badpath, err := pathFor(manifestRevisionPathSpec{
name: "foo/bar", name: "foo/bar",
}) })

View file

@ -18,8 +18,8 @@ type registry struct {
blobDescriptorCacheProvider cache.BlobDescriptorCacheProvider blobDescriptorCacheProvider cache.BlobDescriptorCacheProvider
deleteEnabled bool deleteEnabled bool
resumableDigestEnabled bool resumableDigestEnabled bool
schema1SignaturesEnabled bool
schema1SigningKey libtrust.PrivateKey schema1SigningKey libtrust.PrivateKey
blobDescriptorServiceFactory distribution.BlobDescriptorServiceFactory
} }
// RegistryOption is the type used for functional options for NewRegistry. // RegistryOption is the type used for functional options for NewRegistry.
@ -46,17 +46,8 @@ func DisableDigestResumption(registry *registry) error {
return nil return nil
} }
// DisableSchema1Signatures is a functional option for NewRegistry. It disables
// signature storage and ensures all schema1 manifests will only be returned
// with a signature from a provided signing key.
func DisableSchema1Signatures(registry *registry) error {
registry.schema1SignaturesEnabled = false
return nil
}
// Schema1SigningKey returns a functional option for NewRegistry. It sets the // Schema1SigningKey returns a functional option for NewRegistry. It sets the
// signing key for adding a signature to all schema1 manifests. This should be // key for signing all schema1 manifests.
// used in conjunction with disabling signature store.
func Schema1SigningKey(key libtrust.PrivateKey) RegistryOption { func Schema1SigningKey(key libtrust.PrivateKey) RegistryOption {
return func(registry *registry) error { return func(registry *registry) error {
registry.schema1SigningKey = key registry.schema1SigningKey = key
@ -64,6 +55,15 @@ func Schema1SigningKey(key libtrust.PrivateKey) RegistryOption {
} }
} }
// BlobDescriptorServiceFactory returns a functional option for NewRegistry. It sets the
// factory to create BlobDescriptorServiceFactory middleware.
func BlobDescriptorServiceFactory(factory distribution.BlobDescriptorServiceFactory) RegistryOption {
return func(registry *registry) error {
registry.blobDescriptorServiceFactory = factory
return nil
}
}
// BlobDescriptorCacheProvider returns a functional option for // BlobDescriptorCacheProvider returns a functional option for
// NewRegistry. It creates a cached blob statter for use by the // NewRegistry. It creates a cached blob statter for use by the
// registry. // registry.
@ -108,7 +108,6 @@ func NewRegistry(ctx context.Context, driver storagedriver.StorageDriver, option
}, },
statter: statter, statter: statter,
resumableDigestEnabled: true, resumableDigestEnabled: true,
schema1SignaturesEnabled: true,
} }
for _, option := range options { for _, option := range options {
@ -190,16 +189,22 @@ func (repo *repository) Manifests(ctx context.Context, options ...distribution.M
manifestDirectoryPathSpec := manifestRevisionsPathSpec{name: repo.name.Name()} manifestDirectoryPathSpec := manifestRevisionsPathSpec{name: repo.name.Name()}
var statter distribution.BlobDescriptorService = &linkedBlobStatter{
blobStore: repo.blobStore,
repository: repo,
linkPathFns: manifestLinkPathFns,
}
if repo.registry.blobDescriptorServiceFactory != nil {
statter = repo.registry.blobDescriptorServiceFactory.BlobAccessController(statter)
}
blobStore := &linkedBlobStore{ blobStore := &linkedBlobStore{
ctx: ctx, ctx: ctx,
blobStore: repo.blobStore, blobStore: repo.blobStore,
repository: repo, repository: repo,
deleteEnabled: repo.registry.deleteEnabled, deleteEnabled: repo.registry.deleteEnabled,
blobAccessController: &linkedBlobStatter{ blobAccessController: statter,
blobStore: repo.blobStore,
repository: repo,
linkPathFns: manifestLinkPathFns,
},
// TODO(stevvooe): linkPath limits this blob store to only // TODO(stevvooe): linkPath limits this blob store to only
// manifests. This instance cannot be used for blob checks. // manifests. This instance cannot be used for blob checks.
@ -215,11 +220,6 @@ func (repo *repository) Manifests(ctx context.Context, options ...distribution.M
ctx: ctx, ctx: ctx,
repository: repo, repository: repo,
blobStore: blobStore, blobStore: blobStore,
signatures: &signatureStore{
ctx: ctx,
repository: repo,
blobStore: repo.blobStore,
},
}, },
schema2Handler: &schema2ManifestHandler{ schema2Handler: &schema2ManifestHandler{
ctx: ctx, ctx: ctx,
@ -258,6 +258,10 @@ func (repo *repository) Blobs(ctx context.Context) distribution.BlobStore {
statter = cache.NewCachedBlobStatter(repo.descriptorCache, statter) statter = cache.NewCachedBlobStatter(repo.descriptorCache, statter)
} }
if repo.registry.blobDescriptorServiceFactory != nil {
statter = repo.registry.blobDescriptorServiceFactory.BlobAccessController(statter)
}
return &linkedBlobStore{ return &linkedBlobStore{
registry: repo.registry, registry: repo.registry,
blobStore: repo.blobStore, blobStore: repo.blobStore,

View file

@ -1,15 +1,24 @@
package storage package storage
import ( import (
"errors"
"fmt" "fmt"
"net/url"
"encoding/json" "encoding/json"
"github.com/docker/distribution" "github.com/docker/distribution"
"github.com/docker/distribution/context" "github.com/docker/distribution/context"
"github.com/docker/distribution/digest" "github.com/docker/distribution/digest"
"github.com/docker/distribution/manifest/schema2" "github.com/docker/distribution/manifest/schema2"
) )
var (
errUnexpectedURL = errors.New("unexpected URL on layer")
errMissingURL = errors.New("missing URL on layer")
errInvalidURL = errors.New("invalid URL on layer")
)
//schema2ManifestHandler is a ManifestHandler that covers schema2 manifests. //schema2ManifestHandler is a ManifestHandler that covers schema2 manifests.
type schema2ManifestHandler struct { type schema2ManifestHandler struct {
repository *repository repository *repository
@ -80,7 +89,27 @@ func (ms *schema2ManifestHandler) verifyManifest(ctx context.Context, mnfst sche
} }
for _, fsLayer := range mnfst.References() { for _, fsLayer := range mnfst.References() {
_, err := ms.repository.Blobs(ctx).Stat(ctx, fsLayer.Digest) var err error
if fsLayer.MediaType != schema2.MediaTypeForeignLayer {
if len(fsLayer.URLs) == 0 {
_, err = ms.repository.Blobs(ctx).Stat(ctx, fsLayer.Digest)
} else {
err = errUnexpectedURL
}
} else {
// Clients download this layer from an external URL, so do not check for
// its presense.
if len(fsLayer.URLs) == 0 {
err = errMissingURL
}
for _, u := range fsLayer.URLs {
var pu *url.URL
pu, err = url.Parse(u)
if err != nil || (pu.Scheme != "http" && pu.Scheme != "https") || pu.Fragment != "" {
err = errInvalidURL
}
}
}
if err != nil { if err != nil {
if err != distribution.ErrBlobUnknown { if err != distribution.ErrBlobUnknown {
errs = append(errs, err) errs = append(errs, err)

View file

@ -0,0 +1,117 @@
package storage
import (
"testing"
"github.com/docker/distribution"
"github.com/docker/distribution/context"
"github.com/docker/distribution/manifest"
"github.com/docker/distribution/manifest/schema2"
"github.com/docker/distribution/registry/storage/driver/inmemory"
)
func TestVerifyManifestForeignLayer(t *testing.T) {
ctx := context.Background()
inmemoryDriver := inmemory.New()
registry := createRegistry(t, inmemoryDriver)
repo := makeRepository(t, registry, "test")
manifestService := makeManifestService(t, repo)
config, err := repo.Blobs(ctx).Put(ctx, schema2.MediaTypeConfig, nil)
if err != nil {
t.Fatal(err)
}
layer, err := repo.Blobs(ctx).Put(ctx, schema2.MediaTypeLayer, nil)
if err != nil {
t.Fatal(err)
}
foreignLayer := distribution.Descriptor{
Digest: "sha256:463435349086340864309863409683460843608348608934092322395278926a",
Size: 6323,
MediaType: schema2.MediaTypeForeignLayer,
}
template := schema2.Manifest{
Versioned: manifest.Versioned{
SchemaVersion: 2,
MediaType: schema2.MediaTypeManifest,
},
Config: config,
}
type testcase struct {
BaseLayer distribution.Descriptor
URLs []string
Err error
}
cases := []testcase{
{
foreignLayer,
nil,
errMissingURL,
},
{
layer,
[]string{"http://foo/bar"},
errUnexpectedURL,
},
{
foreignLayer,
[]string{"file:///local/file"},
errInvalidURL,
},
{
foreignLayer,
[]string{"http://foo/bar#baz"},
errInvalidURL,
},
{
foreignLayer,
[]string{""},
errInvalidURL,
},
{
foreignLayer,
[]string{"https://foo/bar", ""},
errInvalidURL,
},
{
foreignLayer,
[]string{"http://foo/bar"},
nil,
},
{
foreignLayer,
[]string{"https://foo/bar"},
nil,
},
}
for _, c := range cases {
m := template
l := c.BaseLayer
l.URLs = c.URLs
m.Layers = []distribution.Descriptor{l}
dm, err := schema2.FromStruct(m)
if err != nil {
t.Error(err)
continue
}
_, err = manifestService.Put(ctx, dm)
if verr, ok := err.(distribution.ErrManifestVerification); ok {
// Extract the first error
if len(verr) == 2 {
if _, ok = verr[1].(distribution.ErrManifestBlobUnknown); ok {
err = verr[0]
}
}
}
if err != c.Err {
t.Errorf("%#v: expected %v, got %v", l, c.Err, err)
}
}
}

View file

@ -1,131 +0,0 @@
package storage
import (
"path"
"sync"
"github.com/docker/distribution/context"
"github.com/docker/distribution/digest"
)
type signatureStore struct {
repository *repository
blobStore *blobStore
ctx context.Context
}
func (s *signatureStore) Get(dgst digest.Digest) ([][]byte, error) {
signaturesPath, err := pathFor(manifestSignaturesPathSpec{
name: s.repository.Named().Name(),
revision: dgst,
})
if err != nil {
return nil, err
}
// Need to append signature digest algorithm to path to get all items.
// Perhaps, this should be in the pathMapper but it feels awkward. This
// can be eliminated by implementing listAll on drivers.
signaturesPath = path.Join(signaturesPath, "sha256")
signaturePaths, err := s.blobStore.driver.List(s.ctx, signaturesPath)
if err != nil {
return nil, err
}
var wg sync.WaitGroup
type result struct {
index int
signature []byte
err error
}
ch := make(chan result)
bs := s.linkedBlobStore(s.ctx, dgst)
for i, sigPath := range signaturePaths {
sigdgst, err := digest.ParseDigest("sha256:" + path.Base(sigPath))
if err != nil {
context.GetLogger(s.ctx).Errorf("could not get digest from path: %q, skipping", sigPath)
continue
}
wg.Add(1)
go func(idx int, sigdgst digest.Digest) {
defer wg.Done()
context.GetLogger(s.ctx).
Debugf("fetching signature %q", sigdgst)
r := result{index: idx}
if p, err := bs.Get(s.ctx, sigdgst); err != nil {
context.GetLogger(s.ctx).
Errorf("error fetching signature %q: %v", sigdgst, err)
r.err = err
} else {
r.signature = p
}
ch <- r
}(i, sigdgst)
}
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
// aggregrate the results
signatures := make([][]byte, len(signaturePaths))
loop:
for {
select {
case result := <-ch:
signatures[result.index] = result.signature
if result.err != nil && err == nil {
// only set the first one.
err = result.err
}
case <-done:
break loop
}
}
return signatures, err
}
func (s *signatureStore) Put(dgst digest.Digest, signatures ...[]byte) error {
bs := s.linkedBlobStore(s.ctx, dgst)
for _, signature := range signatures {
if _, err := bs.Put(s.ctx, "application/json", signature); err != nil {
return err
}
}
return nil
}
// linkedBlobStore returns the namedBlobStore of the signatures for the
// manifest with the given digest. Effectively, each signature link path
// layout is a unique linked blob store.
func (s *signatureStore) linkedBlobStore(ctx context.Context, revision digest.Digest) *linkedBlobStore {
linkpath := func(name string, dgst digest.Digest) (string, error) {
return pathFor(manifestSignatureLinkPathSpec{
name: name,
revision: revision,
signature: dgst,
})
}
return &linkedBlobStore{
ctx: ctx,
repository: s.repository,
blobStore: s.blobStore,
blobAccessController: &linkedBlobStatter{
blobStore: s.blobStore,
repository: s.repository,
linkPathFns: []linkPathFunc{linkpath},
},
linkPathFns: []linkPathFunc{linkpath},
}
}

View file

@ -18,7 +18,6 @@ type signedManifestHandler struct {
repository *repository repository *repository
blobStore *linkedBlobStore blobStore *linkedBlobStore
ctx context.Context ctx context.Context
signatures *signatureStore
} }
var _ ManifestHandler = &signedManifestHandler{} var _ ManifestHandler = &signedManifestHandler{}
@ -30,13 +29,6 @@ func (ms *signedManifestHandler) Unmarshal(ctx context.Context, dgst digest.Dige
signatures [][]byte signatures [][]byte
err error err error
) )
if ms.repository.schema1SignaturesEnabled {
// Fetch the signatures for the manifest
signatures, err = ms.signatures.Get(dgst)
if err != nil {
return nil, err
}
}
jsig, err := libtrust.NewJSONSignature(content, signatures...) jsig, err := libtrust.NewJSONSignature(content, signatures...)
if err != nil { if err != nil {
@ -47,8 +39,6 @@ func (ms *signedManifestHandler) Unmarshal(ctx context.Context, dgst digest.Dige
if err := jsig.Sign(ms.repository.schema1SigningKey); err != nil { if err := jsig.Sign(ms.repository.schema1SigningKey); err != nil {
return nil, err return nil, err
} }
} else if !ms.repository.schema1SignaturesEnabled {
return nil, fmt.Errorf("missing signing key with signature store disabled")
} }
// Extract the pretty JWS // Extract the pretty JWS
@ -90,18 +80,6 @@ func (ms *signedManifestHandler) Put(ctx context.Context, manifest distribution.
return "", err return "", err
} }
if ms.repository.schema1SignaturesEnabled {
// Grab each json signature and store them.
signatures, err := sm.Signatures()
if err != nil {
return "", err
}
if err := ms.signatures.Put(revision.Digest, signatures...); err != nil {
return "", err
}
}
return revision.Digest, nil return revision.Digest, nil
} }