diff --git a/configuration/configuration.go b/configuration/configuration.go index d880597b..d4783c4d 100644 --- a/configuration/configuration.go +++ b/configuration/configuration.go @@ -85,6 +85,10 @@ type Configuration struct { // Location headers RelativeURLs bool `yaml:"relativeurls,omitempty"` + // Amount of time to wait for connection to drain before shutting down when registry + // receives a stop signal + DrainTimeout time.Duration `yaml:"draintimeout,omitempty"` + // TLS instructs the http server to listen with a TLS configuration. // This only support simple tls configuration with a cert and key. // Mostly, this is useful for testing situations or simple deployments diff --git a/configuration/configuration_test.go b/configuration/configuration_test.go index 1ae4ce9b..e5f71486 100644 --- a/configuration/configuration_test.go +++ b/configuration/configuration_test.go @@ -7,6 +7,7 @@ import ( "reflect" "strings" "testing" + "time" . "gopkg.in/check.v1" "gopkg.in/yaml.v2" @@ -71,12 +72,13 @@ var configStruct = Configuration{ }, }, HTTP: struct { - Addr string `yaml:"addr,omitempty"` - Net string `yaml:"net,omitempty"` - Host string `yaml:"host,omitempty"` - Prefix string `yaml:"prefix,omitempty"` - Secret string `yaml:"secret,omitempty"` - RelativeURLs bool `yaml:"relativeurls,omitempty"` + Addr string `yaml:"addr,omitempty"` + Net string `yaml:"net,omitempty"` + Host string `yaml:"host,omitempty"` + Prefix string `yaml:"prefix,omitempty"` + Secret string `yaml:"secret,omitempty"` + RelativeURLs bool `yaml:"relativeurls,omitempty"` + DrainTimeout time.Duration `yaml:"draintimeout,omitempty"` TLS struct { Certificate string `yaml:"certificate,omitempty"` Key string `yaml:"key,omitempty"` diff --git a/docs/configuration.md b/docs/configuration.md index db916f50..a6f1b190 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -206,6 +206,7 @@ http: host: https://myregistryaddress.org:5000 secret: asecretforlocaldevelopment relativeurls: false + draintimeout: 60s tls: certificate: /path/to/x509/public key: /path/to/x509/private @@ -739,6 +740,7 @@ http: host: https://myregistryaddress.org:5000 secret: asecretforlocaldevelopment relativeurls: false + draintimeout: 60s tls: certificate: /path/to/x509/public key: /path/to/x509/private @@ -768,6 +770,7 @@ registry. | `host` | no | A fully-qualified URL for an externally-reachable address for the registry. If present, it is used when creating generated URLs. Otherwise, these URLs are derived from client requests. | | `secret` | no | A random piece of data used to sign state that may be stored with the client to protect against tampering. For production environments you should generate a random piece of data using a cryptographically secure random generator. If you omit the secret, the registry will automatically generate a secret when it starts. **If you are building a cluster of registries behind a load balancer, you MUST ensure the secret is the same for all registries.**| | `relativeurls`| no | If `true`, the registry returns relative URLs in Location headers. The client is responsible for resolving the correct URL. **This option is not compatible with Docker 1.7 and earlier.**| +| `draintimeout`| no | Amount of time to wait for HTTP connections to drain before shutting down after registry receives SIGTERM signal| ### `tls` diff --git a/registry/registry.go b/registry/registry.go index f0708bcf..44c0edf5 100644 --- a/registry/registry.go +++ b/registry/registry.go @@ -8,6 +8,8 @@ import ( "io/ioutil" "net/http" "os" + "os/signal" + "syscall" "time" "rsc.io/letsencrypt" @@ -28,6 +30,9 @@ import ( "github.com/yvasiyarov/gorelic" ) +// this channel gets notified when process receives signal. It is global to ease unit testing +var quit = make(chan os.Signal, 1) + // ServeCmd is a cobra command for running the registry. var ServeCmd = &cobra.Command{ Use: "serve ", @@ -195,7 +200,29 @@ func (registry *Registry) ListenAndServe() error { dcontext.GetLogger(registry.app).Infof("listening on %v", ln.Addr()) } - return registry.server.Serve(ln) + if config.HTTP.DrainTimeout == 0 { + return registry.server.Serve(ln) + } + + // setup channel to get notified on SIGTERM signal + signal.Notify(quit, syscall.SIGTERM) + serveErr := make(chan error) + + // Start serving in goroutine and listen for stop signal in main thread + go func() { + serveErr <- registry.server.Serve(ln) + }() + + select { + case err := <-serveErr: + return err + case <-quit: + dcontext.GetLogger(registry.app).Info("stopping server gracefully. Draining connections for ", config.HTTP.DrainTimeout) + // shutdown the server with a grace period of configured timeout + c, cancel := context.WithTimeout(context.Background(), config.HTTP.DrainTimeout) + defer cancel() + return registry.server.Shutdown(c) + } } func configureReporting(app *handlers.App) http.Handler { diff --git a/registry/registry_test.go b/registry/registry_test.go index 34673117..d8deb35e 100644 --- a/registry/registry_test.go +++ b/registry/registry_test.go @@ -1,10 +1,19 @@ package registry import ( + "bufio" + "context" + "fmt" + "io/ioutil" + "net" + "net/http" + "os" "reflect" "testing" + "time" "github.com/docker/distribution/configuration" + _ "github.com/docker/distribution/registry/storage/driver/inmemory" ) // Tests to ensure nextProtos returns the correct protocols when: @@ -28,3 +37,64 @@ func TestNextProtos(t *testing.T) { t.Fatalf("expected protos to equal [http/1.1], got %s", protos) } } + +func setupRegistry() (*Registry, error) { + config := &configuration.Configuration{} + // TODO: this needs to change to something ephemeral as the test will fail if there is any server + // already listening on port 5000 + config.HTTP.Addr = ":5000" + config.HTTP.DrainTimeout = time.Duration(10) * time.Second + config.Storage = map[string]configuration.Parameters{"inmemory": map[string]interface{}{}} + return NewRegistry(context.Background(), config) +} + +func TestGracefulShutdown(t *testing.T) { + registry, err := setupRegistry() + if err != nil { + t.Fatal(err) + } + + // run registry server + var errchan chan error + go func() { + errchan <- registry.ListenAndServe() + }() + select { + case err = <-errchan: + t.Fatalf("Error listening: %v", err) + default: + } + + // Wait for some unknown random time for server to start listening + time.Sleep(3 * time.Second) + + // send incomplete request + conn, err := net.Dial("tcp", "localhost:5000") + if err != nil { + t.Fatal(err) + } + fmt.Fprintf(conn, "GET /v2/ ") + + // send stop signal + quit <- os.Interrupt + time.Sleep(100 * time.Millisecond) + + // try connecting again. it shouldn't + _, err = net.Dial("tcp", "localhost:5000") + if err == nil { + t.Fatal("Managed to connect after stopping.") + } + + // make sure earlier request is not disconnected and response can be received + fmt.Fprintf(conn, "HTTP/1.1\r\nHost: 127.0.0.1\r\n\r\n") + resp, err := http.ReadResponse(bufio.NewReader(conn), nil) + if err != nil { + t.Fatal(err) + } + if resp.Status != "200 OK" { + t.Error("response status is not 200 OK: ", resp.Status) + } + if body, err := ioutil.ReadAll(resp.Body); err != nil || string(body) != "{}" { + t.Error("Body is not {}; ", string(body)) + } +}