diff --git a/contrib/token-server/main.go b/contrib/token-server/main.go index edd894f4..f53d5c8d 100644 --- a/contrib/token-server/main.go +++ b/contrib/token-server/main.go @@ -176,6 +176,18 @@ func filterAccessList(ctx context.Context, scope string, requestedAccessList []a return grantedAccessList } +type acctSubject struct{} + +func (acctSubject) String() string { return "acctSubject" } + +type requestedAccess struct{} + +func (requestedAccess) String() string { return "requestedAccess" } + +type grantedAccess struct{} + +func (grantedAccess) String() string { return "grantedAccess" } + // getToken handles authenticating the request and authorizing access to the // requested scopes. func (ts *tokenServer) getToken(ctx context.Context, w http.ResponseWriter, r *http.Request) { @@ -218,17 +230,17 @@ func (ts *tokenServer) getToken(ctx context.Context, w http.ResponseWriter, r *h username := context.GetStringValue(ctx, "auth.user.name") - ctx = context.WithValue(ctx, "acctSubject", username) - ctx = context.WithLogger(ctx, context.GetLogger(ctx, "acctSubject")) + ctx = context.WithValue(ctx, acctSubject{}, username) + ctx = context.WithLogger(ctx, context.GetLogger(ctx, acctSubject{})) context.GetLogger(ctx).Info("authenticated client") - ctx = context.WithValue(ctx, "requestedAccess", requestedAccessList) - ctx = context.WithLogger(ctx, context.GetLogger(ctx, "requestedAccess")) + ctx = context.WithValue(ctx, requestedAccess{}, requestedAccessList) + ctx = context.WithLogger(ctx, context.GetLogger(ctx, requestedAccess{})) grantedAccessList := filterAccessList(ctx, username, requestedAccessList) - ctx = context.WithValue(ctx, "grantedAccess", grantedAccessList) - ctx = context.WithLogger(ctx, context.GetLogger(ctx, "grantedAccess")) + ctx = context.WithValue(ctx, grantedAccess{}, grantedAccessList) + ctx = context.WithLogger(ctx, context.GetLogger(ctx, grantedAccess{})) token, err := ts.issuer.CreateJWT(username, service, grantedAccessList) if err != nil { @@ -340,17 +352,17 @@ func (ts *tokenServer) postToken(ctx context.Context, w http.ResponseWriter, r * return } - ctx = context.WithValue(ctx, "acctSubject", subject) - ctx = context.WithLogger(ctx, context.GetLogger(ctx, "acctSubject")) + ctx = context.WithValue(ctx, acctSubject{}, subject) + ctx = context.WithLogger(ctx, context.GetLogger(ctx, acctSubject{})) context.GetLogger(ctx).Info("authenticated client") - ctx = context.WithValue(ctx, "requestedAccess", requestedAccessList) - ctx = context.WithLogger(ctx, context.GetLogger(ctx, "requestedAccess")) + ctx = context.WithValue(ctx, requestedAccess{}, requestedAccessList) + ctx = context.WithLogger(ctx, context.GetLogger(ctx, requestedAccess{})) grantedAccessList := filterAccessList(ctx, subject, requestedAccessList) - ctx = context.WithValue(ctx, "grantedAccess", grantedAccessList) - ctx = context.WithLogger(ctx, context.GetLogger(ctx, "grantedAccess")) + ctx = context.WithValue(ctx, grantedAccess{}, grantedAccessList) + ctx = context.WithLogger(ctx, context.GetLogger(ctx, grantedAccess{})) token, err := ts.issuer.CreateJWT(subject, service, grantedAccessList) if err != nil { diff --git a/registry/auth/silly/access_test.go b/registry/auth/silly/access_test.go index a7c14cb9..0a5103e6 100644 --- a/registry/auth/silly/access_test.go +++ b/registry/auth/silly/access_test.go @@ -16,7 +16,7 @@ func TestSillyAccessController(t *testing.T) { } server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx := context.WithValue(nil, "http.request", r) + ctx := context.WithRequest(context.Background(), r) authCtx, err := ac.Authorized(ctx) if err != nil { switch err := err.(type) { diff --git a/registry/auth/token/token_test.go b/registry/auth/token/token_test.go index 827dbbd7..3140d473 100644 --- a/registry/auth/token/token_test.go +++ b/registry/auth/token/token_test.go @@ -284,7 +284,7 @@ func TestAccessController(t *testing.T) { Action: "baz", } - ctx := context.WithValue(nil, "http.request", req) + ctx := context.WithRequest(context.Background(), req) authCtx, err := accessController.Authorized(ctx, testAccess) challenge, ok := err.(auth.Challenge) if !ok { diff --git a/registry/handlers/app.go b/registry/handlers/app.go index 33f49670..b80b8b9d 100644 --- a/registry/handlers/app.go +++ b/registry/handlers/app.go @@ -425,6 +425,8 @@ func (app *App) configureEvents(configuration *configuration.Configuration) { } } +type redisStartAtKey struct{} + func (app *App) configureRedis(configuration *configuration.Configuration) { if configuration.Redis.Addr == "" { ctxu.GetLogger(app).Infof("redis not configured") @@ -434,11 +436,11 @@ func (app *App) configureRedis(configuration *configuration.Configuration) { pool := &redis.Pool{ Dial: func() (redis.Conn, error) { // TODO(stevvooe): Yet another use case for contextual timing. - ctx := context.WithValue(app, "redis.connect.startedat", time.Now()) + ctx := context.WithValue(app, redisStartAtKey{}, time.Now()) done := func(err error) { logger := ctxu.GetLoggerWithField(ctx, "redis.connect.duration", - ctxu.Since(ctx, "redis.connect.startedat")) + ctxu.Since(ctx, redisStartAtKey{})) if err != nil { logger.Errorf("redis: error connecting: %v", err) } else { @@ -671,6 +673,18 @@ func (app *App) dispatcher(dispatch dispatchFunc) http.Handler { }) } +type errCodeKey struct{} + +func (errCodeKey) String() string { return "err.code" } + +type errMessageKey struct{} + +func (errMessageKey) String() string { return "err.message" } + +type errDetailKey struct{} + +func (errDetailKey) String() string { return "err.detail" } + func (app *App) logError(context context.Context, errors errcode.Errors) { for _, e1 := range errors { var c ctxu.Context @@ -678,23 +692,23 @@ func (app *App) logError(context context.Context, errors errcode.Errors) { switch e1.(type) { case errcode.Error: e, _ := e1.(errcode.Error) - c = ctxu.WithValue(context, "err.code", e.Code) - c = ctxu.WithValue(c, "err.message", e.Code.Message()) - c = ctxu.WithValue(c, "err.detail", e.Detail) + c = ctxu.WithValue(context, errCodeKey{}, e.Code) + c = ctxu.WithValue(c, errMessageKey{}, e.Code.Message()) + c = ctxu.WithValue(c, errDetailKey{}, e.Detail) case errcode.ErrorCode: e, _ := e1.(errcode.ErrorCode) - c = ctxu.WithValue(context, "err.code", e) - c = ctxu.WithValue(c, "err.message", e.Message()) + c = ctxu.WithValue(context, errCodeKey{}, e) + c = ctxu.WithValue(c, errMessageKey{}, e.Message()) default: // just normal go 'error' - c = ctxu.WithValue(context, "err.code", errcode.ErrorCodeUnknown) - c = ctxu.WithValue(c, "err.message", e1.Error()) + c = ctxu.WithValue(context, errCodeKey{}, errcode.ErrorCodeUnknown) + c = ctxu.WithValue(c, errMessageKey{}, e1.Error()) } c = ctxu.WithLogger(c, ctxu.GetLogger(c, - "err.code", - "err.message", - "err.detail")) + errCodeKey{}, + errMessageKey{}, + errDetailKey{})) ctxu.GetResponseLogger(c).Errorf("response completed with error") } }