Instead, provide a variant of instrumentedResponseWriter that does not implement CloseNotifier, and use that when necessary. In copyFullPayload, log instead of panicing when we encounter something that doesn't implement CloseNotifier. This is more complicated than I'd like, but it's necessary because instrumentedResponseWriter must not embed CloseNotifier unless there's really a CloseNotifier to embed. Signed-off-by: Aaron Lehmann <aaron.lehmann@docker.com>
		
			
				
	
	
		
			285 lines
		
	
	
	
		
			6.3 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			285 lines
		
	
	
	
		
			6.3 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package context
 | |
| 
 | |
| import (
 | |
| 	"net/http"
 | |
| 	"net/http/httptest"
 | |
| 	"net/http/httputil"
 | |
| 	"net/url"
 | |
| 	"reflect"
 | |
| 	"testing"
 | |
| 	"time"
 | |
| )
 | |
| 
 | |
| func TestWithRequest(t *testing.T) {
 | |
| 	var req http.Request
 | |
| 
 | |
| 	start := time.Now()
 | |
| 	req.Method = "GET"
 | |
| 	req.Host = "example.com"
 | |
| 	req.RequestURI = "/test-test"
 | |
| 	req.Header = make(http.Header)
 | |
| 	req.Header.Set("Referer", "foo.com/referer")
 | |
| 	req.Header.Set("User-Agent", "test/0.1")
 | |
| 
 | |
| 	ctx := WithRequest(Background(), &req)
 | |
| 	for _, testcase := range []struct {
 | |
| 		key      string
 | |
| 		expected interface{}
 | |
| 	}{
 | |
| 		{
 | |
| 			key:      "http.request",
 | |
| 			expected: &req,
 | |
| 		},
 | |
| 		{
 | |
| 			key: "http.request.id",
 | |
| 		},
 | |
| 		{
 | |
| 			key:      "http.request.method",
 | |
| 			expected: req.Method,
 | |
| 		},
 | |
| 		{
 | |
| 			key:      "http.request.host",
 | |
| 			expected: req.Host,
 | |
| 		},
 | |
| 		{
 | |
| 			key:      "http.request.uri",
 | |
| 			expected: req.RequestURI,
 | |
| 		},
 | |
| 		{
 | |
| 			key:      "http.request.referer",
 | |
| 			expected: req.Referer(),
 | |
| 		},
 | |
| 		{
 | |
| 			key:      "http.request.useragent",
 | |
| 			expected: req.UserAgent(),
 | |
| 		},
 | |
| 		{
 | |
| 			key:      "http.request.remoteaddr",
 | |
| 			expected: req.RemoteAddr,
 | |
| 		},
 | |
| 		{
 | |
| 			key: "http.request.startedat",
 | |
| 		},
 | |
| 	} {
 | |
| 		v := ctx.Value(testcase.key)
 | |
| 
 | |
| 		if v == nil {
 | |
| 			t.Fatalf("value not found for %q", testcase.key)
 | |
| 		}
 | |
| 
 | |
| 		if testcase.expected != nil && v != testcase.expected {
 | |
| 			t.Fatalf("%s: %v != %v", testcase.key, v, testcase.expected)
 | |
| 		}
 | |
| 
 | |
| 		// Key specific checks!
 | |
| 		switch testcase.key {
 | |
| 		case "http.request.id":
 | |
| 			if _, ok := v.(string); !ok {
 | |
| 				t.Fatalf("request id not a string: %v", v)
 | |
| 			}
 | |
| 		case "http.request.startedat":
 | |
| 			vt, ok := v.(time.Time)
 | |
| 			if !ok {
 | |
| 				t.Fatalf("value not a time: %v", v)
 | |
| 			}
 | |
| 
 | |
| 			now := time.Now()
 | |
| 			if vt.After(now) {
 | |
| 				t.Fatalf("time generated too late: %v > %v", vt, now)
 | |
| 			}
 | |
| 
 | |
| 			if vt.Before(start) {
 | |
| 				t.Fatalf("time generated too early: %v < %v", vt, start)
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| type testResponseWriter struct {
 | |
| 	flushed bool
 | |
| 	status  int
 | |
| 	written int64
 | |
| 	header  http.Header
 | |
| }
 | |
| 
 | |
| func (trw *testResponseWriter) Header() http.Header {
 | |
| 	if trw.header == nil {
 | |
| 		trw.header = make(http.Header)
 | |
| 	}
 | |
| 
 | |
| 	return trw.header
 | |
| }
 | |
| 
 | |
| func (trw *testResponseWriter) Write(p []byte) (n int, err error) {
 | |
| 	if trw.status == 0 {
 | |
| 		trw.status = http.StatusOK
 | |
| 	}
 | |
| 
 | |
| 	n = len(p)
 | |
| 	trw.written += int64(n)
 | |
| 	return
 | |
| }
 | |
| 
 | |
| func (trw *testResponseWriter) WriteHeader(status int) {
 | |
| 	trw.status = status
 | |
| }
 | |
| 
 | |
| func (trw *testResponseWriter) Flush() {
 | |
| 	trw.flushed = true
 | |
| }
 | |
| 
 | |
| func TestWithResponseWriter(t *testing.T) {
 | |
| 	trw := testResponseWriter{}
 | |
| 	ctx, rw := WithResponseWriter(Background(), &trw)
 | |
| 
 | |
| 	if ctx.Value("http.response") != rw {
 | |
| 		t.Fatalf("response not available in context: %v != %v", ctx.Value("http.response"), rw)
 | |
| 	}
 | |
| 
 | |
| 	grw, err := GetResponseWriter(ctx)
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("error getting response writer: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	if grw != rw {
 | |
| 		t.Fatalf("unexpected response writer returned: %#v != %#v", grw, rw)
 | |
| 	}
 | |
| 
 | |
| 	if ctx.Value("http.response.status") != 0 {
 | |
| 		t.Fatalf("response status should always be a number and should be zero here: %v != 0", ctx.Value("http.response.status"))
 | |
| 	}
 | |
| 
 | |
| 	if n, err := rw.Write(make([]byte, 1024)); err != nil {
 | |
| 		t.Fatalf("unexpected error writing: %v", err)
 | |
| 	} else if n != 1024 {
 | |
| 		t.Fatalf("unexpected number of bytes written: %v != %v", n, 1024)
 | |
| 	}
 | |
| 
 | |
| 	if ctx.Value("http.response.status") != http.StatusOK {
 | |
| 		t.Fatalf("unexpected response status in context: %v != %v", ctx.Value("http.response.status"), http.StatusOK)
 | |
| 	}
 | |
| 
 | |
| 	if ctx.Value("http.response.written") != int64(1024) {
 | |
| 		t.Fatalf("unexpected number reported bytes written: %v != %v", ctx.Value("http.response.written"), 1024)
 | |
| 	}
 | |
| 
 | |
| 	// Make sure flush propagates
 | |
| 	rw.(http.Flusher).Flush()
 | |
| 
 | |
| 	if !trw.flushed {
 | |
| 		t.Fatalf("response writer not flushed")
 | |
| 	}
 | |
| 
 | |
| 	// Write another status and make sure context is correct. This normally
 | |
| 	// wouldn't work except for in this contrived testcase.
 | |
| 	rw.WriteHeader(http.StatusBadRequest)
 | |
| 
 | |
| 	if ctx.Value("http.response.status") != http.StatusBadRequest {
 | |
| 		t.Fatalf("unexpected response status in context: %v != %v", ctx.Value("http.response.status"), http.StatusBadRequest)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestWithVars(t *testing.T) {
 | |
| 	var req http.Request
 | |
| 	vars := map[string]string{
 | |
| 		"foo": "asdf",
 | |
| 		"bar": "qwer",
 | |
| 	}
 | |
| 
 | |
| 	getVarsFromRequest = func(r *http.Request) map[string]string {
 | |
| 		if r != &req {
 | |
| 			t.Fatalf("unexpected request: %v != %v", r, req)
 | |
| 		}
 | |
| 
 | |
| 		return vars
 | |
| 	}
 | |
| 
 | |
| 	ctx := WithVars(Background(), &req)
 | |
| 	for _, testcase := range []struct {
 | |
| 		key      string
 | |
| 		expected interface{}
 | |
| 	}{
 | |
| 		{
 | |
| 			key:      "vars",
 | |
| 			expected: vars,
 | |
| 		},
 | |
| 		{
 | |
| 			key:      "vars.foo",
 | |
| 			expected: "asdf",
 | |
| 		},
 | |
| 		{
 | |
| 			key:      "vars.bar",
 | |
| 			expected: "qwer",
 | |
| 		},
 | |
| 	} {
 | |
| 		v := ctx.Value(testcase.key)
 | |
| 
 | |
| 		if !reflect.DeepEqual(v, testcase.expected) {
 | |
| 			t.Fatalf("%q: %v != %v", testcase.key, v, testcase.expected)
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // SingleHostReverseProxy will insert an X-Forwarded-For header, and can be used to test
 | |
| // RemoteAddr().  A fake RemoteAddr cannot be set on the HTTP request - it is overwritten
 | |
| // at the transport layer to 127.0.0.1:<port> .  However, as the X-Forwarded-For header
 | |
| // just contains the IP address, it is different enough for testing.
 | |
| func TestRemoteAddr(t *testing.T) {
 | |
| 	var expectedRemote string
 | |
| 	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | |
| 		defer r.Body.Close()
 | |
| 
 | |
| 		if r.RemoteAddr == expectedRemote {
 | |
| 			t.Errorf("Unexpected matching remote addresses")
 | |
| 		}
 | |
| 
 | |
| 		actualRemote := RemoteAddr(r)
 | |
| 		if expectedRemote != actualRemote {
 | |
| 			t.Errorf("Mismatching remote hosts: %v != %v", expectedRemote, actualRemote)
 | |
| 		}
 | |
| 
 | |
| 		w.WriteHeader(200)
 | |
| 	}))
 | |
| 
 | |
| 	defer backend.Close()
 | |
| 	backendURL, err := url.Parse(backend.URL)
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	proxy := httputil.NewSingleHostReverseProxy(backendURL)
 | |
| 	frontend := httptest.NewServer(proxy)
 | |
| 	defer frontend.Close()
 | |
| 
 | |
| 	// X-Forwarded-For set by proxy
 | |
| 	expectedRemote = "127.0.0.1"
 | |
| 	proxyReq, err := http.NewRequest("GET", frontend.URL, nil)
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	_, err = http.DefaultClient.Do(proxyReq)
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	// RemoteAddr in X-Real-Ip
 | |
| 	getReq, err := http.NewRequest("GET", backend.URL, nil)
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	expectedRemote = "1.2.3.4"
 | |
| 	getReq.Header["X-Real-ip"] = []string{expectedRemote}
 | |
| 	_, err = http.DefaultClient.Do(getReq)
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	// Valid X-Real-Ip and invalid X-Forwarded-For
 | |
| 	getReq.Header["X-forwarded-for"] = []string{"1.2.3"}
 | |
| 	_, err = http.DefaultClient.Do(getReq)
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| }
 |