// Copyright 2016 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package autocert import ( "context" "crypto" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" "crypto/rsa" "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/base64" "encoding/json" "fmt" "html/template" "io" "math/big" "net/http" "net/http/httptest" "reflect" "strings" "sync" "testing" "time" "golang.org/x/crypto/acme" ) var discoTmpl = template.Must(template.New("disco").Parse(`{ "new-reg": "{{.}}/new-reg", "new-authz": "{{.}}/new-authz", "new-cert": "{{.}}/new-cert" }`)) var authzTmpl = template.Must(template.New("authz").Parse(`{ "status": "pending", "challenges": [ { "uri": "{{.}}/challenge/1", "type": "tls-sni-01", "token": "token-01" }, { "uri": "{{.}}/challenge/2", "type": "tls-sni-02", "token": "token-02" }, { "uri": "{{.}}/challenge/dns-01", "type": "dns-01", "token": "token-dns-01" }, { "uri": "{{.}}/challenge/http-01", "type": "http-01", "token": "token-http-01" } ] }`)) type memCache struct { mu sync.Mutex keyData map[string][]byte } func (m *memCache) Get(ctx context.Context, key string) ([]byte, error) { m.mu.Lock() defer m.mu.Unlock() v, ok := m.keyData[key] if !ok { return nil, ErrCacheMiss } return v, nil } func (m *memCache) Put(ctx context.Context, key string, data []byte) error { m.mu.Lock() defer m.mu.Unlock() m.keyData[key] = data return nil } func (m *memCache) Delete(ctx context.Context, key string) error { m.mu.Lock() defer m.mu.Unlock() delete(m.keyData, key) return nil } func newMemCache() *memCache { return &memCache{ keyData: make(map[string][]byte), } } func dummyCert(pub interface{}, san ...string) ([]byte, error) { return dateDummyCert(pub, time.Now(), time.Now().Add(90*24*time.Hour), san...) } func dateDummyCert(pub interface{}, start, end time.Time, san ...string) ([]byte, error) { // use EC key to run faster on 386 key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { return nil, err } t := &x509.Certificate{ SerialNumber: big.NewInt(1), NotBefore: start, NotAfter: end, BasicConstraintsValid: true, KeyUsage: x509.KeyUsageKeyEncipherment, DNSNames: san, } if pub == nil { pub = &key.PublicKey } return x509.CreateCertificate(rand.Reader, t, t, pub, key) } func decodePayload(v interface{}, r io.Reader) error { var req struct{ Payload string } if err := json.NewDecoder(r).Decode(&req); err != nil { return err } payload, err := base64.RawURLEncoding.DecodeString(req.Payload) if err != nil { return err } return json.Unmarshal(payload, v) } func TestGetCertificate(t *testing.T) { man := &Manager{Prompt: AcceptTOS} defer man.stopRenew() hello := &tls.ClientHelloInfo{ServerName: "example.org"} testGetCertificate(t, man, "example.org", hello) } func TestGetCertificate_trailingDot(t *testing.T) { man := &Manager{Prompt: AcceptTOS} defer man.stopRenew() hello := &tls.ClientHelloInfo{ServerName: "example.org."} testGetCertificate(t, man, "example.org", hello) } func TestGetCertificate_ForceRSA(t *testing.T) { man := &Manager{ Prompt: AcceptTOS, Cache: newMemCache(), ForceRSA: true, } defer man.stopRenew() hello := &tls.ClientHelloInfo{ServerName: "example.org"} testGetCertificate(t, man, "example.org", hello) cert, err := man.cacheGet(context.Background(), "example.org") if err != nil { t.Fatalf("man.cacheGet: %v", err) } if _, ok := cert.PrivateKey.(*rsa.PrivateKey); !ok { t.Errorf("cert.PrivateKey is %T; want *rsa.PrivateKey", cert.PrivateKey) } } func TestGetCertificate_nilPrompt(t *testing.T) { man := &Manager{} defer man.stopRenew() url, finish := startACMEServerStub(t, man, "example.org") defer finish() key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { t.Fatal(err) } man.Client = &acme.Client{ Key: key, DirectoryURL: url, } hello := &tls.ClientHelloInfo{ServerName: "example.org"} if _, err := man.GetCertificate(hello); err == nil { t.Error("got certificate for example.org; wanted error") } } func TestGetCertificate_expiredCache(t *testing.T) { // Make an expired cert and cache it. pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { t.Fatal(err) } tmpl := &x509.Certificate{ SerialNumber: big.NewInt(1), Subject: pkix.Name{CommonName: "example.org"}, NotAfter: time.Now(), } pub, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &pk.PublicKey, pk) if err != nil { t.Fatal(err) } tlscert := &tls.Certificate{ Certificate: [][]byte{pub}, PrivateKey: pk, } man := &Manager{Prompt: AcceptTOS, Cache: newMemCache()} defer man.stopRenew() if err := man.cachePut(context.Background(), "example.org", tlscert); err != nil { t.Fatalf("man.cachePut: %v", err) } // The expired cached cert should trigger a new cert issuance // and return without an error. hello := &tls.ClientHelloInfo{ServerName: "example.org"} testGetCertificate(t, man, "example.org", hello) } func TestGetCertificate_failedAttempt(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusBadRequest) })) defer ts.Close() const example = "example.org" d := createCertRetryAfter f := testDidRemoveState defer func() { createCertRetryAfter = d testDidRemoveState = f }() createCertRetryAfter = 0 done := make(chan struct{}) testDidRemoveState = func(domain string) { if domain != example { t.Errorf("testDidRemoveState: domain = %q; want %q", domain, example) } close(done) } key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { t.Fatal(err) } man := &Manager{ Prompt: AcceptTOS, Client: &acme.Client{ Key: key, DirectoryURL: ts.URL, }, } defer man.stopRenew() hello := &tls.ClientHelloInfo{ServerName: example} if _, err := man.GetCertificate(hello); err == nil { t.Error("GetCertificate: err is nil") } select { case <-time.After(5 * time.Second): t.Errorf("took too long to remove the %q state", example) case <-done: man.stateMu.Lock() defer man.stateMu.Unlock() if v, exist := man.state[example]; exist { t.Errorf("state exists for %q: %+v", example, v) } } } // startACMEServerStub runs an ACME server // The domain argument is the expected domain name of a certificate request. func startACMEServerStub(t *testing.T, man *Manager, domain string) (url string, finish func()) { // echo token-02 | shasum -a 256 // then divide result in 2 parts separated by dot tokenCertName := "4e8eb87631187e9ff2153b56b13a4dec.13a35d002e485d60ff37354b32f665d9.token.acme.invalid" verifyTokenCert := func() { hello := &tls.ClientHelloInfo{ServerName: tokenCertName} _, err := man.GetCertificate(hello) if err != nil { t.Errorf("verifyTokenCert: GetCertificate(%q): %v", tokenCertName, err) return } } // ACME CA server stub var ca *httptest.Server ca = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Replay-Nonce", "nonce") if r.Method == "HEAD" { // a nonce request return } switch r.URL.Path { // discovery case "/": if err := discoTmpl.Execute(w, ca.URL); err != nil { t.Errorf("discoTmpl: %v", err) } // client key registration case "/new-reg": w.Write([]byte("{}")) // domain authorization case "/new-authz": w.Header().Set("Location", ca.URL+"/authz/1") w.WriteHeader(http.StatusCreated) if err := authzTmpl.Execute(w, ca.URL); err != nil { t.Errorf("authzTmpl: %v", err) } // accept tls-sni-02 challenge case "/challenge/2": verifyTokenCert() w.Write([]byte("{}")) // authorization status case "/authz/1": w.Write([]byte(`{"status": "valid"}`)) // cert request case "/new-cert": var req struct { CSR string `json:"csr"` } decodePayload(&req, r.Body) b, _ := base64.RawURLEncoding.DecodeString(req.CSR) csr, err := x509.ParseCertificateRequest(b) if err != nil { t.Errorf("new-cert: CSR: %v", err) } if csr.Subject.CommonName != domain { t.Errorf("CommonName in CSR = %q; want %q", csr.Subject.CommonName, domain) } der, err := dummyCert(csr.PublicKey, domain) if err != nil { t.Errorf("new-cert: dummyCert: %v", err) } chainUp := fmt.Sprintf("<%s/ca-cert>; rel=up", ca.URL) w.Header().Set("Link", chainUp) w.WriteHeader(http.StatusCreated) w.Write(der) // CA chain cert case "/ca-cert": der, err := dummyCert(nil, "ca") if err != nil { t.Errorf("ca-cert: dummyCert: %v", err) } w.Write(der) default: t.Errorf("unrecognized r.URL.Path: %s", r.URL.Path) } })) finish = func() { ca.Close() // make sure token cert was removed cancel := make(chan struct{}) done := make(chan struct{}) go func() { defer close(done) tick := time.NewTicker(100 * time.Millisecond) defer tick.Stop() for { hello := &tls.ClientHelloInfo{ServerName: tokenCertName} if _, err := man.GetCertificate(hello); err != nil { return } select { case <-tick.C: case <-cancel: return } } }() select { case <-done: case <-time.After(5 * time.Second): close(cancel) t.Error("token cert was not removed") <-done } } return ca.URL, finish } // tests man.GetCertificate flow using the provided hello argument. // The domain argument is the expected domain name of a certificate request. func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.ClientHelloInfo) { url, finish := startACMEServerStub(t, man, domain) defer finish() // use EC key to run faster on 386 key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { t.Fatal(err) } man.Client = &acme.Client{ Key: key, DirectoryURL: url, } // simulate tls.Config.GetCertificate var tlscert *tls.Certificate done := make(chan struct{}) go func() { tlscert, err = man.GetCertificate(hello) close(done) }() select { case <-time.After(time.Minute): t.Fatal("man.GetCertificate took too long to return") case <-done: } if err != nil { t.Fatalf("man.GetCertificate: %v", err) } // verify the tlscert is the same we responded with from the CA stub if len(tlscert.Certificate) == 0 { t.Fatal("len(tlscert.Certificate) is 0") } cert, err := x509.ParseCertificate(tlscert.Certificate[0]) if err != nil { t.Fatalf("x509.ParseCertificate: %v", err) } if len(cert.DNSNames) == 0 || cert.DNSNames[0] != domain { t.Errorf("cert.DNSNames = %v; want %q", cert.DNSNames, domain) } } func TestVerifyHTTP01(t *testing.T) { var ( http01 http.Handler authzCount int // num. of created authorizations didAcceptHTTP01 bool ) verifyHTTPToken := func() { r := httptest.NewRequest("GET", "/.well-known/acme-challenge/token-http-01", nil) w := httptest.NewRecorder() http01.ServeHTTP(w, r) if w.Code != http.StatusOK { t.Errorf("http token: w.Code = %d; want %d", w.Code, http.StatusOK) } if v := string(w.Body.Bytes()); !strings.HasPrefix(v, "token-http-01.") { t.Errorf("http token value = %q; want 'token-http-01.' prefix", v) } } // ACME CA server stub, only the needed bits. // TODO: Merge this with startACMEServerStub, making it a configurable CA for testing. var ca *httptest.Server ca = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Replay-Nonce", "nonce") if r.Method == "HEAD" { // a nonce request return } switch r.URL.Path { // Discovery. case "/": if err := discoTmpl.Execute(w, ca.URL); err != nil { t.Errorf("discoTmpl: %v", err) } // Client key registration. case "/new-reg": w.Write([]byte("{}")) // New domain authorization. case "/new-authz": authzCount++ w.Header().Set("Location", fmt.Sprintf("%s/authz/%d", ca.URL, authzCount)) w.WriteHeader(http.StatusCreated) if err := authzTmpl.Execute(w, ca.URL); err != nil { t.Errorf("authzTmpl: %v", err) } // Accept tls-sni-02. case "/challenge/2": w.Write([]byte("{}")) // Reject tls-sni-01. case "/challenge/1": http.Error(w, "won't accept tls-sni-01", http.StatusBadRequest) // Should not accept dns-01. case "/challenge/dns-01": t.Errorf("dns-01 challenge was accepted") http.Error(w, "won't accept dns-01", http.StatusBadRequest) // Accept http-01. case "/challenge/http-01": didAcceptHTTP01 = true verifyHTTPToken() w.Write([]byte("{}")) // Authorization statuses. // Make tls-sni-xxx invalid. case "/authz/1", "/authz/2": w.Write([]byte(`{"status": "invalid"}`)) case "/authz/3", "/authz/4": w.Write([]byte(`{"status": "valid"}`)) default: http.NotFound(w, r) t.Errorf("unrecognized r.URL.Path: %s", r.URL.Path) } })) defer ca.Close() key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { t.Fatal(err) } m := &Manager{ Client: &acme.Client{ Key: key, DirectoryURL: ca.URL, }, } http01 = m.HTTPHandler(nil) if err := m.verify(context.Background(), m.Client, "example.org"); err != nil { t.Errorf("m.verify: %v", err) } // Only tls-sni-01, tls-sni-02 and http-01 must be accepted // The dns-01 challenge is unsupported. if authzCount != 3 { t.Errorf("authzCount = %d; want 3", authzCount) } if !didAcceptHTTP01 { t.Error("did not accept http-01 challenge") } } func TestHTTPHandlerDefaultFallback(t *testing.T) { tt := []struct { method, url string wantCode int wantLocation string }{ {"GET", "http://example.org", 302, "https://example.org/"}, {"GET", "http://example.org/foo", 302, "https://example.org/foo"}, {"GET", "http://example.org/foo/bar/", 302, "https://example.org/foo/bar/"}, {"GET", "http://example.org/?a=b", 302, "https://example.org/?a=b"}, {"GET", "http://example.org/foo?a=b", 302, "https://example.org/foo?a=b"}, {"GET", "http://example.org:80/foo?a=b", 302, "https://example.org:443/foo?a=b"}, {"GET", "http://example.org:80/foo%20bar", 302, "https://example.org:443/foo%20bar"}, {"GET", "http://[2602:d1:xxxx::c60a]:1234", 302, "https://[2602:d1:xxxx::c60a]:443/"}, {"GET", "http://[2602:d1:xxxx::c60a]", 302, "https://[2602:d1:xxxx::c60a]/"}, {"GET", "http://[2602:d1:xxxx::c60a]/foo?a=b", 302, "https://[2602:d1:xxxx::c60a]/foo?a=b"}, {"HEAD", "http://example.org", 302, "https://example.org/"}, {"HEAD", "http://example.org/foo", 302, "https://example.org/foo"}, {"HEAD", "http://example.org/foo/bar/", 302, "https://example.org/foo/bar/"}, {"HEAD", "http://example.org/?a=b", 302, "https://example.org/?a=b"}, {"HEAD", "http://example.org/foo?a=b", 302, "https://example.org/foo?a=b"}, {"POST", "http://example.org", 400, ""}, {"PUT", "http://example.org", 400, ""}, {"GET", "http://example.org/.well-known/acme-challenge/x", 404, ""}, } var m Manager h := m.HTTPHandler(nil) for i, test := range tt { r := httptest.NewRequest(test.method, test.url, nil) w := httptest.NewRecorder() h.ServeHTTP(w, r) if w.Code != test.wantCode { t.Errorf("%d: w.Code = %d; want %d", i, w.Code, test.wantCode) t.Errorf("%d: body: %s", i, w.Body.Bytes()) } if v := w.Header().Get("Location"); v != test.wantLocation { t.Errorf("%d: Location = %q; want %q", i, v, test.wantLocation) } } } func TestAccountKeyCache(t *testing.T) { m := Manager{Cache: newMemCache()} ctx := context.Background() k1, err := m.accountKey(ctx) if err != nil { t.Fatal(err) } k2, err := m.accountKey(ctx) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(k1, k2) { t.Errorf("account keys don't match: k1 = %#v; k2 = %#v", k1, k2) } } func TestCache(t *testing.T) { privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { t.Fatal(err) } tmpl := &x509.Certificate{ SerialNumber: big.NewInt(1), Subject: pkix.Name{CommonName: "example.org"}, NotAfter: time.Now().Add(time.Hour), } pub, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &privKey.PublicKey, privKey) if err != nil { t.Fatal(err) } tlscert := &tls.Certificate{ Certificate: [][]byte{pub}, PrivateKey: privKey, } man := &Manager{Cache: newMemCache()} defer man.stopRenew() ctx := context.Background() if err := man.cachePut(ctx, "example.org", tlscert); err != nil { t.Fatalf("man.cachePut: %v", err) } res, err := man.cacheGet(ctx, "example.org") if err != nil { t.Fatalf("man.cacheGet: %v", err) } if res == nil { t.Fatal("res is nil") } } func TestHostWhitelist(t *testing.T) { policy := HostWhitelist("example.com", "example.org", "*.example.net") tt := []struct { host string allow bool }{ {"example.com", true}, {"example.org", true}, {"one.example.com", false}, {"two.example.org", false}, {"three.example.net", false}, {"dummy", false}, } for i, test := range tt { err := policy(nil, test.host) if err != nil && test.allow { t.Errorf("%d: policy(%q): %v; want nil", i, test.host, err) } if err == nil && !test.allow { t.Errorf("%d: policy(%q): nil; want an error", i, test.host) } } } func TestValidCert(t *testing.T) { key1, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { t.Fatal(err) } key2, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { t.Fatal(err) } key3, err := rsa.GenerateKey(rand.Reader, 512) if err != nil { t.Fatal(err) } cert1, err := dummyCert(key1.Public(), "example.org") if err != nil { t.Fatal(err) } cert2, err := dummyCert(key2.Public(), "example.org") if err != nil { t.Fatal(err) } cert3, err := dummyCert(key3.Public(), "example.org") if err != nil { t.Fatal(err) } now := time.Now() early, err := dateDummyCert(key1.Public(), now.Add(time.Hour), now.Add(2*time.Hour), "example.org") if err != nil { t.Fatal(err) } expired, err := dateDummyCert(key1.Public(), now.Add(-2*time.Hour), now.Add(-time.Hour), "example.org") if err != nil { t.Fatal(err) } tt := []struct { domain string key crypto.Signer cert [][]byte ok bool }{ {"example.org", key1, [][]byte{cert1}, true}, {"example.org", key3, [][]byte{cert3}, true}, {"example.org", key1, [][]byte{cert1, cert2, cert3}, true}, {"example.org", key1, [][]byte{cert1, {1}}, false}, {"example.org", key1, [][]byte{{1}}, false}, {"example.org", key1, [][]byte{cert2}, false}, {"example.org", key2, [][]byte{cert1}, false}, {"example.org", key1, [][]byte{cert3}, false}, {"example.org", key3, [][]byte{cert1}, false}, {"example.net", key1, [][]byte{cert1}, false}, {"example.org", key1, [][]byte{early}, false}, {"example.org", key1, [][]byte{expired}, false}, } for i, test := range tt { leaf, err := validCert(test.domain, test.cert, test.key) if err != nil && test.ok { t.Errorf("%d: err = %v", i, err) } if err == nil && !test.ok { t.Errorf("%d: err is nil", i) } if err == nil && test.ok && leaf == nil { t.Errorf("%d: leaf is nil", i) } } } type cacheGetFunc func(ctx context.Context, key string) ([]byte, error) func (f cacheGetFunc) Get(ctx context.Context, key string) ([]byte, error) { return f(ctx, key) } func (f cacheGetFunc) Put(ctx context.Context, key string, data []byte) error { return fmt.Errorf("unsupported Put of %q = %q", key, data) } func (f cacheGetFunc) Delete(ctx context.Context, key string) error { return fmt.Errorf("unsupported Delete of %q", key) } func TestManagerGetCertificateBogusSNI(t *testing.T) { m := Manager{ Prompt: AcceptTOS, Cache: cacheGetFunc(func(ctx context.Context, key string) ([]byte, error) { return nil, fmt.Errorf("cache.Get of %s", key) }), } tests := []struct { name string wantErr string }{ {"foo.com", "cache.Get of foo.com"}, {"foo.com.", "cache.Get of foo.com"}, {`a\b.com`, "acme/autocert: server name contains invalid character"}, {`a/b.com`, "acme/autocert: server name contains invalid character"}, {"", "acme/autocert: missing server name"}, {"foo", "acme/autocert: server name component count invalid"}, {".foo", "acme/autocert: server name component count invalid"}, {"foo.", "acme/autocert: server name component count invalid"}, {"fo.o", "cache.Get of fo.o"}, } for _, tt := range tests { _, err := m.GetCertificate(&tls.ClientHelloInfo{ServerName: tt.name}) got := fmt.Sprint(err) if got != tt.wantErr { t.Errorf("GetCertificate(SNI = %q) = %q; want %q", tt.name, got, tt.wantErr) } } }