Fix godeps

Signed-off-by: Olivier Gambier <olivier@docker.com>
This commit is contained in:
Olivier Gambier 2016-03-21 12:08:47 -07:00
parent 77e69b9cf3
commit 53e3c1d7b2
806 changed files with 431 additions and 1075412 deletions

251
Godeps/Godeps.json generated
View file

@ -1,48 +1,14 @@
{ {
"ImportPath": "github.com/docker/distribution", "ImportPath": "github.com/docker/distribution",
"GoVersion": "go1.4.2", "GoVersion": "go1.6",
"GodepVersion": "v60",
"Packages": [ "Packages": [
"./..." "./..."
], ],
"Deps": [ "Deps": [
{
"ImportPath": "golang.org/x/oauth2",
"Rev": "8914e5017ca260f2a3a1575b1e6868874050d95e"
},
{
"ImportPath": "golang.org/x/oauth2/google",
"Rev": "8914e5017ca260f2a3a1575b1e6868874050d95e"
},
{
"ImportPath": "google.golang.org/api/storage/v1",
"Rev": "18450f4e95c7e76ce3a5dc3a8cb7178ab6d56121"
},
{
"ImportPath": "google.golang.org/cloud",
"Rev": "2400193c85c3561d13880d34e0e10c4315bb02af"
},
{
"ImportPath": "google.golang.org/api",
"Rev": "18450f4e95c7e76ce3a5dc3a8cb7178ab6d56121"
},
{
"ImportPath": "google.golang.org/grpc",
"Rev": "91c8b79535eb6045d70ec671d302213f88a3ab95"
},
{
"ImportPath": "github.com/bradfitz/http2",
"Rev": "f8202bc903bda493ebba4aa54922d78430c2c42f"
},
{
"ImportPath": "github.com/golang/protobuf",
"Rev": "0f7a9caded1fb3c9cc5a9b4bcf2ff633cc8ae644"
},
{
"ImportPath": "google.golang.org/cloud/storage",
"Rev": "2400193c85c3561d13880d34e0e10c4315bb02af"
},
{ {
"ImportPath": "github.com/Azure/azure-sdk-for-go/storage", "ImportPath": "github.com/Azure/azure-sdk-for-go/storage",
"Comment": "v1.2-334-g95361a2",
"Rev": "95361a2573b1fa92a00c5fc2707a80308483c6f9" "Rev": "95361a2573b1fa92a00c5fc2707a80308483c6f9"
}, },
{ {
@ -50,11 +16,71 @@
"Comment": "v0.7.3", "Comment": "v0.7.3",
"Rev": "55eb11d21d2a31a3cc93838241d04800f52e823d" "Rev": "55eb11d21d2a31a3cc93838241d04800f52e823d"
}, },
{
"ImportPath": "github.com/Sirupsen/logrus/formatters/logstash",
"Comment": "v0.7.3",
"Rev": "55eb11d21d2a31a3cc93838241d04800f52e823d"
},
{ {
"ImportPath": "github.com/aws/aws-sdk-go/aws", "ImportPath": "github.com/aws/aws-sdk-go/aws",
"Comment": "v1.1.0-14-g49c3892", "Comment": "v1.1.0-14-g49c3892",
"Rev": "49c3892b61af1d4996292a3025f36e4dfa25eaee" "Rev": "49c3892b61af1d4996292a3025f36e4dfa25eaee"
}, },
{
"ImportPath": "github.com/aws/aws-sdk-go/aws/awserr",
"Comment": "v1.1.0-14-g49c3892",
"Rev": "49c3892b61af1d4996292a3025f36e4dfa25eaee"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/aws/awsutil",
"Comment": "v1.1.0-14-g49c3892",
"Rev": "49c3892b61af1d4996292a3025f36e4dfa25eaee"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/aws/client",
"Comment": "v1.1.0-14-g49c3892",
"Rev": "49c3892b61af1d4996292a3025f36e4dfa25eaee"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/aws/client/metadata",
"Comment": "v1.1.0-14-g49c3892",
"Rev": "49c3892b61af1d4996292a3025f36e4dfa25eaee"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/aws/corehandlers",
"Comment": "v1.1.0-14-g49c3892",
"Rev": "49c3892b61af1d4996292a3025f36e4dfa25eaee"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/aws/credentials",
"Comment": "v1.1.0-14-g49c3892",
"Rev": "49c3892b61af1d4996292a3025f36e4dfa25eaee"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds",
"Comment": "v1.1.0-14-g49c3892",
"Rev": "49c3892b61af1d4996292a3025f36e4dfa25eaee"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/aws/defaults",
"Comment": "v1.1.0-14-g49c3892",
"Rev": "49c3892b61af1d4996292a3025f36e4dfa25eaee"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/aws/ec2metadata",
"Comment": "v1.1.0-14-g49c3892",
"Rev": "49c3892b61af1d4996292a3025f36e4dfa25eaee"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/aws/request",
"Comment": "v1.1.0-14-g49c3892",
"Rev": "49c3892b61af1d4996292a3025f36e4dfa25eaee"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/aws/session",
"Comment": "v1.1.0-14-g49c3892",
"Rev": "49c3892b61af1d4996292a3025f36e4dfa25eaee"
},
{ {
"ImportPath": "github.com/aws/aws-sdk-go/private/endpoints", "ImportPath": "github.com/aws/aws-sdk-go/private/endpoints",
"Comment": "v1.1.0-14-g49c3892", "Comment": "v1.1.0-14-g49c3892",
@ -65,6 +91,31 @@
"Comment": "v1.1.0-14-g49c3892", "Comment": "v1.1.0-14-g49c3892",
"Rev": "49c3892b61af1d4996292a3025f36e4dfa25eaee" "Rev": "49c3892b61af1d4996292a3025f36e4dfa25eaee"
}, },
{
"ImportPath": "github.com/aws/aws-sdk-go/private/protocol/query",
"Comment": "v1.1.0-14-g49c3892",
"Rev": "49c3892b61af1d4996292a3025f36e4dfa25eaee"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/private/protocol/query/queryutil",
"Comment": "v1.1.0-14-g49c3892",
"Rev": "49c3892b61af1d4996292a3025f36e4dfa25eaee"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/private/protocol/rest",
"Comment": "v1.1.0-14-g49c3892",
"Rev": "49c3892b61af1d4996292a3025f36e4dfa25eaee"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/private/protocol/restxml",
"Comment": "v1.1.0-14-g49c3892",
"Rev": "49c3892b61af1d4996292a3025f36e4dfa25eaee"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil",
"Comment": "v1.1.0-14-g49c3892",
"Rev": "49c3892b61af1d4996292a3025f36e4dfa25eaee"
},
{ {
"ImportPath": "github.com/aws/aws-sdk-go/private/signer/v4", "ImportPath": "github.com/aws/aws-sdk-go/private/signer/v4",
"Comment": "v1.1.0-14-g49c3892", "Comment": "v1.1.0-14-g49c3892",
@ -85,19 +136,37 @@
"Comment": "v1.1.0-14-g49c3892", "Comment": "v1.1.0-14-g49c3892",
"Rev": "49c3892b61af1d4996292a3025f36e4dfa25eaee" "Rev": "49c3892b61af1d4996292a3025f36e4dfa25eaee"
}, },
{
"ImportPath": "github.com/bradfitz/http2",
"Rev": "f8202bc903bda493ebba4aa54922d78430c2c42f"
},
{
"ImportPath": "github.com/bradfitz/http2/hpack",
"Rev": "f8202bc903bda493ebba4aa54922d78430c2c42f"
},
{ {
"ImportPath": "github.com/bugsnag/bugsnag-go", "ImportPath": "github.com/bugsnag/bugsnag-go",
"Comment": "v1.0.2-5-gb1d1530", "Comment": "v1.0.2-5-gb1d1530",
"Rev": "b1d153021fcd90ca3f080db36bec96dc690fb274" "Rev": "b1d153021fcd90ca3f080db36bec96dc690fb274"
}, },
{
"ImportPath": "github.com/bugsnag/bugsnag-go/errors",
"Comment": "v1.0.2-5-gb1d1530",
"Rev": "b1d153021fcd90ca3f080db36bec96dc690fb274"
},
{ {
"ImportPath": "github.com/bugsnag/osext", "ImportPath": "github.com/bugsnag/osext",
"Rev": "0dd3f918b21bec95ace9dc86c7e70266cfc5c702" "Rev": "0dd3f918b21bec95ace9dc86c7e70266cfc5c702"
}, },
{ {
"ImportPath": "github.com/bugsnag/panicwrap", "ImportPath": "github.com/bugsnag/panicwrap",
"Comment": "1.0.0-2-ge2c2850",
"Rev": "e2c28503fcd0675329da73bf48b33404db873782" "Rev": "e2c28503fcd0675329da73bf48b33404db873782"
}, },
{
"ImportPath": "github.com/denverdino/aliyungo/common",
"Rev": "6ffb587da9da6d029d0ce517b85fecc82172d502"
},
{ {
"ImportPath": "github.com/denverdino/aliyungo/oss", "ImportPath": "github.com/denverdino/aliyungo/oss",
"Rev": "6ffb587da9da6d029d0ce517b85fecc82172d502" "Rev": "6ffb587da9da6d029d0ce517b85fecc82172d502"
@ -106,10 +175,6 @@
"ImportPath": "github.com/denverdino/aliyungo/util", "ImportPath": "github.com/denverdino/aliyungo/util",
"Rev": "6ffb587da9da6d029d0ce517b85fecc82172d502" "Rev": "6ffb587da9da6d029d0ce517b85fecc82172d502"
}, },
{
"ImportPath": "github.com/denverdino/aliyungo/common",
"Rev": "6ffb587da9da6d029d0ce517b85fecc82172d502"
},
{ {
"ImportPath": "github.com/docker/goamz/aws", "ImportPath": "github.com/docker/goamz/aws",
"Rev": "f0a21f5b2e12f83a505ecf79b633bb2035cf6f85" "Rev": "f0a21f5b2e12f83a505ecf79b633bb2035cf6f85"
@ -135,6 +200,10 @@
"Comment": "v1.8.6", "Comment": "v1.8.6",
"Rev": "afbd495e5aaea13597b5e14fe514ddeaa4d76fc3" "Rev": "afbd495e5aaea13597b5e14fe514ddeaa4d76fc3"
}, },
{
"ImportPath": "github.com/golang/protobuf/proto",
"Rev": "0f7a9caded1fb3c9cc5a9b4bcf2ff633cc8ae644"
},
{ {
"ImportPath": "github.com/gorilla/context", "ImportPath": "github.com/gorilla/context",
"Rev": "14f550f51af52180c2eefed15e5fd18d63c0a64a" "Rev": "14f550f51af52180c2eefed15e5fd18d63c0a64a"
@ -151,19 +220,23 @@
"ImportPath": "github.com/inconshreveable/mousetrap", "ImportPath": "github.com/inconshreveable/mousetrap",
"Rev": "76626ae9c91c4f2a10f34cad8ce83ea42c93bb75" "Rev": "76626ae9c91c4f2a10f34cad8ce83ea42c93bb75"
}, },
{
"ImportPath": "github.com/mitchellh/mapstructure",
"Rev": "482a9fd5fa83e8c4e7817413b80f3eb8feec03ef"
},
{ {
"ImportPath": "github.com/jmespath/go-jmespath", "ImportPath": "github.com/jmespath/go-jmespath",
"Comment": "0.2.2-12-g0b12d6b", "Comment": "0.2.2-12-g0b12d6b",
"Rev": "0b12d6b521d83fc7f755e7cfc1b1fbdd35a01a74" "Rev": "0b12d6b521d83fc7f755e7cfc1b1fbdd35a01a74"
}, },
{
"ImportPath": "github.com/mitchellh/mapstructure",
"Rev": "482a9fd5fa83e8c4e7817413b80f3eb8feec03ef"
},
{ {
"ImportPath": "github.com/ncw/swift", "ImportPath": "github.com/ncw/swift",
"Rev": "c54732e87b0b283d1baf0a18db689d0aea460ba3" "Rev": "c54732e87b0b283d1baf0a18db689d0aea460ba3"
}, },
{
"ImportPath": "github.com/ncw/swift/swifttest",
"Rev": "c54732e87b0b283d1baf0a18db689d0aea460ba3"
},
{ {
"ImportPath": "github.com/spf13/cobra", "ImportPath": "github.com/spf13/cobra",
"Rev": "312092086bed4968099259622145a0c9ae280064" "Rev": "312092086bed4968099259622145a0c9ae280064"
@ -176,6 +249,14 @@
"ImportPath": "github.com/stevvooe/resumable", "ImportPath": "github.com/stevvooe/resumable",
"Rev": "51ad44105773cafcbe91927f70ac68e1bf78f8b4" "Rev": "51ad44105773cafcbe91927f70ac68e1bf78f8b4"
}, },
{
"ImportPath": "github.com/stevvooe/resumable/sha256",
"Rev": "51ad44105773cafcbe91927f70ac68e1bf78f8b4"
},
{
"ImportPath": "github.com/stevvooe/resumable/sha512",
"Rev": "51ad44105773cafcbe91927f70ac68e1bf78f8b4"
},
{ {
"ImportPath": "github.com/yvasiyarov/go-metrics", "ImportPath": "github.com/yvasiyarov/go-metrics",
"Rev": "57bccd1ccd43f94bb17fdd8bf3007059b802f85e" "Rev": "57bccd1ccd43f94bb17fdd8bf3007059b802f85e"
@ -201,10 +282,90 @@
"ImportPath": "golang.org/x/net/context", "ImportPath": "golang.org/x/net/context",
"Rev": "2cba614e8ff920c60240d2677bc019af32ee04e5" "Rev": "2cba614e8ff920c60240d2677bc019af32ee04e5"
}, },
{
"ImportPath": "golang.org/x/net/internal/timeseries",
"Rev": "2cba614e8ff920c60240d2677bc019af32ee04e5"
},
{ {
"ImportPath": "golang.org/x/net/trace", "ImportPath": "golang.org/x/net/trace",
"Rev": "2cba614e8ff920c60240d2677bc019af32ee04e5" "Rev": "2cba614e8ff920c60240d2677bc019af32ee04e5"
}, },
{
"ImportPath": "golang.org/x/oauth2",
"Rev": "8914e5017ca260f2a3a1575b1e6868874050d95e"
},
{
"ImportPath": "golang.org/x/oauth2/google",
"Rev": "8914e5017ca260f2a3a1575b1e6868874050d95e"
},
{
"ImportPath": "golang.org/x/oauth2/internal",
"Rev": "8914e5017ca260f2a3a1575b1e6868874050d95e"
},
{
"ImportPath": "golang.org/x/oauth2/jws",
"Rev": "8914e5017ca260f2a3a1575b1e6868874050d95e"
},
{
"ImportPath": "golang.org/x/oauth2/jwt",
"Rev": "8914e5017ca260f2a3a1575b1e6868874050d95e"
},
{
"ImportPath": "google.golang.org/api/googleapi",
"Rev": "18450f4e95c7e76ce3a5dc3a8cb7178ab6d56121"
},
{
"ImportPath": "google.golang.org/api/googleapi/internal/uritemplates",
"Rev": "18450f4e95c7e76ce3a5dc3a8cb7178ab6d56121"
},
{
"ImportPath": "google.golang.org/api/storage/v1",
"Rev": "18450f4e95c7e76ce3a5dc3a8cb7178ab6d56121"
},
{
"ImportPath": "google.golang.org/cloud",
"Rev": "2400193c85c3561d13880d34e0e10c4315bb02af"
},
{
"ImportPath": "google.golang.org/cloud/compute/metadata",
"Rev": "2400193c85c3561d13880d34e0e10c4315bb02af"
},
{
"ImportPath": "google.golang.org/cloud/internal",
"Rev": "2400193c85c3561d13880d34e0e10c4315bb02af"
},
{
"ImportPath": "google.golang.org/cloud/internal/opts",
"Rev": "2400193c85c3561d13880d34e0e10c4315bb02af"
},
{
"ImportPath": "google.golang.org/cloud/storage",
"Rev": "2400193c85c3561d13880d34e0e10c4315bb02af"
},
{
"ImportPath": "google.golang.org/grpc",
"Rev": "91c8b79535eb6045d70ec671d302213f88a3ab95"
},
{
"ImportPath": "google.golang.org/grpc/codes",
"Rev": "91c8b79535eb6045d70ec671d302213f88a3ab95"
},
{
"ImportPath": "google.golang.org/grpc/credentials",
"Rev": "91c8b79535eb6045d70ec671d302213f88a3ab95"
},
{
"ImportPath": "google.golang.org/grpc/grpclog",
"Rev": "91c8b79535eb6045d70ec671d302213f88a3ab95"
},
{
"ImportPath": "google.golang.org/grpc/metadata",
"Rev": "91c8b79535eb6045d70ec671d302213f88a3ab95"
},
{
"ImportPath": "google.golang.org/grpc/transport",
"Rev": "91c8b79535eb6045d70ec671d302213f88a3ab95"
},
{ {
"ImportPath": "gopkg.in/check.v1", "ImportPath": "gopkg.in/check.v1",
"Rev": "64131543e7896d5bcc6bd5a76287eb75ea96c673" "Rev": "64131543e7896d5bcc6bd5a76287eb75ea96c673"

View file

@ -34,6 +34,9 @@ PKGS := $(shell go list -tags "${DOCKER_BUILDTAGS}" ./... | grep -v ^github.com/
GOLINT_BIN := $(GOPATH)/bin/golint GOLINT_BIN := $(GOPATH)/bin/golint
GOLINT := $(shell [ -x $(GOLINT_BIN) ] && echo $(GOLINT_BIN) || echo '') GOLINT := $(shell [ -x $(GOLINT_BIN) ] && echo $(GOLINT_BIN) || echo '')
GODEP_BIN := $(GOPATH)/bin/godep
GODEP := $(shell [ -x $(GODEP_BIN) ] && echo $(GODEP_BIN) || echo '')
${PREFIX}/bin/registry: $(wildcard **/*.go) ${PREFIX}/bin/registry: $(wildcard **/*.go)
@echo "+ $@" @echo "+ $@"
@go build -tags "${DOCKER_BUILDTAGS}" -o $@ ${GO_LDFLAGS} ${GO_GCFLAGS} ./cmd/registry @go build -tags "${DOCKER_BUILDTAGS}" -o $@ ${GO_LDFLAGS} ${GO_GCFLAGS} ./cmd/registry
@ -82,3 +85,13 @@ binaries: ${PREFIX}/bin/registry ${PREFIX}/bin/digest ${PREFIX}/bin/registry-api
clean: clean:
@echo "+ $@" @echo "+ $@"
@rm -rf "${PREFIX}/bin/registry" "${PREFIX}/bin/digest" "${PREFIX}/bin/registry-api-descriptor-template" @rm -rf "${PREFIX}/bin/registry" "${PREFIX}/bin/digest" "${PREFIX}/bin/registry-api-descriptor-template"
dep-save:
$(if $(GODEP), , \
$(error Please install godep: go get github.com/tools/godep))
$(GODEP) save $(PKGS)
dep-restore:
$(if $(GODEP), , \
$(error Please install godep: go get github.com/tools/godep))
$(GODEP) restore -v

2
vendor/.gitignore vendored
View file

@ -1,2 +0,0 @@
/pkg
/bin

View file

@ -1,29 +0,0 @@
# Compiled Object files, Static and Dynamic libs (Shared Objects)
*.o
*.a
*.so
# Folders
_obj
_test
# Architecture specific extensions/prefixes
*.[568vq]
[568vq].out
*.cgo1.go
*.cgo2.c
_cgo_defun.c
_cgo_gotypes.go
_cgo_export.*
_testmain.go
*.exe
*.test
*.prof
# Editor swap files
*.swp
*~
.DS_Store

View file

@ -1,19 +0,0 @@
sudo: false
language: go
before_script:
- go get -u golang.org/x/tools/cmd/vet
- go get -u github.com/golang/lint/golint
go: tip
script:
- test -z "$(gofmt -s -l -w management | tee /dev/stderr)"
- test -z "$(gofmt -s -l -w storage | tee /dev/stderr)"
- go build -v ./...
- go test -v ./storage/... -check.v
- test -z "$(golint ./storage/... | tee /dev/stderr)"
- go vet ./storage/...
- go test -v ./management/...
- test -z "$(golint ./management/... | grep -v 'should have comment' | grep -v 'stutters' | tee /dev/stderr)"
- go vet ./management/...

View file

@ -1,88 +0,0 @@
# Microsoft Azure SDK for Go
This project provides various Go packages to perform operations
on Microsoft Azure REST APIs.
[![GoDoc](https://godoc.org/github.com/Azure/azure-sdk-for-go?status.svg)](https://godoc.org/github.com/Azure/azure-sdk-for-go) [![Build Status](https://travis-ci.org/Azure/azure-sdk-for-go.svg?branch=master)](https://travis-ci.org/Azure/azure-sdk-for-go)
See list of implemented API clients [here](http://godoc.org/github.com/Azure/azure-sdk-for-go).
> **NOTE:** This repository is under heavy ongoing development and
is likely to break over time. We currently do not have any releases
yet. If you are planning to use the repository, please consider vendoring
the packages in your project and update them when a stable tag is out.
# Installation
go get -d github.com/Azure/azure-sdk-for-go/management
# Usage
Read Godoc of the repository at: http://godoc.org/github.com/Azure/azure-sdk-for-go/
The client currently supports authentication to the Service Management
API with certificates or Azure `.publishSettings` file. You can
download the `.publishSettings` file for your subscriptions
[here](https://manage.windowsazure.com/publishsettings).
### Example: Creating a Linux Virtual Machine
```go
package main
import (
"encoding/base64"
"fmt"
"github.com/Azure/azure-sdk-for-go/management"
"github.com/Azure/azure-sdk-for-go/management/hostedservice"
"github.com/Azure/azure-sdk-for-go/management/virtualmachine"
"github.com/Azure/azure-sdk-for-go/management/vmutils"
)
func main() {
dnsName := "test-vm-from-go"
storageAccount := "mystorageaccount"
location := "West US"
vmSize := "Small"
vmImage := "b39f27a8b8c64d52b05eac6a62ebad85__Ubuntu-14_04-LTS-amd64-server-20140724-en-us-30GB"
userName := "testuser"
userPassword := "Test123"
client, err := management.ClientFromPublishSettingsFile("path/to/downloaded.publishsettings", "")
if err != nil {
panic(err)
}
// create hosted service
if err := hostedservice.NewClient(client).CreateHostedService(hostedservice.CreateHostedServiceParameters{
ServiceName: dnsName,
Location: location,
Label: base64.StdEncoding.EncodeToString([]byte(dnsName))}); err != nil {
panic(err)
}
// create virtual machine
role := vmutils.NewVMConfiguration(dnsName, vmSize)
vmutils.ConfigureDeploymentFromPlatformImage(
&role,
vmImage,
fmt.Sprintf("http://%s.blob.core.windows.net/sdktest/%s.vhd", storageAccount, dnsName),
"")
vmutils.ConfigureForLinux(&role, dnsName, userName, userPassword)
vmutils.ConfigureWithPublicSSH(&role)
operationID, err := virtualmachine.NewClient(client).
CreateDeployment(role, dnsName, virtualmachine.CreateDeploymentOptions{})
if err != nil {
panic(err)
}
if err := client.WaitForOperation(operationID, nil); err != nil {
panic(err)
}
}
```
# License
This project is published under [Apache 2.0 License](LICENSE).

View file

@ -325,7 +325,7 @@ func (c Client) exec(verb, url string, headers map[string]string, body io.Reader
} }
statusCode := resp.StatusCode statusCode := resp.StatusCode
if statusCode >= 400 && statusCode <= 505 && statusCode != 404 { if statusCode >= 400 && statusCode <= 505 {
var respBody []byte var respBody []byte
respBody, err = readResponseBody(resp) respBody, err = readResponseBody(resp)
if err != nil { if err != nil {

View file

@ -1,53 +0,0 @@
package logrus
import (
"bytes"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
func TestEntryPanicln(t *testing.T) {
errBoom := fmt.Errorf("boom time")
defer func() {
p := recover()
assert.NotNil(t, p)
switch pVal := p.(type) {
case *Entry:
assert.Equal(t, "kaboom", pVal.Message)
assert.Equal(t, errBoom, pVal.Data["err"])
default:
t.Fatalf("want type *Entry, got %T: %#v", pVal, pVal)
}
}()
logger := New()
logger.Out = &bytes.Buffer{}
entry := NewEntry(logger)
entry.WithField("err", errBoom).Panicln("kaboom")
}
func TestEntryPanicf(t *testing.T) {
errBoom := fmt.Errorf("boom again")
defer func() {
p := recover()
assert.NotNil(t, p)
switch pVal := p.(type) {
case *Entry:
assert.Equal(t, "kaboom true", pVal.Message)
assert.Equal(t, errBoom, pVal.Data["err"])
default:
t.Fatalf("want type *Entry, got %T: %#v", pVal, pVal)
}
}()
logger := New()
logger.Out = &bytes.Buffer{}
entry := NewEntry(logger)
entry.WithField("err", errBoom).Panicf("kaboom %v", true)
}

View file

@ -1,50 +0,0 @@
package main
import (
"github.com/Sirupsen/logrus"
)
var log = logrus.New()
func init() {
log.Formatter = new(logrus.JSONFormatter)
log.Formatter = new(logrus.TextFormatter) // default
log.Level = logrus.DebugLevel
}
func main() {
defer func() {
err := recover()
if err != nil {
log.WithFields(logrus.Fields{
"omg": true,
"err": err,
"number": 100,
}).Fatal("The ice breaks!")
}
}()
log.WithFields(logrus.Fields{
"animal": "walrus",
"number": 8,
}).Debug("Started observing beach")
log.WithFields(logrus.Fields{
"animal": "walrus",
"size": 10,
}).Info("A group of walrus emerges from the ocean")
log.WithFields(logrus.Fields{
"omg": true,
"number": 122,
}).Warn("The group's number increased tremendously!")
log.WithFields(logrus.Fields{
"temperature": -4,
}).Debug("Temperature changes")
log.WithFields(logrus.Fields{
"animal": "orca",
"size": 9009,
}).Panic("It's over 9000!")
}

View file

@ -1,30 +0,0 @@
package main
import (
"github.com/Sirupsen/logrus"
"github.com/Sirupsen/logrus/hooks/airbrake"
)
var log = logrus.New()
func init() {
log.Formatter = new(logrus.TextFormatter) // default
log.Hooks.Add(airbrake.NewHook("https://example.com", "xyz", "development"))
}
func main() {
log.WithFields(logrus.Fields{
"animal": "walrus",
"size": 10,
}).Info("A group of walrus emerges from the ocean")
log.WithFields(logrus.Fields{
"omg": true,
"number": 122,
}).Warn("The group's number increased tremendously!")
log.WithFields(logrus.Fields{
"omg": true,
"number": 100,
}).Fatal("The ice breaks!")
}

View file

@ -1,88 +0,0 @@
package logrus
import (
"testing"
"time"
)
// smallFields is a small size data set for benchmarking
var smallFields = Fields{
"foo": "bar",
"baz": "qux",
"one": "two",
"three": "four",
}
// largeFields is a large size data set for benchmarking
var largeFields = Fields{
"foo": "bar",
"baz": "qux",
"one": "two",
"three": "four",
"five": "six",
"seven": "eight",
"nine": "ten",
"eleven": "twelve",
"thirteen": "fourteen",
"fifteen": "sixteen",
"seventeen": "eighteen",
"nineteen": "twenty",
"a": "b",
"c": "d",
"e": "f",
"g": "h",
"i": "j",
"k": "l",
"m": "n",
"o": "p",
"q": "r",
"s": "t",
"u": "v",
"w": "x",
"y": "z",
"this": "will",
"make": "thirty",
"entries": "yeah",
}
func BenchmarkSmallTextFormatter(b *testing.B) {
doBenchmark(b, &TextFormatter{DisableColors: true}, smallFields)
}
func BenchmarkLargeTextFormatter(b *testing.B) {
doBenchmark(b, &TextFormatter{DisableColors: true}, largeFields)
}
func BenchmarkSmallColoredTextFormatter(b *testing.B) {
doBenchmark(b, &TextFormatter{ForceColors: true}, smallFields)
}
func BenchmarkLargeColoredTextFormatter(b *testing.B) {
doBenchmark(b, &TextFormatter{ForceColors: true}, largeFields)
}
func BenchmarkSmallJSONFormatter(b *testing.B) {
doBenchmark(b, &JSONFormatter{}, smallFields)
}
func BenchmarkLargeJSONFormatter(b *testing.B) {
doBenchmark(b, &JSONFormatter{}, largeFields)
}
func doBenchmark(b *testing.B, formatter Formatter, fields Fields) {
entry := &Entry{
Time: time.Time{},
Level: InfoLevel,
Message: "message",
Data: fields,
}
var d []byte
var err error
for i := 0; i < b.N; i++ {
d, err = formatter.Format(entry)
if err != nil {
b.Fatal(err)
}
b.SetBytes(int64(len(d)))
}
}

View file

@ -1,52 +0,0 @@
package logstash
import (
"bytes"
"encoding/json"
"github.com/Sirupsen/logrus"
"github.com/stretchr/testify/assert"
"testing"
)
func TestLogstashFormatter(t *testing.T) {
assert := assert.New(t)
lf := LogstashFormatter{Type: "abc"}
fields := logrus.Fields{
"message": "def",
"level": "ijk",
"type": "lmn",
"one": 1,
"pi": 3.14,
"bool": true,
}
entry := logrus.WithFields(fields)
entry.Message = "msg"
entry.Level = logrus.InfoLevel
b, _ := lf.Format(entry)
var data map[string]interface{}
dec := json.NewDecoder(bytes.NewReader(b))
dec.UseNumber()
dec.Decode(&data)
// base fields
assert.Equal(json.Number("1"), data["@version"])
assert.NotEmpty(data["@timestamp"])
assert.Equal("abc", data["type"])
assert.Equal("msg", data["message"])
assert.Equal("info", data["level"])
// substituted fields
assert.Equal("def", data["fields.message"])
assert.Equal("ijk", data["fields.level"])
assert.Equal("lmn", data["fields.type"])
// formats
assert.Equal(json.Number("1"), data["one"])
assert.Equal(json.Number("3.14"), data["pi"])
assert.Equal(true, data["bool"])
}

View file

@ -1,122 +0,0 @@
package logrus
import (
"testing"
"github.com/stretchr/testify/assert"
)
type TestHook struct {
Fired bool
}
func (hook *TestHook) Fire(entry *Entry) error {
hook.Fired = true
return nil
}
func (hook *TestHook) Levels() []Level {
return []Level{
DebugLevel,
InfoLevel,
WarnLevel,
ErrorLevel,
FatalLevel,
PanicLevel,
}
}
func TestHookFires(t *testing.T) {
hook := new(TestHook)
LogAndAssertJSON(t, func(log *Logger) {
log.Hooks.Add(hook)
assert.Equal(t, hook.Fired, false)
log.Print("test")
}, func(fields Fields) {
assert.Equal(t, hook.Fired, true)
})
}
type ModifyHook struct {
}
func (hook *ModifyHook) Fire(entry *Entry) error {
entry.Data["wow"] = "whale"
return nil
}
func (hook *ModifyHook) Levels() []Level {
return []Level{
DebugLevel,
InfoLevel,
WarnLevel,
ErrorLevel,
FatalLevel,
PanicLevel,
}
}
func TestHookCanModifyEntry(t *testing.T) {
hook := new(ModifyHook)
LogAndAssertJSON(t, func(log *Logger) {
log.Hooks.Add(hook)
log.WithField("wow", "elephant").Print("test")
}, func(fields Fields) {
assert.Equal(t, fields["wow"], "whale")
})
}
func TestCanFireMultipleHooks(t *testing.T) {
hook1 := new(ModifyHook)
hook2 := new(TestHook)
LogAndAssertJSON(t, func(log *Logger) {
log.Hooks.Add(hook1)
log.Hooks.Add(hook2)
log.WithField("wow", "elephant").Print("test")
}, func(fields Fields) {
assert.Equal(t, fields["wow"], "whale")
assert.Equal(t, hook2.Fired, true)
})
}
type ErrorHook struct {
Fired bool
}
func (hook *ErrorHook) Fire(entry *Entry) error {
hook.Fired = true
return nil
}
func (hook *ErrorHook) Levels() []Level {
return []Level{
ErrorLevel,
}
}
func TestErrorHookShouldntFireOnInfo(t *testing.T) {
hook := new(ErrorHook)
LogAndAssertJSON(t, func(log *Logger) {
log.Hooks.Add(hook)
log.Info("test")
}, func(fields Fields) {
assert.Equal(t, hook.Fired, false)
})
}
func TestErrorHookShouldFireOnError(t *testing.T) {
hook := new(ErrorHook)
LogAndAssertJSON(t, func(log *Logger) {
log.Hooks.Add(hook)
log.Error("test")
}, func(fields Fields) {
assert.Equal(t, hook.Fired, true)
})
}

View file

@ -1,54 +0,0 @@
package airbrake
import (
"errors"
"fmt"
"github.com/Sirupsen/logrus"
"github.com/tobi/airbrake-go"
)
// AirbrakeHook to send exceptions to an exception-tracking service compatible
// with the Airbrake API.
type airbrakeHook struct {
APIKey string
Endpoint string
Environment string
}
func NewHook(endpoint, apiKey, env string) *airbrakeHook {
return &airbrakeHook{
APIKey: apiKey,
Endpoint: endpoint,
Environment: env,
}
}
func (hook *airbrakeHook) Fire(entry *logrus.Entry) error {
airbrake.ApiKey = hook.APIKey
airbrake.Endpoint = hook.Endpoint
airbrake.Environment = hook.Environment
var notifyErr error
err, ok := entry.Data["error"].(error)
if ok {
notifyErr = err
} else {
notifyErr = errors.New(entry.Message)
}
airErr := airbrake.Notify(notifyErr)
if airErr != nil {
return fmt.Errorf("Failed to send error to Airbrake: %s", airErr)
}
return nil
}
func (hook *airbrakeHook) Levels() []logrus.Level {
return []logrus.Level{
logrus.ErrorLevel,
logrus.FatalLevel,
logrus.PanicLevel,
}
}

View file

@ -1,133 +0,0 @@
package airbrake
import (
"encoding/xml"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Sirupsen/logrus"
)
type notice struct {
Error NoticeError `xml:"error"`
}
type NoticeError struct {
Class string `xml:"class"`
Message string `xml:"message"`
}
type customErr struct {
msg string
}
func (e *customErr) Error() string {
return e.msg
}
const (
testAPIKey = "abcxyz"
testEnv = "development"
expectedClass = "*airbrake.customErr"
expectedMsg = "foo"
unintendedMsg = "Airbrake will not see this string"
)
var (
noticeError = make(chan NoticeError, 1)
)
// TestLogEntryMessageReceived checks if invoking Logrus' log.Error
// method causes an XML payload containing the log entry message is received
// by a HTTP server emulating an Airbrake-compatible endpoint.
func TestLogEntryMessageReceived(t *testing.T) {
log := logrus.New()
ts := startAirbrakeServer(t)
defer ts.Close()
hook := NewHook(ts.URL, testAPIKey, "production")
log.Hooks.Add(hook)
log.Error(expectedMsg)
select {
case received := <-noticeError:
if received.Message != expectedMsg {
t.Errorf("Unexpected message received: %s", received.Message)
}
case <-time.After(time.Second):
t.Error("Timed out; no notice received by Airbrake API")
}
}
// TestLogEntryMessageReceived confirms that, when passing an error type using
// logrus.Fields, a HTTP server emulating an Airbrake endpoint receives the
// error message returned by the Error() method on the error interface
// rather than the logrus.Entry.Message string.
func TestLogEntryWithErrorReceived(t *testing.T) {
log := logrus.New()
ts := startAirbrakeServer(t)
defer ts.Close()
hook := NewHook(ts.URL, testAPIKey, "production")
log.Hooks.Add(hook)
log.WithFields(logrus.Fields{
"error": &customErr{expectedMsg},
}).Error(unintendedMsg)
select {
case received := <-noticeError:
if received.Message != expectedMsg {
t.Errorf("Unexpected message received: %s", received.Message)
}
if received.Class != expectedClass {
t.Errorf("Unexpected error class: %s", received.Class)
}
case <-time.After(time.Second):
t.Error("Timed out; no notice received by Airbrake API")
}
}
// TestLogEntryWithNonErrorTypeNotReceived confirms that, when passing a
// non-error type using logrus.Fields, a HTTP server emulating an Airbrake
// endpoint receives the logrus.Entry.Message string.
//
// Only error types are supported when setting the 'error' field using
// logrus.WithFields().
func TestLogEntryWithNonErrorTypeNotReceived(t *testing.T) {
log := logrus.New()
ts := startAirbrakeServer(t)
defer ts.Close()
hook := NewHook(ts.URL, testAPIKey, "production")
log.Hooks.Add(hook)
log.WithFields(logrus.Fields{
"error": expectedMsg,
}).Error(unintendedMsg)
select {
case received := <-noticeError:
if received.Message != unintendedMsg {
t.Errorf("Unexpected message received: %s", received.Message)
}
case <-time.After(time.Second):
t.Error("Timed out; no notice received by Airbrake API")
}
}
func startAirbrakeServer(t *testing.T) *httptest.Server {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var notice notice
if err := xml.NewDecoder(r.Body).Decode(&notice); err != nil {
t.Error(err)
}
r.Body.Close()
noticeError <- notice.Error
}))
return ts
}

View file

@ -1,68 +0,0 @@
package logrus_bugsnag
import (
"errors"
"github.com/Sirupsen/logrus"
"github.com/bugsnag/bugsnag-go"
)
type bugsnagHook struct{}
// ErrBugsnagUnconfigured is returned if NewBugsnagHook is called before
// bugsnag.Configure. Bugsnag must be configured before the hook.
var ErrBugsnagUnconfigured = errors.New("bugsnag must be configured before installing this logrus hook")
// ErrBugsnagSendFailed indicates that the hook failed to submit an error to
// bugsnag. The error was successfully generated, but `bugsnag.Notify()`
// failed.
type ErrBugsnagSendFailed struct {
err error
}
func (e ErrBugsnagSendFailed) Error() string {
return "failed to send error to Bugsnag: " + e.err.Error()
}
// NewBugsnagHook initializes a logrus hook which sends exceptions to an
// exception-tracking service compatible with the Bugsnag API. Before using
// this hook, you must call bugsnag.Configure(). The returned object should be
// registered with a log via `AddHook()`
//
// Entries that trigger an Error, Fatal or Panic should now include an "error"
// field to send to Bugsnag.
func NewBugsnagHook() (*bugsnagHook, error) {
if bugsnag.Config.APIKey == "" {
return nil, ErrBugsnagUnconfigured
}
return &bugsnagHook{}, nil
}
// Fire forwards an error to Bugsnag. Given a logrus.Entry, it extracts the
// "error" field (or the Message if the error isn't present) and sends it off.
func (hook *bugsnagHook) Fire(entry *logrus.Entry) error {
var notifyErr error
err, ok := entry.Data["error"].(error)
if ok {
notifyErr = err
} else {
notifyErr = errors.New(entry.Message)
}
bugsnagErr := bugsnag.Notify(notifyErr)
if bugsnagErr != nil {
return ErrBugsnagSendFailed{bugsnagErr}
}
return nil
}
// Levels enumerates the log levels on which the error should be forwarded to
// bugsnag: everything at or above the "Error" level.
func (hook *bugsnagHook) Levels() []logrus.Level {
return []logrus.Level{
logrus.ErrorLevel,
logrus.FatalLevel,
logrus.PanicLevel,
}
}

View file

@ -1,64 +0,0 @@
package logrus_bugsnag
import (
"encoding/json"
"errors"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Sirupsen/logrus"
"github.com/bugsnag/bugsnag-go"
)
type notice struct {
Events []struct {
Exceptions []struct {
Message string `json:"message"`
} `json:"exceptions"`
} `json:"events"`
}
func TestNoticeReceived(t *testing.T) {
msg := make(chan string, 1)
expectedMsg := "foo"
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var notice notice
data, _ := ioutil.ReadAll(r.Body)
if err := json.Unmarshal(data, &notice); err != nil {
t.Error(err)
}
_ = r.Body.Close()
msg <- notice.Events[0].Exceptions[0].Message
}))
defer ts.Close()
hook := &bugsnagHook{}
bugsnag.Configure(bugsnag.Configuration{
Endpoint: ts.URL,
ReleaseStage: "production",
APIKey: "12345678901234567890123456789012",
Synchronous: true,
})
log := logrus.New()
log.Hooks.Add(hook)
log.WithFields(logrus.Fields{
"error": errors.New(expectedMsg),
}).Error("Bugsnag will not see this string")
select {
case received := <-msg:
if received != expectedMsg {
t.Errorf("Unexpected message received: %s", received)
}
case <-time.After(time.Second):
t.Error("Timed out; no notice received by Bugsnag API")
}
}

View file

@ -1,28 +0,0 @@
# Papertrail Hook for Logrus <img src="http://i.imgur.com/hTeVwmJ.png" width="40" height="40" alt=":walrus:" class="emoji" title=":walrus:" />
[Papertrail](https://papertrailapp.com) provides hosted log management. Once stored in Papertrail, you can [group](http://help.papertrailapp.com/kb/how-it-works/groups/) your logs on various dimensions, [search](http://help.papertrailapp.com/kb/how-it-works/search-syntax) them, and trigger [alerts](http://help.papertrailapp.com/kb/how-it-works/alerts).
In most deployments, you'll want to send logs to Papertrail via their [remote_syslog](http://help.papertrailapp.com/kb/configuration/configuring-centralized-logging-from-text-log-files-in-unix/) daemon, which requires no application-specific configuration. This hook is intended for relatively low-volume logging, likely in managed cloud hosting deployments where installing `remote_syslog` is not possible.
## Usage
You can find your Papertrail UDP port on your [Papertrail account page](https://papertrailapp.com/account/destinations). Substitute it below for `YOUR_PAPERTRAIL_UDP_PORT`.
For `YOUR_APP_NAME`, substitute a short string that will readily identify your application or service in the logs.
```go
import (
"log/syslog"
"github.com/Sirupsen/logrus"
"github.com/Sirupsen/logrus/hooks/papertrail"
)
func main() {
log := logrus.New()
hook, err := logrus_papertrail.NewPapertrailHook("logs.papertrailapp.com", YOUR_PAPERTRAIL_UDP_PORT, YOUR_APP_NAME)
if err == nil {
log.Hooks.Add(hook)
}
}
```

View file

@ -1,55 +0,0 @@
package logrus_papertrail
import (
"fmt"
"net"
"os"
"time"
"github.com/Sirupsen/logrus"
)
const (
format = "Jan 2 15:04:05"
)
// PapertrailHook to send logs to a logging service compatible with the Papertrail API.
type PapertrailHook struct {
Host string
Port int
AppName string
UDPConn net.Conn
}
// NewPapertrailHook creates a hook to be added to an instance of logger.
func NewPapertrailHook(host string, port int, appName string) (*PapertrailHook, error) {
conn, err := net.Dial("udp", fmt.Sprintf("%s:%d", host, port))
return &PapertrailHook{host, port, appName, conn}, err
}
// Fire is called when a log event is fired.
func (hook *PapertrailHook) Fire(entry *logrus.Entry) error {
date := time.Now().Format(format)
msg, _ := entry.String()
payload := fmt.Sprintf("<22> %s %s: %s", date, hook.AppName, msg)
bytesWritten, err := hook.UDPConn.Write([]byte(payload))
if err != nil {
fmt.Fprintf(os.Stderr, "Unable to send log line to Papertrail via UDP. Wrote %d bytes before error: %v", bytesWritten, err)
return err
}
return nil
}
// Levels returns the available logging levels.
func (hook *PapertrailHook) Levels() []logrus.Level {
return []logrus.Level{
logrus.PanicLevel,
logrus.FatalLevel,
logrus.ErrorLevel,
logrus.WarnLevel,
logrus.InfoLevel,
logrus.DebugLevel,
}
}

View file

@ -1,26 +0,0 @@
package logrus_papertrail
import (
"fmt"
"testing"
"github.com/Sirupsen/logrus"
"github.com/stvp/go-udp-testing"
)
func TestWritingToUDP(t *testing.T) {
port := 16661
udp.SetAddr(fmt.Sprintf(":%d", port))
hook, err := NewPapertrailHook("localhost", port, "test")
if err != nil {
t.Errorf("Unable to connect to local UDP server.")
}
log := logrus.New()
log.Hooks.Add(hook)
udp.ShouldReceive(t, "foo", func() {
log.Info("foo")
})
}

View file

@ -1,61 +0,0 @@
# Sentry Hook for Logrus <img src="http://i.imgur.com/hTeVwmJ.png" width="40" height="40" alt=":walrus:" class="emoji" title=":walrus:" />
[Sentry](https://getsentry.com) provides both self-hosted and hosted
solutions for exception tracking.
Both client and server are
[open source](https://github.com/getsentry/sentry).
## Usage
Every sentry application defined on the server gets a different
[DSN](https://www.getsentry.com/docs/). In the example below replace
`YOUR_DSN` with the one created for your application.
```go
import (
"github.com/Sirupsen/logrus"
"github.com/Sirupsen/logrus/hooks/sentry"
)
func main() {
log := logrus.New()
hook, err := logrus_sentry.NewSentryHook(YOUR_DSN, []logrus.Level{
logrus.PanicLevel,
logrus.FatalLevel,
logrus.ErrorLevel,
})
if err == nil {
log.Hooks.Add(hook)
}
}
```
## Special fields
Some logrus fields have a special meaning in this hook,
these are server_name and logger.
When logs are sent to sentry these fields are treated differently.
- server_name (also known as hostname) is the name of the server which
is logging the event (hostname.example.com)
- logger is the part of the application which is logging the event.
In go this usually means setting it to the name of the package.
## Timeout
`Timeout` is the time the sentry hook will wait for a response
from the sentry server.
If this time elapses with no response from
the server an error will be returned.
If `Timeout` is set to 0 the SentryHook will not wait for a reply
and will assume a correct delivery.
The SentryHook has a default timeout of `100 milliseconds` when created
with a call to `NewSentryHook`. This can be changed by assigning a value to the `Timeout` field:
```go
hook, _ := logrus_sentry.NewSentryHook(...)
hook.Timeout = 20*time.Second
```

View file

@ -1,100 +0,0 @@
package logrus_sentry
import (
"fmt"
"time"
"github.com/Sirupsen/logrus"
"github.com/getsentry/raven-go"
)
var (
severityMap = map[logrus.Level]raven.Severity{
logrus.DebugLevel: raven.DEBUG,
logrus.InfoLevel: raven.INFO,
logrus.WarnLevel: raven.WARNING,
logrus.ErrorLevel: raven.ERROR,
logrus.FatalLevel: raven.FATAL,
logrus.PanicLevel: raven.FATAL,
}
)
func getAndDel(d logrus.Fields, key string) (string, bool) {
var (
ok bool
v interface{}
val string
)
if v, ok = d[key]; !ok {
return "", false
}
if val, ok = v.(string); !ok {
return "", false
}
delete(d, key)
return val, true
}
// SentryHook delivers logs to a sentry server.
type SentryHook struct {
// Timeout sets the time to wait for a delivery error from the sentry server.
// If this is set to zero the server will not wait for any response and will
// consider the message correctly sent
Timeout time.Duration
client *raven.Client
levels []logrus.Level
}
// NewSentryHook creates a hook to be added to an instance of logger
// and initializes the raven client.
// This method sets the timeout to 100 milliseconds.
func NewSentryHook(DSN string, levels []logrus.Level) (*SentryHook, error) {
client, err := raven.NewClient(DSN, nil)
if err != nil {
return nil, err
}
return &SentryHook{100 * time.Millisecond, client, levels}, nil
}
// Called when an event should be sent to sentry
// Special fields that sentry uses to give more information to the server
// are extracted from entry.Data (if they are found)
// These fields are: logger and server_name
func (hook *SentryHook) Fire(entry *logrus.Entry) error {
packet := &raven.Packet{
Message: entry.Message,
Timestamp: raven.Timestamp(entry.Time),
Level: severityMap[entry.Level],
Platform: "go",
}
d := entry.Data
if logger, ok := getAndDel(d, "logger"); ok {
packet.Logger = logger
}
if serverName, ok := getAndDel(d, "server_name"); ok {
packet.ServerName = serverName
}
packet.Extra = map[string]interface{}(d)
_, errCh := hook.client.Capture(packet, nil)
timeout := hook.Timeout
if timeout != 0 {
timeoutCh := time.After(timeout)
select {
case err := <-errCh:
return err
case <-timeoutCh:
return fmt.Errorf("no response from sentry server in %s", timeout)
}
}
return nil
}
// Levels returns the available logging levels.
func (hook *SentryHook) Levels() []logrus.Level {
return hook.levels
}

View file

@ -1,97 +0,0 @@
package logrus_sentry
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/Sirupsen/logrus"
"github.com/getsentry/raven-go"
)
const (
message = "error message"
server_name = "testserver.internal"
logger_name = "test.logger"
)
func getTestLogger() *logrus.Logger {
l := logrus.New()
l.Out = ioutil.Discard
return l
}
func WithTestDSN(t *testing.T, tf func(string, <-chan *raven.Packet)) {
pch := make(chan *raven.Packet, 1)
s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
defer req.Body.Close()
d := json.NewDecoder(req.Body)
p := &raven.Packet{}
err := d.Decode(p)
if err != nil {
t.Fatal(err.Error())
}
pch <- p
}))
defer s.Close()
fragments := strings.SplitN(s.URL, "://", 2)
dsn := fmt.Sprintf(
"%s://public:secret@%s/sentry/project-id",
fragments[0],
fragments[1],
)
tf(dsn, pch)
}
func TestSpecialFields(t *testing.T) {
WithTestDSN(t, func(dsn string, pch <-chan *raven.Packet) {
logger := getTestLogger()
hook, err := NewSentryHook(dsn, []logrus.Level{
logrus.ErrorLevel,
})
if err != nil {
t.Fatal(err.Error())
}
logger.Hooks.Add(hook)
logger.WithFields(logrus.Fields{
"server_name": server_name,
"logger": logger_name,
}).Error(message)
packet := <-pch
if packet.Logger != logger_name {
t.Errorf("logger should have been %s, was %s", logger_name, packet.Logger)
}
if packet.ServerName != server_name {
t.Errorf("server_name should have been %s, was %s", server_name, packet.ServerName)
}
})
}
func TestSentryHandler(t *testing.T) {
WithTestDSN(t, func(dsn string, pch <-chan *raven.Packet) {
logger := getTestLogger()
hook, err := NewSentryHook(dsn, []logrus.Level{
logrus.ErrorLevel,
})
if err != nil {
t.Fatal(err.Error())
}
logger.Hooks.Add(hook)
logger.Error(message)
packet := <-pch
if packet.Message != message {
t.Errorf("message should have been %s, was %s", message, packet.Message)
}
})
}

View file

@ -1,20 +0,0 @@
# Syslog Hooks for Logrus <img src="http://i.imgur.com/hTeVwmJ.png" width="40" height="40" alt=":walrus:" class="emoji" title=":walrus:"/>
## Usage
```go
import (
"log/syslog"
"github.com/Sirupsen/logrus"
logrus_syslog "github.com/Sirupsen/logrus/hooks/syslog"
)
func main() {
log := logrus.New()
hook, err := logrus_syslog.NewSyslogHook("udp", "localhost:514", syslog.LOG_INFO, "")
if err == nil {
log.Hooks.Add(hook)
}
}
```

View file

@ -1,59 +0,0 @@
package logrus_syslog
import (
"fmt"
"github.com/Sirupsen/logrus"
"log/syslog"
"os"
)
// SyslogHook to send logs via syslog.
type SyslogHook struct {
Writer *syslog.Writer
SyslogNetwork string
SyslogRaddr string
}
// Creates a hook to be added to an instance of logger. This is called with
// `hook, err := NewSyslogHook("udp", "localhost:514", syslog.LOG_DEBUG, "")`
// `if err == nil { log.Hooks.Add(hook) }`
func NewSyslogHook(network, raddr string, priority syslog.Priority, tag string) (*SyslogHook, error) {
w, err := syslog.Dial(network, raddr, priority, tag)
return &SyslogHook{w, network, raddr}, err
}
func (hook *SyslogHook) Fire(entry *logrus.Entry) error {
line, err := entry.String()
if err != nil {
fmt.Fprintf(os.Stderr, "Unable to read entry, %v", err)
return err
}
switch entry.Level {
case logrus.PanicLevel:
return hook.Writer.Crit(line)
case logrus.FatalLevel:
return hook.Writer.Crit(line)
case logrus.ErrorLevel:
return hook.Writer.Err(line)
case logrus.WarnLevel:
return hook.Writer.Warning(line)
case logrus.InfoLevel:
return hook.Writer.Info(line)
case logrus.DebugLevel:
return hook.Writer.Debug(line)
default:
return nil
}
}
func (hook *SyslogHook) Levels() []logrus.Level {
return []logrus.Level{
logrus.PanicLevel,
logrus.FatalLevel,
logrus.ErrorLevel,
logrus.WarnLevel,
logrus.InfoLevel,
logrus.DebugLevel,
}
}

View file

@ -1,26 +0,0 @@
package logrus_syslog
import (
"github.com/Sirupsen/logrus"
"log/syslog"
"testing"
)
func TestLocalhostAddAndPrint(t *testing.T) {
log := logrus.New()
hook, err := NewSyslogHook("udp", "localhost:514", syslog.LOG_INFO, "")
if err != nil {
t.Errorf("Unable to connect to local syslog.")
}
log.Hooks.Add(hook)
for _, level := range hook.Levels() {
if len(log.Hooks[level]) != 1 {
t.Errorf("SyslogHook was not added. The length of log.Hooks[%v]: %v", level, len(log.Hooks[level]))
}
}
log.Info("Congratulations!")
}

View file

@ -1,120 +0,0 @@
package logrus
import (
"encoding/json"
"errors"
"testing"
)
func TestErrorNotLost(t *testing.T) {
formatter := &JSONFormatter{}
b, err := formatter.Format(WithField("error", errors.New("wild walrus")))
if err != nil {
t.Fatal("Unable to format entry: ", err)
}
entry := make(map[string]interface{})
err = json.Unmarshal(b, &entry)
if err != nil {
t.Fatal("Unable to unmarshal formatted entry: ", err)
}
if entry["error"] != "wild walrus" {
t.Fatal("Error field not set")
}
}
func TestErrorNotLostOnFieldNotNamedError(t *testing.T) {
formatter := &JSONFormatter{}
b, err := formatter.Format(WithField("omg", errors.New("wild walrus")))
if err != nil {
t.Fatal("Unable to format entry: ", err)
}
entry := make(map[string]interface{})
err = json.Unmarshal(b, &entry)
if err != nil {
t.Fatal("Unable to unmarshal formatted entry: ", err)
}
if entry["omg"] != "wild walrus" {
t.Fatal("Error field not set")
}
}
func TestFieldClashWithTime(t *testing.T) {
formatter := &JSONFormatter{}
b, err := formatter.Format(WithField("time", "right now!"))
if err != nil {
t.Fatal("Unable to format entry: ", err)
}
entry := make(map[string]interface{})
err = json.Unmarshal(b, &entry)
if err != nil {
t.Fatal("Unable to unmarshal formatted entry: ", err)
}
if entry["fields.time"] != "right now!" {
t.Fatal("fields.time not set to original time field")
}
if entry["time"] != "0001-01-01T00:00:00Z" {
t.Fatal("time field not set to current time, was: ", entry["time"])
}
}
func TestFieldClashWithMsg(t *testing.T) {
formatter := &JSONFormatter{}
b, err := formatter.Format(WithField("msg", "something"))
if err != nil {
t.Fatal("Unable to format entry: ", err)
}
entry := make(map[string]interface{})
err = json.Unmarshal(b, &entry)
if err != nil {
t.Fatal("Unable to unmarshal formatted entry: ", err)
}
if entry["fields.msg"] != "something" {
t.Fatal("fields.msg not set to original msg field")
}
}
func TestFieldClashWithLevel(t *testing.T) {
formatter := &JSONFormatter{}
b, err := formatter.Format(WithField("level", "something"))
if err != nil {
t.Fatal("Unable to format entry: ", err)
}
entry := make(map[string]interface{})
err = json.Unmarshal(b, &entry)
if err != nil {
t.Fatal("Unable to unmarshal formatted entry: ", err)
}
if entry["fields.level"] != "something" {
t.Fatal("fields.level not set to original level field")
}
}
func TestJSONEntryEndsWithNewline(t *testing.T) {
formatter := &JSONFormatter{}
b, err := formatter.Format(WithField("level", "something"))
if err != nil {
t.Fatal("Unable to format entry: ", err)
}
if b[len(b)-1] != '\n' {
t.Fatal("Expected JSON log entry to end with a newline")
}
}

View file

@ -1,301 +0,0 @@
package logrus
import (
"bytes"
"encoding/json"
"strconv"
"strings"
"sync"
"testing"
"github.com/stretchr/testify/assert"
)
func LogAndAssertJSON(t *testing.T, log func(*Logger), assertions func(fields Fields)) {
var buffer bytes.Buffer
var fields Fields
logger := New()
logger.Out = &buffer
logger.Formatter = new(JSONFormatter)
log(logger)
err := json.Unmarshal(buffer.Bytes(), &fields)
assert.Nil(t, err)
assertions(fields)
}
func LogAndAssertText(t *testing.T, log func(*Logger), assertions func(fields map[string]string)) {
var buffer bytes.Buffer
logger := New()
logger.Out = &buffer
logger.Formatter = &TextFormatter{
DisableColors: true,
}
log(logger)
fields := make(map[string]string)
for _, kv := range strings.Split(buffer.String(), " ") {
if !strings.Contains(kv, "=") {
continue
}
kvArr := strings.Split(kv, "=")
key := strings.TrimSpace(kvArr[0])
val := kvArr[1]
if kvArr[1][0] == '"' {
var err error
val, err = strconv.Unquote(val)
assert.NoError(t, err)
}
fields[key] = val
}
assertions(fields)
}
func TestPrint(t *testing.T) {
LogAndAssertJSON(t, func(log *Logger) {
log.Print("test")
}, func(fields Fields) {
assert.Equal(t, fields["msg"], "test")
assert.Equal(t, fields["level"], "info")
})
}
func TestInfo(t *testing.T) {
LogAndAssertJSON(t, func(log *Logger) {
log.Info("test")
}, func(fields Fields) {
assert.Equal(t, fields["msg"], "test")
assert.Equal(t, fields["level"], "info")
})
}
func TestWarn(t *testing.T) {
LogAndAssertJSON(t, func(log *Logger) {
log.Warn("test")
}, func(fields Fields) {
assert.Equal(t, fields["msg"], "test")
assert.Equal(t, fields["level"], "warning")
})
}
func TestInfolnShouldAddSpacesBetweenStrings(t *testing.T) {
LogAndAssertJSON(t, func(log *Logger) {
log.Infoln("test", "test")
}, func(fields Fields) {
assert.Equal(t, fields["msg"], "test test")
})
}
func TestInfolnShouldAddSpacesBetweenStringAndNonstring(t *testing.T) {
LogAndAssertJSON(t, func(log *Logger) {
log.Infoln("test", 10)
}, func(fields Fields) {
assert.Equal(t, fields["msg"], "test 10")
})
}
func TestInfolnShouldAddSpacesBetweenTwoNonStrings(t *testing.T) {
LogAndAssertJSON(t, func(log *Logger) {
log.Infoln(10, 10)
}, func(fields Fields) {
assert.Equal(t, fields["msg"], "10 10")
})
}
func TestInfoShouldAddSpacesBetweenTwoNonStrings(t *testing.T) {
LogAndAssertJSON(t, func(log *Logger) {
log.Infoln(10, 10)
}, func(fields Fields) {
assert.Equal(t, fields["msg"], "10 10")
})
}
func TestInfoShouldNotAddSpacesBetweenStringAndNonstring(t *testing.T) {
LogAndAssertJSON(t, func(log *Logger) {
log.Info("test", 10)
}, func(fields Fields) {
assert.Equal(t, fields["msg"], "test10")
})
}
func TestInfoShouldNotAddSpacesBetweenStrings(t *testing.T) {
LogAndAssertJSON(t, func(log *Logger) {
log.Info("test", "test")
}, func(fields Fields) {
assert.Equal(t, fields["msg"], "testtest")
})
}
func TestWithFieldsShouldAllowAssignments(t *testing.T) {
var buffer bytes.Buffer
var fields Fields
logger := New()
logger.Out = &buffer
logger.Formatter = new(JSONFormatter)
localLog := logger.WithFields(Fields{
"key1": "value1",
})
localLog.WithField("key2", "value2").Info("test")
err := json.Unmarshal(buffer.Bytes(), &fields)
assert.Nil(t, err)
assert.Equal(t, "value2", fields["key2"])
assert.Equal(t, "value1", fields["key1"])
buffer = bytes.Buffer{}
fields = Fields{}
localLog.Info("test")
err = json.Unmarshal(buffer.Bytes(), &fields)
assert.Nil(t, err)
_, ok := fields["key2"]
assert.Equal(t, false, ok)
assert.Equal(t, "value1", fields["key1"])
}
func TestUserSuppliedFieldDoesNotOverwriteDefaults(t *testing.T) {
LogAndAssertJSON(t, func(log *Logger) {
log.WithField("msg", "hello").Info("test")
}, func(fields Fields) {
assert.Equal(t, fields["msg"], "test")
})
}
func TestUserSuppliedMsgFieldHasPrefix(t *testing.T) {
LogAndAssertJSON(t, func(log *Logger) {
log.WithField("msg", "hello").Info("test")
}, func(fields Fields) {
assert.Equal(t, fields["msg"], "test")
assert.Equal(t, fields["fields.msg"], "hello")
})
}
func TestUserSuppliedTimeFieldHasPrefix(t *testing.T) {
LogAndAssertJSON(t, func(log *Logger) {
log.WithField("time", "hello").Info("test")
}, func(fields Fields) {
assert.Equal(t, fields["fields.time"], "hello")
})
}
func TestUserSuppliedLevelFieldHasPrefix(t *testing.T) {
LogAndAssertJSON(t, func(log *Logger) {
log.WithField("level", 1).Info("test")
}, func(fields Fields) {
assert.Equal(t, fields["level"], "info")
assert.Equal(t, fields["fields.level"], 1)
})
}
func TestDefaultFieldsAreNotPrefixed(t *testing.T) {
LogAndAssertText(t, func(log *Logger) {
ll := log.WithField("herp", "derp")
ll.Info("hello")
ll.Info("bye")
}, func(fields map[string]string) {
for _, fieldName := range []string{"fields.level", "fields.time", "fields.msg"} {
if _, ok := fields[fieldName]; ok {
t.Fatalf("should not have prefixed %q: %v", fieldName, fields)
}
}
})
}
func TestDoubleLoggingDoesntPrefixPreviousFields(t *testing.T) {
var buffer bytes.Buffer
var fields Fields
logger := New()
logger.Out = &buffer
logger.Formatter = new(JSONFormatter)
llog := logger.WithField("context", "eating raw fish")
llog.Info("looks delicious")
err := json.Unmarshal(buffer.Bytes(), &fields)
assert.NoError(t, err, "should have decoded first message")
assert.Equal(t, len(fields), 4, "should only have msg/time/level/context fields")
assert.Equal(t, fields["msg"], "looks delicious")
assert.Equal(t, fields["context"], "eating raw fish")
buffer.Reset()
llog.Warn("omg it is!")
err = json.Unmarshal(buffer.Bytes(), &fields)
assert.NoError(t, err, "should have decoded second message")
assert.Equal(t, len(fields), 4, "should only have msg/time/level/context fields")
assert.Equal(t, fields["msg"], "omg it is!")
assert.Equal(t, fields["context"], "eating raw fish")
assert.Nil(t, fields["fields.msg"], "should not have prefixed previous `msg` entry")
}
func TestConvertLevelToString(t *testing.T) {
assert.Equal(t, "debug", DebugLevel.String())
assert.Equal(t, "info", InfoLevel.String())
assert.Equal(t, "warning", WarnLevel.String())
assert.Equal(t, "error", ErrorLevel.String())
assert.Equal(t, "fatal", FatalLevel.String())
assert.Equal(t, "panic", PanicLevel.String())
}
func TestParseLevel(t *testing.T) {
l, err := ParseLevel("panic")
assert.Nil(t, err)
assert.Equal(t, PanicLevel, l)
l, err = ParseLevel("fatal")
assert.Nil(t, err)
assert.Equal(t, FatalLevel, l)
l, err = ParseLevel("error")
assert.Nil(t, err)
assert.Equal(t, ErrorLevel, l)
l, err = ParseLevel("warn")
assert.Nil(t, err)
assert.Equal(t, WarnLevel, l)
l, err = ParseLevel("warning")
assert.Nil(t, err)
assert.Equal(t, WarnLevel, l)
l, err = ParseLevel("info")
assert.Nil(t, err)
assert.Equal(t, InfoLevel, l)
l, err = ParseLevel("debug")
assert.Nil(t, err)
assert.Equal(t, DebugLevel, l)
l, err = ParseLevel("invalid")
assert.Equal(t, "not a valid logrus Level: \"invalid\"", err.Error())
}
func TestGetSetLevelRace(t *testing.T) {
wg := sync.WaitGroup{}
for i := 0; i < 100; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
if i%2 == 0 {
SetLevel(InfoLevel)
} else {
GetLevel()
}
}(i)
}
wg.Wait()
}

View file

@ -1,61 +0,0 @@
package logrus
import (
"bytes"
"errors"
"testing"
"time"
)
func TestQuoting(t *testing.T) {
tf := &TextFormatter{DisableColors: true}
checkQuoting := func(q bool, value interface{}) {
b, _ := tf.Format(WithField("test", value))
idx := bytes.Index(b, ([]byte)("test="))
cont := bytes.Contains(b[idx+5:], []byte{'"'})
if cont != q {
if q {
t.Errorf("quoting expected for: %#v", value)
} else {
t.Errorf("quoting not expected for: %#v", value)
}
}
}
checkQuoting(false, "abcd")
checkQuoting(false, "v1.0")
checkQuoting(false, "1234567890")
checkQuoting(true, "/foobar")
checkQuoting(true, "x y")
checkQuoting(true, "x,y")
checkQuoting(false, errors.New("invalid"))
checkQuoting(true, errors.New("invalid argument"))
}
func TestTimestampFormat(t *testing.T) {
checkTimeStr := func(format string) {
customFormatter := &TextFormatter{DisableColors: true, TimestampFormat: format}
customStr, _ := customFormatter.Format(WithField("test", "test"))
timeStart := bytes.Index(customStr, ([]byte)("time="))
timeEnd := bytes.Index(customStr, ([]byte)("level="))
timeStr := customStr[timeStart+5 : timeEnd-1]
if timeStr[0] == '"' && timeStr[len(timeStr)-1] == '"' {
timeStr = timeStr[1 : len(timeStr)-1]
}
if format == "" {
format = time.RFC3339
}
_, e := time.Parse(format, (string)(timeStr))
if e != nil {
t.Errorf("time string \"%s\" did not match provided time format \"%s\": %s", timeStr, format, e)
}
}
checkTimeStr("2006-01-02T15:04:05.000000000Z07:00")
checkTimeStr("Mon Jan _2 15:04:05 2006")
checkTimeStr("")
}
// TODO add tests for sorting etc., this requires a parser for the text
// formatter output.

View file

@ -1,134 +0,0 @@
// Package stscreds are credential Providers to retrieve STS AWS credentials.
//
// STS provides multiple ways to retrieve credentials which can be used when making
// future AWS service API operation calls.
package stscreds
import (
"fmt"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/service/sts"
)
// ProviderName provides a name of AssumeRole provider
const ProviderName = "AssumeRoleProvider"
// AssumeRoler represents the minimal subset of the STS client API used by this provider.
type AssumeRoler interface {
AssumeRole(input *sts.AssumeRoleInput) (*sts.AssumeRoleOutput, error)
}
// DefaultDuration is the default amount of time in minutes that the credentials
// will be valid for.
var DefaultDuration = time.Duration(15) * time.Minute
// AssumeRoleProvider retrieves temporary credentials from the STS service, and
// keeps track of their expiration time. This provider must be used explicitly,
// as it is not included in the credentials chain.
type AssumeRoleProvider struct {
credentials.Expiry
// STS client to make assume role request with.
Client AssumeRoler
// Role to be assumed.
RoleARN string
// Session name, if you wish to reuse the credentials elsewhere.
RoleSessionName string
// Expiry duration of the STS credentials. Defaults to 15 minutes if not set.
Duration time.Duration
// Optional ExternalID to pass along, defaults to nil if not set.
ExternalID *string
// ExpiryWindow will allow the credentials to trigger refreshing prior to
// the credentials actually expiring. This is beneficial so race conditions
// with expiring credentials do not cause request to fail unexpectedly
// due to ExpiredTokenException exceptions.
//
// So a ExpiryWindow of 10s would cause calls to IsExpired() to return true
// 10 seconds before the credentials are actually expired.
//
// If ExpiryWindow is 0 or less it will be ignored.
ExpiryWindow time.Duration
}
// NewCredentials returns a pointer to a new Credentials object wrapping the
// AssumeRoleProvider. The credentials will expire every 15 minutes and the
// role will be named after a nanosecond timestamp of this operation.
//
// Takes a Config provider to create the STS client. The ConfigProvider is
// satisfied by the session.Session type.
func NewCredentials(c client.ConfigProvider, roleARN string, options ...func(*AssumeRoleProvider)) *credentials.Credentials {
p := &AssumeRoleProvider{
Client: sts.New(c),
RoleARN: roleARN,
Duration: DefaultDuration,
}
for _, option := range options {
option(p)
}
return credentials.NewCredentials(p)
}
// NewCredentialsWithClient returns a pointer to a new Credentials object wrapping the
// AssumeRoleProvider. The credentials will expire every 15 minutes and the
// role will be named after a nanosecond timestamp of this operation.
//
// Takes an AssumeRoler which can be satisfiede by the STS client.
func NewCredentialsWithClient(svc AssumeRoler, roleARN string, options ...func(*AssumeRoleProvider)) *credentials.Credentials {
p := &AssumeRoleProvider{
Client: svc,
RoleARN: roleARN,
Duration: DefaultDuration,
}
for _, option := range options {
option(p)
}
return credentials.NewCredentials(p)
}
// Retrieve generates a new set of temporary credentials using STS.
func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
// Apply defaults where parameters are not set.
if p.RoleSessionName == "" {
// Try to work out a role name that will hopefully end up unique.
p.RoleSessionName = fmt.Sprintf("%d", time.Now().UTC().UnixNano())
}
if p.Duration == 0 {
// Expire as often as AWS permits.
p.Duration = DefaultDuration
}
roleOutput, err := p.Client.AssumeRole(&sts.AssumeRoleInput{
DurationSeconds: aws.Int64(int64(p.Duration / time.Second)),
RoleArn: aws.String(p.RoleARN),
RoleSessionName: aws.String(p.RoleSessionName),
ExternalId: p.ExternalID,
})
if err != nil {
return credentials.Value{ProviderName: ProviderName}, err
}
// We will proactively generate new credentials before they expire.
p.SetExpiration(*roleOutput.Credentials.Expiration, p.ExpiryWindow)
return credentials.Value{
AccessKeyID: *roleOutput.Credentials.AccessKeyId,
SecretAccessKey: *roleOutput.Credentials.SecretAccessKey,
SessionToken: *roleOutput.Credentials.SessionToken,
ProviderName: ProviderName,
}, nil
}

View file

@ -1,35 +0,0 @@
// Package ec2query provides serialisation of AWS EC2 requests and responses.
package ec2query
//go:generate go run ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/input/ec2.json build_test.go
import (
"net/url"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/private/protocol/query/queryutil"
)
// BuildHandler is a named request handler for building ec2query protocol requests
var BuildHandler = request.NamedHandler{Name: "awssdk.ec2query.Build", Fn: Build}
// Build builds a request for the EC2 protocol.
func Build(r *request.Request) {
body := url.Values{
"Action": {r.Operation.Name},
"Version": {r.ClientInfo.APIVersion},
}
if err := queryutil.Parse(body, r.Params, true); err != nil {
r.Error = awserr.New("SerializationError", "failed encoding EC2 Query request", err)
}
if r.ExpireTime == 0 {
r.HTTPRequest.Method = "POST"
r.HTTPRequest.Header.Set("Content-Type", "application/x-www-form-urlencoded; charset=utf-8")
r.SetBufferBody([]byte(body.Encode()))
} else { // This is a pre-signed request
r.HTTPRequest.Method = "GET"
r.HTTPRequest.URL.RawQuery = body.Encode()
}
}

View file

@ -1,63 +0,0 @@
package ec2query
//go:generate go run ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/output/ec2.json unmarshal_test.go
import (
"encoding/xml"
"io"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil"
)
// UnmarshalHandler is a named request handler for unmarshaling ec2query protocol requests
var UnmarshalHandler = request.NamedHandler{Name: "awssdk.ec2query.Unmarshal", Fn: Unmarshal}
// UnmarshalMetaHandler is a named request handler for unmarshaling ec2query protocol request metadata
var UnmarshalMetaHandler = request.NamedHandler{Name: "awssdk.ec2query.UnmarshalMeta", Fn: UnmarshalMeta}
// UnmarshalErrorHandler is a named request handler for unmarshaling ec2query protocol request errors
var UnmarshalErrorHandler = request.NamedHandler{Name: "awssdk.ec2query.UnmarshalError", Fn: UnmarshalError}
// Unmarshal unmarshals a response body for the EC2 protocol.
func Unmarshal(r *request.Request) {
defer r.HTTPResponse.Body.Close()
if r.DataFilled() {
decoder := xml.NewDecoder(r.HTTPResponse.Body)
err := xmlutil.UnmarshalXML(r.Data, decoder, "")
if err != nil {
r.Error = awserr.New("SerializationError", "failed decoding EC2 Query response", err)
return
}
}
}
// UnmarshalMeta unmarshals response headers for the EC2 protocol.
func UnmarshalMeta(r *request.Request) {
// TODO implement unmarshaling of request IDs
}
type xmlErrorResponse struct {
XMLName xml.Name `xml:"Response"`
Code string `xml:"Errors>Error>Code"`
Message string `xml:"Errors>Error>Message"`
RequestID string `xml:"RequestId"`
}
// UnmarshalError unmarshals a response error for the EC2 protocol.
func UnmarshalError(r *request.Request) {
defer r.HTTPResponse.Body.Close()
resp := &xmlErrorResponse{}
err := xml.NewDecoder(r.HTTPResponse.Body).Decode(resp)
if err != nil && err != io.EOF {
r.Error = awserr.New("SerializationError", "failed decoding EC2 Query error response", err)
} else {
r.Error = awserr.NewRequestFailure(
awserr.New(resp.Code, resp.Message, nil),
r.HTTPResponse.StatusCode,
resp.RequestID,
)
}
}

View file

@ -1,251 +0,0 @@
// Package jsonutil provides JSON serialisation of AWS requests and responses.
package jsonutil
import (
"bytes"
"encoding/base64"
"fmt"
"reflect"
"sort"
"strconv"
"time"
"github.com/aws/aws-sdk-go/private/protocol"
)
var timeType = reflect.ValueOf(time.Time{}).Type()
var byteSliceType = reflect.ValueOf([]byte{}).Type()
// BuildJSON builds a JSON string for a given object v.
func BuildJSON(v interface{}) ([]byte, error) {
var buf bytes.Buffer
err := buildAny(reflect.ValueOf(v), &buf, "")
return buf.Bytes(), err
}
func buildAny(value reflect.Value, buf *bytes.Buffer, tag reflect.StructTag) error {
value = reflect.Indirect(value)
if !value.IsValid() {
return nil
}
vtype := value.Type()
t := tag.Get("type")
if t == "" {
switch vtype.Kind() {
case reflect.Struct:
// also it can't be a time object
if value.Type() != timeType {
t = "structure"
}
case reflect.Slice:
// also it can't be a byte slice
if _, ok := value.Interface().([]byte); !ok {
t = "list"
}
case reflect.Map:
t = "map"
}
}
switch t {
case "structure":
if field, ok := vtype.FieldByName("_"); ok {
tag = field.Tag
}
return buildStruct(value, buf, tag)
case "list":
return buildList(value, buf, tag)
case "map":
return buildMap(value, buf, tag)
default:
return buildScalar(value, buf, tag)
}
}
func buildStruct(value reflect.Value, buf *bytes.Buffer, tag reflect.StructTag) error {
if !value.IsValid() {
return nil
}
// unwrap payloads
if payload := tag.Get("payload"); payload != "" {
field, _ := value.Type().FieldByName(payload)
tag = field.Tag
value = elemOf(value.FieldByName(payload))
if !value.IsValid() {
return nil
}
}
buf.WriteByte('{')
t := value.Type()
first := true
for i := 0; i < t.NumField(); i++ {
member := value.Field(i)
field := t.Field(i)
if field.PkgPath != "" {
continue // ignore unexported fields
}
if field.Tag.Get("json") == "-" {
continue
}
if field.Tag.Get("location") != "" {
continue // ignore non-body elements
}
if protocol.CanSetIdempotencyToken(member, field) {
token := protocol.GetIdempotencyToken()
member = reflect.ValueOf(&token)
}
if (member.Kind() == reflect.Ptr || member.Kind() == reflect.Slice || member.Kind() == reflect.Map) && member.IsNil() {
continue // ignore unset fields
}
if first {
first = false
} else {
buf.WriteByte(',')
}
// figure out what this field is called
name := field.Name
if locName := field.Tag.Get("locationName"); locName != "" {
name = locName
}
fmt.Fprintf(buf, "%q:", name)
err := buildAny(member, buf, field.Tag)
if err != nil {
return err
}
}
buf.WriteString("}")
return nil
}
func buildList(value reflect.Value, buf *bytes.Buffer, tag reflect.StructTag) error {
buf.WriteString("[")
for i := 0; i < value.Len(); i++ {
buildAny(value.Index(i), buf, "")
if i < value.Len()-1 {
buf.WriteString(",")
}
}
buf.WriteString("]")
return nil
}
type sortedValues []reflect.Value
func (sv sortedValues) Len() int { return len(sv) }
func (sv sortedValues) Swap(i, j int) { sv[i], sv[j] = sv[j], sv[i] }
func (sv sortedValues) Less(i, j int) bool { return sv[i].String() < sv[j].String() }
func buildMap(value reflect.Value, buf *bytes.Buffer, tag reflect.StructTag) error {
buf.WriteString("{")
var sv sortedValues = value.MapKeys()
sort.Sort(sv)
for i, k := range sv {
if i > 0 {
buf.WriteByte(',')
}
fmt.Fprintf(buf, "%q:", k)
buildAny(value.MapIndex(k), buf, "")
}
buf.WriteString("}")
return nil
}
func buildScalar(value reflect.Value, buf *bytes.Buffer, tag reflect.StructTag) error {
switch value.Kind() {
case reflect.String:
writeString(value.String(), buf)
case reflect.Bool:
buf.WriteString(strconv.FormatBool(value.Bool()))
case reflect.Int64:
buf.WriteString(strconv.FormatInt(value.Int(), 10))
case reflect.Float64:
buf.WriteString(strconv.FormatFloat(value.Float(), 'f', -1, 64))
default:
switch value.Type() {
case timeType:
converted := value.Interface().(time.Time)
buf.WriteString(strconv.FormatInt(converted.UTC().Unix(), 10))
case byteSliceType:
if !value.IsNil() {
converted := value.Interface().([]byte)
buf.WriteByte('"')
if len(converted) < 1024 {
// for small buffers, using Encode directly is much faster.
dst := make([]byte, base64.StdEncoding.EncodedLen(len(converted)))
base64.StdEncoding.Encode(dst, converted)
buf.Write(dst)
} else {
// for large buffers, avoid unnecessary extra temporary
// buffer space.
enc := base64.NewEncoder(base64.StdEncoding, buf)
enc.Write(converted)
enc.Close()
}
buf.WriteByte('"')
}
default:
return fmt.Errorf("unsupported JSON value %v (%s)", value.Interface(), value.Type())
}
}
return nil
}
func writeString(s string, buf *bytes.Buffer) {
buf.WriteByte('"')
for _, r := range s {
if r == '"' {
buf.WriteString(`\"`)
} else if r == '\\' {
buf.WriteString(`\\`)
} else if r == '\b' {
buf.WriteString(`\b`)
} else if r == '\f' {
buf.WriteString(`\f`)
} else if r == '\r' {
buf.WriteString(`\r`)
} else if r == '\t' {
buf.WriteString(`\t`)
} else if r == '\n' {
buf.WriteString(`\n`)
} else if r < 32 {
fmt.Fprintf(buf, "\\u%0.4x", r)
} else {
buf.WriteRune(r)
}
}
buf.WriteByte('"')
}
// Returns the reflection element of a value, if it is a pointer.
func elemOf(value reflect.Value) reflect.Value {
for value.Kind() == reflect.Ptr {
value = value.Elem()
}
return value
}

View file

@ -1,213 +0,0 @@
package jsonutil
import (
"encoding/base64"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"reflect"
"time"
)
// UnmarshalJSON reads a stream and unmarshals the results in object v.
func UnmarshalJSON(v interface{}, stream io.Reader) error {
var out interface{}
b, err := ioutil.ReadAll(stream)
if err != nil {
return err
}
if len(b) == 0 {
return nil
}
if err := json.Unmarshal(b, &out); err != nil {
return err
}
return unmarshalAny(reflect.ValueOf(v), out, "")
}
func unmarshalAny(value reflect.Value, data interface{}, tag reflect.StructTag) error {
vtype := value.Type()
if vtype.Kind() == reflect.Ptr {
vtype = vtype.Elem() // check kind of actual element type
}
t := tag.Get("type")
if t == "" {
switch vtype.Kind() {
case reflect.Struct:
// also it can't be a time object
if _, ok := value.Interface().(*time.Time); !ok {
t = "structure"
}
case reflect.Slice:
// also it can't be a byte slice
if _, ok := value.Interface().([]byte); !ok {
t = "list"
}
case reflect.Map:
t = "map"
}
}
switch t {
case "structure":
if field, ok := vtype.FieldByName("_"); ok {
tag = field.Tag
}
return unmarshalStruct(value, data, tag)
case "list":
return unmarshalList(value, data, tag)
case "map":
return unmarshalMap(value, data, tag)
default:
return unmarshalScalar(value, data, tag)
}
}
func unmarshalStruct(value reflect.Value, data interface{}, tag reflect.StructTag) error {
if data == nil {
return nil
}
mapData, ok := data.(map[string]interface{})
if !ok {
return fmt.Errorf("JSON value is not a structure (%#v)", data)
}
t := value.Type()
if value.Kind() == reflect.Ptr {
if value.IsNil() { // create the structure if it's nil
s := reflect.New(value.Type().Elem())
value.Set(s)
value = s
}
value = value.Elem()
t = t.Elem()
}
// unwrap any payloads
if payload := tag.Get("payload"); payload != "" {
field, _ := t.FieldByName(payload)
return unmarshalAny(value.FieldByName(payload), data, field.Tag)
}
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
if field.PkgPath != "" {
continue // ignore unexported fields
}
// figure out what this field is called
name := field.Name
if locName := field.Tag.Get("locationName"); locName != "" {
name = locName
}
member := value.FieldByIndex(field.Index)
err := unmarshalAny(member, mapData[name], field.Tag)
if err != nil {
return err
}
}
return nil
}
func unmarshalList(value reflect.Value, data interface{}, tag reflect.StructTag) error {
if data == nil {
return nil
}
listData, ok := data.([]interface{})
if !ok {
return fmt.Errorf("JSON value is not a list (%#v)", data)
}
if value.IsNil() {
l := len(listData)
value.Set(reflect.MakeSlice(value.Type(), l, l))
}
for i, c := range listData {
err := unmarshalAny(value.Index(i), c, "")
if err != nil {
return err
}
}
return nil
}
func unmarshalMap(value reflect.Value, data interface{}, tag reflect.StructTag) error {
if data == nil {
return nil
}
mapData, ok := data.(map[string]interface{})
if !ok {
return fmt.Errorf("JSON value is not a map (%#v)", data)
}
if value.IsNil() {
value.Set(reflect.MakeMap(value.Type()))
}
for k, v := range mapData {
kvalue := reflect.ValueOf(k)
vvalue := reflect.New(value.Type().Elem()).Elem()
unmarshalAny(vvalue, v, "")
value.SetMapIndex(kvalue, vvalue)
}
return nil
}
func unmarshalScalar(value reflect.Value, data interface{}, tag reflect.StructTag) error {
errf := func() error {
return fmt.Errorf("unsupported value: %v (%s)", value.Interface(), value.Type())
}
switch d := data.(type) {
case nil:
return nil // nothing to do here
case string:
switch value.Interface().(type) {
case *string:
value.Set(reflect.ValueOf(&d))
case []byte:
b, err := base64.StdEncoding.DecodeString(d)
if err != nil {
return err
}
value.Set(reflect.ValueOf(b))
default:
return errf()
}
case float64:
switch value.Interface().(type) {
case *int64:
di := int64(d)
value.Set(reflect.ValueOf(&di))
case *float64:
value.Set(reflect.ValueOf(&d))
case *time.Time:
t := time.Unix(int64(d), 0).UTC()
value.Set(reflect.ValueOf(&t))
default:
return errf()
}
case bool:
switch value.Interface().(type) {
case *bool:
value.Set(reflect.ValueOf(&d))
default:
return errf()
}
default:
return fmt.Errorf("unsupported JSON value (%v)", data)
}
return nil
}

View file

@ -1,111 +0,0 @@
// Package jsonrpc provides JSON RPC utilities for serialisation of AWS
// requests and responses.
package jsonrpc
//go:generate go run ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/input/json.json build_test.go
//go:generate go run ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/output/json.json unmarshal_test.go
import (
"encoding/json"
"io/ioutil"
"strings"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/private/protocol/json/jsonutil"
"github.com/aws/aws-sdk-go/private/protocol/rest"
)
var emptyJSON = []byte("{}")
// BuildHandler is a named request handler for building jsonrpc protocol requests
var BuildHandler = request.NamedHandler{Name: "awssdk.jsonrpc.Build", Fn: Build}
// UnmarshalHandler is a named request handler for unmarshaling jsonrpc protocol requests
var UnmarshalHandler = request.NamedHandler{Name: "awssdk.jsonrpc.Unmarshal", Fn: Unmarshal}
// UnmarshalMetaHandler is a named request handler for unmarshaling jsonrpc protocol request metadata
var UnmarshalMetaHandler = request.NamedHandler{Name: "awssdk.jsonrpc.UnmarshalMeta", Fn: UnmarshalMeta}
// UnmarshalErrorHandler is a named request handler for unmarshaling jsonrpc protocol request errors
var UnmarshalErrorHandler = request.NamedHandler{Name: "awssdk.jsonrpc.UnmarshalError", Fn: UnmarshalError}
// Build builds a JSON payload for a JSON RPC request.
func Build(req *request.Request) {
var buf []byte
var err error
if req.ParamsFilled() {
buf, err = jsonutil.BuildJSON(req.Params)
if err != nil {
req.Error = awserr.New("SerializationError", "failed encoding JSON RPC request", err)
return
}
} else {
buf = emptyJSON
}
if req.ClientInfo.TargetPrefix != "" || string(buf) != "{}" {
req.SetBufferBody(buf)
}
if req.ClientInfo.TargetPrefix != "" {
target := req.ClientInfo.TargetPrefix + "." + req.Operation.Name
req.HTTPRequest.Header.Add("X-Amz-Target", target)
}
if req.ClientInfo.JSONVersion != "" {
jsonVersion := req.ClientInfo.JSONVersion
req.HTTPRequest.Header.Add("Content-Type", "application/x-amz-json-"+jsonVersion)
}
}
// Unmarshal unmarshals a response for a JSON RPC service.
func Unmarshal(req *request.Request) {
defer req.HTTPResponse.Body.Close()
if req.DataFilled() {
err := jsonutil.UnmarshalJSON(req.Data, req.HTTPResponse.Body)
if err != nil {
req.Error = awserr.New("SerializationError", "failed decoding JSON RPC response", err)
}
}
return
}
// UnmarshalMeta unmarshals headers from a response for a JSON RPC service.
func UnmarshalMeta(req *request.Request) {
rest.UnmarshalMeta(req)
}
// UnmarshalError unmarshals an error response for a JSON RPC service.
func UnmarshalError(req *request.Request) {
defer req.HTTPResponse.Body.Close()
bodyBytes, err := ioutil.ReadAll(req.HTTPResponse.Body)
if err != nil {
req.Error = awserr.New("SerializationError", "failed reading JSON RPC error response", err)
return
}
if len(bodyBytes) == 0 {
req.Error = awserr.NewRequestFailure(
awserr.New("SerializationError", req.HTTPResponse.Status, nil),
req.HTTPResponse.StatusCode,
"",
)
return
}
var jsonErr jsonErrorResponse
if err := json.Unmarshal(bodyBytes, &jsonErr); err != nil {
req.Error = awserr.New("SerializationError", "failed decoding JSON RPC error response", err)
return
}
codes := strings.SplitN(jsonErr.Code, "#", 2)
req.Error = awserr.NewRequestFailure(
awserr.New(codes[len(codes)-1], jsonErr.Message, nil),
req.HTTPResponse.StatusCode,
req.RequestID,
)
}
type jsonErrorResponse struct {
Code string `json:"__type"`
Message string `json:"message"`
}

View file

@ -1,91 +0,0 @@
// Package restjson provides RESTful JSON serialisation of AWS
// requests and responses.
package restjson
//go:generate go run ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/input/rest-json.json build_test.go
//go:generate go run ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/output/rest-json.json unmarshal_test.go
import (
"encoding/json"
"io/ioutil"
"strings"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/private/protocol/jsonrpc"
"github.com/aws/aws-sdk-go/private/protocol/rest"
)
// BuildHandler is a named request handler for building restjson protocol requests
var BuildHandler = request.NamedHandler{Name: "awssdk.restjson.Build", Fn: Build}
// UnmarshalHandler is a named request handler for unmarshaling restjson protocol requests
var UnmarshalHandler = request.NamedHandler{Name: "awssdk.restjson.Unmarshal", Fn: Unmarshal}
// UnmarshalMetaHandler is a named request handler for unmarshaling restjson protocol request metadata
var UnmarshalMetaHandler = request.NamedHandler{Name: "awssdk.restjson.UnmarshalMeta", Fn: UnmarshalMeta}
// UnmarshalErrorHandler is a named request handler for unmarshaling restjson protocol request errors
var UnmarshalErrorHandler = request.NamedHandler{Name: "awssdk.restjson.UnmarshalError", Fn: UnmarshalError}
// Build builds a request for the REST JSON protocol.
func Build(r *request.Request) {
rest.Build(r)
if t := rest.PayloadType(r.Params); t == "structure" || t == "" {
jsonrpc.Build(r)
}
}
// Unmarshal unmarshals a response body for the REST JSON protocol.
func Unmarshal(r *request.Request) {
if t := rest.PayloadType(r.Data); t == "structure" || t == "" {
jsonrpc.Unmarshal(r)
} else {
rest.Unmarshal(r)
}
}
// UnmarshalMeta unmarshals response headers for the REST JSON protocol.
func UnmarshalMeta(r *request.Request) {
rest.UnmarshalMeta(r)
}
// UnmarshalError unmarshals a response error for the REST JSON protocol.
func UnmarshalError(r *request.Request) {
code := r.HTTPResponse.Header.Get("X-Amzn-Errortype")
bodyBytes, err := ioutil.ReadAll(r.HTTPResponse.Body)
if err != nil {
r.Error = awserr.New("SerializationError", "failed reading REST JSON error response", err)
return
}
if len(bodyBytes) == 0 {
r.Error = awserr.NewRequestFailure(
awserr.New("SerializationError", r.HTTPResponse.Status, nil),
r.HTTPResponse.StatusCode,
"",
)
return
}
var jsonErr jsonErrorResponse
if err := json.Unmarshal(bodyBytes, &jsonErr); err != nil {
r.Error = awserr.New("SerializationError", "failed decoding REST JSON error response", err)
return
}
if code == "" {
code = jsonErr.Code
}
code = strings.SplitN(code, ":", 2)[0]
r.Error = awserr.NewRequestFailure(
awserr.New(code, jsonErr.Message, nil),
r.HTTPResponse.StatusCode,
r.RequestID,
)
}
type jsonErrorResponse struct {
Code string `json:"code"`
Message string `json:"message"`
}

View file

@ -1,246 +0,0 @@
// THIS FILE IS AUTOMATICALLY GENERATED. DO NOT EDIT.
// Package s3iface provides an interface for the Amazon Simple Storage Service.
package s3iface
import (
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/s3"
)
// S3API is the interface type for s3.S3.
type S3API interface {
AbortMultipartUploadRequest(*s3.AbortMultipartUploadInput) (*request.Request, *s3.AbortMultipartUploadOutput)
AbortMultipartUpload(*s3.AbortMultipartUploadInput) (*s3.AbortMultipartUploadOutput, error)
CompleteMultipartUploadRequest(*s3.CompleteMultipartUploadInput) (*request.Request, *s3.CompleteMultipartUploadOutput)
CompleteMultipartUpload(*s3.CompleteMultipartUploadInput) (*s3.CompleteMultipartUploadOutput, error)
CopyObjectRequest(*s3.CopyObjectInput) (*request.Request, *s3.CopyObjectOutput)
CopyObject(*s3.CopyObjectInput) (*s3.CopyObjectOutput, error)
CreateBucketRequest(*s3.CreateBucketInput) (*request.Request, *s3.CreateBucketOutput)
CreateBucket(*s3.CreateBucketInput) (*s3.CreateBucketOutput, error)
CreateMultipartUploadRequest(*s3.CreateMultipartUploadInput) (*request.Request, *s3.CreateMultipartUploadOutput)
CreateMultipartUpload(*s3.CreateMultipartUploadInput) (*s3.CreateMultipartUploadOutput, error)
DeleteBucketRequest(*s3.DeleteBucketInput) (*request.Request, *s3.DeleteBucketOutput)
DeleteBucket(*s3.DeleteBucketInput) (*s3.DeleteBucketOutput, error)
DeleteBucketCorsRequest(*s3.DeleteBucketCorsInput) (*request.Request, *s3.DeleteBucketCorsOutput)
DeleteBucketCors(*s3.DeleteBucketCorsInput) (*s3.DeleteBucketCorsOutput, error)
DeleteBucketLifecycleRequest(*s3.DeleteBucketLifecycleInput) (*request.Request, *s3.DeleteBucketLifecycleOutput)
DeleteBucketLifecycle(*s3.DeleteBucketLifecycleInput) (*s3.DeleteBucketLifecycleOutput, error)
DeleteBucketPolicyRequest(*s3.DeleteBucketPolicyInput) (*request.Request, *s3.DeleteBucketPolicyOutput)
DeleteBucketPolicy(*s3.DeleteBucketPolicyInput) (*s3.DeleteBucketPolicyOutput, error)
DeleteBucketReplicationRequest(*s3.DeleteBucketReplicationInput) (*request.Request, *s3.DeleteBucketReplicationOutput)
DeleteBucketReplication(*s3.DeleteBucketReplicationInput) (*s3.DeleteBucketReplicationOutput, error)
DeleteBucketTaggingRequest(*s3.DeleteBucketTaggingInput) (*request.Request, *s3.DeleteBucketTaggingOutput)
DeleteBucketTagging(*s3.DeleteBucketTaggingInput) (*s3.DeleteBucketTaggingOutput, error)
DeleteBucketWebsiteRequest(*s3.DeleteBucketWebsiteInput) (*request.Request, *s3.DeleteBucketWebsiteOutput)
DeleteBucketWebsite(*s3.DeleteBucketWebsiteInput) (*s3.DeleteBucketWebsiteOutput, error)
DeleteObjectRequest(*s3.DeleteObjectInput) (*request.Request, *s3.DeleteObjectOutput)
DeleteObject(*s3.DeleteObjectInput) (*s3.DeleteObjectOutput, error)
DeleteObjectsRequest(*s3.DeleteObjectsInput) (*request.Request, *s3.DeleteObjectsOutput)
DeleteObjects(*s3.DeleteObjectsInput) (*s3.DeleteObjectsOutput, error)
GetBucketAclRequest(*s3.GetBucketAclInput) (*request.Request, *s3.GetBucketAclOutput)
GetBucketAcl(*s3.GetBucketAclInput) (*s3.GetBucketAclOutput, error)
GetBucketCorsRequest(*s3.GetBucketCorsInput) (*request.Request, *s3.GetBucketCorsOutput)
GetBucketCors(*s3.GetBucketCorsInput) (*s3.GetBucketCorsOutput, error)
GetBucketLifecycleRequest(*s3.GetBucketLifecycleInput) (*request.Request, *s3.GetBucketLifecycleOutput)
GetBucketLifecycle(*s3.GetBucketLifecycleInput) (*s3.GetBucketLifecycleOutput, error)
GetBucketLifecycleConfigurationRequest(*s3.GetBucketLifecycleConfigurationInput) (*request.Request, *s3.GetBucketLifecycleConfigurationOutput)
GetBucketLifecycleConfiguration(*s3.GetBucketLifecycleConfigurationInput) (*s3.GetBucketLifecycleConfigurationOutput, error)
GetBucketLocationRequest(*s3.GetBucketLocationInput) (*request.Request, *s3.GetBucketLocationOutput)
GetBucketLocation(*s3.GetBucketLocationInput) (*s3.GetBucketLocationOutput, error)
GetBucketLoggingRequest(*s3.GetBucketLoggingInput) (*request.Request, *s3.GetBucketLoggingOutput)
GetBucketLogging(*s3.GetBucketLoggingInput) (*s3.GetBucketLoggingOutput, error)
GetBucketNotificationRequest(*s3.GetBucketNotificationConfigurationRequest) (*request.Request, *s3.NotificationConfigurationDeprecated)
GetBucketNotification(*s3.GetBucketNotificationConfigurationRequest) (*s3.NotificationConfigurationDeprecated, error)
GetBucketNotificationConfigurationRequest(*s3.GetBucketNotificationConfigurationRequest) (*request.Request, *s3.NotificationConfiguration)
GetBucketNotificationConfiguration(*s3.GetBucketNotificationConfigurationRequest) (*s3.NotificationConfiguration, error)
GetBucketPolicyRequest(*s3.GetBucketPolicyInput) (*request.Request, *s3.GetBucketPolicyOutput)
GetBucketPolicy(*s3.GetBucketPolicyInput) (*s3.GetBucketPolicyOutput, error)
GetBucketReplicationRequest(*s3.GetBucketReplicationInput) (*request.Request, *s3.GetBucketReplicationOutput)
GetBucketReplication(*s3.GetBucketReplicationInput) (*s3.GetBucketReplicationOutput, error)
GetBucketRequestPaymentRequest(*s3.GetBucketRequestPaymentInput) (*request.Request, *s3.GetBucketRequestPaymentOutput)
GetBucketRequestPayment(*s3.GetBucketRequestPaymentInput) (*s3.GetBucketRequestPaymentOutput, error)
GetBucketTaggingRequest(*s3.GetBucketTaggingInput) (*request.Request, *s3.GetBucketTaggingOutput)
GetBucketTagging(*s3.GetBucketTaggingInput) (*s3.GetBucketTaggingOutput, error)
GetBucketVersioningRequest(*s3.GetBucketVersioningInput) (*request.Request, *s3.GetBucketVersioningOutput)
GetBucketVersioning(*s3.GetBucketVersioningInput) (*s3.GetBucketVersioningOutput, error)
GetBucketWebsiteRequest(*s3.GetBucketWebsiteInput) (*request.Request, *s3.GetBucketWebsiteOutput)
GetBucketWebsite(*s3.GetBucketWebsiteInput) (*s3.GetBucketWebsiteOutput, error)
GetObjectRequest(*s3.GetObjectInput) (*request.Request, *s3.GetObjectOutput)
GetObject(*s3.GetObjectInput) (*s3.GetObjectOutput, error)
GetObjectAclRequest(*s3.GetObjectAclInput) (*request.Request, *s3.GetObjectAclOutput)
GetObjectAcl(*s3.GetObjectAclInput) (*s3.GetObjectAclOutput, error)
GetObjectTorrentRequest(*s3.GetObjectTorrentInput) (*request.Request, *s3.GetObjectTorrentOutput)
GetObjectTorrent(*s3.GetObjectTorrentInput) (*s3.GetObjectTorrentOutput, error)
HeadBucketRequest(*s3.HeadBucketInput) (*request.Request, *s3.HeadBucketOutput)
HeadBucket(*s3.HeadBucketInput) (*s3.HeadBucketOutput, error)
HeadObjectRequest(*s3.HeadObjectInput) (*request.Request, *s3.HeadObjectOutput)
HeadObject(*s3.HeadObjectInput) (*s3.HeadObjectOutput, error)
ListBucketsRequest(*s3.ListBucketsInput) (*request.Request, *s3.ListBucketsOutput)
ListBuckets(*s3.ListBucketsInput) (*s3.ListBucketsOutput, error)
ListMultipartUploadsRequest(*s3.ListMultipartUploadsInput) (*request.Request, *s3.ListMultipartUploadsOutput)
ListMultipartUploads(*s3.ListMultipartUploadsInput) (*s3.ListMultipartUploadsOutput, error)
ListMultipartUploadsPages(*s3.ListMultipartUploadsInput, func(*s3.ListMultipartUploadsOutput, bool) bool) error
ListObjectVersionsRequest(*s3.ListObjectVersionsInput) (*request.Request, *s3.ListObjectVersionsOutput)
ListObjectVersions(*s3.ListObjectVersionsInput) (*s3.ListObjectVersionsOutput, error)
ListObjectVersionsPages(*s3.ListObjectVersionsInput, func(*s3.ListObjectVersionsOutput, bool) bool) error
ListObjectsRequest(*s3.ListObjectsInput) (*request.Request, *s3.ListObjectsOutput)
ListObjects(*s3.ListObjectsInput) (*s3.ListObjectsOutput, error)
ListObjectsPages(*s3.ListObjectsInput, func(*s3.ListObjectsOutput, bool) bool) error
ListPartsRequest(*s3.ListPartsInput) (*request.Request, *s3.ListPartsOutput)
ListParts(*s3.ListPartsInput) (*s3.ListPartsOutput, error)
ListPartsPages(*s3.ListPartsInput, func(*s3.ListPartsOutput, bool) bool) error
PutBucketAclRequest(*s3.PutBucketAclInput) (*request.Request, *s3.PutBucketAclOutput)
PutBucketAcl(*s3.PutBucketAclInput) (*s3.PutBucketAclOutput, error)
PutBucketCorsRequest(*s3.PutBucketCorsInput) (*request.Request, *s3.PutBucketCorsOutput)
PutBucketCors(*s3.PutBucketCorsInput) (*s3.PutBucketCorsOutput, error)
PutBucketLifecycleRequest(*s3.PutBucketLifecycleInput) (*request.Request, *s3.PutBucketLifecycleOutput)
PutBucketLifecycle(*s3.PutBucketLifecycleInput) (*s3.PutBucketLifecycleOutput, error)
PutBucketLifecycleConfigurationRequest(*s3.PutBucketLifecycleConfigurationInput) (*request.Request, *s3.PutBucketLifecycleConfigurationOutput)
PutBucketLifecycleConfiguration(*s3.PutBucketLifecycleConfigurationInput) (*s3.PutBucketLifecycleConfigurationOutput, error)
PutBucketLoggingRequest(*s3.PutBucketLoggingInput) (*request.Request, *s3.PutBucketLoggingOutput)
PutBucketLogging(*s3.PutBucketLoggingInput) (*s3.PutBucketLoggingOutput, error)
PutBucketNotificationRequest(*s3.PutBucketNotificationInput) (*request.Request, *s3.PutBucketNotificationOutput)
PutBucketNotification(*s3.PutBucketNotificationInput) (*s3.PutBucketNotificationOutput, error)
PutBucketNotificationConfigurationRequest(*s3.PutBucketNotificationConfigurationInput) (*request.Request, *s3.PutBucketNotificationConfigurationOutput)
PutBucketNotificationConfiguration(*s3.PutBucketNotificationConfigurationInput) (*s3.PutBucketNotificationConfigurationOutput, error)
PutBucketPolicyRequest(*s3.PutBucketPolicyInput) (*request.Request, *s3.PutBucketPolicyOutput)
PutBucketPolicy(*s3.PutBucketPolicyInput) (*s3.PutBucketPolicyOutput, error)
PutBucketReplicationRequest(*s3.PutBucketReplicationInput) (*request.Request, *s3.PutBucketReplicationOutput)
PutBucketReplication(*s3.PutBucketReplicationInput) (*s3.PutBucketReplicationOutput, error)
PutBucketRequestPaymentRequest(*s3.PutBucketRequestPaymentInput) (*request.Request, *s3.PutBucketRequestPaymentOutput)
PutBucketRequestPayment(*s3.PutBucketRequestPaymentInput) (*s3.PutBucketRequestPaymentOutput, error)
PutBucketTaggingRequest(*s3.PutBucketTaggingInput) (*request.Request, *s3.PutBucketTaggingOutput)
PutBucketTagging(*s3.PutBucketTaggingInput) (*s3.PutBucketTaggingOutput, error)
PutBucketVersioningRequest(*s3.PutBucketVersioningInput) (*request.Request, *s3.PutBucketVersioningOutput)
PutBucketVersioning(*s3.PutBucketVersioningInput) (*s3.PutBucketVersioningOutput, error)
PutBucketWebsiteRequest(*s3.PutBucketWebsiteInput) (*request.Request, *s3.PutBucketWebsiteOutput)
PutBucketWebsite(*s3.PutBucketWebsiteInput) (*s3.PutBucketWebsiteOutput, error)
PutObjectRequest(*s3.PutObjectInput) (*request.Request, *s3.PutObjectOutput)
PutObject(*s3.PutObjectInput) (*s3.PutObjectOutput, error)
PutObjectAclRequest(*s3.PutObjectAclInput) (*request.Request, *s3.PutObjectAclOutput)
PutObjectAcl(*s3.PutObjectAclInput) (*s3.PutObjectAclOutput, error)
RestoreObjectRequest(*s3.RestoreObjectInput) (*request.Request, *s3.RestoreObjectOutput)
RestoreObject(*s3.RestoreObjectInput) (*s3.RestoreObjectOutput, error)
UploadPartRequest(*s3.UploadPartInput) (*request.Request, *s3.UploadPartOutput)
UploadPart(*s3.UploadPartInput) (*s3.UploadPartOutput, error)
UploadPartCopyRequest(*s3.UploadPartCopyInput) (*request.Request, *s3.UploadPartCopyOutput)
UploadPartCopy(*s3.UploadPartCopyInput) (*s3.UploadPartCopyOutput, error)
}
var _ S3API = (*s3.S3)(nil)

View file

@ -1,3 +0,0 @@
// Package s3manager provides utilities to upload and download objects from
// S3 concurrently. Helpful for when working with large objects.
package s3manager

View file

@ -1,354 +0,0 @@
package s3manager
import (
"fmt"
"io"
"net/http"
"strconv"
"strings"
"sync"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3iface"
)
// DefaultDownloadPartSize is the default range of bytes to get at a time when
// using Download().
const DefaultDownloadPartSize = 1024 * 1024 * 5
// DefaultDownloadConcurrency is the default number of goroutines to spin up
// when using Download().
const DefaultDownloadConcurrency = 5
// The Downloader structure that calls Download(). It is safe to call Download()
// on this structure for multiple objects and across concurrent goroutines.
// Mutating the Downloader's properties is not safe to be done concurrently.
type Downloader struct {
// The buffer size (in bytes) to use when buffering data into chunks and
// sending them as parts to S3. The minimum allowed part size is 5MB, and
// if this value is set to zero, the DefaultPartSize value will be used.
PartSize int64
// The number of goroutines to spin up in parallel when sending parts.
// If this is set to zero, the DefaultConcurrency value will be used.
Concurrency int
// An S3 client to use when performing downloads.
S3 s3iface.S3API
}
// NewDownloader creates a new Downloader instance to downloads objects from
// S3 in concurrent chunks. Pass in additional functional options to customize
// the downloader behavior. Requires a client.ConfigProvider in order to create
// a S3 service client. The session.Session satisfies the client.ConfigProvider
// interface.
//
// Example:
// // The session the S3 Downloader will use
// sess := session.New()
//
// // Create a downloader with the session and default options
// downloader := s3manager.NewDownloader(sess)
//
// // Create a downloader with the session and custom options
// downloader := s3manager.NewDownloader(sess, func(d *s3manager.Uploader) {
// d.PartSize = 64 * 1024 * 1024 // 64MB per part
// })
func NewDownloader(c client.ConfigProvider, options ...func(*Downloader)) *Downloader {
d := &Downloader{
S3: s3.New(c),
PartSize: DefaultDownloadPartSize,
Concurrency: DefaultDownloadConcurrency,
}
for _, option := range options {
option(d)
}
return d
}
// NewDownloaderWithClient creates a new Downloader instance to downloads
// objects from S3 in concurrent chunks. Pass in additional functional
// options to customize the downloader behavior. Requires a S3 service client
// to make S3 API calls.
//
// Example:
// // The S3 client the S3 Downloader will use
// s3Svc := s3.new(session.New())
//
// // Create a downloader with the s3 client and default options
// downloader := s3manager.NewDownloaderWithClient(s3Svc)
//
// // Create a downloader with the s3 client and custom options
// downloader := s3manager.NewDownloaderWithClient(s3Svc, func(d *s3manager.Uploader) {
// d.PartSize = 64 * 1024 * 1024 // 64MB per part
// })
func NewDownloaderWithClient(svc s3iface.S3API, options ...func(*Downloader)) *Downloader {
d := &Downloader{
S3: svc,
PartSize: DefaultDownloadPartSize,
Concurrency: DefaultDownloadConcurrency,
}
for _, option := range options {
option(d)
}
return d
}
// Download downloads an object in S3 and writes the payload into w using
// concurrent GET requests.
//
// Additional functional options can be provided to configure the individual
// upload. These options are copies of the Uploader instance Upload is called from.
// Modifying the options will not impact the original Uploader instance.
//
// It is safe to call this method concurrently across goroutines.
//
// The w io.WriterAt can be satisfied by an os.File to do multipart concurrent
// downloads, or in memory []byte wrapper using aws.WriteAtBuffer.
func (d Downloader) Download(w io.WriterAt, input *s3.GetObjectInput, options ...func(*Downloader)) (n int64, err error) {
impl := downloader{w: w, in: input, ctx: d}
for _, option := range options {
option(&impl.ctx)
}
return impl.download()
}
// downloader is the implementation structure used internally by Downloader.
type downloader struct {
ctx Downloader
in *s3.GetObjectInput
w io.WriterAt
wg sync.WaitGroup
m sync.Mutex
pos int64
totalBytes int64
written int64
err error
}
// init initializes the downloader with default options.
func (d *downloader) init() {
d.totalBytes = -1
if d.ctx.Concurrency == 0 {
d.ctx.Concurrency = DefaultDownloadConcurrency
}
if d.ctx.PartSize == 0 {
d.ctx.PartSize = DefaultDownloadPartSize
}
}
// download performs the implementation of the object download across ranged
// GETs.
func (d *downloader) download() (n int64, err error) {
d.init()
// Spin off first worker to check additional header information
d.getChunk()
if total := d.getTotalBytes(); total >= 0 {
// Spin up workers
ch := make(chan dlchunk, d.ctx.Concurrency)
for i := 0; i < d.ctx.Concurrency; i++ {
d.wg.Add(1)
go d.downloadPart(ch)
}
// Assign work
for d.getErr() == nil {
if d.pos >= total {
break // We're finished queueing chunks
}
// Queue the next range of bytes to read.
ch <- dlchunk{w: d.w, start: d.pos, size: d.ctx.PartSize}
d.pos += d.ctx.PartSize
}
// Wait for completion
close(ch)
d.wg.Wait()
} else {
// Checking if we read anything new
for d.err == nil {
d.getChunk()
}
// We expect a 416 error letting us know we are done downloading the
// total bytes. Since we do not know the content's length, this will
// keep grabbing chunks of data until the range of bytes specified in
// the request is out of range of the content. Once, this happens, a
// 416 should occur.
e, ok := d.err.(awserr.RequestFailure)
if ok && e.StatusCode() == http.StatusRequestedRangeNotSatisfiable {
d.err = nil
}
}
// Return error
return d.written, d.err
}
// downloadPart is an individual goroutine worker reading from the ch channel
// and performing a GetObject request on the data with a given byte range.
//
// If this is the first worker, this operation also resolves the total number
// of bytes to be read so that the worker manager knows when it is finished.
func (d *downloader) downloadPart(ch chan dlchunk) {
defer d.wg.Done()
for {
chunk, ok := <-ch
if !ok {
break
}
d.downloadChunk(chunk)
}
}
// getChunk grabs a chunk of data from the body.
// Not thread safe. Should only used when grabbing data on a single thread.
func (d *downloader) getChunk() {
chunk := dlchunk{w: d.w, start: d.pos, size: d.ctx.PartSize}
d.pos += d.ctx.PartSize
d.downloadChunk(chunk)
}
// downloadChunk downloads the chunk froom s3
func (d *downloader) downloadChunk(chunk dlchunk) {
if d.getErr() != nil {
return
}
// Get the next byte range of data
in := &s3.GetObjectInput{}
awsutil.Copy(in, d.in)
rng := fmt.Sprintf("bytes=%d-%d",
chunk.start, chunk.start+chunk.size-1)
in.Range = &rng
req, resp := d.ctx.S3.GetObjectRequest(in)
req.Handlers.Build.PushBack(request.MakeAddToUserAgentFreeFormHandler("S3Manager"))
err := req.Send()
if err != nil {
d.setErr(err)
} else {
d.setTotalBytes(resp) // Set total if not yet set.
n, err := io.Copy(&chunk, resp.Body)
resp.Body.Close()
if err != nil {
d.setErr(err)
}
d.incrWritten(n)
}
}
// getTotalBytes is a thread-safe getter for retrieving the total byte status.
func (d *downloader) getTotalBytes() int64 {
d.m.Lock()
defer d.m.Unlock()
return d.totalBytes
}
// setTotalBytes is a thread-safe setter for setting the total byte status.
// Will extract the object's total bytes from the Content-Range if the file
// will be chunked, or Content-Length. Content-Length is used when the response
// does not include a Content-Range. Meaning the object was not chunked. This
// occurs when the full file fits within the PartSize directive.
func (d *downloader) setTotalBytes(resp *s3.GetObjectOutput) {
d.m.Lock()
defer d.m.Unlock()
if d.totalBytes >= 0 {
return
}
if resp.ContentRange == nil {
// ContentRange is nil when the full file contents is provied, and
// is not chunked. Use ContentLength instead.
if resp.ContentLength != nil {
d.totalBytes = *resp.ContentLength
return
}
} else {
parts := strings.Split(*resp.ContentRange, "/")
total := int64(-1)
var err error
// Checking for whether or not a numbered total exists
// If one does not exist, we will assume the total to be -1, undefined,
// and sequentially download each chunk until hitting a 416 error
totalStr := parts[len(parts)-1]
if totalStr != "*" {
total, err = strconv.ParseInt(totalStr, 10, 64)
if err != nil {
d.err = err
return
}
}
d.totalBytes = total
}
}
func (d *downloader) incrWritten(n int64) {
d.m.Lock()
defer d.m.Unlock()
d.written += n
}
// getErr is a thread-safe getter for the error object
func (d *downloader) getErr() error {
d.m.Lock()
defer d.m.Unlock()
return d.err
}
// setErr is a thread-safe setter for the error object
func (d *downloader) setErr(e error) {
d.m.Lock()
defer d.m.Unlock()
d.err = e
}
// dlchunk represents a single chunk of data to write by the worker routine.
// This structure also implements an io.SectionReader style interface for
// io.WriterAt, effectively making it an io.SectionWriter (which does not
// exist).
type dlchunk struct {
w io.WriterAt
start int64
size int64
cur int64
}
// Write wraps io.WriterAt for the dlchunk, writing from the dlchunk's start
// position to its end (or EOF).
func (c *dlchunk) Write(p []byte) (n int, err error) {
if c.cur >= c.size {
return 0, io.EOF
}
n, err = c.w.WriteAt(p, c.start+c.cur)
c.cur += int64(n)
return
}

View file

@ -1,23 +0,0 @@
// Package s3manageriface provides an interface for the s3manager package
package s3manageriface
import (
"io"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
)
// DownloaderAPI is the interface type for s3manager.Downloader.
type DownloaderAPI interface {
Download(io.WriterAt, *s3.GetObjectInput, ...func(*s3manager.Downloader)) (int64, error)
}
var _ DownloaderAPI = (*s3manager.Downloader)(nil)
// UploaderAPI is the interface type for s3manager.Uploader.
type UploaderAPI interface {
Upload(*s3manager.UploadInput, ...func(*s3manager.Uploader)) (*s3manager.UploadOutput, error)
}
var _ UploaderAPI = (*s3manager.Uploader)(nil)

View file

@ -1,661 +0,0 @@
package s3manager
import (
"bytes"
"fmt"
"io"
"sort"
"sync"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3iface"
)
// MaxUploadParts is the maximum allowed number of parts in a multi-part upload
// on Amazon S3.
const MaxUploadParts = 10000
// MinUploadPartSize is the minimum allowed part size when uploading a part to
// Amazon S3.
const MinUploadPartSize int64 = 1024 * 1024 * 5
// DefaultUploadPartSize is the default part size to buffer chunks of a
// payload into.
const DefaultUploadPartSize = MinUploadPartSize
// DefaultUploadConcurrency is the default number of goroutines to spin up when
// using Upload().
const DefaultUploadConcurrency = 5
// A MultiUploadFailure wraps a failed S3 multipart upload. An error returned
// will satisfy this interface when a multi part upload failed to upload all
// chucks to S3. In the case of a failure the UploadID is needed to operate on
// the chunks, if any, which were uploaded.
//
// Example:
//
// u := s3manager.NewUploader(opts)
// output, err := u.upload(input)
// if err != nil {
// if multierr, ok := err.(MultiUploadFailure); ok {
// // Process error and its associated uploadID
// fmt.Println("Error:", multierr.Code(), multierr.Message(), multierr.UploadID())
// } else {
// // Process error generically
// fmt.Println("Error:", err.Error())
// }
// }
//
type MultiUploadFailure interface {
awserr.Error
// Returns the upload id for the S3 multipart upload that failed.
UploadID() string
}
// So that the Error interface type can be included as an anonymous field
// in the multiUploadError struct and not conflict with the error.Error() method.
type awsError awserr.Error
// A multiUploadError wraps the upload ID of a failed s3 multipart upload.
// Composed of BaseError for code, message, and original error
//
// Should be used for an error that occurred failing a S3 multipart upload,
// and a upload ID is available. If an uploadID is not available a more relevant
type multiUploadError struct {
awsError
// ID for multipart upload which failed.
uploadID string
}
// Error returns the string representation of the error.
//
// See apierr.BaseError ErrorWithExtra for output format
//
// Satisfies the error interface.
func (m multiUploadError) Error() string {
extra := fmt.Sprintf("upload id: %s", m.uploadID)
return awserr.SprintError(m.Code(), m.Message(), extra, m.OrigErr())
}
// String returns the string representation of the error.
// Alias for Error to satisfy the stringer interface.
func (m multiUploadError) String() string {
return m.Error()
}
// UploadID returns the id of the S3 upload which failed.
func (m multiUploadError) UploadID() string {
return m.uploadID
}
// UploadInput contains all input for upload requests to Amazon S3.
type UploadInput struct {
// The canned ACL to apply to the object.
ACL *string `location:"header" locationName:"x-amz-acl" type:"string"`
Bucket *string `location:"uri" locationName:"Bucket" type:"string" required:"true"`
// Specifies caching behavior along the request/reply chain.
CacheControl *string `location:"header" locationName:"Cache-Control" type:"string"`
// Specifies presentational information for the object.
ContentDisposition *string `location:"header" locationName:"Content-Disposition" type:"string"`
// Specifies what content encodings have been applied to the object and thus
// what decoding mechanisms must be applied to obtain the media-type referenced
// by the Content-Type header field.
ContentEncoding *string `location:"header" locationName:"Content-Encoding" type:"string"`
// The language the content is in.
ContentLanguage *string `location:"header" locationName:"Content-Language" type:"string"`
// A standard MIME type describing the format of the object data.
ContentType *string `location:"header" locationName:"Content-Type" type:"string"`
// The date and time at which the object is no longer cacheable.
Expires *time.Time `location:"header" locationName:"Expires" type:"timestamp" timestampFormat:"rfc822"`
// Gives the grantee READ, READ_ACP, and WRITE_ACP permissions on the object.
GrantFullControl *string `location:"header" locationName:"x-amz-grant-full-control" type:"string"`
// Allows grantee to read the object data and its metadata.
GrantRead *string `location:"header" locationName:"x-amz-grant-read" type:"string"`
// Allows grantee to read the object ACL.
GrantReadACP *string `location:"header" locationName:"x-amz-grant-read-acp" type:"string"`
// Allows grantee to write the ACL for the applicable object.
GrantWriteACP *string `location:"header" locationName:"x-amz-grant-write-acp" type:"string"`
Key *string `location:"uri" locationName:"Key" type:"string" required:"true"`
// A map of metadata to store with the object in S3.
Metadata map[string]*string `location:"headers" locationName:"x-amz-meta-" type:"map"`
// Confirms that the requester knows that she or he will be charged for the
// request. Bucket owners need not specify this parameter in their requests.
// Documentation on downloading objects from requester pays buckets can be found
// at http://docs.aws.amazon.com/AmazonS3/latest/dev/ObjectsinRequesterPaysBuckets.html
RequestPayer *string `location:"header" locationName:"x-amz-request-payer" type:"string"`
// Specifies the algorithm to use to when encrypting the object (e.g., AES256,
// aws:kms).
SSECustomerAlgorithm *string `location:"header" locationName:"x-amz-server-side-encryption-customer-algorithm" type:"string"`
// Specifies the customer-provided encryption key for Amazon S3 to use in encrypting
// data. This value is used to store the object and then it is discarded; Amazon
// does not store the encryption key. The key must be appropriate for use with
// the algorithm specified in the x-amz-server-side-encryption-customer-algorithm
// header.
SSECustomerKey *string `location:"header" locationName:"x-amz-server-side-encryption-customer-key" type:"string"`
// Specifies the 128-bit MD5 digest of the encryption key according to RFC 1321.
// Amazon S3 uses this header for a message integrity check to ensure the encryption
// key was transmitted without error.
SSECustomerKeyMD5 *string `location:"header" locationName:"x-amz-server-side-encryption-customer-key-MD5" type:"string"`
// Specifies the AWS KMS key ID to use for object encryption. All GET and PUT
// requests for an object protected by AWS KMS will fail if not made via SSL
// or using SigV4. Documentation on configuring any of the officially supported
// AWS SDKs and CLI can be found at http://docs.aws.amazon.com/AmazonS3/latest/dev/UsingAWSSDK.html#specify-signature-version
SSEKMSKeyID *string `location:"header" locationName:"x-amz-server-side-encryption-aws-kms-key-id" type:"string"`
// The Server-side encryption algorithm used when storing this object in S3
// (e.g., AES256, aws:kms).
ServerSideEncryption *string `location:"header" locationName:"x-amz-server-side-encryption" type:"string"`
// The type of storage to use for the object. Defaults to 'STANDARD'.
StorageClass *string `location:"header" locationName:"x-amz-storage-class" type:"string"`
// If the bucket is configured as a website, redirects requests for this object
// to another object in the same bucket or to an external URL. Amazon S3 stores
// the value of this header in the object metadata.
WebsiteRedirectLocation *string `location:"header" locationName:"x-amz-website-redirect-location" type:"string"`
// The readable body payload to send to S3.
Body io.Reader
}
// UploadOutput represents a response from the Upload() call.
type UploadOutput struct {
// The URL where the object was uploaded to.
Location string
// The version of the object that was uploaded. Will only be populated if
// the S3 Bucket is versioned. If the bucket is not versioned this field
// will not be set.
VersionID *string
// The ID for a multipart upload to S3. In the case of an error the error
// can be cast to the MultiUploadFailure interface to extract the upload ID.
UploadID string
}
// The Uploader structure that calls Upload(). It is safe to call Upload()
// on this structure for multiple objects and across concurrent goroutines.
// Mutating the Uploader's properties is not safe to be done concurrently.
type Uploader struct {
// The buffer size (in bytes) to use when buffering data into chunks and
// sending them as parts to S3. The minimum allowed part size is 5MB, and
// if this value is set to zero, the DefaultPartSize value will be used.
PartSize int64
// The number of goroutines to spin up in parallel when sending parts.
// If this is set to zero, the DefaultConcurrency value will be used.
Concurrency int
// Setting this value to true will cause the SDK to avoid calling
// AbortMultipartUpload on a failure, leaving all successfully uploaded
// parts on S3 for manual recovery.
//
// Note that storing parts of an incomplete multipart upload counts towards
// space usage on S3 and will add additional costs if not cleaned up.
LeavePartsOnError bool
// MaxUploadParts is the max number of parts which will be uploaded to S3.
// Will be used to calculate the partsize of the object to be uploaded.
// E.g: 5GB file, with MaxUploadParts set to 100, will upload the file
// as 100, 50MB parts.
// With a limited of s3.MaxUploadParts (10,000 parts).
MaxUploadParts int
// The client to use when uploading to S3.
S3 s3iface.S3API
}
// NewUploader creates a new Uploader instance to upload objects to S3. Pass In
// additional functional options to customize the uploader's behavior. Requires a
// client.ConfigProvider in order to create a S3 service client. The session.Session
// satisfies the client.ConfigProvider interface.
//
// Example:
// // The session the S3 Uploader will use
// sess := session.New()
//
// // Create an uploader with the session and default options
// uploader := s3manager.NewUploader(sess)
//
// // Create an uploader with the session and custom options
// uploader := s3manager.NewUploader(session, func(u *s3manager.Uploader) {
// u.PartSize = 64 * 1024 * 1024 // 64MB per part
// })
func NewUploader(c client.ConfigProvider, options ...func(*Uploader)) *Uploader {
u := &Uploader{
S3: s3.New(c),
PartSize: DefaultUploadPartSize,
Concurrency: DefaultUploadConcurrency,
LeavePartsOnError: false,
MaxUploadParts: MaxUploadParts,
}
for _, option := range options {
option(u)
}
return u
}
// NewUploaderWithClient creates a new Uploader instance to upload objects to S3. Pass in
// additional functional options to customize the uploader's behavior. Requires
// a S3 service client to make S3 API calls.
//
// Example:
// // S3 service client the Upload manager will use.
// s3Svc := s3.New(session.New())
//
// // Create an uploader with S3 client and default options
// uploader := s3manager.NewUploaderWithClient(s3Svc)
//
// // Create an uploader with S3 client and custom options
// uploader := s3manager.NewUploaderWithClient(s3Svc, func(u *s3manager.Uploader) {
// u.PartSize = 64 * 1024 * 1024 // 64MB per part
// })
func NewUploaderWithClient(svc s3iface.S3API, options ...func(*Uploader)) *Uploader {
u := &Uploader{
S3: svc,
PartSize: DefaultUploadPartSize,
Concurrency: DefaultUploadConcurrency,
LeavePartsOnError: false,
MaxUploadParts: MaxUploadParts,
}
for _, option := range options {
option(u)
}
return u
}
// Upload uploads an object to S3, intelligently buffering large files into
// smaller chunks and sending them in parallel across multiple goroutines. You
// can configure the buffer size and concurrency through the Uploader's parameters.
//
// Additional functional options can be provided to configure the individual
// upload. These options are copies of the Uploader instance Upload is called from.
// Modifying the options will not impact the original Uploader instance.
//
// It is safe to call this method concurrently across goroutines.
//
// Example:
// // Upload input parameters
// upParams := &s3manager.UploadInput{
// Bucket: &bucketName,
// Key: &keyName,
// Body: file,
// }
//
// // Perform an upload.
// result, err := uploader.Upload(upParams)
//
// // Perform upload with options different than the those in the Uploader.
// result, err := uploader.Upload(upParams, func(u *s3manager.Uploader) {
// u.PartSize = 10 * 1024 * 1024 // 10MB part size
// u.LeavePartsOnError = true // Dont delete the parts if the upload fails.
// })
func (u Uploader) Upload(input *UploadInput, options ...func(*Uploader)) (*UploadOutput, error) {
i := uploader{in: input, ctx: u}
for _, option := range options {
option(&i.ctx)
}
return i.upload()
}
// internal structure to manage an upload to S3.
type uploader struct {
ctx Uploader
in *UploadInput
readerPos int64 // current reader position
totalSize int64 // set to -1 if the size is not known
}
// internal logic for deciding whether to upload a single part or use a
// multipart upload.
func (u *uploader) upload() (*UploadOutput, error) {
u.init()
if u.ctx.PartSize < MinUploadPartSize {
msg := fmt.Sprintf("part size must be at least %d bytes", MinUploadPartSize)
return nil, awserr.New("ConfigError", msg, nil)
}
// Do one read to determine if we have more than one part
buf, err := u.nextReader()
if err == io.EOF || err == io.ErrUnexpectedEOF { // single part
return u.singlePart(buf)
} else if err != nil {
return nil, awserr.New("ReadRequestBody", "read upload data failed", err)
}
mu := multiuploader{uploader: u}
return mu.upload(buf)
}
// init will initialize all default options.
func (u *uploader) init() {
if u.ctx.Concurrency == 0 {
u.ctx.Concurrency = DefaultUploadConcurrency
}
if u.ctx.PartSize == 0 {
u.ctx.PartSize = DefaultUploadPartSize
}
// Try to get the total size for some optimizations
u.initSize()
}
// initSize tries to detect the total stream size, setting u.totalSize. If
// the size is not known, totalSize is set to -1.
func (u *uploader) initSize() {
u.totalSize = -1
switch r := u.in.Body.(type) {
case io.Seeker:
pos, _ := r.Seek(0, 1)
defer r.Seek(pos, 0)
n, err := r.Seek(0, 2)
if err != nil {
return
}
u.totalSize = n
// Try to adjust partSize if it is too small and account for
// integer division truncation.
if u.totalSize/u.ctx.PartSize >= int64(u.ctx.MaxUploadParts) {
// Add one to the part size to account for remainders
// during the size calculation. e.g odd number of bytes.
u.ctx.PartSize = (u.totalSize / int64(u.ctx.MaxUploadParts)) + 1
}
}
}
// nextReader returns a seekable reader representing the next packet of data.
// This operation increases the shared u.readerPos counter, but note that it
// does not need to be wrapped in a mutex because nextReader is only called
// from the main thread.
func (u *uploader) nextReader() (io.ReadSeeker, error) {
switch r := u.in.Body.(type) {
case io.ReaderAt:
var err error
n := u.ctx.PartSize
if u.totalSize >= 0 {
bytesLeft := u.totalSize - u.readerPos
if bytesLeft == 0 {
err = io.EOF
n = bytesLeft
} else if bytesLeft <= u.ctx.PartSize {
err = io.ErrUnexpectedEOF
n = bytesLeft
}
}
buf := io.NewSectionReader(r, u.readerPos, n)
u.readerPos += n
return buf, err
default:
packet := make([]byte, u.ctx.PartSize)
n, err := io.ReadFull(u.in.Body, packet)
u.readerPos += int64(n)
return bytes.NewReader(packet[0:n]), err
}
}
// singlePart contains upload logic for uploading a single chunk via
// a regular PutObject request. Multipart requests require at least two
// parts, or at least 5MB of data.
func (u *uploader) singlePart(buf io.ReadSeeker) (*UploadOutput, error) {
params := &s3.PutObjectInput{}
awsutil.Copy(params, u.in)
params.Body = buf
req, out := u.ctx.S3.PutObjectRequest(params)
req.Handlers.Build.PushBack(request.MakeAddToUserAgentFreeFormHandler("S3Manager"))
if err := req.Send(); err != nil {
return nil, err
}
url := req.HTTPRequest.URL.String()
return &UploadOutput{
Location: url,
VersionID: out.VersionId,
}, nil
}
// internal structure to manage a specific multipart upload to S3.
type multiuploader struct {
*uploader
wg sync.WaitGroup
m sync.Mutex
err error
uploadID string
parts completedParts
}
// keeps track of a single chunk of data being sent to S3.
type chunk struct {
buf io.ReadSeeker
num int64
}
// completedParts is a wrapper to make parts sortable by their part number,
// since S3 required this list to be sent in sorted order.
type completedParts []*s3.CompletedPart
func (a completedParts) Len() int { return len(a) }
func (a completedParts) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a completedParts) Less(i, j int) bool { return *a[i].PartNumber < *a[j].PartNumber }
// upload will perform a multipart upload using the firstBuf buffer containing
// the first chunk of data.
func (u *multiuploader) upload(firstBuf io.ReadSeeker) (*UploadOutput, error) {
params := &s3.CreateMultipartUploadInput{}
awsutil.Copy(params, u.in)
// Create the multipart
req, resp := u.ctx.S3.CreateMultipartUploadRequest(params)
req.Handlers.Build.PushBack(request.MakeAddToUserAgentFreeFormHandler("S3Manager"))
if err := req.Send(); err != nil {
return nil, err
}
u.uploadID = *resp.UploadId
// Create the workers
ch := make(chan chunk, u.ctx.Concurrency)
for i := 0; i < u.ctx.Concurrency; i++ {
u.wg.Add(1)
go u.readChunk(ch)
}
// Send part 1 to the workers
var num int64 = 1
ch <- chunk{buf: firstBuf, num: num}
// Read and queue the rest of the parts
for u.geterr() == nil {
// This upload exceeded maximum number of supported parts, error now.
if num > int64(u.ctx.MaxUploadParts) || num > int64(MaxUploadParts) {
var msg string
if num > int64(u.ctx.MaxUploadParts) {
msg = fmt.Sprintf("exceeded total allowed configured MaxUploadParts (%d). Adjust PartSize to fit in this limit",
u.ctx.MaxUploadParts)
} else {
msg = fmt.Sprintf("exceeded total allowed S3 limit MaxUploadParts (%d). Adjust PartSize to fit in this limit",
MaxUploadParts)
}
u.seterr(awserr.New("TotalPartsExceeded", msg, nil))
break
}
num++
buf, err := u.nextReader()
if err == io.EOF {
break
}
ch <- chunk{buf: buf, num: num}
if err != nil && err != io.ErrUnexpectedEOF {
u.seterr(awserr.New(
"ReadRequestBody",
"read multipart upload data failed",
err))
break
}
}
// Close the channel, wait for workers, and complete upload
close(ch)
u.wg.Wait()
complete := u.complete()
if err := u.geterr(); err != nil {
return nil, &multiUploadError{
awsError: awserr.New(
"MultipartUpload",
"upload multipart failed",
err),
uploadID: u.uploadID,
}
}
return &UploadOutput{
Location: *complete.Location,
VersionID: complete.VersionId,
UploadID: u.uploadID,
}, nil
}
// readChunk runs in worker goroutines to pull chunks off of the ch channel
// and send() them as UploadPart requests.
func (u *multiuploader) readChunk(ch chan chunk) {
defer u.wg.Done()
for {
data, ok := <-ch
if !ok {
break
}
if u.geterr() == nil {
if err := u.send(data); err != nil {
u.seterr(err)
}
}
}
}
// send performs an UploadPart request and keeps track of the completed
// part information.
func (u *multiuploader) send(c chunk) error {
req, resp := u.ctx.S3.UploadPartRequest(&s3.UploadPartInput{
Bucket: u.in.Bucket,
Key: u.in.Key,
Body: c.buf,
UploadId: &u.uploadID,
PartNumber: &c.num,
})
req.Handlers.Build.PushBack(request.MakeAddToUserAgentFreeFormHandler("S3Manager"))
if err := req.Send(); err != nil {
return err
}
n := c.num
completed := &s3.CompletedPart{ETag: resp.ETag, PartNumber: &n}
u.m.Lock()
u.parts = append(u.parts, completed)
u.m.Unlock()
return nil
}
// geterr is a thread-safe getter for the error object
func (u *multiuploader) geterr() error {
u.m.Lock()
defer u.m.Unlock()
return u.err
}
// seterr is a thread-safe setter for the error object
func (u *multiuploader) seterr(e error) {
u.m.Lock()
defer u.m.Unlock()
u.err = e
}
// fail will abort the multipart unless LeavePartsOnError is set to true.
func (u *multiuploader) fail() {
if u.ctx.LeavePartsOnError {
return
}
req, _ := u.ctx.S3.AbortMultipartUploadRequest(&s3.AbortMultipartUploadInput{
Bucket: u.in.Bucket,
Key: u.in.Key,
UploadId: &u.uploadID,
})
req.Handlers.Build.PushBack(request.MakeAddToUserAgentFreeFormHandler("S3Manager"))
req.Send()
}
// complete successfully completes a multipart upload and returns the response.
func (u *multiuploader) complete() *s3.CompleteMultipartUploadOutput {
if u.geterr() != nil {
u.fail()
return nil
}
// Parts must be sorted in PartNumber order.
sort.Sort(u.parts)
req, resp := u.ctx.S3.CompleteMultipartUploadRequest(&s3.CompleteMultipartUploadInput{
Bucket: u.in.Bucket,
Key: u.in.Key,
UploadId: &u.uploadID,
MultipartUpload: &s3.CompletedMultipartUpload{Parts: u.parts},
})
req.Handlers.Build.PushBack(request.MakeAddToUserAgentFreeFormHandler("S3Manager"))
if err := req.Send(); err != nil {
u.seterr(err)
u.fail()
}
return resp
}

View file

@ -1,154 +0,0 @@
// Copyright 2014 The Go Authors.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
package http2
import (
"io"
"reflect"
"testing"
)
var bufferReadTests = []struct {
buf buffer
read, wn int
werr error
wp []byte
wbuf buffer
}{
{
buffer{[]byte{'a', 0}, 0, 1, false, nil},
5, 1, nil, []byte{'a'},
buffer{[]byte{'a', 0}, 1, 1, false, nil},
},
{
buffer{[]byte{'a', 0}, 0, 1, true, io.EOF},
5, 1, io.EOF, []byte{'a'},
buffer{[]byte{'a', 0}, 1, 1, true, io.EOF},
},
{
buffer{[]byte{0, 'a'}, 1, 2, false, nil},
5, 1, nil, []byte{'a'},
buffer{[]byte{0, 'a'}, 2, 2, false, nil},
},
{
buffer{[]byte{0, 'a'}, 1, 2, true, io.EOF},
5, 1, io.EOF, []byte{'a'},
buffer{[]byte{0, 'a'}, 2, 2, true, io.EOF},
},
{
buffer{[]byte{}, 0, 0, false, nil},
5, 0, errReadEmpty, []byte{},
buffer{[]byte{}, 0, 0, false, nil},
},
{
buffer{[]byte{}, 0, 0, true, io.EOF},
5, 0, io.EOF, []byte{},
buffer{[]byte{}, 0, 0, true, io.EOF},
},
}
func TestBufferRead(t *testing.T) {
for i, tt := range bufferReadTests {
read := make([]byte, tt.read)
n, err := tt.buf.Read(read)
if n != tt.wn {
t.Errorf("#%d: wn = %d want %d", i, n, tt.wn)
continue
}
if err != tt.werr {
t.Errorf("#%d: werr = %v want %v", i, err, tt.werr)
continue
}
read = read[:n]
if !reflect.DeepEqual(read, tt.wp) {
t.Errorf("#%d: read = %+v want %+v", i, read, tt.wp)
}
if !reflect.DeepEqual(tt.buf, tt.wbuf) {
t.Errorf("#%d: buf = %+v want %+v", i, tt.buf, tt.wbuf)
}
}
}
var bufferWriteTests = []struct {
buf buffer
write, wn int
werr error
wbuf buffer
}{
{
buf: buffer{
buf: []byte{},
},
wbuf: buffer{
buf: []byte{},
},
},
{
buf: buffer{
buf: []byte{1, 'a'},
},
write: 1,
wn: 1,
wbuf: buffer{
buf: []byte{0, 'a'},
w: 1,
},
},
{
buf: buffer{
buf: []byte{'a', 1},
r: 1,
w: 1,
},
write: 2,
wn: 2,
wbuf: buffer{
buf: []byte{0, 0},
w: 2,
},
},
{
buf: buffer{
buf: []byte{},
r: 1,
closed: true,
},
write: 5,
werr: errWriteClosed,
wbuf: buffer{
buf: []byte{},
r: 1,
closed: true,
},
},
{
buf: buffer{
buf: []byte{},
},
write: 5,
werr: errWriteFull,
wbuf: buffer{
buf: []byte{},
},
},
}
func TestBufferWrite(t *testing.T) {
for i, tt := range bufferWriteTests {
n, err := tt.buf.Write(make([]byte, tt.write))
if n != tt.wn {
t.Errorf("#%d: wrote %d bytes; want %d", i, n, tt.wn)
continue
}
if err != tt.werr {
t.Errorf("#%d: error = %v; want %v", i, err, tt.werr)
continue
}
if !reflect.DeepEqual(tt.buf, tt.wbuf) {
t.Errorf("#%d: buf = %+v; want %+v", i, tt.buf, tt.wbuf)
}
}
}

View file

@ -1,27 +0,0 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
package http2
import "testing"
func TestErrCodeString(t *testing.T) {
tests := []struct {
err ErrCode
want string
}{
{ErrCodeProtocol, "PROTOCOL_ERROR"},
{0xd, "HTTP_1_1_REQUIRED"},
{0xf, "unknown error code 0xf"},
}
for i, tt := range tests {
got := tt.err.String()
if got != tt.want {
t.Errorf("%d. Error = %q; want %q", i, got, tt.want)
}
}
}

View file

@ -1,54 +0,0 @@
// Copyright 2014 The Go Authors.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
package http2
import "testing"
func TestFlow(t *testing.T) {
var st flow
var conn flow
st.add(3)
conn.add(2)
if got, want := st.available(), int32(3); got != want {
t.Errorf("available = %d; want %d", got, want)
}
st.setConnFlow(&conn)
if got, want := st.available(), int32(2); got != want {
t.Errorf("after parent setup, available = %d; want %d", got, want)
}
st.take(2)
if got, want := conn.available(), int32(0); got != want {
t.Errorf("after taking 2, conn = %d; want %d", got, want)
}
if got, want := st.available(), int32(0); got != want {
t.Errorf("after taking 2, stream = %d; want %d", got, want)
}
}
func TestFlowAdd(t *testing.T) {
var f flow
if !f.add(1) {
t.Fatal("failed to add 1")
}
if !f.add(-1) {
t.Fatal("failed to add -1")
}
if got, want := f.available(), int32(0); got != want {
t.Fatalf("size = %d; want %d", got, want)
}
if !f.add(1<<31 - 1) {
t.Fatal("failed to add 2^31-1")
}
if got, want := f.available(), int32(1<<31-1); got != want {
t.Fatalf("size = %d; want %d", got, want)
}
if f.add(1) {
t.Fatal("adding 1 to max shouldn't be allowed")
}
}

View file

@ -1,597 +0,0 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package http2
import (
"bytes"
"reflect"
"strings"
"testing"
"unsafe"
)
func testFramer() (*Framer, *bytes.Buffer) {
buf := new(bytes.Buffer)
return NewFramer(buf, buf), buf
}
func TestFrameSizes(t *testing.T) {
// Catch people rearranging the FrameHeader fields.
if got, want := int(unsafe.Sizeof(FrameHeader{})), 12; got != want {
t.Errorf("FrameHeader size = %d; want %d", got, want)
}
}
func TestFrameTypeString(t *testing.T) {
tests := []struct {
ft FrameType
want string
}{
{FrameData, "DATA"},
{FramePing, "PING"},
{FrameGoAway, "GOAWAY"},
{0xf, "UNKNOWN_FRAME_TYPE_15"},
}
for i, tt := range tests {
got := tt.ft.String()
if got != tt.want {
t.Errorf("%d. String(FrameType %d) = %q; want %q", i, int(tt.ft), got, tt.want)
}
}
}
func TestWriteRST(t *testing.T) {
fr, buf := testFramer()
var streamID uint32 = 1<<24 + 2<<16 + 3<<8 + 4
var errCode uint32 = 7<<24 + 6<<16 + 5<<8 + 4
fr.WriteRSTStream(streamID, ErrCode(errCode))
const wantEnc = "\x00\x00\x04\x03\x00\x01\x02\x03\x04\x07\x06\x05\x04"
if buf.String() != wantEnc {
t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc)
}
f, err := fr.ReadFrame()
if err != nil {
t.Fatal(err)
}
want := &RSTStreamFrame{
FrameHeader: FrameHeader{
valid: true,
Type: 0x3,
Flags: 0x0,
Length: 0x4,
StreamID: 0x1020304,
},
ErrCode: 0x7060504,
}
if !reflect.DeepEqual(f, want) {
t.Errorf("parsed back %#v; want %#v", f, want)
}
}
func TestWriteData(t *testing.T) {
fr, buf := testFramer()
var streamID uint32 = 1<<24 + 2<<16 + 3<<8 + 4
data := []byte("ABC")
fr.WriteData(streamID, true, data)
const wantEnc = "\x00\x00\x03\x00\x01\x01\x02\x03\x04ABC"
if buf.String() != wantEnc {
t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc)
}
f, err := fr.ReadFrame()
if err != nil {
t.Fatal(err)
}
df, ok := f.(*DataFrame)
if !ok {
t.Fatalf("got %T; want *DataFrame", f)
}
if !bytes.Equal(df.Data(), data) {
t.Errorf("got %q; want %q", df.Data(), data)
}
if f.Header().Flags&1 == 0 {
t.Errorf("didn't see END_STREAM flag")
}
}
func TestWriteHeaders(t *testing.T) {
tests := []struct {
name string
p HeadersFrameParam
wantEnc string
wantFrame *HeadersFrame
}{
{
"basic",
HeadersFrameParam{
StreamID: 42,
BlockFragment: []byte("abc"),
Priority: PriorityParam{},
},
"\x00\x00\x03\x01\x00\x00\x00\x00*abc",
&HeadersFrame{
FrameHeader: FrameHeader{
valid: true,
StreamID: 42,
Type: FrameHeaders,
Length: uint32(len("abc")),
},
Priority: PriorityParam{},
headerFragBuf: []byte("abc"),
},
},
{
"basic + end flags",
HeadersFrameParam{
StreamID: 42,
BlockFragment: []byte("abc"),
EndStream: true,
EndHeaders: true,
Priority: PriorityParam{},
},
"\x00\x00\x03\x01\x05\x00\x00\x00*abc",
&HeadersFrame{
FrameHeader: FrameHeader{
valid: true,
StreamID: 42,
Type: FrameHeaders,
Flags: FlagHeadersEndStream | FlagHeadersEndHeaders,
Length: uint32(len("abc")),
},
Priority: PriorityParam{},
headerFragBuf: []byte("abc"),
},
},
{
"with padding",
HeadersFrameParam{
StreamID: 42,
BlockFragment: []byte("abc"),
EndStream: true,
EndHeaders: true,
PadLength: 5,
Priority: PriorityParam{},
},
"\x00\x00\t\x01\r\x00\x00\x00*\x05abc\x00\x00\x00\x00\x00",
&HeadersFrame{
FrameHeader: FrameHeader{
valid: true,
StreamID: 42,
Type: FrameHeaders,
Flags: FlagHeadersEndStream | FlagHeadersEndHeaders | FlagHeadersPadded,
Length: uint32(1 + len("abc") + 5), // pad length + contents + padding
},
Priority: PriorityParam{},
headerFragBuf: []byte("abc"),
},
},
{
"with priority",
HeadersFrameParam{
StreamID: 42,
BlockFragment: []byte("abc"),
EndStream: true,
EndHeaders: true,
PadLength: 2,
Priority: PriorityParam{
StreamDep: 15,
Exclusive: true,
Weight: 127,
},
},
"\x00\x00\v\x01-\x00\x00\x00*\x02\x80\x00\x00\x0f\u007fabc\x00\x00",
&HeadersFrame{
FrameHeader: FrameHeader{
valid: true,
StreamID: 42,
Type: FrameHeaders,
Flags: FlagHeadersEndStream | FlagHeadersEndHeaders | FlagHeadersPadded | FlagHeadersPriority,
Length: uint32(1 + 5 + len("abc") + 2), // pad length + priority + contents + padding
},
Priority: PriorityParam{
StreamDep: 15,
Exclusive: true,
Weight: 127,
},
headerFragBuf: []byte("abc"),
},
},
}
for _, tt := range tests {
fr, buf := testFramer()
if err := fr.WriteHeaders(tt.p); err != nil {
t.Errorf("test %q: %v", tt.name, err)
continue
}
if buf.String() != tt.wantEnc {
t.Errorf("test %q: encoded %q; want %q", tt.name, buf.Bytes(), tt.wantEnc)
}
f, err := fr.ReadFrame()
if err != nil {
t.Errorf("test %q: failed to read the frame back: %v", tt.name, err)
continue
}
if !reflect.DeepEqual(f, tt.wantFrame) {
t.Errorf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tt.name, f, tt.wantFrame)
}
}
}
func TestWriteContinuation(t *testing.T) {
const streamID = 42
tests := []struct {
name string
end bool
frag []byte
wantFrame *ContinuationFrame
}{
{
"not end",
false,
[]byte("abc"),
&ContinuationFrame{
FrameHeader: FrameHeader{
valid: true,
StreamID: streamID,
Type: FrameContinuation,
Length: uint32(len("abc")),
},
headerFragBuf: []byte("abc"),
},
},
{
"end",
true,
[]byte("def"),
&ContinuationFrame{
FrameHeader: FrameHeader{
valid: true,
StreamID: streamID,
Type: FrameContinuation,
Flags: FlagContinuationEndHeaders,
Length: uint32(len("def")),
},
headerFragBuf: []byte("def"),
},
},
}
for _, tt := range tests {
fr, _ := testFramer()
if err := fr.WriteContinuation(streamID, tt.end, tt.frag); err != nil {
t.Errorf("test %q: %v", tt.name, err)
continue
}
f, err := fr.ReadFrame()
if err != nil {
t.Errorf("test %q: failed to read the frame back: %v", tt.name, err)
continue
}
if !reflect.DeepEqual(f, tt.wantFrame) {
t.Errorf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tt.name, f, tt.wantFrame)
}
}
}
func TestWritePriority(t *testing.T) {
const streamID = 42
tests := []struct {
name string
priority PriorityParam
wantFrame *PriorityFrame
}{
{
"not exclusive",
PriorityParam{
StreamDep: 2,
Exclusive: false,
Weight: 127,
},
&PriorityFrame{
FrameHeader{
valid: true,
StreamID: streamID,
Type: FramePriority,
Length: 5,
},
PriorityParam{
StreamDep: 2,
Exclusive: false,
Weight: 127,
},
},
},
{
"exclusive",
PriorityParam{
StreamDep: 3,
Exclusive: true,
Weight: 77,
},
&PriorityFrame{
FrameHeader{
valid: true,
StreamID: streamID,
Type: FramePriority,
Length: 5,
},
PriorityParam{
StreamDep: 3,
Exclusive: true,
Weight: 77,
},
},
},
}
for _, tt := range tests {
fr, _ := testFramer()
if err := fr.WritePriority(streamID, tt.priority); err != nil {
t.Errorf("test %q: %v", tt.name, err)
continue
}
f, err := fr.ReadFrame()
if err != nil {
t.Errorf("test %q: failed to read the frame back: %v", tt.name, err)
continue
}
if !reflect.DeepEqual(f, tt.wantFrame) {
t.Errorf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tt.name, f, tt.wantFrame)
}
}
}
func TestWriteSettings(t *testing.T) {
fr, buf := testFramer()
settings := []Setting{{1, 2}, {3, 4}}
fr.WriteSettings(settings...)
const wantEnc = "\x00\x00\f\x04\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02\x00\x03\x00\x00\x00\x04"
if buf.String() != wantEnc {
t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc)
}
f, err := fr.ReadFrame()
if err != nil {
t.Fatal(err)
}
sf, ok := f.(*SettingsFrame)
if !ok {
t.Fatalf("Got a %T; want a SettingsFrame", f)
}
var got []Setting
sf.ForeachSetting(func(s Setting) error {
got = append(got, s)
valBack, ok := sf.Value(s.ID)
if !ok || valBack != s.Val {
t.Errorf("Value(%d) = %v, %v; want %v, true", s.ID, valBack, ok, s.Val)
}
return nil
})
if !reflect.DeepEqual(settings, got) {
t.Errorf("Read settings %+v != written settings %+v", got, settings)
}
}
func TestWriteSettingsAck(t *testing.T) {
fr, buf := testFramer()
fr.WriteSettingsAck()
const wantEnc = "\x00\x00\x00\x04\x01\x00\x00\x00\x00"
if buf.String() != wantEnc {
t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc)
}
}
func TestWriteWindowUpdate(t *testing.T) {
fr, buf := testFramer()
const streamID = 1<<24 + 2<<16 + 3<<8 + 4
const incr = 7<<24 + 6<<16 + 5<<8 + 4
if err := fr.WriteWindowUpdate(streamID, incr); err != nil {
t.Fatal(err)
}
const wantEnc = "\x00\x00\x04\x08\x00\x01\x02\x03\x04\x07\x06\x05\x04"
if buf.String() != wantEnc {
t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc)
}
f, err := fr.ReadFrame()
if err != nil {
t.Fatal(err)
}
want := &WindowUpdateFrame{
FrameHeader: FrameHeader{
valid: true,
Type: 0x8,
Flags: 0x0,
Length: 0x4,
StreamID: 0x1020304,
},
Increment: 0x7060504,
}
if !reflect.DeepEqual(f, want) {
t.Errorf("parsed back %#v; want %#v", f, want)
}
}
func TestWritePing(t *testing.T) { testWritePing(t, false) }
func TestWritePingAck(t *testing.T) { testWritePing(t, true) }
func testWritePing(t *testing.T, ack bool) {
fr, buf := testFramer()
if err := fr.WritePing(ack, [8]byte{1, 2, 3, 4, 5, 6, 7, 8}); err != nil {
t.Fatal(err)
}
var wantFlags Flags
if ack {
wantFlags = FlagPingAck
}
var wantEnc = "\x00\x00\x08\x06" + string(wantFlags) + "\x00\x00\x00\x00" + "\x01\x02\x03\x04\x05\x06\x07\x08"
if buf.String() != wantEnc {
t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc)
}
f, err := fr.ReadFrame()
if err != nil {
t.Fatal(err)
}
want := &PingFrame{
FrameHeader: FrameHeader{
valid: true,
Type: 0x6,
Flags: wantFlags,
Length: 0x8,
StreamID: 0,
},
Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8},
}
if !reflect.DeepEqual(f, want) {
t.Errorf("parsed back %#v; want %#v", f, want)
}
}
func TestReadFrameHeader(t *testing.T) {
tests := []struct {
in string
want FrameHeader
}{
{in: "\x00\x00\x00" + "\x00" + "\x00" + "\x00\x00\x00\x00", want: FrameHeader{}},
{in: "\x01\x02\x03" + "\x04" + "\x05" + "\x06\x07\x08\x09", want: FrameHeader{
Length: 66051, Type: 4, Flags: 5, StreamID: 101124105,
}},
// Ignore high bit:
{in: "\xff\xff\xff" + "\xff" + "\xff" + "\xff\xff\xff\xff", want: FrameHeader{
Length: 16777215, Type: 255, Flags: 255, StreamID: 2147483647}},
{in: "\xff\xff\xff" + "\xff" + "\xff" + "\x7f\xff\xff\xff", want: FrameHeader{
Length: 16777215, Type: 255, Flags: 255, StreamID: 2147483647}},
}
for i, tt := range tests {
got, err := readFrameHeader(make([]byte, 9), strings.NewReader(tt.in))
if err != nil {
t.Errorf("%d. readFrameHeader(%q) = %v", i, tt.in, err)
continue
}
tt.want.valid = true
if got != tt.want {
t.Errorf("%d. readFrameHeader(%q) = %+v; want %+v", i, tt.in, got, tt.want)
}
}
}
func TestReadWriteFrameHeader(t *testing.T) {
tests := []struct {
len uint32
typ FrameType
flags Flags
streamID uint32
}{
{len: 0, typ: 255, flags: 1, streamID: 0},
{len: 0, typ: 255, flags: 1, streamID: 1},
{len: 0, typ: 255, flags: 1, streamID: 255},
{len: 0, typ: 255, flags: 1, streamID: 256},
{len: 0, typ: 255, flags: 1, streamID: 65535},
{len: 0, typ: 255, flags: 1, streamID: 65536},
{len: 0, typ: 1, flags: 255, streamID: 1},
{len: 255, typ: 1, flags: 255, streamID: 1},
{len: 256, typ: 1, flags: 255, streamID: 1},
{len: 65535, typ: 1, flags: 255, streamID: 1},
{len: 65536, typ: 1, flags: 255, streamID: 1},
{len: 16777215, typ: 1, flags: 255, streamID: 1},
}
for _, tt := range tests {
fr, buf := testFramer()
fr.startWrite(tt.typ, tt.flags, tt.streamID)
fr.writeBytes(make([]byte, tt.len))
fr.endWrite()
fh, err := ReadFrameHeader(buf)
if err != nil {
t.Errorf("ReadFrameHeader(%+v) = %v", tt, err)
continue
}
if fh.Type != tt.typ || fh.Flags != tt.flags || fh.Length != tt.len || fh.StreamID != tt.streamID {
t.Errorf("ReadFrameHeader(%+v) = %+v; mismatch", tt, fh)
}
}
}
func TestWriteTooLargeFrame(t *testing.T) {
fr, _ := testFramer()
fr.startWrite(0, 1, 1)
fr.writeBytes(make([]byte, 1<<24))
err := fr.endWrite()
if err != ErrFrameTooLarge {
t.Errorf("endWrite = %v; want errFrameTooLarge", err)
}
}
func TestWriteGoAway(t *testing.T) {
const debug = "foo"
fr, buf := testFramer()
if err := fr.WriteGoAway(0x01020304, 0x05060708, []byte(debug)); err != nil {
t.Fatal(err)
}
const wantEnc = "\x00\x00\v\a\x00\x00\x00\x00\x00\x01\x02\x03\x04\x05\x06\x07\x08" + debug
if buf.String() != wantEnc {
t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc)
}
f, err := fr.ReadFrame()
if err != nil {
t.Fatal(err)
}
want := &GoAwayFrame{
FrameHeader: FrameHeader{
valid: true,
Type: 0x7,
Flags: 0,
Length: uint32(4 + 4 + len(debug)),
StreamID: 0,
},
LastStreamID: 0x01020304,
ErrCode: 0x05060708,
debugData: []byte(debug),
}
if !reflect.DeepEqual(f, want) {
t.Fatalf("parsed back:\n%#v\nwant:\n%#v", f, want)
}
if got := string(f.(*GoAwayFrame).DebugData()); got != debug {
t.Errorf("debug data = %q; want %q", got, debug)
}
}
func TestWritePushPromise(t *testing.T) {
pp := PushPromiseParam{
StreamID: 42,
PromiseID: 42,
BlockFragment: []byte("abc"),
}
fr, buf := testFramer()
if err := fr.WritePushPromise(pp); err != nil {
t.Fatal(err)
}
const wantEnc = "\x00\x00\x07\x05\x00\x00\x00\x00*\x00\x00\x00*abc"
if buf.String() != wantEnc {
t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc)
}
f, err := fr.ReadFrame()
if err != nil {
t.Fatal(err)
}
_, ok := f.(*PushPromiseFrame)
if !ok {
t.Fatalf("got %T; want *PushPromiseFrame", f)
}
want := &PushPromiseFrame{
FrameHeader: FrameHeader{
valid: true,
Type: 0x5,
Flags: 0x0,
Length: 0x7,
StreamID: 42,
},
PromiseID: 42,
headerFragBuf: []byte("abc"),
}
if !reflect.DeepEqual(f, want) {
t.Fatalf("parsed back:\n%#v\nwant:\n%#v", f, want)
}
}

View file

@ -1,33 +0,0 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
package http2
import (
"fmt"
"strings"
"testing"
)
func TestGoroutineLock(t *testing.T) {
DebugGoroutines = true
g := newGoroutineLock()
g.check()
sawPanic := make(chan interface{})
go func() {
defer func() { sawPanic <- recover() }()
g.check() // should panic
}()
e := <-sawPanic
if e == nil {
t.Fatal("did not see panic from check in other goroutine")
}
if !strings.Contains(fmt.Sprint(e), "wrong goroutine") {
t.Errorf("expected on see panic about running on the wrong goroutine; got %v", e)
}
}

View file

@ -1,5 +0,0 @@
h2demo
h2demo.linux
client-id.dat
client-secret.dat
token.dat

View file

@ -1,5 +0,0 @@
h2demo.linux: h2demo.go
GOOS=linux go build --tags=h2demo -o h2demo.linux .
upload: h2demo.linux
cat h2demo.linux | go run launch.go --write_object=http2-demo-server-tls/h2demo --write_object_is_public

View file

@ -1,16 +0,0 @@
Client:
-- Firefox nightly with about:config network.http.spdy.enabled.http2draft set true
-- Chrome: go to chrome://flags/#enable-spdy4, save and restart (button at bottom)
Make CA:
$ openssl genrsa -out rootCA.key 2048
$ openssl req -x509 -new -nodes -key rootCA.key -days 1024 -out rootCA.pem
... install that to Firefox
Make cert:
$ openssl genrsa -out server.key 2048
$ openssl req -new -key server.key -out server.csr
$ openssl x509 -req -in server.csr -CA rootCA.pem -CAkey rootCA.key -CAcreateserial -out server.crt -days 500

View file

@ -1,426 +0,0 @@
// Copyright 2014 The Go Authors.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
// +build h2demo
package main
import (
"bytes"
"crypto/tls"
"flag"
"fmt"
"hash/crc32"
"image"
"image/jpeg"
"io"
"io/ioutil"
"log"
"net"
"net/http"
"os/exec"
"path"
"regexp"
"runtime"
"strconv"
"strings"
"sync"
"time"
"camlistore.org/pkg/googlestorage"
"camlistore.org/pkg/singleflight"
"github.com/bradfitz/http2"
)
var (
openFirefox = flag.Bool("openff", false, "Open Firefox")
addr = flag.String("addr", "localhost:4430", "TLS address to listen on")
httpAddr = flag.String("httpaddr", "", "If non-empty, address to listen for regular HTTP on")
prod = flag.Bool("prod", false, "Whether to configure itself to be the production http2.golang.org server.")
)
func homeOldHTTP(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, `<html>
<body>
<h1>Go + HTTP/2</h1>
<p>Welcome to <a href="https://golang.org/">the Go language</a>'s <a href="https://http2.github.io/">HTTP/2</a> demo & interop server.</p>
<p>Unfortunately, you're <b>not</b> using HTTP/2 right now. To do so:</p>
<ul>
<li>Use Firefox Nightly or go to <b>about:config</b> and enable "network.http.spdy.enabled.http2draft"</li>
<li>Use Google Chrome Canary and/or go to <b>chrome://flags/#enable-spdy4</b> to <i>Enable SPDY/4</i> (Chrome's name for HTTP/2)</li>
</ul>
<p>See code & instructions for connecting at <a href="https://github.com/bradfitz/http2">https://github.com/bradfitz/http2</a>.</p>
</body></html>`)
}
func home(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/" {
http.NotFound(w, r)
return
}
io.WriteString(w, `<html>
<body>
<h1>Go + HTTP/2</h1>
<p>Welcome to <a href="https://golang.org/">the Go language</a>'s <a
href="https://http2.github.io/">HTTP/2</a> demo & interop server.</p>
<p>Congratulations, <b>you're using HTTP/2 right now</b>.</p>
<p>This server exists for others in the HTTP/2 community to test their HTTP/2 client implementations and point out flaws in our server.</p>
<p> The code is currently at <a
href="https://github.com/bradfitz/http2">github.com/bradfitz/http2</a>
but will move to the Go standard library at some point in the future
(enabled by default, without users needing to change their code).</p>
<p>Contact info: <i>bradfitz@golang.org</i>, or <a
href="https://github.com/bradfitz/http2/issues">file a bug</a>.</p>
<h2>Handlers for testing</h2>
<ul>
<li>GET <a href="/reqinfo">/reqinfo</a> to dump the request + headers received</li>
<li>GET <a href="/clockstream">/clockstream</a> streams the current time every second</li>
<li>GET <a href="/gophertiles">/gophertiles</a> to see a page with a bunch of images</li>
<li>GET <a href="/file/gopher.png">/file/gopher.png</a> for a small file (does If-Modified-Since, Content-Range, etc)</li>
<li>GET <a href="/file/go.src.tar.gz">/file/go.src.tar.gz</a> for a larger file (~10 MB)</li>
<li>GET <a href="/redirect">/redirect</a> to redirect back to / (this page)</li>
<li>GET <a href="/goroutines">/goroutines</a> to see all active goroutines in this server</li>
<li>PUT something to <a href="/crc32">/crc32</a> to get a count of number of bytes and its CRC-32</li>
</ul>
</body></html>`)
}
func reqInfoHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
fmt.Fprintf(w, "Method: %s\n", r.Method)
fmt.Fprintf(w, "Protocol: %s\n", r.Proto)
fmt.Fprintf(w, "Host: %s\n", r.Host)
fmt.Fprintf(w, "RemoteAddr: %s\n", r.RemoteAddr)
fmt.Fprintf(w, "RequestURI: %q\n", r.RequestURI)
fmt.Fprintf(w, "URL: %#v\n", r.URL)
fmt.Fprintf(w, "Body.ContentLength: %d (-1 means unknown)\n", r.ContentLength)
fmt.Fprintf(w, "Close: %v (relevant for HTTP/1 only)\n", r.Close)
fmt.Fprintf(w, "TLS: %#v\n", r.TLS)
fmt.Fprintf(w, "\nHeaders:\n")
r.Header.Write(w)
}
func crcHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != "PUT" {
http.Error(w, "PUT required.", 400)
return
}
crc := crc32.NewIEEE()
n, err := io.Copy(crc, r.Body)
if err == nil {
w.Header().Set("Content-Type", "text/plain")
fmt.Fprintf(w, "bytes=%d, CRC32=%x", n, crc.Sum(nil))
}
}
var (
fsGrp singleflight.Group
fsMu sync.Mutex // guards fsCache
fsCache = map[string]http.Handler{}
)
// fileServer returns a file-serving handler that proxies URL.
// It lazily fetches URL on the first access and caches its contents forever.
func fileServer(url string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
hi, err := fsGrp.Do(url, func() (interface{}, error) {
fsMu.Lock()
if h, ok := fsCache[url]; ok {
fsMu.Unlock()
return h, nil
}
fsMu.Unlock()
res, err := http.Get(url)
if err != nil {
return nil, err
}
defer res.Body.Close()
slurp, err := ioutil.ReadAll(res.Body)
if err != nil {
return nil, err
}
modTime := time.Now()
var h http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.ServeContent(w, r, path.Base(url), modTime, bytes.NewReader(slurp))
})
fsMu.Lock()
fsCache[url] = h
fsMu.Unlock()
return h, nil
})
if err != nil {
http.Error(w, err.Error(), 500)
return
}
hi.(http.Handler).ServeHTTP(w, r)
})
}
func clockStreamHandler(w http.ResponseWriter, r *http.Request) {
clientGone := w.(http.CloseNotifier).CloseNotify()
w.Header().Set("Content-Type", "text/plain")
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
fmt.Fprintf(w, "# ~1KB of junk to force browsers to start rendering immediately: \n")
io.WriteString(w, strings.Repeat("# xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\n", 13))
for {
fmt.Fprintf(w, "%v\n", time.Now())
w.(http.Flusher).Flush()
select {
case <-ticker.C:
case <-clientGone:
log.Printf("Client %v disconnected from the clock", r.RemoteAddr)
return
}
}
}
func registerHandlers() {
tiles := newGopherTilesHandler()
mux2 := http.NewServeMux()
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
if r.TLS == nil {
if r.URL.Path == "/gophertiles" {
tiles.ServeHTTP(w, r)
return
}
http.Redirect(w, r, "https://http2.golang.org/", http.StatusFound)
return
}
if r.ProtoMajor == 1 {
if r.URL.Path == "/reqinfo" {
reqInfoHandler(w, r)
return
}
homeOldHTTP(w, r)
return
}
mux2.ServeHTTP(w, r)
})
mux2.HandleFunc("/", home)
mux2.Handle("/file/gopher.png", fileServer("https://golang.org/doc/gopher/frontpage.png"))
mux2.Handle("/file/go.src.tar.gz", fileServer("https://storage.googleapis.com/golang/go1.4.1.src.tar.gz"))
mux2.HandleFunc("/reqinfo", reqInfoHandler)
mux2.HandleFunc("/crc32", crcHandler)
mux2.HandleFunc("/clockstream", clockStreamHandler)
mux2.Handle("/gophertiles", tiles)
mux2.HandleFunc("/redirect", func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/", http.StatusFound)
})
stripHomedir := regexp.MustCompile(`/(Users|home)/\w+`)
mux2.HandleFunc("/goroutines", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
buf := make([]byte, 2<<20)
w.Write(stripHomedir.ReplaceAll(buf[:runtime.Stack(buf, true)], nil))
})
}
func newGopherTilesHandler() http.Handler {
const gopherURL = "https://blog.golang.org/go-programming-language-turns-two_gophers.jpg"
res, err := http.Get(gopherURL)
if err != nil {
log.Fatal(err)
}
if res.StatusCode != 200 {
log.Fatalf("Error fetching %s: %v", gopherURL, res.Status)
}
slurp, err := ioutil.ReadAll(res.Body)
res.Body.Close()
if err != nil {
log.Fatal(err)
}
im, err := jpeg.Decode(bytes.NewReader(slurp))
if err != nil {
if len(slurp) > 1024 {
slurp = slurp[:1024]
}
log.Fatalf("Failed to decode gopher image: %v (got %q)", err, slurp)
}
type subImager interface {
SubImage(image.Rectangle) image.Image
}
const tileSize = 32
xt := im.Bounds().Max.X / tileSize
yt := im.Bounds().Max.Y / tileSize
var tile [][][]byte // y -> x -> jpeg bytes
for yi := 0; yi < yt; yi++ {
var row [][]byte
for xi := 0; xi < xt; xi++ {
si := im.(subImager).SubImage(image.Rectangle{
Min: image.Point{xi * tileSize, yi * tileSize},
Max: image.Point{(xi + 1) * tileSize, (yi + 1) * tileSize},
})
buf := new(bytes.Buffer)
if err := jpeg.Encode(buf, si, &jpeg.Options{Quality: 90}); err != nil {
log.Fatal(err)
}
row = append(row, buf.Bytes())
}
tile = append(tile, row)
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ms, _ := strconv.Atoi(r.FormValue("latency"))
const nanosPerMilli = 1e6
if r.FormValue("x") != "" {
x, _ := strconv.Atoi(r.FormValue("x"))
y, _ := strconv.Atoi(r.FormValue("y"))
if ms <= 1000 {
time.Sleep(time.Duration(ms) * nanosPerMilli)
}
if x >= 0 && x < xt && y >= 0 && y < yt {
http.ServeContent(w, r, "", time.Time{}, bytes.NewReader(tile[y][x]))
return
}
}
io.WriteString(w, "<html><body>")
fmt.Fprintf(w, "A grid of %d tiled images is below. Compare:<p>", xt*yt)
for _, ms := range []int{0, 30, 200, 1000} {
d := time.Duration(ms) * nanosPerMilli
fmt.Fprintf(w, "[<a href='https://%s/gophertiles?latency=%d'>HTTP/2, %v latency</a>] [<a href='http://%s/gophertiles?latency=%d'>HTTP/1, %v latency</a>]<br>\n",
httpsHost(), ms, d,
httpHost(), ms, d,
)
}
io.WriteString(w, "<p>\n")
cacheBust := time.Now().UnixNano()
for y := 0; y < yt; y++ {
for x := 0; x < xt; x++ {
fmt.Fprintf(w, "<img width=%d height=%d src='/gophertiles?x=%d&y=%d&cachebust=%d&latency=%d'>",
tileSize, tileSize, x, y, cacheBust, ms)
}
io.WriteString(w, "<br/>\n")
}
io.WriteString(w, "<hr><a href='/'>&lt;&lt Back to Go HTTP/2 demo server</a></body></html>")
})
}
func httpsHost() string {
if *prod {
return "http2.golang.org"
}
if v := *addr; strings.HasPrefix(v, ":") {
return "localhost" + v
} else {
return v
}
}
func httpHost() string {
if *prod {
return "http2.golang.org"
}
if v := *httpAddr; strings.HasPrefix(v, ":") {
return "localhost" + v
} else {
return v
}
}
func serveProdTLS() error {
c, err := googlestorage.NewServiceClient()
if err != nil {
return err
}
slurp := func(key string) ([]byte, error) {
const bucket = "http2-demo-server-tls"
rc, _, err := c.GetObject(&googlestorage.Object{
Bucket: bucket,
Key: key,
})
if err != nil {
return nil, fmt.Errorf("Error fetching GCS object %q in bucket %q: %v", key, bucket, err)
}
defer rc.Close()
return ioutil.ReadAll(rc)
}
certPem, err := slurp("http2.golang.org.chained.pem")
if err != nil {
return err
}
keyPem, err := slurp("http2.golang.org.key")
if err != nil {
return err
}
cert, err := tls.X509KeyPair(certPem, keyPem)
if err != nil {
return err
}
srv := &http.Server{
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{cert},
},
}
http2.ConfigureServer(srv, &http2.Server{})
ln, err := net.Listen("tcp", ":443")
if err != nil {
return err
}
return srv.Serve(tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig))
}
type tcpKeepAliveListener struct {
*net.TCPListener
}
func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
tc, err := ln.AcceptTCP()
if err != nil {
return
}
tc.SetKeepAlive(true)
tc.SetKeepAlivePeriod(3 * time.Minute)
return tc, nil
}
func serveProd() error {
errc := make(chan error, 2)
go func() { errc <- http.ListenAndServe(":80", nil) }()
go func() { errc <- serveProdTLS() }()
return <-errc
}
func main() {
var srv http.Server
flag.BoolVar(&http2.VerboseLogs, "verbose", false, "Verbose HTTP/2 debugging.")
flag.Parse()
srv.Addr = *addr
registerHandlers()
if *prod {
*httpAddr = "http2.golang.org"
log.Fatal(serveProd())
}
url := "https://" + *addr + "/"
log.Printf("Listening on " + url)
http2.ConfigureServer(&srv, &http2.Server{})
if *httpAddr != "" {
go func() { log.Fatal(http.ListenAndServe(*httpAddr, nil)) }()
}
go func() {
log.Fatal(srv.ListenAndServeTLS("server.crt", "server.key"))
}()
if *openFirefox && runtime.GOOS == "darwin" {
time.Sleep(250 * time.Millisecond)
exec.Command("open", "-b", "org.mozilla.nightly", "https://localhost:4430/").Run()
}
select {}
}

View file

@ -1,302 +0,0 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build ignore
package main
import (
"bufio"
"bytes"
"encoding/json"
"flag"
"fmt"
"io"
"io/ioutil"
"log"
"net/http"
"os"
"strings"
"time"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
compute "google.golang.org/api/compute/v1"
)
var (
proj = flag.String("project", "symbolic-datum-552", "name of Project")
zone = flag.String("zone", "us-central1-a", "GCE zone")
mach = flag.String("machinetype", "n1-standard-1", "Machine type")
instName = flag.String("instance_name", "http2-demo", "Name of VM instance.")
sshPub = flag.String("ssh_public_key", "", "ssh public key file to authorize. Can modify later in Google's web UI anyway.")
staticIP = flag.String("static_ip", "130.211.116.44", "Static IP to use. If empty, automatic.")
writeObject = flag.String("write_object", "", "If non-empty, a VM isn't created and the flag value is Google Cloud Storage bucket/object to write. The contents from stdin.")
publicObject = flag.Bool("write_object_is_public", false, "Whether the object created by --write_object should be public.")
)
func readFile(v string) string {
slurp, err := ioutil.ReadFile(v)
if err != nil {
log.Fatalf("Error reading %s: %v", v, err)
}
return strings.TrimSpace(string(slurp))
}
var config = &oauth2.Config{
// The client-id and secret should be for an "Installed Application" when using
// the CLI. Later we'll use a web application with a callback.
ClientID: readFile("client-id.dat"),
ClientSecret: readFile("client-secret.dat"),
Endpoint: google.Endpoint,
Scopes: []string{
compute.DevstorageFull_controlScope,
compute.ComputeScope,
"https://www.googleapis.com/auth/sqlservice",
"https://www.googleapis.com/auth/sqlservice.admin",
},
RedirectURL: "urn:ietf:wg:oauth:2.0:oob",
}
const baseConfig = `#cloud-config
coreos:
units:
- name: h2demo.service
command: start
content: |
[Unit]
Description=HTTP2 Demo
[Service]
ExecStartPre=/bin/bash -c 'mkdir -p /opt/bin && curl -s -o /opt/bin/h2demo http://storage.googleapis.com/http2-demo-server-tls/h2demo && chmod +x /opt/bin/h2demo'
ExecStart=/opt/bin/h2demo --prod
RestartSec=5s
Restart=always
Type=simple
[Install]
WantedBy=multi-user.target
`
func main() {
flag.Parse()
if *proj == "" {
log.Fatalf("Missing --project flag")
}
prefix := "https://www.googleapis.com/compute/v1/projects/" + *proj
machType := prefix + "/zones/" + *zone + "/machineTypes/" + *mach
const tokenFileName = "token.dat"
tokenFile := tokenCacheFile(tokenFileName)
tokenSource := oauth2.ReuseTokenSource(nil, tokenFile)
token, err := tokenSource.Token()
if err != nil {
if *writeObject != "" {
log.Fatalf("Can't use --write_object without a valid token.dat file already cached.")
}
log.Printf("Error getting token from %s: %v", tokenFileName, err)
log.Printf("Get auth code from %v", config.AuthCodeURL("my-state"))
fmt.Print("\nEnter auth code: ")
sc := bufio.NewScanner(os.Stdin)
sc.Scan()
authCode := strings.TrimSpace(sc.Text())
token, err = config.Exchange(oauth2.NoContext, authCode)
if err != nil {
log.Fatalf("Error exchanging auth code for a token: %v", err)
}
if err := tokenFile.WriteToken(token); err != nil {
log.Fatalf("Error writing to %s: %v", tokenFileName, err)
}
tokenSource = oauth2.ReuseTokenSource(token, nil)
}
oauthClient := oauth2.NewClient(oauth2.NoContext, tokenSource)
if *writeObject != "" {
writeCloudStorageObject(oauthClient)
return
}
computeService, _ := compute.New(oauthClient)
natIP := *staticIP
if natIP == "" {
// Try to find it by name.
aggAddrList, err := computeService.Addresses.AggregatedList(*proj).Do()
if err != nil {
log.Fatal(err)
}
// http://godoc.org/code.google.com/p/google-api-go-client/compute/v1#AddressAggregatedList
IPLoop:
for _, asl := range aggAddrList.Items {
for _, addr := range asl.Addresses {
if addr.Name == *instName+"-ip" && addr.Status == "RESERVED" {
natIP = addr.Address
break IPLoop
}
}
}
}
cloudConfig := baseConfig
if *sshPub != "" {
key := strings.TrimSpace(readFile(*sshPub))
cloudConfig += fmt.Sprintf("\nssh_authorized_keys:\n - %s\n", key)
}
if os.Getenv("USER") == "bradfitz" {
cloudConfig += fmt.Sprintf("\nssh_authorized_keys:\n - %s\n", "ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAIEAwks9dwWKlRC+73gRbvYtVg0vdCwDSuIlyt4z6xa/YU/jTDynM4R4W10hm2tPjy8iR1k8XhDv4/qdxe6m07NjG/By1tkmGpm1mGwho4Pr5kbAAy/Qg+NLCSdAYnnE00FQEcFOC15GFVMOW2AzDGKisReohwH9eIzHPzdYQNPRWXE= bradfitz@papag.bradfitz.com")
}
const maxCloudConfig = 32 << 10 // per compute API docs
if len(cloudConfig) > maxCloudConfig {
log.Fatalf("cloud config length of %d bytes is over %d byte limit", len(cloudConfig), maxCloudConfig)
}
instance := &compute.Instance{
Name: *instName,
Description: "Go Builder",
MachineType: machType,
Disks: []*compute.AttachedDisk{instanceDisk(computeService)},
Tags: &compute.Tags{
Items: []string{"http-server", "https-server"},
},
Metadata: &compute.Metadata{
Items: []*compute.MetadataItems{
{
Key: "user-data",
Value: cloudConfig,
},
},
},
NetworkInterfaces: []*compute.NetworkInterface{
&compute.NetworkInterface{
AccessConfigs: []*compute.AccessConfig{
&compute.AccessConfig{
Type: "ONE_TO_ONE_NAT",
Name: "External NAT",
NatIP: natIP,
},
},
Network: prefix + "/global/networks/default",
},
},
ServiceAccounts: []*compute.ServiceAccount{
{
Email: "default",
Scopes: []string{
compute.DevstorageFull_controlScope,
compute.ComputeScope,
},
},
},
}
log.Printf("Creating instance...")
op, err := computeService.Instances.Insert(*proj, *zone, instance).Do()
if err != nil {
log.Fatalf("Failed to create instance: %v", err)
}
opName := op.Name
log.Printf("Created. Waiting on operation %v", opName)
OpLoop:
for {
time.Sleep(2 * time.Second)
op, err := computeService.ZoneOperations.Get(*proj, *zone, opName).Do()
if err != nil {
log.Fatalf("Failed to get op %s: %v", opName, err)
}
switch op.Status {
case "PENDING", "RUNNING":
log.Printf("Waiting on operation %v", opName)
continue
case "DONE":
if op.Error != nil {
for _, operr := range op.Error.Errors {
log.Printf("Error: %+v", operr)
}
log.Fatalf("Failed to start.")
}
log.Printf("Success. %+v", op)
break OpLoop
default:
log.Fatalf("Unknown status %q: %+v", op.Status, op)
}
}
inst, err := computeService.Instances.Get(*proj, *zone, *instName).Do()
if err != nil {
log.Fatalf("Error getting instance after creation: %v", err)
}
ij, _ := json.MarshalIndent(inst, "", " ")
log.Printf("Instance: %s", ij)
}
func instanceDisk(svc *compute.Service) *compute.AttachedDisk {
const imageURL = "https://www.googleapis.com/compute/v1/projects/coreos-cloud/global/images/coreos-stable-444-5-0-v20141016"
diskName := *instName + "-disk"
return &compute.AttachedDisk{
AutoDelete: true,
Boot: true,
Type: "PERSISTENT",
InitializeParams: &compute.AttachedDiskInitializeParams{
DiskName: diskName,
SourceImage: imageURL,
DiskSizeGb: 50,
},
}
}
func writeCloudStorageObject(httpClient *http.Client) {
content := os.Stdin
const maxSlurp = 1 << 20
var buf bytes.Buffer
n, err := io.CopyN(&buf, content, maxSlurp)
if err != nil && err != io.EOF {
log.Fatalf("Error reading from stdin: %v, %v", n, err)
}
contentType := http.DetectContentType(buf.Bytes())
req, err := http.NewRequest("PUT", "https://storage.googleapis.com/"+*writeObject, io.MultiReader(&buf, content))
if err != nil {
log.Fatal(err)
}
req.Header.Set("x-goog-api-version", "2")
if *publicObject {
req.Header.Set("x-goog-acl", "public-read")
}
req.Header.Set("Content-Type", contentType)
res, err := httpClient.Do(req)
if err != nil {
log.Fatal(err)
}
if res.StatusCode != 200 {
res.Write(os.Stderr)
log.Fatalf("Failed.")
}
log.Printf("Success.")
os.Exit(0)
}
type tokenCacheFile string
func (f tokenCacheFile) Token() (*oauth2.Token, error) {
slurp, err := ioutil.ReadFile(string(f))
if err != nil {
return nil, err
}
t := new(oauth2.Token)
if err := json.Unmarshal(slurp, t); err != nil {
return nil, err
}
return t, nil
}
func (f tokenCacheFile) WriteToken(t *oauth2.Token) error {
jt, err := json.Marshal(t)
if err != nil {
return err
}
return ioutil.WriteFile(string(f), jt, 0600)
}

View file

@ -1,27 +0,0 @@
-----BEGIN RSA PRIVATE KEY-----
MIIEowIBAAKCAQEAt5fAjp4fTcekWUTfzsp0kyih1OYbsGL0KX1eRbSSR8Od0+9Q
62Hyny+GFwMTb4A/KU8mssoHvcceSAAbwfbxFK/+s51TobqUnORZrOoTZjkUygby
XDSK99YBbcR1Pip8vwMTm4XKuLtCigeBBdjjAQdgUO28LENGlsMnmeYkJfODVGnV
mr5Ltb9ANA8IKyTfsnHJ4iOCS/PlPbUj2q7YnoVLposUBMlgUb/CykX3mOoLb4yJ
JQyA/iST6ZxiIEj36D4yWZ5lg7YJl+UiiBQHGCnPdGyipqV06ex0heYWcaiW8LWZ
SUQ93jQ+WVCH8hT7DQO1dmsvUmXlq/JeAlwQ/QIDAQABAoIBAFFHV7JMAqPWnMYA
nezY6J81v9+XN+7xABNWM2Q8uv4WdksbigGLTXR3/680Z2hXqJ7LMeC5XJACFT/e
/Gr0vmpgOCygnCPfjGehGKpavtfksXV3edikUlnCXsOP1C//c1bFL+sMYmFCVgTx
qYdDK8yKzXNGrKYT6q5YG7IglyRNV1rsQa8lM/5taFYiD1Ck/3tQi3YIq8Lcuser
hrxsMABcQ6mi+EIvG6Xr4mfJug0dGJMHG4RG1UGFQn6RXrQq2+q53fC8ZbVUSi0j
NQ918aKFzktwv+DouKU0ME4I9toks03gM860bAL7zCbKGmwR3hfgX/TqzVCWpG9E
LDVfvekCgYEA8fk9N53jbBRmULUGEf4qWypcLGiZnNU0OeXWpbPV9aa3H0VDytA7
8fCN2dPAVDPqlthMDdVe983NCNwp2Yo8ZimDgowyIAKhdC25s1kejuaiH9OAPj3c
0f8KbriYX4n8zNHxFwK6Ae3pQ6EqOLJVCUsziUaZX9nyKY5aZlyX6xcCgYEAwjws
K62PjC64U5wYddNLp+kNdJ4edx+a7qBb3mEgPvSFT2RO3/xafJyG8kQB30Mfstjd
bRxyUV6N0vtX1zA7VQtRUAvfGCecpMo+VQZzcHXKzoRTnQ7eZg4Lmj5fQ9tOAKAo
QCVBoSW/DI4PZL26CAMDcAba4Pa22ooLapoRIQsCgYA6pIfkkbxLNkpxpt2YwLtt
Kr/590O7UaR9n6k8sW/aQBRDXNsILR1KDl2ifAIxpf9lnXgZJiwE7HiTfCAcW7c1
nzwDCI0hWuHcMTS/NYsFYPnLsstyyjVZI3FY0h4DkYKV9Q9z3zJLQ2hz/nwoD3gy
b2pHC7giFcTts1VPV4Nt8wKBgHeFn4ihHJweg76vZz3Z78w7VNRWGFklUalVdDK7
gaQ7w2y/ROn/146mo0OhJaXFIFRlrpvdzVrU3GDf2YXJYDlM5ZRkObwbZADjksev
WInzcgDy3KDg7WnPasRXbTfMU4t/AkW2p1QKbi3DnSVYuokDkbH2Beo45vxDxhKr
C69RAoGBAIyo3+OJenoZmoNzNJl2WPW5MeBUzSh8T/bgyjFTdqFHF5WiYRD/lfHj
x9Glyw2nutuT4hlOqHvKhgTYdDMsF2oQ72fe3v8Q5FU7FuKndNPEAyvKNXZaShVA
hnlhv5DjXKb0wFWnt5PCCiQLtzG0yyHaITrrEme7FikkIcTxaX/Y
-----END RSA PRIVATE KEY-----

View file

@ -1 +0,0 @@
E2CE26BF3285059C

View file

@ -1,20 +0,0 @@
-----BEGIN CERTIFICATE-----
MIIDPjCCAiYCCQDizia/MoUFnDANBgkqhkiG9w0BAQUFADB7MQswCQYDVQQGEwJV
UzELMAkGA1UECBMCQ0ExFjAUBgNVBAcTDVNhbiBGcmFuY2lzY28xFDASBgNVBAoT
C0JyYWRmaXR6aW5jMRIwEAYDVQQDEwlsb2NhbGhvc3QxHTAbBgkqhkiG9w0BCQEW
DmJyYWRAZGFuZ2EuY29tMB4XDTE0MDcxNTIwNTAyN1oXDTE1MTEyNzIwNTAyN1ow
RzELMAkGA1UEBhMCVVMxCzAJBgNVBAgTAkNBMQswCQYDVQQHEwJTRjEeMBwGA1UE
ChMVYnJhZGZpdHogaHR0cDIgc2VydmVyMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A
MIIBCgKCAQEAs1Y9CyLFrdL8VQWN1WaifDqaZFnoqjHhCMlc1TfG2zA+InDifx2l
gZD3o8FeNnAcfM2sPlk3+ZleOYw9P/CklFVDlvqmpCv9ss/BEp/dDaWvy1LmJ4c2
dbQJfmTxn7CV1H3TsVJvKdwFmdoABb41NoBp6+NNO7OtDyhbIMiCI0pL3Nefb3HL
A7hIMo3DYbORTtJLTIH9W8YKrEWL0lwHLrYFx/UdutZnv+HjdmO6vCN4na55mjws
/vjKQUmc7xeY7Xe20xDEG2oDKVkL2eD7FfyrYMS3rO1ExP2KSqlXYG/1S9I/fz88
F0GK7HX55b5WjZCl2J3ERVdnv/0MQv+sYQIDAQABMA0GCSqGSIb3DQEBBQUAA4IB
AQC0zL+n/YpRZOdulSu9tS8FxrstXqGWoxfe+vIUgqfMZ5+0MkjJ/vW0FqlLDl2R
rn4XaR3e7FmWkwdDVbq/UB6lPmoAaFkCgh9/5oapMaclNVNnfF3fjCJfRr+qj/iD
EmJStTIN0ZuUjAlpiACmfnpEU55PafT5Zx+i1yE4FGjw8bJpFoyD4Hnm54nGjX19
KeCuvcYFUPnBm3lcL0FalF2AjqV02WTHYNQk7YF/oeO7NKBoEgvGvKG3x+xaOeBI
dwvdq175ZsGul30h+QjrRlXhH/twcuaT3GSdoysDl9cCYE8f1Mk8PD6gan3uBCJU
90p6/CbU71bGbfpM2PHot2fm
-----END CERTIFICATE-----

View file

@ -1,27 +0,0 @@
-----BEGIN RSA PRIVATE KEY-----
MIIEowIBAAKCAQEAs1Y9CyLFrdL8VQWN1WaifDqaZFnoqjHhCMlc1TfG2zA+InDi
fx2lgZD3o8FeNnAcfM2sPlk3+ZleOYw9P/CklFVDlvqmpCv9ss/BEp/dDaWvy1Lm
J4c2dbQJfmTxn7CV1H3TsVJvKdwFmdoABb41NoBp6+NNO7OtDyhbIMiCI0pL3Nef
b3HLA7hIMo3DYbORTtJLTIH9W8YKrEWL0lwHLrYFx/UdutZnv+HjdmO6vCN4na55
mjws/vjKQUmc7xeY7Xe20xDEG2oDKVkL2eD7FfyrYMS3rO1ExP2KSqlXYG/1S9I/
fz88F0GK7HX55b5WjZCl2J3ERVdnv/0MQv+sYQIDAQABAoIBADQ2spUwbY+bcz4p
3M66ECrNQTBggP40gYl2XyHxGGOu2xhZ94f9ELf1hjRWU2DUKWco1rJcdZClV6q3
qwmXvcM2Q/SMS8JW0ImkNVl/0/NqPxGatEnj8zY30d/L8hGFb0orzFu/XYA5gCP4
NbN2WrXgk3ZLeqwcNxHHtSiJWGJ/fPyeDWAu/apy75u9Xf2GlzBZmV6HYD9EfK80
LTlI60f5FO487CrJnboL7ovPJrIHn+k05xRQqwma4orpz932rTXnTjs9Lg6KtbQN
a7PrqfAntIISgr11a66Mng3IYH1lYqJsWJJwX/xHT4WLEy0EH4/0+PfYemJekz2+
Co62drECgYEA6O9zVJZXrLSDsIi54cfxA7nEZWm5CAtkYWeAHa4EJ+IlZ7gIf9sL
W8oFcEfFGpvwVqWZ+AsQ70dsjXAv3zXaG0tmg9FtqWp7pzRSMPidifZcQwWkKeTO
gJnFmnVyed8h6GfjTEu4gxo1/S5U0V+mYSha01z5NTnN6ltKx1Or3b0CgYEAxRgm
S30nZxnyg/V7ys61AZhst1DG2tkZXEMcA7dYhabMoXPJAP/EfhlWwpWYYUs/u0gS
Wwmf5IivX5TlYScgmkvb/NYz0u4ZmOXkLTnLPtdKKFXhjXJcHjUP67jYmOxNlJLp
V4vLRnFxTpffAV+OszzRxsXX6fvruwZBANYJeXUCgYBVouLFsFgfWGYp2rpr9XP4
KK25kvrBqF6JKOIDB1zjxNJ3pUMKrl8oqccCFoCyXa4oTM2kUX0yWxHfleUjrMq4
yimwQKiOZmV7fVLSSjSw6e/VfBd0h3gb82ygcplZkN0IclkwTY5SNKqwn/3y07V5
drqdhkrgdJXtmQ6O5YYECQKBgATERcDToQ1USlI4sKrB/wyv1AlG8dg/IebiVJ4e
ZAyvcQmClFzq0qS+FiQUnB/WQw9TeeYrwGs1hxBHuJh16srwhLyDrbMvQP06qh8R
48F8UXXSRec22dV9MQphaROhu2qZdv1AC0WD3tqov6L33aqmEOi+xi8JgbT/PLk5
c/c1AoGBAI1A/02ryksW6/wc7/6SP2M2rTy4m1sD/GnrTc67EHnRcVBdKO6qH2RY
nqC8YcveC2ZghgPTDsA3VGuzuBXpwY6wTyV99q6jxQJ6/xcrD9/NUG6Uwv/xfCxl
IJLeBYEqQundSSny3VtaAUK8Ul1nxpTvVRNwtcyWTo8RHAAyNPWd
-----END RSA PRIVATE KEY-----

View file

@ -1,97 +0,0 @@
# h2i
**h2i** is an interactive HTTP/2 ("h2") console debugger. Miss the good ol'
days of telnetting to your HTTP/1.n servers? We're bringing you
back.
Features:
- send raw HTTP/2 frames
- PING
- SETTINGS
- HEADERS
- etc
- type in HTTP/1.n and have it auto-HPACK/frame-ify it for HTTP/2
- pretty print all received HTTP/2 frames from the peer (including HPACK decoding)
- tab completion of commands, options
Not yet features, but soon:
- unnecessary CONTINUATION frames on short boundaries, to test peer implementations
- request bodies (DATA frames)
- send invalid frames for testing server implementations (supported by underlying Framer)
Later:
- act like a server
## Installation
```
$ go get github.com/bradfitz/http2/h2i
$ h2i <host>
```
## Demo
```
$ h2i
Usage: h2i <hostname>
-insecure
Whether to skip TLS cert validation
-nextproto string
Comma-separated list of NPN/ALPN protocol names to negotiate. (default "h2,h2-14")
$ h2i google.com
Connecting to google.com:443 ...
Connected to 74.125.224.41:443
Negotiated protocol "h2-14"
[FrameHeader SETTINGS len=18]
[MAX_CONCURRENT_STREAMS = 100]
[INITIAL_WINDOW_SIZE = 1048576]
[MAX_FRAME_SIZE = 16384]
[FrameHeader WINDOW_UPDATE len=4]
Window-Increment = 983041
h2i> PING h2iSayHI
[FrameHeader PING flags=ACK len=8]
Data = "h2iSayHI"
h2i> headers
(as HTTP/1.1)> GET / HTTP/1.1
(as HTTP/1.1)> Host: ip.appspot.com
(as HTTP/1.1)> User-Agent: h2i/brad-n-blake
(as HTTP/1.1)>
Opening Stream-ID 1:
:authority = ip.appspot.com
:method = GET
:path = /
:scheme = https
user-agent = h2i/brad-n-blake
[FrameHeader HEADERS flags=END_HEADERS stream=1 len=77]
:status = "200"
alternate-protocol = "443:quic,p=1"
content-length = "15"
content-type = "text/html"
date = "Fri, 01 May 2015 23:06:56 GMT"
server = "Google Frontend"
[FrameHeader DATA flags=END_STREAM stream=1 len=15]
"173.164.155.78\n"
[FrameHeader PING len=8]
Data = "\x00\x00\x00\x00\x00\x00\x00\x00"
h2i> ping
[FrameHeader PING flags=ACK len=8]
Data = "h2i_ping"
h2i> ping
[FrameHeader PING flags=ACK len=8]
Data = "h2i_ping"
h2i> ping
[FrameHeader GOAWAY len=22]
Last-Stream-ID = 1; Error-Code = PROTOCOL_ERROR (1)
ReadFrame: EOF
```
## Status
Quick few hour hack. So much yet to do. Feel free to file issues for
bugs or wishlist items, but [@bmizerany](https://github.com/bmizerany/)
and I aren't yet accepting pull requests until things settle down.

View file

@ -1,489 +0,0 @@
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
/*
The h2i command is an interactive HTTP/2 console.
Usage:
$ h2i [flags] <hostname>
Interactive commands in the console: (all parts case-insensitive)
ping [data]
settings ack
settings FOO=n BAR=z
headers (open a new stream by typing HTTP/1.1)
*/
package main
import (
"bufio"
"bytes"
"crypto/tls"
"errors"
"flag"
"fmt"
"io"
"log"
"net"
"net/http"
"os"
"regexp"
"strconv"
"strings"
"github.com/bradfitz/http2"
"github.com/bradfitz/http2/hpack"
"golang.org/x/crypto/ssh/terminal"
)
// Flags
var (
flagNextProto = flag.String("nextproto", "h2,h2-14", "Comma-separated list of NPN/ALPN protocol names to negotiate.")
flagInsecure = flag.Bool("insecure", false, "Whether to skip TLS cert validation")
)
type command struct {
run func(*h2i, []string) error // required
// complete optionally specifies tokens (case-insensitive) which are
// valid for this subcommand.
complete func() []string
}
var commands = map[string]command{
"ping": command{run: (*h2i).cmdPing},
"settings": command{
run: (*h2i).cmdSettings,
complete: func() []string {
return []string{
"ACK",
http2.SettingHeaderTableSize.String(),
http2.SettingEnablePush.String(),
http2.SettingMaxConcurrentStreams.String(),
http2.SettingInitialWindowSize.String(),
http2.SettingMaxFrameSize.String(),
http2.SettingMaxHeaderListSize.String(),
}
},
},
"quit": command{run: (*h2i).cmdQuit},
"headers": command{run: (*h2i).cmdHeaders},
}
func usage() {
fmt.Fprintf(os.Stderr, "Usage: h2i <hostname>\n\n")
flag.PrintDefaults()
os.Exit(1)
}
// withPort adds ":443" if another port isn't already present.
func withPort(host string) string {
if _, _, err := net.SplitHostPort(host); err != nil {
return net.JoinHostPort(host, "443")
}
return host
}
// h2i is the app's state.
type h2i struct {
host string
tc *tls.Conn
framer *http2.Framer
term *terminal.Terminal
// owned by the command loop:
streamID uint32
hbuf bytes.Buffer
henc *hpack.Encoder
// owned by the readFrames loop:
peerSetting map[http2.SettingID]uint32
hdec *hpack.Decoder
}
func main() {
flag.Usage = usage
flag.Parse()
if flag.NArg() != 1 {
usage()
}
log.SetFlags(0)
host := flag.Arg(0)
app := &h2i{
host: host,
peerSetting: make(map[http2.SettingID]uint32),
}
app.henc = hpack.NewEncoder(&app.hbuf)
if err := app.Main(); err != nil {
if app.term != nil {
app.logf("%v\n", err)
} else {
fmt.Fprintf(os.Stderr, "%v\n", err)
}
os.Exit(1)
}
fmt.Fprintf(os.Stdout, "\n")
}
func (app *h2i) Main() error {
cfg := &tls.Config{
ServerName: app.host,
NextProtos: strings.Split(*flagNextProto, ","),
InsecureSkipVerify: *flagInsecure,
}
hostAndPort := withPort(app.host)
log.Printf("Connecting to %s ...", hostAndPort)
tc, err := tls.Dial("tcp", hostAndPort, cfg)
if err != nil {
return fmt.Errorf("Error dialing %s: %v", withPort(app.host), err)
}
log.Printf("Connected to %v", tc.RemoteAddr())
defer tc.Close()
if err := tc.Handshake(); err != nil {
return fmt.Errorf("TLS handshake: %v", err)
}
if !*flagInsecure {
if err := tc.VerifyHostname(app.host); err != nil {
return fmt.Errorf("VerifyHostname: %v", err)
}
}
state := tc.ConnectionState()
log.Printf("Negotiated protocol %q", state.NegotiatedProtocol)
if !state.NegotiatedProtocolIsMutual || state.NegotiatedProtocol == "" {
return fmt.Errorf("Could not negotiate protocol mutually")
}
if _, err := io.WriteString(tc, http2.ClientPreface); err != nil {
return err
}
app.framer = http2.NewFramer(tc, tc)
oldState, err := terminal.MakeRaw(0)
if err != nil {
return err
}
defer terminal.Restore(0, oldState)
var screen = struct {
io.Reader
io.Writer
}{os.Stdin, os.Stdout}
app.term = terminal.NewTerminal(screen, "h2i> ")
lastWord := regexp.MustCompile(`.+\W(\w+)$`)
app.term.AutoCompleteCallback = func(line string, pos int, key rune) (newLine string, newPos int, ok bool) {
if key != '\t' {
return
}
if pos != len(line) {
// TODO: we're being lazy for now, only supporting tab completion at the end.
return
}
// Auto-complete for the command itself.
if !strings.Contains(line, " ") {
var name string
name, _, ok = lookupCommand(line)
if !ok {
return
}
return name, len(name), true
}
_, c, ok := lookupCommand(line[:strings.IndexByte(line, ' ')])
if !ok || c.complete == nil {
return
}
if strings.HasSuffix(line, " ") {
app.logf("%s", strings.Join(c.complete(), " "))
return line, pos, true
}
m := lastWord.FindStringSubmatch(line)
if m == nil {
return line, len(line), true
}
soFar := m[1]
var match []string
for _, cand := range c.complete() {
if len(soFar) > len(cand) || !strings.EqualFold(cand[:len(soFar)], soFar) {
continue
}
match = append(match, cand)
}
if len(match) == 0 {
return
}
if len(match) > 1 {
// TODO: auto-complete any common prefix
app.logf("%s", strings.Join(match, " "))
return line, pos, true
}
newLine = line[:len(line)-len(soFar)] + match[0]
return newLine, len(newLine), true
}
errc := make(chan error, 2)
go func() { errc <- app.readFrames() }()
go func() { errc <- app.readConsole() }()
return <-errc
}
func (app *h2i) logf(format string, args ...interface{}) {
fmt.Fprintf(app.term, format+"\n", args...)
}
func (app *h2i) readConsole() error {
for {
line, err := app.term.ReadLine()
if err == io.EOF {
return nil
}
if err != nil {
return fmt.Errorf("terminal.ReadLine: %v", err)
}
f := strings.Fields(line)
if len(f) == 0 {
continue
}
cmd, args := f[0], f[1:]
if _, c, ok := lookupCommand(cmd); ok {
err = c.run(app, args)
} else {
app.logf("Unknown command %q", line)
}
if err == errExitApp {
return nil
}
if err != nil {
return err
}
}
}
func lookupCommand(prefix string) (name string, c command, ok bool) {
prefix = strings.ToLower(prefix)
if c, ok = commands[prefix]; ok {
return prefix, c, ok
}
for full, candidate := range commands {
if strings.HasPrefix(full, prefix) {
if c.run != nil {
return "", command{}, false // ambiguous
}
c = candidate
name = full
}
}
return name, c, c.run != nil
}
var errExitApp = errors.New("internal sentinel error value to quit the console reading loop")
func (a *h2i) cmdQuit(args []string) error {
if len(args) > 0 {
a.logf("the QUIT command takes no argument")
return nil
}
return errExitApp
}
func (a *h2i) cmdSettings(args []string) error {
if len(args) == 1 && strings.EqualFold(args[0], "ACK") {
return a.framer.WriteSettingsAck()
}
var settings []http2.Setting
for _, arg := range args {
if strings.EqualFold(arg, "ACK") {
a.logf("Error: ACK must be only argument with the SETTINGS command")
return nil
}
eq := strings.Index(arg, "=")
if eq == -1 {
a.logf("Error: invalid argument %q (expected SETTING_NAME=nnnn)", arg)
return nil
}
sid, ok := settingByName(arg[:eq])
if !ok {
a.logf("Error: unknown setting name %q", arg[:eq])
return nil
}
val, err := strconv.ParseUint(arg[eq+1:], 10, 32)
if err != nil {
a.logf("Error: invalid argument %q (expected SETTING_NAME=nnnn)", arg)
return nil
}
settings = append(settings, http2.Setting{
ID: sid,
Val: uint32(val),
})
}
a.logf("Sending: %v", settings)
return a.framer.WriteSettings(settings...)
}
func settingByName(name string) (http2.SettingID, bool) {
for _, sid := range [...]http2.SettingID{
http2.SettingHeaderTableSize,
http2.SettingEnablePush,
http2.SettingMaxConcurrentStreams,
http2.SettingInitialWindowSize,
http2.SettingMaxFrameSize,
http2.SettingMaxHeaderListSize,
} {
if strings.EqualFold(sid.String(), name) {
return sid, true
}
}
return 0, false
}
func (app *h2i) cmdPing(args []string) error {
if len(args) > 1 {
app.logf("invalid PING usage: only accepts 0 or 1 args")
return nil // nil means don't end the program
}
var data [8]byte
if len(args) == 1 {
copy(data[:], args[0])
} else {
copy(data[:], "h2i_ping")
}
return app.framer.WritePing(false, data)
}
func (app *h2i) cmdHeaders(args []string) error {
if len(args) > 0 {
app.logf("Error: HEADERS doesn't yet take arguments.")
// TODO: flags for restricting window size, to force CONTINUATION
// frames.
return nil
}
var h1req bytes.Buffer
app.term.SetPrompt("(as HTTP/1.1)> ")
defer app.term.SetPrompt("h2i> ")
for {
line, err := app.term.ReadLine()
if err != nil {
return err
}
h1req.WriteString(line)
h1req.WriteString("\r\n")
if line == "" {
break
}
}
req, err := http.ReadRequest(bufio.NewReader(&h1req))
if err != nil {
app.logf("Invalid HTTP/1.1 request: %v", err)
return nil
}
if app.streamID == 0 {
app.streamID = 1
} else {
app.streamID += 2
}
app.logf("Opening Stream-ID %d:", app.streamID)
hbf := app.encodeHeaders(req)
if len(hbf) > 16<<10 {
app.logf("TODO: h2i doesn't yet write CONTINUATION frames. Copy it from transport.go")
return nil
}
return app.framer.WriteHeaders(http2.HeadersFrameParam{
StreamID: app.streamID,
BlockFragment: hbf,
EndStream: req.Method == "GET" || req.Method == "HEAD", // good enough for now
EndHeaders: true, // for now
})
}
func (app *h2i) readFrames() error {
for {
f, err := app.framer.ReadFrame()
if err != nil {
return fmt.Errorf("ReadFrame: %v", err)
}
app.logf("%v", f)
switch f := f.(type) {
case *http2.PingFrame:
app.logf(" Data = %q", f.Data)
case *http2.SettingsFrame:
f.ForeachSetting(func(s http2.Setting) error {
app.logf(" %v", s)
app.peerSetting[s.ID] = s.Val
return nil
})
case *http2.WindowUpdateFrame:
app.logf(" Window-Increment = %v\n", f.Increment)
case *http2.GoAwayFrame:
app.logf(" Last-Stream-ID = %d; Error-Code = %v (%d)\n", f.LastStreamID, f.ErrCode, f.ErrCode)
case *http2.DataFrame:
app.logf(" %q", f.Data())
case *http2.HeadersFrame:
if f.HasPriority() {
app.logf(" PRIORITY = %v", f.Priority)
}
if app.hdec == nil {
// TODO: if the user uses h2i to send a SETTINGS frame advertising
// something larger, we'll need to respect SETTINGS_HEADER_TABLE_SIZE
// and stuff here instead of using the 4k default. But for now:
tableSize := uint32(4 << 10)
app.hdec = hpack.NewDecoder(tableSize, app.onNewHeaderField)
}
app.hdec.Write(f.HeaderBlockFragment())
}
}
}
// called from readLoop
func (app *h2i) onNewHeaderField(f hpack.HeaderField) {
if f.Sensitive {
app.logf(" %s = %q (SENSITIVE)", f.Name, f.Value)
}
app.logf(" %s = %q", f.Name, f.Value)
}
func (app *h2i) encodeHeaders(req *http.Request) []byte {
app.hbuf.Reset()
// TODO(bradfitz): figure out :authority-vs-Host stuff between http2 and Go
host := req.Host
if host == "" {
host = req.URL.Host
}
path := req.URL.Path
if path == "" {
path = "/"
}
app.writeHeader(":authority", host) // probably not right for all sites
app.writeHeader(":method", req.Method)
app.writeHeader(":path", path)
app.writeHeader(":scheme", "https")
for k, vv := range req.Header {
lowKey := strings.ToLower(k)
if lowKey == "host" {
continue
}
for _, v := range vv {
app.writeHeader(lowKey, v)
}
}
return app.hbuf.Bytes()
}
func (app *h2i) writeHeader(name, value string) {
app.henc.WriteField(hpack.HeaderField{Name: name, Value: value})
app.logf(" %s = %s", name, value)
}

View file

@ -1,331 +0,0 @@
// Copyright 2014 The Go Authors.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
package hpack
import (
"bytes"
"encoding/hex"
"reflect"
"strings"
"testing"
)
func TestEncoderTableSizeUpdate(t *testing.T) {
tests := []struct {
size1, size2 uint32
wantHex string
}{
// Should emit 2 table size updates (2048 and 4096)
{2048, 4096, "3fe10f 3fe11f 82"},
// Should emit 1 table size update (2048)
{16384, 2048, "3fe10f 82"},
}
for _, tt := range tests {
var buf bytes.Buffer
e := NewEncoder(&buf)
e.SetMaxDynamicTableSize(tt.size1)
e.SetMaxDynamicTableSize(tt.size2)
if err := e.WriteField(pair(":method", "GET")); err != nil {
t.Fatal(err)
}
want := removeSpace(tt.wantHex)
if got := hex.EncodeToString(buf.Bytes()); got != want {
t.Errorf("e.SetDynamicTableSize %v, %v = %q; want %q", tt.size1, tt.size2, got, want)
}
}
}
func TestEncoderWriteField(t *testing.T) {
var buf bytes.Buffer
e := NewEncoder(&buf)
var got []HeaderField
d := NewDecoder(4<<10, func(f HeaderField) {
got = append(got, f)
})
tests := []struct {
hdrs []HeaderField
}{
{[]HeaderField{
pair(":method", "GET"),
pair(":scheme", "http"),
pair(":path", "/"),
pair(":authority", "www.example.com"),
}},
{[]HeaderField{
pair(":method", "GET"),
pair(":scheme", "http"),
pair(":path", "/"),
pair(":authority", "www.example.com"),
pair("cache-control", "no-cache"),
}},
{[]HeaderField{
pair(":method", "GET"),
pair(":scheme", "https"),
pair(":path", "/index.html"),
pair(":authority", "www.example.com"),
pair("custom-key", "custom-value"),
}},
}
for i, tt := range tests {
buf.Reset()
got = got[:0]
for _, hf := range tt.hdrs {
if err := e.WriteField(hf); err != nil {
t.Fatal(err)
}
}
_, err := d.Write(buf.Bytes())
if err != nil {
t.Errorf("%d. Decoder Write = %v", i, err)
}
if !reflect.DeepEqual(got, tt.hdrs) {
t.Errorf("%d. Decoded %+v; want %+v", i, got, tt.hdrs)
}
}
}
func TestEncoderSearchTable(t *testing.T) {
e := NewEncoder(nil)
e.dynTab.add(pair("foo", "bar"))
e.dynTab.add(pair("blake", "miz"))
e.dynTab.add(pair(":method", "GET"))
tests := []struct {
hf HeaderField
wantI uint64
wantMatch bool
}{
// Name and Value match
{pair("foo", "bar"), uint64(len(staticTable) + 3), true},
{pair("blake", "miz"), uint64(len(staticTable) + 2), true},
{pair(":method", "GET"), 2, true},
// Only name match because Sensitive == true
{HeaderField{":method", "GET", true}, 2, false},
// Only Name matches
{pair("foo", "..."), uint64(len(staticTable) + 3), false},
{pair("blake", "..."), uint64(len(staticTable) + 2), false},
{pair(":method", "..."), 2, false},
// None match
{pair("foo-", "bar"), 0, false},
}
for _, tt := range tests {
if gotI, gotMatch := e.searchTable(tt.hf); gotI != tt.wantI || gotMatch != tt.wantMatch {
t.Errorf("d.search(%+v) = %v, %v; want %v, %v", tt.hf, gotI, gotMatch, tt.wantI, tt.wantMatch)
}
}
}
func TestAppendVarInt(t *testing.T) {
tests := []struct {
n byte
i uint64
want []byte
}{
// Fits in a byte:
{1, 0, []byte{0}},
{2, 2, []byte{2}},
{3, 6, []byte{6}},
{4, 14, []byte{14}},
{5, 30, []byte{30}},
{6, 62, []byte{62}},
{7, 126, []byte{126}},
{8, 254, []byte{254}},
// Multiple bytes:
{5, 1337, []byte{31, 154, 10}},
}
for _, tt := range tests {
got := appendVarInt(nil, tt.n, tt.i)
if !bytes.Equal(got, tt.want) {
t.Errorf("appendVarInt(nil, %v, %v) = %v; want %v", tt.n, tt.i, got, tt.want)
}
}
}
func TestAppendHpackString(t *testing.T) {
tests := []struct {
s, wantHex string
}{
// Huffman encoded
{"www.example.com", "8c f1e3 c2e5 f23a 6ba0 ab90 f4ff"},
// Not Huffman encoded
{"a", "01 61"},
// zero length
{"", "00"},
}
for _, tt := range tests {
want := removeSpace(tt.wantHex)
buf := appendHpackString(nil, tt.s)
if got := hex.EncodeToString(buf); want != got {
t.Errorf("appendHpackString(nil, %q) = %q; want %q", tt.s, got, want)
}
}
}
func TestAppendIndexed(t *testing.T) {
tests := []struct {
i uint64
wantHex string
}{
// 1 byte
{1, "81"},
{126, "fe"},
// 2 bytes
{127, "ff00"},
{128, "ff01"},
}
for _, tt := range tests {
want := removeSpace(tt.wantHex)
buf := appendIndexed(nil, tt.i)
if got := hex.EncodeToString(buf); want != got {
t.Errorf("appendIndex(nil, %v) = %q; want %q", tt.i, got, want)
}
}
}
func TestAppendNewName(t *testing.T) {
tests := []struct {
f HeaderField
indexing bool
wantHex string
}{
// Incremental indexing
{HeaderField{"custom-key", "custom-value", false}, true, "40 88 25a8 49e9 5ba9 7d7f 89 25a8 49e9 5bb8 e8b4 bf"},
// Without indexing
{HeaderField{"custom-key", "custom-value", false}, false, "00 88 25a8 49e9 5ba9 7d7f 89 25a8 49e9 5bb8 e8b4 bf"},
// Never indexed
{HeaderField{"custom-key", "custom-value", true}, true, "10 88 25a8 49e9 5ba9 7d7f 89 25a8 49e9 5bb8 e8b4 bf"},
{HeaderField{"custom-key", "custom-value", true}, false, "10 88 25a8 49e9 5ba9 7d7f 89 25a8 49e9 5bb8 e8b4 bf"},
}
for _, tt := range tests {
want := removeSpace(tt.wantHex)
buf := appendNewName(nil, tt.f, tt.indexing)
if got := hex.EncodeToString(buf); want != got {
t.Errorf("appendNewName(nil, %+v, %v) = %q; want %q", tt.f, tt.indexing, got, want)
}
}
}
func TestAppendIndexedName(t *testing.T) {
tests := []struct {
f HeaderField
i uint64
indexing bool
wantHex string
}{
// Incremental indexing
{HeaderField{":status", "302", false}, 8, true, "48 82 6402"},
// Without indexing
{HeaderField{":status", "302", false}, 8, false, "08 82 6402"},
// Never indexed
{HeaderField{":status", "302", true}, 8, true, "18 82 6402"},
{HeaderField{":status", "302", true}, 8, false, "18 82 6402"},
}
for _, tt := range tests {
want := removeSpace(tt.wantHex)
buf := appendIndexedName(nil, tt.f, tt.i, tt.indexing)
if got := hex.EncodeToString(buf); want != got {
t.Errorf("appendIndexedName(nil, %+v, %v) = %q; want %q", tt.f, tt.indexing, got, want)
}
}
}
func TestAppendTableSize(t *testing.T) {
tests := []struct {
i uint32
wantHex string
}{
// Fits into 1 byte
{30, "3e"},
// Extra byte
{31, "3f00"},
{32, "3f01"},
}
for _, tt := range tests {
want := removeSpace(tt.wantHex)
buf := appendTableSize(nil, tt.i)
if got := hex.EncodeToString(buf); want != got {
t.Errorf("appendTableSize(nil, %v) = %q; want %q", tt.i, got, want)
}
}
}
func TestEncoderSetMaxDynamicTableSize(t *testing.T) {
var buf bytes.Buffer
e := NewEncoder(&buf)
tests := []struct {
v uint32
wantUpdate bool
wantMinSize uint32
wantMaxSize uint32
}{
// Set new table size to 2048
{2048, true, 2048, 2048},
// Set new table size to 16384, but still limited to
// 4096
{16384, true, 2048, 4096},
}
for _, tt := range tests {
e.SetMaxDynamicTableSize(tt.v)
if got := e.tableSizeUpdate; tt.wantUpdate != got {
t.Errorf("e.tableSizeUpdate = %v; want %v", got, tt.wantUpdate)
}
if got := e.minSize; tt.wantMinSize != got {
t.Errorf("e.minSize = %v; want %v", got, tt.wantMinSize)
}
if got := e.dynTab.maxSize; tt.wantMaxSize != got {
t.Errorf("e.maxSize = %v; want %v", got, tt.wantMaxSize)
}
}
}
func TestEncoderSetMaxDynamicTableSizeLimit(t *testing.T) {
e := NewEncoder(nil)
// 4095 < initialHeaderTableSize means maxSize is truncated to
// 4095.
e.SetMaxDynamicTableSizeLimit(4095)
if got, want := e.dynTab.maxSize, uint32(4095); got != want {
t.Errorf("e.dynTab.maxSize = %v; want %v", got, want)
}
if got, want := e.maxSizeLimit, uint32(4095); got != want {
t.Errorf("e.maxSizeLimit = %v; want %v", got, want)
}
if got, want := e.tableSizeUpdate, true; got != want {
t.Errorf("e.tableSizeUpdate = %v; want %v", got, want)
}
// maxSize will be truncated to maxSizeLimit
e.SetMaxDynamicTableSize(16384)
if got, want := e.dynTab.maxSize, uint32(4095); got != want {
t.Errorf("e.dynTab.maxSize = %v; want %v", got, want)
}
// 8192 > current maxSizeLimit, so maxSize does not change.
e.SetMaxDynamicTableSizeLimit(8192)
if got, want := e.dynTab.maxSize, uint32(4095); got != want {
t.Errorf("e.dynTab.maxSize = %v; want %v", got, want)
}
if got, want := e.maxSizeLimit, uint32(8192); got != want {
t.Errorf("e.maxSizeLimit = %v; want %v", got, want)
}
}
func removeSpace(s string) string {
return strings.Replace(s, " ", "", -1)
}

View file

@ -1,648 +0,0 @@
// Copyright 2014 The Go Authors.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
package hpack
import (
"bufio"
"bytes"
"encoding/hex"
"fmt"
"reflect"
"regexp"
"strconv"
"strings"
"testing"
)
func TestStaticTable(t *testing.T) {
fromSpec := `
+-------+-----------------------------+---------------+
| 1 | :authority | |
| 2 | :method | GET |
| 3 | :method | POST |
| 4 | :path | / |
| 5 | :path | /index.html |
| 6 | :scheme | http |
| 7 | :scheme | https |
| 8 | :status | 200 |
| 9 | :status | 204 |
| 10 | :status | 206 |
| 11 | :status | 304 |
| 12 | :status | 400 |
| 13 | :status | 404 |
| 14 | :status | 500 |
| 15 | accept-charset | |
| 16 | accept-encoding | gzip, deflate |
| 17 | accept-language | |
| 18 | accept-ranges | |
| 19 | accept | |
| 20 | access-control-allow-origin | |
| 21 | age | |
| 22 | allow | |
| 23 | authorization | |
| 24 | cache-control | |
| 25 | content-disposition | |
| 26 | content-encoding | |
| 27 | content-language | |
| 28 | content-length | |
| 29 | content-location | |
| 30 | content-range | |
| 31 | content-type | |
| 32 | cookie | |
| 33 | date | |
| 34 | etag | |
| 35 | expect | |
| 36 | expires | |
| 37 | from | |
| 38 | host | |
| 39 | if-match | |
| 40 | if-modified-since | |
| 41 | if-none-match | |
| 42 | if-range | |
| 43 | if-unmodified-since | |
| 44 | last-modified | |
| 45 | link | |
| 46 | location | |
| 47 | max-forwards | |
| 48 | proxy-authenticate | |
| 49 | proxy-authorization | |
| 50 | range | |
| 51 | referer | |
| 52 | refresh | |
| 53 | retry-after | |
| 54 | server | |
| 55 | set-cookie | |
| 56 | strict-transport-security | |
| 57 | transfer-encoding | |
| 58 | user-agent | |
| 59 | vary | |
| 60 | via | |
| 61 | www-authenticate | |
+-------+-----------------------------+---------------+
`
bs := bufio.NewScanner(strings.NewReader(fromSpec))
re := regexp.MustCompile(`\| (\d+)\s+\| (\S+)\s*\| (\S(.*\S)?)?\s+\|`)
for bs.Scan() {
l := bs.Text()
if !strings.Contains(l, "|") {
continue
}
m := re.FindStringSubmatch(l)
if m == nil {
continue
}
i, err := strconv.Atoi(m[1])
if err != nil {
t.Errorf("Bogus integer on line %q", l)
continue
}
if i < 1 || i > len(staticTable) {
t.Errorf("Bogus index %d on line %q", i, l)
continue
}
if got, want := staticTable[i-1].Name, m[2]; got != want {
t.Errorf("header index %d name = %q; want %q", i, got, want)
}
if got, want := staticTable[i-1].Value, m[3]; got != want {
t.Errorf("header index %d value = %q; want %q", i, got, want)
}
}
if err := bs.Err(); err != nil {
t.Error(err)
}
}
func (d *Decoder) mustAt(idx int) HeaderField {
if hf, ok := d.at(uint64(idx)); !ok {
panic(fmt.Sprintf("bogus index %d", idx))
} else {
return hf
}
}
func TestDynamicTableAt(t *testing.T) {
d := NewDecoder(4096, nil)
at := d.mustAt
if got, want := at(2), (pair(":method", "GET")); got != want {
t.Errorf("at(2) = %v; want %v", got, want)
}
d.dynTab.add(pair("foo", "bar"))
d.dynTab.add(pair("blake", "miz"))
if got, want := at(len(staticTable)+1), (pair("blake", "miz")); got != want {
t.Errorf("at(dyn 1) = %v; want %v", got, want)
}
if got, want := at(len(staticTable)+2), (pair("foo", "bar")); got != want {
t.Errorf("at(dyn 2) = %v; want %v", got, want)
}
if got, want := at(3), (pair(":method", "POST")); got != want {
t.Errorf("at(3) = %v; want %v", got, want)
}
}
func TestDynamicTableSearch(t *testing.T) {
dt := dynamicTable{}
dt.setMaxSize(4096)
dt.add(pair("foo", "bar"))
dt.add(pair("blake", "miz"))
dt.add(pair(":method", "GET"))
tests := []struct {
hf HeaderField
wantI uint64
wantMatch bool
}{
// Name and Value match
{pair("foo", "bar"), 3, true},
{pair(":method", "GET"), 1, true},
// Only name match because of Sensitive == true
{HeaderField{"blake", "miz", true}, 2, false},
// Only Name matches
{pair("foo", "..."), 3, false},
{pair("blake", "..."), 2, false},
{pair(":method", "..."), 1, false},
// None match
{pair("foo-", "bar"), 0, false},
}
for _, tt := range tests {
if gotI, gotMatch := dt.search(tt.hf); gotI != tt.wantI || gotMatch != tt.wantMatch {
t.Errorf("d.search(%+v) = %v, %v; want %v, %v", tt.hf, gotI, gotMatch, tt.wantI, tt.wantMatch)
}
}
}
func TestDynamicTableSizeEvict(t *testing.T) {
d := NewDecoder(4096, nil)
if want := uint32(0); d.dynTab.size != want {
t.Fatalf("size = %d; want %d", d.dynTab.size, want)
}
add := d.dynTab.add
add(pair("blake", "eats pizza"))
if want := uint32(15 + 32); d.dynTab.size != want {
t.Fatalf("after pizza, size = %d; want %d", d.dynTab.size, want)
}
add(pair("foo", "bar"))
if want := uint32(15 + 32 + 6 + 32); d.dynTab.size != want {
t.Fatalf("after foo bar, size = %d; want %d", d.dynTab.size, want)
}
d.dynTab.setMaxSize(15 + 32 + 1 /* slop */)
if want := uint32(6 + 32); d.dynTab.size != want {
t.Fatalf("after setMaxSize, size = %d; want %d", d.dynTab.size, want)
}
if got, want := d.mustAt(len(staticTable)+1), (pair("foo", "bar")); got != want {
t.Errorf("at(dyn 1) = %v; want %v", got, want)
}
add(pair("long", strings.Repeat("x", 500)))
if want := uint32(0); d.dynTab.size != want {
t.Fatalf("after big one, size = %d; want %d", d.dynTab.size, want)
}
}
func TestDecoderDecode(t *testing.T) {
tests := []struct {
name string
in []byte
want []HeaderField
wantDynTab []HeaderField // newest entry first
}{
// C.2.1 Literal Header Field with Indexing
// http://http2.github.io/http2-spec/compression.html#rfc.section.C.2.1
{"C.2.1", dehex("400a 6375 7374 6f6d 2d6b 6579 0d63 7573 746f 6d2d 6865 6164 6572"),
[]HeaderField{pair("custom-key", "custom-header")},
[]HeaderField{pair("custom-key", "custom-header")},
},
// C.2.2 Literal Header Field without Indexing
// http://http2.github.io/http2-spec/compression.html#rfc.section.C.2.2
{"C.2.2", dehex("040c 2f73 616d 706c 652f 7061 7468"),
[]HeaderField{pair(":path", "/sample/path")},
[]HeaderField{}},
// C.2.3 Literal Header Field never Indexed
// http://http2.github.io/http2-spec/compression.html#rfc.section.C.2.3
{"C.2.3", dehex("1008 7061 7373 776f 7264 0673 6563 7265 74"),
[]HeaderField{{"password", "secret", true}},
[]HeaderField{}},
// C.2.4 Indexed Header Field
// http://http2.github.io/http2-spec/compression.html#rfc.section.C.2.4
{"C.2.4", []byte("\x82"),
[]HeaderField{pair(":method", "GET")},
[]HeaderField{}},
}
for _, tt := range tests {
d := NewDecoder(4096, nil)
hf, err := d.DecodeFull(tt.in)
if err != nil {
t.Errorf("%s: %v", tt.name, err)
continue
}
if !reflect.DeepEqual(hf, tt.want) {
t.Errorf("%s: Got %v; want %v", tt.name, hf, tt.want)
}
gotDynTab := d.dynTab.reverseCopy()
if !reflect.DeepEqual(gotDynTab, tt.wantDynTab) {
t.Errorf("%s: dynamic table after = %v; want %v", tt.name, gotDynTab, tt.wantDynTab)
}
}
}
func (dt *dynamicTable) reverseCopy() (hf []HeaderField) {
hf = make([]HeaderField, len(dt.ents))
for i := range hf {
hf[i] = dt.ents[len(dt.ents)-1-i]
}
return
}
type encAndWant struct {
enc []byte
want []HeaderField
wantDynTab []HeaderField
wantDynSize uint32
}
// C.3 Request Examples without Huffman Coding
// http://http2.github.io/http2-spec/compression.html#rfc.section.C.3
func TestDecodeC3_NoHuffman(t *testing.T) {
testDecodeSeries(t, 4096, []encAndWant{
{dehex("8286 8441 0f77 7777 2e65 7861 6d70 6c65 2e63 6f6d"),
[]HeaderField{
pair(":method", "GET"),
pair(":scheme", "http"),
pair(":path", "/"),
pair(":authority", "www.example.com"),
},
[]HeaderField{
pair(":authority", "www.example.com"),
},
57,
},
{dehex("8286 84be 5808 6e6f 2d63 6163 6865"),
[]HeaderField{
pair(":method", "GET"),
pair(":scheme", "http"),
pair(":path", "/"),
pair(":authority", "www.example.com"),
pair("cache-control", "no-cache"),
},
[]HeaderField{
pair("cache-control", "no-cache"),
pair(":authority", "www.example.com"),
},
110,
},
{dehex("8287 85bf 400a 6375 7374 6f6d 2d6b 6579 0c63 7573 746f 6d2d 7661 6c75 65"),
[]HeaderField{
pair(":method", "GET"),
pair(":scheme", "https"),
pair(":path", "/index.html"),
pair(":authority", "www.example.com"),
pair("custom-key", "custom-value"),
},
[]HeaderField{
pair("custom-key", "custom-value"),
pair("cache-control", "no-cache"),
pair(":authority", "www.example.com"),
},
164,
},
})
}
// C.4 Request Examples with Huffman Coding
// http://http2.github.io/http2-spec/compression.html#rfc.section.C.4
func TestDecodeC4_Huffman(t *testing.T) {
testDecodeSeries(t, 4096, []encAndWant{
{dehex("8286 8441 8cf1 e3c2 e5f2 3a6b a0ab 90f4 ff"),
[]HeaderField{
pair(":method", "GET"),
pair(":scheme", "http"),
pair(":path", "/"),
pair(":authority", "www.example.com"),
},
[]HeaderField{
pair(":authority", "www.example.com"),
},
57,
},
{dehex("8286 84be 5886 a8eb 1064 9cbf"),
[]HeaderField{
pair(":method", "GET"),
pair(":scheme", "http"),
pair(":path", "/"),
pair(":authority", "www.example.com"),
pair("cache-control", "no-cache"),
},
[]HeaderField{
pair("cache-control", "no-cache"),
pair(":authority", "www.example.com"),
},
110,
},
{dehex("8287 85bf 4088 25a8 49e9 5ba9 7d7f 8925 a849 e95b b8e8 b4bf"),
[]HeaderField{
pair(":method", "GET"),
pair(":scheme", "https"),
pair(":path", "/index.html"),
pair(":authority", "www.example.com"),
pair("custom-key", "custom-value"),
},
[]HeaderField{
pair("custom-key", "custom-value"),
pair("cache-control", "no-cache"),
pair(":authority", "www.example.com"),
},
164,
},
})
}
// http://http2.github.io/http2-spec/compression.html#rfc.section.C.5
// "This section shows several consecutive header lists, corresponding
// to HTTP responses, on the same connection. The HTTP/2 setting
// parameter SETTINGS_HEADER_TABLE_SIZE is set to the value of 256
// octets, causing some evictions to occur."
func TestDecodeC5_ResponsesNoHuff(t *testing.T) {
testDecodeSeries(t, 256, []encAndWant{
{dehex(`
4803 3330 3258 0770 7269 7661 7465 611d
4d6f 6e2c 2032 3120 4f63 7420 3230 3133
2032 303a 3133 3a32 3120 474d 546e 1768
7474 7073 3a2f 2f77 7777 2e65 7861 6d70
6c65 2e63 6f6d
`),
[]HeaderField{
pair(":status", "302"),
pair("cache-control", "private"),
pair("date", "Mon, 21 Oct 2013 20:13:21 GMT"),
pair("location", "https://www.example.com"),
},
[]HeaderField{
pair("location", "https://www.example.com"),
pair("date", "Mon, 21 Oct 2013 20:13:21 GMT"),
pair("cache-control", "private"),
pair(":status", "302"),
},
222,
},
{dehex("4803 3330 37c1 c0bf"),
[]HeaderField{
pair(":status", "307"),
pair("cache-control", "private"),
pair("date", "Mon, 21 Oct 2013 20:13:21 GMT"),
pair("location", "https://www.example.com"),
},
[]HeaderField{
pair(":status", "307"),
pair("location", "https://www.example.com"),
pair("date", "Mon, 21 Oct 2013 20:13:21 GMT"),
pair("cache-control", "private"),
},
222,
},
{dehex(`
88c1 611d 4d6f 6e2c 2032 3120 4f63 7420
3230 3133 2032 303a 3133 3a32 3220 474d
54c0 5a04 677a 6970 7738 666f 6f3d 4153
444a 4b48 514b 425a 584f 5157 454f 5049
5541 5851 5745 4f49 553b 206d 6178 2d61
6765 3d33 3630 303b 2076 6572 7369 6f6e
3d31
`),
[]HeaderField{
pair(":status", "200"),
pair("cache-control", "private"),
pair("date", "Mon, 21 Oct 2013 20:13:22 GMT"),
pair("location", "https://www.example.com"),
pair("content-encoding", "gzip"),
pair("set-cookie", "foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; version=1"),
},
[]HeaderField{
pair("set-cookie", "foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; version=1"),
pair("content-encoding", "gzip"),
pair("date", "Mon, 21 Oct 2013 20:13:22 GMT"),
},
215,
},
})
}
// http://http2.github.io/http2-spec/compression.html#rfc.section.C.6
// "This section shows the same examples as the previous section, but
// using Huffman encoding for the literal values. The HTTP/2 setting
// parameter SETTINGS_HEADER_TABLE_SIZE is set to the value of 256
// octets, causing some evictions to occur. The eviction mechanism
// uses the length of the decoded literal values, so the same
// evictions occurs as in the previous section."
func TestDecodeC6_ResponsesHuffman(t *testing.T) {
testDecodeSeries(t, 256, []encAndWant{
{dehex(`
4882 6402 5885 aec3 771a 4b61 96d0 7abe
9410 54d4 44a8 2005 9504 0b81 66e0 82a6
2d1b ff6e 919d 29ad 1718 63c7 8f0b 97c8
e9ae 82ae 43d3
`),
[]HeaderField{
pair(":status", "302"),
pair("cache-control", "private"),
pair("date", "Mon, 21 Oct 2013 20:13:21 GMT"),
pair("location", "https://www.example.com"),
},
[]HeaderField{
pair("location", "https://www.example.com"),
pair("date", "Mon, 21 Oct 2013 20:13:21 GMT"),
pair("cache-control", "private"),
pair(":status", "302"),
},
222,
},
{dehex("4883 640e ffc1 c0bf"),
[]HeaderField{
pair(":status", "307"),
pair("cache-control", "private"),
pair("date", "Mon, 21 Oct 2013 20:13:21 GMT"),
pair("location", "https://www.example.com"),
},
[]HeaderField{
pair(":status", "307"),
pair("location", "https://www.example.com"),
pair("date", "Mon, 21 Oct 2013 20:13:21 GMT"),
pair("cache-control", "private"),
},
222,
},
{dehex(`
88c1 6196 d07a be94 1054 d444 a820 0595
040b 8166 e084 a62d 1bff c05a 839b d9ab
77ad 94e7 821d d7f2 e6c7 b335 dfdf cd5b
3960 d5af 2708 7f36 72c1 ab27 0fb5 291f
9587 3160 65c0 03ed 4ee5 b106 3d50 07
`),
[]HeaderField{
pair(":status", "200"),
pair("cache-control", "private"),
pair("date", "Mon, 21 Oct 2013 20:13:22 GMT"),
pair("location", "https://www.example.com"),
pair("content-encoding", "gzip"),
pair("set-cookie", "foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; version=1"),
},
[]HeaderField{
pair("set-cookie", "foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; version=1"),
pair("content-encoding", "gzip"),
pair("date", "Mon, 21 Oct 2013 20:13:22 GMT"),
},
215,
},
})
}
func testDecodeSeries(t *testing.T, size uint32, steps []encAndWant) {
d := NewDecoder(size, nil)
for i, step := range steps {
hf, err := d.DecodeFull(step.enc)
if err != nil {
t.Fatalf("Error at step index %d: %v", i, err)
}
if !reflect.DeepEqual(hf, step.want) {
t.Fatalf("At step index %d: Got headers %v; want %v", i, hf, step.want)
}
gotDynTab := d.dynTab.reverseCopy()
if !reflect.DeepEqual(gotDynTab, step.wantDynTab) {
t.Errorf("After step index %d, dynamic table = %v; want %v", i, gotDynTab, step.wantDynTab)
}
if d.dynTab.size != step.wantDynSize {
t.Errorf("After step index %d, dynamic table size = %v; want %v", i, d.dynTab.size, step.wantDynSize)
}
}
}
func TestHuffmanDecode(t *testing.T) {
tests := []struct {
inHex, want string
}{
{"f1e3 c2e5 f23a 6ba0 ab90 f4ff", "www.example.com"},
{"a8eb 1064 9cbf", "no-cache"},
{"25a8 49e9 5ba9 7d7f", "custom-key"},
{"25a8 49e9 5bb8 e8b4 bf", "custom-value"},
{"6402", "302"},
{"aec3 771a 4b", "private"},
{"d07a be94 1054 d444 a820 0595 040b 8166 e082 a62d 1bff", "Mon, 21 Oct 2013 20:13:21 GMT"},
{"9d29 ad17 1863 c78f 0b97 c8e9 ae82 ae43 d3", "https://www.example.com"},
{"9bd9 ab", "gzip"},
{"94e7 821d d7f2 e6c7 b335 dfdf cd5b 3960 d5af 2708 7f36 72c1 ab27 0fb5 291f 9587 3160 65c0 03ed 4ee5 b106 3d50 07",
"foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; version=1"},
}
for i, tt := range tests {
var buf bytes.Buffer
in, err := hex.DecodeString(strings.Replace(tt.inHex, " ", "", -1))
if err != nil {
t.Errorf("%d. hex input error: %v", i, err)
continue
}
if _, err := HuffmanDecode(&buf, in); err != nil {
t.Errorf("%d. decode error: %v", i, err)
continue
}
if got := buf.String(); tt.want != got {
t.Errorf("%d. decode = %q; want %q", i, got, tt.want)
}
}
}
func TestAppendHuffmanString(t *testing.T) {
tests := []struct {
in, want string
}{
{"www.example.com", "f1e3 c2e5 f23a 6ba0 ab90 f4ff"},
{"no-cache", "a8eb 1064 9cbf"},
{"custom-key", "25a8 49e9 5ba9 7d7f"},
{"custom-value", "25a8 49e9 5bb8 e8b4 bf"},
{"302", "6402"},
{"private", "aec3 771a 4b"},
{"Mon, 21 Oct 2013 20:13:21 GMT", "d07a be94 1054 d444 a820 0595 040b 8166 e082 a62d 1bff"},
{"https://www.example.com", "9d29 ad17 1863 c78f 0b97 c8e9 ae82 ae43 d3"},
{"gzip", "9bd9 ab"},
{"foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; version=1",
"94e7 821d d7f2 e6c7 b335 dfdf cd5b 3960 d5af 2708 7f36 72c1 ab27 0fb5 291f 9587 3160 65c0 03ed 4ee5 b106 3d50 07"},
}
for i, tt := range tests {
buf := []byte{}
want := strings.Replace(tt.want, " ", "", -1)
buf = AppendHuffmanString(buf, tt.in)
if got := hex.EncodeToString(buf); want != got {
t.Errorf("%d. encode = %q; want %q", i, got, want)
}
}
}
func TestReadVarInt(t *testing.T) {
type res struct {
i uint64
consumed int
err error
}
tests := []struct {
n byte
p []byte
want res
}{
// Fits in a byte:
{1, []byte{0}, res{0, 1, nil}},
{2, []byte{2}, res{2, 1, nil}},
{3, []byte{6}, res{6, 1, nil}},
{4, []byte{14}, res{14, 1, nil}},
{5, []byte{30}, res{30, 1, nil}},
{6, []byte{62}, res{62, 1, nil}},
{7, []byte{126}, res{126, 1, nil}},
{8, []byte{254}, res{254, 1, nil}},
// Doesn't fit in a byte:
{1, []byte{1}, res{0, 0, errNeedMore}},
{2, []byte{3}, res{0, 0, errNeedMore}},
{3, []byte{7}, res{0, 0, errNeedMore}},
{4, []byte{15}, res{0, 0, errNeedMore}},
{5, []byte{31}, res{0, 0, errNeedMore}},
{6, []byte{63}, res{0, 0, errNeedMore}},
{7, []byte{127}, res{0, 0, errNeedMore}},
{8, []byte{255}, res{0, 0, errNeedMore}},
// Ignoring top bits:
{5, []byte{255, 154, 10}, res{1337, 3, nil}}, // high dummy three bits: 111
{5, []byte{159, 154, 10}, res{1337, 3, nil}}, // high dummy three bits: 100
{5, []byte{191, 154, 10}, res{1337, 3, nil}}, // high dummy three bits: 101
// Extra byte:
{5, []byte{191, 154, 10, 2}, res{1337, 3, nil}}, // extra byte
// Short a byte:
{5, []byte{191, 154}, res{0, 0, errNeedMore}},
// integer overflow:
{1, []byte{255, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128}, res{0, 0, errVarintOverflow}},
}
for _, tt := range tests {
i, remain, err := readVarInt(tt.n, tt.p)
consumed := len(tt.p) - len(remain)
got := res{i, consumed, err}
if got != tt.want {
t.Errorf("readVarInt(%d, %v ~ %x) = %+v; want %+v", tt.n, tt.p, tt.p, got, tt.want)
}
}
}
func dehex(s string) []byte {
s = strings.Replace(s, " ", "", -1)
s = strings.Replace(s, "\n", "", -1)
b, err := hex.DecodeString(s)
if err != nil {
panic(err)
}
return b
}

View file

@ -1,152 +0,0 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
package http2
import (
"bytes"
"errors"
"flag"
"fmt"
"net/http"
"os/exec"
"strconv"
"strings"
"testing"
"github.com/bradfitz/http2/hpack"
)
var knownFailing = flag.Bool("known_failing", false, "Run known-failing tests.")
func condSkipFailingTest(t *testing.T) {
if !*knownFailing {
t.Skip("Skipping known-failing test without --known_failing")
}
}
func init() {
DebugGoroutines = true
flag.BoolVar(&VerboseLogs, "verboseh2", false, "Verbose HTTP/2 debug logging")
}
func TestSettingString(t *testing.T) {
tests := []struct {
s Setting
want string
}{
{Setting{SettingMaxFrameSize, 123}, "[MAX_FRAME_SIZE = 123]"},
{Setting{1<<16 - 1, 123}, "[UNKNOWN_SETTING_65535 = 123]"},
}
for i, tt := range tests {
got := fmt.Sprint(tt.s)
if got != tt.want {
t.Errorf("%d. for %#v, string = %q; want %q", i, tt.s, got, tt.want)
}
}
}
type twriter struct {
t testing.TB
st *serverTester // optional
}
func (w twriter) Write(p []byte) (n int, err error) {
if w.st != nil {
ps := string(p)
for _, phrase := range w.st.logFilter {
if strings.Contains(ps, phrase) {
return len(p), nil // no logging
}
}
}
w.t.Logf("%s", p)
return len(p), nil
}
// like encodeHeader, but don't add implicit psuedo headers.
func encodeHeaderNoImplicit(t *testing.T, headers ...string) []byte {
var buf bytes.Buffer
enc := hpack.NewEncoder(&buf)
for len(headers) > 0 {
k, v := headers[0], headers[1]
headers = headers[2:]
if err := enc.WriteField(hpack.HeaderField{Name: k, Value: v}); err != nil {
t.Fatalf("HPACK encoding error for %q/%q: %v", k, v, err)
}
}
return buf.Bytes()
}
// Verify that curl has http2.
func requireCurl(t *testing.T) {
out, err := dockerLogs(curl(t, "--version"))
if err != nil {
t.Skipf("failed to determine curl features; skipping test")
}
if !strings.Contains(string(out), "HTTP2") {
t.Skip("curl doesn't support HTTP2; skipping test")
}
}
func curl(t *testing.T, args ...string) (container string) {
out, err := exec.Command("docker", append([]string{"run", "-d", "--net=host", "gohttp2/curl"}, args...)...).CombinedOutput()
if err != nil {
t.Skipf("Failed to run curl in docker: %v, %s", err, out)
}
return strings.TrimSpace(string(out))
}
type puppetCommand struct {
fn func(w http.ResponseWriter, r *http.Request)
done chan<- bool
}
type handlerPuppet struct {
ch chan puppetCommand
}
func newHandlerPuppet() *handlerPuppet {
return &handlerPuppet{
ch: make(chan puppetCommand),
}
}
func (p *handlerPuppet) act(w http.ResponseWriter, r *http.Request) {
for cmd := range p.ch {
cmd.fn(w, r)
cmd.done <- true
}
}
func (p *handlerPuppet) done() { close(p.ch) }
func (p *handlerPuppet) do(fn func(http.ResponseWriter, *http.Request)) {
done := make(chan bool)
p.ch <- puppetCommand{fn, done}
<-done
}
func dockerLogs(container string) ([]byte, error) {
out, err := exec.Command("docker", "wait", container).CombinedOutput()
if err != nil {
return out, err
}
exitStatus, err := strconv.Atoi(strings.TrimSpace(string(out)))
if err != nil {
return out, errors.New("unexpected exit status from docker wait")
}
out, err = exec.Command("docker", "logs", container).CombinedOutput()
exec.Command("docker", "rm", container).Run()
if err == nil && exitStatus != 0 {
err = fmt.Errorf("exit status %d: %s", exitStatus, out)
}
return out, err
}
func kill(container string) {
exec.Command("docker", "kill", container).Run()
exec.Command("docker", "rm", container).Run()
}

View file

@ -1,24 +0,0 @@
// Copyright 2014 The Go Authors.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
package http2
import (
"errors"
"testing"
)
func TestPipeClose(t *testing.T) {
var p pipe
p.c.L = &p.m
a := errors.New("a")
b := errors.New("b")
p.Close(a)
p.Close(b)
_, err := p.Read(make([]byte, 1))
if err != a {
t.Errorf("err = %v want %v", err, a)
}
}

View file

@ -1,121 +0,0 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
package http2
import (
"testing"
)
func TestPriority(t *testing.T) {
// A -> B
// move A's parent to B
streams := make(map[uint32]*stream)
a := &stream{
parent: nil,
weight: 16,
}
streams[1] = a
b := &stream{
parent: a,
weight: 16,
}
streams[2] = b
adjustStreamPriority(streams, 1, PriorityParam{
Weight: 20,
StreamDep: 2,
})
if a.parent != b {
t.Errorf("Expected A's parent to be B")
}
if a.weight != 20 {
t.Errorf("Expected A's weight to be 20; got %d", a.weight)
}
if b.parent != nil {
t.Errorf("Expected B to have no parent")
}
if b.weight != 16 {
t.Errorf("Expected B's weight to be 16; got %d", b.weight)
}
}
func TestPriorityExclusiveZero(t *testing.T) {
// A B and C are all children of the 0 stream.
// Exclusive reprioritization to any of the streams
// should bring the rest of the streams under the
// reprioritized stream
streams := make(map[uint32]*stream)
a := &stream{
parent: nil,
weight: 16,
}
streams[1] = a
b := &stream{
parent: nil,
weight: 16,
}
streams[2] = b
c := &stream{
parent: nil,
weight: 16,
}
streams[3] = c
adjustStreamPriority(streams, 3, PriorityParam{
Weight: 20,
StreamDep: 0,
Exclusive: true,
})
if a.parent != c {
t.Errorf("Expected A's parent to be C")
}
if a.weight != 16 {
t.Errorf("Expected A's weight to be 16; got %d", a.weight)
}
if b.parent != c {
t.Errorf("Expected B's parent to be C")
}
if b.weight != 16 {
t.Errorf("Expected B's weight to be 16; got %d", b.weight)
}
if c.parent != nil {
t.Errorf("Expected C to have no parent")
}
if c.weight != 20 {
t.Errorf("Expected C's weight to be 20; got %d", b.weight)
}
}
func TestPriorityOwnParent(t *testing.T) {
streams := make(map[uint32]*stream)
a := &stream{
parent: nil,
weight: 16,
}
streams[1] = a
b := &stream{
parent: a,
weight: 16,
}
streams[2] = b
adjustStreamPriority(streams, 1, PriorityParam{
Weight: 20,
StreamDep: 1,
})
if a.parent != nil {
t.Errorf("Expected A's parent to be nil")
}
if a.weight != 20 {
t.Errorf("Expected A's weight to be 20; got %d", a.weight)
}
if b.parent != a {
t.Errorf("Expected B's parent to be A")
}
if b.weight != 16 {
t.Errorf("Expected B's weight to be 16; got %d", b.weight)
}
}

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -1,168 +0,0 @@
// Copyright 2015 The Go Authors.
// See https://go.googlesource.com/go/+/master/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://go.googlesource.com/go/+/master/LICENSE
package http2
import (
"flag"
"io"
"io/ioutil"
"net/http"
"os"
"reflect"
"strings"
"testing"
"time"
)
var (
extNet = flag.Bool("extnet", false, "do external network tests")
transportHost = flag.String("transporthost", "http2.golang.org", "hostname to use for TestTransport")
insecure = flag.Bool("insecure", false, "insecure TLS dials")
)
func TestTransportExternal(t *testing.T) {
if !*extNet {
t.Skip("skipping external network test")
}
req, _ := http.NewRequest("GET", "https://"+*transportHost+"/", nil)
rt := &Transport{
InsecureTLSDial: *insecure,
}
res, err := rt.RoundTrip(req)
if err != nil {
t.Fatalf("%v", err)
}
res.Write(os.Stdout)
}
func TestTransport(t *testing.T) {
const body = "sup"
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, body)
})
defer st.Close()
tr := &Transport{InsecureTLSDial: true}
defer tr.CloseIdleConnections()
req, err := http.NewRequest("GET", st.ts.URL, nil)
if err != nil {
t.Fatal(err)
}
res, err := tr.RoundTrip(req)
if err != nil {
t.Fatal(err)
}
defer res.Body.Close()
t.Logf("Got res: %+v", res)
if g, w := res.StatusCode, 200; g != w {
t.Errorf("StatusCode = %v; want %v", g, w)
}
if g, w := res.Status, "200 OK"; g != w {
t.Errorf("Status = %q; want %q", g, w)
}
wantHeader := http.Header{
"Content-Length": []string{"3"},
"Content-Type": []string{"text/plain; charset=utf-8"},
}
if !reflect.DeepEqual(res.Header, wantHeader) {
t.Errorf("res Header = %v; want %v", res.Header, wantHeader)
}
if res.Request != req {
t.Errorf("Response.Request = %p; want %p", res.Request, req)
}
if res.TLS == nil {
t.Error("Response.TLS = nil; want non-nil")
}
slurp, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Errorf("Body read: %v", err)
} else if string(slurp) != body {
t.Errorf("Body = %q; want %q", slurp, body)
}
}
func TestTransportReusesConns(t *testing.T) {
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, r.RemoteAddr)
}, optOnlyServer)
defer st.Close()
tr := &Transport{InsecureTLSDial: true}
defer tr.CloseIdleConnections()
get := func() string {
req, err := http.NewRequest("GET", st.ts.URL, nil)
if err != nil {
t.Fatal(err)
}
res, err := tr.RoundTrip(req)
if err != nil {
t.Fatal(err)
}
defer res.Body.Close()
slurp, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Fatalf("Body read: %v", err)
}
addr := strings.TrimSpace(string(slurp))
if addr == "" {
t.Fatalf("didn't get an addr in response")
}
return addr
}
first := get()
second := get()
if first != second {
t.Errorf("first and second responses were on different connections: %q vs %q", first, second)
}
}
func TestTransportAbortClosesPipes(t *testing.T) {
shutdown := make(chan struct{})
st := newServerTester(t,
func(w http.ResponseWriter, r *http.Request) {
w.(http.Flusher).Flush()
<-shutdown
},
optOnlyServer,
)
defer st.Close()
defer close(shutdown) // we must shutdown before st.Close() to avoid hanging
done := make(chan struct{})
requestMade := make(chan struct{})
go func() {
defer close(done)
tr := &Transport{
InsecureTLSDial: true,
}
req, err := http.NewRequest("GET", st.ts.URL, nil)
if err != nil {
t.Fatal(err)
}
res, err := tr.RoundTrip(req)
if err != nil {
t.Fatal(err)
}
defer res.Body.Close()
close(requestMade)
_, err = ioutil.ReadAll(res.Body)
if err == nil {
t.Error("expected error from res.Body.Read")
}
}()
<-requestMade
// Now force the serve loop to end, via closing the connection.
st.closeConn()
// deadlock? that's a bug.
select {
case <-done:
case <-time.After(3 * time.Second):
t.Fatal("timeout")
}
}

View file

@ -1,357 +0,0 @@
// Copyright 2014 The Go Authors.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
package http2
import (
"bytes"
"encoding/xml"
"flag"
"fmt"
"io"
"os"
"reflect"
"regexp"
"sort"
"strconv"
"strings"
"sync"
"testing"
)
var coverSpec = flag.Bool("coverspec", false, "Run spec coverage tests")
// The global map of sentence coverage for the http2 spec.
var defaultSpecCoverage specCoverage
var loadSpecOnce sync.Once
func loadSpec() {
if f, err := os.Open("testdata/draft-ietf-httpbis-http2.xml"); err != nil {
panic(err)
} else {
defaultSpecCoverage = readSpecCov(f)
f.Close()
}
}
// covers marks all sentences for section sec in defaultSpecCoverage. Sentences not
// "covered" will be included in report outputed by TestSpecCoverage.
func covers(sec, sentences string) {
loadSpecOnce.Do(loadSpec)
defaultSpecCoverage.cover(sec, sentences)
}
type specPart struct {
section string
sentence string
}
func (ss specPart) Less(oo specPart) bool {
atoi := func(s string) int {
n, err := strconv.Atoi(s)
if err != nil {
panic(err)
}
return n
}
a := strings.Split(ss.section, ".")
b := strings.Split(oo.section, ".")
for len(a) > 0 {
if len(b) == 0 {
return false
}
x, y := atoi(a[0]), atoi(b[0])
if x == y {
a, b = a[1:], b[1:]
continue
}
return x < y
}
if len(b) > 0 {
return true
}
return false
}
type bySpecSection []specPart
func (a bySpecSection) Len() int { return len(a) }
func (a bySpecSection) Less(i, j int) bool { return a[i].Less(a[j]) }
func (a bySpecSection) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
type specCoverage struct {
coverage map[specPart]bool
d *xml.Decoder
}
func joinSection(sec []int) string {
s := fmt.Sprintf("%d", sec[0])
for _, n := range sec[1:] {
s = fmt.Sprintf("%s.%d", s, n)
}
return s
}
func (sc specCoverage) readSection(sec []int) {
var (
buf = new(bytes.Buffer)
sub = 0
)
for {
tk, err := sc.d.Token()
if err != nil {
if err == io.EOF {
return
}
panic(err)
}
switch v := tk.(type) {
case xml.StartElement:
if skipElement(v) {
if err := sc.d.Skip(); err != nil {
panic(err)
}
if v.Name.Local == "section" {
sub++
}
break
}
switch v.Name.Local {
case "section":
sub++
sc.readSection(append(sec, sub))
case "xref":
buf.Write(sc.readXRef(v))
}
case xml.CharData:
if len(sec) == 0 {
break
}
buf.Write(v)
case xml.EndElement:
if v.Name.Local == "section" {
sc.addSentences(joinSection(sec), buf.String())
return
}
}
}
}
func (sc specCoverage) readXRef(se xml.StartElement) []byte {
var b []byte
for {
tk, err := sc.d.Token()
if err != nil {
panic(err)
}
switch v := tk.(type) {
case xml.CharData:
if b != nil {
panic("unexpected CharData")
}
b = []byte(string(v))
case xml.EndElement:
if v.Name.Local != "xref" {
panic("expected </xref>")
}
if b != nil {
return b
}
sig := attrSig(se)
switch sig {
case "target":
return []byte(fmt.Sprintf("[%s]", attrValue(se, "target")))
case "fmt-of,rel,target", "fmt-,,rel,target":
return []byte(fmt.Sprintf("[%s, %s]", attrValue(se, "target"), attrValue(se, "rel")))
case "fmt-of,sec,target", "fmt-,,sec,target":
return []byte(fmt.Sprintf("[section %s of %s]", attrValue(se, "sec"), attrValue(se, "target")))
case "fmt-of,rel,sec,target":
return []byte(fmt.Sprintf("[section %s of %s, %s]", attrValue(se, "sec"), attrValue(se, "target"), attrValue(se, "rel")))
default:
panic(fmt.Sprintf("unknown attribute signature %q in %#v", sig, fmt.Sprintf("%#v", se)))
}
default:
panic(fmt.Sprintf("unexpected tag %q", v))
}
}
}
var skipAnchor = map[string]bool{
"intro": true,
"Overview": true,
}
var skipTitle = map[string]bool{
"Acknowledgements": true,
"Change Log": true,
"Document Organization": true,
"Conventions and Terminology": true,
}
func skipElement(s xml.StartElement) bool {
switch s.Name.Local {
case "artwork":
return true
case "section":
for _, attr := range s.Attr {
switch attr.Name.Local {
case "anchor":
if skipAnchor[attr.Value] || strings.HasPrefix(attr.Value, "changes.since.") {
return true
}
case "title":
if skipTitle[attr.Value] {
return true
}
}
}
}
return false
}
func readSpecCov(r io.Reader) specCoverage {
sc := specCoverage{
coverage: map[specPart]bool{},
d: xml.NewDecoder(r)}
sc.readSection(nil)
return sc
}
func (sc specCoverage) addSentences(sec string, sentence string) {
for _, s := range parseSentences(sentence) {
sc.coverage[specPart{sec, s}] = false
}
}
func (sc specCoverage) cover(sec string, sentence string) {
for _, s := range parseSentences(sentence) {
p := specPart{sec, s}
if _, ok := sc.coverage[p]; !ok {
panic(fmt.Sprintf("Not found in spec: %q, %q", sec, s))
}
sc.coverage[specPart{sec, s}] = true
}
}
var whitespaceRx = regexp.MustCompile(`\s+`)
func parseSentences(sens string) []string {
sens = strings.TrimSpace(sens)
if sens == "" {
return nil
}
ss := strings.Split(whitespaceRx.ReplaceAllString(sens, " "), ". ")
for i, s := range ss {
s = strings.TrimSpace(s)
if !strings.HasSuffix(s, ".") {
s += "."
}
ss[i] = s
}
return ss
}
func TestSpecParseSentences(t *testing.T) {
tests := []struct {
ss string
want []string
}{
{"Sentence 1. Sentence 2.",
[]string{
"Sentence 1.",
"Sentence 2.",
}},
{"Sentence 1. \nSentence 2.\tSentence 3.",
[]string{
"Sentence 1.",
"Sentence 2.",
"Sentence 3.",
}},
}
for i, tt := range tests {
got := parseSentences(tt.ss)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("%d: got = %q, want %q", i, got, tt.want)
}
}
}
func TestSpecCoverage(t *testing.T) {
if !*coverSpec {
t.Skip()
}
loadSpecOnce.Do(loadSpec)
var (
list []specPart
cv = defaultSpecCoverage.coverage
total = len(cv)
complete = 0
)
for sp, touched := range defaultSpecCoverage.coverage {
if touched {
complete++
} else {
list = append(list, sp)
}
}
sort.Stable(bySpecSection(list))
if testing.Short() && len(list) > 5 {
list = list[:5]
}
for _, p := range list {
t.Errorf("\tSECTION %s: %s", p.section, p.sentence)
}
t.Logf("%d/%d (%d%%) sentances covered", complete, total, (complete/total)*100)
}
func attrSig(se xml.StartElement) string {
var names []string
for _, attr := range se.Attr {
if attr.Name.Local == "fmt" {
names = append(names, "fmt-"+attr.Value)
} else {
names = append(names, attr.Name.Local)
}
}
sort.Strings(names)
return strings.Join(names, ",")
}
func attrValue(se xml.StartElement, attr string) string {
for _, a := range se.Attr {
if a.Name.Local == attr {
return a.Value
}
}
panic("unknown attribute " + attr)
}
func TestSpecPartLess(t *testing.T) {
tests := []struct {
sec1, sec2 string
want bool
}{
{"6.2.1", "6.2", false},
{"6.2", "6.2.1", true},
{"6.10", "6.10.1", true},
{"6.10", "6.1.1", false}, // 10, not 1
{"6.1", "6.1", false}, // equal, so not less
}
for _, tt := range tests {
got := (specPart{tt.sec1, "foo"}).Less(specPart{tt.sec2, "foo"})
if got != tt.want {
t.Errorf("Less(%q, %q) = %v; want %v", tt.sec1, tt.sec2, got, tt.want)
}
}
}

View file

@ -1,461 +0,0 @@
package bugsnag
import (
"fmt"
"io/ioutil"
"log"
"net"
"net/http"
"strings"
"sync"
"testing"
"time"
"github.com/bitly/go-simplejson"
)
func TestConfigure(t *testing.T) {
Configure(Configuration{
APIKey: testAPIKey,
})
if Config.APIKey != testAPIKey {
t.Errorf("Setting APIKey didn't work")
}
if New().Config.APIKey != testAPIKey {
t.Errorf("Setting APIKey didn't work for new notifiers")
}
}
var postedJSON = make(chan []byte, 10)
var testOnce sync.Once
var testEndpoint string
var testAPIKey = "166f5ad3590596f9aa8d601ea89af845"
func startTestServer() {
testOnce.Do(func() {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body)
if err != nil {
panic(err)
}
postedJSON <- body
})
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
panic(err)
}
testEndpoint = "http://" + l.Addr().String() + "/"
go http.Serve(l, mux)
})
}
type _recurse struct {
*_recurse
}
func TestNotify(t *testing.T) {
startTestServer()
recurse := _recurse{}
recurse._recurse = &recurse
OnBeforeNotify(func(event *Event, config *Configuration) error {
if event.Context == "testing" {
event.GroupingHash = "lol"
}
return nil
})
Notify(fmt.Errorf("hello world"),
Configuration{
APIKey: testAPIKey,
Endpoint: testEndpoint,
ReleaseStage: "test",
AppVersion: "1.2.3",
Hostname: "web1",
ProjectPackages: []string{"github.com/bugsnag/bugsnag-go"},
},
User{Id: "123", Name: "Conrad", Email: "me@cirw.in"},
Context{"testing"},
MetaData{"test": {
"password": "sneaky",
"value": "able",
"broken": complex(1, 2),
"recurse": recurse,
}},
)
json, err := simplejson.NewJson(<-postedJSON)
if err != nil {
t.Fatal(err)
}
if json.Get("apiKey").MustString() != testAPIKey {
t.Errorf("Wrong api key in payload")
}
if json.GetPath("notifier", "name").MustString() != "Bugsnag Go" {
t.Errorf("Wrong notifier name in payload")
}
event := json.Get("events").GetIndex(0)
for k, value := range map[string]string{
"payloadVersion": "2",
"severity": "warning",
"context": "testing",
"groupingHash": "lol",
"app.releaseStage": "test",
"app.version": "1.2.3",
"device.hostname": "web1",
"user.id": "123",
"user.name": "Conrad",
"user.email": "me@cirw.in",
"metaData.test.password": "[REDACTED]",
"metaData.test.value": "able",
"metaData.test.broken": "[complex128]",
"metaData.test.recurse._recurse": "[RECURSION]",
} {
key := strings.Split(k, ".")
if event.GetPath(key...).MustString() != value {
t.Errorf("Wrong %v: %v != %v", key, event.GetPath(key...).MustString(), value)
}
}
exception := event.Get("exceptions").GetIndex(0)
if exception.Get("message").MustString() != "hello world" {
t.Errorf("Wrong message in payload")
}
if exception.Get("errorClass").MustString() != "*errors.errorString" {
t.Errorf("Wrong errorClass in payload: %v", exception.Get("errorClass").MustString())
}
frame0 := exception.Get("stacktrace").GetIndex(0)
if frame0.Get("file").MustString() != "bugsnag_test.go" ||
frame0.Get("method").MustString() != "TestNotify" ||
frame0.Get("inProject").MustBool() != true ||
frame0.Get("lineNumber").MustInt() == 0 {
t.Errorf("Wrong frame0")
}
frame1 := exception.Get("stacktrace").GetIndex(1)
if frame1.Get("file").MustString() != "testing/testing.go" ||
frame1.Get("method").MustString() != "tRunner" ||
frame1.Get("inProject").MustBool() != false ||
frame1.Get("lineNumber").MustInt() == 0 {
t.Errorf("Wrong frame1")
}
}
func crashyHandler(w http.ResponseWriter, r *http.Request) {
c := make(chan int)
close(c)
c <- 1
}
func runCrashyServer(rawData ...interface{}) (net.Listener, error) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return nil, err
}
mux := http.NewServeMux()
mux.HandleFunc("/", crashyHandler)
srv := http.Server{
Addr: l.Addr().String(),
Handler: Handler(mux, rawData...),
ErrorLog: log.New(ioutil.Discard, log.Prefix(), 0),
}
go srv.Serve(l)
return l, err
}
func TestHandler(t *testing.T) {
startTestServer()
l, err := runCrashyServer(Configuration{
APIKey: testAPIKey,
Endpoint: testEndpoint,
ProjectPackages: []string{"github.com/bugsnag/bugsnag-go"},
Logger: log.New(ioutil.Discard, log.Prefix(), log.Flags()),
}, SeverityInfo)
if err != nil {
t.Fatal(err)
}
http.Get("http://" + l.Addr().String() + "/ok?foo=bar")
l.Close()
json, err := simplejson.NewJson(<-postedJSON)
if err != nil {
t.Fatal(err)
}
if json.Get("apiKey").MustString() != testAPIKey {
t.Errorf("Wrong api key in payload")
}
if json.GetPath("notifier", "name").MustString() != "Bugsnag Go" {
t.Errorf("Wrong notifier name in payload")
}
event := json.Get("events").GetIndex(0)
for k, value := range map[string]string{
"payloadVersion": "2",
"severity": "info",
"user.id": "127.0.0.1",
"metaData.Request.Url": "http://" + l.Addr().String() + "/ok?foo=bar",
"metaData.Request.Method": "GET",
} {
key := strings.Split(k, ".")
if event.GetPath(key...).MustString() != value {
t.Errorf("Wrong %v: %v != %v", key, event.GetPath(key...).MustString(), value)
}
}
if event.GetPath("metaData", "Request", "Params", "foo").GetIndex(0).MustString() != "bar" {
t.Errorf("missing GET params in request metadata")
}
if event.GetPath("metaData", "Headers", "Accept-Encoding").GetIndex(0).MustString() != "gzip" {
t.Errorf("missing GET params in request metadata: %v", event.GetPath("metaData", "Headers"))
}
exception := event.Get("exceptions").GetIndex(0)
if exception.Get("message").MustString() != "runtime error: send on closed channel" {
t.Errorf("Wrong message in payload: %v", exception.Get("message").MustString())
}
if exception.Get("errorClass").MustString() != "runtime.errorCString" {
t.Errorf("Wrong errorClass in payload: %v", exception.Get("errorClass").MustString())
}
// TODO:CI these are probably dependent on go version.
frame0 := exception.Get("stacktrace").GetIndex(0)
if frame0.Get("file").MustString() != "runtime/panic.c" ||
frame0.Get("method").MustString() != "panicstring" ||
frame0.Get("inProject").MustBool() != false ||
frame0.Get("lineNumber").MustInt() == 0 {
t.Errorf("Wrong frame0: %v", frame0)
}
frame3 := exception.Get("stacktrace").GetIndex(3)
if frame3.Get("file").MustString() != "bugsnag_test.go" ||
frame3.Get("method").MustString() != "crashyHandler" ||
frame3.Get("inProject").MustBool() != true ||
frame3.Get("lineNumber").MustInt() == 0 {
t.Errorf("Wrong frame3: %v", frame3)
}
}
func TestAutoNotify(t *testing.T) {
var panicked interface{}
func() {
defer func() {
panicked = recover()
}()
defer AutoNotify(Configuration{Endpoint: testEndpoint, APIKey: testAPIKey})
panic("eggs")
}()
if panicked.(string) != "eggs" {
t.Errorf("didn't re-panic")
}
json, err := simplejson.NewJson(<-postedJSON)
if err != nil {
t.Fatal(err)
}
event := json.Get("events").GetIndex(0)
if event.Get("severity").MustString() != "error" {
t.Errorf("severity should be error")
}
exception := event.Get("exceptions").GetIndex(0)
if exception.Get("message").MustString() != "eggs" {
t.Errorf("caught wrong panic")
}
}
func TestRecover(t *testing.T) {
var panicked interface{}
func() {
defer func() {
panicked = recover()
}()
defer Recover(Configuration{Endpoint: testEndpoint, APIKey: testAPIKey})
panic("ham")
}()
if panicked != nil {
t.Errorf("re-panick'd")
}
json, err := simplejson.NewJson(<-postedJSON)
if err != nil {
t.Fatal(err)
}
event := json.Get("events").GetIndex(0)
if event.Get("severity").MustString() != "warning" {
t.Errorf("severity should be warning")
}
exception := event.Get("exceptions").GetIndex(0)
if exception.Get("message").MustString() != "ham" {
t.Errorf("caught wrong panic")
}
}
func handleGet(w http.ResponseWriter, r *http.Request) {
}
var createAccount = handleGet
type _job struct {
Name string
Process func()
}
func ExampleAutoNotify() interface{} {
return func(w http.ResponseWriter, request *http.Request) {
defer AutoNotify(request, Context{"createAccount"})
createAccount(w, request)
}
}
func ExampleRecover(job _job) {
go func() {
defer Recover(Context{job.Name}, SeverityWarning)
job.Process()
}()
}
func ExampleConfigure() {
Configure(Configuration{
APIKey: "YOUR_API_KEY_HERE",
ReleaseStage: "production",
// See Configuration{} for other fields
})
}
func ExampleHandler() {
// Set up your http handlers as usual
http.HandleFunc("/", handleGet)
// use bugsnag.Handler(nil) to wrap the default http handlers
// so that Bugsnag is automatically notified about panics.
http.ListenAndServe(":1234", Handler(nil))
}
func ExampleHandler_customServer() {
// If you're using a custom server, set the handlers explicitly.
http.HandleFunc("/", handleGet)
srv := http.Server{
Addr: ":1234",
ReadTimeout: 10 * time.Second,
// use bugsnag.Handler(nil) to wrap the default http handlers
// so that Bugsnag is automatically notified about panics.
Handler: Handler(nil),
}
srv.ListenAndServe()
}
func ExampleHandler_customHandlers() {
// If you're using custom handlers, wrap the handlers explicitly.
handler := http.NewServeMux()
http.HandleFunc("/", handleGet)
// use bugsnag.Handler(handler) to wrap the handlers so that Bugsnag is
// automatically notified about panics
http.ListenAndServe(":1234", Handler(handler))
}
func ExampleNotify() {
_, err := net.Listen("tcp", ":80")
if err != nil {
Notify(err)
}
}
func ExampleNotify_details(userID string) {
_, err := net.Listen("tcp", ":80")
if err != nil {
Notify(err,
// show as low-severity
SeverityInfo,
// set the context
Context{"createlistener"},
// pass the user id in to count users affected.
User{Id: userID},
// custom meta-data tab
MetaData{
"Listen": {
"Protocol": "tcp",
"Port": "80",
},
},
)
}
}
type Job struct {
Retry bool
UserId string
UserEmail string
Name string
Params map[string]string
}
func ExampleOnBeforeNotify() {
OnBeforeNotify(func(event *Event, config *Configuration) error {
// Search all the RawData for any *Job pointers that we're passed in
// to bugsnag.Notify() and friends.
for _, datum := range event.RawData {
if job, ok := datum.(*Job); ok {
// don't notify bugsnag about errors in retries
if job.Retry {
return fmt.Errorf("bugsnag middleware: not notifying about job retry")
}
// add the job as a tab on Bugsnag.com
event.MetaData.AddStruct("Job", job)
// set the user correctly
event.User = &User{Id: job.UserId, Email: job.UserEmail}
}
}
// continue notifying as normal
return nil
})
}

View file

@ -1,58 +0,0 @@
package bugsnag
import (
"testing"
)
func TestNotifyReleaseStages(t *testing.T) {
var testCases = []struct {
stage string
configured []string
notify bool
msg string
}{
{
stage: "production",
notify: true,
msg: "Should notify in all release stages by default",
},
{
stage: "production",
configured: []string{"development", "production"},
notify: true,
msg: "Failed to notify in configured release stage",
},
{
stage: "staging",
configured: []string{"development", "production"},
notify: false,
msg: "Failed to prevent notification in excluded release stage",
},
}
for _, testCase := range testCases {
Configure(Configuration{ReleaseStage: testCase.stage, NotifyReleaseStages: testCase.configured})
if Config.notifyInReleaseStage() != testCase.notify {
t.Error(testCase.msg)
}
}
}
func TestProjectPackages(t *testing.T) {
Configure(Configuration{ProjectPackages: []string{"main", "github.com/ConradIrwin/*"}})
if !Config.isProjectPackage("main") {
t.Error("literal project package doesn't work")
}
if !Config.isProjectPackage("github.com/ConradIrwin/foo") {
t.Error("wildcard project package doesn't work")
}
if Config.isProjectPackage("runtime") {
t.Error("wrong packges being marked in project")
}
if Config.isProjectPackage("github.com/ConradIrwin/foo/bar") {
t.Error("wrong packges being marked in project")
}
}

View file

@ -1,117 +0,0 @@
package errors
import (
"bytes"
"fmt"
"io"
"runtime/debug"
"testing"
)
func TestStackFormatMatches(t *testing.T) {
defer func() {
err := recover()
if err != 'a' {
t.Fatal(err)
}
bs := [][]byte{Errorf("hi").Stack(), debug.Stack()}
// Ignore the first line (as it contains the PC of the .Stack() call)
bs[0] = bytes.SplitN(bs[0], []byte("\n"), 2)[1]
bs[1] = bytes.SplitN(bs[1], []byte("\n"), 2)[1]
if bytes.Compare(bs[0], bs[1]) != 0 {
t.Errorf("Stack didn't match")
t.Errorf("%s", bs[0])
t.Errorf("%s", bs[1])
}
}()
a()
}
func TestSkipWorks(t *testing.T) {
defer func() {
err := recover()
if err != 'a' {
t.Fatal(err)
}
bs := [][]byte{New("hi", 2).Stack(), debug.Stack()}
// should skip four lines of debug.Stack()
bs[1] = bytes.SplitN(bs[1], []byte("\n"), 5)[4]
if bytes.Compare(bs[0], bs[1]) != 0 {
t.Errorf("Stack didn't match")
t.Errorf("%s", bs[0])
t.Errorf("%s", bs[1])
}
}()
a()
}
func TestNewError(t *testing.T) {
e := func() error {
return New("hi", 1)
}()
if e.Error() != "hi" {
t.Errorf("Constructor with a string failed")
}
if New(fmt.Errorf("yo"), 0).Error() != "yo" {
t.Errorf("Constructor with an error failed")
}
if New(e, 0) != e {
t.Errorf("Constructor with an Error failed")
}
if New(nil, 0).Error() != "<nil>" {
t.Errorf("Constructor with nil failed")
}
}
func ExampleErrorf(x int) (int, error) {
if x%2 == 1 {
return 0, Errorf("can only halve even numbers, got %d", x)
}
return x / 2, nil
}
func ExampleNewError() (error, error) {
// Wrap io.EOF with the current stack-trace and return it
return nil, New(io.EOF, 0)
}
func ExampleNewError_skip() {
defer func() {
if err := recover(); err != nil {
// skip 1 frame (the deferred function) and then return the wrapped err
err = New(err, 1)
}
}()
}
func ExampleError_Stack(err Error) {
fmt.Printf("Error: %s\n%s", err.Error(), err.Stack())
}
func a() error {
b(5)
return nil
}
func b(i int) {
c()
}
func c() {
panic('a')
}

View file

@ -1,142 +0,0 @@
package errors
import (
"reflect"
"testing"
)
var createdBy = `panic: hello!
goroutine 54 [running]:
runtime.panic(0x35ce40, 0xc208039db0)
/0/c/go/src/pkg/runtime/panic.c:279 +0xf5
github.com/loopj/bugsnag-example-apps/go/revelapp/app/controllers.func·001()
/0/go/src/github.com/loopj/bugsnag-example-apps/go/revelapp/app/controllers/app.go:13 +0x74
net/http.(*Server).Serve(0xc20806c780, 0x910c88, 0xc20803e168, 0x0, 0x0)
/0/c/go/src/pkg/net/http/server.go:1698 +0x91
created by github.com/loopj/bugsnag-example-apps/go/revelapp/app/controllers.App.Index
/0/go/src/github.com/loopj/bugsnag-example-apps/go/revelapp/app/controllers/app.go:14 +0x3e
goroutine 16 [IO wait]:
net.runtime_pollWait(0x911c30, 0x72, 0x0)
/0/c/go/src/pkg/runtime/netpoll.goc:146 +0x66
net.(*pollDesc).Wait(0xc2080ba990, 0x72, 0x0, 0x0)
/0/c/go/src/pkg/net/fd_poll_runtime.go:84 +0x46
net.(*pollDesc).WaitRead(0xc2080ba990, 0x0, 0x0)
/0/c/go/src/pkg/net/fd_poll_runtime.go:89 +0x42
net.(*netFD).accept(0xc2080ba930, 0x58be30, 0x0, 0x9103f0, 0x23)
/0/c/go/src/pkg/net/fd_unix.go:409 +0x343
net.(*TCPListener).AcceptTCP(0xc20803e168, 0x8, 0x0, 0x0)
/0/c/go/src/pkg/net/tcpsock_posix.go:234 +0x5d
net.(*TCPListener).Accept(0xc20803e168, 0x0, 0x0, 0x0, 0x0)
/0/c/go/src/pkg/net/tcpsock_posix.go:244 +0x4b
github.com/revel/revel.Run(0xe6d9)
/0/go/src/github.com/revel/revel/server.go:113 +0x926
main.main()
/0/go/src/github.com/loopj/bugsnag-example-apps/go/revelapp/app/tmp/main.go:109 +0xe1a
`
var normalSplit = `panic: hello!
goroutine 54 [running]:
runtime.panic(0x35ce40, 0xc208039db0)
/0/c/go/src/pkg/runtime/panic.c:279 +0xf5
github.com/loopj/bugsnag-example-apps/go/revelapp/app/controllers.func·001()
/0/go/src/github.com/loopj/bugsnag-example-apps/go/revelapp/app/controllers/app.go:13 +0x74
net/http.(*Server).Serve(0xc20806c780, 0x910c88, 0xc20803e168, 0x0, 0x0)
/0/c/go/src/pkg/net/http/server.go:1698 +0x91
goroutine 16 [IO wait]:
net.runtime_pollWait(0x911c30, 0x72, 0x0)
/0/c/go/src/pkg/runtime/netpoll.goc:146 +0x66
net.(*pollDesc).Wait(0xc2080ba990, 0x72, 0x0, 0x0)
/0/c/go/src/pkg/net/fd_poll_runtime.go:84 +0x46
net.(*pollDesc).WaitRead(0xc2080ba990, 0x0, 0x0)
/0/c/go/src/pkg/net/fd_poll_runtime.go:89 +0x42
net.(*netFD).accept(0xc2080ba930, 0x58be30, 0x0, 0x9103f0, 0x23)
/0/c/go/src/pkg/net/fd_unix.go:409 +0x343
net.(*TCPListener).AcceptTCP(0xc20803e168, 0x8, 0x0, 0x0)
/0/c/go/src/pkg/net/tcpsock_posix.go:234 +0x5d
net.(*TCPListener).Accept(0xc20803e168, 0x0, 0x0, 0x0, 0x0)
/0/c/go/src/pkg/net/tcpsock_posix.go:244 +0x4b
github.com/revel/revel.Run(0xe6d9)
/0/go/src/github.com/revel/revel/server.go:113 +0x926
main.main()
/0/go/src/github.com/loopj/bugsnag-example-apps/go/revelapp/app/tmp/main.go:109 +0xe1a
`
var lastGoroutine = `panic: hello!
goroutine 16 [IO wait]:
net.runtime_pollWait(0x911c30, 0x72, 0x0)
/0/c/go/src/pkg/runtime/netpoll.goc:146 +0x66
net.(*pollDesc).Wait(0xc2080ba990, 0x72, 0x0, 0x0)
/0/c/go/src/pkg/net/fd_poll_runtime.go:84 +0x46
net.(*pollDesc).WaitRead(0xc2080ba990, 0x0, 0x0)
/0/c/go/src/pkg/net/fd_poll_runtime.go:89 +0x42
net.(*netFD).accept(0xc2080ba930, 0x58be30, 0x0, 0x9103f0, 0x23)
/0/c/go/src/pkg/net/fd_unix.go:409 +0x343
net.(*TCPListener).AcceptTCP(0xc20803e168, 0x8, 0x0, 0x0)
/0/c/go/src/pkg/net/tcpsock_posix.go:234 +0x5d
net.(*TCPListener).Accept(0xc20803e168, 0x0, 0x0, 0x0, 0x0)
/0/c/go/src/pkg/net/tcpsock_posix.go:244 +0x4b
github.com/revel/revel.Run(0xe6d9)
/0/go/src/github.com/revel/revel/server.go:113 +0x926
main.main()
/0/go/src/github.com/loopj/bugsnag-example-apps/go/revelapp/app/tmp/main.go:109 +0xe1a
goroutine 54 [running]:
runtime.panic(0x35ce40, 0xc208039db0)
/0/c/go/src/pkg/runtime/panic.c:279 +0xf5
github.com/loopj/bugsnag-example-apps/go/revelapp/app/controllers.func·001()
/0/go/src/github.com/loopj/bugsnag-example-apps/go/revelapp/app/controllers/app.go:13 +0x74
net/http.(*Server).Serve(0xc20806c780, 0x910c88, 0xc20803e168, 0x0, 0x0)
/0/c/go/src/pkg/net/http/server.go:1698 +0x91
`
var result = []StackFrame{
StackFrame{File: "/0/c/go/src/pkg/runtime/panic.c", LineNumber: 279, Name: "panic", Package: "runtime"},
StackFrame{File: "/0/go/src/github.com/loopj/bugsnag-example-apps/go/revelapp/app/controllers/app.go", LineNumber: 13, Name: "func.001", Package: "github.com/loopj/bugsnag-example-apps/go/revelapp/app/controllers"},
StackFrame{File: "/0/c/go/src/pkg/net/http/server.go", LineNumber: 1698, Name: "(*Server).Serve", Package: "net/http"},
}
var resultCreatedBy = append(result,
StackFrame{File: "/0/go/src/github.com/loopj/bugsnag-example-apps/go/revelapp/app/controllers/app.go", LineNumber: 14, Name: "App.Index", Package: "github.com/loopj/bugsnag-example-apps/go/revelapp/app/controllers", ProgramCounter: 0x0})
func TestParsePanic(t *testing.T) {
todo := map[string]string{
"createdBy": createdBy,
"normalSplit": normalSplit,
"lastGoroutine": lastGoroutine,
}
for key, val := range todo {
Err, err := ParsePanic(val)
if err != nil {
t.Fatal(err)
}
if Err.TypeName() != "panic" {
t.Errorf("Wrong type: %s", Err.TypeName())
}
if Err.Error() != "hello!" {
t.Errorf("Wrong message: %s", Err.TypeName())
}
if Err.StackFrames()[0].Func() != nil {
t.Errorf("Somehow managed to find a func...")
}
result := result
if key == "createdBy" {
result = resultCreatedBy
}
if !reflect.DeepEqual(Err.StackFrames(), result) {
t.Errorf("Wrong stack for %s: %#v", key, Err.StackFrames())
}
}
}

View file

@ -1,182 +0,0 @@
package bugsnag
import (
"reflect"
"testing"
"unsafe"
"github.com/bugsnag/bugsnag-go/errors"
)
type _account struct {
ID string
Name string
Plan struct {
Premium bool
}
Password string
secret string
Email string `json:"email"`
EmptyEmail string `json:"emptyemail,omitempty"`
NotEmptyEmail string `json:"not_empty_email,omitempty"`
}
type _broken struct {
Me *_broken
Data string
}
var account = _account{}
var notifier = New(Configuration{})
func TestMetaDataAdd(t *testing.T) {
m := MetaData{
"one": {
"key": "value",
"override": false,
}}
m.Add("one", "override", true)
m.Add("one", "new", "key")
m.Add("new", "tab", account)
m.AddStruct("lol", "not really a struct")
m.AddStruct("account", account)
if !reflect.DeepEqual(m, MetaData{
"one": {
"key": "value",
"override": true,
"new": "key",
},
"new": {
"tab": account,
},
"Extra data": {
"lol": "not really a struct",
},
"account": {
"ID": "",
"Name": "",
"Plan": map[string]interface{}{
"Premium": false,
},
"Password": "",
"email": "",
},
}) {
t.Errorf("metadata.Add didn't work: %#v", m)
}
}
func TestMetaDataUpdate(t *testing.T) {
m := MetaData{
"one": {
"key": "value",
"override": false,
}}
m.Update(MetaData{
"one": {
"override": true,
"new": "key",
},
"new": {
"tab": account,
},
})
if !reflect.DeepEqual(m, MetaData{
"one": {
"key": "value",
"override": true,
"new": "key",
},
"new": {
"tab": account,
},
}) {
t.Errorf("metadata.Update didn't work: %#v", m)
}
}
func TestMetaDataSanitize(t *testing.T) {
var broken = _broken{}
broken.Me = &broken
broken.Data = "ohai"
account.Name = "test"
account.ID = "test"
account.secret = "hush"
account.Email = "example@example.com"
account.EmptyEmail = ""
account.NotEmptyEmail = "not_empty_email@example.com"
m := MetaData{
"one": {
"bool": true,
"int": 7,
"float": 7.1,
"complex": complex(1, 1),
"func": func() {},
"unsafe": unsafe.Pointer(broken.Me),
"string": "string",
"password": "secret",
"array": []hash{{
"creditcard": "1234567812345678",
"broken": broken,
}},
"broken": broken,
"account": account,
},
}
n := m.sanitize([]string{"password", "creditcard"})
if !reflect.DeepEqual(n, map[string]interface{}{
"one": map[string]interface{}{
"bool": true,
"int": 7,
"float": 7.1,
"complex": "[complex128]",
"string": "string",
"unsafe": "[unsafe.Pointer]",
"func": "[func()]",
"password": "[REDACTED]",
"array": []interface{}{map[string]interface{}{
"creditcard": "[REDACTED]",
"broken": map[string]interface{}{
"Me": "[RECURSION]",
"Data": "ohai",
},
}},
"broken": map[string]interface{}{
"Me": "[RECURSION]",
"Data": "ohai",
},
"account": map[string]interface{}{
"ID": "test",
"Name": "test",
"Plan": map[string]interface{}{
"Premium": false,
},
"Password": "[REDACTED]",
"email": "example@example.com",
"not_empty_email": "not_empty_email@example.com",
},
},
}) {
t.Errorf("metadata.Sanitize didn't work: %#v", n)
}
}
func ExampleMetaData() {
notifier.Notify(errors.Errorf("hi world"),
MetaData{"Account": {
"id": account.ID,
"name": account.Name,
"paying?": account.Plan.Premium,
}})
}

View file

@ -1,88 +0,0 @@
package bugsnag
import (
"bytes"
"fmt"
"log"
"reflect"
"testing"
)
func TestMiddlewareOrder(t *testing.T) {
result := make([]int, 0, 7)
stack := middlewareStack{}
stack.OnBeforeNotify(func(e *Event, c *Configuration) error {
result = append(result, 2)
return nil
})
stack.OnBeforeNotify(func(e *Event, c *Configuration) error {
result = append(result, 1)
return nil
})
stack.OnBeforeNotify(func(e *Event, c *Configuration) error {
result = append(result, 0)
return nil
})
stack.Run(nil, nil, func() error {
result = append(result, 3)
return nil
})
if !reflect.DeepEqual(result, []int{0, 1, 2, 3}) {
t.Errorf("unexpected middleware order %v", result)
}
}
func TestBeforeNotifyReturnErr(t *testing.T) {
stack := middlewareStack{}
err := fmt.Errorf("test")
stack.OnBeforeNotify(func(e *Event, c *Configuration) error {
return err
})
called := false
e := stack.Run(nil, nil, func() error {
called = true
return nil
})
if e != err {
t.Errorf("Middleware didn't return the error")
}
if called == true {
t.Errorf("Notify was called when BeforeNotify returned False")
}
}
func TestBeforeNotifyPanic(t *testing.T) {
stack := middlewareStack{}
stack.OnBeforeNotify(func(e *Event, c *Configuration) error {
panic("oops")
})
called := false
b := &bytes.Buffer{}
stack.Run(nil, &Configuration{Logger: log.New(b, log.Prefix(), 0)}, func() error {
called = true
return nil
})
logged := b.String()
if logged != "bugsnag/middleware: unexpected panic: oops\n" {
t.Errorf("Logged: %s", logged)
}
if called == false {
t.Errorf("Notify was not called when BeforeNotify panicked")
}
}

View file

@ -1,79 +0,0 @@
// +build !appengine
package bugsnag
import (
"github.com/bitly/go-simplejson"
"github.com/mitchellh/osext"
"os"
"os/exec"
"testing"
"time"
)
func TestPanicHandler(t *testing.T) {
startTestServer()
exePath, err := osext.Executable()
if err != nil {
t.Fatal(err)
}
// Use the same trick as panicwrap() to re-run ourselves.
// In the init() block below, we will then panic.
cmd := exec.Command(exePath, os.Args[1:]...)
cmd.Env = append(os.Environ(), "BUGSNAG_API_KEY="+testAPIKey, "BUGSNAG_ENDPOINT="+testEndpoint, "please_panic=please_panic")
if err = cmd.Start(); err != nil {
t.Fatal(err)
}
if err = cmd.Wait(); err.Error() != "exit status 2" {
t.Fatal(err)
}
json, err := simplejson.NewJson(<-postedJSON)
if err != nil {
t.Fatal(err)
}
event := json.Get("events").GetIndex(0)
if event.Get("severity").MustString() != "error" {
t.Errorf("severity should be error")
}
exception := event.Get("exceptions").GetIndex(0)
if exception.Get("message").MustString() != "ruh roh" {
t.Errorf("caught wrong panic")
}
if exception.Get("errorClass").MustString() != "panic" {
t.Errorf("caught wrong panic")
}
frame := exception.Get("stacktrace").GetIndex(1)
// Yeah, we just caught a panic from the init() function below and sent it to the server running above (mindblown)
if frame.Get("inProject").MustBool() != true ||
frame.Get("file").MustString() != "panicwrap_test.go" ||
frame.Get("method").MustString() != "panick" ||
frame.Get("lineNumber").MustInt() == 0 {
t.Errorf("stack trace seemed wrong")
}
}
func init() {
if os.Getenv("please_panic") != "" {
Configure(Configuration{APIKey: os.Getenv("BUGSNAG_API_KEY"), Endpoint: os.Getenv("BUGSNAG_ENDPOINT"), ProjectPackages: []string{"github.com/bugsnag/bugsnag-go"}})
go func() {
panick()
}()
// Plenty of time to crash, it shouldn't need any of it.
time.Sleep(1 * time.Second)
}
}
func panick() {
panic("ruh roh")
}

View file

@ -1,60 +0,0 @@
// Package bugsnagrevel adds Bugsnag to revel.
// It lets you pass *revel.Controller into bugsnag.Notify(),
// and provides a Filter to catch errors.
package bugsnagrevel
import (
"strings"
"sync"
"github.com/bugsnag/bugsnag-go"
"github.com/revel/revel"
)
var once sync.Once
// Filter should be added to the filter chain just after the PanicFilter.
// It sends errors to Bugsnag automatically. Configuration is read out of
// conf/app.conf, you should set bugsnag.apikey, and can also set
// bugsnag.endpoint, bugsnag.releasestage, bugsnag.appversion,
// bugsnag.projectroot, bugsnag.projectpackages if needed.
func Filter(c *revel.Controller, fc []revel.Filter) {
defer bugsnag.AutoNotify(c)
fc[0](c, fc[1:])
}
// Add support to bugsnag for reading data out of *revel.Controllers
func middleware(event *bugsnag.Event, config *bugsnag.Configuration) error {
for _, datum := range event.RawData {
if controller, ok := datum.(*revel.Controller); ok {
// make the request visible to the builtin HttpMIddleware
event.RawData = append(event.RawData, controller.Request.Request)
event.Context = controller.Action
event.MetaData.AddStruct("Session", controller.Session)
}
}
return nil
}
func init() {
revel.OnAppStart(func() {
bugsnag.OnBeforeNotify(middleware)
var projectPackages []string
if packages, ok := revel.Config.String("bugsnag.projectpackages"); ok {
projectPackages = strings.Split(packages, ",")
} else {
projectPackages = []string{revel.ImportPath + "/app/*", revel.ImportPath + "/app"}
}
bugsnag.Configure(bugsnag.Configuration{
APIKey: revel.Config.StringDefault("bugsnag.apikey", ""),
Endpoint: revel.Config.StringDefault("bugsnag.endpoint", ""),
AppVersion: revel.Config.StringDefault("bugsnag.appversion", ""),
ReleaseStage: revel.Config.StringDefault("bugsnag.releasestage", revel.RunMode),
ProjectPackages: projectPackages,
Logger: revel.ERROR,
})
})
}

View file

@ -1,79 +0,0 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build darwin linux freebsd netbsd windows
package osext
import (
"fmt"
"os"
oexec "os/exec"
"path/filepath"
"runtime"
"testing"
)
const execPath_EnvVar = "OSTEST_OUTPUT_EXECPATH"
func TestExecPath(t *testing.T) {
ep, err := Executable()
if err != nil {
t.Fatalf("ExecPath failed: %v", err)
}
// we want fn to be of the form "dir/prog"
dir := filepath.Dir(filepath.Dir(ep))
fn, err := filepath.Rel(dir, ep)
if err != nil {
t.Fatalf("filepath.Rel: %v", err)
}
cmd := &oexec.Cmd{}
// make child start with a relative program path
cmd.Dir = dir
cmd.Path = fn
// forge argv[0] for child, so that we can verify we could correctly
// get real path of the executable without influenced by argv[0].
cmd.Args = []string{"-", "-test.run=XXXX"}
cmd.Env = []string{fmt.Sprintf("%s=1", execPath_EnvVar)}
out, err := cmd.CombinedOutput()
if err != nil {
t.Fatalf("exec(self) failed: %v", err)
}
outs := string(out)
if !filepath.IsAbs(outs) {
t.Fatalf("Child returned %q, want an absolute path", out)
}
if !sameFile(outs, ep) {
t.Fatalf("Child returned %q, not the same file as %q", out, ep)
}
}
func sameFile(fn1, fn2 string) bool {
fi1, err := os.Stat(fn1)
if err != nil {
return false
}
fi2, err := os.Stat(fn2)
if err != nil {
return false
}
return os.SameFile(fi1, fi2)
}
func init() {
if e := os.Getenv(execPath_EnvVar); e != "" {
// first chdir to another path
dir := "/"
if runtime.GOOS == "windows" {
dir = filepath.VolumeName(".")
}
os.Chdir(dir)
if ep, err := Executable(); err != nil {
fmt.Fprint(os.Stderr, "ERROR: ", err)
} else {
fmt.Fprint(os.Stderr, ep)
}
os.Exit(0)
}
}

View file

@ -23,7 +23,7 @@ func EncodeSorted(values url.Values) string {
// preallocate the arrays for perfomance // preallocate the arrays for perfomance
keys := make([]string, 0, len(values)) keys := make([]string, 0, len(values))
sarray := make([]string, 0, len(values)) sarray := make([]string, 0, len(values))
for k, _ := range values { for k := range values {
keys = append(keys, k) keys = append(keys, k)
} }
sort.Strings(keys) sort.Strings(keys)
@ -65,7 +65,7 @@ func (s *V2Signer) Sign(method, path string, params map[string]string) {
// from the natural order of the encoded value of key=value. // from the natural order of the encoded value of key=value.
// Percent and gocheck.Equals affect the sorting order. // Percent and gocheck.Equals affect the sorting order.
var keys, sarray []string var keys, sarray []string
for k, _ := range params { for k := range params {
keys = append(keys, k) keys = append(keys, k)
} }
sort.Strings(keys) sort.Strings(keys)
@ -372,7 +372,7 @@ func (s *V4Signer) canonicalHeaders(h http.Header) string {
func (s *V4Signer) signedHeaders(h http.Header) string { func (s *V4Signer) signedHeaders(h http.Header) string {
i, a := 0, make([]string, len(h)) i, a := 0, make([]string, len(h))
for k, _ := range h { for k := range h {
a[i] = strings.ToLower(k) a[i] = strings.ToLower(k)
i++ i++
} }

File diff suppressed because it is too large Load diff

View file

@ -1,111 +0,0 @@
package libtrust
import (
"encoding/pem"
"io/ioutil"
"net"
"os"
"path"
"testing"
)
func TestGenerateCertificates(t *testing.T) {
key, err := GenerateECP256PrivateKey()
if err != nil {
t.Fatal(err)
}
_, err = GenerateSelfSignedServerCert(key, []string{"localhost"}, []net.IP{net.ParseIP("127.0.0.1")})
if err != nil {
t.Fatal(err)
}
_, err = GenerateSelfSignedClientCert(key)
if err != nil {
t.Fatal(err)
}
}
func TestGenerateCACertPool(t *testing.T) {
key, err := GenerateECP256PrivateKey()
if err != nil {
t.Fatal(err)
}
caKey1, err := GenerateECP256PrivateKey()
if err != nil {
t.Fatal(err)
}
caKey2, err := GenerateECP256PrivateKey()
if err != nil {
t.Fatal(err)
}
_, err = GenerateCACertPool(key, []PublicKey{caKey1.PublicKey(), caKey2.PublicKey()})
if err != nil {
t.Fatal(err)
}
}
func TestLoadCertificates(t *testing.T) {
key, err := GenerateECP256PrivateKey()
if err != nil {
t.Fatal(err)
}
caKey1, err := GenerateECP256PrivateKey()
if err != nil {
t.Fatal(err)
}
caKey2, err := GenerateECP256PrivateKey()
if err != nil {
t.Fatal(err)
}
cert1, err := GenerateCACert(caKey1, key)
if err != nil {
t.Fatal(err)
}
cert2, err := GenerateCACert(caKey2, key)
if err != nil {
t.Fatal(err)
}
d, err := ioutil.TempDir("/tmp", "cert-test")
if err != nil {
t.Fatal(err)
}
caFile := path.Join(d, "ca.pem")
f, err := os.OpenFile(caFile, os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
t.Fatal(err)
}
err = pem.Encode(f, &pem.Block{Type: "CERTIFICATE", Bytes: cert1.Raw})
if err != nil {
t.Fatal(err)
}
err = pem.Encode(f, &pem.Block{Type: "CERTIFICATE", Bytes: cert2.Raw})
if err != nil {
t.Fatal(err)
}
f.Close()
certs, err := LoadCertificateBundle(caFile)
if err != nil {
t.Fatal(err)
}
if len(certs) != 2 {
t.Fatalf("Wrong number of certs received, expected: %d, received %d", 2, len(certs))
}
pool, err := LoadCertificatePool(caFile)
if err != nil {
t.Fatal(err)
}
if len(pool.Subjects()) != 2 {
t.Fatalf("Invalid certificate pool")
}
}

View file

@ -1,157 +0,0 @@
package libtrust
import (
"bytes"
"encoding/json"
"testing"
)
func generateECTestKeys(t *testing.T) []PrivateKey {
p256Key, err := GenerateECP256PrivateKey()
if err != nil {
t.Fatal(err)
}
p384Key, err := GenerateECP384PrivateKey()
if err != nil {
t.Fatal(err)
}
p521Key, err := GenerateECP521PrivateKey()
if err != nil {
t.Fatal(err)
}
return []PrivateKey{p256Key, p384Key, p521Key}
}
func TestECKeys(t *testing.T) {
ecKeys := generateECTestKeys(t)
for _, ecKey := range ecKeys {
if ecKey.KeyType() != "EC" {
t.Fatalf("key type must be %q, instead got %q", "EC", ecKey.KeyType())
}
}
}
func TestECSignVerify(t *testing.T) {
ecKeys := generateECTestKeys(t)
message := "Hello, World!"
data := bytes.NewReader([]byte(message))
sigAlgs := []*signatureAlgorithm{es256, es384, es512}
for i, ecKey := range ecKeys {
sigAlg := sigAlgs[i]
t.Logf("%s signature of %q with kid: %s\n", sigAlg.HeaderParam(), message, ecKey.KeyID())
data.Seek(0, 0) // Reset the byte reader
// Sign
sig, alg, err := ecKey.Sign(data, sigAlg.HashID())
if err != nil {
t.Fatal(err)
}
data.Seek(0, 0) // Reset the byte reader
// Verify
err = ecKey.Verify(data, alg, sig)
if err != nil {
t.Fatal(err)
}
}
}
func TestMarshalUnmarshalECKeys(t *testing.T) {
ecKeys := generateECTestKeys(t)
data := bytes.NewReader([]byte("This is a test. I repeat: this is only a test."))
sigAlgs := []*signatureAlgorithm{es256, es384, es512}
for i, ecKey := range ecKeys {
sigAlg := sigAlgs[i]
privateJWKJSON, err := json.MarshalIndent(ecKey, "", " ")
if err != nil {
t.Fatal(err)
}
publicJWKJSON, err := json.MarshalIndent(ecKey.PublicKey(), "", " ")
if err != nil {
t.Fatal(err)
}
t.Logf("JWK Private Key: %s", string(privateJWKJSON))
t.Logf("JWK Public Key: %s", string(publicJWKJSON))
privKey2, err := UnmarshalPrivateKeyJWK(privateJWKJSON)
if err != nil {
t.Fatal(err)
}
pubKey2, err := UnmarshalPublicKeyJWK(publicJWKJSON)
if err != nil {
t.Fatal(err)
}
// Ensure we can sign/verify a message with the unmarshalled keys.
data.Seek(0, 0) // Reset the byte reader
signature, alg, err := privKey2.Sign(data, sigAlg.HashID())
if err != nil {
t.Fatal(err)
}
data.Seek(0, 0) // Reset the byte reader
err = pubKey2.Verify(data, alg, signature)
if err != nil {
t.Fatal(err)
}
}
}
func TestFromCryptoECKeys(t *testing.T) {
ecKeys := generateECTestKeys(t)
for _, ecKey := range ecKeys {
cryptoPrivateKey := ecKey.CryptoPrivateKey()
cryptoPublicKey := ecKey.CryptoPublicKey()
pubKey, err := FromCryptoPublicKey(cryptoPublicKey)
if err != nil {
t.Fatal(err)
}
if pubKey.KeyID() != ecKey.KeyID() {
t.Fatal("public key key ID mismatch")
}
privKey, err := FromCryptoPrivateKey(cryptoPrivateKey)
if err != nil {
t.Fatal(err)
}
if privKey.KeyID() != ecKey.KeyID() {
t.Fatal("public key key ID mismatch")
}
}
}
func TestExtendedFields(t *testing.T) {
key, err := GenerateECP256PrivateKey()
if err != nil {
t.Fatal(err)
}
key.AddExtendedField("test", "foobar")
val := key.GetExtendedField("test")
gotVal, ok := val.(string)
if !ok {
t.Fatalf("value is not a string")
} else if gotVal != val {
t.Fatalf("value %q is not equal to %q", gotVal, val)
}
}

View file

@ -1,81 +0,0 @@
package libtrust
import (
"testing"
)
func compareKeySlices(t *testing.T, sliceA, sliceB []PublicKey) {
if len(sliceA) != len(sliceB) {
t.Fatalf("slice size %d, expected %d", len(sliceA), len(sliceB))
}
for i, itemA := range sliceA {
itemB := sliceB[i]
if itemA != itemB {
t.Fatalf("slice index %d not equal: %#v != %#v", i, itemA, itemB)
}
}
}
func TestFilter(t *testing.T) {
keys := make([]PublicKey, 0, 8)
// Create 8 keys and add host entries.
for i := 0; i < cap(keys); i++ {
key, err := GenerateECP256PrivateKey()
if err != nil {
t.Fatal(err)
}
// we use both []interface{} and []string here because jwt uses
// []interface{} format, while PEM uses []string
switch {
case i == 0:
// Don't add entries for this key, key 0.
break
case i%2 == 0:
// Should catch keys 2, 4, and 6.
key.AddExtendedField("hosts", []interface{}{"*.even.example.com"})
case i == 7:
// Should catch only the last key, and make it match any hostname.
key.AddExtendedField("hosts", []string{"*"})
default:
// should catch keys 1, 3, 5.
key.AddExtendedField("hosts", []string{"*.example.com"})
}
keys = append(keys, key)
}
// Should match 2 keys, the empty one, and the one that matches all hosts.
matchedKeys, err := FilterByHosts(keys, "foo.bar.com", true)
if err != nil {
t.Fatal(err)
}
expectedMatch := []PublicKey{keys[0], keys[7]}
compareKeySlices(t, expectedMatch, matchedKeys)
// Should match 1 key, the one that matches any host.
matchedKeys, err = FilterByHosts(keys, "foo.bar.com", false)
if err != nil {
t.Fatal(err)
}
expectedMatch = []PublicKey{keys[7]}
compareKeySlices(t, expectedMatch, matchedKeys)
// Should match keys that end in "example.com", and the key that matches anything.
matchedKeys, err = FilterByHosts(keys, "foo.example.com", false)
if err != nil {
t.Fatal(err)
}
expectedMatch = []PublicKey{keys[1], keys[3], keys[5], keys[7]}
compareKeySlices(t, expectedMatch, matchedKeys)
// Should match all of the keys except the empty key.
matchedKeys, err = FilterByHosts(keys, "foo.even.example.com", false)
if err != nil {
t.Fatal(err)
}
expectedMatch = keys[1:]
compareKeySlices(t, expectedMatch, matchedKeys)
}

View file

@ -1,380 +0,0 @@
package libtrust
import (
"bytes"
"crypto/rand"
"crypto/x509"
"encoding/json"
"fmt"
"io"
"testing"
"github.com/docker/libtrust/testutil"
)
func createTestJSON(sigKey string, indent string) (map[string]interface{}, []byte) {
testMap := map[string]interface{}{
"name": "dmcgowan/mycontainer",
"config": map[string]interface{}{
"ports": []int{9101, 9102},
"run": "/bin/echo \"Hello\"",
},
"layers": []string{
"2893c080-27f5-11e4-8c21-0800200c9a66",
"c54bc25b-fbb2-497b-a899-a8bc1b5b9d55",
"4d5d7e03-f908-49f3-a7f6-9ba28dfe0fb4",
"0b6da891-7f7f-4abf-9c97-7887549e696c",
"1d960389-ae4f-4011-85fd-18d0f96a67ad",
},
}
formattedSection := `{"config":{"ports":[9101,9102],"run":"/bin/echo \"Hello\""},"layers":["2893c080-27f5-11e4-8c21-0800200c9a66","c54bc25b-fbb2-497b-a899-a8bc1b5b9d55","4d5d7e03-f908-49f3-a7f6-9ba28dfe0fb4","0b6da891-7f7f-4abf-9c97-7887549e696c","1d960389-ae4f-4011-85fd-18d0f96a67ad"],"name":"dmcgowan/mycontainer","%s":[{"header":{`
formattedSection = fmt.Sprintf(formattedSection, sigKey)
if indent != "" {
buf := bytes.NewBuffer(nil)
json.Indent(buf, []byte(formattedSection), "", indent)
return testMap, buf.Bytes()
}
return testMap, []byte(formattedSection)
}
func TestSignJSON(t *testing.T) {
key, err := GenerateECP256PrivateKey()
if err != nil {
t.Fatalf("Error generating EC key: %s", err)
}
testMap, _ := createTestJSON("buildSignatures", " ")
indented, err := json.MarshalIndent(testMap, "", " ")
if err != nil {
t.Fatalf("Marshall error: %s", err)
}
js, err := NewJSONSignature(indented)
if err != nil {
t.Fatalf("Error creating JSON signature: %s", err)
}
err = js.Sign(key)
if err != nil {
t.Fatalf("Error signing content: %s", err)
}
keys, err := js.Verify()
if err != nil {
t.Fatalf("Error verifying signature: %s", err)
}
if len(keys) != 1 {
t.Fatalf("Error wrong number of keys returned")
}
if keys[0].KeyID() != key.KeyID() {
t.Fatalf("Unexpected public key returned")
}
}
func TestSignMap(t *testing.T) {
key, err := GenerateECP256PrivateKey()
if err != nil {
t.Fatalf("Error generating EC key: %s", err)
}
testMap, _ := createTestJSON("buildSignatures", " ")
js, err := NewJSONSignatureFromMap(testMap)
if err != nil {
t.Fatalf("Error creating JSON signature: %s", err)
}
err = js.Sign(key)
if err != nil {
t.Fatalf("Error signing JSON signature: %s", err)
}
keys, err := js.Verify()
if err != nil {
t.Fatalf("Error verifying signature: %s", err)
}
if len(keys) != 1 {
t.Fatalf("Error wrong number of keys returned")
}
if keys[0].KeyID() != key.KeyID() {
t.Fatalf("Unexpected public key returned")
}
}
func TestFormattedJson(t *testing.T) {
key, err := GenerateECP256PrivateKey()
if err != nil {
t.Fatalf("Error generating EC key: %s", err)
}
testMap, firstSection := createTestJSON("buildSignatures", " ")
indented, err := json.MarshalIndent(testMap, "", " ")
if err != nil {
t.Fatalf("Marshall error: %s", err)
}
js, err := NewJSONSignature(indented)
if err != nil {
t.Fatalf("Error creating JSON signature: %s", err)
}
err = js.Sign(key)
if err != nil {
t.Fatalf("Error signing content: %s", err)
}
b, err := js.PrettySignature("buildSignatures")
if err != nil {
t.Fatalf("Error signing map: %s", err)
}
if bytes.Compare(b[:len(firstSection)], firstSection) != 0 {
t.Fatalf("Wrong signed value\nExpected:\n%s\nActual:\n%s", firstSection, b[:len(firstSection)])
}
parsed, err := ParsePrettySignature(b, "buildSignatures")
if err != nil {
t.Fatalf("Error parsing formatted signature: %s", err)
}
keys, err := parsed.Verify()
if err != nil {
t.Fatalf("Error verifying signature: %s", err)
}
if len(keys) != 1 {
t.Fatalf("Error wrong number of keys returned")
}
if keys[0].KeyID() != key.KeyID() {
t.Fatalf("Unexpected public key returned")
}
var unmarshalled map[string]interface{}
err = json.Unmarshal(b, &unmarshalled)
if err != nil {
t.Fatalf("Could not unmarshall after parse: %s", err)
}
}
func TestFormattedFlatJson(t *testing.T) {
key, err := GenerateECP256PrivateKey()
if err != nil {
t.Fatalf("Error generating EC key: %s", err)
}
testMap, firstSection := createTestJSON("buildSignatures", "")
unindented, err := json.Marshal(testMap)
if err != nil {
t.Fatalf("Marshall error: %s", err)
}
js, err := NewJSONSignature(unindented)
if err != nil {
t.Fatalf("Error creating JSON signature: %s", err)
}
err = js.Sign(key)
if err != nil {
t.Fatalf("Error signing JSON signature: %s", err)
}
b, err := js.PrettySignature("buildSignatures")
if err != nil {
t.Fatalf("Error signing map: %s", err)
}
if bytes.Compare(b[:len(firstSection)], firstSection) != 0 {
t.Fatalf("Wrong signed value\nExpected:\n%s\nActual:\n%s", firstSection, b[:len(firstSection)])
}
parsed, err := ParsePrettySignature(b, "buildSignatures")
if err != nil {
t.Fatalf("Error parsing formatted signature: %s", err)
}
keys, err := parsed.Verify()
if err != nil {
t.Fatalf("Error verifying signature: %s", err)
}
if len(keys) != 1 {
t.Fatalf("Error wrong number of keys returned")
}
if keys[0].KeyID() != key.KeyID() {
t.Fatalf("Unexpected public key returned")
}
}
func generateTrustChain(t *testing.T, key PrivateKey, ca *x509.Certificate) (PrivateKey, []*x509.Certificate) {
parent := ca
parentKey := key
chain := make([]*x509.Certificate, 6)
for i := 5; i > 0; i-- {
intermediatekey, err := GenerateECP256PrivateKey()
if err != nil {
t.Fatalf("Error generate key: %s", err)
}
chain[i], err = testutil.GenerateIntermediate(intermediatekey.CryptoPublicKey(), parentKey.CryptoPrivateKey(), parent)
if err != nil {
t.Fatalf("Error generating intermdiate certificate: %s", err)
}
parent = chain[i]
parentKey = intermediatekey
}
trustKey, err := GenerateECP256PrivateKey()
if err != nil {
t.Fatalf("Error generate key: %s", err)
}
chain[0], err = testutil.GenerateTrustCert(trustKey.CryptoPublicKey(), parentKey.CryptoPrivateKey(), parent)
if err != nil {
t.Fatalf("Error generate trust cert: %s", err)
}
return trustKey, chain
}
func TestChainVerify(t *testing.T) {
caKey, err := GenerateECP256PrivateKey()
if err != nil {
t.Fatalf("Error generating key: %s", err)
}
ca, err := testutil.GenerateTrustCA(caKey.CryptoPublicKey(), caKey.CryptoPrivateKey())
if err != nil {
t.Fatalf("Error generating ca: %s", err)
}
trustKey, chain := generateTrustChain(t, caKey, ca)
testMap, _ := createTestJSON("verifySignatures", " ")
js, err := NewJSONSignatureFromMap(testMap)
if err != nil {
t.Fatalf("Error creating JSONSignature from map: %s", err)
}
err = js.SignWithChain(trustKey, chain)
if err != nil {
t.Fatalf("Error signing with chain: %s", err)
}
pool := x509.NewCertPool()
pool.AddCert(ca)
chains, err := js.VerifyChains(pool)
if err != nil {
t.Fatalf("Error verifying content: %s", err)
}
if len(chains) != 1 {
t.Fatalf("Unexpected chains length: %d", len(chains))
}
if len(chains[0]) != 7 {
t.Fatalf("Unexpected chain length: %d", len(chains[0]))
}
}
func TestInvalidChain(t *testing.T) {
caKey, err := GenerateECP256PrivateKey()
if err != nil {
t.Fatalf("Error generating key: %s", err)
}
ca, err := testutil.GenerateTrustCA(caKey.CryptoPublicKey(), caKey.CryptoPrivateKey())
if err != nil {
t.Fatalf("Error generating ca: %s", err)
}
trustKey, chain := generateTrustChain(t, caKey, ca)
testMap, _ := createTestJSON("verifySignatures", " ")
js, err := NewJSONSignatureFromMap(testMap)
if err != nil {
t.Fatalf("Error creating JSONSignature from map: %s", err)
}
err = js.SignWithChain(trustKey, chain[:5])
if err != nil {
t.Fatalf("Error signing with chain: %s", err)
}
pool := x509.NewCertPool()
pool.AddCert(ca)
chains, err := js.VerifyChains(pool)
if err == nil {
t.Fatalf("Expected error verifying with bad chain")
}
if len(chains) != 0 {
t.Fatalf("Unexpected chains returned from invalid verify")
}
}
func TestMergeSignatures(t *testing.T) {
pk1, err := GenerateECP256PrivateKey()
if err != nil {
t.Fatalf("unexpected error generating private key 1: %v", err)
}
pk2, err := GenerateECP256PrivateKey()
if err != nil {
t.Fatalf("unexpected error generating private key 2: %v", err)
}
payload := make([]byte, 1<<10)
if _, err = io.ReadFull(rand.Reader, payload); err != nil {
t.Fatalf("error generating payload: %v", err)
}
payload, _ = json.Marshal(map[string]interface{}{"data": payload})
sig1, err := NewJSONSignature(payload)
if err != nil {
t.Fatalf("unexpected error creating signature 1: %v", err)
}
if err := sig1.Sign(pk1); err != nil {
t.Fatalf("unexpected error signing with pk1: %v", err)
}
sig2, err := NewJSONSignature(payload)
if err != nil {
t.Fatalf("unexpected error creating signature 2: %v", err)
}
if err := sig2.Sign(pk2); err != nil {
t.Fatalf("unexpected error signing with pk2: %v", err)
}
// Now, we actually merge into sig1
if err := sig1.Merge(sig2); err != nil {
t.Fatalf("unexpected error merging: %v", err)
}
// Verify the new signature package
pubkeys, err := sig1.Verify()
if err != nil {
t.Fatalf("unexpected error during verify: %v", err)
}
// Make sure the pubkeys match the two private keys from before
privkeys := map[string]PrivateKey{
pk1.KeyID(): pk1,
pk2.KeyID(): pk2,
}
found := map[string]struct{}{}
for _, pubkey := range pubkeys {
if _, ok := privkeys[pubkey.KeyID()]; !ok {
t.Fatalf("unexpected public key found during verification: %v", pubkey)
}
found[pubkey.KeyID()] = struct{}{}
}
// Make sure we've found all the private keys from verification
for keyid, _ := range privkeys {
if _, ok := found[keyid]; !ok {
t.Fatalf("public key %v not found during verification", keyid)
}
}
// Create another signature, with a different payload, and ensure we get an error.
sig3, err := NewJSONSignature([]byte("{}"))
if err != nil {
t.Fatalf("unexpected error making signature for sig3: %v", err)
}
if err := sig1.Merge(sig3); err == nil {
t.Fatalf("error expected during invalid merge with different payload")
}
}

View file

@ -1,220 +0,0 @@
package libtrust
import (
"errors"
"io/ioutil"
"os"
"testing"
)
func makeTempFile(t *testing.T, prefix string) (filename string) {
file, err := ioutil.TempFile("", prefix)
if err != nil {
t.Fatal(err)
}
filename = file.Name()
file.Close()
return
}
func TestKeyFiles(t *testing.T) {
key, err := GenerateECP256PrivateKey()
if err != nil {
t.Fatal(err)
}
testKeyFiles(t, key)
key, err = GenerateRSA2048PrivateKey()
if err != nil {
t.Fatal(err)
}
testKeyFiles(t, key)
}
func testKeyFiles(t *testing.T, key PrivateKey) {
var err error
privateKeyFilename := makeTempFile(t, "private_key")
privateKeyFilenamePEM := privateKeyFilename + ".pem"
privateKeyFilenameJWK := privateKeyFilename + ".jwk"
publicKeyFilename := makeTempFile(t, "public_key")
publicKeyFilenamePEM := publicKeyFilename + ".pem"
publicKeyFilenameJWK := publicKeyFilename + ".jwk"
if err = SaveKey(privateKeyFilenamePEM, key); err != nil {
t.Fatal(err)
}
if err = SaveKey(privateKeyFilenameJWK, key); err != nil {
t.Fatal(err)
}
if err = SavePublicKey(publicKeyFilenamePEM, key.PublicKey()); err != nil {
t.Fatal(err)
}
if err = SavePublicKey(publicKeyFilenameJWK, key.PublicKey()); err != nil {
t.Fatal(err)
}
loadedPEMKey, err := LoadKeyFile(privateKeyFilenamePEM)
if err != nil {
t.Fatal(err)
}
loadedJWKKey, err := LoadKeyFile(privateKeyFilenameJWK)
if err != nil {
t.Fatal(err)
}
loadedPEMPublicKey, err := LoadPublicKeyFile(publicKeyFilenamePEM)
if err != nil {
t.Fatal(err)
}
loadedJWKPublicKey, err := LoadPublicKeyFile(publicKeyFilenameJWK)
if err != nil {
t.Fatal(err)
}
if key.KeyID() != loadedPEMKey.KeyID() {
t.Fatal(errors.New("key IDs do not match"))
}
if key.KeyID() != loadedJWKKey.KeyID() {
t.Fatal(errors.New("key IDs do not match"))
}
if key.KeyID() != loadedPEMPublicKey.KeyID() {
t.Fatal(errors.New("key IDs do not match"))
}
if key.KeyID() != loadedJWKPublicKey.KeyID() {
t.Fatal(errors.New("key IDs do not match"))
}
os.Remove(privateKeyFilename)
os.Remove(privateKeyFilenamePEM)
os.Remove(privateKeyFilenameJWK)
os.Remove(publicKeyFilename)
os.Remove(publicKeyFilenamePEM)
os.Remove(publicKeyFilenameJWK)
}
func TestTrustedHostKeysFile(t *testing.T) {
trustedHostKeysFilename := makeTempFile(t, "trusted_host_keys")
trustedHostKeysFilenamePEM := trustedHostKeysFilename + ".pem"
trustedHostKeysFilenameJWK := trustedHostKeysFilename + ".json"
testTrustedHostKeysFile(t, trustedHostKeysFilenamePEM)
testTrustedHostKeysFile(t, trustedHostKeysFilenameJWK)
os.Remove(trustedHostKeysFilename)
os.Remove(trustedHostKeysFilenamePEM)
os.Remove(trustedHostKeysFilenameJWK)
}
func testTrustedHostKeysFile(t *testing.T, trustedHostKeysFilename string) {
hostAddress1 := "docker.example.com:2376"
hostKey1, err := GenerateECP256PrivateKey()
if err != nil {
t.Fatal(err)
}
hostKey1.AddExtendedField("hosts", []string{hostAddress1})
err = AddKeySetFile(trustedHostKeysFilename, hostKey1.PublicKey())
if err != nil {
t.Fatal(err)
}
trustedHostKeysMapping, err := LoadKeySetFile(trustedHostKeysFilename)
if err != nil {
t.Fatal(err)
}
for addr, hostKey := range trustedHostKeysMapping {
t.Logf("Host Address: %d\n", addr)
t.Logf("Host Key: %s\n\n", hostKey)
}
hostAddress2 := "192.168.59.103:2376"
hostKey2, err := GenerateRSA2048PrivateKey()
if err != nil {
t.Fatal(err)
}
hostKey2.AddExtendedField("hosts", hostAddress2)
err = AddKeySetFile(trustedHostKeysFilename, hostKey2.PublicKey())
if err != nil {
t.Fatal(err)
}
trustedHostKeysMapping, err = LoadKeySetFile(trustedHostKeysFilename)
if err != nil {
t.Fatal(err)
}
for addr, hostKey := range trustedHostKeysMapping {
t.Logf("Host Address: %d\n", addr)
t.Logf("Host Key: %s\n\n", hostKey)
}
}
func TestTrustedClientKeysFile(t *testing.T) {
trustedClientKeysFilename := makeTempFile(t, "trusted_client_keys")
trustedClientKeysFilenamePEM := trustedClientKeysFilename + ".pem"
trustedClientKeysFilenameJWK := trustedClientKeysFilename + ".json"
testTrustedClientKeysFile(t, trustedClientKeysFilenamePEM)
testTrustedClientKeysFile(t, trustedClientKeysFilenameJWK)
os.Remove(trustedClientKeysFilename)
os.Remove(trustedClientKeysFilenamePEM)
os.Remove(trustedClientKeysFilenameJWK)
}
func testTrustedClientKeysFile(t *testing.T, trustedClientKeysFilename string) {
clientKey1, err := GenerateECP256PrivateKey()
if err != nil {
t.Fatal(err)
}
err = AddKeySetFile(trustedClientKeysFilename, clientKey1.PublicKey())
if err != nil {
t.Fatal(err)
}
trustedClientKeys, err := LoadKeySetFile(trustedClientKeysFilename)
if err != nil {
t.Fatal(err)
}
for _, clientKey := range trustedClientKeys {
t.Logf("Client Key: %s\n", clientKey)
}
clientKey2, err := GenerateRSA2048PrivateKey()
if err != nil {
t.Fatal(err)
}
err = AddKeySetFile(trustedClientKeysFilename, clientKey2.PublicKey())
if err != nil {
t.Fatal(err)
}
trustedClientKeys, err = LoadKeySetFile(trustedClientKeysFilename)
if err != nil {
t.Fatal(err)
}
for _, clientKey := range trustedClientKeys {
t.Logf("Client Key: %s\n", clientKey)
}
}

View file

@ -1,80 +0,0 @@
package libtrust
import (
"testing"
)
type generateFunc func() (PrivateKey, error)
func runGenerateBench(b *testing.B, f generateFunc, name string) {
for i := 0; i < b.N; i++ {
_, err := f()
if err != nil {
b.Fatalf("Error generating %s: %s", name, err)
}
}
}
func runFingerprintBench(b *testing.B, f generateFunc, name string) {
b.StopTimer()
// Don't count this relatively slow generation call.
key, err := f()
if err != nil {
b.Fatalf("Error generating %s: %s", name, err)
}
b.StartTimer()
for i := 0; i < b.N; i++ {
if key.KeyID() == "" {
b.Fatalf("Error generating key ID for %s", name)
}
}
}
func BenchmarkECP256Generate(b *testing.B) {
runGenerateBench(b, GenerateECP256PrivateKey, "P256")
}
func BenchmarkECP384Generate(b *testing.B) {
runGenerateBench(b, GenerateECP384PrivateKey, "P384")
}
func BenchmarkECP521Generate(b *testing.B) {
runGenerateBench(b, GenerateECP521PrivateKey, "P521")
}
func BenchmarkRSA2048Generate(b *testing.B) {
runGenerateBench(b, GenerateRSA2048PrivateKey, "RSA2048")
}
func BenchmarkRSA3072Generate(b *testing.B) {
runGenerateBench(b, GenerateRSA3072PrivateKey, "RSA3072")
}
func BenchmarkRSA4096Generate(b *testing.B) {
runGenerateBench(b, GenerateRSA4096PrivateKey, "RSA4096")
}
func BenchmarkECP256Fingerprint(b *testing.B) {
runFingerprintBench(b, GenerateECP256PrivateKey, "P256")
}
func BenchmarkECP384Fingerprint(b *testing.B) {
runFingerprintBench(b, GenerateECP384PrivateKey, "P384")
}
func BenchmarkECP521Fingerprint(b *testing.B) {
runFingerprintBench(b, GenerateECP521PrivateKey, "P521")
}
func BenchmarkRSA2048Fingerprint(b *testing.B) {
runFingerprintBench(b, GenerateRSA2048PrivateKey, "RSA2048")
}
func BenchmarkRSA3072Fingerprint(b *testing.B) {
runFingerprintBench(b, GenerateRSA3072PrivateKey, "RSA3072")
}
func BenchmarkRSA4096Fingerprint(b *testing.B) {
runFingerprintBench(b, GenerateRSA4096PrivateKey, "RSA4096")
}

View file

@ -1,157 +0,0 @@
package libtrust
import (
"bytes"
"encoding/json"
"log"
"testing"
)
var rsaKeys []PrivateKey
func init() {
var err error
rsaKeys, err = generateRSATestKeys()
if err != nil {
log.Fatal(err)
}
}
func generateRSATestKeys() (keys []PrivateKey, err error) {
log.Println("Generating RSA 2048-bit Test Key")
rsa2048Key, err := GenerateRSA2048PrivateKey()
if err != nil {
return
}
log.Println("Generating RSA 3072-bit Test Key")
rsa3072Key, err := GenerateRSA3072PrivateKey()
if err != nil {
return
}
log.Println("Generating RSA 4096-bit Test Key")
rsa4096Key, err := GenerateRSA4096PrivateKey()
if err != nil {
return
}
log.Println("Done generating RSA Test Keys!")
keys = []PrivateKey{rsa2048Key, rsa3072Key, rsa4096Key}
return
}
func TestRSAKeys(t *testing.T) {
for _, rsaKey := range rsaKeys {
if rsaKey.KeyType() != "RSA" {
t.Fatalf("key type must be %q, instead got %q", "RSA", rsaKey.KeyType())
}
}
}
func TestRSASignVerify(t *testing.T) {
message := "Hello, World!"
data := bytes.NewReader([]byte(message))
sigAlgs := []*signatureAlgorithm{rs256, rs384, rs512}
for i, rsaKey := range rsaKeys {
sigAlg := sigAlgs[i]
t.Logf("%s signature of %q with kid: %s\n", sigAlg.HeaderParam(), message, rsaKey.KeyID())
data.Seek(0, 0) // Reset the byte reader
// Sign
sig, alg, err := rsaKey.Sign(data, sigAlg.HashID())
if err != nil {
t.Fatal(err)
}
data.Seek(0, 0) // Reset the byte reader
// Verify
err = rsaKey.Verify(data, alg, sig)
if err != nil {
t.Fatal(err)
}
}
}
func TestMarshalUnmarshalRSAKeys(t *testing.T) {
data := bytes.NewReader([]byte("This is a test. I repeat: this is only a test."))
sigAlgs := []*signatureAlgorithm{rs256, rs384, rs512}
for i, rsaKey := range rsaKeys {
sigAlg := sigAlgs[i]
privateJWKJSON, err := json.MarshalIndent(rsaKey, "", " ")
if err != nil {
t.Fatal(err)
}
publicJWKJSON, err := json.MarshalIndent(rsaKey.PublicKey(), "", " ")
if err != nil {
t.Fatal(err)
}
t.Logf("JWK Private Key: %s", string(privateJWKJSON))
t.Logf("JWK Public Key: %s", string(publicJWKJSON))
privKey2, err := UnmarshalPrivateKeyJWK(privateJWKJSON)
if err != nil {
t.Fatal(err)
}
pubKey2, err := UnmarshalPublicKeyJWK(publicJWKJSON)
if err != nil {
t.Fatal(err)
}
// Ensure we can sign/verify a message with the unmarshalled keys.
data.Seek(0, 0) // Reset the byte reader
signature, alg, err := privKey2.Sign(data, sigAlg.HashID())
if err != nil {
t.Fatal(err)
}
data.Seek(0, 0) // Reset the byte reader
err = pubKey2.Verify(data, alg, signature)
if err != nil {
t.Fatal(err)
}
// It's a good idea to validate the Private Key to make sure our
// (un)marshal process didn't corrupt the extra parameters.
k := privKey2.(*rsaPrivateKey)
err = k.PrivateKey.Validate()
if err != nil {
t.Fatal(err)
}
}
}
func TestFromCryptoRSAKeys(t *testing.T) {
for _, rsaKey := range rsaKeys {
cryptoPrivateKey := rsaKey.CryptoPrivateKey()
cryptoPublicKey := rsaKey.CryptoPublicKey()
pubKey, err := FromCryptoPublicKey(cryptoPublicKey)
if err != nil {
t.Fatal(err)
}
if pubKey.KeyID() != rsaKey.KeyID() {
t.Fatal("public key key ID mismatch")
}
privKey, err := FromCryptoPrivateKey(cryptoPrivateKey)
if err != nil {
t.Fatal(err)
}
if privKey.KeyID() != rsaKey.KeyID() {
t.Fatal("public key key ID mismatch")
}
}
}

View file

@ -1,94 +0,0 @@
package testutil
import (
"crypto"
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"math/big"
"time"
)
// GenerateTrustCA generates a new certificate authority for testing.
func GenerateTrustCA(pub crypto.PublicKey, priv crypto.PrivateKey) (*x509.Certificate, error) {
cert := &x509.Certificate{
SerialNumber: big.NewInt(0),
Subject: pkix.Name{
CommonName: "CA Root",
},
NotBefore: time.Now().Add(-time.Second),
NotAfter: time.Now().Add(time.Hour),
IsCA: true,
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
BasicConstraintsValid: true,
}
certDER, err := x509.CreateCertificate(rand.Reader, cert, cert, pub, priv)
if err != nil {
return nil, err
}
cert, err = x509.ParseCertificate(certDER)
if err != nil {
return nil, err
}
return cert, nil
}
// GenerateIntermediate generates an intermediate certificate for testing using
// the parent certificate (likely a CA) and the provided keys.
func GenerateIntermediate(key crypto.PublicKey, parentKey crypto.PrivateKey, parent *x509.Certificate) (*x509.Certificate, error) {
cert := &x509.Certificate{
SerialNumber: big.NewInt(0),
Subject: pkix.Name{
CommonName: "Intermediate",
},
NotBefore: time.Now().Add(-time.Second),
NotAfter: time.Now().Add(time.Hour),
IsCA: true,
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
BasicConstraintsValid: true,
}
certDER, err := x509.CreateCertificate(rand.Reader, cert, parent, key, parentKey)
if err != nil {
return nil, err
}
cert, err = x509.ParseCertificate(certDER)
if err != nil {
return nil, err
}
return cert, nil
}
// GenerateTrustCert generates a new trust certificate for testing. Unlike the
// intermediate certificates, this certificate should be used for signature
// only, not creating certificates.
func GenerateTrustCert(key crypto.PublicKey, parentKey crypto.PrivateKey, parent *x509.Certificate) (*x509.Certificate, error) {
cert := &x509.Certificate{
SerialNumber: big.NewInt(0),
Subject: pkix.Name{
CommonName: "Trust Cert",
},
NotBefore: time.Now().Add(-time.Second),
NotAfter: time.Now().Add(time.Hour),
IsCA: true,
KeyUsage: x509.KeyUsageDigitalSignature,
BasicConstraintsValid: true,
}
certDER, err := x509.CreateCertificate(rand.Reader, cert, parent, key, parentKey)
if err != nil {
return nil, err
}
cert, err = x509.ParseCertificate(certDER)
if err != nil {
return nil, err
}
return cert, nil
}

View file

@ -1,50 +0,0 @@
## Libtrust TLS Config Demo
This program generates key pairs and trust files for a TLS client and server.
To generate the keys, run:
```
$ go run genkeys.go
```
The generated files are:
```
$ ls -l client_data/ server_data/
client_data/:
total 24
-rw------- 1 jlhawn staff 281 Aug 8 16:21 private_key.json
-rw-r--r-- 1 jlhawn staff 225 Aug 8 16:21 public_key.json
-rw-r--r-- 1 jlhawn staff 275 Aug 8 16:21 trusted_hosts.json
server_data/:
total 24
-rw-r--r-- 1 jlhawn staff 348 Aug 8 16:21 trusted_clients.json
-rw------- 1 jlhawn staff 281 Aug 8 16:21 private_key.json
-rw-r--r-- 1 jlhawn staff 225 Aug 8 16:21 public_key.json
```
The private key and public key for the client and server are stored in `private_key.json` and `public_key.json`, respectively, and in their respective directories. They are represented as JSON Web Keys: JSON objects which represent either an ECDSA or RSA private key. The host keys trusted by the client are stored in `trusted_hosts.json` and contain a mapping of an internet address, `<HOSTNAME_OR_IP>:<PORT>`, to a JSON Web Key which is a JSON object representing either an ECDSA or RSA public key of the trusted server. The client keys trusted by the server are stored in `trusted_clients.json` and contain an array of JSON objects which contain a comment field which can be used describe the key and a JSON Web Key which is a JSON object representing either an ECDSA or RSA public key of the trusted client.
To start the server, run:
```
$ go run server.go
```
This starts an HTTPS server which listens on `localhost:8888`. The server configures itself with a certificate which is valid for both `localhost` and `127.0.0.1` and uses the key from `server_data/private_key.json`. It accepts connections from clients which present a certificate for a key that it is configured to trust from the `trusted_clients.json` file and returns a simple 'hello' message.
To make a request using the client, run:
```
$ go run client.go
```
This command creates an HTTPS client which makes a GET request to `https://localhost:8888`. The client configures itself with a certificate using the key from `client_data/private_key.json`. It only connects to a server which presents a certificate signed by the key specified for the `localhost:8888` address from `client_data/trusted_hosts.json` and made to be used for the `localhost` hostname. If the connection succeeds, it prints the response from the server.
The file `gencert.go` can be used to generate PEM encoded version of the client key and certificate. If you save them to `key.pem` and `cert.pem` respectively, you can use them with `curl` to test out the server (if it is still running).
```
curl --cert cert.pem --key key.pem -k https://localhost:8888
```

View file

@ -1,89 +0,0 @@
package main
import (
"crypto/tls"
"fmt"
"io/ioutil"
"log"
"net"
"net/http"
"github.com/docker/libtrust"
)
var (
serverAddress = "localhost:8888"
privateKeyFilename = "client_data/private_key.pem"
trustedHostsFilename = "client_data/trusted_hosts.pem"
)
func main() {
// Load Client Key.
clientKey, err := libtrust.LoadKeyFile(privateKeyFilename)
if err != nil {
log.Fatal(err)
}
// Generate Client Certificate.
selfSignedClientCert, err := libtrust.GenerateSelfSignedClientCert(clientKey)
if err != nil {
log.Fatal(err)
}
// Load trusted host keys.
hostKeys, err := libtrust.LoadKeySetFile(trustedHostsFilename)
if err != nil {
log.Fatal(err)
}
// Ensure the host we want to connect to is trusted!
host, _, err := net.SplitHostPort(serverAddress)
if err != nil {
log.Fatal(err)
}
serverKeys, err := libtrust.FilterByHosts(hostKeys, host, false)
if err != nil {
log.Fatalf("%q is not a known and trusted host", host)
}
// Generate a CA pool with the trusted host's key.
caPool, err := libtrust.GenerateCACertPool(clientKey, serverKeys)
if err != nil {
log.Fatal(err)
}
// Create HTTP Client.
client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
Certificates: []tls.Certificate{
tls.Certificate{
Certificate: [][]byte{selfSignedClientCert.Raw},
PrivateKey: clientKey.CryptoPrivateKey(),
Leaf: selfSignedClientCert,
},
},
RootCAs: caPool,
},
},
}
var makeRequest = func(url string) {
resp, err := client.Get(url)
if err != nil {
log.Fatal(err)
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
log.Fatal(err)
}
log.Println(resp.Status)
log.Println(string(body))
}
// Make the request to the trusted server!
makeRequest(fmt.Sprintf("https://%s", serverAddress))
}

View file

@ -1,62 +0,0 @@
package main
import (
"encoding/pem"
"fmt"
"log"
"net"
"github.com/docker/libtrust"
)
var (
serverAddress = "localhost:8888"
clientPrivateKeyFilename = "client_data/private_key.pem"
trustedHostsFilename = "client_data/trusted_hosts.pem"
)
func main() {
key, err := libtrust.LoadKeyFile(clientPrivateKeyFilename)
if err != nil {
log.Fatal(err)
}
keyPEMBlock, err := key.PEMBlock()
if err != nil {
log.Fatal(err)
}
encodedPrivKey := pem.EncodeToMemory(keyPEMBlock)
fmt.Printf("Client Key:\n\n%s\n", string(encodedPrivKey))
cert, err := libtrust.GenerateSelfSignedClientCert(key)
if err != nil {
log.Fatal(err)
}
encodedCert := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw})
fmt.Printf("Client Cert:\n\n%s\n", string(encodedCert))
trustedServerKeys, err := libtrust.LoadKeySetFile(trustedHostsFilename)
if err != nil {
log.Fatal(err)
}
hostname, _, err := net.SplitHostPort(serverAddress)
if err != nil {
log.Fatal(err)
}
trustedServerKeys, err = libtrust.FilterByHosts(trustedServerKeys, hostname, false)
if err != nil {
log.Fatal(err)
}
caCert, err := libtrust.GenerateCACert(key, trustedServerKeys[0])
if err != nil {
log.Fatal(err)
}
encodedCert = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: caCert.Raw})
fmt.Printf("CA Cert:\n\n%s\n", string(encodedCert))
}

View file

@ -1,61 +0,0 @@
package main
import (
"log"
"github.com/docker/libtrust"
)
func main() {
// Generate client key.
clientKey, err := libtrust.GenerateECP256PrivateKey()
if err != nil {
log.Fatal(err)
}
// Add a comment for the client key.
clientKey.AddExtendedField("comment", "TLS Demo Client")
// Save the client key, public and private versions.
err = libtrust.SaveKey("client_data/private_key.pem", clientKey)
if err != nil {
log.Fatal(err)
}
err = libtrust.SavePublicKey("client_data/public_key.pem", clientKey.PublicKey())
if err != nil {
log.Fatal(err)
}
// Generate server key.
serverKey, err := libtrust.GenerateECP256PrivateKey()
if err != nil {
log.Fatal(err)
}
// Set the list of addresses to use for the server.
serverKey.AddExtendedField("hosts", []string{"localhost", "docker.example.com"})
// Save the server key, public and private versions.
err = libtrust.SaveKey("server_data/private_key.pem", serverKey)
if err != nil {
log.Fatal(err)
}
err = libtrust.SavePublicKey("server_data/public_key.pem", serverKey.PublicKey())
if err != nil {
log.Fatal(err)
}
// Generate Authorized Keys file for server.
err = libtrust.AddKeySetFile("server_data/trusted_clients.pem", clientKey.PublicKey())
if err != nil {
log.Fatal(err)
}
// Generate Known Host Keys file for client.
err = libtrust.AddKeySetFile("client_data/trusted_hosts.pem", serverKey.PublicKey())
if err != nil {
log.Fatal(err)
}
}

View file

@ -1,80 +0,0 @@
package main
import (
"crypto/tls"
"fmt"
"html"
"log"
"net"
"net/http"
"github.com/docker/libtrust"
)
var (
serverAddress = "localhost:8888"
privateKeyFilename = "server_data/private_key.pem"
authorizedClientsFilename = "server_data/trusted_clients.pem"
)
func requestHandler(w http.ResponseWriter, r *http.Request) {
clientCert := r.TLS.PeerCertificates[0]
keyID := clientCert.Subject.CommonName
log.Printf("Request from keyID: %s\n", keyID)
fmt.Fprintf(w, "Hello, client! I'm a server! And you are %T: %s.\n", clientCert.PublicKey, html.EscapeString(keyID))
}
func main() {
// Load server key.
serverKey, err := libtrust.LoadKeyFile(privateKeyFilename)
if err != nil {
log.Fatal(err)
}
// Generate server certificate.
selfSignedServerCert, err := libtrust.GenerateSelfSignedServerCert(
serverKey, []string{"localhost"}, []net.IP{net.ParseIP("127.0.0.1")},
)
if err != nil {
log.Fatal(err)
}
// Load authorized client keys.
authorizedClients, err := libtrust.LoadKeySetFile(authorizedClientsFilename)
if err != nil {
log.Fatal(err)
}
// Create CA pool using trusted client keys.
caPool, err := libtrust.GenerateCACertPool(serverKey, authorizedClients)
if err != nil {
log.Fatal(err)
}
// Create TLS config, requiring client certificates.
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{
tls.Certificate{
Certificate: [][]byte{selfSignedServerCert.Raw},
PrivateKey: serverKey.CryptoPrivateKey(),
Leaf: selfSignedServerCert,
},
},
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: caPool,
}
// Create HTTP server with simple request handler.
server := &http.Server{
Addr: serverAddress,
Handler: http.HandlerFunc(requestHandler),
}
// Listen and server HTTPS using the libtrust TLS config.
listener, err := net.Listen("tcp", server.Addr)
if err != nil {
log.Fatal(err)
}
tlsListener := tls.NewListener(listener, tlsConfig)
server.Serve(tlsListener)
}

View file

@ -1,50 +0,0 @@
package trustgraph
import "github.com/docker/libtrust"
// TrustGraph represents a graph of authorization mapping
// public keys to nodes and grants between nodes.
type TrustGraph interface {
// Verifies that the given public key is allowed to perform
// the given action on the given node according to the trust
// graph.
Verify(libtrust.PublicKey, string, uint16) (bool, error)
// GetGrants returns an array of all grant chains which are used to
// allow the requested permission.
GetGrants(libtrust.PublicKey, string, uint16) ([][]*Grant, error)
}
// Grant represents a transfer of permission from one part of the
// trust graph to another. This is the only way to delegate
// permission between two different sub trees in the graph.
type Grant struct {
// Subject is the namespace being granted
Subject string
// Permissions is a bit map of permissions
Permission uint16
// Grantee represents the node being granted
// a permission scope. The grantee can be
// either a namespace item or a key id where namespace
// items will always start with a '/'.
Grantee string
// statement represents the statement used to create
// this object.
statement *Statement
}
// Permissions
// Read node 0x01 (can read node, no sub nodes)
// Write node 0x02 (can write to node object, cannot create subnodes)
// Read subtree 0x04 (delegates read to each sub node)
// Write subtree 0x08 (delegates write to each sub node, included create on the subject)
//
// Permission shortcuts
// ReadItem = 0x01
// WriteItem = 0x03
// ReadAccess = 0x07
// WriteAccess = 0x0F
// Delegate = 0x0F

View file

@ -1,133 +0,0 @@
package trustgraph
import (
"strings"
"github.com/docker/libtrust"
)
type grantNode struct {
grants []*Grant
children map[string]*grantNode
}
type memoryGraph struct {
roots map[string]*grantNode
}
func newGrantNode() *grantNode {
return &grantNode{
grants: []*Grant{},
children: map[string]*grantNode{},
}
}
// NewMemoryGraph returns a new in memory trust graph created from
// a static list of grants. This graph is immutable after creation
// and any alterations should create a new instance.
func NewMemoryGraph(grants []*Grant) TrustGraph {
roots := map[string]*grantNode{}
for _, grant := range grants {
parts := strings.Split(grant.Grantee, "/")
nodes := roots
var node *grantNode
var nodeOk bool
for _, part := range parts {
node, nodeOk = nodes[part]
if !nodeOk {
node = newGrantNode()
nodes[part] = node
}
if part != "" {
node.grants = append(node.grants, grant)
}
nodes = node.children
}
}
return &memoryGraph{roots}
}
func (g *memoryGraph) getGrants(name string) []*Grant {
nameParts := strings.Split(name, "/")
nodes := g.roots
var node *grantNode
var nodeOk bool
for _, part := range nameParts {
node, nodeOk = nodes[part]
if !nodeOk {
return nil
}
nodes = node.children
}
return node.grants
}
func isSubName(name, sub string) bool {
if strings.HasPrefix(name, sub) {
if len(name) == len(sub) || name[len(sub)] == '/' {
return true
}
}
return false
}
type walkFunc func(*Grant, []*Grant) bool
func foundWalkFunc(*Grant, []*Grant) bool {
return true
}
func (g *memoryGraph) walkGrants(start, target string, permission uint16, f walkFunc, chain []*Grant, visited map[*Grant]bool, collect bool) bool {
if visited == nil {
visited = map[*Grant]bool{}
}
grants := g.getGrants(start)
subGrants := make([]*Grant, 0, len(grants))
for _, grant := range grants {
if visited[grant] {
continue
}
visited[grant] = true
if grant.Permission&permission == permission {
if isSubName(target, grant.Subject) {
if f(grant, chain) {
return true
}
} else {
subGrants = append(subGrants, grant)
}
}
}
for _, grant := range subGrants {
var chainCopy []*Grant
if collect {
chainCopy = make([]*Grant, len(chain)+1)
copy(chainCopy, chain)
chainCopy[len(chainCopy)-1] = grant
} else {
chainCopy = nil
}
if g.walkGrants(grant.Subject, target, permission, f, chainCopy, visited, collect) {
return true
}
}
return false
}
func (g *memoryGraph) Verify(key libtrust.PublicKey, node string, permission uint16) (bool, error) {
return g.walkGrants(key.KeyID(), node, permission, foundWalkFunc, nil, nil, false), nil
}
func (g *memoryGraph) GetGrants(key libtrust.PublicKey, node string, permission uint16) ([][]*Grant, error) {
grants := [][]*Grant{}
collect := func(grant *Grant, chain []*Grant) bool {
grantChain := make([]*Grant, len(chain)+1)
copy(grantChain, chain)
grantChain[len(grantChain)-1] = grant
grants = append(grants, grantChain)
return false
}
g.walkGrants(key.KeyID(), node, permission, collect, nil, nil, true)
return grants, nil
}

View file

@ -1,174 +0,0 @@
package trustgraph
import (
"fmt"
"testing"
"github.com/docker/libtrust"
)
func createTestKeysAndGrants(count int) ([]*Grant, []libtrust.PrivateKey) {
grants := make([]*Grant, count)
keys := make([]libtrust.PrivateKey, count)
for i := 0; i < count; i++ {
pk, err := libtrust.GenerateECP256PrivateKey()
if err != nil {
panic(err)
}
grant := &Grant{
Subject: fmt.Sprintf("/user-%d", i+1),
Permission: 0x0f,
Grantee: pk.KeyID(),
}
keys[i] = pk
grants[i] = grant
}
return grants, keys
}
func testVerified(t *testing.T, g TrustGraph, k libtrust.PublicKey, keyName, target string, permission uint16) {
if ok, err := g.Verify(k, target, permission); err != nil {
t.Fatalf("Unexpected error during verification: %s", err)
} else if !ok {
t.Errorf("key failed verification\n\tKey: %s(%s)\n\tNamespace: %s", keyName, k.KeyID(), target)
}
}
func testNotVerified(t *testing.T, g TrustGraph, k libtrust.PublicKey, keyName, target string, permission uint16) {
if ok, err := g.Verify(k, target, permission); err != nil {
t.Fatalf("Unexpected error during verification: %s", err)
} else if ok {
t.Errorf("key should have failed verification\n\tKey: %s(%s)\n\tNamespace: %s", keyName, k.KeyID(), target)
}
}
func TestVerify(t *testing.T) {
grants, keys := createTestKeysAndGrants(4)
extraGrants := make([]*Grant, 3)
extraGrants[0] = &Grant{
Subject: "/user-3",
Permission: 0x0f,
Grantee: "/user-2",
}
extraGrants[1] = &Grant{
Subject: "/user-3/sub-project",
Permission: 0x0f,
Grantee: "/user-4",
}
extraGrants[2] = &Grant{
Subject: "/user-4",
Permission: 0x07,
Grantee: "/user-1",
}
grants = append(grants, extraGrants...)
g := NewMemoryGraph(grants)
testVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-1", 0x0f)
testVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-1/some-project/sub-value", 0x0f)
testVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-4", 0x07)
testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-2/", 0x0f)
testVerified(t, g, keys[2].PublicKey(), "user-key-3", "/user-3/sub-value", 0x0f)
testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-3/sub-value", 0x0f)
testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-3", 0x0f)
testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-3/", 0x0f)
testVerified(t, g, keys[3].PublicKey(), "user-key-4", "/user-3/sub-project", 0x0f)
testVerified(t, g, keys[3].PublicKey(), "user-key-4", "/user-3/sub-project/app", 0x0f)
testVerified(t, g, keys[3].PublicKey(), "user-key-4", "/user-4", 0x0f)
testNotVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-2", 0x0f)
testNotVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-3/sub-value", 0x0f)
testNotVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-4", 0x0f)
testNotVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-1/", 0x0f)
testNotVerified(t, g, keys[2].PublicKey(), "user-key-3", "/user-2", 0x0f)
testNotVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-4", 0x0f)
testNotVerified(t, g, keys[3].PublicKey(), "user-key-4", "/user-3", 0x0f)
}
func TestCircularWalk(t *testing.T) {
grants, keys := createTestKeysAndGrants(3)
user1Grant := &Grant{
Subject: "/user-2",
Permission: 0x0f,
Grantee: "/user-1",
}
user2Grant := &Grant{
Subject: "/user-1",
Permission: 0x0f,
Grantee: "/user-2",
}
grants = append(grants, user1Grant, user2Grant)
g := NewMemoryGraph(grants)
testVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-1", 0x0f)
testVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-2", 0x0f)
testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-2", 0x0f)
testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-1", 0x0f)
testVerified(t, g, keys[2].PublicKey(), "user-key-3", "/user-3", 0x0f)
testNotVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-3", 0x0f)
testNotVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-3", 0x0f)
}
func assertGrantSame(t *testing.T, actual, expected *Grant) {
if actual != expected {
t.Fatalf("Unexpected grant retrieved\n\tExpected: %v\n\tActual: %v", expected, actual)
}
}
func TestGetGrants(t *testing.T) {
grants, keys := createTestKeysAndGrants(5)
extraGrants := make([]*Grant, 4)
extraGrants[0] = &Grant{
Subject: "/user-3/friend-project",
Permission: 0x0f,
Grantee: "/user-2/friends",
}
extraGrants[1] = &Grant{
Subject: "/user-3/sub-project",
Permission: 0x0f,
Grantee: "/user-4",
}
extraGrants[2] = &Grant{
Subject: "/user-2/friends",
Permission: 0x0f,
Grantee: "/user-5/fun-project",
}
extraGrants[3] = &Grant{
Subject: "/user-5/fun-project",
Permission: 0x0f,
Grantee: "/user-1",
}
grants = append(grants, extraGrants...)
g := NewMemoryGraph(grants)
grantChains, err := g.GetGrants(keys[3], "/user-3/sub-project/specific-app", 0x0f)
if err != nil {
t.Fatalf("Error getting grants: %s", err)
}
if len(grantChains) != 1 {
t.Fatalf("Expected number of grant chains returned, expected %d, received %d", 1, len(grantChains))
}
if len(grantChains[0]) != 2 {
t.Fatalf("Unexpected number of grants retrieved\n\tExpected: %d\n\tActual: %d", 2, len(grantChains[0]))
}
assertGrantSame(t, grantChains[0][0], grants[3])
assertGrantSame(t, grantChains[0][1], extraGrants[1])
grantChains, err = g.GetGrants(keys[0], "/user-3/friend-project/fun-app", 0x0f)
if err != nil {
t.Fatalf("Error getting grants: %s", err)
}
if len(grantChains) != 1 {
t.Fatalf("Expected number of grant chains returned, expected %d, received %d", 1, len(grantChains))
}
if len(grantChains[0]) != 4 {
t.Fatalf("Unexpected number of grants retrieved\n\tExpected: %d\n\tActual: %d", 2, len(grantChains[0]))
}
assertGrantSame(t, grantChains[0][0], grants[0])
assertGrantSame(t, grantChains[0][1], extraGrants[3])
assertGrantSame(t, grantChains[0][2], extraGrants[2])
assertGrantSame(t, grantChains[0][3], extraGrants[0])
}

View file

@ -1,227 +0,0 @@
package trustgraph
import (
"crypto/x509"
"encoding/json"
"io"
"io/ioutil"
"sort"
"strings"
"time"
"github.com/docker/libtrust"
)
type jsonGrant struct {
Subject string `json:"subject"`
Permission uint16 `json:"permission"`
Grantee string `json:"grantee"`
}
type jsonRevocation struct {
Subject string `json:"subject"`
Revocation uint16 `json:"revocation"`
Grantee string `json:"grantee"`
}
type jsonStatement struct {
Revocations []*jsonRevocation `json:"revocations"`
Grants []*jsonGrant `json:"grants"`
Expiration time.Time `json:"expiration"`
IssuedAt time.Time `json:"issuedAt"`
}
func (g *jsonGrant) Grant(statement *Statement) *Grant {
return &Grant{
Subject: g.Subject,
Permission: g.Permission,
Grantee: g.Grantee,
statement: statement,
}
}
// Statement represents a set of grants made from a verifiable
// authority. A statement has an expiration associated with it
// set by the authority.
type Statement struct {
jsonStatement
signature *libtrust.JSONSignature
}
// IsExpired returns whether the statement has expired
func (s *Statement) IsExpired() bool {
return s.Expiration.Before(time.Now().Add(-10 * time.Second))
}
// Bytes returns an indented json representation of the statement
// in a byte array. This value can be written to a file or stream
// without alteration.
func (s *Statement) Bytes() ([]byte, error) {
return s.signature.PrettySignature("signatures")
}
// LoadStatement loads and verifies a statement from an input stream.
func LoadStatement(r io.Reader, authority *x509.CertPool) (*Statement, error) {
b, err := ioutil.ReadAll(r)
if err != nil {
return nil, err
}
js, err := libtrust.ParsePrettySignature(b, "signatures")
if err != nil {
return nil, err
}
payload, err := js.Payload()
if err != nil {
return nil, err
}
var statement Statement
err = json.Unmarshal(payload, &statement.jsonStatement)
if err != nil {
return nil, err
}
if authority == nil {
_, err = js.Verify()
if err != nil {
return nil, err
}
} else {
_, err = js.VerifyChains(authority)
if err != nil {
return nil, err
}
}
statement.signature = js
return &statement, nil
}
// CreateStatements creates and signs a statement from a stream of grants
// and revocations in a JSON array.
func CreateStatement(grants, revocations io.Reader, expiration time.Duration, key libtrust.PrivateKey, chain []*x509.Certificate) (*Statement, error) {
var statement Statement
err := json.NewDecoder(grants).Decode(&statement.jsonStatement.Grants)
if err != nil {
return nil, err
}
err = json.NewDecoder(revocations).Decode(&statement.jsonStatement.Revocations)
if err != nil {
return nil, err
}
statement.jsonStatement.Expiration = time.Now().UTC().Add(expiration)
statement.jsonStatement.IssuedAt = time.Now().UTC()
b, err := json.MarshalIndent(&statement.jsonStatement, "", " ")
if err != nil {
return nil, err
}
statement.signature, err = libtrust.NewJSONSignature(b)
if err != nil {
return nil, err
}
err = statement.signature.SignWithChain(key, chain)
if err != nil {
return nil, err
}
return &statement, nil
}
type statementList []*Statement
func (s statementList) Len() int {
return len(s)
}
func (s statementList) Less(i, j int) bool {
return s[i].IssuedAt.Before(s[j].IssuedAt)
}
func (s statementList) Swap(i, j int) {
s[i], s[j] = s[j], s[i]
}
// CollapseStatements returns a single list of the valid statements as well as the
// time when the next grant will expire.
func CollapseStatements(statements []*Statement, useExpired bool) ([]*Grant, time.Time, error) {
sorted := make(statementList, 0, len(statements))
for _, statement := range statements {
if useExpired || !statement.IsExpired() {
sorted = append(sorted, statement)
}
}
sort.Sort(sorted)
var minExpired time.Time
var grantCount int
roots := map[string]*grantNode{}
for i, statement := range sorted {
if statement.Expiration.Before(minExpired) || i == 0 {
minExpired = statement.Expiration
}
for _, grant := range statement.Grants {
parts := strings.Split(grant.Grantee, "/")
nodes := roots
g := grant.Grant(statement)
grantCount = grantCount + 1
for _, part := range parts {
node, nodeOk := nodes[part]
if !nodeOk {
node = newGrantNode()
nodes[part] = node
}
node.grants = append(node.grants, g)
nodes = node.children
}
}
for _, revocation := range statement.Revocations {
parts := strings.Split(revocation.Grantee, "/")
nodes := roots
var node *grantNode
var nodeOk bool
for _, part := range parts {
node, nodeOk = nodes[part]
if !nodeOk {
break
}
nodes = node.children
}
if node != nil {
for _, grant := range node.grants {
if isSubName(grant.Subject, revocation.Subject) {
grant.Permission = grant.Permission &^ revocation.Revocation
}
}
}
}
}
retGrants := make([]*Grant, 0, grantCount)
for _, rootNodes := range roots {
retGrants = append(retGrants, rootNodes.grants...)
}
return retGrants, minExpired, nil
}
// FilterStatements filters the statements to statements including the given grants.
func FilterStatements(grants []*Grant) ([]*Statement, error) {
statements := map[*Statement]bool{}
for _, grant := range grants {
if grant.statement != nil {
statements[grant.statement] = true
}
}
retStatements := make([]*Statement, len(statements))
var i int
for statement := range statements {
retStatements[i] = statement
i++
}
return retStatements, nil
}

View file

@ -1,417 +0,0 @@
package trustgraph
import (
"bytes"
"crypto/x509"
"encoding/json"
"testing"
"time"
"github.com/docker/libtrust"
"github.com/docker/libtrust/testutil"
)
const testStatementExpiration = time.Hour * 5
func generateStatement(grants []*Grant, key libtrust.PrivateKey, chain []*x509.Certificate) (*Statement, error) {
var statement Statement
statement.Grants = make([]*jsonGrant, len(grants))
for i, grant := range grants {
statement.Grants[i] = &jsonGrant{
Subject: grant.Subject,
Permission: grant.Permission,
Grantee: grant.Grantee,
}
}
statement.IssuedAt = time.Now()
statement.Expiration = time.Now().Add(testStatementExpiration)
statement.Revocations = make([]*jsonRevocation, 0)
marshalled, err := json.MarshalIndent(statement.jsonStatement, "", " ")
if err != nil {
return nil, err
}
sig, err := libtrust.NewJSONSignature(marshalled)
if err != nil {
return nil, err
}
err = sig.SignWithChain(key, chain)
if err != nil {
return nil, err
}
statement.signature = sig
return &statement, nil
}
func generateTrustChain(t *testing.T, chainLen int) (libtrust.PrivateKey, *x509.CertPool, []*x509.Certificate) {
caKey, err := libtrust.GenerateECP256PrivateKey()
if err != nil {
t.Fatalf("Error generating key: %s", err)
}
ca, err := testutil.GenerateTrustCA(caKey.CryptoPublicKey(), caKey.CryptoPrivateKey())
if err != nil {
t.Fatalf("Error generating ca: %s", err)
}
parent := ca
parentKey := caKey
chain := make([]*x509.Certificate, chainLen)
for i := chainLen - 1; i > 0; i-- {
intermediatekey, err := libtrust.GenerateECP256PrivateKey()
if err != nil {
t.Fatalf("Error generate key: %s", err)
}
chain[i], err = testutil.GenerateIntermediate(intermediatekey.CryptoPublicKey(), parentKey.CryptoPrivateKey(), parent)
if err != nil {
t.Fatalf("Error generating intermdiate certificate: %s", err)
}
parent = chain[i]
parentKey = intermediatekey
}
trustKey, err := libtrust.GenerateECP256PrivateKey()
if err != nil {
t.Fatalf("Error generate key: %s", err)
}
chain[0], err = testutil.GenerateTrustCert(trustKey.CryptoPublicKey(), parentKey.CryptoPrivateKey(), parent)
if err != nil {
t.Fatalf("Error generate trust cert: %s", err)
}
caPool := x509.NewCertPool()
caPool.AddCert(ca)
return trustKey, caPool, chain
}
func TestLoadStatement(t *testing.T) {
grantCount := 4
grants, _ := createTestKeysAndGrants(grantCount)
trustKey, caPool, chain := generateTrustChain(t, 6)
statement, err := generateStatement(grants, trustKey, chain)
if err != nil {
t.Fatalf("Error generating statement: %s", err)
}
statementBytes, err := statement.Bytes()
if err != nil {
t.Fatalf("Error getting statement bytes: %s", err)
}
s2, err := LoadStatement(bytes.NewReader(statementBytes), caPool)
if err != nil {
t.Fatalf("Error loading statement: %s", err)
}
if len(s2.Grants) != grantCount {
t.Fatalf("Unexpected grant length\n\tExpected: %d\n\tActual: %d", grantCount, len(s2.Grants))
}
pool := x509.NewCertPool()
_, err = LoadStatement(bytes.NewReader(statementBytes), pool)
if err == nil {
t.Fatalf("No error thrown verifying without an authority")
} else if _, ok := err.(x509.UnknownAuthorityError); !ok {
t.Fatalf("Unexpected error verifying without authority: %s", err)
}
s2, err = LoadStatement(bytes.NewReader(statementBytes), nil)
if err != nil {
t.Fatalf("Error loading statement: %s", err)
}
if len(s2.Grants) != grantCount {
t.Fatalf("Unexpected grant length\n\tExpected: %d\n\tActual: %d", grantCount, len(s2.Grants))
}
badData := make([]byte, len(statementBytes))
copy(badData, statementBytes)
badData[0] = '['
_, err = LoadStatement(bytes.NewReader(badData), nil)
if err == nil {
t.Fatalf("No error thrown parsing bad json")
}
alteredData := make([]byte, len(statementBytes))
copy(alteredData, statementBytes)
alteredData[30] = '0'
_, err = LoadStatement(bytes.NewReader(alteredData), nil)
if err == nil {
t.Fatalf("No error thrown from bad data")
}
}
func TestCollapseGrants(t *testing.T) {
grantCount := 8
grants, keys := createTestKeysAndGrants(grantCount)
linkGrants := make([]*Grant, 4)
linkGrants[0] = &Grant{
Subject: "/user-3",
Permission: 0x0f,
Grantee: "/user-2",
}
linkGrants[1] = &Grant{
Subject: "/user-3/sub-project",
Permission: 0x0f,
Grantee: "/user-4",
}
linkGrants[2] = &Grant{
Subject: "/user-6",
Permission: 0x0f,
Grantee: "/user-7",
}
linkGrants[3] = &Grant{
Subject: "/user-6/sub-project/specific-app",
Permission: 0x0f,
Grantee: "/user-5",
}
trustKey, pool, chain := generateTrustChain(t, 3)
statements := make([]*Statement, 3)
var err error
statements[0], err = generateStatement(grants[0:4], trustKey, chain)
if err != nil {
t.Fatalf("Error generating statement: %s", err)
}
statements[1], err = generateStatement(grants[4:], trustKey, chain)
if err != nil {
t.Fatalf("Error generating statement: %s", err)
}
statements[2], err = generateStatement(linkGrants, trustKey, chain)
if err != nil {
t.Fatalf("Error generating statement: %s", err)
}
statementsCopy := make([]*Statement, len(statements))
for i, statement := range statements {
b, err := statement.Bytes()
if err != nil {
t.Fatalf("Error getting statement bytes: %s", err)
}
verifiedStatement, err := LoadStatement(bytes.NewReader(b), pool)
if err != nil {
t.Fatalf("Error loading statement: %s", err)
}
// Force sort by reversing order
statementsCopy[len(statementsCopy)-i-1] = verifiedStatement
}
statements = statementsCopy
collapsedGrants, expiration, err := CollapseStatements(statements, false)
if len(collapsedGrants) != 12 {
t.Fatalf("Unexpected number of grants\n\tExpected: %d\n\tActual: %d", 12, len(collapsedGrants))
}
if expiration.After(time.Now().Add(time.Hour*5)) || expiration.Before(time.Now()) {
t.Fatalf("Unexpected expiration time: %s", expiration.String())
}
g := NewMemoryGraph(collapsedGrants)
testVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-1", 0x0f)
testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-2", 0x0f)
testVerified(t, g, keys[2].PublicKey(), "user-key-3", "/user-3", 0x0f)
testVerified(t, g, keys[3].PublicKey(), "user-key-4", "/user-4", 0x0f)
testVerified(t, g, keys[4].PublicKey(), "user-key-5", "/user-5", 0x0f)
testVerified(t, g, keys[5].PublicKey(), "user-key-6", "/user-6", 0x0f)
testVerified(t, g, keys[6].PublicKey(), "user-key-7", "/user-7", 0x0f)
testVerified(t, g, keys[7].PublicKey(), "user-key-8", "/user-8", 0x0f)
testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-3", 0x0f)
testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-3/sub-project/specific-app", 0x0f)
testVerified(t, g, keys[3].PublicKey(), "user-key-4", "/user-3/sub-project", 0x0f)
testVerified(t, g, keys[6].PublicKey(), "user-key-7", "/user-6", 0x0f)
testVerified(t, g, keys[6].PublicKey(), "user-key-7", "/user-6/sub-project/specific-app", 0x0f)
testVerified(t, g, keys[4].PublicKey(), "user-key-5", "/user-6/sub-project/specific-app", 0x0f)
testNotVerified(t, g, keys[3].PublicKey(), "user-key-4", "/user-3", 0x0f)
testNotVerified(t, g, keys[3].PublicKey(), "user-key-4", "/user-6/sub-project", 0x0f)
testNotVerified(t, g, keys[4].PublicKey(), "user-key-5", "/user-6/sub-project", 0x0f)
// Add revocation grant
statements = append(statements, &Statement{
jsonStatement{
IssuedAt: time.Now(),
Expiration: time.Now().Add(testStatementExpiration),
Grants: []*jsonGrant{},
Revocations: []*jsonRevocation{
&jsonRevocation{
Subject: "/user-1",
Revocation: 0x0f,
Grantee: keys[0].KeyID(),
},
&jsonRevocation{
Subject: "/user-2",
Revocation: 0x08,
Grantee: keys[1].KeyID(),
},
&jsonRevocation{
Subject: "/user-6",
Revocation: 0x0f,
Grantee: "/user-7",
},
&jsonRevocation{
Subject: "/user-9",
Revocation: 0x0f,
Grantee: "/user-10",
},
},
},
nil,
})
collapsedGrants, expiration, err = CollapseStatements(statements, false)
if len(collapsedGrants) != 12 {
t.Fatalf("Unexpected number of grants\n\tExpected: %d\n\tActual: %d", 12, len(collapsedGrants))
}
if expiration.After(time.Now().Add(time.Hour*5)) || expiration.Before(time.Now()) {
t.Fatalf("Unexpected expiration time: %s", expiration.String())
}
g = NewMemoryGraph(collapsedGrants)
testNotVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-1", 0x0f)
testNotVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-2", 0x0f)
testNotVerified(t, g, keys[6].PublicKey(), "user-key-7", "/user-6/sub-project/specific-app", 0x0f)
testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-2", 0x07)
}
func TestFilterStatements(t *testing.T) {
grantCount := 8
grants, keys := createTestKeysAndGrants(grantCount)
linkGrants := make([]*Grant, 3)
linkGrants[0] = &Grant{
Subject: "/user-3",
Permission: 0x0f,
Grantee: "/user-2",
}
linkGrants[1] = &Grant{
Subject: "/user-5",
Permission: 0x0f,
Grantee: "/user-4",
}
linkGrants[2] = &Grant{
Subject: "/user-7",
Permission: 0x0f,
Grantee: "/user-6",
}
trustKey, _, chain := generateTrustChain(t, 3)
statements := make([]*Statement, 5)
var err error
statements[0], err = generateStatement(grants[0:2], trustKey, chain)
if err != nil {
t.Fatalf("Error generating statement: %s", err)
}
statements[1], err = generateStatement(grants[2:4], trustKey, chain)
if err != nil {
t.Fatalf("Error generating statement: %s", err)
}
statements[2], err = generateStatement(grants[4:6], trustKey, chain)
if err != nil {
t.Fatalf("Error generating statement: %s", err)
}
statements[3], err = generateStatement(grants[6:], trustKey, chain)
if err != nil {
t.Fatalf("Error generating statement: %s", err)
}
statements[4], err = generateStatement(linkGrants, trustKey, chain)
if err != nil {
t.Fatalf("Error generating statement: %s", err)
}
collapsed, _, err := CollapseStatements(statements, false)
if err != nil {
t.Fatalf("Error collapsing grants: %s", err)
}
// Filter 1, all 5 statements
filter1, err := FilterStatements(collapsed)
if err != nil {
t.Fatalf("Error filtering statements: %s", err)
}
if len(filter1) != 5 {
t.Fatalf("Wrong number of statements, expected %d, received %d", 5, len(filter1))
}
// Filter 2, one statement
filter2, err := FilterStatements([]*Grant{collapsed[0]})
if err != nil {
t.Fatalf("Error filtering statements: %s", err)
}
if len(filter2) != 1 {
t.Fatalf("Wrong number of statements, expected %d, received %d", 1, len(filter2))
}
// Filter 3, 2 statements, from graph lookup
g := NewMemoryGraph(collapsed)
lookupGrants, err := g.GetGrants(keys[1], "/user-3", 0x0f)
if err != nil {
t.Fatalf("Error looking up grants: %s", err)
}
if len(lookupGrants) != 1 {
t.Fatalf("Wrong numberof grant chains returned from lookup, expected %d, received %d", 1, len(lookupGrants))
}
if len(lookupGrants[0]) != 2 {
t.Fatalf("Wrong number of grants looked up, expected %d, received %d", 2, len(lookupGrants))
}
filter3, err := FilterStatements(lookupGrants[0])
if err != nil {
t.Fatalf("Error filtering statements: %s", err)
}
if len(filter3) != 2 {
t.Fatalf("Wrong number of statements, expected %d, received %d", 2, len(filter3))
}
}
func TestCreateStatement(t *testing.T) {
grantJSON := bytes.NewReader([]byte(`[
{
"subject": "/user-2",
"permission": 15,
"grantee": "/user-1"
},
{
"subject": "/user-7",
"permission": 1,
"grantee": "/user-9"
},
{
"subject": "/user-3",
"permission": 15,
"grantee": "/user-2"
}
]`))
revocationJSON := bytes.NewReader([]byte(`[
{
"subject": "user-8",
"revocation": 12,
"grantee": "user-9"
}
]`))
trustKey, pool, chain := generateTrustChain(t, 3)
statement, err := CreateStatement(grantJSON, revocationJSON, testStatementExpiration, trustKey, chain)
if err != nil {
t.Fatalf("Error creating statement: %s", err)
}
b, err := statement.Bytes()
if err != nil {
t.Fatalf("Error retrieving bytes: %s", err)
}
verified, err := LoadStatement(bytes.NewReader(b), pool)
if err != nil {
t.Fatalf("Error loading statement: %s", err)
}
if len(verified.Grants) != 3 {
t.Errorf("Unexpected number of grants, expected %d, received %d", 3, len(verified.Grants))
}
if len(verified.Revocations) != 1 {
t.Errorf("Unexpected number of revocations, expected %d, received %d", 1, len(verified.Revocations))
}
}

View file

@ -1,23 +0,0 @@
package libtrust
import (
"encoding/pem"
"reflect"
"testing"
)
func TestAddPEMHeadersToKey(t *testing.T) {
pk := &rsaPublicKey{nil, map[string]interface{}{}}
blk := &pem.Block{Headers: map[string]string{"hosts": "localhost,127.0.0.1"}}
addPEMHeadersToKey(blk, pk)
val := pk.GetExtendedField("hosts")
hosts, ok := val.([]string)
if !ok {
t.Fatalf("hosts type(%v), expected []string", reflect.TypeOf(val))
}
expected := []string{"localhost", "127.0.0.1"}
if !reflect.DeepEqual(hosts, expected) {
t.Errorf("hosts(%v), expected %v", hosts, expected)
}
}

View file

@ -1,65 +0,0 @@
// Copyright 2014 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
// Package redistest contains utilities for writing Redigo tests.
package redistest
import (
"errors"
"time"
"github.com/garyburd/redigo/redis"
)
type testConn struct {
redis.Conn
}
func (t testConn) Close() error {
_, err := t.Conn.Do("SELECT", "9")
if err != nil {
return nil
}
_, err = t.Conn.Do("FLUSHDB")
if err != nil {
return err
}
return t.Conn.Close()
}
// Dial dials the local Redis server and selects database 9. To prevent
// stomping on real data, DialTestDB fails if database 9 contains data. The
// returned connection flushes database 9 on close.
func Dial() (redis.Conn, error) {
c, err := redis.DialTimeout("tcp", ":6379", 0, 1*time.Second, 1*time.Second)
if err != nil {
return nil, err
}
_, err = c.Do("SELECT", "9")
if err != nil {
return nil, err
}
n, err := redis.Int(c.Do("DBSIZE"))
if err != nil {
return nil, err
}
if n != 0 {
return nil, errors.New("database #9 is not empty, test can not continue")
}
return testConn{c}, nil
}

View file

@ -1,542 +0,0 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis_test
import (
"bufio"
"bytes"
"math"
"net"
"reflect"
"strings"
"testing"
"time"
"github.com/garyburd/redigo/internal/redistest"
"github.com/garyburd/redigo/redis"
)
var writeTests = []struct {
args []interface{}
expected string
}{
{
[]interface{}{"SET", "key", "value"},
"*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n",
},
{
[]interface{}{"SET", "key", "value"},
"*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n",
},
{
[]interface{}{"SET", "key", byte(100)},
"*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$3\r\n100\r\n",
},
{
[]interface{}{"SET", "key", 100},
"*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$3\r\n100\r\n",
},
{
[]interface{}{"SET", "key", int64(math.MinInt64)},
"*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$20\r\n-9223372036854775808\r\n",
},
{
[]interface{}{"SET", "key", float64(1349673917.939762)},
"*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$21\r\n1.349673917939762e+09\r\n",
},
{
[]interface{}{"SET", "key", ""},
"*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$0\r\n\r\n",
},
{
[]interface{}{"SET", "key", nil},
"*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$0\r\n\r\n",
},
{
[]interface{}{"ECHO", true, false},
"*3\r\n$4\r\nECHO\r\n$1\r\n1\r\n$1\r\n0\r\n",
},
}
func TestWrite(t *testing.T) {
for _, tt := range writeTests {
var buf bytes.Buffer
rw := bufio.ReadWriter{Writer: bufio.NewWriter(&buf)}
c := redis.NewConnBufio(rw)
err := c.Send(tt.args[0].(string), tt.args[1:]...)
if err != nil {
t.Errorf("Send(%v) returned error %v", tt.args, err)
continue
}
rw.Flush()
actual := buf.String()
if actual != tt.expected {
t.Errorf("Send(%v) = %q, want %q", tt.args, actual, tt.expected)
}
}
}
var errorSentinel = &struct{}{}
var readTests = []struct {
reply string
expected interface{}
}{
{
"+OK\r\n",
"OK",
},
{
"+PONG\r\n",
"PONG",
},
{
"@OK\r\n",
errorSentinel,
},
{
"$6\r\nfoobar\r\n",
[]byte("foobar"),
},
{
"$-1\r\n",
nil,
},
{
":1\r\n",
int64(1),
},
{
":-2\r\n",
int64(-2),
},
{
"*0\r\n",
[]interface{}{},
},
{
"*-1\r\n",
nil,
},
{
"*4\r\n$3\r\nfoo\r\n$3\r\nbar\r\n$5\r\nHello\r\n$5\r\nWorld\r\n",
[]interface{}{[]byte("foo"), []byte("bar"), []byte("Hello"), []byte("World")},
},
{
"*3\r\n$3\r\nfoo\r\n$-1\r\n$3\r\nbar\r\n",
[]interface{}{[]byte("foo"), nil, []byte("bar")},
},
{
// "x" is not a valid length
"$x\r\nfoobar\r\n",
errorSentinel,
},
{
// -2 is not a valid length
"$-2\r\n",
errorSentinel,
},
{
// "x" is not a valid integer
":x\r\n",
errorSentinel,
},
{
// missing \r\n following value
"$6\r\nfoobar",
errorSentinel,
},
{
// short value
"$6\r\nxx",
errorSentinel,
},
{
// long value
"$6\r\nfoobarx\r\n",
errorSentinel,
},
}
func TestRead(t *testing.T) {
for _, tt := range readTests {
rw := bufio.ReadWriter{
Reader: bufio.NewReader(strings.NewReader(tt.reply)),
Writer: bufio.NewWriter(nil), // writer need to support Flush
}
c := redis.NewConnBufio(rw)
actual, err := c.Receive()
if tt.expected == errorSentinel {
if err == nil {
t.Errorf("Receive(%q) did not return expected error", tt.reply)
}
} else {
if err != nil {
t.Errorf("Receive(%q) returned error %v", tt.reply, err)
continue
}
if !reflect.DeepEqual(actual, tt.expected) {
t.Errorf("Receive(%q) = %v, want %v", tt.reply, actual, tt.expected)
}
}
}
}
var testCommands = []struct {
args []interface{}
expected interface{}
}{
{
[]interface{}{"PING"},
"PONG",
},
{
[]interface{}{"SET", "foo", "bar"},
"OK",
},
{
[]interface{}{"GET", "foo"},
[]byte("bar"),
},
{
[]interface{}{"GET", "nokey"},
nil,
},
{
[]interface{}{"MGET", "nokey", "foo"},
[]interface{}{nil, []byte("bar")},
},
{
[]interface{}{"INCR", "mycounter"},
int64(1),
},
{
[]interface{}{"LPUSH", "mylist", "foo"},
int64(1),
},
{
[]interface{}{"LPUSH", "mylist", "bar"},
int64(2),
},
{
[]interface{}{"LRANGE", "mylist", 0, -1},
[]interface{}{[]byte("bar"), []byte("foo")},
},
{
[]interface{}{"MULTI"},
"OK",
},
{
[]interface{}{"LRANGE", "mylist", 0, -1},
"QUEUED",
},
{
[]interface{}{"PING"},
"QUEUED",
},
{
[]interface{}{"EXEC"},
[]interface{}{
[]interface{}{[]byte("bar"), []byte("foo")},
"PONG",
},
},
}
func TestDoCommands(t *testing.T) {
c, err := redistest.Dial()
if err != nil {
t.Fatalf("error connection to database, %v", err)
}
defer c.Close()
for _, cmd := range testCommands {
actual, err := c.Do(cmd.args[0].(string), cmd.args[1:]...)
if err != nil {
t.Errorf("Do(%v) returned error %v", cmd.args, err)
continue
}
if !reflect.DeepEqual(actual, cmd.expected) {
t.Errorf("Do(%v) = %v, want %v", cmd.args, actual, cmd.expected)
}
}
}
func TestPipelineCommands(t *testing.T) {
c, err := redistest.Dial()
if err != nil {
t.Fatalf("error connection to database, %v", err)
}
defer c.Close()
for _, cmd := range testCommands {
if err := c.Send(cmd.args[0].(string), cmd.args[1:]...); err != nil {
t.Fatalf("Send(%v) returned error %v", cmd.args, err)
}
}
if err := c.Flush(); err != nil {
t.Errorf("Flush() returned error %v", err)
}
for _, cmd := range testCommands {
actual, err := c.Receive()
if err != nil {
t.Fatalf("Receive(%v) returned error %v", cmd.args, err)
}
if !reflect.DeepEqual(actual, cmd.expected) {
t.Errorf("Receive(%v) = %v, want %v", cmd.args, actual, cmd.expected)
}
}
}
func TestBlankCommmand(t *testing.T) {
c, err := redistest.Dial()
if err != nil {
t.Fatalf("error connection to database, %v", err)
}
defer c.Close()
for _, cmd := range testCommands {
if err := c.Send(cmd.args[0].(string), cmd.args[1:]...); err != nil {
t.Fatalf("Send(%v) returned error %v", cmd.args, err)
}
}
reply, err := redis.Values(c.Do(""))
if err != nil {
t.Fatalf("Do() returned error %v", err)
}
if len(reply) != len(testCommands) {
t.Fatalf("len(reply)=%d, want %d", len(reply), len(testCommands))
}
for i, cmd := range testCommands {
actual := reply[i]
if !reflect.DeepEqual(actual, cmd.expected) {
t.Errorf("Receive(%v) = %v, want %v", cmd.args, actual, cmd.expected)
}
}
}
func TestRecvBeforeSend(t *testing.T) {
c, err := redistest.Dial()
if err != nil {
t.Fatalf("error connection to database, %v", err)
}
defer c.Close()
done := make(chan struct{})
go func() {
c.Receive()
close(done)
}()
time.Sleep(time.Millisecond)
c.Send("PING")
c.Flush()
<-done
_, err = c.Do("")
if err != nil {
t.Fatalf("error=%v", err)
}
}
func TestError(t *testing.T) {
c, err := redistest.Dial()
if err != nil {
t.Fatalf("error connection to database, %v", err)
}
defer c.Close()
c.Do("SET", "key", "val")
_, err = c.Do("HSET", "key", "fld", "val")
if err == nil {
t.Errorf("Expected err for HSET on string key.")
}
if c.Err() != nil {
t.Errorf("Conn has Err()=%v, expect nil", c.Err())
}
_, err = c.Do("SET", "key", "val")
if err != nil {
t.Errorf("Do(SET, key, val) returned error %v, expected nil.", err)
}
}
func TestReadDeadline(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("net.Listen returned %v", err)
}
defer l.Close()
go func() {
for {
c, err := l.Accept()
if err != nil {
return
}
go func() {
time.Sleep(time.Second)
c.Write([]byte("+OK\r\n"))
c.Close()
}()
}
}()
c1, err := redis.DialTimeout(l.Addr().Network(), l.Addr().String(), 0, time.Millisecond, 0)
if err != nil {
t.Fatalf("redis.Dial returned %v", err)
}
defer c1.Close()
_, err = c1.Do("PING")
if err == nil {
t.Fatalf("c1.Do() returned nil, expect error")
}
if c1.Err() == nil {
t.Fatalf("c1.Err() = nil, expect error")
}
c2, err := redis.DialTimeout(l.Addr().Network(), l.Addr().String(), 0, time.Millisecond, 0)
if err != nil {
t.Fatalf("redis.Dial returned %v", err)
}
defer c2.Close()
c2.Send("PING")
c2.Flush()
_, err = c2.Receive()
if err == nil {
t.Fatalf("c2.Receive() returned nil, expect error")
}
if c2.Err() == nil {
t.Fatalf("c2.Err() = nil, expect error")
}
}
// Connect to local instance of Redis running on the default port.
func ExampleDial(x int) {
c, err := redis.Dial("tcp", ":6379")
if err != nil {
// handle error
}
defer c.Close()
}
// TextExecError tests handling of errors in a transaction. See
// http://redis.io/topics/transactions for information on how Redis handles
// errors in a transaction.
func TestExecError(t *testing.T) {
c, err := redistest.Dial()
if err != nil {
t.Fatalf("error connection to database, %v", err)
}
defer c.Close()
// Execute commands that fail before EXEC is called.
c.Do("ZADD", "k0", 0, 0)
c.Send("MULTI")
c.Send("NOTACOMMAND", "k0", 0, 0)
c.Send("ZINCRBY", "k0", 0, 0)
v, err := c.Do("EXEC")
if err == nil {
t.Fatalf("EXEC returned values %v, expected error", v)
}
// Execute commands that fail after EXEC is called. The first command
// returns an error.
c.Do("ZADD", "k1", 0, 0)
c.Send("MULTI")
c.Send("HSET", "k1", 0, 0)
c.Send("ZINCRBY", "k1", 0, 0)
v, err = c.Do("EXEC")
if err != nil {
t.Fatalf("EXEC returned error %v", err)
}
vs, err := redis.Values(v, nil)
if err != nil {
t.Fatalf("Values(v) returned error %v", err)
}
if len(vs) != 2 {
t.Fatalf("len(vs) == %d, want 2", len(vs))
}
if _, ok := vs[0].(error); !ok {
t.Fatalf("first result is type %T, expected error", vs[0])
}
if _, ok := vs[1].([]byte); !ok {
t.Fatalf("second result is type %T, expected []byte", vs[2])
}
// Execute commands that fail after EXEC is called. The second command
// returns an error.
c.Do("ZADD", "k2", 0, 0)
c.Send("MULTI")
c.Send("ZINCRBY", "k2", 0, 0)
c.Send("HSET", "k2", 0, 0)
v, err = c.Do("EXEC")
if err != nil {
t.Fatalf("EXEC returned error %v", err)
}
vs, err = redis.Values(v, nil)
if err != nil {
t.Fatalf("Values(v) returned error %v", err)
}
if len(vs) != 2 {
t.Fatalf("len(vs) == %d, want 2", len(vs))
}
if _, ok := vs[0].([]byte); !ok {
t.Fatalf("first result is type %T, expected []byte", vs[0])
}
if _, ok := vs[1].(error); !ok {
t.Fatalf("second result is type %T, expected error", vs[2])
}
}
func BenchmarkDoEmpty(b *testing.B) {
b.StopTimer()
c, err := redistest.Dial()
if err != nil {
b.Fatal(err)
}
defer c.Close()
b.StartTimer()
for i := 0; i < b.N; i++ {
if _, err := c.Do(""); err != nil {
b.Fatal(err)
}
}
}
func BenchmarkDoPing(b *testing.B) {
b.StopTimer()
c, err := redistest.Dial()
if err != nil {
b.Fatal(err)
}
defer c.Close()
b.StartTimer()
for i := 0; i < b.N; i++ {
if _, err := c.Do("PING"); err != nil {
b.Fatal(err)
}
}
}

View file

@ -1,674 +0,0 @@
// Copyright 2011 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis_test
import (
"errors"
"io"
"reflect"
"sync"
"testing"
"time"
"github.com/garyburd/redigo/internal/redistest"
"github.com/garyburd/redigo/redis"
)
type poolTestConn struct {
d *poolDialer
err error
redis.Conn
}
func (c *poolTestConn) Close() error { c.d.open -= 1; return nil }
func (c *poolTestConn) Err() error { return c.err }
func (c *poolTestConn) Do(commandName string, args ...interface{}) (reply interface{}, err error) {
if commandName == "ERR" {
c.err = args[0].(error)
commandName = "PING"
}
if commandName != "" {
c.d.commands = append(c.d.commands, commandName)
}
return c.Conn.Do(commandName, args...)
}
func (c *poolTestConn) Send(commandName string, args ...interface{}) error {
c.d.commands = append(c.d.commands, commandName)
return c.Conn.Send(commandName, args...)
}
type poolDialer struct {
t *testing.T
dialed int
open int
commands []string
dialErr error
}
func (d *poolDialer) dial() (redis.Conn, error) {
d.dialed += 1
if d.dialErr != nil {
return nil, d.dialErr
}
c, err := redistest.Dial()
if err != nil {
return nil, err
}
d.open += 1
return &poolTestConn{d: d, Conn: c}, nil
}
func (d *poolDialer) check(message string, p *redis.Pool, dialed, open int) {
if d.dialed != dialed {
d.t.Errorf("%s: dialed=%d, want %d", message, d.dialed, dialed)
}
if d.open != open {
d.t.Errorf("%s: open=%d, want %d", message, d.open, open)
}
if active := p.ActiveCount(); active != open {
d.t.Errorf("%s: active=%d, want %d", message, active, open)
}
}
func TestPoolReuse(t *testing.T) {
d := poolDialer{t: t}
p := &redis.Pool{
MaxIdle: 2,
Dial: d.dial,
}
for i := 0; i < 10; i++ {
c1 := p.Get()
c1.Do("PING")
c2 := p.Get()
c2.Do("PING")
c1.Close()
c2.Close()
}
d.check("before close", p, 2, 2)
p.Close()
d.check("after close", p, 2, 0)
}
func TestPoolMaxIdle(t *testing.T) {
d := poolDialer{t: t}
p := &redis.Pool{
MaxIdle: 2,
Dial: d.dial,
}
for i := 0; i < 10; i++ {
c1 := p.Get()
c1.Do("PING")
c2 := p.Get()
c2.Do("PING")
c3 := p.Get()
c3.Do("PING")
c1.Close()
c2.Close()
c3.Close()
}
d.check("before close", p, 12, 2)
p.Close()
d.check("after close", p, 12, 0)
}
func TestPoolError(t *testing.T) {
d := poolDialer{t: t}
p := &redis.Pool{
MaxIdle: 2,
Dial: d.dial,
}
c := p.Get()
c.Do("ERR", io.EOF)
if c.Err() == nil {
t.Errorf("expected c.Err() != nil")
}
c.Close()
c = p.Get()
c.Do("ERR", io.EOF)
c.Close()
d.check(".", p, 2, 0)
}
func TestPoolClose(t *testing.T) {
d := poolDialer{t: t}
p := &redis.Pool{
MaxIdle: 2,
Dial: d.dial,
}
c1 := p.Get()
c1.Do("PING")
c2 := p.Get()
c2.Do("PING")
c3 := p.Get()
c3.Do("PING")
c1.Close()
if _, err := c1.Do("PING"); err == nil {
t.Errorf("expected error after connection closed")
}
c2.Close()
c2.Close()
p.Close()
d.check("after pool close", p, 3, 1)
if _, err := c1.Do("PING"); err == nil {
t.Errorf("expected error after connection and pool closed")
}
c3.Close()
d.check("after conn close", p, 3, 0)
c1 = p.Get()
if _, err := c1.Do("PING"); err == nil {
t.Errorf("expected error after pool closed")
}
}
func TestPoolTimeout(t *testing.T) {
d := poolDialer{t: t}
p := &redis.Pool{
MaxIdle: 2,
IdleTimeout: 300 * time.Second,
Dial: d.dial,
}
now := time.Now()
redis.SetNowFunc(func() time.Time { return now })
defer redis.SetNowFunc(time.Now)
c := p.Get()
c.Do("PING")
c.Close()
d.check("1", p, 1, 1)
now = now.Add(p.IdleTimeout)
c = p.Get()
c.Do("PING")
c.Close()
d.check("2", p, 2, 1)
p.Close()
}
func TestPoolConcurrenSendReceive(t *testing.T) {
p := &redis.Pool{
Dial: redistest.Dial,
}
c := p.Get()
done := make(chan error, 1)
go func() {
_, err := c.Receive()
done <- err
}()
c.Send("PING")
c.Flush()
err := <-done
if err != nil {
t.Fatalf("Receive() returned error %v", err)
}
_, err = c.Do("")
if err != nil {
t.Fatalf("Do() returned error %v", err)
}
c.Close()
p.Close()
}
func TestPoolBorrowCheck(t *testing.T) {
d := poolDialer{t: t}
p := &redis.Pool{
MaxIdle: 2,
Dial: d.dial,
TestOnBorrow: func(redis.Conn, time.Time) error { return redis.Error("BLAH") },
}
for i := 0; i < 10; i++ {
c := p.Get()
c.Do("PING")
c.Close()
}
d.check("1", p, 10, 1)
p.Close()
}
func TestPoolMaxActive(t *testing.T) {
d := poolDialer{t: t}
p := &redis.Pool{
MaxIdle: 2,
MaxActive: 2,
Dial: d.dial,
}
c1 := p.Get()
c1.Do("PING")
c2 := p.Get()
c2.Do("PING")
d.check("1", p, 2, 2)
c3 := p.Get()
if _, err := c3.Do("PING"); err != redis.ErrPoolExhausted {
t.Errorf("expected pool exhausted")
}
c3.Close()
d.check("2", p, 2, 2)
c2.Close()
d.check("3", p, 2, 2)
c3 = p.Get()
if _, err := c3.Do("PING"); err != nil {
t.Errorf("expected good channel, err=%v", err)
}
c3.Close()
d.check("4", p, 2, 2)
p.Close()
}
func TestPoolMonitorCleanup(t *testing.T) {
d := poolDialer{t: t}
p := &redis.Pool{
MaxIdle: 2,
MaxActive: 2,
Dial: d.dial,
}
c := p.Get()
c.Send("MONITOR")
c.Close()
d.check("", p, 1, 0)
p.Close()
}
func TestPoolPubSubCleanup(t *testing.T) {
d := poolDialer{t: t}
p := &redis.Pool{
MaxIdle: 2,
MaxActive: 2,
Dial: d.dial,
}
c := p.Get()
c.Send("SUBSCRIBE", "x")
c.Close()
want := []string{"SUBSCRIBE", "UNSUBSCRIBE", "PUNSUBSCRIBE", "ECHO"}
if !reflect.DeepEqual(d.commands, want) {
t.Errorf("got commands %v, want %v", d.commands, want)
}
d.commands = nil
c = p.Get()
c.Send("PSUBSCRIBE", "x*")
c.Close()
want = []string{"PSUBSCRIBE", "UNSUBSCRIBE", "PUNSUBSCRIBE", "ECHO"}
if !reflect.DeepEqual(d.commands, want) {
t.Errorf("got commands %v, want %v", d.commands, want)
}
d.commands = nil
p.Close()
}
func TestPoolTransactionCleanup(t *testing.T) {
d := poolDialer{t: t}
p := &redis.Pool{
MaxIdle: 2,
MaxActive: 2,
Dial: d.dial,
}
c := p.Get()
c.Do("WATCH", "key")
c.Do("PING")
c.Close()
want := []string{"WATCH", "PING", "UNWATCH"}
if !reflect.DeepEqual(d.commands, want) {
t.Errorf("got commands %v, want %v", d.commands, want)
}
d.commands = nil
c = p.Get()
c.Do("WATCH", "key")
c.Do("UNWATCH")
c.Do("PING")
c.Close()
want = []string{"WATCH", "UNWATCH", "PING"}
if !reflect.DeepEqual(d.commands, want) {
t.Errorf("got commands %v, want %v", d.commands, want)
}
d.commands = nil
c = p.Get()
c.Do("WATCH", "key")
c.Do("MULTI")
c.Do("PING")
c.Close()
want = []string{"WATCH", "MULTI", "PING", "DISCARD"}
if !reflect.DeepEqual(d.commands, want) {
t.Errorf("got commands %v, want %v", d.commands, want)
}
d.commands = nil
c = p.Get()
c.Do("WATCH", "key")
c.Do("MULTI")
c.Do("DISCARD")
c.Do("PING")
c.Close()
want = []string{"WATCH", "MULTI", "DISCARD", "PING"}
if !reflect.DeepEqual(d.commands, want) {
t.Errorf("got commands %v, want %v", d.commands, want)
}
d.commands = nil
c = p.Get()
c.Do("WATCH", "key")
c.Do("MULTI")
c.Do("EXEC")
c.Do("PING")
c.Close()
want = []string{"WATCH", "MULTI", "EXEC", "PING"}
if !reflect.DeepEqual(d.commands, want) {
t.Errorf("got commands %v, want %v", d.commands, want)
}
d.commands = nil
p.Close()
}
func startGoroutines(p *redis.Pool, cmd string, args ...interface{}) chan error {
errs := make(chan error, 10)
for i := 0; i < cap(errs); i++ {
go func() {
c := p.Get()
_, err := c.Do(cmd, args...)
errs <- err
c.Close()
}()
}
// Wait for goroutines to block.
time.Sleep(time.Second / 4)
return errs
}
func TestWaitPool(t *testing.T) {
d := poolDialer{t: t}
p := &redis.Pool{
MaxIdle: 1,
MaxActive: 1,
Dial: d.dial,
Wait: true,
}
defer p.Close()
c := p.Get()
errs := startGoroutines(p, "PING")
d.check("before close", p, 1, 1)
c.Close()
timeout := time.After(2 * time.Second)
for i := 0; i < cap(errs); i++ {
select {
case err := <-errs:
if err != nil {
t.Fatal(err)
}
case <-timeout:
t.Fatalf("timeout waiting for blocked goroutine %d", i)
}
}
d.check("done", p, 1, 1)
}
func TestWaitPoolClose(t *testing.T) {
d := poolDialer{t: t}
p := &redis.Pool{
MaxIdle: 1,
MaxActive: 1,
Dial: d.dial,
Wait: true,
}
c := p.Get()
if _, err := c.Do("PING"); err != nil {
t.Fatal(err)
}
errs := startGoroutines(p, "PING")
d.check("before close", p, 1, 1)
p.Close()
timeout := time.After(2 * time.Second)
for i := 0; i < cap(errs); i++ {
select {
case err := <-errs:
switch err {
case nil:
t.Fatal("blocked goroutine did not get error")
case redis.ErrPoolExhausted:
t.Fatal("blocked goroutine got pool exhausted error")
}
case <-timeout:
t.Fatal("timeout waiting for blocked goroutine")
}
}
c.Close()
d.check("done", p, 1, 0)
}
func TestWaitPoolCommandError(t *testing.T) {
testErr := errors.New("test")
d := poolDialer{t: t}
p := &redis.Pool{
MaxIdle: 1,
MaxActive: 1,
Dial: d.dial,
Wait: true,
}
defer p.Close()
c := p.Get()
errs := startGoroutines(p, "ERR", testErr)
d.check("before close", p, 1, 1)
c.Close()
timeout := time.After(2 * time.Second)
for i := 0; i < cap(errs); i++ {
select {
case err := <-errs:
if err != nil {
t.Fatal(err)
}
case <-timeout:
t.Fatalf("timeout waiting for blocked goroutine %d", i)
}
}
d.check("done", p, cap(errs), 0)
}
func TestWaitPoolDialError(t *testing.T) {
testErr := errors.New("test")
d := poolDialer{t: t}
p := &redis.Pool{
MaxIdle: 1,
MaxActive: 1,
Dial: d.dial,
Wait: true,
}
defer p.Close()
c := p.Get()
errs := startGoroutines(p, "ERR", testErr)
d.check("before close", p, 1, 1)
d.dialErr = errors.New("dial")
c.Close()
nilCount := 0
errCount := 0
timeout := time.After(2 * time.Second)
for i := 0; i < cap(errs); i++ {
select {
case err := <-errs:
switch err {
case nil:
nilCount++
case d.dialErr:
errCount++
default:
t.Fatalf("expected dial error or nil, got %v", err)
}
case <-timeout:
t.Fatalf("timeout waiting for blocked goroutine %d", i)
}
}
if nilCount != 1 {
t.Errorf("expected one nil error, got %d", nilCount)
}
if errCount != cap(errs)-1 {
t.Errorf("expected %d dial erors, got %d", cap(errs)-1, errCount)
}
d.check("done", p, cap(errs), 0)
}
// Borrowing requires us to iterate over the idle connections, unlock the pool,
// and perform a blocking operation to check the connection still works. If
// TestOnBorrow fails, we must reacquire the lock and continue iteration. This
// test ensures that iteration will work correctly if multiple threads are
// iterating simultaneously.
func TestLocking_TestOnBorrowFails_PoolDoesntCrash(t *testing.T) {
count := 100
// First we'll Create a pool where the pilfering of idle connections fails.
d := poolDialer{t: t}
p := &redis.Pool{
MaxIdle: count,
MaxActive: count,
Dial: d.dial,
TestOnBorrow: func(c redis.Conn, t time.Time) error {
return errors.New("No way back into the real world.")
},
}
defer p.Close()
// Fill the pool with idle connections.
b1 := sync.WaitGroup{}
b1.Add(count)
b2 := sync.WaitGroup{}
b2.Add(count)
for i := 0; i < count; i++ {
go func() {
c := p.Get()
if c.Err() != nil {
t.Errorf("pool get failed: %v", c.Err())
}
b1.Done()
b1.Wait()
c.Close()
b2.Done()
}()
}
b2.Wait()
if d.dialed != count {
t.Errorf("Expected %d dials, got %d", count, d.dialed)
}
// Spawn a bunch of goroutines to thrash the pool.
b2.Add(count)
for i := 0; i < count; i++ {
go func() {
c := p.Get()
if c.Err() != nil {
t.Errorf("pool get failed: %v", c.Err())
}
c.Close()
b2.Done()
}()
}
b2.Wait()
if d.dialed != count*2 {
t.Errorf("Expected %d dials, got %d", count*2, d.dialed)
}
}
func BenchmarkPoolGet(b *testing.B) {
b.StopTimer()
p := redis.Pool{Dial: redistest.Dial, MaxIdle: 2}
c := p.Get()
if err := c.Err(); err != nil {
b.Fatal(err)
}
c.Close()
defer p.Close()
b.StartTimer()
for i := 0; i < b.N; i++ {
c = p.Get()
c.Close()
}
}
func BenchmarkPoolGetErr(b *testing.B) {
b.StopTimer()
p := redis.Pool{Dial: redistest.Dial, MaxIdle: 2}
c := p.Get()
if err := c.Err(); err != nil {
b.Fatal(err)
}
c.Close()
defer p.Close()
b.StartTimer()
for i := 0; i < b.N; i++ {
c = p.Get()
if err := c.Err(); err != nil {
b.Fatal(err)
}
c.Close()
}
}
func BenchmarkPoolGetPing(b *testing.B) {
b.StopTimer()
p := redis.Pool{Dial: redistest.Dial, MaxIdle: 2}
c := p.Get()
if err := c.Err(); err != nil {
b.Fatal(err)
}
c.Close()
defer p.Close()
b.StartTimer()
for i := 0; i < b.N; i++ {
c = p.Get()
if _, err := c.Do("PING"); err != nil {
b.Fatal(err)
}
c.Close()
}
}

View file

@ -1,143 +0,0 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis_test
import (
"fmt"
"net"
"reflect"
"sync"
"testing"
"time"
"github.com/garyburd/redigo/internal/redistest"
"github.com/garyburd/redigo/redis"
)
func publish(channel, value interface{}) {
c, err := dial()
if err != nil {
panic(err)
}
defer c.Close()
c.Do("PUBLISH", channel, value)
}
// Applications can receive pushed messages from one goroutine and manage subscriptions from another goroutine.
func ExamplePubSubConn() {
c, err := dial()
if err != nil {
panic(err)
}
defer c.Close()
var wg sync.WaitGroup
wg.Add(2)
psc := redis.PubSubConn{Conn: c}
// This goroutine receives and prints pushed notifications from the server.
// The goroutine exits when the connection is unsubscribed from all
// channels or there is an error.
go func() {
defer wg.Done()
for {
switch n := psc.Receive().(type) {
case redis.Message:
fmt.Printf("Message: %s %s\n", n.Channel, n.Data)
case redis.PMessage:
fmt.Printf("PMessage: %s %s %s\n", n.Pattern, n.Channel, n.Data)
case redis.Subscription:
fmt.Printf("Subscription: %s %s %d\n", n.Kind, n.Channel, n.Count)
if n.Count == 0 {
return
}
case error:
fmt.Printf("error: %v\n", n)
return
}
}
}()
// This goroutine manages subscriptions for the connection.
go func() {
defer wg.Done()
psc.Subscribe("example")
psc.PSubscribe("p*")
// The following function calls publish a message using another
// connection to the Redis server.
publish("example", "hello")
publish("example", "world")
publish("pexample", "foo")
publish("pexample", "bar")
// Unsubscribe from all connections. This will cause the receiving
// goroutine to exit.
psc.Unsubscribe()
psc.PUnsubscribe()
}()
wg.Wait()
// Output:
// Subscription: subscribe example 1
// Subscription: psubscribe p* 2
// Message: example hello
// Message: example world
// PMessage: p* pexample foo
// PMessage: p* pexample bar
// Subscription: unsubscribe example 1
// Subscription: punsubscribe p* 0
}
func expectPushed(t *testing.T, c redis.PubSubConn, message string, expected interface{}) {
actual := c.Receive()
if !reflect.DeepEqual(actual, expected) {
t.Errorf("%s = %v, want %v", message, actual, expected)
}
}
func TestPushed(t *testing.T) {
pc, err := redistest.Dial()
if err != nil {
t.Fatalf("error connection to database, %v", err)
}
defer pc.Close()
nc, err := net.Dial("tcp", ":6379")
if err != nil {
t.Fatal(err)
}
defer nc.Close()
nc.SetReadDeadline(time.Now().Add(4 * time.Second))
c := redis.PubSubConn{Conn: redis.NewConn(nc, 0, 0)}
c.Subscribe("c1")
expectPushed(t, c, "Subscribe(c1)", redis.Subscription{Kind: "subscribe", Channel: "c1", Count: 1})
c.Subscribe("c2")
expectPushed(t, c, "Subscribe(c2)", redis.Subscription{Kind: "subscribe", Channel: "c2", Count: 2})
c.PSubscribe("p1")
expectPushed(t, c, "PSubscribe(p1)", redis.Subscription{Kind: "psubscribe", Channel: "p1", Count: 3})
c.PSubscribe("p2")
expectPushed(t, c, "PSubscribe(p2)", redis.Subscription{Kind: "psubscribe", Channel: "p2", Count: 4})
c.PUnsubscribe()
expectPushed(t, c, "Punsubscribe(p1)", redis.Subscription{Kind: "punsubscribe", Channel: "p1", Count: 3})
expectPushed(t, c, "Punsubscribe()", redis.Subscription{Kind: "punsubscribe", Channel: "p2", Count: 2})
pc.Do("PUBLISH", "c1", "hello")
expectPushed(t, c, "PUBLISH c1 hello", redis.Message{Channel: "c1", Data: []byte("hello")})
}

Some files were not shown because too many files have changed in this diff Show more