From e8ecc6dc5506b2046800221ee66583ae692640eb Mon Sep 17 00:00:00 2001 From: tifayuki Date: Tue, 14 Nov 2017 17:21:36 -0800 Subject: [PATCH] add s3 region filters for cloudfront Signed-off-by: tifayuki --- .gitignore | 1 + context/logger.go | 2 + docs/configuration.md | 16 + .../middleware/cloudfront/middleware.go | 79 ++- .../driver/middleware/cloudfront/s3filter.go | 223 ++++++ .../middleware/cloudfront/s3filter_test.go | 401 +++++++++++ vendor/github.com/miekg/dns/msg_generate.go | 340 --------- vendor/github.com/miekg/dns/types_generate.go | 271 ------- vendor/golang.org/x/net/idna/idna.go | 68 -- vendor/golang.org/x/net/idna/punycode.go | 200 ------ vendor/golang.org/x/net/publicsuffix/gen.go | 663 ------------------ 11 files changed, 716 insertions(+), 1548 deletions(-) create mode 100644 registry/storage/driver/middleware/cloudfront/s3filter.go create mode 100644 registry/storage/driver/middleware/cloudfront/s3filter_test.go delete mode 100644 vendor/github.com/miekg/dns/msg_generate.go delete mode 100644 vendor/github.com/miekg/dns/types_generate.go delete mode 100644 vendor/golang.org/x/net/idna/idna.go delete mode 100644 vendor/golang.org/x/net/idna/punycode.go delete mode 100644 vendor/golang.org/x/net/publicsuffix/gen.go diff --git a/.gitignore b/.gitignore index 1c3ae0a7..4cf7888e 100644 --- a/.gitignore +++ b/.gitignore @@ -35,3 +35,4 @@ bin/* # Editor/IDE specific files. *.sublime-project *.sublime-workspace +.idea/* diff --git a/context/logger.go b/context/logger.go index afc8860b..3e5b81bb 100644 --- a/context/logger.go +++ b/context/logger.go @@ -39,6 +39,8 @@ type Logger interface { Warn(args ...interface{}) Warnf(format string, args ...interface{}) Warnln(args ...interface{}) + + WithError(err error) *logrus.Entry } type loggerKey struct{} diff --git a/docs/configuration.md b/docs/configuration.md index c7f9023f..807353cc 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -183,6 +183,10 @@ middleware: privatekey: /path/to/pem keypairid: cloudfrontkeypairid duration: 3000s + ipfilteredby: awsregion + awsregion: us-east-1, use-east-2 + updatefrenquency: 12h + iprangesurl: https://ip-ranges.amazonaws.com/ip-ranges.json storage: - name: redirect options: @@ -636,6 +640,10 @@ middleware: privatekey: /path/to/pem keypairid: cloudfrontkeypairid duration: 3000s + ipfilteredby: awsregion + awsregion: us-east-1, use-east-2 + updatefrenquency: 12h + iprangesurl: https://ip-ranges.amazonaws.com/ip-ranges.json ``` Each middleware entry has `name` and `options` entries. The `name` must @@ -655,6 +663,14 @@ interpretation of the options. | `privatekey` | yes | The private key for Cloudfront, provided by AWS. | | `keypairid` | yes | The key pair ID provided by AWS. | | `duration` | no | An integer and unit for the duration of the Cloudfront session. Valid time units are `ns`, `us` (or `µs`), `ms`, `s`, `m`, or `h`. For example, `3000s` is valid, but `3000 s` is not. If you do not specify a `duration` or you specify an integer without a time unit, the duration defaults to `20m` (20 minutes).| +|`ipfilteredby`|no | A string with the following value `none|aws|awsregion`. | +|`awsregion`|no | A comma separated string of AWS regions, only available when `ipfilteredby` is `awsregion`. For example, `us-east-1, us-west-2`| +|`updatefrenquency`|no | The frequency to update AWS IP regions, default: `12h`| +|`iprangesurl`|no | The URL contains the AWS IP ranges information, default: `https://ip-ranges.amazonaws.com/ip-ranges.json`| +Then value of ipfilteredby: +`none`: default, do not filter by IP +`aws`: IP from AWS goes to S3 directly +`awsregion`: IP from certain AWS regions goes to S3 directly, use together with `awsregion` ### `redirect` diff --git a/registry/storage/driver/middleware/cloudfront/middleware.go b/registry/storage/driver/middleware/cloudfront/middleware.go index 61e787a4..5dc3ee41 100644 --- a/registry/storage/driver/middleware/cloudfront/middleware.go +++ b/registry/storage/driver/middleware/cloudfront/middleware.go @@ -16,7 +16,7 @@ import ( "github.com/aws/aws-sdk-go/service/cloudfront/sign" dcontext "github.com/docker/distribution/context" storagedriver "github.com/docker/distribution/registry/storage/driver" - storagemiddleware "github.com/docker/distribution/registry/storage/driver/middleware" + "github.com/docker/distribution/registry/storage/driver/middleware" ) // cloudFrontStorageMiddleware provides a simple implementation of layerHandler that @@ -24,6 +24,7 @@ import ( // then issues HTTP Temporary Redirects to this CloudFront content URL. type cloudFrontStorageMiddleware struct { storagedriver.StorageDriver + awsIPs *awsIPs urlSigner *sign.URLSigner baseURL string duration time.Duration @@ -34,7 +35,13 @@ var _ storagedriver.StorageDriver = &cloudFrontStorageMiddleware{} // newCloudFrontLayerHandler constructs and returns a new CloudFront // LayerHandler implementation. // Required options: baseurl, privatekey, keypairid + +// Optional options: ipFilteredBy, awsregion +// ipfilteredby: valid value "none|aws|awsregion". "none", do not filter any IP, default value. "aws", only aws IP goes +// to S3 directly. "awsregion", only regions listed in awsregion options goes to S3 directly +// awsregion: a comma separated string of AWS regions. func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error) { + // parse baseurl base, ok := options["baseurl"] if !ok { return nil, fmt.Errorf("no baseurl provided") @@ -52,6 +59,8 @@ func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, o if _, err := url.Parse(baseURL); err != nil { return nil, fmt.Errorf("invalid baseurl: %v", err) } + + // parse privatekey to get pkPath pk, ok := options["privatekey"] if !ok { return nil, fmt.Errorf("no privatekey provided") @@ -60,6 +69,8 @@ func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, o if !ok { return nil, fmt.Errorf("privatekey must be a string") } + + // parse keypairid kpid, ok := options["keypairid"] if !ok { return nil, fmt.Errorf("no keypairid provided") @@ -69,6 +80,7 @@ func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, o return nil, fmt.Errorf("keypairid must be a string") } + // get urlSigner from the file specified in pkPath pkBytes, err := ioutil.ReadFile(pkPath) if err != nil { return nil, fmt.Errorf("failed to read privatekey file: %s", err) @@ -82,12 +94,11 @@ func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, o if err != nil { return nil, err } - urlSigner := sign.NewURLSigner(keypairID, privateKey) + // parse duration duration := 20 * time.Minute - d, ok := options["duration"] - if ok { + if d, ok := options["duration"]; ok { switch d := d.(type) { case time.Duration: duration = d @@ -100,11 +111,62 @@ func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, o } } + // parse updatefrenquency + updateFrequency := defaultUpdateFrequency + if u, ok := options["updatefrenquency"]; ok { + switch u := u.(type) { + case time.Duration: + updateFrequency = u + case string: + updateFreq, err := time.ParseDuration(u) + if err != nil { + return nil, fmt.Errorf("invalid updatefrenquency: %s", err) + } + duration = updateFreq + } + } + + // parse iprangesurl + ipRangesURL := defaultIPRangesURL + if i, ok := options["iprangesurl"]; ok { + if iprangeurl, ok := i.(string); ok { + ipRangesURL = iprangeurl + } else { + return nil, fmt.Errorf("iprangesurl must be a string") + } + } + + // parse ipfilteredby + var awsIPs *awsIPs + if ipFilteredBy := options["ipfilteredby"].(string); ok { + switch strings.ToLower(strings.TrimSpace(ipFilteredBy)) { + case "", "none": + awsIPs = nil + case "aws": + newAWSIPs(ipRangesURL, updateFrequency, nil) + case "awsregion": + var awsRegion []string + if regions, ok := options["awsregion"].(string); ok { + for _, awsRegions := range strings.Split(regions, ",") { + awsRegion = append(awsRegion, strings.ToLower(strings.TrimSpace(awsRegions))) + } + awsIPs = newAWSIPs(ipRangesURL, updateFrequency, awsRegion) + } else { + return nil, fmt.Errorf("awsRegion must be a comma separated string of valid aws regions") + } + default: + return nil, fmt.Errorf("ipfilteredby only allows a string the following value: none|aws|awsregion") + } + } else { + return nil, fmt.Errorf("ipfilteredby only allows a string with the following value: none|aws|awsregion") + } + return &cloudFrontStorageMiddleware{ StorageDriver: storageDriver, urlSigner: urlSigner, baseURL: baseURL, duration: duration, + awsIPs: awsIPs, }, nil } @@ -114,8 +176,8 @@ type S3BucketKeyer interface { S3BucketKey(path string) string } -// Resolve returns an http.Handler which can serve the contents of the given -// Layer, or an error if not supported by the storagedriver. +// URLFor attempts to find a url which may be used to retrieve the file at the given path. +// Returns an error if the file cannot be found. func (lh *cloudFrontStorageMiddleware) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) { // TODO(endophage): currently only supports S3 keyer, ok := lh.StorageDriver.(S3BucketKeyer) @@ -124,6 +186,11 @@ func (lh *cloudFrontStorageMiddleware) URLFor(ctx context.Context, path string, return lh.StorageDriver.URLFor(ctx, path, options) } + if eligibleForS3(ctx, lh.awsIPs) { + return lh.StorageDriver.URLFor(ctx, path, options) + } + + // Get signed cloudfront url. cfURL, err := lh.urlSigner.Sign(lh.baseURL+keyer.S3BucketKey(path), time.Now().Add(lh.duration)) if err != nil { return "", err diff --git a/registry/storage/driver/middleware/cloudfront/s3filter.go b/registry/storage/driver/middleware/cloudfront/s3filter.go new file mode 100644 index 00000000..c8c7f570 --- /dev/null +++ b/registry/storage/driver/middleware/cloudfront/s3filter.go @@ -0,0 +1,223 @@ +package middleware + +import ( + "context" + "encoding/json" + "fmt" + "io/ioutil" + "net" + "net/http" + "strings" + "sync" + "time" + + dcontext "github.com/docker/distribution/context" +) + +const ( + // ipRangesURL is the URL to get definition of AWS IPs + defaultIPRangesURL = "https://ip-ranges.amazonaws.com/ip-ranges.json" + // updateFrequency tells how frequently AWS IPs need to be updated + defaultUpdateFrequency = time.Hour * 12 +) + +// newAWSIPs returns a New awsIP object. +// If awsRegion is `nil`, it accepts any region. Otherwise, it only allow the regions specified +func newAWSIPs(host string, updateFrequency time.Duration, awsRegion []string) *awsIPs { + ips := &awsIPs{ + host: host, + updateFrequency: updateFrequency, + awsRegion: awsRegion, + updaterStopChan: make(chan bool), + } + if err := ips.tryUpdate(); err != nil { + dcontext.GetLogger(context.Background()).WithError(err).Warn("failed to update AWS IP") + } + go ips.updater() + return ips +} + +// awsIPs tracks a list of AWS ips, filtered by awsRegion +type awsIPs struct { + host string + updateFrequency time.Duration + ipv4 []net.IPNet + ipv6 []net.IPNet + mutex sync.RWMutex + awsRegion []string + updaterStopChan chan bool + initialized bool +} + +type awsIPResponse struct { + Prefixes []prefixEntry `json:"prefixes"` + V6Prefixes []prefixEntry `json:"ipv6_prefixes"` +} + +type prefixEntry struct { + IPV4Prefix string `json:"ip_prefix"` + IPV6Prefix string `json:"ipv6_prefix"` + Region string `json:"region"` + Service string `json:"service"` +} + +func fetchAWSIPs(url string) (awsIPResponse, error) { + var response awsIPResponse + resp, err := http.Get(url) + if err != nil { + return response, err + } + if resp.StatusCode != 200 { + body, _ := ioutil.ReadAll(resp.Body) + return response, fmt.Errorf("failed to fetch network data. response = %s", body) + } + decoder := json.NewDecoder(resp.Body) + err = decoder.Decode(&response) + if err != nil { + return response, err + } + return response, nil +} + +// tryUpdate attempts to download the new set of ip addresses. +// tryUpdate must be thread safe with contains +func (s *awsIPs) tryUpdate() error { + response, err := fetchAWSIPs(s.host) + if err != nil { + return err + } + + var ipv4 []net.IPNet + var ipv6 []net.IPNet + + processAddress := func(output *[]net.IPNet, prefix string, region string) { + regionAllowed := false + if len(s.awsRegion) > 0 { + for _, ar := range s.awsRegion { + if strings.ToLower(region) == ar { + regionAllowed = true + break + } + } + } else { + regionAllowed = true + } + + _, network, err := net.ParseCIDR(prefix) + if err != nil { + dcontext.GetLoggerWithFields(dcontext.Background(), map[interface{}]interface{}{ + "cidr": prefix, + }).Error("unparseable cidr") + return + } + if regionAllowed { + *output = append(*output, *network) + } + + } + + for _, prefix := range response.Prefixes { + processAddress(&ipv4, prefix.IPV4Prefix, prefix.Region) + } + for _, prefix := range response.V6Prefixes { + processAddress(&ipv6, prefix.IPV6Prefix, prefix.Region) + } + s.mutex.Lock() + defer s.mutex.Unlock() + // Update each attr of awsips atomically. + s.ipv4 = ipv4 + s.ipv6 = ipv6 + s.initialized = true + return nil +} + +// This function is meant to be run in a background goroutine. +// It will periodically update the ips from aws. +func (s *awsIPs) updater() { + defer close(s.updaterStopChan) + for { + time.Sleep(s.updateFrequency) + select { + case <-s.updaterStopChan: + dcontext.GetLogger(context.Background()).Info("aws ip updater received stop signal") + return + default: + err := s.tryUpdate() + if err != nil { + dcontext.GetLogger(context.Background()).WithError(err).Error("git AWS IP") + } + } + } +} + +// getCandidateNetworks returns either the ipv4 or ipv6 networks +// that were last read from aws. The networks returned +// have the same type as the ip address provided. +func (s *awsIPs) getCandidateNetworks(ip net.IP) []net.IPNet { + s.mutex.RLock() + defer s.mutex.RUnlock() + if ip.To4() != nil { + return s.ipv4 + } else if ip.To16() != nil { + return s.ipv6 + } else { + dcontext.GetLoggerWithFields(dcontext.Background(), map[interface{}]interface{}{ + "ip": ip, + }).Error("unknown ip address format") + // assume mismatch, pass through cloudfront + return nil + } +} + +// Contains determines whether the host is within aws. +func (s *awsIPs) contains(ip net.IP) bool { + networks := s.getCandidateNetworks(ip) + for _, network := range networks { + if network.Contains(ip) { + return true + } + } + return false +} + +// parseIPFromRequest attempts to extract the ip address of the +// client that made the request +func parseIPFromRequest(ctx context.Context) (net.IP, error) { + request, err := dcontext.GetRequest(ctx) + if err != nil { + return nil, err + } + ipStr := dcontext.RemoteIP(request) + ip := net.ParseIP(ipStr) + if ip == nil { + return nil, fmt.Errorf("invalid ip address from requester: %s", ipStr) + } + + return ip, nil +} + +// eligibleForS3 checks if a request is eligible for using S3 directly +// Return true only when the IP belongs to a specific aws region and user-agent is docker +func eligibleForS3(ctx context.Context, awsIPs *awsIPs) bool { + if awsIPs != nil && awsIPs.initialized { + if addr, err := parseIPFromRequest(ctx); err == nil { + request, err := dcontext.GetRequest(ctx) + if err != nil { + dcontext.GetLogger(ctx).Warnf("the CloudFront middleware cannot parse the request: %s", err) + } else { + loggerField := map[interface{}]interface{}{ + "user-client": request.UserAgent(), + "ip": dcontext.RemoteIP(request), + } + if awsIPs.contains(addr) { + dcontext.GetLoggerWithFields(ctx, loggerField).Info("request from the allowed AWS region, skipping CloudFront") + return true + } + dcontext.GetLoggerWithFields(ctx, loggerField).Warn("request not from the allowed AWS region, fallback to CloudFront") + } + } else { + dcontext.GetLogger(ctx).WithError(err).Warn("failed to parse ip address from context, fallback to CloudFront") + } + } + return false +} diff --git a/registry/storage/driver/middleware/cloudfront/s3filter_test.go b/registry/storage/driver/middleware/cloudfront/s3filter_test.go new file mode 100644 index 00000000..6aca3abc --- /dev/null +++ b/registry/storage/driver/middleware/cloudfront/s3filter_test.go @@ -0,0 +1,401 @@ +package middleware + +import ( + "context" + "crypto/rand" + "encoding/json" + "fmt" + dcontext "github.com/docker/distribution/context" + "net" + "net/http" + "net/http/httptest" + "testing" + "time" + + "reflect" // used as a replacement for testify +) + +// Rather than pull in all of testify +func assertEqual(t *testing.T, x, y interface{}) { + if !reflect.DeepEqual(x, y) { + t.Errorf("%s: Not equal! Expected='%v', Actual='%v'\n", t.Name(), x, y) + t.FailNow() + } +} + +type mockIPRangeHandler struct { + data awsIPResponse +} + +func (m mockIPRangeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + bytes, err := json.Marshal(m.data) + if err != nil { + w.WriteHeader(500) + return + } + w.Write(bytes) + +} + +func newTestHandler(data awsIPResponse) *httptest.Server { + return httptest.NewServer(mockIPRangeHandler{ + data: data, + }) +} + +func serverIPRanges(server *httptest.Server) string { + return fmt.Sprintf("%s/", server.URL) +} + +func setupTest(data awsIPResponse) *httptest.Server { + // This is a basic schema which only claims the exact ip + // is in aws. + server := newTestHandler(data) + return server +} + +func TestS3TryUpdate(t *testing.T) { + t.Parallel() + server := setupTest(awsIPResponse{ + Prefixes: []prefixEntry{ + {IPV4Prefix: "123.231.123.231/32"}, + }, + }) + defer server.Close() + + ips := newAWSIPs(serverIPRanges(server), time.Hour, nil) + + assertEqual(t, 1, len(ips.ipv4)) + assertEqual(t, 0, len(ips.ipv6)) + +} + +func TestMatchIPV6(t *testing.T) { + t.Parallel() + server := setupTest(awsIPResponse{ + V6Prefixes: []prefixEntry{ + {IPV6Prefix: "ff00::/16"}, + }, + }) + defer server.Close() + + ips := newAWSIPs(serverIPRanges(server), time.Hour, nil) + ips.tryUpdate() + assertEqual(t, true, ips.contains(net.ParseIP("ff00::"))) + assertEqual(t, 1, len(ips.ipv6)) + assertEqual(t, 0, len(ips.ipv4)) +} + +func TestMatchIPV4(t *testing.T) { + t.Parallel() + server := setupTest(awsIPResponse{ + Prefixes: []prefixEntry{ + {IPV4Prefix: "192.168.0.0/24"}, + }, + }) + defer server.Close() + + ips := newAWSIPs(serverIPRanges(server), time.Hour, nil) + ips.tryUpdate() + assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0"))) + assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1"))) + assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0"))) +} + +func TestMatchIPV4_2(t *testing.T) { + t.Parallel() + server := setupTest(awsIPResponse{ + Prefixes: []prefixEntry{ + { + IPV4Prefix: "192.168.0.0/24", + Region: "us-east-1", + }, + }, + }) + defer server.Close() + + ips := newAWSIPs(serverIPRanges(server), time.Hour, nil) + ips.tryUpdate() + assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0"))) + assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1"))) + assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0"))) +} + +func TestMatchIPV4WithRegionMatched(t *testing.T) { + t.Parallel() + server := setupTest(awsIPResponse{ + Prefixes: []prefixEntry{ + { + IPV4Prefix: "192.168.0.0/24", + Region: "us-east-1", + }, + }, + }) + defer server.Close() + + ips := newAWSIPs(serverIPRanges(server), time.Hour, []string{"us-east-1"}) + ips.tryUpdate() + assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0"))) + assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1"))) + assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0"))) +} + +func TestMatchIPV4WithRegionMatch_2(t *testing.T) { + t.Parallel() + server := setupTest(awsIPResponse{ + Prefixes: []prefixEntry{ + { + IPV4Prefix: "192.168.0.0/24", + Region: "us-east-1", + }, + }, + }) + defer server.Close() + + ips := newAWSIPs(serverIPRanges(server), time.Hour, []string{"us-west-2", "us-east-1"}) + ips.tryUpdate() + assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0"))) + assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1"))) + assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0"))) +} + +func TestMatchIPV4WithRegionNotMatched(t *testing.T) { + t.Parallel() + server := setupTest(awsIPResponse{ + Prefixes: []prefixEntry{ + { + IPV4Prefix: "192.168.0.0/24", + Region: "us-east-1", + }, + }, + }) + defer server.Close() + + ips := newAWSIPs(serverIPRanges(server), time.Hour, []string{"us-west-2"}) + ips.tryUpdate() + assertEqual(t, false, ips.contains(net.ParseIP("192.168.0.0"))) + assertEqual(t, false, ips.contains(net.ParseIP("192.168.0.1"))) + assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0"))) +} + +func TestInvalidData(t *testing.T) { + t.Parallel() + // Invalid entries from aws should be ignored. + server := setupTest(awsIPResponse{ + Prefixes: []prefixEntry{ + {IPV4Prefix: "9000"}, + {IPV4Prefix: "192.168.0.0/24"}, + }, + }) + defer server.Close() + + ips := newAWSIPs(serverIPRanges(server), time.Hour, nil) + ips.tryUpdate() + assertEqual(t, 1, len(ips.ipv4)) +} + +func TestInvalidNetworkType(t *testing.T) { + t.Parallel() + server := setupTest(awsIPResponse{ + Prefixes: []prefixEntry{ + {IPV4Prefix: "192.168.0.0/24"}, + }, + V6Prefixes: []prefixEntry{ + {IPV6Prefix: "ff00::/8"}, + {IPV6Prefix: "fe00::/8"}, + }, + }) + defer server.Close() + + ips := newAWSIPs(serverIPRanges(server), time.Hour, nil) + assertEqual(t, 0, len(ips.getCandidateNetworks(make([]byte, 17)))) // 17 bytes does not correspond to any net type + assertEqual(t, 1, len(ips.getCandidateNetworks(make([]byte, 4)))) // netv4 networks + assertEqual(t, 2, len(ips.getCandidateNetworks(make([]byte, 16)))) // netv6 networks +} + +func TestParsing(t *testing.T) { + var data = `{ + "prefixes": [{ + "ip_prefix": "192.168.0.0", + "region": "someregion", + "service": "s3"}], + "ipv6_prefixes": [{ + "ipv6_prefix": "2001:4860:4860::8888", + "region": "anotherregion", + "service": "ec2"}] + }` + rawMockHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte(data)) }) + t.Parallel() + server := httptest.NewServer(rawMockHandler) + defer server.Close() + schema, err := fetchAWSIPs(server.URL) + + assertEqual(t, nil, err) + assertEqual(t, 1, len(schema.Prefixes)) + assertEqual(t, prefixEntry{ + IPV4Prefix: "192.168.0.0", + Region: "someregion", + Service: "s3", + }, schema.Prefixes[0]) + assertEqual(t, 1, len(schema.V6Prefixes)) + assertEqual(t, prefixEntry{ + IPV6Prefix: "2001:4860:4860::8888", + Region: "anotherregion", + Service: "ec2", + }, schema.V6Prefixes[0]) +} + +func TestUpdateCalledRegularly(t *testing.T) { + t.Parallel() + + updateCount := 0 + server := httptest.NewServer(http.HandlerFunc( + func(rw http.ResponseWriter, req *http.Request) { + updateCount++ + rw.Write([]byte("ok")) + })) + defer server.Close() + newAWSIPs(fmt.Sprintf("%s/", server.URL), time.Second, nil) + time.Sleep(time.Second*4 + time.Millisecond*500) + if updateCount < 4 { + t.Errorf("Update should have been called at least 4 times, actual=%d", updateCount) + } +} + +func TestEligibleForS3(t *testing.T) { + awsIPs := &awsIPs{ + ipv4: []net.IPNet{{ + IP: net.ParseIP("192.168.1.1"), + Mask: net.IPv4Mask(255, 255, 255, 0), + }}, + initialized: true, + } + empty := context.TODO() + makeContext := func(ip string) context.Context { + req := &http.Request{ + RemoteAddr: ip, + } + + return dcontext.WithRequest(empty, req) + } + + cases := []struct { + Context context.Context + Expected bool + }{ + {Context: empty, Expected: false}, + {Context: makeContext("192.168.1.2"), Expected: true}, + {Context: makeContext("192.168.0.2"), Expected: false}, + } + + for _, testCase := range cases { + name := fmt.Sprintf("Client IP = %v", + testCase.Context.Value("http.request.ip")) + t.Run(name, func(t *testing.T) { + assertEqual(t, testCase.Expected, eligibleForS3(testCase.Context, awsIPs)) + }) + } +} + +func TestEligibleForS3WithAWSIPNotInitialized(t *testing.T) { + awsIPs := &awsIPs{ + ipv4: []net.IPNet{{ + IP: net.ParseIP("192.168.1.1"), + Mask: net.IPv4Mask(255, 255, 255, 0), + }}, + initialized: false, + } + empty := context.TODO() + makeContext := func(ip string) context.Context { + req := &http.Request{ + RemoteAddr: ip, + } + + return dcontext.WithRequest(empty, req) + } + + cases := []struct { + Context context.Context + Expected bool + }{ + {Context: empty, Expected: false}, + {Context: makeContext("192.168.1.2"), Expected: false}, + {Context: makeContext("192.168.0.2"), Expected: false}, + } + + for _, testCase := range cases { + name := fmt.Sprintf("Client IP = %v", + testCase.Context.Value("http.request.ip")) + t.Run(name, func(t *testing.T) { + assertEqual(t, testCase.Expected, eligibleForS3(testCase.Context, awsIPs)) + }) + } +} + +// populate ips with a number of different ipv4 and ipv6 networks, for the purposes +// of benchmarking contains() performance. +func populateRandomNetworks(b *testing.B, ips *awsIPs, ipv4Count, ipv6Count int) { + generateNetworks := func(dest *[]net.IPNet, bytes int, count int) { + for i := 0; i < count; i++ { + ip := make([]byte, bytes) + _, err := rand.Read(ip) + if err != nil { + b.Fatalf("failed to generate network for test : %s", err.Error()) + } + mask := make([]byte, bytes) + for i := 0; i < bytes; i++ { + mask[i] = 0xff + } + *dest = append(*dest, net.IPNet{ + IP: ip, + Mask: mask, + }) + } + } + + generateNetworks(&ips.ipv4, 4, ipv4Count) + generateNetworks(&ips.ipv6, 16, ipv6Count) +} + +func BenchmarkContainsRandom(b *testing.B) { + // Generate a random network configuration, of size comparable to + // aws official networks list + // curl -s https://ip-ranges.amazonaws.com/ip-ranges.json | jq '.prefixes | length' + // 941 + numNetworksPerType := 1000 // keep in sync with the above + // intentionally skip constructor when creating awsIPs, to avoid updater routine. + // This benchmark is only concerned with contains() performance. + awsIPs := awsIPs{} + populateRandomNetworks(b, &awsIPs, numNetworksPerType, numNetworksPerType) + + ipv4 := make([][]byte, b.N) + ipv6 := make([][]byte, b.N) + for i := 0; i < b.N; i++ { + ipv4[i] = make([]byte, 4) + ipv6[i] = make([]byte, 16) + rand.Read(ipv4[i]) + rand.Read(ipv6[i]) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + awsIPs.contains(ipv4[i]) + awsIPs.contains(ipv6[i]) + } +} + +func BenchmarkContainsProd(b *testing.B) { + awsIPs := newAWSIPs(defaultIPRangesURL, defaultUpdateFrequency, nil) + ipv4 := make([][]byte, b.N) + ipv6 := make([][]byte, b.N) + for i := 0; i < b.N; i++ { + ipv4[i] = make([]byte, 4) + ipv6[i] = make([]byte, 16) + rand.Read(ipv4[i]) + rand.Read(ipv6[i]) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + awsIPs.contains(ipv4[i]) + awsIPs.contains(ipv6[i]) + } +} diff --git a/vendor/github.com/miekg/dns/msg_generate.go b/vendor/github.com/miekg/dns/msg_generate.go deleted file mode 100644 index 35786f22..00000000 --- a/vendor/github.com/miekg/dns/msg_generate.go +++ /dev/null @@ -1,340 +0,0 @@ -//+build ignore - -// msg_generate.go is meant to run with go generate. It will use -// go/{importer,types} to track down all the RR struct types. Then for each type -// it will generate pack/unpack methods based on the struct tags. The generated source is -// written to zmsg.go, and is meant to be checked into git. -package main - -import ( - "bytes" - "fmt" - "go/format" - "go/importer" - "go/types" - "log" - "os" - "strings" -) - -var packageHdr = ` -// *** DO NOT MODIFY *** -// AUTOGENERATED BY go generate from msg_generate.go - -package dns - -` - -// getTypeStruct will take a type and the package scope, and return the -// (innermost) struct if the type is considered a RR type (currently defined as -// those structs beginning with a RR_Header, could be redefined as implementing -// the RR interface). The bool return value indicates if embedded structs were -// resolved. -func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) { - st, ok := t.Underlying().(*types.Struct) - if !ok { - return nil, false - } - if st.Field(0).Type() == scope.Lookup("RR_Header").Type() { - return st, false - } - if st.Field(0).Anonymous() { - st, _ := getTypeStruct(st.Field(0).Type(), scope) - return st, true - } - return nil, false -} - -func main() { - // Import and type-check the package - pkg, err := importer.Default().Import("github.com/miekg/dns") - fatalIfErr(err) - scope := pkg.Scope() - - // Collect actual types (*X) - var namedTypes []string - for _, name := range scope.Names() { - o := scope.Lookup(name) - if o == nil || !o.Exported() { - continue - } - if st, _ := getTypeStruct(o.Type(), scope); st == nil { - continue - } - if name == "PrivateRR" { - continue - } - - // Check if corresponding TypeX exists - if scope.Lookup("Type"+o.Name()) == nil && o.Name() != "RFC3597" { - log.Fatalf("Constant Type%s does not exist.", o.Name()) - } - - namedTypes = append(namedTypes, o.Name()) - } - - b := &bytes.Buffer{} - b.WriteString(packageHdr) - - fmt.Fprint(b, "// pack*() functions\n\n") - for _, name := range namedTypes { - o := scope.Lookup(name) - st, _ := getTypeStruct(o.Type(), scope) - - fmt.Fprintf(b, "func (rr *%s) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) {\n", name) - fmt.Fprint(b, `off, err := rr.Hdr.pack(msg, off, compression, compress) -if err != nil { - return off, err -} -headerEnd := off -`) - for i := 1; i < st.NumFields(); i++ { - o := func(s string) { - fmt.Fprintf(b, s, st.Field(i).Name()) - fmt.Fprint(b, `if err != nil { -return off, err -} -`) - } - - if _, ok := st.Field(i).Type().(*types.Slice); ok { - switch st.Tag(i) { - case `dns:"-"`: // ignored - case `dns:"txt"`: - o("off, err = packStringTxt(rr.%s, msg, off)\n") - case `dns:"opt"`: - o("off, err = packDataOpt(rr.%s, msg, off)\n") - case `dns:"nsec"`: - o("off, err = packDataNsec(rr.%s, msg, off)\n") - case `dns:"domain-name"`: - o("off, err = packDataDomainNames(rr.%s, msg, off, compression, compress)\n") - default: - log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) - } - continue - } - - switch { - case st.Tag(i) == `dns:"-"`: // ignored - case st.Tag(i) == `dns:"cdomain-name"`: - fallthrough - case st.Tag(i) == `dns:"domain-name"`: - o("off, err = PackDomainName(rr.%s, msg, off, compression, compress)\n") - case st.Tag(i) == `dns:"a"`: - o("off, err = packDataA(rr.%s, msg, off)\n") - case st.Tag(i) == `dns:"aaaa"`: - o("off, err = packDataAAAA(rr.%s, msg, off)\n") - case st.Tag(i) == `dns:"uint48"`: - o("off, err = packUint48(rr.%s, msg, off)\n") - case st.Tag(i) == `dns:"txt"`: - o("off, err = packString(rr.%s, msg, off)\n") - - case strings.HasPrefix(st.Tag(i), `dns:"size-base32`): // size-base32 can be packed just like base32 - fallthrough - case st.Tag(i) == `dns:"base32"`: - o("off, err = packStringBase32(rr.%s, msg, off)\n") - - case strings.HasPrefix(st.Tag(i), `dns:"size-base64`): // size-base64 can be packed just like base64 - fallthrough - case st.Tag(i) == `dns:"base64"`: - o("off, err = packStringBase64(rr.%s, msg, off)\n") - - case strings.HasPrefix(st.Tag(i), `dns:"size-hex:SaltLength`): // Hack to fix empty salt length for NSEC3 - o("if rr.%s == \"-\" { /* do nothing, empty salt */ }\n") - continue - case strings.HasPrefix(st.Tag(i), `dns:"size-hex`): // size-hex can be packed just like hex - fallthrough - case st.Tag(i) == `dns:"hex"`: - o("off, err = packStringHex(rr.%s, msg, off)\n") - - case st.Tag(i) == `dns:"octet"`: - o("off, err = packStringOctet(rr.%s, msg, off)\n") - case st.Tag(i) == "": - switch st.Field(i).Type().(*types.Basic).Kind() { - case types.Uint8: - o("off, err = packUint8(rr.%s, msg, off)\n") - case types.Uint16: - o("off, err = packUint16(rr.%s, msg, off)\n") - case types.Uint32: - o("off, err = packUint32(rr.%s, msg, off)\n") - case types.Uint64: - o("off, err = packUint64(rr.%s, msg, off)\n") - case types.String: - o("off, err = packString(rr.%s, msg, off)\n") - default: - log.Fatalln(name, st.Field(i).Name()) - } - default: - log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) - } - } - // We have packed everything, only now we know the rdlength of this RR - fmt.Fprintln(b, "rr.Header().Rdlength = uint16(off-headerEnd)") - fmt.Fprintln(b, "return off, nil }\n") - } - - fmt.Fprint(b, "// unpack*() functions\n\n") - for _, name := range namedTypes { - o := scope.Lookup(name) - st, _ := getTypeStruct(o.Type(), scope) - - fmt.Fprintf(b, "func unpack%s(h RR_Header, msg []byte, off int) (RR, int, error) {\n", name) - fmt.Fprintf(b, "rr := new(%s)\n", name) - fmt.Fprint(b, "rr.Hdr = h\n") - fmt.Fprint(b, `if noRdata(h) { -return rr, off, nil - } -var err error -rdStart := off -_ = rdStart - -`) - for i := 1; i < st.NumFields(); i++ { - o := func(s string) { - fmt.Fprintf(b, s, st.Field(i).Name()) - fmt.Fprint(b, `if err != nil { -return rr, off, err -} -`) - } - - // size-* are special, because they reference a struct member we should use for the length. - if strings.HasPrefix(st.Tag(i), `dns:"size-`) { - structMember := structMember(st.Tag(i)) - structTag := structTag(st.Tag(i)) - switch structTag { - case "hex": - fmt.Fprintf(b, "rr.%s, off, err = unpackStringHex(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember) - case "base32": - fmt.Fprintf(b, "rr.%s, off, err = unpackStringBase32(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember) - case "base64": - fmt.Fprintf(b, "rr.%s, off, err = unpackStringBase64(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember) - default: - log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) - } - fmt.Fprint(b, `if err != nil { -return rr, off, err -} -`) - continue - } - - if _, ok := st.Field(i).Type().(*types.Slice); ok { - switch st.Tag(i) { - case `dns:"-"`: // ignored - case `dns:"txt"`: - o("rr.%s, off, err = unpackStringTxt(msg, off)\n") - case `dns:"opt"`: - o("rr.%s, off, err = unpackDataOpt(msg, off)\n") - case `dns:"nsec"`: - o("rr.%s, off, err = unpackDataNsec(msg, off)\n") - case `dns:"domain-name"`: - o("rr.%s, off, err = unpackDataDomainNames(msg, off, rdStart + int(rr.Hdr.Rdlength))\n") - default: - log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) - } - continue - } - - switch st.Tag(i) { - case `dns:"-"`: // ignored - case `dns:"cdomain-name"`: - fallthrough - case `dns:"domain-name"`: - o("rr.%s, off, err = UnpackDomainName(msg, off)\n") - case `dns:"a"`: - o("rr.%s, off, err = unpackDataA(msg, off)\n") - case `dns:"aaaa"`: - o("rr.%s, off, err = unpackDataAAAA(msg, off)\n") - case `dns:"uint48"`: - o("rr.%s, off, err = unpackUint48(msg, off)\n") - case `dns:"txt"`: - o("rr.%s, off, err = unpackString(msg, off)\n") - case `dns:"base32"`: - o("rr.%s, off, err = unpackStringBase32(msg, off, rdStart + int(rr.Hdr.Rdlength))\n") - case `dns:"base64"`: - o("rr.%s, off, err = unpackStringBase64(msg, off, rdStart + int(rr.Hdr.Rdlength))\n") - case `dns:"hex"`: - o("rr.%s, off, err = unpackStringHex(msg, off, rdStart + int(rr.Hdr.Rdlength))\n") - case `dns:"octet"`: - o("rr.%s, off, err = unpackStringOctet(msg, off)\n") - case "": - switch st.Field(i).Type().(*types.Basic).Kind() { - case types.Uint8: - o("rr.%s, off, err = unpackUint8(msg, off)\n") - case types.Uint16: - o("rr.%s, off, err = unpackUint16(msg, off)\n") - case types.Uint32: - o("rr.%s, off, err = unpackUint32(msg, off)\n") - case types.Uint64: - o("rr.%s, off, err = unpackUint64(msg, off)\n") - case types.String: - o("rr.%s, off, err = unpackString(msg, off)\n") - default: - log.Fatalln(name, st.Field(i).Name()) - } - default: - log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) - } - // If we've hit len(msg) we return without error. - if i < st.NumFields()-1 { - fmt.Fprintf(b, `if off == len(msg) { -return rr, off, nil - } -`) - } - } - fmt.Fprintf(b, "return rr, off, err }\n\n") - } - // Generate typeToUnpack map - fmt.Fprintln(b, "var typeToUnpack = map[uint16]func(RR_Header, []byte, int) (RR, int, error){") - for _, name := range namedTypes { - if name == "RFC3597" { - continue - } - fmt.Fprintf(b, "Type%s: unpack%s,\n", name, name) - } - fmt.Fprintln(b, "}\n") - - // gofmt - res, err := format.Source(b.Bytes()) - if err != nil { - b.WriteTo(os.Stderr) - log.Fatal(err) - } - - // write result - f, err := os.Create("zmsg.go") - fatalIfErr(err) - defer f.Close() - f.Write(res) -} - -// structMember will take a tag like dns:"size-base32:SaltLength" and return the last part of this string. -func structMember(s string) string { - fields := strings.Split(s, ":") - if len(fields) == 0 { - return "" - } - f := fields[len(fields)-1] - // f should have a closing " - if len(f) > 1 { - return f[:len(f)-1] - } - return f -} - -// structTag will take a tag like dns:"size-base32:SaltLength" and return base32. -func structTag(s string) string { - fields := strings.Split(s, ":") - if len(fields) < 2 { - return "" - } - return fields[1][len("\"size-"):] -} - -func fatalIfErr(err error) { - if err != nil { - log.Fatal(err) - } -} diff --git a/vendor/github.com/miekg/dns/types_generate.go b/vendor/github.com/miekg/dns/types_generate.go deleted file mode 100644 index bf80da32..00000000 --- a/vendor/github.com/miekg/dns/types_generate.go +++ /dev/null @@ -1,271 +0,0 @@ -//+build ignore - -// types_generate.go is meant to run with go generate. It will use -// go/{importer,types} to track down all the RR struct types. Then for each type -// it will generate conversion tables (TypeToRR and TypeToString) and banal -// methods (len, Header, copy) based on the struct tags. The generated source is -// written to ztypes.go, and is meant to be checked into git. -package main - -import ( - "bytes" - "fmt" - "go/format" - "go/importer" - "go/types" - "log" - "os" - "strings" - "text/template" -) - -var skipLen = map[string]struct{}{ - "NSEC": {}, - "NSEC3": {}, - "OPT": {}, -} - -var packageHdr = ` -// *** DO NOT MODIFY *** -// AUTOGENERATED BY go generate from type_generate.go - -package dns - -import ( - "encoding/base64" - "net" -) - -` - -var TypeToRR = template.Must(template.New("TypeToRR").Parse(` -// TypeToRR is a map of constructors for each RR type. -var TypeToRR = map[uint16]func() RR{ -{{range .}}{{if ne . "RFC3597"}} Type{{.}}: func() RR { return new({{.}}) }, -{{end}}{{end}} } - -`)) - -var typeToString = template.Must(template.New("typeToString").Parse(` -// TypeToString is a map of strings for each RR type. -var TypeToString = map[uint16]string{ -{{range .}}{{if ne . "NSAPPTR"}} Type{{.}}: "{{.}}", -{{end}}{{end}} TypeNSAPPTR: "NSAP-PTR", -} - -`)) - -var headerFunc = template.Must(template.New("headerFunc").Parse(` -// Header() functions -{{range .}} func (rr *{{.}}) Header() *RR_Header { return &rr.Hdr } -{{end}} - -`)) - -// getTypeStruct will take a type and the package scope, and return the -// (innermost) struct if the type is considered a RR type (currently defined as -// those structs beginning with a RR_Header, could be redefined as implementing -// the RR interface). The bool return value indicates if embedded structs were -// resolved. -func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) { - st, ok := t.Underlying().(*types.Struct) - if !ok { - return nil, false - } - if st.Field(0).Type() == scope.Lookup("RR_Header").Type() { - return st, false - } - if st.Field(0).Anonymous() { - st, _ := getTypeStruct(st.Field(0).Type(), scope) - return st, true - } - return nil, false -} - -func main() { - // Import and type-check the package - pkg, err := importer.Default().Import("github.com/miekg/dns") - fatalIfErr(err) - scope := pkg.Scope() - - // Collect constants like TypeX - var numberedTypes []string - for _, name := range scope.Names() { - o := scope.Lookup(name) - if o == nil || !o.Exported() { - continue - } - b, ok := o.Type().(*types.Basic) - if !ok || b.Kind() != types.Uint16 { - continue - } - if !strings.HasPrefix(o.Name(), "Type") { - continue - } - name := strings.TrimPrefix(o.Name(), "Type") - if name == "PrivateRR" { - continue - } - numberedTypes = append(numberedTypes, name) - } - - // Collect actual types (*X) - var namedTypes []string - for _, name := range scope.Names() { - o := scope.Lookup(name) - if o == nil || !o.Exported() { - continue - } - if st, _ := getTypeStruct(o.Type(), scope); st == nil { - continue - } - if name == "PrivateRR" { - continue - } - - // Check if corresponding TypeX exists - if scope.Lookup("Type"+o.Name()) == nil && o.Name() != "RFC3597" { - log.Fatalf("Constant Type%s does not exist.", o.Name()) - } - - namedTypes = append(namedTypes, o.Name()) - } - - b := &bytes.Buffer{} - b.WriteString(packageHdr) - - // Generate TypeToRR - fatalIfErr(TypeToRR.Execute(b, namedTypes)) - - // Generate typeToString - fatalIfErr(typeToString.Execute(b, numberedTypes)) - - // Generate headerFunc - fatalIfErr(headerFunc.Execute(b, namedTypes)) - - // Generate len() - fmt.Fprint(b, "// len() functions\n") - for _, name := range namedTypes { - if _, ok := skipLen[name]; ok { - continue - } - o := scope.Lookup(name) - st, isEmbedded := getTypeStruct(o.Type(), scope) - if isEmbedded { - continue - } - fmt.Fprintf(b, "func (rr *%s) len() int {\n", name) - fmt.Fprintf(b, "l := rr.Hdr.len()\n") - for i := 1; i < st.NumFields(); i++ { - o := func(s string) { fmt.Fprintf(b, s, st.Field(i).Name()) } - - if _, ok := st.Field(i).Type().(*types.Slice); ok { - switch st.Tag(i) { - case `dns:"-"`: - // ignored - case `dns:"cdomain-name"`, `dns:"domain-name"`, `dns:"txt"`: - o("for _, x := range rr.%s { l += len(x) + 1 }\n") - default: - log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) - } - continue - } - - switch { - case st.Tag(i) == `dns:"-"`: - // ignored - case st.Tag(i) == `dns:"cdomain-name"`, st.Tag(i) == `dns:"domain-name"`: - o("l += len(rr.%s) + 1\n") - case st.Tag(i) == `dns:"octet"`: - o("l += len(rr.%s)\n") - case strings.HasPrefix(st.Tag(i), `dns:"size-base64`): - fallthrough - case st.Tag(i) == `dns:"base64"`: - o("l += base64.StdEncoding.DecodedLen(len(rr.%s))\n") - case strings.HasPrefix(st.Tag(i), `dns:"size-hex`): - fallthrough - case st.Tag(i) == `dns:"hex"`: - o("l += len(rr.%s)/2 + 1\n") - case st.Tag(i) == `dns:"a"`: - o("l += net.IPv4len // %s\n") - case st.Tag(i) == `dns:"aaaa"`: - o("l += net.IPv6len // %s\n") - case st.Tag(i) == `dns:"txt"`: - o("for _, t := range rr.%s { l += len(t) + 1 }\n") - case st.Tag(i) == `dns:"uint48"`: - o("l += 6 // %s\n") - case st.Tag(i) == "": - switch st.Field(i).Type().(*types.Basic).Kind() { - case types.Uint8: - o("l += 1 // %s\n") - case types.Uint16: - o("l += 2 // %s\n") - case types.Uint32: - o("l += 4 // %s\n") - case types.Uint64: - o("l += 8 // %s\n") - case types.String: - o("l += len(rr.%s) + 1\n") - default: - log.Fatalln(name, st.Field(i).Name()) - } - default: - log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) - } - } - fmt.Fprintf(b, "return l }\n") - } - - // Generate copy() - fmt.Fprint(b, "// copy() functions\n") - for _, name := range namedTypes { - o := scope.Lookup(name) - st, isEmbedded := getTypeStruct(o.Type(), scope) - if isEmbedded { - continue - } - fmt.Fprintf(b, "func (rr *%s) copy() RR {\n", name) - fields := []string{"*rr.Hdr.copyHeader()"} - for i := 1; i < st.NumFields(); i++ { - f := st.Field(i).Name() - if sl, ok := st.Field(i).Type().(*types.Slice); ok { - t := sl.Underlying().String() - t = strings.TrimPrefix(t, "[]") - if strings.Contains(t, ".") { - splits := strings.Split(t, ".") - t = splits[len(splits)-1] - } - fmt.Fprintf(b, "%s := make([]%s, len(rr.%s)); copy(%s, rr.%s)\n", - f, t, f, f, f) - fields = append(fields, f) - continue - } - if st.Field(i).Type().String() == "net.IP" { - fields = append(fields, "copyIP(rr."+f+")") - continue - } - fields = append(fields, "rr."+f) - } - fmt.Fprintf(b, "return &%s{%s}\n", name, strings.Join(fields, ",")) - fmt.Fprintf(b, "}\n") - } - - // gofmt - res, err := format.Source(b.Bytes()) - if err != nil { - b.WriteTo(os.Stderr) - log.Fatal(err) - } - - // write result - f, err := os.Create("ztypes.go") - fatalIfErr(err) - defer f.Close() - f.Write(res) -} - -func fatalIfErr(err error) { - if err != nil { - log.Fatal(err) - } -} diff --git a/vendor/golang.org/x/net/idna/idna.go b/vendor/golang.org/x/net/idna/idna.go deleted file mode 100644 index 3daa8979..00000000 --- a/vendor/golang.org/x/net/idna/idna.go +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2012 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package idna implements IDNA2008 (Internationalized Domain Names for -// Applications), defined in RFC 5890, RFC 5891, RFC 5892, RFC 5893 and -// RFC 5894. -package idna // import "golang.org/x/net/idna" - -import ( - "strings" - "unicode/utf8" -) - -// TODO(nigeltao): specify when errors occur. For example, is ToASCII(".") or -// ToASCII("foo\x00") an error? See also http://www.unicode.org/faq/idn.html#11 - -// acePrefix is the ASCII Compatible Encoding prefix. -const acePrefix = "xn--" - -// ToASCII converts a domain or domain label to its ASCII form. For example, -// ToASCII("bücher.example.com") is "xn--bcher-kva.example.com", and -// ToASCII("golang") is "golang". -func ToASCII(s string) (string, error) { - if ascii(s) { - return s, nil - } - labels := strings.Split(s, ".") - for i, label := range labels { - if !ascii(label) { - a, err := encode(acePrefix, label) - if err != nil { - return "", err - } - labels[i] = a - } - } - return strings.Join(labels, "."), nil -} - -// ToUnicode converts a domain or domain label to its Unicode form. For example, -// ToUnicode("xn--bcher-kva.example.com") is "bücher.example.com", and -// ToUnicode("golang") is "golang". -func ToUnicode(s string) (string, error) { - if !strings.Contains(s, acePrefix) { - return s, nil - } - labels := strings.Split(s, ".") - for i, label := range labels { - if strings.HasPrefix(label, acePrefix) { - u, err := decode(label[len(acePrefix):]) - if err != nil { - return "", err - } - labels[i] = u - } - } - return strings.Join(labels, "."), nil -} - -func ascii(s string) bool { - for i := 0; i < len(s); i++ { - if s[i] >= utf8.RuneSelf { - return false - } - } - return true -} diff --git a/vendor/golang.org/x/net/idna/punycode.go b/vendor/golang.org/x/net/idna/punycode.go deleted file mode 100644 index 92e733f6..00000000 --- a/vendor/golang.org/x/net/idna/punycode.go +++ /dev/null @@ -1,200 +0,0 @@ -// Copyright 2012 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package idna - -// This file implements the Punycode algorithm from RFC 3492. - -import ( - "fmt" - "math" - "strings" - "unicode/utf8" -) - -// These parameter values are specified in section 5. -// -// All computation is done with int32s, so that overflow behavior is identical -// regardless of whether int is 32-bit or 64-bit. -const ( - base int32 = 36 - damp int32 = 700 - initialBias int32 = 72 - initialN int32 = 128 - skew int32 = 38 - tmax int32 = 26 - tmin int32 = 1 -) - -// decode decodes a string as specified in section 6.2. -func decode(encoded string) (string, error) { - if encoded == "" { - return "", nil - } - pos := 1 + strings.LastIndex(encoded, "-") - if pos == 1 { - return "", fmt.Errorf("idna: invalid label %q", encoded) - } - if pos == len(encoded) { - return encoded[:len(encoded)-1], nil - } - output := make([]rune, 0, len(encoded)) - if pos != 0 { - for _, r := range encoded[:pos-1] { - output = append(output, r) - } - } - i, n, bias := int32(0), initialN, initialBias - for pos < len(encoded) { - oldI, w := i, int32(1) - for k := base; ; k += base { - if pos == len(encoded) { - return "", fmt.Errorf("idna: invalid label %q", encoded) - } - digit, ok := decodeDigit(encoded[pos]) - if !ok { - return "", fmt.Errorf("idna: invalid label %q", encoded) - } - pos++ - i += digit * w - if i < 0 { - return "", fmt.Errorf("idna: invalid label %q", encoded) - } - t := k - bias - if t < tmin { - t = tmin - } else if t > tmax { - t = tmax - } - if digit < t { - break - } - w *= base - t - if w >= math.MaxInt32/base { - return "", fmt.Errorf("idna: invalid label %q", encoded) - } - } - x := int32(len(output) + 1) - bias = adapt(i-oldI, x, oldI == 0) - n += i / x - i %= x - if n > utf8.MaxRune || len(output) >= 1024 { - return "", fmt.Errorf("idna: invalid label %q", encoded) - } - output = append(output, 0) - copy(output[i+1:], output[i:]) - output[i] = n - i++ - } - return string(output), nil -} - -// encode encodes a string as specified in section 6.3 and prepends prefix to -// the result. -// -// The "while h < length(input)" line in the specification becomes "for -// remaining != 0" in the Go code, because len(s) in Go is in bytes, not runes. -func encode(prefix, s string) (string, error) { - output := make([]byte, len(prefix), len(prefix)+1+2*len(s)) - copy(output, prefix) - delta, n, bias := int32(0), initialN, initialBias - b, remaining := int32(0), int32(0) - for _, r := range s { - if r < 0x80 { - b++ - output = append(output, byte(r)) - } else { - remaining++ - } - } - h := b - if b > 0 { - output = append(output, '-') - } - for remaining != 0 { - m := int32(0x7fffffff) - for _, r := range s { - if m > r && r >= n { - m = r - } - } - delta += (m - n) * (h + 1) - if delta < 0 { - return "", fmt.Errorf("idna: invalid label %q", s) - } - n = m - for _, r := range s { - if r < n { - delta++ - if delta < 0 { - return "", fmt.Errorf("idna: invalid label %q", s) - } - continue - } - if r > n { - continue - } - q := delta - for k := base; ; k += base { - t := k - bias - if t < tmin { - t = tmin - } else if t > tmax { - t = tmax - } - if q < t { - break - } - output = append(output, encodeDigit(t+(q-t)%(base-t))) - q = (q - t) / (base - t) - } - output = append(output, encodeDigit(q)) - bias = adapt(delta, h+1, h == b) - delta = 0 - h++ - remaining-- - } - delta++ - n++ - } - return string(output), nil -} - -func decodeDigit(x byte) (digit int32, ok bool) { - switch { - case '0' <= x && x <= '9': - return int32(x - ('0' - 26)), true - case 'A' <= x && x <= 'Z': - return int32(x - 'A'), true - case 'a' <= x && x <= 'z': - return int32(x - 'a'), true - } - return 0, false -} - -func encodeDigit(digit int32) byte { - switch { - case 0 <= digit && digit < 26: - return byte(digit + 'a') - case 26 <= digit && digit < 36: - return byte(digit + ('0' - 26)) - } - panic("idna: internal error in punycode encoding") -} - -// adapt is the bias adaptation function specified in section 6.1. -func adapt(delta, numPoints int32, firstTime bool) int32 { - if firstTime { - delta /= damp - } else { - delta /= 2 - } - delta += delta / numPoints - k := int32(0) - for delta > ((base-tmin)*tmax)/2 { - delta /= base - tmin - k += base - } - return k + (base-tmin+1)*delta/(delta+skew) -} diff --git a/vendor/golang.org/x/net/publicsuffix/gen.go b/vendor/golang.org/x/net/publicsuffix/gen.go deleted file mode 100644 index 5c8d7b5f..00000000 --- a/vendor/golang.org/x/net/publicsuffix/gen.go +++ /dev/null @@ -1,663 +0,0 @@ -// Copyright 2012 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build ignore - -package main - -// This program generates table.go and table_test.go. -// Invoke as: -// -// go run gen.go -version "xxx" >table.go -// go run gen.go -version "xxx" -test >table_test.go -// -// Pass -v to print verbose progress information. -// -// The version is derived from information found at -// https://github.com/publicsuffix/list/commits/master/public_suffix_list.dat -// -// To fetch a particular git revision, such as 5c70ccd250, pass -// -url "https://raw.githubusercontent.com/publicsuffix/list/5c70ccd250/public_suffix_list.dat" - -import ( - "bufio" - "bytes" - "flag" - "fmt" - "go/format" - "io" - "net/http" - "os" - "regexp" - "sort" - "strings" - - "golang.org/x/net/idna" -) - -const ( - // These sum of these four values must be no greater than 32. - nodesBitsChildren = 9 - nodesBitsICANN = 1 - nodesBitsTextOffset = 15 - nodesBitsTextLength = 6 - - // These sum of these four values must be no greater than 32. - childrenBitsWildcard = 1 - childrenBitsNodeType = 2 - childrenBitsHi = 14 - childrenBitsLo = 14 -) - -var ( - maxChildren int - maxTextOffset int - maxTextLength int - maxHi uint32 - maxLo uint32 -) - -func max(a, b int) int { - if a < b { - return b - } - return a -} - -func u32max(a, b uint32) uint32 { - if a < b { - return b - } - return a -} - -const ( - nodeTypeNormal = 0 - nodeTypeException = 1 - nodeTypeParentOnly = 2 - numNodeType = 3 -) - -func nodeTypeStr(n int) string { - switch n { - case nodeTypeNormal: - return "+" - case nodeTypeException: - return "!" - case nodeTypeParentOnly: - return "o" - } - panic("unreachable") -} - -var ( - labelEncoding = map[string]uint32{} - labelsList = []string{} - labelsMap = map[string]bool{} - rules = []string{} - - // validSuffix is used to check that the entries in the public suffix list - // are in canonical form (after Punycode encoding). Specifically, capital - // letters are not allowed. - validSuffix = regexp.MustCompile(`^[a-z0-9_\!\*\-\.]+$`) - - subset = flag.Bool("subset", false, "generate only a subset of the full table, for debugging") - url = flag.String("url", - "https://publicsuffix.org/list/effective_tld_names.dat", - "URL of the publicsuffix.org list. If empty, stdin is read instead") - v = flag.Bool("v", false, "verbose output (to stderr)") - version = flag.String("version", "", "the effective_tld_names.dat version") - test = flag.Bool("test", false, "generate table_test.go") -) - -func main() { - if err := main1(); err != nil { - fmt.Fprintln(os.Stderr, err) - os.Exit(1) - } -} - -func main1() error { - flag.Parse() - if nodesBitsTextLength+nodesBitsTextOffset+nodesBitsICANN+nodesBitsChildren > 32 { - return fmt.Errorf("not enough bits to encode the nodes table") - } - if childrenBitsLo+childrenBitsHi+childrenBitsNodeType+childrenBitsWildcard > 32 { - return fmt.Errorf("not enough bits to encode the children table") - } - if *version == "" { - return fmt.Errorf("-version was not specified") - } - var r io.Reader = os.Stdin - if *url != "" { - res, err := http.Get(*url) - if err != nil { - return err - } - if res.StatusCode != http.StatusOK { - return fmt.Errorf("bad GET status for %s: %d", *url, res.Status) - } - r = res.Body - defer res.Body.Close() - } - - var root node - icann := false - buf := new(bytes.Buffer) - br := bufio.NewReader(r) - for { - s, err := br.ReadString('\n') - if err != nil { - if err == io.EOF { - break - } - return err - } - s = strings.TrimSpace(s) - if strings.Contains(s, "BEGIN ICANN DOMAINS") { - icann = true - continue - } - if strings.Contains(s, "END ICANN DOMAINS") { - icann = false - continue - } - if s == "" || strings.HasPrefix(s, "//") { - continue - } - s, err = idna.ToASCII(s) - if err != nil { - return err - } - if !validSuffix.MatchString(s) { - return fmt.Errorf("bad publicsuffix.org list data: %q", s) - } - - if *subset { - switch { - case s == "ac.jp" || strings.HasSuffix(s, ".ac.jp"): - case s == "ak.us" || strings.HasSuffix(s, ".ak.us"): - case s == "ao" || strings.HasSuffix(s, ".ao"): - case s == "ar" || strings.HasSuffix(s, ".ar"): - case s == "arpa" || strings.HasSuffix(s, ".arpa"): - case s == "cy" || strings.HasSuffix(s, ".cy"): - case s == "dyndns.org" || strings.HasSuffix(s, ".dyndns.org"): - case s == "jp": - case s == "kobe.jp" || strings.HasSuffix(s, ".kobe.jp"): - case s == "kyoto.jp" || strings.HasSuffix(s, ".kyoto.jp"): - case s == "om" || strings.HasSuffix(s, ".om"): - case s == "uk" || strings.HasSuffix(s, ".uk"): - case s == "uk.com" || strings.HasSuffix(s, ".uk.com"): - case s == "tw" || strings.HasSuffix(s, ".tw"): - case s == "zw" || strings.HasSuffix(s, ".zw"): - case s == "xn--p1ai" || strings.HasSuffix(s, ".xn--p1ai"): - // xn--p1ai is Russian-Cyrillic "рф". - default: - continue - } - } - - rules = append(rules, s) - - nt, wildcard := nodeTypeNormal, false - switch { - case strings.HasPrefix(s, "*."): - s, nt = s[2:], nodeTypeParentOnly - wildcard = true - case strings.HasPrefix(s, "!"): - s, nt = s[1:], nodeTypeException - } - labels := strings.Split(s, ".") - for n, i := &root, len(labels)-1; i >= 0; i-- { - label := labels[i] - n = n.child(label) - if i == 0 { - if nt != nodeTypeParentOnly && n.nodeType == nodeTypeParentOnly { - n.nodeType = nt - } - n.icann = n.icann && icann - n.wildcard = n.wildcard || wildcard - } - labelsMap[label] = true - } - } - labelsList = make([]string, 0, len(labelsMap)) - for label := range labelsMap { - labelsList = append(labelsList, label) - } - sort.Strings(labelsList) - - p := printReal - if *test { - p = printTest - } - if err := p(buf, &root); err != nil { - return err - } - - b, err := format.Source(buf.Bytes()) - if err != nil { - return err - } - _, err = os.Stdout.Write(b) - return err -} - -func printTest(w io.Writer, n *node) error { - fmt.Fprintf(w, "// generated by go run gen.go; DO NOT EDIT\n\n") - fmt.Fprintf(w, "package publicsuffix\n\nvar rules = [...]string{\n") - for _, rule := range rules { - fmt.Fprintf(w, "%q,\n", rule) - } - fmt.Fprintf(w, "}\n\nvar nodeLabels = [...]string{\n") - if err := n.walk(w, printNodeLabel); err != nil { - return err - } - fmt.Fprintf(w, "}\n") - return nil -} - -func printReal(w io.Writer, n *node) error { - const header = `// generated by go run gen.go; DO NOT EDIT - -package publicsuffix - -const version = %q - -const ( - nodesBitsChildren = %d - nodesBitsICANN = %d - nodesBitsTextOffset = %d - nodesBitsTextLength = %d - - childrenBitsWildcard = %d - childrenBitsNodeType = %d - childrenBitsHi = %d - childrenBitsLo = %d -) - -const ( - nodeTypeNormal = %d - nodeTypeException = %d - nodeTypeParentOnly = %d -) - -// numTLD is the number of top level domains. -const numTLD = %d - -` - fmt.Fprintf(w, header, *version, - nodesBitsChildren, nodesBitsICANN, nodesBitsTextOffset, nodesBitsTextLength, - childrenBitsWildcard, childrenBitsNodeType, childrenBitsHi, childrenBitsLo, - nodeTypeNormal, nodeTypeException, nodeTypeParentOnly, len(n.children)) - - text := combineText(labelsList) - if text == "" { - return fmt.Errorf("internal error: makeText returned no text") - } - for _, label := range labelsList { - offset, length := strings.Index(text, label), len(label) - if offset < 0 { - return fmt.Errorf("internal error: could not find %q in text %q", label, text) - } - maxTextOffset, maxTextLength = max(maxTextOffset, offset), max(maxTextLength, length) - if offset >= 1<= 1< 64 { - n, plus = 64, " +" - } - fmt.Fprintf(w, "%q%s\n", text[:n], plus) - text = text[n:] - } - - if err := n.walk(w, assignIndexes); err != nil { - return err - } - - fmt.Fprintf(w, ` - -// nodes is the list of nodes. Each node is represented as a uint32, which -// encodes the node's children, wildcard bit and node type (as an index into -// the children array), ICANN bit and text. -// -// In the //-comment after each node's data, the nodes indexes of the children -// are formatted as (n0x1234-n0x1256), with * denoting the wildcard bit. The -// nodeType is printed as + for normal, ! for exception, and o for parent-only -// nodes that have children but don't match a domain label in their own right. -// An I denotes an ICANN domain. -// -// The layout within the uint32, from MSB to LSB, is: -// [%2d bits] unused -// [%2d bits] children index -// [%2d bits] ICANN bit -// [%2d bits] text index -// [%2d bits] text length -var nodes = [...]uint32{ -`, - 32-nodesBitsChildren-nodesBitsICANN-nodesBitsTextOffset-nodesBitsTextLength, - nodesBitsChildren, nodesBitsICANN, nodesBitsTextOffset, nodesBitsTextLength) - if err := n.walk(w, printNode); err != nil { - return err - } - fmt.Fprintf(w, `} - -// children is the list of nodes' children, the parent's wildcard bit and the -// parent's node type. If a node has no children then their children index -// will be in the range [0, 6), depending on the wildcard bit and node type. -// -// The layout within the uint32, from MSB to LSB, is: -// [%2d bits] unused -// [%2d bits] wildcard bit -// [%2d bits] node type -// [%2d bits] high nodes index (exclusive) of children -// [%2d bits] low nodes index (inclusive) of children -var children=[...]uint32{ -`, - 32-childrenBitsWildcard-childrenBitsNodeType-childrenBitsHi-childrenBitsLo, - childrenBitsWildcard, childrenBitsNodeType, childrenBitsHi, childrenBitsLo) - for i, c := range childrenEncoding { - s := "---------------" - lo := c & (1<> childrenBitsLo) & (1<>(childrenBitsLo+childrenBitsHi)) & (1<>(childrenBitsLo+childrenBitsHi+childrenBitsNodeType) != 0 - fmt.Fprintf(w, "0x%08x, // c0x%04x (%s)%s %s\n", - c, i, s, wildcardStr(wildcard), nodeTypeStr(nodeType)) - } - fmt.Fprintf(w, "}\n\n") - fmt.Fprintf(w, "// max children %d (capacity %d)\n", maxChildren, 1<= 1<= 1<= 1< 0 && ss[0] == "" { - ss = ss[1:] - } - return ss -} - -// crush combines a list of strings, taking advantage of overlaps. It returns a -// single string that contains each input string as a substring. -func crush(ss []string) string { - maxLabelLen := 0 - for _, s := range ss { - if maxLabelLen < len(s) { - maxLabelLen = len(s) - } - } - - for prefixLen := maxLabelLen; prefixLen > 0; prefixLen-- { - prefixes := makePrefixMap(ss, prefixLen) - for i, s := range ss { - if len(s) <= prefixLen { - continue - } - mergeLabel(ss, i, prefixLen, prefixes) - } - } - - return strings.Join(ss, "") -} - -// mergeLabel merges the label at ss[i] with the first available matching label -// in prefixMap, where the last "prefixLen" characters in ss[i] match the first -// "prefixLen" characters in the matching label. -// It will merge ss[i] repeatedly until no more matches are available. -// All matching labels merged into ss[i] are replaced by "". -func mergeLabel(ss []string, i, prefixLen int, prefixes prefixMap) { - s := ss[i] - suffix := s[len(s)-prefixLen:] - for _, j := range prefixes[suffix] { - // Empty strings mean "already used." Also avoid merging with self. - if ss[j] == "" || i == j { - continue - } - if *v { - fmt.Fprintf(os.Stderr, "%d-length overlap at (%4d,%4d): %q and %q share %q\n", - prefixLen, i, j, ss[i], ss[j], suffix) - } - ss[i] += ss[j][prefixLen:] - ss[j] = "" - // ss[i] has a new suffix, so merge again if possible. - // Note: we only have to merge again at the same prefix length. Shorter - // prefix lengths will be handled in the next iteration of crush's for loop. - // Can there be matches for longer prefix lengths, introduced by the merge? - // I believe that any such matches would by necessity have been eliminated - // during substring removal or merged at a higher prefix length. For - // instance, in crush("abc", "cde", "bcdef"), combining "abc" and "cde" - // would yield "abcde", which could be merged with "bcdef." However, in - // practice "cde" would already have been elimintated by removeSubstrings. - mergeLabel(ss, i, prefixLen, prefixes) - return - } -} - -// prefixMap maps from a prefix to a list of strings containing that prefix. The -// list of strings is represented as indexes into a slice of strings stored -// elsewhere. -type prefixMap map[string][]int - -// makePrefixMap constructs a prefixMap from a slice of strings. -func makePrefixMap(ss []string, prefixLen int) prefixMap { - prefixes := make(prefixMap) - for i, s := range ss { - // We use < rather than <= because if a label matches on a prefix equal to - // its full length, that's actually a substring match handled by - // removeSubstrings. - if prefixLen < len(s) { - prefix := s[:prefixLen] - prefixes[prefix] = append(prefixes[prefix], i) - } - } - - return prefixes -}