package middleware import ( "context" "crypto/rand" "encoding/json" "fmt" dcontext "github.com/docker/distribution/context" "net" "net/http" "net/http/httptest" "testing" "time" "reflect" // used as a replacement for testify ) // Rather than pull in all of testify func assertEqual(t *testing.T, x, y interface{}) { if !reflect.DeepEqual(x, y) { t.Errorf("%s: Not equal! Expected='%v', Actual='%v'\n", t.Name(), x, y) t.FailNow() } } type mockIPRangeHandler struct { data awsIPResponse } func (m mockIPRangeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { bytes, err := json.Marshal(m.data) if err != nil { w.WriteHeader(500) return } w.Write(bytes) } func newTestHandler(data awsIPResponse) *httptest.Server { return httptest.NewServer(mockIPRangeHandler{ data: data, }) } func serverIPRanges(server *httptest.Server) string { return fmt.Sprintf("%s/", server.URL) } func setupTest(data awsIPResponse) *httptest.Server { // This is a basic schema which only claims the exact ip // is in aws. server := newTestHandler(data) return server } func TestS3TryUpdate(t *testing.T) { t.Parallel() server := setupTest(awsIPResponse{ Prefixes: []prefixEntry{ {IPV4Prefix: "123.231.123.231/32"}, }, }) defer server.Close() ips := newAWSIPs(serverIPRanges(server), time.Hour, nil) assertEqual(t, 1, len(ips.ipv4)) assertEqual(t, 0, len(ips.ipv6)) } func TestMatchIPV6(t *testing.T) { t.Parallel() server := setupTest(awsIPResponse{ V6Prefixes: []prefixEntry{ {IPV6Prefix: "ff00::/16"}, }, }) defer server.Close() ips := newAWSIPs(serverIPRanges(server), time.Hour, nil) ips.tryUpdate() assertEqual(t, true, ips.contains(net.ParseIP("ff00::"))) assertEqual(t, 1, len(ips.ipv6)) assertEqual(t, 0, len(ips.ipv4)) } func TestMatchIPV4(t *testing.T) { t.Parallel() server := setupTest(awsIPResponse{ Prefixes: []prefixEntry{ {IPV4Prefix: "192.168.0.0/24"}, }, }) defer server.Close() ips := newAWSIPs(serverIPRanges(server), time.Hour, nil) ips.tryUpdate() assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0"))) assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1"))) assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0"))) } func TestMatchIPV4_2(t *testing.T) { t.Parallel() server := setupTest(awsIPResponse{ Prefixes: []prefixEntry{ { IPV4Prefix: "192.168.0.0/24", Region: "us-east-1", }, }, }) defer server.Close() ips := newAWSIPs(serverIPRanges(server), time.Hour, nil) ips.tryUpdate() assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0"))) assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1"))) assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0"))) } func TestMatchIPV4WithRegionMatched(t *testing.T) { t.Parallel() server := setupTest(awsIPResponse{ Prefixes: []prefixEntry{ { IPV4Prefix: "192.168.0.0/24", Region: "us-east-1", }, }, }) defer server.Close() ips := newAWSIPs(serverIPRanges(server), time.Hour, []string{"us-east-1"}) ips.tryUpdate() assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0"))) assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1"))) assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0"))) } func TestMatchIPV4WithRegionMatch_2(t *testing.T) { t.Parallel() server := setupTest(awsIPResponse{ Prefixes: []prefixEntry{ { IPV4Prefix: "192.168.0.0/24", Region: "us-east-1", }, }, }) defer server.Close() ips := newAWSIPs(serverIPRanges(server), time.Hour, []string{"us-west-2", "us-east-1"}) ips.tryUpdate() assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0"))) assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1"))) assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0"))) } func TestMatchIPV4WithRegionNotMatched(t *testing.T) { t.Parallel() server := setupTest(awsIPResponse{ Prefixes: []prefixEntry{ { IPV4Prefix: "192.168.0.0/24", Region: "us-east-1", }, }, }) defer server.Close() ips := newAWSIPs(serverIPRanges(server), time.Hour, []string{"us-west-2"}) ips.tryUpdate() assertEqual(t, false, ips.contains(net.ParseIP("192.168.0.0"))) assertEqual(t, false, ips.contains(net.ParseIP("192.168.0.1"))) assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0"))) } func TestInvalidData(t *testing.T) { t.Parallel() // Invalid entries from aws should be ignored. server := setupTest(awsIPResponse{ Prefixes: []prefixEntry{ {IPV4Prefix: "9000"}, {IPV4Prefix: "192.168.0.0/24"}, }, }) defer server.Close() ips := newAWSIPs(serverIPRanges(server), time.Hour, nil) ips.tryUpdate() assertEqual(t, 1, len(ips.ipv4)) } func TestInvalidNetworkType(t *testing.T) { t.Parallel() server := setupTest(awsIPResponse{ Prefixes: []prefixEntry{ {IPV4Prefix: "192.168.0.0/24"}, }, V6Prefixes: []prefixEntry{ {IPV6Prefix: "ff00::/8"}, {IPV6Prefix: "fe00::/8"}, }, }) defer server.Close() ips := newAWSIPs(serverIPRanges(server), time.Hour, nil) assertEqual(t, 0, len(ips.getCandidateNetworks(make([]byte, 17)))) // 17 bytes does not correspond to any net type assertEqual(t, 1, len(ips.getCandidateNetworks(make([]byte, 4)))) // netv4 networks assertEqual(t, 2, len(ips.getCandidateNetworks(make([]byte, 16)))) // netv6 networks } func TestParsing(t *testing.T) { var data = `{ "prefixes": [{ "ip_prefix": "192.168.0.0", "region": "someregion", "service": "s3"}], "ipv6_prefixes": [{ "ipv6_prefix": "2001:4860:4860::8888", "region": "anotherregion", "service": "ec2"}] }` rawMockHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte(data)) }) t.Parallel() server := httptest.NewServer(rawMockHandler) defer server.Close() schema, err := fetchAWSIPs(server.URL) assertEqual(t, nil, err) assertEqual(t, 1, len(schema.Prefixes)) assertEqual(t, prefixEntry{ IPV4Prefix: "192.168.0.0", Region: "someregion", Service: "s3", }, schema.Prefixes[0]) assertEqual(t, 1, len(schema.V6Prefixes)) assertEqual(t, prefixEntry{ IPV6Prefix: "2001:4860:4860::8888", Region: "anotherregion", Service: "ec2", }, schema.V6Prefixes[0]) } func TestUpdateCalledRegularly(t *testing.T) { t.Parallel() updateCount := 0 server := httptest.NewServer(http.HandlerFunc( func(rw http.ResponseWriter, req *http.Request) { updateCount++ rw.Write([]byte("ok")) })) defer server.Close() newAWSIPs(fmt.Sprintf("%s/", server.URL), time.Second, nil) time.Sleep(time.Second*4 + time.Millisecond*500) if updateCount < 4 { t.Errorf("Update should have been called at least 4 times, actual=%d", updateCount) } } func TestEligibleForS3(t *testing.T) { awsIPs := &awsIPs{ ipv4: []net.IPNet{{ IP: net.ParseIP("192.168.1.1"), Mask: net.IPv4Mask(255, 255, 255, 0), }}, initialized: true, } empty := context.TODO() makeContext := func(ip string) context.Context { req := &http.Request{ RemoteAddr: ip, } return dcontext.WithRequest(empty, req) } cases := []struct { Context context.Context Expected bool }{ {Context: empty, Expected: false}, {Context: makeContext("192.168.1.2"), Expected: true}, {Context: makeContext("192.168.0.2"), Expected: false}, } for _, testCase := range cases { name := fmt.Sprintf("Client IP = %v", testCase.Context.Value("http.request.ip")) t.Run(name, func(t *testing.T) { assertEqual(t, testCase.Expected, eligibleForS3(testCase.Context, awsIPs)) }) } } func TestEligibleForS3WithAWSIPNotInitialized(t *testing.T) { awsIPs := &awsIPs{ ipv4: []net.IPNet{{ IP: net.ParseIP("192.168.1.1"), Mask: net.IPv4Mask(255, 255, 255, 0), }}, initialized: false, } empty := context.TODO() makeContext := func(ip string) context.Context { req := &http.Request{ RemoteAddr: ip, } return dcontext.WithRequest(empty, req) } cases := []struct { Context context.Context Expected bool }{ {Context: empty, Expected: false}, {Context: makeContext("192.168.1.2"), Expected: false}, {Context: makeContext("192.168.0.2"), Expected: false}, } for _, testCase := range cases { name := fmt.Sprintf("Client IP = %v", testCase.Context.Value("http.request.ip")) t.Run(name, func(t *testing.T) { assertEqual(t, testCase.Expected, eligibleForS3(testCase.Context, awsIPs)) }) } } // populate ips with a number of different ipv4 and ipv6 networks, for the purposes // of benchmarking contains() performance. func populateRandomNetworks(b *testing.B, ips *awsIPs, ipv4Count, ipv6Count int) { generateNetworks := func(dest *[]net.IPNet, bytes int, count int) { for i := 0; i < count; i++ { ip := make([]byte, bytes) _, err := rand.Read(ip) if err != nil { b.Fatalf("failed to generate network for test : %s", err.Error()) } mask := make([]byte, bytes) for i := 0; i < bytes; i++ { mask[i] = 0xff } *dest = append(*dest, net.IPNet{ IP: ip, Mask: mask, }) } } generateNetworks(&ips.ipv4, 4, ipv4Count) generateNetworks(&ips.ipv6, 16, ipv6Count) } func BenchmarkContainsRandom(b *testing.B) { // Generate a random network configuration, of size comparable to // aws official networks list // curl -s https://ip-ranges.amazonaws.com/ip-ranges.json | jq '.prefixes | length' // 941 numNetworksPerType := 1000 // keep in sync with the above // intentionally skip constructor when creating awsIPs, to avoid updater routine. // This benchmark is only concerned with contains() performance. awsIPs := awsIPs{} populateRandomNetworks(b, &awsIPs, numNetworksPerType, numNetworksPerType) ipv4 := make([][]byte, b.N) ipv6 := make([][]byte, b.N) for i := 0; i < b.N; i++ { ipv4[i] = make([]byte, 4) ipv6[i] = make([]byte, 16) rand.Read(ipv4[i]) rand.Read(ipv6[i]) } b.ResetTimer() for i := 0; i < b.N; i++ { awsIPs.contains(ipv4[i]) awsIPs.contains(ipv6[i]) } } func BenchmarkContainsProd(b *testing.B) { awsIPs := newAWSIPs(defaultIPRangesURL, defaultUpdateFrequency, nil) ipv4 := make([][]byte, b.N) ipv6 := make([][]byte, b.N) for i := 0; i < b.N; i++ { ipv4[i] = make([]byte, 4) ipv6[i] = make([]byte, 16) rand.Read(ipv4[i]) rand.Read(ipv6[i]) } b.ResetTimer() for i := 0; i < b.N; i++ { awsIPs.contains(ipv4[i]) awsIPs.contains(ipv6[i]) } }