update vendor

Signed-off-by: Jess Frazelle <acidburn@microsoft.com>
This commit is contained in:
Jess Frazelle 2018-09-25 12:27:46 -04:00
parent 19a32db84d
commit 94d1cfbfbf
No known key found for this signature in database
GPG key ID: 18F3685C0022BFF3
10501 changed files with 2307943 additions and 29279 deletions

14
vendor/google.golang.org/grpc/.github/ISSUE_TEMPLATE generated vendored Normal file
View file

@ -0,0 +1,14 @@
Please answer these questions before submitting your issue.
### What version of gRPC are you using?
### What version of Go are you using (`go version`)?
### What operating system (Linux, Windows, …) and version?
### What did you do?
If possible, provide a recipe for reproducing the error.
### What did you expect to see?
### What did you see instead?

View file

@ -1,23 +1,38 @@
language: go
go:
- 1.6.x
- 1.7.x
- 1.8.x
- 1.9.x
matrix:
include:
- go: 1.9.x
- go: 1.11.x
env: VET=1 GO111MODULE=on
- go: 1.11.x
env: RACE=1 GO111MODULE=on
- go: 1.11.x
env: RUN386=1
- go: 1.11.x
env: GRPC_GO_RETRY=on
- go: 1.10.x
- go: 1.9.x
- go: 1.9.x
env: GAE=1
- go: 1.8.x
- go: 1.6.x
go_import_path: google.golang.org/grpc
before_install:
- if [[ "$TRAVIS_GO_VERSION" = 1.9* && "$GOARCH" != "386" ]]; then ./vet.sh -install || exit 1; fi
- if [[ "${GO111MODULE}" = "on" ]]; then mkdir "${HOME}/go"; export GOPATH="${HOME}/go"; fi
- if [[ -n "${RUN386}" ]]; then export GOARCH=386; fi
- if [[ "${TRAVIS_EVENT_TYPE}" = "cron" && -z "${RUN386}" ]]; then RACE=1; fi
- if [[ "${TRAVIS_EVENT_TYPE}" != "cron" ]]; then VET_SKIP_PROTO=1; fi
install:
- if [[ "${GO111MODULE}" = "on" ]]; then go mod download; else make testdeps; fi
- if [[ "${GAE}" = 1 ]]; then source ./install_gae.sh; make testappenginedeps; fi
- if [[ "${VET}" = 1 ]]; then ./vet.sh -install; fi
script:
- if [[ -n "$RUN386" ]]; then export GOARCH=386; fi
- if [[ "$TRAVIS_GO_VERSION" = 1.9* && "$GOARCH" != "386" ]]; then ./vet.sh || exit 1; fi
- make test || exit 1
- if [[ "$GOARCH" != "386" ]]; then make testrace; fi
- set -e
- if [[ "${VET}" = 1 ]]; then ./vet.sh; fi
- if [[ "${GAE}" = 1 ]]; then make testappengine; exit 0; fi
- if [[ "${RACE}" = 1 ]]; then make testrace; exit 0; fi
- make test

View file

@ -27,6 +27,10 @@ How to get your contributions merged smoothly and quickly.
- Keep your PR up to date with upstream/master (if there are merge conflicts, we can't really merge your change).
- **All tests need to be passing** before your change can be merged. We recommend you **run tests locally** before creating your PR to catch breakages early on.
- `make all` to test everything, OR
- `make vet` to catch vet errors
- `make test` to run the tests
- `make testrace` to run tests in race mode
- Exceptions to the rules can be made if there's a compelling reason for doing so.

View file

@ -0,0 +1,80 @@
# Compression
The preferred method for configuring message compression on both clients and
servers is to use
[`encoding.RegisterCompressor`](https://godoc.org/google.golang.org/grpc/encoding#RegisterCompressor)
to register an implementation of a compression algorithm. See
`grpc/encoding/gzip/gzip.go` for an example of how to implement one.
Once a compressor has been registered on the client-side, RPCs may be sent using
it via the
[`UseCompressor`](https://godoc.org/google.golang.org/grpc#UseCompressor)
`CallOption`. Remember that `CallOption`s may be turned into defaults for all
calls from a `ClientConn` by using the
[`WithDefaultCallOptions`](https://godoc.org/google.golang.org/grpc#WithDefaultCallOptions)
`DialOption`. If `UseCompressor` is used and the corresponding compressor has
not been installed, an `Internal` error will be returned to the application
before the RPC is sent.
Server-side, registered compressors will be used automatically to decode request
messages and encode the responses. Servers currently always respond using the
same compression method specified by the client. If the corresponding
compressor has not been registered, an `Unimplemented` status will be returned
to the client.
## Deprecated API
There is a deprecated API for setting compression as well. It is not
recommended for use. However, if you were previously using it, the following
section may be helpful in understanding how it works in combination with the new
API.
### Client-Side
There are two legacy functions and one new function to configure compression:
```go
func WithCompressor(grpc.Compressor) DialOption {}
func WithDecompressor(grpc.Decompressor) DialOption {}
func UseCompressor(name) CallOption {}
```
For outgoing requests, the following rules are applied in order:
1. If `UseCompressor` is used, messages will be compressed using the compressor
named.
* If the compressor named is not registered, an Internal error is returned
back to the client before sending the RPC.
* If UseCompressor("identity"), no compressor will be used, but "identity"
will be sent in the header to the server.
1. If `WithCompressor` is used, messages will be compressed using that
compressor implementation.
1. Otherwise, outbound messages will be uncompressed.
For incoming responses, the following rules are applied in order:
1. If `WithDecompressor` is used and it matches the message's encoding, it will
be used.
1. If a registered compressor matches the response's encoding, it will be used.
1. Otherwise, the stream will be closed and an `Unimplemented` status error will
be returned to the application.
### Server-Side
There are two legacy functions to configure compression:
```go
func RPCCompressor(grpc.Compressor) ServerOption {}
func RPCDecompressor(grpc.Decompressor) ServerOption {}
```
For incoming requests, the following rules are applied in order:
1. If `RPCDecompressor` is used and that decompressor matches the request's
encoding: it will be used.
1. If a registered compressor matches the request's encoding, it will be used.
1. Otherwise, an `Unimplemented` status will be returned to the client.
For outgoing responses, the following rules are applied in order:
1. If `RPCCompressor` is used, that compressor will be used to compress all
response messages.
1. If compression was used for the incoming request and a registered compressor
supports it, that same compression method will be used for the outgoing
response.
1. Otherwise, no compression will be used for the outgoing response.

View file

@ -0,0 +1,33 @@
# Concurrency
In general, gRPC-go provides a concurrency-friendly API. What follows are some
guidelines.
## Clients
A [ClientConn][client-conn] can safely be accessed concurrently. Using
[helloworld][helloworld] as an example, one could share the `ClientConn` across
multiple goroutines to create multiple `GreeterClient` types. In this case, RPCs
would be sent in parallel.
## Streams
When using streams, one must take care to avoid calling either `SendMsg` or
`RecvMsg` multiple times against the same [Stream][stream] from different
goroutines. In other words, it's safe to have a goroutine calling `SendMsg` and
another goroutine calling `RecvMsg` on the same stream at the same time. But it
is not safe to call `SendMsg` on the same stream in different goroutines, or to
call `RecvMsg` on the same stream in different goroutines.
## Servers
Each RPC handler attached to a registered server will be invoked in its own
goroutine. For example, [SayHello][say-hello] will be invoked in its own
goroutine. The same is true for service handlers for streaming RPCs, as seen
in the route guide example [here][route-guide-stream].
[helloworld]: https://github.com/grpc/grpc-go/blob/master/examples/helloworld/greeter_client/main.go#L43
[client-conn]: https://godoc.org/google.golang.org/grpc#ClientConn
[stream]: https://godoc.org/google.golang.org/grpc#Stream
[say-hello]: https://github.com/grpc/grpc-go/blob/master/examples/helloworld/greeter_server/main.go#L41
[route-guide-stream]: https://github.com/grpc/grpc-go/blob/master/examples/route_guide/server/server.go#L126

146
vendor/google.golang.org/grpc/Documentation/encoding.md generated vendored Normal file
View file

@ -0,0 +1,146 @@
# Encoding
The gRPC API for sending and receiving is based upon *messages*. However,
messages cannot be transmitted directly over a network; they must first be
converted into *bytes*. This document describes how gRPC-Go converts messages
into bytes and vice-versa for the purposes of network transmission.
## Codecs (Serialization and Deserialization)
A `Codec` contains code to serialize a message into a byte slice (`Marshal`) and
deserialize a byte slice back into a message (`Unmarshal`). `Codec`s are
registered by name into a global registry maintained in the `encoding` package.
### Implementing a `Codec`
A typical `Codec` will be implemented in its own package with an `init` function
that registers itself, and is imported anonymously. For example:
```go
package proto
import "google.golang.org/grpc/encoding"
func init() {
encoding.RegisterCodec(protoCodec{})
}
// ... implementation of protoCodec ...
```
For an example, gRPC's implementation of the `proto` codec can be found in
[`encoding/proto`](https://godoc.org/google.golang.org/grpc/encoding/proto).
### Using a `Codec`
By default, gRPC registers and uses the "proto" codec, so it is not necessary to
do this in your own code to send and receive proto messages. To use another
`Codec` from a client or server:
```go
package myclient
import _ "path/to/another/codec"
```
`Codec`s, by definition, must be symmetric, so the same desired `Codec` should
be registered in both client and server binaries.
On the client-side, to specify a `Codec` to use for message transmission, the
`CallOption` `CallContentSubtype` should be used as follows:
```go
response, err := myclient.MyCall(ctx, request, grpc.CallContentSubtype("mycodec"))
```
As a reminder, all `CallOption`s may be converted into `DialOption`s that become
the default for all RPCs sent through a client using `grpc.WithDefaultCallOptions`:
```go
myclient := grpc.Dial(ctx, target, grpc.WithDefaultCallOptions(grpc.CallContentSubtype("mycodec")))
```
When specified in either of these ways, messages will be encoded using this
codec and sent along with headers indicating the codec (`content-type` set to
`application/grpc+<codec name>`).
On the server-side, using a `Codec` is as simple as registering it into the
global registry (i.e. `import`ing it). If a message is encoded with the content
sub-type supported by a registered `Codec`, it will be used automatically for
decoding the request and encoding the response. Otherwise, for
backward-compatibility reasons, gRPC will attempt to use the "proto" codec. In
an upcoming change (tracked in [this
issue](https://github.com/grpc/grpc-go/issues/1824)), such requests will be
rejected with status code `Unimplemented` instead.
## Compressors (Compression and Decompression)
Sometimes, the resulting serialization of a message is not space-efficient, and
it may be beneficial to compress this byte stream before transmitting it over
the network. To facilitate this operation, gRPC supports a mechanism for
performing compression and decompression.
A `Compressor` contains code to compress and decompress by wrapping `io.Writer`s
and `io.Reader`s, respectively. (The form of `Compress` and `Decompress` were
chosen to most closely match Go's standard package
[implementations](https://golang.org/pkg/compress/) of compressors. Like
`Codec`s, `Compressor`s are registered by name into a global registry maintained
in the `encoding` package.
### Implementing a `Compressor`
A typical `Compressor` will be implemented in its own package with an `init`
function that registers itself, and is imported anonymously. For example:
```go
package gzip
import "google.golang.org/grpc/encoding"
func init() {
encoding.RegisterCompressor(compressor{})
}
// ... implementation of compressor ...
```
An implementation of a `gzip` compressor can be found in
[`encoding/gzip`](https://godoc.org/google.golang.org/grpc/encoding/gzip).
### Using a `Compressor`
By default, gRPC does not register or use any compressors. To use a
`Compressor` from a client or server:
```go
package myclient
import _ "google.golang.org/grpc/encoding/gzip"
```
`Compressor`s, by definition, must be symmetric, so the same desired
`Compressor` should be registered in both client and server binaries.
On the client-side, to specify a `Compressor` to use for message transmission,
the `CallOption` `UseCompressor` should be used as follows:
```go
response, err := myclient.MyCall(ctx, request, grpc.UseCompressor("gzip"))
```
As a reminder, all `CallOption`s may be converted into `DialOption`s that become
the default for all RPCs sent through a client using `grpc.WithDefaultCallOptions`:
```go
myclient := grpc.Dial(ctx, target, grpc.WithDefaultCallOptions(grpc.UseCompresor("gzip")))
```
When specified in either of these ways, messages will be compressed using this
compressor and sent along with headers indicating the compressor
(`content-coding` set to `<compressor name>`).
On the server-side, using a `Compressor` is as simple as registering it into the
global registry (i.e. `import`ing it). If a message is compressed with the
content coding supported by a registered `Compressor`, it will be used
automatically for decompressing the request and compressing the response.
Otherwise, the request will be rejected with status code `Unimplemented`.

View file

@ -0,0 +1,182 @@
# Mocking Service for gRPC
[Example code unary RPC](https://github.com/grpc/grpc-go/tree/master/examples/helloworld/mock_helloworld)
[Example code streaming RPC](https://github.com/grpc/grpc-go/tree/master/examples/route_guide/mock_routeguide)
## Why?
To test client-side logic without the overhead of connecting to a real server. Mocking enables users to write light-weight unit tests to check functionalities on client-side without invoking RPC calls to a server.
## Idea: Mock the client stub that connects to the server.
We use Gomock to mock the client interface (in the generated code) and programmatically set its methods to expect and return pre-determined values. This enables users to write tests around the client logic and use this mocked stub while making RPC calls.
## How to use Gomock?
Documentation on Gomock can be found [here](https://github.com/golang/mock).
A quick reading of the documentation should enable users to follow the code below.
Consider a gRPC service based on following proto file:
```proto
//helloworld.proto
package helloworld;
message HelloRequest {
string name = 1;
}
message HelloReply {
string name = 1;
}
service Greeter {
rpc SayHello (HelloRequest) returns (HelloReply) {}
}
```
The generated file helloworld.pb.go will have a client interface for each service defined in the proto file. This interface will have methods corresponding to each rpc inside that service.
```Go
type GreeterClient interface {
SayHello(ctx context.Context, in *HelloRequest, opts ...grpc.CallOption) (*HelloReply, error)
}
```
The generated code also contains a struct that implements this interface.
```Go
type greeterClient struct {
cc *grpc.ClientConn
}
func (c *greeterClient) SayHello(ctx context.Context, in *HelloRequest, opts ...grpc.CallOption) (*HelloReply, error){
// ...
// gRPC specific code here
// ...
}
```
Along with this the generated code has a method to create an instance of this struct.
```Go
func NewGreeterClient(cc *grpc.ClientConn) GreeterClient
```
The user code uses this function to create an instance of the struct greeterClient which then can be used to make rpc calls to the server.
We will mock this interface GreeterClient and use an instance of that mock to make rpc calls. These calls instead of going to server will return pre-determined values.
To create a mock well use [mockgen](https://github.com/golang/mock#running-mockgen).
From the directory ``` examples/helloworld/ ``` run ``` mockgen google.golang.org/grpc/examples/helloworld/helloworld GreeterClient > mock_helloworld/hw_mock.go ```
Notice that in the above command we specify GreeterClient as the interface to be mocked.
The user test code can import the package generated by mockgen along with library package gomock to write unit tests around client-side logic.
```Go
import "github.com/golang/mock/gomock"
import hwmock "google.golang.org/grpc/examples/helloworld/mock_helloworld"
```
An instance of the mocked interface can be created as:
```Go
mockGreeterClient := hwmock.NewMockGreeterClient(ctrl)
```
This mocked object can be programmed to expect calls to its methods and return pre-determined values. For instance, we can program mockGreeterClient to expect a call to its method SayHello and return a HelloReply with message “Mocked RPC”.
```Go
mockGreeterClient.EXPECT().SayHello(
gomock.Any(), // expect any value for first parameter
gomock.Any(), // expect any value for second parameter
).Return(&helloworld.HelloReply{Message: “Mocked RPC”}, nil)
```
gomock.Any() indicates that the parameter can have any value or type. We can indicate specific values for built-in types with gomock.Eq().
However, if the test code needs to specify the parameter to have a proto message type, we can replace gomock.Any() with an instance of a struct that implements gomock.Matcher interface.
```Go
type rpcMsg struct {
msg proto.Message
}
func (r *rpcMsg) Matches(msg interface{}) bool {
m, ok := msg.(proto.Message)
if !ok {
return false
}
return proto.Equal(m, r.msg)
}
func (r *rpcMsg) String() string {
return fmt.Sprintf("is %s", r.msg)
}
...
req := &helloworld.HelloRequest{Name: "unit_test"}
mockGreeterClient.EXPECT().SayHello(
gomock.Any(),
&rpcMsg{msg: req},
).Return(&helloworld.HelloReply{Message: "Mocked Interface"}, nil)
```
## Mock streaming RPCs:
For our example we consider the case of bi-directional streaming RPCs. Concretely, we'll write a test for RouteChat function from the route guide example to demonstrate how to write mocks for streams.
RouteChat is a bi-directional streaming RPC, which means calling RouteChat returns a stream that can __Send__ and __Recv__ messages to and from the server, respectively. We'll start by creating a mock of this stream interface returned by RouteChat and then we'll mock the client interface and set expectation on the method RouteChat to return our mocked stream.
### Generating mocking code:
Like before we'll use [mockgen](https://github.com/golang/mock#running-mockgen). From the `examples/route_guide` directory run: `mockgen google.golang.org/grpc/examples/route_guide/routeguide RouteGuideClient,RouteGuide_RouteChatClient > mock_route_guide/rg_mock.go`
Notice that we are mocking both client(`RouteGuideClient`) and stream(`RouteGuide_RouteChatClient`) interfaces here.
This will create a file `rg_mock.go` under directory `mock_route_guide`. This file contins all the mocking code we need to write our test.
In our test code, like before, we import the this mocking code along with the generated code
```go
import (
rgmock "google.golang.org/grpc/examples/route_guide/mock_routeguide"
rgpb "google.golang.org/grpc/examples/route_guide/routeguide"
)
```
Now conside a test that takes the RouteGuide client object as a parameter, makes a RouteChat rpc call and sends a message on the resulting stream. Furthermore, this test expects to see the same message to be received on the stream.
```go
var msg = ...
// Creates a RouteChat call and sends msg on it.
// Checks if the received message was equal to msg.
func testRouteChat(client rgb.RouteChatClient) error{
...
}
```
We can inject our mock in here by simply passing it as an argument to the method.
Creating mock for stream interface:
```go
stream := rgmock.NewMockRouteGuide_RouteChatClient(ctrl)
}
```
Setting Expectations:
```go
stream.EXPECT().Send(gomock.Any()).Return(nil)
stream.EXPECT().Recv().Return(msg, nil)
```
Creating mock for client interface:
```go
rgclient := rgmock.NewMockRouteGuideClient(ctrl)
```
Setting Expectations:
```go
rgclient.EXPECT().RouteChat(gomock.Any()).Return(stream, nil)
```

View file

@ -0,0 +1,78 @@
# Authentication
As outlined in the [gRPC authentication guide](https://grpc.io/docs/guides/auth.html) there are a number of different mechanisms for asserting identity between an client and server. We'll present some code-samples here demonstrating how to provide TLS support encryption and identity assertions as well as passing OAuth2 tokens to services that support it.
# Enabling TLS on a gRPC client
```Go
conn, err := grpc.Dial(serverAddr, grpc.WithTransportCredentials(credentials.NewClientTLSFromCert(nil, "")))
```
# Enabling TLS on a gRPC server
```Go
creds, err := credentials.NewServerTLSFromFile(certFile, keyFile)
if err != nil {
log.Fatalf("Failed to generate credentials %v", err)
}
lis, err := net.Listen("tcp", ":0")
server := grpc.NewServer(grpc.Creds(creds))
...
server.Serve(lis)
```
# OAuth2
For an example of how to configure client and server to use OAuth2 tokens, see
[here](https://github.com/grpc/grpc-go/blob/master/examples/oauth/).
## Validating a token on the server
Clients may use
[metadata.MD](https://godoc.org/google.golang.org/grpc/metadata#MD)
to store tokens and other authentication-related data. To gain access to the
`metadata.MD` object, a server may use
[metadata.FromIncomingContext](https://godoc.org/google.golang.org/grpc/metadata#FromIncomingContext).
With a reference to `metadata.MD` on the server, one needs to simply lookup the
`authorization` key. Note, all keys stored within `metadata.MD` are normalized
to lowercase. See [here](https://godoc.org/google.golang.org/grpc/metadata#New).
It is possible to configure token validation for all RPCs using an interceptor.
A server may configure either a
[grpc.UnaryInterceptor](https://godoc.org/google.golang.org/grpc#UnaryInterceptor)
or a
[grpc.StreamInterceptor](https://godoc.org/google.golang.org/grpc#StreamInterceptor).
## Adding a token to all outgoing client RPCs
To send an OAuth2 token with each RPC, a client may configure the
`grpc.DialOption`
[grpc.WithPerRPCCredentials](https://godoc.org/google.golang.org/grpc#WithPerRPCCredentials).
Alternatively, a client may also use the `grpc.CallOption`
[grpc.PerRPCCredentials](https://godoc.org/google.golang.org/grpc#PerRPCCredentials)
on each invocation of an RPC.
To create a `credentials.PerRPCCredentials`, use
[oauth.NewOauthAccess](https://godoc.org/google.golang.org/grpc/credentials/oauth#NewOauthAccess).
Note, the OAuth2 implementation of `grpc.PerRPCCredentials` requires a client to use
[grpc.WithTransportCredentials](https://godoc.org/google.golang.org/grpc#WithTransportCredentials)
to prevent any insecure transmission of tokens.
# Authenticating with Google
## Google Compute Engine (GCE)
```Go
conn, err := grpc.Dial(serverAddr, grpc.WithTransportCredentials(credentials.NewClientTLSFromCert(nil, "")), grpc.WithPerRPCCredentials(oauth.NewComputeEngine()))
```
## JWT
```Go
jwtCreds, err := oauth.NewServiceAccountFromFile(*serviceAccountKeyFile, *oauthScope)
if err != nil {
log.Fatalf("Failed to create JWT credentials: %v", err)
}
conn, err := grpc.Dial(serverAddr, grpc.WithTransportCredentials(credentials.NewClientTLSFromCert(nil, "")), grpc.WithPerRPCCredentials(jwtCreds))
```

View file

@ -0,0 +1,227 @@
# Metadata
gRPC supports sending metadata between client and server.
This doc shows how to send and receive metadata in gRPC-go.
## Background
Four kinds of service method:
- [Unary RPC](https://grpc.io/docs/guides/concepts.html#unary-rpc)
- [Server streaming RPC](https://grpc.io/docs/guides/concepts.html#server-streaming-rpc)
- [Client streaming RPC](https://grpc.io/docs/guides/concepts.html#client-streaming-rpc)
- [Bidirectional streaming RPC](https://grpc.io/docs/guides/concepts.html#bidirectional-streaming-rpc)
And concept of [metadata](https://grpc.io/docs/guides/concepts.html#metadata).
## Constructing metadata
A metadata can be created using package [metadata](https://godoc.org/google.golang.org/grpc/metadata).
The type MD is actually a map from string to a list of strings:
```go
type MD map[string][]string
```
Metadata can be read like a normal map.
Note that the value type of this map is `[]string`,
so that users can attach multiple values using a single key.
### Creating a new metadata
A metadata can be created from a `map[string]string` using function `New`:
```go
md := metadata.New(map[string]string{"key1": "val1", "key2": "val2"})
```
Another way is to use `Pairs`.
Values with the same key will be merged into a list:
```go
md := metadata.Pairs(
"key1", "val1",
"key1", "val1-2", // "key1" will have map value []string{"val1", "val1-2"}
"key2", "val2",
)
```
__Note:__ all the keys will be automatically converted to lowercase,
so "key1" and "kEy1" will be the same key and their values will be merged into the same list.
This happens for both `New` and `Pairs`.
### Storing binary data in metadata
In metadata, keys are always strings. But values can be strings or binary data.
To store binary data value in metadata, simply add "-bin" suffix to the key.
The values with "-bin" suffixed keys will be encoded when creating the metadata:
```go
md := metadata.Pairs(
"key", "string value",
"key-bin", string([]byte{96, 102}), // this binary data will be encoded (base64) before sending
// and will be decoded after being transferred.
)
```
## Retrieving metadata from context
Metadata can be retrieved from context using `FromIncomingContext`:
```go
func (s *server) SomeRPC(ctx context.Context, in *pb.SomeRequest) (*pb.SomeResponse, err) {
md, ok := metadata.FromIncomingContext(ctx)
// do something with metadata
}
```
## Sending and receiving metadata - client side
[//]: # "TODO: uncomment next line after example source added"
[//]: # "Real metadata sending and receiving examples are available [here](TODO:example_dir)."
### Sending metadata
There are two ways to send metadata to the server. The recommended way is to append kv pairs to the context using
`AppendToOutgoingContext`. This can be used with or without existing metadata on the context. When there is no prior
metadata, metadata is added; when metadata already exists on the context, kv pairs are merged in.
```go
// create a new context with some metadata
ctx := metadata.AppendToOutgoingContext(ctx, "k1", "v1", "k1", "v2", "k2", "v3")
// later, add some more metadata to the context (e.g. in an interceptor)
ctx := metadata.AppendToOutgoingContext(ctx, "k3", "v4")
// make unary RPC
response, err := client.SomeRPC(ctx, someRequest)
// or make streaming RPC
stream, err := client.SomeStreamingRPC(ctx)
```
Alternatively, metadata may be attached to the context using `NewOutgoingContext`. However, this
replaces any existing metadata in the context, so care must be taken to preserve the existing
metadata if desired. This is slower than using `AppendToOutgoingContext`. An example of this
is below:
```go
// create a new context with some metadata
md := metadata.Pairs("k1", "v1", "k1", "v2", "k2", "v3")
ctx := metadata.NewOutgoingContext(context.Background(), md)
// later, add some more metadata to the context (e.g. in an interceptor)
md, _ := metadata.FromOutgoingContext(ctx)
newMD := metadata.Pairs("k3", "v3")
ctx = metadata.NewContext(ctx, metadata.Join(metadata.New(send), newMD))
// make unary RPC
response, err := client.SomeRPC(ctx, someRequest)
// or make streaming RPC
stream, err := client.SomeStreamingRPC(ctx)
```
### Receiving metadata
Metadata that a client can receive includes header and trailer.
#### Unary call
Header and trailer sent along with a unary call can be retrieved using function [Header](https://godoc.org/google.golang.org/grpc#Header) and [Trailer](https://godoc.org/google.golang.org/grpc#Trailer) in [CallOption](https://godoc.org/google.golang.org/grpc#CallOption):
```go
var header, trailer metadata.MD // variable to store header and trailer
r, err := client.SomeRPC(
ctx,
someRequest,
grpc.Header(&header), // will retrieve header
grpc.Trailer(&trailer), // will retrieve trailer
)
// do something with header and trailer
```
#### Streaming call
For streaming calls including:
- Server streaming RPC
- Client streaming RPC
- Bidirectional streaming RPC
Header and trailer can be retrieved from the returned stream using function `Header` and `Trailer` in interface [ClientStream](https://godoc.org/google.golang.org/grpc#ClientStream):
```go
stream, err := client.SomeStreamingRPC(ctx)
// retrieve header
header, err := stream.Header()
// retrieve trailer
trailer := stream.Trailer()
```
## Sending and receiving metadata - server side
[//]: # "TODO: uncomment next line after example source added"
[//]: # "Real metadata sending and receiving examples are available [here](TODO:example_dir)."
### Receiving metadata
To read metadata sent by the client, the server needs to retrieve it from RPC context.
If it is a unary call, the RPC handler's context can be used.
For streaming calls, the server needs to get context from the stream.
#### Unary call
```go
func (s *server) SomeRPC(ctx context.Context, in *pb.someRequest) (*pb.someResponse, error) {
md, ok := metadata.FromIncomingContext(ctx)
// do something with metadata
}
```
#### Streaming call
```go
func (s *server) SomeStreamingRPC(stream pb.Service_SomeStreamingRPCServer) error {
md, ok := metadata.FromIncomingContext(stream.Context()) // get context from stream
// do something with metadata
}
```
### Sending metadata
#### Unary call
To send header and trailer to client in unary call, the server can call [SendHeader](https://godoc.org/google.golang.org/grpc#SendHeader) and [SetTrailer](https://godoc.org/google.golang.org/grpc#SetTrailer) functions in module [grpc](https://godoc.org/google.golang.org/grpc).
These two functions take a context as the first parameter.
It should be the RPC handler's context or one derived from it:
```go
func (s *server) SomeRPC(ctx context.Context, in *pb.someRequest) (*pb.someResponse, error) {
// create and send header
header := metadata.Pairs("header-key", "val")
grpc.SendHeader(ctx, header)
// create and set trailer
trailer := metadata.Pairs("trailer-key", "val")
grpc.SetTrailer(ctx, trailer)
}
```
#### Streaming call
For streaming calls, header and trailer can be sent using function `SendHeader` and `SetTrailer` in interface [ServerStream](https://godoc.org/google.golang.org/grpc#ServerStream):
```go
func (s *server) SomeStreamingRPC(stream pb.Service_SomeStreamingRPCServer) error {
// create and send header
header := metadata.Pairs("header-key", "val")
stream.SendHeader(header)
// create and set trailer
trailer := metadata.Pairs("trailer-key", "val")
stream.SetTrailer(trailer)
}
```

View file

@ -0,0 +1,49 @@
# Log Levels
This document describes the different log levels supported by the grpc-go
library, and under what conditions they should be used.
### Info
Info messages are for informational purposes and may aid in the debugging of
applications or the gRPC library.
Examples:
- The name resolver received an update.
- The balancer updated its picker.
- Significant gRPC state is changing.
At verbosity of 0 (the default), any single info message should not be output
more than once every 5 minutes under normal operation.
### Warning
Warning messages indicate problems that are non-fatal for the application, but
could lead to unexpected behavior or subsequent errors.
Examples:
- Resolver could not resolve target name.
- Error received while connecting to a server.
- Lost or corrupt connection with remote endpoint.
### Error
Error messages represent errors in the usage of gRPC that cannot be returned to
the application as errors, or internal gRPC-Go errors that are recoverable.
Internal errors are detected during gRPC tests and will result in test failures.
Examples:
- Invalid arguments passed to a function that cannot return an error.
- An internal error that cannot be returned or would be inappropriate to return
to the user.
### Fatal
Fatal errors are severe internal errors that are unrecoverable. These lead
directly to panics, and are avoided as much as possible.
Example:
- Internal invariant was violated.
- User attempted an action that cannot return an error gracefully, but would
lead to an invalid state if performed.

View file

@ -0,0 +1,68 @@
# RPC Errors
All service method handlers should return `nil` or errors from the
`status.Status` type. Clients have direct access to the errors.
Upon encountering an error, a gRPC server method handler should create a
`status.Status`. In typical usage, one would use [status.New][new-status]
passing in an appropriate [codes.Code][code] as well as a description of the
error to produce a `status.Status`. Calling [status.Err][status-err] converts
the `status.Status` type into an `error`. As a convenience method, there is also
[status.Error][status-error] which obviates the conversion step. Compare:
```
st := status.New(codes.NotFound, "some description")
err := st.Err()
// vs.
err := status.Error(codes.NotFound, "some description")
```
## Adding additional details to errors
In some cases, it may be necessary to add details for a particular error on the
server side. The [status.WithDetails][with-details] method exists for this
purpose. Clients may then read those details by first converting the plain
`error` type back to a [status.Status][status] and then using
[status.Details][details].
## Example
The [example][example] demonstrates the API discussed above and shows how to add
information about rate limits to the error message using `status.Status`.
To run the example, first start the server:
```
$ go run examples/rpc_errors/server/main.go
```
In a separate session, run the client:
```
$ go run examples/rpc_errors/client/main.go
```
On the first run of the client, all is well:
```
2018/03/12 19:39:33 Greeting: Hello world
```
Upon running the client a second time, the client exceeds the rate limit and
receives an error with details:
```
2018/03/19 16:42:01 Quota failure: violations:<subject:"name:world" description:"Limit one greeting per person" >
exit status 1
```
[status]: https://godoc.org/google.golang.org/grpc/status#Status
[new-status]: https://godoc.org/google.golang.org/grpc/status#New
[code]: https://godoc.org/google.golang.org/grpc/codes#Code
[with-details]: https://godoc.org/google.golang.org/grpc/status#Status.WithDetails
[details]: https://godoc.org/google.golang.org/grpc/status#Status.Details
[status-err]: https://godoc.org/google.golang.org/grpc/status#Status.Err
[status-error]: https://godoc.org/google.golang.org/grpc/status#Error
[example]: https://github.com/grpc/grpc-go/blob/master/examples/rpc_errors

View file

@ -0,0 +1,152 @@
# gRPC Server Reflection Tutorial
gRPC Server Reflection provides information about publicly-accessible gRPC
services on a server, and assists clients at runtime to construct RPC
requests and responses without precompiled service information. It is used by
gRPC CLI, which can be used to introspect server protos and send/receive test
RPCs.
## Enable Server Reflection
gRPC-go Server Reflection is implemented in package [reflection](https://github.com/grpc/grpc-go/tree/master/reflection). To enable server reflection, you need to import this package and register reflection service on your gRPC server.
For example, to enable server reflection in `example/helloworld`, we need to make the following changes:
```diff
--- a/examples/helloworld/greeter_server/main.go
+++ b/examples/helloworld/greeter_server/main.go
@@ -40,6 +40,7 @@ import (
"golang.org/x/net/context"
"google.golang.org/grpc"
pb "google.golang.org/grpc/examples/helloworld/helloworld"
+ "google.golang.org/grpc/reflection"
)
const (
@@ -61,6 +62,8 @@ func main() {
}
s := grpc.NewServer()
pb.RegisterGreeterServer(s, &server{})
+ // Register reflection service on gRPC server.
+ reflection.Register(s)
if err := s.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
}
```
We have made this change in `example/helloworld`, and we will use it as an example to show the use of gRPC server reflection and gRPC CLI in this tutorial.
## gRPC CLI
After enabling Server Reflection in a server application, you can use gRPC CLI to check its services.
gRPC CLI is only available in c++. Instructions on how to use gRPC CLI can be found at [command_line_tool.md](https://github.com/grpc/grpc/blob/master/doc/command_line_tool.md).
To build gRPC CLI:
```sh
git clone https://github.com/grpc/grpc
cd grpc
make grpc_cli
cd bins/opt # grpc_cli is in directory bins/opt/
```
## Use gRPC CLI to check services
First, start the helloworld server in grpc-go directory:
```sh
$ cd <grpc-go-directory>
$ go run examples/helloworld/greeter_server/main.go
```
Open a new terminal and make sure you are in the directory where grpc_cli lives:
```sh
$ cd <grpc-cpp-dirctory>/bins/opt
```
### List services
`grpc_cli ls` command lists services and methods exposed at a given port:
- List all the services exposed at a given port
```sh
$ ./grpc_cli ls localhost:50051
```
output:
```sh
helloworld.Greeter
grpc.reflection.v1alpha.ServerReflection
```
- List one service with details
`grpc_cli ls` command inspects a service given its full name (in the format of
\<package\>.\<service\>). It can print information with a long listing format
when `-l` flag is set. This flag can be used to get more details about a
service.
```sh
$ ./grpc_cli ls localhost:50051 helloworld.Greeter -l
```
output:
```sh
filename: helloworld.proto
package: helloworld;
service Greeter {
rpc SayHello(helloworld.HelloRequest) returns (helloworld.HelloReply) {}
}
```
### List methods
- List one method with details
`grpc_cli ls` command also inspects a method given its full name (in the
format of \<package\>.\<service\>.\<method\>).
```sh
$ ./grpc_cli ls localhost:50051 helloworld.Greeter.SayHello -l
```
output:
```sh
rpc SayHello(helloworld.HelloRequest) returns (helloworld.HelloReply) {}
```
### Inspect message types
We can use`grpc_cli type` command to inspect request/response types given the
full name of the type (in the format of \<package\>.\<type\>).
- Get information about the request type
```sh
$ ./grpc_cli type localhost:50051 helloworld.HelloRequest
```
output:
```sh
message HelloRequest {
optional string name = 1[json_name = "name"];
}
```
### Call a remote method
We can send RPCs to a server and get responses using `grpc_cli call` command.
- Call a unary method
```sh
$ ./grpc_cli call localhost:50051 SayHello "name: 'gRPC CLI'"
```
output:
```sh
message: "Hello gRPC CLI"
```

View file

@ -0,0 +1,34 @@
# Versioning and Releases
Note: This document references terminology defined at http://semver.org.
## Release Frequency
Regular MINOR releases of gRPC-Go are performed every six weeks. Patch releases
to the previous two MINOR releases may be performed on demand or if serious
security problems are discovered.
## Versioning Policy
The gRPC-Go versioning policy follows the Semantic Versioning 2.0.0
specification, with the following exceptions:
- A MINOR version will not _necessarily_ add new functionality.
- MINOR releases will not break backward compatibility, except in the following
circumstances:
- An API was marked as EXPERIMENTAL upon its introduction.
- An API was marked as DEPRECATED in the initial MAJOR release.
- An API is inherently flawed and cannot provide correct or secure behavior.
In these cases, APIs MAY be changed or removed without a MAJOR release.
Otherwise, backward compatibility will be preserved by MINOR releases.
For an API marked as DEPRECATED, an alternative will be available (if
appropriate) for at least three months prior to its removal.
## Release History
Please see our release history on GitHub:
https://github.com/grpc/grpc-go/releases

View file

@ -1,20 +1,14 @@
all: test testrace
deps:
go get -d -v google.golang.org/grpc/...
updatedeps:
go get -d -v -u -f google.golang.org/grpc/...
testdeps:
go get -d -v -t google.golang.org/grpc/...
updatetestdeps:
go get -d -v -t -u -f google.golang.org/grpc/...
all: vet test testrace testappengine
build: deps
go build google.golang.org/grpc/...
clean:
go clean -i google.golang.org/grpc/...
deps:
go get -d -v google.golang.org/grpc/...
proto:
@ if ! which protoc > /dev/null; then \
echo "error: protoc not installed" >&2; \
@ -25,21 +19,42 @@ proto:
test: testdeps
go test -cpu 1,4 -timeout 5m google.golang.org/grpc/...
testappengine: testappenginedeps
goapp test -cpu 1,4 -timeout 5m google.golang.org/grpc/...
testappenginedeps:
goapp get -d -v -t -tags 'appengine appenginevm' google.golang.org/grpc/...
testdeps:
go get -d -v -t google.golang.org/grpc/...
testrace: testdeps
go test -race -cpu 1,4 -timeout 7m google.golang.org/grpc/...
clean:
go clean -i google.golang.org/grpc/...
updatedeps:
go get -d -v -u -f google.golang.org/grpc/...
updatetestdeps:
go get -d -v -t -u -f google.golang.org/grpc/...
vet: vetdeps
./vet.sh
vetdeps:
./vet.sh -install
.PHONY: \
all \
deps \
updatedeps \
testdeps \
updatetestdeps \
build \
clean \
deps \
proto \
test \
testappengine \
testappenginedeps \
testdeps \
testrace \
clean \
coverage
updatedeps \
updatetestdeps \
vet \
vetdeps

View file

@ -16,8 +16,7 @@ $ go get -u google.golang.org/grpc
Prerequisites
-------------
This requires Go 1.6 or later. Go 1.7 will be required as of the next gRPC-Go
release (1.8).
This requires Go 1.6 or later. Go 1.7 will be required soon.
Constraints
-----------

View file

@ -16,81 +16,23 @@
*
*/
// See internal/backoff package for the backoff implementation. This file is
// kept for the exported types and API backward compatility.
package grpc
import (
"math/rand"
"time"
)
// DefaultBackoffConfig uses values specified for backoff in
// https://github.com/grpc/grpc/blob/master/doc/connection-backoff.md.
var DefaultBackoffConfig = BackoffConfig{
MaxDelay: 120 * time.Second,
baseDelay: 1.0 * time.Second,
factor: 1.6,
jitter: 0.2,
}
// backoffStrategy defines the methodology for backing off after a grpc
// connection failure.
//
// This is unexported until the gRPC project decides whether or not to allow
// alternative backoff strategies. Once a decision is made, this type and its
// method may be exported.
type backoffStrategy interface {
// backoff returns the amount of time to wait before the next retry given
// the number of consecutive failures.
backoff(retries int) time.Duration
MaxDelay: 120 * time.Second,
}
// BackoffConfig defines the parameters for the default gRPC backoff strategy.
type BackoffConfig struct {
// MaxDelay is the upper bound of backoff delay.
MaxDelay time.Duration
// TODO(stevvooe): The following fields are not exported, as allowing
// changes would violate the current gRPC specification for backoff. If
// gRPC decides to allow more interesting backoff strategies, these fields
// may be opened up in the future.
// baseDelay is the amount of time to wait before retrying after the first
// failure.
baseDelay time.Duration
// factor is applied to the backoff after each retry.
factor float64
// jitter provides a range to randomize backoff delays.
jitter float64
}
func setDefaults(bc *BackoffConfig) {
md := bc.MaxDelay
*bc = DefaultBackoffConfig
if md > 0 {
bc.MaxDelay = md
}
}
func (bc BackoffConfig) backoff(retries int) time.Duration {
if retries == 0 {
return bc.baseDelay
}
backoff, max := float64(bc.baseDelay), float64(bc.MaxDelay)
for backoff < max && retries > 0 {
backoff *= bc.factor
retries--
}
if backoff > max {
backoff = max
}
// Randomize backoff delays so that if a cluster of requests start at
// the same time, they won't operate in lockstep.
backoff *= 1 + bc.jitter*(rand.Float64()*2-1)
if backoff < 0 {
return 0
}
return time.Duration(backoff)
}

View file

@ -19,7 +19,6 @@
package grpc
import (
"fmt"
"net"
"sync"
@ -32,7 +31,8 @@ import (
)
// Address represents a server the client connects to.
// This is the EXPERIMENTAL API and may be changed or extended in the future.
//
// Deprecated: please use package balancer.
type Address struct {
// Addr is the server address on which a connection will be established.
Addr string
@ -42,6 +42,8 @@ type Address struct {
}
// BalancerConfig specifies the configurations for Balancer.
//
// Deprecated: please use package balancer.
type BalancerConfig struct {
// DialCreds is the transport credential the Balancer implementation can
// use to dial to a remote load balancer server. The Balancer implementations
@ -54,7 +56,8 @@ type BalancerConfig struct {
}
// BalancerGetOptions configures a Get call.
// This is the EXPERIMENTAL API and may be changed or extended in the future.
//
// Deprecated: please use package balancer.
type BalancerGetOptions struct {
// BlockingWait specifies whether Get should block when there is no
// connected address.
@ -62,7 +65,8 @@ type BalancerGetOptions struct {
}
// Balancer chooses network addresses for RPCs.
// This is the EXPERIMENTAL API and may be changed or extended in the future.
//
// Deprecated: please use package balancer.
type Balancer interface {
// Start does the initialization work to bootstrap a Balancer. For example,
// this function may start the name resolution and watch the updates. It will
@ -113,28 +117,10 @@ type Balancer interface {
Close() error
}
// downErr implements net.Error. It is constructed by gRPC internals and passed to the down
// call of Balancer.
type downErr struct {
timeout bool
temporary bool
desc string
}
func (e downErr) Error() string { return e.desc }
func (e downErr) Timeout() bool { return e.timeout }
func (e downErr) Temporary() bool { return e.temporary }
func downErrorf(timeout, temporary bool, format string, a ...interface{}) downErr {
return downErr{
timeout: timeout,
temporary: temporary,
desc: fmt.Sprintf(format, a...),
}
}
// RoundRobin returns a Balancer that selects addresses round-robin. It uses r to watch
// the name resolution updates and updates the addresses available correspondingly.
//
// Deprecated: please use package balancer/roundrobin.
func RoundRobin(r naming.Resolver) Balancer {
return &roundRobin{r: r}
}
@ -403,7 +389,3 @@ func (rr *roundRobin) Close() error {
type pickFirst struct {
*roundRobin
}
func pickFirstBalancerV1(r naming.Resolver) Balancer {
return &pickFirst{&roundRobin{r: r}}
}

274
vendor/google.golang.org/grpc/balancer/balancer.go generated vendored Normal file
View file

@ -0,0 +1,274 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package balancer defines APIs for load balancing in gRPC.
// All APIs in this package are experimental.
package balancer
import (
"errors"
"net"
"strings"
"golang.org/x/net/context"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/resolver"
)
var (
// m is a map from name to balancer builder.
m = make(map[string]Builder)
)
// Register registers the balancer builder to the balancer map. b.Name
// (lowercased) will be used as the name registered with this builder.
//
// NOTE: this function must only be called during initialization time (i.e. in
// an init() function), and is not thread-safe. If multiple Balancers are
// registered with the same name, the one registered last will take effect.
func Register(b Builder) {
m[strings.ToLower(b.Name())] = b
}
// Get returns the resolver builder registered with the given name.
// Note that the compare is done in a case-insenstive fashion.
// If no builder is register with the name, nil will be returned.
func Get(name string) Builder {
if b, ok := m[strings.ToLower(name)]; ok {
return b
}
return nil
}
// SubConn represents a gRPC sub connection.
// Each sub connection contains a list of addresses. gRPC will
// try to connect to them (in sequence), and stop trying the
// remainder once one connection is successful.
//
// The reconnect backoff will be applied on the list, not a single address.
// For example, try_on_all_addresses -> backoff -> try_on_all_addresses.
//
// All SubConns start in IDLE, and will not try to connect. To trigger
// the connecting, Balancers must call Connect.
// When the connection encounters an error, it will reconnect immediately.
// When the connection becomes IDLE, it will not reconnect unless Connect is
// called.
//
// This interface is to be implemented by gRPC. Users should not need a
// brand new implementation of this interface. For the situations like
// testing, the new implementation should embed this interface. This allows
// gRPC to add new methods to this interface.
type SubConn interface {
// UpdateAddresses updates the addresses used in this SubConn.
// gRPC checks if currently-connected address is still in the new list.
// If it's in the list, the connection will be kept.
// If it's not in the list, the connection will gracefully closed, and
// a new connection will be created.
//
// This will trigger a state transition for the SubConn.
UpdateAddresses([]resolver.Address)
// Connect starts the connecting for this SubConn.
Connect()
}
// NewSubConnOptions contains options to create new SubConn.
type NewSubConnOptions struct{}
// ClientConn represents a gRPC ClientConn.
//
// This interface is to be implemented by gRPC. Users should not need a
// brand new implementation of this interface. For the situations like
// testing, the new implementation should embed this interface. This allows
// gRPC to add new methods to this interface.
type ClientConn interface {
// NewSubConn is called by balancer to create a new SubConn.
// It doesn't block and wait for the connections to be established.
// Behaviors of the SubConn can be controlled by options.
NewSubConn([]resolver.Address, NewSubConnOptions) (SubConn, error)
// RemoveSubConn removes the SubConn from ClientConn.
// The SubConn will be shutdown.
RemoveSubConn(SubConn)
// UpdateBalancerState is called by balancer to nofity gRPC that some internal
// state in balancer has changed.
//
// gRPC will update the connectivity state of the ClientConn, and will call pick
// on the new picker to pick new SubConn.
UpdateBalancerState(s connectivity.State, p Picker)
// ResolveNow is called by balancer to notify gRPC to do a name resolving.
ResolveNow(resolver.ResolveNowOption)
// Target returns the dial target for this ClientConn.
Target() string
}
// BuildOptions contains additional information for Build.
type BuildOptions struct {
// DialCreds is the transport credential the Balancer implementation can
// use to dial to a remote load balancer server. The Balancer implementations
// can ignore this if it does not need to talk to another party securely.
DialCreds credentials.TransportCredentials
// Dialer is the custom dialer the Balancer implementation can use to dial
// to a remote load balancer server. The Balancer implementations
// can ignore this if it doesn't need to talk to remote balancer.
Dialer func(context.Context, string) (net.Conn, error)
// ChannelzParentID is the entity parent's channelz unique identification number.
ChannelzParentID int64
}
// Builder creates a balancer.
type Builder interface {
// Build creates a new balancer with the ClientConn.
Build(cc ClientConn, opts BuildOptions) Balancer
// Name returns the name of balancers built by this builder.
// It will be used to pick balancers (for example in service config).
Name() string
}
// PickOptions contains addition information for the Pick operation.
type PickOptions struct {
// FullMethodName is the method name that NewClientStream() is called
// with. The canonical format is /service/Method.
FullMethodName string
}
// DoneInfo contains additional information for done.
type DoneInfo struct {
// Err is the rpc error the RPC finished with. It could be nil.
Err error
// BytesSent indicates if any bytes have been sent to the server.
BytesSent bool
// BytesReceived indicates if any byte has been received from the server.
BytesReceived bool
}
var (
// ErrNoSubConnAvailable indicates no SubConn is available for pick().
// gRPC will block the RPC until a new picker is available via UpdateBalancerState().
ErrNoSubConnAvailable = errors.New("no SubConn is available")
// ErrTransientFailure indicates all SubConns are in TransientFailure.
// WaitForReady RPCs will block, non-WaitForReady RPCs will fail.
ErrTransientFailure = errors.New("all SubConns are in TransientFailure")
)
// Picker is used by gRPC to pick a SubConn to send an RPC.
// Balancer is expected to generate a new picker from its snapshot every time its
// internal state has changed.
//
// The pickers used by gRPC can be updated by ClientConn.UpdateBalancerState().
type Picker interface {
// Pick returns the SubConn to be used to send the RPC.
// The returned SubConn must be one returned by NewSubConn().
//
// This functions is expected to return:
// - a SubConn that is known to be READY;
// - ErrNoSubConnAvailable if no SubConn is available, but progress is being
// made (for example, some SubConn is in CONNECTING mode);
// - other errors if no active connecting is happening (for example, all SubConn
// are in TRANSIENT_FAILURE mode).
//
// If a SubConn is returned:
// - If it is READY, gRPC will send the RPC on it;
// - If it is not ready, or becomes not ready after it's returned, gRPC will block
// until UpdateBalancerState() is called and will call pick on the new picker.
//
// If the returned error is not nil:
// - If the error is ErrNoSubConnAvailable, gRPC will block until UpdateBalancerState()
// - If the error is ErrTransientFailure:
// - If the RPC is wait-for-ready, gRPC will block until UpdateBalancerState()
// is called to pick again;
// - Otherwise, RPC will fail with unavailable error.
// - Else (error is other non-nil error):
// - The RPC will fail with unavailable error.
//
// The returned done() function will be called once the rpc has finished, with the
// final status of that RPC.
// done may be nil if balancer doesn't care about the RPC status.
Pick(ctx context.Context, opts PickOptions) (conn SubConn, done func(DoneInfo), err error)
}
// Balancer takes input from gRPC, manages SubConns, and collects and aggregates
// the connectivity states.
//
// It also generates and updates the Picker used by gRPC to pick SubConns for RPCs.
//
// HandleSubConnectionStateChange, HandleResolvedAddrs and Close are guaranteed
// to be called synchronously from the same goroutine.
// There's no guarantee on picker.Pick, it may be called anytime.
type Balancer interface {
// HandleSubConnStateChange is called by gRPC when the connectivity state
// of sc has changed.
// Balancer is expected to aggregate all the state of SubConn and report
// that back to gRPC.
// Balancer should also generate and update Pickers when its internal state has
// been changed by the new state.
HandleSubConnStateChange(sc SubConn, state connectivity.State)
// HandleResolvedAddrs is called by gRPC to send updated resolved addresses to
// balancers.
// Balancer can create new SubConn or remove SubConn with the addresses.
// An empty address slice and a non-nil error will be passed if the resolver returns
// non-nil error to gRPC.
HandleResolvedAddrs([]resolver.Address, error)
// Close closes the balancer. The balancer is not required to call
// ClientConn.RemoveSubConn for its existing SubConns.
Close()
}
// ConnectivityStateEvaluator takes the connectivity states of multiple SubConns
// and returns one aggregated connectivity state.
//
// It's not thread safe.
type ConnectivityStateEvaluator struct {
numReady uint64 // Number of addrConns in ready state.
numConnecting uint64 // Number of addrConns in connecting state.
numTransientFailure uint64 // Number of addrConns in transientFailure.
}
// RecordTransition records state change happening in subConn and based on that
// it evaluates what aggregated state should be.
//
// - If at least one SubConn in Ready, the aggregated state is Ready;
// - Else if at least one SubConn in Connecting, the aggregated state is Connecting;
// - Else the aggregated state is TransientFailure.
//
// Idle and Shutdown are not considered.
func (cse *ConnectivityStateEvaluator) RecordTransition(oldState, newState connectivity.State) connectivity.State {
// Update counters.
for idx, state := range []connectivity.State{oldState, newState} {
updateVal := 2*uint64(idx) - 1 // -1 for oldState and +1 for new.
switch state {
case connectivity.Ready:
cse.numReady += updateVal
case connectivity.Connecting:
cse.numConnecting += updateVal
case connectivity.TransientFailure:
cse.numTransientFailure += updateVal
}
}
// Evaluate.
if cse.numReady > 0 {
return connectivity.Ready
}
if cse.numConnecting > 0 {
return connectivity.Connecting
}
return connectivity.TransientFailure
}

208
vendor/google.golang.org/grpc/balancer/base/balancer.go generated vendored Normal file
View file

@ -0,0 +1,208 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package base
import (
"golang.org/x/net/context"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/resolver"
)
type baseBuilder struct {
name string
pickerBuilder PickerBuilder
}
func (bb *baseBuilder) Build(cc balancer.ClientConn, opt balancer.BuildOptions) balancer.Balancer {
return &baseBalancer{
cc: cc,
pickerBuilder: bb.pickerBuilder,
subConns: make(map[resolver.Address]balancer.SubConn),
scStates: make(map[balancer.SubConn]connectivity.State),
csEvltr: &connectivityStateEvaluator{},
// Initialize picker to a picker that always return
// ErrNoSubConnAvailable, because when state of a SubConn changes, we
// may call UpdateBalancerState with this picker.
picker: NewErrPicker(balancer.ErrNoSubConnAvailable),
}
}
func (bb *baseBuilder) Name() string {
return bb.name
}
type baseBalancer struct {
cc balancer.ClientConn
pickerBuilder PickerBuilder
csEvltr *connectivityStateEvaluator
state connectivity.State
subConns map[resolver.Address]balancer.SubConn
scStates map[balancer.SubConn]connectivity.State
picker balancer.Picker
}
func (b *baseBalancer) HandleResolvedAddrs(addrs []resolver.Address, err error) {
if err != nil {
grpclog.Infof("base.baseBalancer: HandleResolvedAddrs called with error %v", err)
return
}
grpclog.Infoln("base.baseBalancer: got new resolved addresses: ", addrs)
// addrsSet is the set converted from addrs, it's used for quick lookup of an address.
addrsSet := make(map[resolver.Address]struct{})
for _, a := range addrs {
addrsSet[a] = struct{}{}
if _, ok := b.subConns[a]; !ok {
// a is a new address (not existing in b.subConns).
sc, err := b.cc.NewSubConn([]resolver.Address{a}, balancer.NewSubConnOptions{})
if err != nil {
grpclog.Warningf("base.baseBalancer: failed to create new SubConn: %v", err)
continue
}
b.subConns[a] = sc
b.scStates[sc] = connectivity.Idle
sc.Connect()
}
}
for a, sc := range b.subConns {
// a was removed by resolver.
if _, ok := addrsSet[a]; !ok {
b.cc.RemoveSubConn(sc)
delete(b.subConns, a)
// Keep the state of this sc in b.scStates until sc's state becomes Shutdown.
// The entry will be deleted in HandleSubConnStateChange.
}
}
}
// regeneratePicker takes a snapshot of the balancer, and generates a picker
// from it. The picker is
// - errPicker with ErrTransientFailure if the balancer is in TransientFailure,
// - built by the pickerBuilder with all READY SubConns otherwise.
func (b *baseBalancer) regeneratePicker() {
if b.state == connectivity.TransientFailure {
b.picker = NewErrPicker(balancer.ErrTransientFailure)
return
}
readySCs := make(map[resolver.Address]balancer.SubConn)
// Filter out all ready SCs from full subConn map.
for addr, sc := range b.subConns {
if st, ok := b.scStates[sc]; ok && st == connectivity.Ready {
readySCs[addr] = sc
}
}
b.picker = b.pickerBuilder.Build(readySCs)
}
func (b *baseBalancer) HandleSubConnStateChange(sc balancer.SubConn, s connectivity.State) {
grpclog.Infof("base.baseBalancer: handle SubConn state change: %p, %v", sc, s)
oldS, ok := b.scStates[sc]
if !ok {
grpclog.Infof("base.baseBalancer: got state changes for an unknown SubConn: %p, %v", sc, s)
return
}
b.scStates[sc] = s
switch s {
case connectivity.Idle:
sc.Connect()
case connectivity.Shutdown:
// When an address was removed by resolver, b called RemoveSubConn but
// kept the sc's state in scStates. Remove state for this sc here.
delete(b.scStates, sc)
}
oldAggrState := b.state
b.state = b.csEvltr.recordTransition(oldS, s)
// Regenerate picker when one of the following happens:
// - this sc became ready from not-ready
// - this sc became not-ready from ready
// - the aggregated state of balancer became TransientFailure from non-TransientFailure
// - the aggregated state of balancer became non-TransientFailure from TransientFailure
if (s == connectivity.Ready) != (oldS == connectivity.Ready) ||
(b.state == connectivity.TransientFailure) != (oldAggrState == connectivity.TransientFailure) {
b.regeneratePicker()
}
b.cc.UpdateBalancerState(b.state, b.picker)
}
// Close is a nop because base balancer doesn't have internal state to clean up,
// and it doesn't need to call RemoveSubConn for the SubConns.
func (b *baseBalancer) Close() {
}
// NewErrPicker returns a picker that always returns err on Pick().
func NewErrPicker(err error) balancer.Picker {
return &errPicker{err: err}
}
type errPicker struct {
err error // Pick() always returns this err.
}
func (p *errPicker) Pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) {
return nil, nil, p.err
}
// connectivityStateEvaluator gets updated by addrConns when their
// states transition, based on which it evaluates the state of
// ClientConn.
type connectivityStateEvaluator struct {
numReady uint64 // Number of addrConns in ready state.
numConnecting uint64 // Number of addrConns in connecting state.
numTransientFailure uint64 // Number of addrConns in transientFailure.
}
// recordTransition records state change happening in every subConn and based on
// that it evaluates what aggregated state should be.
// It can only transition between Ready, Connecting and TransientFailure. Other states,
// Idle and Shutdown are transitioned into by ClientConn; in the beginning of the connection
// before any subConn is created ClientConn is in idle state. In the end when ClientConn
// closes it is in Shutdown state.
//
// recordTransition should only be called synchronously from the same goroutine.
func (cse *connectivityStateEvaluator) recordTransition(oldState, newState connectivity.State) connectivity.State {
// Update counters.
for idx, state := range []connectivity.State{oldState, newState} {
updateVal := 2*uint64(idx) - 1 // -1 for oldState and +1 for new.
switch state {
case connectivity.Ready:
cse.numReady += updateVal
case connectivity.Connecting:
cse.numConnecting += updateVal
case connectivity.TransientFailure:
cse.numTransientFailure += updateVal
}
}
// Evaluate.
if cse.numReady > 0 {
return connectivity.Ready
}
if cse.numConnecting > 0 {
return connectivity.Connecting
}
return connectivity.TransientFailure
}

52
vendor/google.golang.org/grpc/balancer/base/base.go generated vendored Normal file
View file

@ -0,0 +1,52 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package base defines a balancer base that can be used to build balancers with
// different picking algorithms.
//
// The base balancer creates a new SubConn for each resolved address. The
// provided picker will only be notified about READY SubConns.
//
// This package is the base of round_robin balancer, its purpose is to be used
// to build round_robin like balancers with complex picking algorithms.
// Balancers with more complicated logic should try to implement a balancer
// builder from scratch.
//
// All APIs in this package are experimental.
package base
import (
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/resolver"
)
// PickerBuilder creates balancer.Picker.
type PickerBuilder interface {
// Build takes a slice of ready SubConns, and returns a picker that will be
// used by gRPC to pick a SubConn.
Build(readySCs map[resolver.Address]balancer.SubConn) balancer.Picker
}
// NewBalancerBuilder returns a balancer builder. The balancers
// built by this builder will use the picker builder to build pickers.
func NewBalancerBuilder(name string, pb PickerBuilder) balancer.Builder {
return &baseBuilder{
name: name,
pickerBuilder: pb,
}
}

View file

@ -0,0 +1,839 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// source: grpc/lb/v1/load_balancer.proto
package grpc_lb_v1 // import "google.golang.org/grpc/balancer/grpclb/grpc_lb_v1"
import proto "github.com/golang/protobuf/proto"
import fmt "fmt"
import math "math"
import duration "github.com/golang/protobuf/ptypes/duration"
import timestamp "github.com/golang/protobuf/ptypes/timestamp"
import (
context "golang.org/x/net/context"
grpc "google.golang.org/grpc"
)
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
type LoadBalanceRequest struct {
// Types that are valid to be assigned to LoadBalanceRequestType:
// *LoadBalanceRequest_InitialRequest
// *LoadBalanceRequest_ClientStats
LoadBalanceRequestType isLoadBalanceRequest_LoadBalanceRequestType `protobuf_oneof:"load_balance_request_type"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *LoadBalanceRequest) Reset() { *m = LoadBalanceRequest{} }
func (m *LoadBalanceRequest) String() string { return proto.CompactTextString(m) }
func (*LoadBalanceRequest) ProtoMessage() {}
func (*LoadBalanceRequest) Descriptor() ([]byte, []int) {
return fileDescriptor_load_balancer_12026aec3f0251ba, []int{0}
}
func (m *LoadBalanceRequest) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_LoadBalanceRequest.Unmarshal(m, b)
}
func (m *LoadBalanceRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_LoadBalanceRequest.Marshal(b, m, deterministic)
}
func (dst *LoadBalanceRequest) XXX_Merge(src proto.Message) {
xxx_messageInfo_LoadBalanceRequest.Merge(dst, src)
}
func (m *LoadBalanceRequest) XXX_Size() int {
return xxx_messageInfo_LoadBalanceRequest.Size(m)
}
func (m *LoadBalanceRequest) XXX_DiscardUnknown() {
xxx_messageInfo_LoadBalanceRequest.DiscardUnknown(m)
}
var xxx_messageInfo_LoadBalanceRequest proto.InternalMessageInfo
type isLoadBalanceRequest_LoadBalanceRequestType interface {
isLoadBalanceRequest_LoadBalanceRequestType()
}
type LoadBalanceRequest_InitialRequest struct {
InitialRequest *InitialLoadBalanceRequest `protobuf:"bytes,1,opt,name=initial_request,json=initialRequest,proto3,oneof"`
}
type LoadBalanceRequest_ClientStats struct {
ClientStats *ClientStats `protobuf:"bytes,2,opt,name=client_stats,json=clientStats,proto3,oneof"`
}
func (*LoadBalanceRequest_InitialRequest) isLoadBalanceRequest_LoadBalanceRequestType() {}
func (*LoadBalanceRequest_ClientStats) isLoadBalanceRequest_LoadBalanceRequestType() {}
func (m *LoadBalanceRequest) GetLoadBalanceRequestType() isLoadBalanceRequest_LoadBalanceRequestType {
if m != nil {
return m.LoadBalanceRequestType
}
return nil
}
func (m *LoadBalanceRequest) GetInitialRequest() *InitialLoadBalanceRequest {
if x, ok := m.GetLoadBalanceRequestType().(*LoadBalanceRequest_InitialRequest); ok {
return x.InitialRequest
}
return nil
}
func (m *LoadBalanceRequest) GetClientStats() *ClientStats {
if x, ok := m.GetLoadBalanceRequestType().(*LoadBalanceRequest_ClientStats); ok {
return x.ClientStats
}
return nil
}
// XXX_OneofFuncs is for the internal use of the proto package.
func (*LoadBalanceRequest) XXX_OneofFuncs() (func(msg proto.Message, b *proto.Buffer) error, func(msg proto.Message, tag, wire int, b *proto.Buffer) (bool, error), func(msg proto.Message) (n int), []interface{}) {
return _LoadBalanceRequest_OneofMarshaler, _LoadBalanceRequest_OneofUnmarshaler, _LoadBalanceRequest_OneofSizer, []interface{}{
(*LoadBalanceRequest_InitialRequest)(nil),
(*LoadBalanceRequest_ClientStats)(nil),
}
}
func _LoadBalanceRequest_OneofMarshaler(msg proto.Message, b *proto.Buffer) error {
m := msg.(*LoadBalanceRequest)
// load_balance_request_type
switch x := m.LoadBalanceRequestType.(type) {
case *LoadBalanceRequest_InitialRequest:
b.EncodeVarint(1<<3 | proto.WireBytes)
if err := b.EncodeMessage(x.InitialRequest); err != nil {
return err
}
case *LoadBalanceRequest_ClientStats:
b.EncodeVarint(2<<3 | proto.WireBytes)
if err := b.EncodeMessage(x.ClientStats); err != nil {
return err
}
case nil:
default:
return fmt.Errorf("LoadBalanceRequest.LoadBalanceRequestType has unexpected type %T", x)
}
return nil
}
func _LoadBalanceRequest_OneofUnmarshaler(msg proto.Message, tag, wire int, b *proto.Buffer) (bool, error) {
m := msg.(*LoadBalanceRequest)
switch tag {
case 1: // load_balance_request_type.initial_request
if wire != proto.WireBytes {
return true, proto.ErrInternalBadWireType
}
msg := new(InitialLoadBalanceRequest)
err := b.DecodeMessage(msg)
m.LoadBalanceRequestType = &LoadBalanceRequest_InitialRequest{msg}
return true, err
case 2: // load_balance_request_type.client_stats
if wire != proto.WireBytes {
return true, proto.ErrInternalBadWireType
}
msg := new(ClientStats)
err := b.DecodeMessage(msg)
m.LoadBalanceRequestType = &LoadBalanceRequest_ClientStats{msg}
return true, err
default:
return false, nil
}
}
func _LoadBalanceRequest_OneofSizer(msg proto.Message) (n int) {
m := msg.(*LoadBalanceRequest)
// load_balance_request_type
switch x := m.LoadBalanceRequestType.(type) {
case *LoadBalanceRequest_InitialRequest:
s := proto.Size(x.InitialRequest)
n += 1 // tag and wire
n += proto.SizeVarint(uint64(s))
n += s
case *LoadBalanceRequest_ClientStats:
s := proto.Size(x.ClientStats)
n += 1 // tag and wire
n += proto.SizeVarint(uint64(s))
n += s
case nil:
default:
panic(fmt.Sprintf("proto: unexpected type %T in oneof", x))
}
return n
}
type InitialLoadBalanceRequest struct {
// The name of the load balanced service (e.g., service.googleapis.com). Its
// length should be less than 256 bytes.
// The name might include a port number. How to handle the port number is up
// to the balancer.
Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *InitialLoadBalanceRequest) Reset() { *m = InitialLoadBalanceRequest{} }
func (m *InitialLoadBalanceRequest) String() string { return proto.CompactTextString(m) }
func (*InitialLoadBalanceRequest) ProtoMessage() {}
func (*InitialLoadBalanceRequest) Descriptor() ([]byte, []int) {
return fileDescriptor_load_balancer_12026aec3f0251ba, []int{1}
}
func (m *InitialLoadBalanceRequest) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_InitialLoadBalanceRequest.Unmarshal(m, b)
}
func (m *InitialLoadBalanceRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_InitialLoadBalanceRequest.Marshal(b, m, deterministic)
}
func (dst *InitialLoadBalanceRequest) XXX_Merge(src proto.Message) {
xxx_messageInfo_InitialLoadBalanceRequest.Merge(dst, src)
}
func (m *InitialLoadBalanceRequest) XXX_Size() int {
return xxx_messageInfo_InitialLoadBalanceRequest.Size(m)
}
func (m *InitialLoadBalanceRequest) XXX_DiscardUnknown() {
xxx_messageInfo_InitialLoadBalanceRequest.DiscardUnknown(m)
}
var xxx_messageInfo_InitialLoadBalanceRequest proto.InternalMessageInfo
func (m *InitialLoadBalanceRequest) GetName() string {
if m != nil {
return m.Name
}
return ""
}
// Contains the number of calls finished for a particular load balance token.
type ClientStatsPerToken struct {
// See Server.load_balance_token.
LoadBalanceToken string `protobuf:"bytes,1,opt,name=load_balance_token,json=loadBalanceToken,proto3" json:"load_balance_token,omitempty"`
// The total number of RPCs that finished associated with the token.
NumCalls int64 `protobuf:"varint,2,opt,name=num_calls,json=numCalls,proto3" json:"num_calls,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *ClientStatsPerToken) Reset() { *m = ClientStatsPerToken{} }
func (m *ClientStatsPerToken) String() string { return proto.CompactTextString(m) }
func (*ClientStatsPerToken) ProtoMessage() {}
func (*ClientStatsPerToken) Descriptor() ([]byte, []int) {
return fileDescriptor_load_balancer_12026aec3f0251ba, []int{2}
}
func (m *ClientStatsPerToken) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ClientStatsPerToken.Unmarshal(m, b)
}
func (m *ClientStatsPerToken) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_ClientStatsPerToken.Marshal(b, m, deterministic)
}
func (dst *ClientStatsPerToken) XXX_Merge(src proto.Message) {
xxx_messageInfo_ClientStatsPerToken.Merge(dst, src)
}
func (m *ClientStatsPerToken) XXX_Size() int {
return xxx_messageInfo_ClientStatsPerToken.Size(m)
}
func (m *ClientStatsPerToken) XXX_DiscardUnknown() {
xxx_messageInfo_ClientStatsPerToken.DiscardUnknown(m)
}
var xxx_messageInfo_ClientStatsPerToken proto.InternalMessageInfo
func (m *ClientStatsPerToken) GetLoadBalanceToken() string {
if m != nil {
return m.LoadBalanceToken
}
return ""
}
func (m *ClientStatsPerToken) GetNumCalls() int64 {
if m != nil {
return m.NumCalls
}
return 0
}
// Contains client level statistics that are useful to load balancing. Each
// count except the timestamp should be reset to zero after reporting the stats.
type ClientStats struct {
// The timestamp of generating the report.
Timestamp *timestamp.Timestamp `protobuf:"bytes,1,opt,name=timestamp,proto3" json:"timestamp,omitempty"`
// The total number of RPCs that started.
NumCallsStarted int64 `protobuf:"varint,2,opt,name=num_calls_started,json=numCallsStarted,proto3" json:"num_calls_started,omitempty"`
// The total number of RPCs that finished.
NumCallsFinished int64 `protobuf:"varint,3,opt,name=num_calls_finished,json=numCallsFinished,proto3" json:"num_calls_finished,omitempty"`
// The total number of RPCs that failed to reach a server except dropped RPCs.
NumCallsFinishedWithClientFailedToSend int64 `protobuf:"varint,6,opt,name=num_calls_finished_with_client_failed_to_send,json=numCallsFinishedWithClientFailedToSend,proto3" json:"num_calls_finished_with_client_failed_to_send,omitempty"`
// The total number of RPCs that finished and are known to have been received
// by a server.
NumCallsFinishedKnownReceived int64 `protobuf:"varint,7,opt,name=num_calls_finished_known_received,json=numCallsFinishedKnownReceived,proto3" json:"num_calls_finished_known_received,omitempty"`
// The list of dropped calls.
CallsFinishedWithDrop []*ClientStatsPerToken `protobuf:"bytes,8,rep,name=calls_finished_with_drop,json=callsFinishedWithDrop,proto3" json:"calls_finished_with_drop,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *ClientStats) Reset() { *m = ClientStats{} }
func (m *ClientStats) String() string { return proto.CompactTextString(m) }
func (*ClientStats) ProtoMessage() {}
func (*ClientStats) Descriptor() ([]byte, []int) {
return fileDescriptor_load_balancer_12026aec3f0251ba, []int{3}
}
func (m *ClientStats) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ClientStats.Unmarshal(m, b)
}
func (m *ClientStats) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_ClientStats.Marshal(b, m, deterministic)
}
func (dst *ClientStats) XXX_Merge(src proto.Message) {
xxx_messageInfo_ClientStats.Merge(dst, src)
}
func (m *ClientStats) XXX_Size() int {
return xxx_messageInfo_ClientStats.Size(m)
}
func (m *ClientStats) XXX_DiscardUnknown() {
xxx_messageInfo_ClientStats.DiscardUnknown(m)
}
var xxx_messageInfo_ClientStats proto.InternalMessageInfo
func (m *ClientStats) GetTimestamp() *timestamp.Timestamp {
if m != nil {
return m.Timestamp
}
return nil
}
func (m *ClientStats) GetNumCallsStarted() int64 {
if m != nil {
return m.NumCallsStarted
}
return 0
}
func (m *ClientStats) GetNumCallsFinished() int64 {
if m != nil {
return m.NumCallsFinished
}
return 0
}
func (m *ClientStats) GetNumCallsFinishedWithClientFailedToSend() int64 {
if m != nil {
return m.NumCallsFinishedWithClientFailedToSend
}
return 0
}
func (m *ClientStats) GetNumCallsFinishedKnownReceived() int64 {
if m != nil {
return m.NumCallsFinishedKnownReceived
}
return 0
}
func (m *ClientStats) GetCallsFinishedWithDrop() []*ClientStatsPerToken {
if m != nil {
return m.CallsFinishedWithDrop
}
return nil
}
type LoadBalanceResponse struct {
// Types that are valid to be assigned to LoadBalanceResponseType:
// *LoadBalanceResponse_InitialResponse
// *LoadBalanceResponse_ServerList
LoadBalanceResponseType isLoadBalanceResponse_LoadBalanceResponseType `protobuf_oneof:"load_balance_response_type"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *LoadBalanceResponse) Reset() { *m = LoadBalanceResponse{} }
func (m *LoadBalanceResponse) String() string { return proto.CompactTextString(m) }
func (*LoadBalanceResponse) ProtoMessage() {}
func (*LoadBalanceResponse) Descriptor() ([]byte, []int) {
return fileDescriptor_load_balancer_12026aec3f0251ba, []int{4}
}
func (m *LoadBalanceResponse) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_LoadBalanceResponse.Unmarshal(m, b)
}
func (m *LoadBalanceResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_LoadBalanceResponse.Marshal(b, m, deterministic)
}
func (dst *LoadBalanceResponse) XXX_Merge(src proto.Message) {
xxx_messageInfo_LoadBalanceResponse.Merge(dst, src)
}
func (m *LoadBalanceResponse) XXX_Size() int {
return xxx_messageInfo_LoadBalanceResponse.Size(m)
}
func (m *LoadBalanceResponse) XXX_DiscardUnknown() {
xxx_messageInfo_LoadBalanceResponse.DiscardUnknown(m)
}
var xxx_messageInfo_LoadBalanceResponse proto.InternalMessageInfo
type isLoadBalanceResponse_LoadBalanceResponseType interface {
isLoadBalanceResponse_LoadBalanceResponseType()
}
type LoadBalanceResponse_InitialResponse struct {
InitialResponse *InitialLoadBalanceResponse `protobuf:"bytes,1,opt,name=initial_response,json=initialResponse,proto3,oneof"`
}
type LoadBalanceResponse_ServerList struct {
ServerList *ServerList `protobuf:"bytes,2,opt,name=server_list,json=serverList,proto3,oneof"`
}
func (*LoadBalanceResponse_InitialResponse) isLoadBalanceResponse_LoadBalanceResponseType() {}
func (*LoadBalanceResponse_ServerList) isLoadBalanceResponse_LoadBalanceResponseType() {}
func (m *LoadBalanceResponse) GetLoadBalanceResponseType() isLoadBalanceResponse_LoadBalanceResponseType {
if m != nil {
return m.LoadBalanceResponseType
}
return nil
}
func (m *LoadBalanceResponse) GetInitialResponse() *InitialLoadBalanceResponse {
if x, ok := m.GetLoadBalanceResponseType().(*LoadBalanceResponse_InitialResponse); ok {
return x.InitialResponse
}
return nil
}
func (m *LoadBalanceResponse) GetServerList() *ServerList {
if x, ok := m.GetLoadBalanceResponseType().(*LoadBalanceResponse_ServerList); ok {
return x.ServerList
}
return nil
}
// XXX_OneofFuncs is for the internal use of the proto package.
func (*LoadBalanceResponse) XXX_OneofFuncs() (func(msg proto.Message, b *proto.Buffer) error, func(msg proto.Message, tag, wire int, b *proto.Buffer) (bool, error), func(msg proto.Message) (n int), []interface{}) {
return _LoadBalanceResponse_OneofMarshaler, _LoadBalanceResponse_OneofUnmarshaler, _LoadBalanceResponse_OneofSizer, []interface{}{
(*LoadBalanceResponse_InitialResponse)(nil),
(*LoadBalanceResponse_ServerList)(nil),
}
}
func _LoadBalanceResponse_OneofMarshaler(msg proto.Message, b *proto.Buffer) error {
m := msg.(*LoadBalanceResponse)
// load_balance_response_type
switch x := m.LoadBalanceResponseType.(type) {
case *LoadBalanceResponse_InitialResponse:
b.EncodeVarint(1<<3 | proto.WireBytes)
if err := b.EncodeMessage(x.InitialResponse); err != nil {
return err
}
case *LoadBalanceResponse_ServerList:
b.EncodeVarint(2<<3 | proto.WireBytes)
if err := b.EncodeMessage(x.ServerList); err != nil {
return err
}
case nil:
default:
return fmt.Errorf("LoadBalanceResponse.LoadBalanceResponseType has unexpected type %T", x)
}
return nil
}
func _LoadBalanceResponse_OneofUnmarshaler(msg proto.Message, tag, wire int, b *proto.Buffer) (bool, error) {
m := msg.(*LoadBalanceResponse)
switch tag {
case 1: // load_balance_response_type.initial_response
if wire != proto.WireBytes {
return true, proto.ErrInternalBadWireType
}
msg := new(InitialLoadBalanceResponse)
err := b.DecodeMessage(msg)
m.LoadBalanceResponseType = &LoadBalanceResponse_InitialResponse{msg}
return true, err
case 2: // load_balance_response_type.server_list
if wire != proto.WireBytes {
return true, proto.ErrInternalBadWireType
}
msg := new(ServerList)
err := b.DecodeMessage(msg)
m.LoadBalanceResponseType = &LoadBalanceResponse_ServerList{msg}
return true, err
default:
return false, nil
}
}
func _LoadBalanceResponse_OneofSizer(msg proto.Message) (n int) {
m := msg.(*LoadBalanceResponse)
// load_balance_response_type
switch x := m.LoadBalanceResponseType.(type) {
case *LoadBalanceResponse_InitialResponse:
s := proto.Size(x.InitialResponse)
n += 1 // tag and wire
n += proto.SizeVarint(uint64(s))
n += s
case *LoadBalanceResponse_ServerList:
s := proto.Size(x.ServerList)
n += 1 // tag and wire
n += proto.SizeVarint(uint64(s))
n += s
case nil:
default:
panic(fmt.Sprintf("proto: unexpected type %T in oneof", x))
}
return n
}
type InitialLoadBalanceResponse struct {
// This is an application layer redirect that indicates the client should use
// the specified server for load balancing. When this field is non-empty in
// the response, the client should open a separate connection to the
// load_balancer_delegate and call the BalanceLoad method. Its length should
// be less than 64 bytes.
LoadBalancerDelegate string `protobuf:"bytes,1,opt,name=load_balancer_delegate,json=loadBalancerDelegate,proto3" json:"load_balancer_delegate,omitempty"`
// This interval defines how often the client should send the client stats
// to the load balancer. Stats should only be reported when the duration is
// positive.
ClientStatsReportInterval *duration.Duration `protobuf:"bytes,2,opt,name=client_stats_report_interval,json=clientStatsReportInterval,proto3" json:"client_stats_report_interval,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *InitialLoadBalanceResponse) Reset() { *m = InitialLoadBalanceResponse{} }
func (m *InitialLoadBalanceResponse) String() string { return proto.CompactTextString(m) }
func (*InitialLoadBalanceResponse) ProtoMessage() {}
func (*InitialLoadBalanceResponse) Descriptor() ([]byte, []int) {
return fileDescriptor_load_balancer_12026aec3f0251ba, []int{5}
}
func (m *InitialLoadBalanceResponse) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_InitialLoadBalanceResponse.Unmarshal(m, b)
}
func (m *InitialLoadBalanceResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_InitialLoadBalanceResponse.Marshal(b, m, deterministic)
}
func (dst *InitialLoadBalanceResponse) XXX_Merge(src proto.Message) {
xxx_messageInfo_InitialLoadBalanceResponse.Merge(dst, src)
}
func (m *InitialLoadBalanceResponse) XXX_Size() int {
return xxx_messageInfo_InitialLoadBalanceResponse.Size(m)
}
func (m *InitialLoadBalanceResponse) XXX_DiscardUnknown() {
xxx_messageInfo_InitialLoadBalanceResponse.DiscardUnknown(m)
}
var xxx_messageInfo_InitialLoadBalanceResponse proto.InternalMessageInfo
func (m *InitialLoadBalanceResponse) GetLoadBalancerDelegate() string {
if m != nil {
return m.LoadBalancerDelegate
}
return ""
}
func (m *InitialLoadBalanceResponse) GetClientStatsReportInterval() *duration.Duration {
if m != nil {
return m.ClientStatsReportInterval
}
return nil
}
type ServerList struct {
// Contains a list of servers selected by the load balancer. The list will
// be updated when server resolutions change or as needed to balance load
// across more servers. The client should consume the server list in order
// unless instructed otherwise via the client_config.
Servers []*Server `protobuf:"bytes,1,rep,name=servers,proto3" json:"servers,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *ServerList) Reset() { *m = ServerList{} }
func (m *ServerList) String() string { return proto.CompactTextString(m) }
func (*ServerList) ProtoMessage() {}
func (*ServerList) Descriptor() ([]byte, []int) {
return fileDescriptor_load_balancer_12026aec3f0251ba, []int{6}
}
func (m *ServerList) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ServerList.Unmarshal(m, b)
}
func (m *ServerList) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_ServerList.Marshal(b, m, deterministic)
}
func (dst *ServerList) XXX_Merge(src proto.Message) {
xxx_messageInfo_ServerList.Merge(dst, src)
}
func (m *ServerList) XXX_Size() int {
return xxx_messageInfo_ServerList.Size(m)
}
func (m *ServerList) XXX_DiscardUnknown() {
xxx_messageInfo_ServerList.DiscardUnknown(m)
}
var xxx_messageInfo_ServerList proto.InternalMessageInfo
func (m *ServerList) GetServers() []*Server {
if m != nil {
return m.Servers
}
return nil
}
// Contains server information. When the drop field is not true, use the other
// fields.
type Server struct {
// A resolved address for the server, serialized in network-byte-order. It may
// either be an IPv4 or IPv6 address.
IpAddress []byte `protobuf:"bytes,1,opt,name=ip_address,json=ipAddress,proto3" json:"ip_address,omitempty"`
// A resolved port number for the server.
Port int32 `protobuf:"varint,2,opt,name=port,proto3" json:"port,omitempty"`
// An opaque but printable token for load reporting. The client must include
// the token of the picked server into the initial metadata when it starts a
// call to that server. The token is used by the server to verify the request
// and to allow the server to report load to the gRPC LB system. The token is
// also used in client stats for reporting dropped calls.
//
// Its length can be variable but must be less than 50 bytes.
LoadBalanceToken string `protobuf:"bytes,3,opt,name=load_balance_token,json=loadBalanceToken,proto3" json:"load_balance_token,omitempty"`
// Indicates whether this particular request should be dropped by the client.
// If the request is dropped, there will be a corresponding entry in
// ClientStats.calls_finished_with_drop.
Drop bool `protobuf:"varint,4,opt,name=drop,proto3" json:"drop,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *Server) Reset() { *m = Server{} }
func (m *Server) String() string { return proto.CompactTextString(m) }
func (*Server) ProtoMessage() {}
func (*Server) Descriptor() ([]byte, []int) {
return fileDescriptor_load_balancer_12026aec3f0251ba, []int{7}
}
func (m *Server) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_Server.Unmarshal(m, b)
}
func (m *Server) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_Server.Marshal(b, m, deterministic)
}
func (dst *Server) XXX_Merge(src proto.Message) {
xxx_messageInfo_Server.Merge(dst, src)
}
func (m *Server) XXX_Size() int {
return xxx_messageInfo_Server.Size(m)
}
func (m *Server) XXX_DiscardUnknown() {
xxx_messageInfo_Server.DiscardUnknown(m)
}
var xxx_messageInfo_Server proto.InternalMessageInfo
func (m *Server) GetIpAddress() []byte {
if m != nil {
return m.IpAddress
}
return nil
}
func (m *Server) GetPort() int32 {
if m != nil {
return m.Port
}
return 0
}
func (m *Server) GetLoadBalanceToken() string {
if m != nil {
return m.LoadBalanceToken
}
return ""
}
func (m *Server) GetDrop() bool {
if m != nil {
return m.Drop
}
return false
}
func init() {
proto.RegisterType((*LoadBalanceRequest)(nil), "grpc.lb.v1.LoadBalanceRequest")
proto.RegisterType((*InitialLoadBalanceRequest)(nil), "grpc.lb.v1.InitialLoadBalanceRequest")
proto.RegisterType((*ClientStatsPerToken)(nil), "grpc.lb.v1.ClientStatsPerToken")
proto.RegisterType((*ClientStats)(nil), "grpc.lb.v1.ClientStats")
proto.RegisterType((*LoadBalanceResponse)(nil), "grpc.lb.v1.LoadBalanceResponse")
proto.RegisterType((*InitialLoadBalanceResponse)(nil), "grpc.lb.v1.InitialLoadBalanceResponse")
proto.RegisterType((*ServerList)(nil), "grpc.lb.v1.ServerList")
proto.RegisterType((*Server)(nil), "grpc.lb.v1.Server")
}
// Reference imports to suppress errors if they are not otherwise used.
var _ context.Context
var _ grpc.ClientConn
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
const _ = grpc.SupportPackageIsVersion4
// LoadBalancerClient is the client API for LoadBalancer service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream.
type LoadBalancerClient interface {
// Bidirectional rpc to get a list of servers.
BalanceLoad(ctx context.Context, opts ...grpc.CallOption) (LoadBalancer_BalanceLoadClient, error)
}
type loadBalancerClient struct {
cc *grpc.ClientConn
}
func NewLoadBalancerClient(cc *grpc.ClientConn) LoadBalancerClient {
return &loadBalancerClient{cc}
}
func (c *loadBalancerClient) BalanceLoad(ctx context.Context, opts ...grpc.CallOption) (LoadBalancer_BalanceLoadClient, error) {
stream, err := c.cc.NewStream(ctx, &_LoadBalancer_serviceDesc.Streams[0], "/grpc.lb.v1.LoadBalancer/BalanceLoad", opts...)
if err != nil {
return nil, err
}
x := &loadBalancerBalanceLoadClient{stream}
return x, nil
}
type LoadBalancer_BalanceLoadClient interface {
Send(*LoadBalanceRequest) error
Recv() (*LoadBalanceResponse, error)
grpc.ClientStream
}
type loadBalancerBalanceLoadClient struct {
grpc.ClientStream
}
func (x *loadBalancerBalanceLoadClient) Send(m *LoadBalanceRequest) error {
return x.ClientStream.SendMsg(m)
}
func (x *loadBalancerBalanceLoadClient) Recv() (*LoadBalanceResponse, error) {
m := new(LoadBalanceResponse)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// LoadBalancerServer is the server API for LoadBalancer service.
type LoadBalancerServer interface {
// Bidirectional rpc to get a list of servers.
BalanceLoad(LoadBalancer_BalanceLoadServer) error
}
func RegisterLoadBalancerServer(s *grpc.Server, srv LoadBalancerServer) {
s.RegisterService(&_LoadBalancer_serviceDesc, srv)
}
func _LoadBalancer_BalanceLoad_Handler(srv interface{}, stream grpc.ServerStream) error {
return srv.(LoadBalancerServer).BalanceLoad(&loadBalancerBalanceLoadServer{stream})
}
type LoadBalancer_BalanceLoadServer interface {
Send(*LoadBalanceResponse) error
Recv() (*LoadBalanceRequest, error)
grpc.ServerStream
}
type loadBalancerBalanceLoadServer struct {
grpc.ServerStream
}
func (x *loadBalancerBalanceLoadServer) Send(m *LoadBalanceResponse) error {
return x.ServerStream.SendMsg(m)
}
func (x *loadBalancerBalanceLoadServer) Recv() (*LoadBalanceRequest, error) {
m := new(LoadBalanceRequest)
if err := x.ServerStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
var _LoadBalancer_serviceDesc = grpc.ServiceDesc{
ServiceName: "grpc.lb.v1.LoadBalancer",
HandlerType: (*LoadBalancerServer)(nil),
Methods: []grpc.MethodDesc{},
Streams: []grpc.StreamDesc{
{
StreamName: "BalanceLoad",
Handler: _LoadBalancer_BalanceLoad_Handler,
ServerStreams: true,
ClientStreams: true,
},
},
Metadata: "grpc/lb/v1/load_balancer.proto",
}
func init() {
proto.RegisterFile("grpc/lb/v1/load_balancer.proto", fileDescriptor_load_balancer_12026aec3f0251ba)
}
var fileDescriptor_load_balancer_12026aec3f0251ba = []byte{
// 752 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x84, 0x55, 0xdd, 0x6e, 0x23, 0x35,
0x14, 0xee, 0x90, 0x69, 0x36, 0x39, 0x29, 0x34, 0xeb, 0x85, 0x65, 0x92, 0xdd, 0x6d, 0x4b, 0x24,
0x56, 0x11, 0x2a, 0x13, 0x52, 0xb8, 0x00, 0x89, 0x0b, 0x48, 0xab, 0x2a, 0x2d, 0xbd, 0x88, 0x9c,
0x4a, 0x45, 0x95, 0x90, 0x99, 0xc9, 0xb8, 0xa9, 0x55, 0xc7, 0x1e, 0x3c, 0x4e, 0x2a, 0xae, 0x79,
0x1f, 0xc4, 0x2b, 0x20, 0x5e, 0x0c, 0x8d, 0xed, 0x49, 0xa6, 0x49, 0xa3, 0xbd, 0xca, 0xf8, 0x9c,
0xcf, 0xdf, 0xf9, 0xfd, 0x1c, 0x38, 0x98, 0xaa, 0x74, 0xd2, 0xe3, 0x71, 0x6f, 0xd1, 0xef, 0x71,
0x19, 0x25, 0x24, 0x8e, 0x78, 0x24, 0x26, 0x54, 0x85, 0xa9, 0x92, 0x5a, 0x22, 0xc8, 0xfd, 0x21,
0x8f, 0xc3, 0x45, 0xbf, 0x7d, 0x30, 0x95, 0x72, 0xca, 0x69, 0xcf, 0x78, 0xe2, 0xf9, 0x5d, 0x2f,
0x99, 0xab, 0x48, 0x33, 0x29, 0x2c, 0xb6, 0x7d, 0xb8, 0xee, 0xd7, 0x6c, 0x46, 0x33, 0x1d, 0xcd,
0x52, 0x0b, 0xe8, 0xfc, 0xeb, 0x01, 0xba, 0x92, 0x51, 0x32, 0xb0, 0x31, 0x30, 0xfd, 0x63, 0x4e,
0x33, 0x8d, 0x46, 0xb0, 0xcf, 0x04, 0xd3, 0x2c, 0xe2, 0x44, 0x59, 0x53, 0xe0, 0x1d, 0x79, 0xdd,
0xc6, 0xc9, 0x97, 0xe1, 0x2a, 0x7a, 0x78, 0x61, 0x21, 0x9b, 0xf7, 0x87, 0x3b, 0xf8, 0x13, 0x77,
0xbf, 0x60, 0xfc, 0x11, 0xf6, 0x26, 0x9c, 0x51, 0xa1, 0x49, 0xa6, 0x23, 0x9d, 0x05, 0x1f, 0x19,
0xba, 0xcf, 0xcb, 0x74, 0xa7, 0xc6, 0x3f, 0xce, 0xdd, 0xc3, 0x1d, 0xdc, 0x98, 0xac, 0x8e, 0x83,
0x37, 0xd0, 0x2a, 0xb7, 0xa2, 0x48, 0x8a, 0xe8, 0x3f, 0x53, 0xda, 0xe9, 0x41, 0x6b, 0x6b, 0x26,
0x08, 0x81, 0x2f, 0xa2, 0x19, 0x35, 0xe9, 0xd7, 0xb1, 0xf9, 0xee, 0xfc, 0x0e, 0xaf, 0x4a, 0xb1,
0x46, 0x54, 0x5d, 0xcb, 0x07, 0x2a, 0xd0, 0x31, 0xa0, 0x27, 0x41, 0x74, 0x6e, 0x75, 0x17, 0x9b,
0x7c, 0x45, 0x6d, 0xd1, 0x6f, 0xa0, 0x2e, 0xe6, 0x33, 0x32, 0x89, 0x38, 0xb7, 0xd5, 0x54, 0x70,
0x4d, 0xcc, 0x67, 0xa7, 0xf9, 0xb9, 0xf3, 0x4f, 0x05, 0x1a, 0xa5, 0x10, 0xe8, 0x7b, 0xa8, 0x2f,
0x3b, 0xef, 0x3a, 0xd9, 0x0e, 0xed, 0x6c, 0xc2, 0x62, 0x36, 0xe1, 0x75, 0x81, 0xc0, 0x2b, 0x30,
0xfa, 0x0a, 0x5e, 0x2e, 0xc3, 0xe4, 0xad, 0x53, 0x9a, 0x26, 0x2e, 0xdc, 0x7e, 0x11, 0x6e, 0x6c,
0xcd, 0x79, 0x01, 0x2b, 0xec, 0x1d, 0x13, 0x2c, 0xbb, 0xa7, 0x49, 0x50, 0x31, 0xe0, 0x66, 0x01,
0x3e, 0x77, 0x76, 0xf4, 0x1b, 0x7c, 0xbd, 0x89, 0x26, 0x8f, 0x4c, 0xdf, 0x13, 0x37, 0xa9, 0xbb,
0x88, 0x71, 0x9a, 0x10, 0x2d, 0x49, 0x46, 0x45, 0x12, 0x54, 0x0d, 0xd1, 0xfb, 0x75, 0xa2, 0x1b,
0xa6, 0xef, 0x6d, 0xad, 0xe7, 0x06, 0x7f, 0x2d, 0xc7, 0x54, 0x24, 0x68, 0x08, 0x5f, 0x3c, 0x43,
0xff, 0x20, 0xe4, 0xa3, 0x20, 0x8a, 0x4e, 0x28, 0x5b, 0xd0, 0x24, 0x78, 0x61, 0x28, 0xdf, 0xad,
0x53, 0xfe, 0x92, 0xa3, 0xb0, 0x03, 0xa1, 0x5f, 0x21, 0x78, 0x2e, 0xc9, 0x44, 0xc9, 0x34, 0xa8,
0x1d, 0x55, 0xba, 0x8d, 0x93, 0xc3, 0x2d, 0x6b, 0x54, 0x8c, 0x16, 0x7f, 0x36, 0x59, 0xcf, 0xf8,
0x4c, 0xc9, 0xf4, 0xd2, 0xaf, 0xf9, 0xcd, 0xdd, 0x4b, 0xbf, 0xb6, 0xdb, 0xac, 0x76, 0xfe, 0xf3,
0xe0, 0xd5, 0x93, 0xfd, 0xc9, 0x52, 0x29, 0x32, 0x8a, 0xc6, 0xd0, 0x5c, 0x49, 0xc1, 0xda, 0xdc,
0x04, 0xdf, 0x7f, 0x48, 0x0b, 0x16, 0x3d, 0xdc, 0xc1, 0xfb, 0x4b, 0x31, 0x38, 0xd2, 0x1f, 0xa0,
0x91, 0x51, 0xb5, 0xa0, 0x8a, 0x70, 0x96, 0x69, 0x27, 0x86, 0xd7, 0x65, 0xbe, 0xb1, 0x71, 0x5f,
0x31, 0x23, 0x26, 0xc8, 0x96, 0xa7, 0xc1, 0x5b, 0x68, 0xaf, 0x49, 0xc1, 0x72, 0x5a, 0x2d, 0xfc,
0xed, 0x41, 0x7b, 0x7b, 0x2a, 0xe8, 0x3b, 0x78, 0xfd, 0xe4, 0x49, 0x21, 0x09, 0xe5, 0x74, 0x1a,
0xe9, 0x42, 0x1f, 0x9f, 0x96, 0xd6, 0x5c, 0x9d, 0x39, 0x1f, 0xba, 0x85, 0xb7, 0x65, 0xed, 0x12,
0x45, 0x53, 0xa9, 0x34, 0x61, 0x42, 0x53, 0xb5, 0x88, 0xb8, 0x4b, 0xbf, 0xb5, 0xb1, 0xd0, 0x67,
0xee, 0x31, 0xc2, 0xad, 0x92, 0x96, 0xb1, 0xb9, 0x7c, 0xe1, 0xee, 0x76, 0x7e, 0x02, 0x58, 0x95,
0x8a, 0x8e, 0xe1, 0x85, 0x2d, 0x35, 0x0b, 0x3c, 0x33, 0x59, 0xb4, 0xd9, 0x13, 0x5c, 0x40, 0x2e,
0xfd, 0x5a, 0xa5, 0xe9, 0x77, 0xfe, 0xf2, 0xa0, 0x6a, 0x3d, 0xe8, 0x1d, 0x00, 0x4b, 0x49, 0x94,
0x24, 0x8a, 0x66, 0x99, 0x29, 0x69, 0x0f, 0xd7, 0x59, 0xfa, 0xb3, 0x35, 0xe4, 0x6f, 0x41, 0x1e,
0xdb, 0xe4, 0xbb, 0x8b, 0xcd, 0xf7, 0x16, 0xd1, 0x57, 0xb6, 0x88, 0x1e, 0x81, 0x6f, 0xd6, 0xce,
0x3f, 0xf2, 0xba, 0x35, 0x6c, 0xbe, 0xed, 0xfa, 0x9c, 0xc4, 0xb0, 0x57, 0x6a, 0xb8, 0x42, 0x18,
0x1a, 0xee, 0x3b, 0x37, 0xa3, 0x83, 0x72, 0x1d, 0x9b, 0xcf, 0x54, 0xfb, 0x70, 0xab, 0xdf, 0x4e,
0xae, 0xeb, 0x7d, 0xe3, 0x0d, 0x6e, 0xe0, 0x63, 0x26, 0x4b, 0xc0, 0xc1, 0xcb, 0x72, 0xc8, 0x51,
0xde, 0xf6, 0x91, 0x77, 0xdb, 0x77, 0x63, 0x98, 0x4a, 0x1e, 0x89, 0x69, 0x28, 0xd5, 0xb4, 0x67,
0xfe, 0x51, 0x8a, 0x99, 0x9b, 0x13, 0x8f, 0xcd, 0x0f, 0xe1, 0x31, 0x59, 0xf4, 0xe3, 0xaa, 0x19,
0xd9, 0xb7, 0xff, 0x07, 0x00, 0x00, 0xff, 0xff, 0x81, 0x14, 0xee, 0xd1, 0x7b, 0x06, 0x00, 0x00,
}

View file

@ -16,19 +16,29 @@
*
*/
package grpc
//go:generate ./regenerate.sh
// Package grpclb defines a grpclb balancer.
//
// To install grpclb balancer, import this package as:
// import _ "google.golang.org/grpc/balancer/grpclb"
package grpclb
import (
"errors"
"strconv"
"strings"
"sync"
"time"
durationpb "github.com/golang/protobuf/ptypes/duration"
"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/balancer"
lbpb "google.golang.org/grpc/balancer/grpclb/grpc_lb_v1"
"google.golang.org/grpc/connectivity"
lbpb "google.golang.org/grpc/grpclb/grpc_lb_v1/messages"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/backoff"
"google.golang.org/grpc/resolver"
)
@ -38,7 +48,21 @@ const (
grpclbName = "grpclb"
)
func convertDuration(d *lbpb.Duration) time.Duration {
var (
// defaultBackoffConfig configures the backoff strategy that's used when the
// init handshake in the RPC is unsuccessful. It's not for the clientconn
// reconnect backoff.
//
// It has the same value as the default grpc.DefaultBackoffConfig.
//
// TODO: make backoff configurable.
defaultBackoffConfig = backoff.Exponential{
MaxDelay: 120 * time.Second,
}
errServerTerminatedConnection = errors.New("grpclb: failed to recv server list: server terminated connection")
)
func convertDuration(d *durationpb.Duration) time.Duration {
if d == nil {
return 0
}
@ -49,16 +73,16 @@ func convertDuration(d *lbpb.Duration) time.Duration {
// Mostly copied from generated pb.go file.
// To avoid circular dependency.
type loadBalancerClient struct {
cc *ClientConn
cc *grpc.ClientConn
}
func (c *loadBalancerClient) BalanceLoad(ctx context.Context, opts ...CallOption) (*balanceLoadClientStream, error) {
desc := &StreamDesc{
func (c *loadBalancerClient) BalanceLoad(ctx context.Context, opts ...grpc.CallOption) (*balanceLoadClientStream, error) {
desc := &grpc.StreamDesc{
StreamName: "BalanceLoad",
ServerStreams: true,
ClientStreams: true,
}
stream, err := NewClientStream(ctx, desc, c.cc, "/grpc.lb.v1.LoadBalancer/BalanceLoad", opts...)
stream, err := c.cc.NewStream(ctx, desc, "/grpc.lb.v1.LoadBalancer/BalanceLoad", opts...)
if err != nil {
return nil, err
}
@ -67,7 +91,7 @@ func (c *loadBalancerClient) BalanceLoad(ctx context.Context, opts ...CallOption
}
type balanceLoadClientStream struct {
ClientStream
grpc.ClientStream
}
func (x *balanceLoadClientStream) Send(m *lbpb.LoadBalanceRequest) error {
@ -88,16 +112,16 @@ func init() {
// newLBBuilder creates a builder for grpclb.
func newLBBuilder() balancer.Builder {
return NewLBBuilderWithFallbackTimeout(defaultFallbackTimeout)
return newLBBuilderWithFallbackTimeout(defaultFallbackTimeout)
}
// NewLBBuilderWithFallbackTimeout creates a grpclb builder with the given
// newLBBuilderWithFallbackTimeout creates a grpclb builder with the given
// fallbackTimeout. If no response is received from the remote balancer within
// fallbackTimeout, the backend addresses from the resolved address list will be
// used.
//
// Only call this function when a non-default fallback timeout is needed.
func NewLBBuilderWithFallbackTimeout(fallbackTimeout time.Duration) balancer.Builder {
func newLBBuilderWithFallbackTimeout(fallbackTimeout time.Duration) balancer.Builder {
return &lbBuilder{
fallbackTimeout: fallbackTimeout,
}
@ -127,25 +151,26 @@ func (b *lbBuilder) Build(cc balancer.ClientConn, opt balancer.BuildOptions) bal
}
lb := &lbBalancer{
cc: cc,
cc: newLBCacheClientConn(cc),
target: target,
opt: opt,
fallbackTimeout: b.fallbackTimeout,
doneCh: make(chan struct{}),
manualResolver: r,
csEvltr: &connectivityStateEvaluator{},
csEvltr: &balancer.ConnectivityStateEvaluator{},
subConns: make(map[resolver.Address]balancer.SubConn),
scStates: make(map[balancer.SubConn]connectivity.State),
picker: &errPicker{err: balancer.ErrNoSubConnAvailable},
clientStats: &rpcStats{},
clientStats: newRPCStats(),
backoff: defaultBackoffConfig, // TODO: make backoff configurable.
}
return lb
}
type lbBalancer struct {
cc balancer.ClientConn
cc *lbCacheClientConn
target string
opt balancer.BuildOptions
fallbackTimeout time.Duration
@ -156,7 +181,9 @@ type lbBalancer struct {
// send to remote LB ClientConn through this resolver.
manualResolver *lbManualResolver
// The ClientConn to talk to the remote balancer.
ccRemoteLB *ClientConn
ccRemoteLB *grpc.ClientConn
// backoff for calling remote balancer.
backoff backoff.Strategy
// Support client side load reporting. Each picker gets a reference to this,
// and will update its content.
@ -173,7 +200,7 @@ type lbBalancer struct {
// but with only READY SCs will be gerenated.
backendAddrs []resolver.Address
// Roundrobin functionalities.
csEvltr *connectivityStateEvaluator
csEvltr *balancer.ConnectivityStateEvaluator
state connectivity.State
subConns map[resolver.Address]balancer.SubConn // Used to new/remove SubConn.
scStates map[balancer.SubConn]connectivity.State // Used to filter READY SubConns.
@ -220,7 +247,6 @@ func (lb *lbBalancer) regeneratePicker() {
subConns: readySCs,
stats: lb.clientStats,
}
return
}
func (lb *lbBalancer) HandleSubConnStateChange(sc balancer.SubConn, s connectivity.State) {
@ -244,7 +270,7 @@ func (lb *lbBalancer) HandleSubConnStateChange(sc balancer.SubConn, s connectivi
}
oldAggrState := lb.state
lb.state = lb.csEvltr.recordTransition(oldS, s)
lb.state = lb.csEvltr.RecordTransition(oldS, s)
// Regenerate picker when one of the following happens:
// - this sc became ready from not-ready
@ -257,7 +283,6 @@ func (lb *lbBalancer) HandleSubConnStateChange(sc balancer.SubConn, s connectivi
}
lb.cc.UpdateBalancerState(lb.state, lb.picker)
return
}
// fallbackToBackendsAfter blocks for fallbackTimeout and falls back to use
@ -339,4 +364,5 @@ func (lb *lbBalancer) Close() {
if lb.ccRemoteLB != nil {
lb.ccRemoteLB.Close()
}
lb.cc.close()
}

View file

@ -16,7 +16,7 @@
*
*/
package grpc
package grpclb
import (
"sync"
@ -24,55 +24,70 @@ import (
"golang.org/x/net/context"
"google.golang.org/grpc/balancer"
lbpb "google.golang.org/grpc/balancer/grpclb/grpc_lb_v1"
"google.golang.org/grpc/codes"
lbpb "google.golang.org/grpc/grpclb/grpc_lb_v1/messages"
"google.golang.org/grpc/status"
)
// rpcStats is same as lbmpb.ClientStats, except that numCallsDropped is a map
// instead of a slice.
type rpcStats struct {
NumCallsStarted int64
NumCallsFinished int64
NumCallsFinishedWithDropForRateLimiting int64
NumCallsFinishedWithDropForLoadBalancing int64
NumCallsFinishedWithClientFailedToSend int64
NumCallsFinishedKnownReceived int64
// Only access the following fields atomically.
numCallsStarted int64
numCallsFinished int64
numCallsFinishedWithClientFailedToSend int64
numCallsFinishedKnownReceived int64
mu sync.Mutex
// map load_balance_token -> num_calls_dropped
numCallsDropped map[string]int64
}
func newRPCStats() *rpcStats {
return &rpcStats{
numCallsDropped: make(map[string]int64),
}
}
// toClientStats converts rpcStats to lbpb.ClientStats, and clears rpcStats.
func (s *rpcStats) toClientStats() *lbpb.ClientStats {
stats := &lbpb.ClientStats{
NumCallsStarted: atomic.SwapInt64(&s.NumCallsStarted, 0),
NumCallsFinished: atomic.SwapInt64(&s.NumCallsFinished, 0),
NumCallsFinishedWithDropForRateLimiting: atomic.SwapInt64(&s.NumCallsFinishedWithDropForRateLimiting, 0),
NumCallsFinishedWithDropForLoadBalancing: atomic.SwapInt64(&s.NumCallsFinishedWithDropForLoadBalancing, 0),
NumCallsFinishedWithClientFailedToSend: atomic.SwapInt64(&s.NumCallsFinishedWithClientFailedToSend, 0),
NumCallsFinishedKnownReceived: atomic.SwapInt64(&s.NumCallsFinishedKnownReceived, 0),
NumCallsStarted: atomic.SwapInt64(&s.numCallsStarted, 0),
NumCallsFinished: atomic.SwapInt64(&s.numCallsFinished, 0),
NumCallsFinishedWithClientFailedToSend: atomic.SwapInt64(&s.numCallsFinishedWithClientFailedToSend, 0),
NumCallsFinishedKnownReceived: atomic.SwapInt64(&s.numCallsFinishedKnownReceived, 0),
}
s.mu.Lock()
dropped := s.numCallsDropped
s.numCallsDropped = make(map[string]int64)
s.mu.Unlock()
for token, count := range dropped {
stats.CallsFinishedWithDrop = append(stats.CallsFinishedWithDrop, &lbpb.ClientStatsPerToken{
LoadBalanceToken: token,
NumCalls: count,
})
}
return stats
}
func (s *rpcStats) dropForRateLimiting() {
atomic.AddInt64(&s.NumCallsStarted, 1)
atomic.AddInt64(&s.NumCallsFinishedWithDropForRateLimiting, 1)
atomic.AddInt64(&s.NumCallsFinished, 1)
}
func (s *rpcStats) dropForLoadBalancing() {
atomic.AddInt64(&s.NumCallsStarted, 1)
atomic.AddInt64(&s.NumCallsFinishedWithDropForLoadBalancing, 1)
atomic.AddInt64(&s.NumCallsFinished, 1)
func (s *rpcStats) drop(token string) {
atomic.AddInt64(&s.numCallsStarted, 1)
s.mu.Lock()
s.numCallsDropped[token]++
s.mu.Unlock()
atomic.AddInt64(&s.numCallsFinished, 1)
}
func (s *rpcStats) failedToSend() {
atomic.AddInt64(&s.NumCallsStarted, 1)
atomic.AddInt64(&s.NumCallsFinishedWithClientFailedToSend, 1)
atomic.AddInt64(&s.NumCallsFinished, 1)
atomic.AddInt64(&s.numCallsStarted, 1)
atomic.AddInt64(&s.numCallsFinishedWithClientFailedToSend, 1)
atomic.AddInt64(&s.numCallsFinished, 1)
}
func (s *rpcStats) knownReceived() {
atomic.AddInt64(&s.NumCallsStarted, 1)
atomic.AddInt64(&s.NumCallsFinishedKnownReceived, 1)
atomic.AddInt64(&s.NumCallsFinished, 1)
atomic.AddInt64(&s.numCallsStarted, 1)
atomic.AddInt64(&s.numCallsFinishedKnownReceived, 1)
atomic.AddInt64(&s.numCallsFinished, 1)
}
type errPicker struct {
@ -131,12 +146,8 @@ func (p *lbPicker) Pick(ctx context.Context, opts balancer.PickOptions) (balance
p.serverListNext = (p.serverListNext + 1) % len(p.serverList)
// If it's a drop, return an error and fail the RPC.
if s.DropForRateLimiting {
p.stats.dropForRateLimiting()
return nil, nil, status.Errorf(codes.Unavailable, "request dropped by grpclb")
}
if s.DropForLoadBalancing {
p.stats.dropForLoadBalancing()
if s.Drop {
p.stats.drop(s.LoadBalanceToken)
return nil, nil, status.Errorf(codes.Unavailable, "request dropped by grpclb")
}

View file

@ -16,19 +16,24 @@
*
*/
package grpc
package grpclb
import (
"fmt"
"io"
"net"
"reflect"
"time"
timestamppb "github.com/golang/protobuf/ptypes/timestamp"
"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/balancer"
lbpb "google.golang.org/grpc/balancer/grpclb/grpc_lb_v1"
"google.golang.org/grpc/connectivity"
lbpb "google.golang.org/grpc/grpclb/grpc_lb_v1/messages"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/resolver"
)
@ -52,8 +57,8 @@ func (lb *lbBalancer) processServerList(l *lbpb.ServerList) {
lb.fullServerList = l.Servers
var backendAddrs []resolver.Address
for _, s := range l.Servers {
if s.DropForLoadBalancing || s.DropForRateLimiting {
for i, s := range l.Servers {
if s.Drop {
continue
}
@ -69,20 +74,22 @@ func (lb *lbBalancer) processServerList(l *lbpb.ServerList) {
Addr: fmt.Sprintf("%s:%d", ipStr, s.Port),
Metadata: &md,
}
grpclog.Infof("lbBalancer: server list entry[%d]: ipStr:|%s|, port:|%d|, load balancer token:|%v|",
i, ipStr, s.Port, s.LoadBalanceToken)
backendAddrs = append(backendAddrs, addr)
}
// Call refreshSubConns to create/remove SubConns.
backendsUpdated := lb.refreshSubConns(backendAddrs)
// If no backend was updated, no SubConn will be newed/removed. But since
// the full serverList was different, there might be updates in drops or
// pick weights(different number of duplicates). We need to update picker
// with the fulllist.
if !backendsUpdated {
lb.regeneratePicker()
lb.cc.UpdateBalancerState(lb.state, lb.picker)
}
lb.refreshSubConns(backendAddrs)
// Regenerate and update picker no matter if there's update on backends (if
// any SubConn will be newed/removed). Because since the full serverList was
// different, there might be updates in drops or pick weights(different
// number of duplicates). We need to update picker with the fulllist.
//
// Now with cache, even if SubConn was newed/removed, there might be no
// state changes.
lb.regeneratePicker()
lb.cc.UpdateBalancerState(lb.state, lb.picker)
}
// refreshSubConns creates/removes SubConns with backendAddrs. It returns a bool
@ -112,7 +119,11 @@ func (lb *lbBalancer) refreshSubConns(backendAddrs []resolver.Address) bool {
continue
}
lb.subConns[addrWithoutMD] = sc // Use the addr without MD as key for the map.
lb.scStates[sc] = connectivity.Idle
if _, ok := lb.scStates[sc]; !ok {
// Only set state of new sc to IDLE. The state could already be
// READY for cached SubConns.
lb.scStates[sc] = connectivity.Idle
}
sc.Connect()
}
}
@ -136,6 +147,9 @@ func (lb *lbBalancer) readServerList(s *balanceLoadClientStream) error {
for {
reply, err := s.Recv()
if err != nil {
if err == io.EOF {
return errServerTerminatedConnection
}
return fmt.Errorf("grpclb: failed to recv server list: %v", err)
}
if serverList := reply.GetServerList(); serverList != nil {
@ -155,7 +169,7 @@ func (lb *lbBalancer) sendLoadReport(s *balanceLoadClientStream, interval time.D
}
stats := lb.clientStats.toClientStats()
t := time.Now()
stats.Timestamp = &lbpb.Timestamp{
stats.Timestamp = &timestamppb.Timestamp{
Seconds: t.Unix(),
Nanos: int32(t.Nanosecond()),
}
@ -168,13 +182,14 @@ func (lb *lbBalancer) sendLoadReport(s *balanceLoadClientStream, interval time.D
}
}
}
func (lb *lbBalancer) callRemoteBalancer() error {
func (lb *lbBalancer) callRemoteBalancer() (backoff bool, _ error) {
lbClient := &loadBalancerClient{cc: lb.ccRemoteLB}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
stream, err := lbClient.BalanceLoad(ctx, FailFast(false))
stream, err := lbClient.BalanceLoad(ctx, grpc.FailFast(false))
if err != nil {
return fmt.Errorf("grpclb: failed to perform RPC to the remote balancer %v", err)
return true, fmt.Errorf("grpclb: failed to perform RPC to the remote balancer %v", err)
}
// grpclb handshake on the stream.
@ -186,18 +201,18 @@ func (lb *lbBalancer) callRemoteBalancer() error {
},
}
if err := stream.Send(initReq); err != nil {
return fmt.Errorf("grpclb: failed to send init request: %v", err)
return true, fmt.Errorf("grpclb: failed to send init request: %v", err)
}
reply, err := stream.Recv()
if err != nil {
return fmt.Errorf("grpclb: failed to recv init response: %v", err)
return true, fmt.Errorf("grpclb: failed to recv init response: %v", err)
}
initResp := reply.GetInitialResponse()
if initResp == nil {
return fmt.Errorf("grpclb: reply from remote balancer did not include initial response")
return true, fmt.Errorf("grpclb: reply from remote balancer did not include initial response")
}
if initResp.LoadBalancerDelegate != "" {
return fmt.Errorf("grpclb: Delegation is not supported")
return true, fmt.Errorf("grpclb: Delegation is not supported")
}
go func() {
@ -205,47 +220,72 @@ func (lb *lbBalancer) callRemoteBalancer() error {
lb.sendLoadReport(stream, d)
}
}()
return lb.readServerList(stream)
// No backoff if init req/resp handshake was successful.
return false, lb.readServerList(stream)
}
func (lb *lbBalancer) watchRemoteBalancer() {
var retryCount int
for {
err := lb.callRemoteBalancer()
doBackoff, err := lb.callRemoteBalancer()
select {
case <-lb.doneCh:
return
default:
if err != nil {
grpclog.Error(err)
if err == errServerTerminatedConnection {
grpclog.Info(err)
} else {
grpclog.Error(err)
}
}
}
if !doBackoff {
retryCount = 0
continue
}
timer := time.NewTimer(lb.backoff.Backoff(retryCount))
select {
case <-timer.C:
case <-lb.doneCh:
timer.Stop()
return
}
retryCount++
}
}
func (lb *lbBalancer) dialRemoteLB(remoteLBName string) {
var dopts []DialOption
var dopts []grpc.DialOption
if creds := lb.opt.DialCreds; creds != nil {
if err := creds.OverrideServerName(remoteLBName); err == nil {
dopts = append(dopts, WithTransportCredentials(creds))
dopts = append(dopts, grpc.WithTransportCredentials(creds))
} else {
grpclog.Warningf("grpclb: failed to override the server name in the credentials: %v, using Insecure", err)
dopts = append(dopts, WithInsecure())
dopts = append(dopts, grpc.WithInsecure())
}
} else {
dopts = append(dopts, WithInsecure())
dopts = append(dopts, grpc.WithInsecure())
}
if lb.opt.Dialer != nil {
// WithDialer takes a different type of function, so we instead use a
// special DialOption here.
dopts = append(dopts, withContextDialer(lb.opt.Dialer))
wcd := internal.WithContextDialer.(func(func(context.Context, string) (net.Conn, error)) grpc.DialOption)
dopts = append(dopts, wcd(lb.opt.Dialer))
}
// Explicitly set pickfirst as the balancer.
dopts = append(dopts, WithBalancerName(PickFirstBalancerName))
dopts = append(dopts, withResolverBuilder(lb.manualResolver))
// Dial using manualResolver.Scheme, which is a random scheme generated
dopts = append(dopts, grpc.WithBalancerName(grpc.PickFirstBalancerName))
wrb := internal.WithResolverBuilder.(func(resolver.Builder) grpc.DialOption)
dopts = append(dopts, wrb(lb.manualResolver))
if channelz.IsOn() {
dopts = append(dopts, grpc.WithChannelzParentID(lb.opt.ChannelzParentID))
}
// DialContext using manualResolver.Scheme, which is a random scheme generated
// when init grpclb. The target name is not important.
cc, err := Dial("grpclb:///grpclb.server", dopts...)
cc, err := grpc.DialContext(context.Background(), "grpclb:///grpclb.server", dopts...)
if err != nil {
grpclog.Fatalf("failed to dial: %v", err)
}

View file

@ -0,0 +1,970 @@
/*
*
* Copyright 2016 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpclb
import (
"errors"
"fmt"
"io"
"net"
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
durationpb "github.com/golang/protobuf/ptypes/duration"
"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/balancer"
lbgrpc "google.golang.org/grpc/balancer/grpclb/grpc_lb_v1"
lbpb "google.golang.org/grpc/balancer/grpclb/grpc_lb_v1"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
_ "google.golang.org/grpc/grpclog/glogger"
"google.golang.org/grpc/internal/leakcheck"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/resolver/manual"
"google.golang.org/grpc/status"
testpb "google.golang.org/grpc/test/grpc_testing"
)
var (
lbServerName = "bar.com"
beServerName = "foo.com"
lbToken = "iamatoken"
// Resolver replaces localhost with fakeName in Next().
// Dialer replaces fakeName with localhost when dialing.
// This will test that custom dialer is passed from Dial to grpclb.
fakeName = "fake.Name"
)
type serverNameCheckCreds struct {
mu sync.Mutex
sn string
expected string
}
func (c *serverNameCheckCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
if _, err := io.WriteString(rawConn, c.sn); err != nil {
fmt.Printf("Failed to write the server name %s to the client %v", c.sn, err)
return nil, nil, err
}
return rawConn, nil, nil
}
func (c *serverNameCheckCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
c.mu.Lock()
defer c.mu.Unlock()
b := make([]byte, len(c.expected))
errCh := make(chan error, 1)
go func() {
_, err := rawConn.Read(b)
errCh <- err
}()
select {
case err := <-errCh:
if err != nil {
fmt.Printf("Failed to read the server name from the server %v", err)
return nil, nil, err
}
case <-ctx.Done():
return nil, nil, ctx.Err()
}
if c.expected != string(b) {
fmt.Printf("Read the server name %s want %s", string(b), c.expected)
return nil, nil, errors.New("received unexpected server name")
}
return rawConn, nil, nil
}
func (c *serverNameCheckCreds) Info() credentials.ProtocolInfo {
c.mu.Lock()
defer c.mu.Unlock()
return credentials.ProtocolInfo{}
}
func (c *serverNameCheckCreds) Clone() credentials.TransportCredentials {
c.mu.Lock()
defer c.mu.Unlock()
return &serverNameCheckCreds{
expected: c.expected,
}
}
func (c *serverNameCheckCreds) OverrideServerName(s string) error {
c.mu.Lock()
defer c.mu.Unlock()
c.expected = s
return nil
}
// fakeNameDialer replaces fakeName with localhost when dialing.
// This will test that custom dialer is passed from Dial to grpclb.
func fakeNameDialer(addr string, timeout time.Duration) (net.Conn, error) {
addr = strings.Replace(addr, fakeName, "localhost", 1)
return net.DialTimeout("tcp", addr, timeout)
}
// merge merges the new client stats into current stats.
//
// It's a test-only method. rpcStats is defined in grpclb_picker.
func (s *rpcStats) merge(cs *lbpb.ClientStats) {
atomic.AddInt64(&s.numCallsStarted, cs.NumCallsStarted)
atomic.AddInt64(&s.numCallsFinished, cs.NumCallsFinished)
atomic.AddInt64(&s.numCallsFinishedWithClientFailedToSend, cs.NumCallsFinishedWithClientFailedToSend)
atomic.AddInt64(&s.numCallsFinishedKnownReceived, cs.NumCallsFinishedKnownReceived)
s.mu.Lock()
for _, perToken := range cs.CallsFinishedWithDrop {
s.numCallsDropped[perToken.LoadBalanceToken] += perToken.NumCalls
}
s.mu.Unlock()
}
func mapsEqual(a, b map[string]int64) bool {
if len(a) != len(b) {
return false
}
for k, v1 := range a {
if v2, ok := b[k]; !ok || v1 != v2 {
return false
}
}
return true
}
func atomicEqual(a, b *int64) bool {
return atomic.LoadInt64(a) == atomic.LoadInt64(b)
}
// equal compares two rpcStats.
//
// It's a test-only method. rpcStats is defined in grpclb_picker.
func (s *rpcStats) equal(o *rpcStats) bool {
if !atomicEqual(&s.numCallsStarted, &o.numCallsStarted) {
return false
}
if !atomicEqual(&s.numCallsFinished, &o.numCallsFinished) {
return false
}
if !atomicEqual(&s.numCallsFinishedWithClientFailedToSend, &o.numCallsFinishedWithClientFailedToSend) {
return false
}
if !atomicEqual(&s.numCallsFinishedKnownReceived, &o.numCallsFinishedKnownReceived) {
return false
}
s.mu.Lock()
defer s.mu.Unlock()
o.mu.Lock()
defer o.mu.Unlock()
if !mapsEqual(s.numCallsDropped, o.numCallsDropped) {
return false
}
return true
}
type remoteBalancer struct {
sls chan *lbpb.ServerList
statsDura time.Duration
done chan struct{}
stats *rpcStats
}
func newRemoteBalancer(intervals []time.Duration) *remoteBalancer {
return &remoteBalancer{
sls: make(chan *lbpb.ServerList, 1),
done: make(chan struct{}),
stats: newRPCStats(),
}
}
func (b *remoteBalancer) stop() {
close(b.sls)
close(b.done)
}
func (b *remoteBalancer) BalanceLoad(stream lbgrpc.LoadBalancer_BalanceLoadServer) error {
req, err := stream.Recv()
if err != nil {
return err
}
initReq := req.GetInitialRequest()
if initReq.Name != beServerName {
return status.Errorf(codes.InvalidArgument, "invalid service name: %v", initReq.Name)
}
resp := &lbpb.LoadBalanceResponse{
LoadBalanceResponseType: &lbpb.LoadBalanceResponse_InitialResponse{
InitialResponse: &lbpb.InitialLoadBalanceResponse{
ClientStatsReportInterval: &durationpb.Duration{
Seconds: int64(b.statsDura.Seconds()),
Nanos: int32(b.statsDura.Nanoseconds() - int64(b.statsDura.Seconds())*1e9),
},
},
},
}
if err := stream.Send(resp); err != nil {
return err
}
go func() {
for {
var (
req *lbpb.LoadBalanceRequest
err error
)
if req, err = stream.Recv(); err != nil {
return
}
b.stats.merge(req.GetClientStats())
}
}()
for v := range b.sls {
resp = &lbpb.LoadBalanceResponse{
LoadBalanceResponseType: &lbpb.LoadBalanceResponse_ServerList{
ServerList: v,
},
}
if err := stream.Send(resp); err != nil {
return err
}
}
<-b.done
return nil
}
type testServer struct {
testpb.TestServiceServer
addr string
fallback bool
}
const testmdkey = "testmd"
func (s *testServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, status.Error(codes.Internal, "failed to receive metadata")
}
if !s.fallback && (md == nil || md["lb-token"][0] != lbToken) {
return nil, status.Errorf(codes.Internal, "received unexpected metadata: %v", md)
}
grpc.SetTrailer(ctx, metadata.Pairs(testmdkey, s.addr))
return &testpb.Empty{}, nil
}
func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error {
return nil
}
func startBackends(sn string, fallback bool, lis ...net.Listener) (servers []*grpc.Server) {
for _, l := range lis {
creds := &serverNameCheckCreds{
sn: sn,
}
s := grpc.NewServer(grpc.Creds(creds))
testpb.RegisterTestServiceServer(s, &testServer{addr: l.Addr().String(), fallback: fallback})
servers = append(servers, s)
go func(s *grpc.Server, l net.Listener) {
s.Serve(l)
}(s, l)
}
return
}
func stopBackends(servers []*grpc.Server) {
for _, s := range servers {
s.Stop()
}
}
type testServers struct {
lbAddr string
ls *remoteBalancer
lb *grpc.Server
beIPs []net.IP
bePorts []int
}
func newLoadBalancer(numberOfBackends int) (tss *testServers, cleanup func(), err error) {
var (
beListeners []net.Listener
ls *remoteBalancer
lb *grpc.Server
beIPs []net.IP
bePorts []int
)
for i := 0; i < numberOfBackends; i++ {
// Start a backend.
beLis, e := net.Listen("tcp", "localhost:0")
if e != nil {
err = fmt.Errorf("Failed to listen %v", err)
return
}
beIPs = append(beIPs, beLis.Addr().(*net.TCPAddr).IP)
bePorts = append(bePorts, beLis.Addr().(*net.TCPAddr).Port)
beListeners = append(beListeners, beLis)
}
backends := startBackends(beServerName, false, beListeners...)
// Start a load balancer.
lbLis, err := net.Listen("tcp", "localhost:0")
if err != nil {
err = fmt.Errorf("Failed to create the listener for the load balancer %v", err)
return
}
lbCreds := &serverNameCheckCreds{
sn: lbServerName,
}
lb = grpc.NewServer(grpc.Creds(lbCreds))
ls = newRemoteBalancer(nil)
lbgrpc.RegisterLoadBalancerServer(lb, ls)
go func() {
lb.Serve(lbLis)
}()
tss = &testServers{
lbAddr: fakeName + ":" + strconv.Itoa(lbLis.Addr().(*net.TCPAddr).Port),
ls: ls,
lb: lb,
beIPs: beIPs,
bePorts: bePorts,
}
cleanup = func() {
defer stopBackends(backends)
defer func() {
ls.stop()
lb.Stop()
}()
}
return
}
func TestGRPCLB(t *testing.T) {
defer leakcheck.Check(t)
r, cleanup := manual.GenerateAndRegisterManualResolver()
defer cleanup()
tss, cleanup, err := newLoadBalancer(1)
if err != nil {
t.Fatalf("failed to create new load balancer: %v", err)
}
defer cleanup()
be := &lbpb.Server{
IpAddress: tss.beIPs[0],
Port: int32(tss.bePorts[0]),
LoadBalanceToken: lbToken,
}
var bes []*lbpb.Server
bes = append(bes, be)
sl := &lbpb.ServerList{
Servers: bes,
}
tss.ls.sls <- sl
creds := serverNameCheckCreds{
expected: beServerName,
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName,
grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer))
if err != nil {
t.Fatalf("Failed to dial to the backend %v", err)
}
defer cc.Close()
testC := testpb.NewTestServiceClient(cc)
r.NewAddress([]resolver.Address{{
Addr: tss.lbAddr,
Type: resolver.GRPCLB,
ServerName: lbServerName,
}})
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
}
}
// The remote balancer sends response with duplicates to grpclb client.
func TestGRPCLBWeighted(t *testing.T) {
defer leakcheck.Check(t)
r, cleanup := manual.GenerateAndRegisterManualResolver()
defer cleanup()
tss, cleanup, err := newLoadBalancer(2)
if err != nil {
t.Fatalf("failed to create new load balancer: %v", err)
}
defer cleanup()
beServers := []*lbpb.Server{{
IpAddress: tss.beIPs[0],
Port: int32(tss.bePorts[0]),
LoadBalanceToken: lbToken,
}, {
IpAddress: tss.beIPs[1],
Port: int32(tss.bePorts[1]),
LoadBalanceToken: lbToken,
}}
portsToIndex := make(map[int]int)
for i := range beServers {
portsToIndex[tss.bePorts[i]] = i
}
creds := serverNameCheckCreds{
expected: beServerName,
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName,
grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer))
if err != nil {
t.Fatalf("Failed to dial to the backend %v", err)
}
defer cc.Close()
testC := testpb.NewTestServiceClient(cc)
r.NewAddress([]resolver.Address{{
Addr: tss.lbAddr,
Type: resolver.GRPCLB,
ServerName: lbServerName,
}})
sequences := []string{"00101", "00011"}
for _, seq := range sequences {
var (
bes []*lbpb.Server
p peer.Peer
result string
)
for _, s := range seq {
bes = append(bes, beServers[s-'0'])
}
tss.ls.sls <- &lbpb.ServerList{Servers: bes}
for i := 0; i < 1000; i++ {
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil {
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
}
result += strconv.Itoa(portsToIndex[p.Addr.(*net.TCPAddr).Port])
}
// The generated result will be in format of "0010100101".
if !strings.Contains(result, strings.Repeat(seq, 2)) {
t.Errorf("got result sequence %q, want patten %q", result, seq)
}
}
}
func TestDropRequest(t *testing.T) {
defer leakcheck.Check(t)
r, cleanup := manual.GenerateAndRegisterManualResolver()
defer cleanup()
tss, cleanup, err := newLoadBalancer(1)
if err != nil {
t.Fatalf("failed to create new load balancer: %v", err)
}
defer cleanup()
tss.ls.sls <- &lbpb.ServerList{
Servers: []*lbpb.Server{{
IpAddress: tss.beIPs[0],
Port: int32(tss.bePorts[0]),
LoadBalanceToken: lbToken,
Drop: false,
}, {
Drop: true,
}},
}
creds := serverNameCheckCreds{
expected: beServerName,
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName,
grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer))
if err != nil {
t.Fatalf("Failed to dial to the backend %v", err)
}
defer cc.Close()
testC := testpb.NewTestServiceClient(cc)
r.NewAddress([]resolver.Address{{
Addr: tss.lbAddr,
Type: resolver.GRPCLB,
ServerName: lbServerName,
}})
// Wait for the 1st, non-fail-fast RPC to succeed. This ensures both server
// connections are made, because the first one has DropForLoadBalancing set
// to true.
var i int
for i = 0; i < 1000; i++ {
if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err == nil {
break
}
time.Sleep(time.Millisecond)
}
if i >= 1000 {
t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", testC, err)
}
select {
case <-ctx.Done():
t.Fatal("timed out", ctx.Err())
default:
}
for _, failfast := range []bool{true, false} {
for i := 0; i < 3; i++ {
// Even RPCs should fail, because the 2st backend has
// DropForLoadBalancing set to true.
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(failfast)); status.Code(err) != codes.Unavailable {
t.Errorf("%v.EmptyCall(_, _) = _, %v, want _, %s", testC, err, codes.Unavailable)
}
// Odd RPCs should succeed since they choose the non-drop-request
// backend according to the round robin policy.
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(failfast)); err != nil {
t.Errorf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
}
}
}
}
// When the balancer in use disconnects, grpclb should connect to the next address from resolved balancer address list.
func TestBalancerDisconnects(t *testing.T) {
defer leakcheck.Check(t)
r, cleanup := manual.GenerateAndRegisterManualResolver()
defer cleanup()
var (
tests []*testServers
lbs []*grpc.Server
)
for i := 0; i < 2; i++ {
tss, cleanup, err := newLoadBalancer(1)
if err != nil {
t.Fatalf("failed to create new load balancer: %v", err)
}
defer cleanup()
be := &lbpb.Server{
IpAddress: tss.beIPs[0],
Port: int32(tss.bePorts[0]),
LoadBalanceToken: lbToken,
}
var bes []*lbpb.Server
bes = append(bes, be)
sl := &lbpb.ServerList{
Servers: bes,
}
tss.ls.sls <- sl
tests = append(tests, tss)
lbs = append(lbs, tss.lb)
}
creds := serverNameCheckCreds{
expected: beServerName,
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName,
grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer))
if err != nil {
t.Fatalf("Failed to dial to the backend %v", err)
}
defer cc.Close()
testC := testpb.NewTestServiceClient(cc)
r.NewAddress([]resolver.Address{{
Addr: tests[0].lbAddr,
Type: resolver.GRPCLB,
ServerName: lbServerName,
}, {
Addr: tests[1].lbAddr,
Type: resolver.GRPCLB,
ServerName: lbServerName,
}})
var p peer.Peer
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil {
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
}
if p.Addr.(*net.TCPAddr).Port != tests[0].bePorts[0] {
t.Fatalf("got peer: %v, want peer port: %v", p.Addr, tests[0].bePorts[0])
}
lbs[0].Stop()
// Stop balancer[0], balancer[1] should be used by grpclb.
// Check peer address to see if that happened.
for i := 0; i < 1000; i++ {
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil {
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
}
if p.Addr.(*net.TCPAddr).Port == tests[1].bePorts[0] {
return
}
time.Sleep(time.Millisecond)
}
t.Fatalf("No RPC sent to second backend after 1 second")
}
type customGRPCLBBuilder struct {
balancer.Builder
name string
}
func (b *customGRPCLBBuilder) Name() string {
return b.name
}
const grpclbCustomFallbackName = "grpclb_with_custom_fallback_timeout"
func init() {
balancer.Register(&customGRPCLBBuilder{
Builder: newLBBuilderWithFallbackTimeout(100 * time.Millisecond),
name: grpclbCustomFallbackName,
})
}
func TestFallback(t *testing.T) {
defer leakcheck.Check(t)
r, cleanup := manual.GenerateAndRegisterManualResolver()
defer cleanup()
tss, cleanup, err := newLoadBalancer(1)
if err != nil {
t.Fatalf("failed to create new load balancer: %v", err)
}
defer cleanup()
// Start a standalone backend.
beLis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Failed to listen %v", err)
}
defer beLis.Close()
standaloneBEs := startBackends(beServerName, true, beLis)
defer stopBackends(standaloneBEs)
be := &lbpb.Server{
IpAddress: tss.beIPs[0],
Port: int32(tss.bePorts[0]),
LoadBalanceToken: lbToken,
}
var bes []*lbpb.Server
bes = append(bes, be)
sl := &lbpb.ServerList{
Servers: bes,
}
tss.ls.sls <- sl
creds := serverNameCheckCreds{
expected: beServerName,
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName,
grpc.WithBalancerName(grpclbCustomFallbackName),
grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer))
if err != nil {
t.Fatalf("Failed to dial to the backend %v", err)
}
defer cc.Close()
testC := testpb.NewTestServiceClient(cc)
r.NewAddress([]resolver.Address{{
Addr: "",
Type: resolver.GRPCLB,
ServerName: lbServerName,
}, {
Addr: beLis.Addr().String(),
Type: resolver.Backend,
ServerName: beServerName,
}})
var p peer.Peer
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil {
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
}
if p.Addr.String() != beLis.Addr().String() {
t.Fatalf("got peer: %v, want peer: %v", p.Addr, beLis.Addr())
}
r.NewAddress([]resolver.Address{{
Addr: tss.lbAddr,
Type: resolver.GRPCLB,
ServerName: lbServerName,
}, {
Addr: beLis.Addr().String(),
Type: resolver.Backend,
ServerName: beServerName,
}})
for i := 0; i < 1000; i++ {
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil {
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
}
if p.Addr.(*net.TCPAddr).Port == tss.bePorts[0] {
return
}
time.Sleep(time.Millisecond)
}
t.Fatalf("No RPC sent to backend behind remote balancer after 1 second")
}
type failPreRPCCred struct{}
func (failPreRPCCred) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
if strings.Contains(uri[0], failtosendURI) {
return nil, fmt.Errorf("rpc should fail to send")
}
return nil, nil
}
func (failPreRPCCred) RequireTransportSecurity() bool {
return false
}
func checkStats(stats, expected *rpcStats) error {
if !stats.equal(expected) {
return fmt.Errorf("stats not equal: got %+v, want %+v", stats, expected)
}
return nil
}
func runAndGetStats(t *testing.T, drop bool, runRPCs func(*grpc.ClientConn)) *rpcStats {
defer leakcheck.Check(t)
r, cleanup := manual.GenerateAndRegisterManualResolver()
defer cleanup()
tss, cleanup, err := newLoadBalancer(1)
if err != nil {
t.Fatalf("failed to create new load balancer: %v", err)
}
defer cleanup()
tss.ls.sls <- &lbpb.ServerList{
Servers: []*lbpb.Server{{
IpAddress: tss.beIPs[0],
Port: int32(tss.bePorts[0]),
LoadBalanceToken: lbToken,
Drop: drop,
}},
}
tss.ls.statsDura = 100 * time.Millisecond
creds := serverNameCheckCreds{expected: beServerName}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName,
grpc.WithTransportCredentials(&creds),
grpc.WithPerRPCCredentials(failPreRPCCred{}),
grpc.WithDialer(fakeNameDialer))
if err != nil {
t.Fatalf("Failed to dial to the backend %v", err)
}
defer cc.Close()
r.NewAddress([]resolver.Address{{
Addr: tss.lbAddr,
Type: resolver.GRPCLB,
ServerName: lbServerName,
}})
runRPCs(cc)
time.Sleep(1 * time.Second)
stats := tss.ls.stats
return stats
}
const (
countRPC = 40
failtosendURI = "failtosend"
dropErrDesc = "request dropped by grpclb"
)
func TestGRPCLBStatsUnarySuccess(t *testing.T) {
defer leakcheck.Check(t)
stats := runAndGetStats(t, false, func(cc *grpc.ClientConn) {
testC := testpb.NewTestServiceClient(cc)
// The first non-failfast RPC succeeds, all connections are up.
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
}
for i := 0; i < countRPC-1; i++ {
testC.EmptyCall(context.Background(), &testpb.Empty{})
}
})
if err := checkStats(stats, &rpcStats{
numCallsStarted: int64(countRPC),
numCallsFinished: int64(countRPC),
numCallsFinishedKnownReceived: int64(countRPC),
}); err != nil {
t.Fatal(err)
}
}
func TestGRPCLBStatsUnaryDrop(t *testing.T) {
defer leakcheck.Check(t)
c := 0
stats := runAndGetStats(t, true, func(cc *grpc.ClientConn) {
testC := testpb.NewTestServiceClient(cc)
for {
c++
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
if strings.Contains(err.Error(), dropErrDesc) {
break
}
}
}
for i := 0; i < countRPC; i++ {
testC.EmptyCall(context.Background(), &testpb.Empty{})
}
})
if err := checkStats(stats, &rpcStats{
numCallsStarted: int64(countRPC + c),
numCallsFinished: int64(countRPC + c),
numCallsFinishedWithClientFailedToSend: int64(c - 1),
numCallsDropped: map[string]int64{lbToken: int64(countRPC + 1)},
}); err != nil {
t.Fatal(err)
}
}
func TestGRPCLBStatsUnaryFailedToSend(t *testing.T) {
defer leakcheck.Check(t)
stats := runAndGetStats(t, false, func(cc *grpc.ClientConn) {
testC := testpb.NewTestServiceClient(cc)
// The first non-failfast RPC succeeds, all connections are up.
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
}
for i := 0; i < countRPC-1; i++ {
cc.Invoke(context.Background(), failtosendURI, &testpb.Empty{}, nil)
}
})
if err := checkStats(stats, &rpcStats{
numCallsStarted: int64(countRPC),
numCallsFinished: int64(countRPC),
numCallsFinishedWithClientFailedToSend: int64(countRPC - 1),
numCallsFinishedKnownReceived: 1,
}); err != nil {
t.Fatal(err)
}
}
func TestGRPCLBStatsStreamingSuccess(t *testing.T) {
defer leakcheck.Check(t)
stats := runAndGetStats(t, false, func(cc *grpc.ClientConn) {
testC := testpb.NewTestServiceClient(cc)
// The first non-failfast RPC succeeds, all connections are up.
stream, err := testC.FullDuplexCall(context.Background(), grpc.FailFast(false))
if err != nil {
t.Fatalf("%v.FullDuplexCall(_, _) = _, %v, want _, <nil>", testC, err)
}
for {
if _, err = stream.Recv(); err == io.EOF {
break
}
}
for i := 0; i < countRPC-1; i++ {
stream, err = testC.FullDuplexCall(context.Background())
if err == nil {
// Wait for stream to end if err is nil.
for {
if _, err = stream.Recv(); err == io.EOF {
break
}
}
}
}
})
if err := checkStats(stats, &rpcStats{
numCallsStarted: int64(countRPC),
numCallsFinished: int64(countRPC),
numCallsFinishedKnownReceived: int64(countRPC),
}); err != nil {
t.Fatal(err)
}
}
func TestGRPCLBStatsStreamingDrop(t *testing.T) {
defer leakcheck.Check(t)
c := 0
stats := runAndGetStats(t, true, func(cc *grpc.ClientConn) {
testC := testpb.NewTestServiceClient(cc)
for {
c++
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
if strings.Contains(err.Error(), dropErrDesc) {
break
}
}
}
for i := 0; i < countRPC; i++ {
testC.FullDuplexCall(context.Background())
}
})
if err := checkStats(stats, &rpcStats{
numCallsStarted: int64(countRPC + c),
numCallsFinished: int64(countRPC + c),
numCallsFinishedWithClientFailedToSend: int64(c - 1),
numCallsDropped: map[string]int64{lbToken: int64(countRPC + 1)},
}); err != nil {
t.Fatal(err)
}
}
func TestGRPCLBStatsStreamingFailedToSend(t *testing.T) {
defer leakcheck.Check(t)
stats := runAndGetStats(t, false, func(cc *grpc.ClientConn) {
testC := testpb.NewTestServiceClient(cc)
// The first non-failfast RPC succeeds, all connections are up.
stream, err := testC.FullDuplexCall(context.Background(), grpc.FailFast(false))
if err != nil {
t.Fatalf("%v.FullDuplexCall(_, _) = _, %v, want _, <nil>", testC, err)
}
for {
if _, err = stream.Recv(); err == io.EOF {
break
}
}
for i := 0; i < countRPC-1; i++ {
cc.NewStream(context.Background(), &grpc.StreamDesc{}, failtosendURI)
}
})
if err := checkStats(stats, &rpcStats{
numCallsStarted: int64(countRPC),
numCallsFinished: int64(countRPC),
numCallsFinishedWithClientFailedToSend: int64(countRPC - 1),
numCallsFinishedKnownReceived: 1,
}); err != nil {
t.Fatal(err)
}
}

View file

@ -16,10 +16,15 @@
*
*/
package grpc
package grpclb
import (
"fmt"
"sync"
"time"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/resolver"
)
@ -88,3 +93,122 @@ func (r *lbManualResolver) NewAddress(addrs []resolver.Address) {
func (r *lbManualResolver) NewServiceConfig(sc string) {
r.ccr.NewServiceConfig(sc)
}
const subConnCacheTime = time.Second * 10
// lbCacheClientConn is a wrapper balancer.ClientConn with a SubConn cache.
// SubConns will be kept in cache for subConnCacheTime before being removed.
//
// Its new and remove methods are updated to do cache first.
type lbCacheClientConn struct {
cc balancer.ClientConn
timeout time.Duration
mu sync.Mutex
// subConnCache only keeps subConns that are being deleted.
subConnCache map[resolver.Address]*subConnCacheEntry
subConnToAddr map[balancer.SubConn]resolver.Address
}
type subConnCacheEntry struct {
sc balancer.SubConn
cancel func()
abortDeleting bool
}
func newLBCacheClientConn(cc balancer.ClientConn) *lbCacheClientConn {
return &lbCacheClientConn{
cc: cc,
timeout: subConnCacheTime,
subConnCache: make(map[resolver.Address]*subConnCacheEntry),
subConnToAddr: make(map[balancer.SubConn]resolver.Address),
}
}
func (ccc *lbCacheClientConn) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) {
if len(addrs) != 1 {
return nil, fmt.Errorf("grpclb calling NewSubConn with addrs of length %v", len(addrs))
}
addrWithoutMD := addrs[0]
addrWithoutMD.Metadata = nil
ccc.mu.Lock()
defer ccc.mu.Unlock()
if entry, ok := ccc.subConnCache[addrWithoutMD]; ok {
// If entry is in subConnCache, the SubConn was being deleted.
// cancel function will never be nil.
entry.cancel()
delete(ccc.subConnCache, addrWithoutMD)
return entry.sc, nil
}
scNew, err := ccc.cc.NewSubConn(addrs, opts)
if err != nil {
return nil, err
}
ccc.subConnToAddr[scNew] = addrWithoutMD
return scNew, nil
}
func (ccc *lbCacheClientConn) RemoveSubConn(sc balancer.SubConn) {
ccc.mu.Lock()
defer ccc.mu.Unlock()
addr, ok := ccc.subConnToAddr[sc]
if !ok {
return
}
if entry, ok := ccc.subConnCache[addr]; ok {
if entry.sc != sc {
// This could happen if NewSubConn was called multiple times for the
// same address, and those SubConns are all removed. We remove sc
// immediately here.
delete(ccc.subConnToAddr, sc)
ccc.cc.RemoveSubConn(sc)
}
return
}
entry := &subConnCacheEntry{
sc: sc,
}
ccc.subConnCache[addr] = entry
timer := time.AfterFunc(ccc.timeout, func() {
ccc.mu.Lock()
if entry.abortDeleting {
return
}
ccc.cc.RemoveSubConn(sc)
delete(ccc.subConnToAddr, sc)
delete(ccc.subConnCache, addr)
ccc.mu.Unlock()
})
entry.cancel = func() {
if !timer.Stop() {
// If stop was not successful, the timer has fired (this can only
// happen in a race). But the deleting function is blocked on ccc.mu
// because the mutex was held by the caller of this function.
//
// Set abortDeleting to true to abort the deleting function. When
// the lock is released, the deleting function will acquire the
// lock, check the value of abortDeleting and return.
entry.abortDeleting = true
}
}
}
func (ccc *lbCacheClientConn) UpdateBalancerState(s connectivity.State, p balancer.Picker) {
ccc.cc.UpdateBalancerState(s, p)
}
func (ccc *lbCacheClientConn) close() {
ccc.mu.Lock()
// Only cancel all existing timers. There's no need to remove SubConns.
for _, entry := range ccc.subConnCache {
entry.cancel()
}
ccc.mu.Unlock()
}

View file

@ -0,0 +1,219 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpclb
import (
"fmt"
"sync"
"testing"
"time"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/resolver"
)
type mockSubConn struct {
balancer.SubConn
}
type mockClientConn struct {
balancer.ClientConn
mu sync.Mutex
subConns map[balancer.SubConn]resolver.Address
}
func newMockClientConn() *mockClientConn {
return &mockClientConn{
subConns: make(map[balancer.SubConn]resolver.Address),
}
}
func (mcc *mockClientConn) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) {
sc := &mockSubConn{}
mcc.mu.Lock()
defer mcc.mu.Unlock()
mcc.subConns[sc] = addrs[0]
return sc, nil
}
func (mcc *mockClientConn) RemoveSubConn(sc balancer.SubConn) {
mcc.mu.Lock()
defer mcc.mu.Unlock()
delete(mcc.subConns, sc)
}
const testCacheTimeout = 100 * time.Millisecond
func checkMockCC(mcc *mockClientConn, scLen int) error {
mcc.mu.Lock()
defer mcc.mu.Unlock()
if len(mcc.subConns) != scLen {
return fmt.Errorf("mcc = %+v, want len(mcc.subConns) = %v", mcc.subConns, scLen)
}
return nil
}
func checkCacheCC(ccc *lbCacheClientConn, sccLen, sctaLen int) error {
ccc.mu.Lock()
defer ccc.mu.Unlock()
if len(ccc.subConnCache) != sccLen {
return fmt.Errorf("ccc = %+v, want len(ccc.subConnCache) = %v", ccc.subConnCache, sccLen)
}
if len(ccc.subConnToAddr) != sctaLen {
return fmt.Errorf("ccc = %+v, want len(ccc.subConnToAddr) = %v", ccc.subConnToAddr, sctaLen)
}
return nil
}
// Test that SubConn won't be immediately removed.
func TestLBCacheClientConnExpire(t *testing.T) {
mcc := newMockClientConn()
if err := checkMockCC(mcc, 0); err != nil {
t.Fatal(err)
}
ccc := newLBCacheClientConn(mcc)
ccc.timeout = testCacheTimeout
if err := checkCacheCC(ccc, 0, 0); err != nil {
t.Fatal(err)
}
sc, _ := ccc.NewSubConn([]resolver.Address{{Addr: "address1"}}, balancer.NewSubConnOptions{})
// One subconn in MockCC.
if err := checkMockCC(mcc, 1); err != nil {
t.Fatal(err)
}
// No subconn being deleted, and one in CacheCC.
if err := checkCacheCC(ccc, 0, 1); err != nil {
t.Fatal(err)
}
ccc.RemoveSubConn(sc)
// One subconn in MockCC before timeout.
if err := checkMockCC(mcc, 1); err != nil {
t.Fatal(err)
}
// One subconn being deleted, and one in CacheCC.
if err := checkCacheCC(ccc, 1, 1); err != nil {
t.Fatal(err)
}
// Should all become empty after timeout.
var err error
for i := 0; i < 2; i++ {
time.Sleep(testCacheTimeout)
err = checkMockCC(mcc, 0)
if err != nil {
continue
}
err = checkCacheCC(ccc, 0, 0)
if err != nil {
continue
}
}
if err != nil {
t.Fatal(err)
}
}
// Test that NewSubConn with the same address of a SubConn being removed will
// reuse the SubConn and cancel the removing.
func TestLBCacheClientConnReuse(t *testing.T) {
mcc := newMockClientConn()
if err := checkMockCC(mcc, 0); err != nil {
t.Fatal(err)
}
ccc := newLBCacheClientConn(mcc)
ccc.timeout = testCacheTimeout
if err := checkCacheCC(ccc, 0, 0); err != nil {
t.Fatal(err)
}
sc, _ := ccc.NewSubConn([]resolver.Address{{Addr: "address1"}}, balancer.NewSubConnOptions{})
// One subconn in MockCC.
if err := checkMockCC(mcc, 1); err != nil {
t.Fatal(err)
}
// No subconn being deleted, and one in CacheCC.
if err := checkCacheCC(ccc, 0, 1); err != nil {
t.Fatal(err)
}
ccc.RemoveSubConn(sc)
// One subconn in MockCC before timeout.
if err := checkMockCC(mcc, 1); err != nil {
t.Fatal(err)
}
// One subconn being deleted, and one in CacheCC.
if err := checkCacheCC(ccc, 1, 1); err != nil {
t.Fatal(err)
}
// Recreate the old subconn, this should cancel the deleting process.
sc, _ = ccc.NewSubConn([]resolver.Address{{Addr: "address1"}}, balancer.NewSubConnOptions{})
// One subconn in MockCC.
if err := checkMockCC(mcc, 1); err != nil {
t.Fatal(err)
}
// No subconn being deleted, and one in CacheCC.
if err := checkCacheCC(ccc, 0, 1); err != nil {
t.Fatal(err)
}
var err error
// Should not become empty after 2*timeout.
time.Sleep(2 * testCacheTimeout)
err = checkMockCC(mcc, 1)
if err != nil {
t.Fatal(err)
}
err = checkCacheCC(ccc, 0, 1)
if err != nil {
t.Fatal(err)
}
// Call remove again, will delete after timeout.
ccc.RemoveSubConn(sc)
// One subconn in MockCC before timeout.
if err := checkMockCC(mcc, 1); err != nil {
t.Fatal(err)
}
// One subconn being deleted, and one in CacheCC.
if err := checkCacheCC(ccc, 1, 1); err != nil {
t.Fatal(err)
}
// Should all become empty after timeout.
for i := 0; i < 2; i++ {
time.Sleep(testCacheTimeout)
err = checkMockCC(mcc, 0)
if err != nil {
continue
}
err = checkCacheCC(ccc, 0, 0)
if err != nil {
continue
}
}
if err != nil {
t.Fatal(err)
}
}

33
vendor/google.golang.org/grpc/balancer/grpclb/regenerate.sh generated vendored Executable file
View file

@ -0,0 +1,33 @@
#!/bin/bash
# Copyright 2018 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -eux -o pipefail
TMP=$(mktemp -d)
function finish {
rm -rf "$TMP"
}
trap finish EXIT
pushd "$TMP"
mkdir -p grpc/lb/v1
curl https://raw.githubusercontent.com/grpc/grpc-proto/master/grpc/lb/v1/load_balancer.proto > grpc/lb/v1/load_balancer.proto
protoc --go_out=plugins=grpc,paths=source_relative:. -I. grpc/lb/v1/*.proto
popd
rm -f grpc_lb_v1/*.pb.go
cp "$TMP"/grpc/lb/v1/*.pb.go grpc_lb_v1/

View file

@ -0,0 +1,79 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package roundrobin defines a roundrobin balancer. Roundrobin balancer is
// installed as one of the default balancers in gRPC, users don't need to
// explicitly install this balancer.
package roundrobin
import (
"sync"
"golang.org/x/net/context"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/base"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/resolver"
)
// Name is the name of round_robin balancer.
const Name = "round_robin"
// newBuilder creates a new roundrobin balancer builder.
func newBuilder() balancer.Builder {
return base.NewBalancerBuilder(Name, &rrPickerBuilder{})
}
func init() {
balancer.Register(newBuilder())
}
type rrPickerBuilder struct{}
func (*rrPickerBuilder) Build(readySCs map[resolver.Address]balancer.SubConn) balancer.Picker {
grpclog.Infof("roundrobinPicker: newPicker called with readySCs: %v", readySCs)
var scs []balancer.SubConn
for _, sc := range readySCs {
scs = append(scs, sc)
}
return &rrPicker{
subConns: scs,
}
}
type rrPicker struct {
// subConns is the snapshot of the roundrobin balancer when this picker was
// created. The slice is immutable. Each Get() will do a round robin
// selection from it and return the selected SubConn.
subConns []balancer.SubConn
mu sync.Mutex
next int
}
func (p *rrPicker) Pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) {
if len(p.subConns) <= 0 {
return nil, nil, balancer.ErrNoSubConnAvailable
}
p.mu.Lock()
sc := p.subConns[p.next]
p.next = (p.next + 1) % len(p.subConns)
p.mu.Unlock()
return sc, nil, nil
}

View file

@ -0,0 +1,477 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package roundrobin_test
import (
"fmt"
"net"
"sync"
"testing"
"time"
"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/balancer/roundrobin"
"google.golang.org/grpc/codes"
_ "google.golang.org/grpc/grpclog/glogger"
"google.golang.org/grpc/internal/leakcheck"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/resolver/manual"
"google.golang.org/grpc/status"
testpb "google.golang.org/grpc/test/grpc_testing"
)
type testServer struct {
testpb.TestServiceServer
}
func (s *testServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
return &testpb.Empty{}, nil
}
func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error {
return nil
}
type test struct {
servers []*grpc.Server
addresses []string
}
func (t *test) cleanup() {
for _, s := range t.servers {
s.Stop()
}
}
func startTestServers(count int) (_ *test, err error) {
t := &test{}
defer func() {
if err != nil {
for _, s := range t.servers {
s.Stop()
}
}
}()
for i := 0; i < count; i++ {
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
return nil, fmt.Errorf("Failed to listen %v", err)
}
s := grpc.NewServer()
testpb.RegisterTestServiceServer(s, &testServer{})
t.servers = append(t.servers, s)
t.addresses = append(t.addresses, lis.Addr().String())
go func(s *grpc.Server, l net.Listener) {
s.Serve(l)
}(s, lis)
}
return t, nil
}
func TestOneBackend(t *testing.T) {
defer leakcheck.Check(t)
r, cleanup := manual.GenerateAndRegisterManualResolver()
defer cleanup()
test, err := startTestServers(1)
if err != nil {
t.Fatalf("failed to start servers: %v", err)
}
defer test.cleanup()
cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerName(roundrobin.Name))
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
defer cc.Close()
testc := testpb.NewTestServiceClient(cc)
// The first RPC should fail because there's no address.
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
defer cancel()
if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || status.Code(err) != codes.DeadlineExceeded {
t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
}
r.NewAddress([]resolver.Address{{Addr: test.addresses[0]}})
// The second RPC should succeed.
if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
}
}
func TestBackendsRoundRobin(t *testing.T) {
defer leakcheck.Check(t)
r, cleanup := manual.GenerateAndRegisterManualResolver()
defer cleanup()
backendCount := 5
test, err := startTestServers(backendCount)
if err != nil {
t.Fatalf("failed to start servers: %v", err)
}
defer test.cleanup()
cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerName(roundrobin.Name))
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
defer cc.Close()
testc := testpb.NewTestServiceClient(cc)
// The first RPC should fail because there's no address.
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
defer cancel()
if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || status.Code(err) != codes.DeadlineExceeded {
t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
}
var resolvedAddrs []resolver.Address
for i := 0; i < backendCount; i++ {
resolvedAddrs = append(resolvedAddrs, resolver.Address{Addr: test.addresses[i]})
}
r.NewAddress(resolvedAddrs)
var p peer.Peer
// Make sure connections to all servers are up.
for si := 0; si < backendCount; si++ {
var connected bool
for i := 0; i < 1000; i++ {
if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
}
if p.Addr.String() == test.addresses[si] {
connected = true
break
}
time.Sleep(time.Millisecond)
}
if !connected {
t.Fatalf("Connection to %v was not up after more than 1 second", test.addresses[si])
}
}
for i := 0; i < 3*backendCount; i++ {
if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
}
if p.Addr.String() != test.addresses[i%backendCount] {
t.Fatalf("Index %d: want peer %v, got peer %v", i, test.addresses[i%backendCount], p.Addr.String())
}
}
}
func TestAddressesRemoved(t *testing.T) {
defer leakcheck.Check(t)
r, cleanup := manual.GenerateAndRegisterManualResolver()
defer cleanup()
test, err := startTestServers(1)
if err != nil {
t.Fatalf("failed to start servers: %v", err)
}
defer test.cleanup()
cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerName(roundrobin.Name))
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
defer cc.Close()
testc := testpb.NewTestServiceClient(cc)
// The first RPC should fail because there's no address.
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
defer cancel()
if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || status.Code(err) != codes.DeadlineExceeded {
t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
}
r.NewAddress([]resolver.Address{{Addr: test.addresses[0]}})
// The second RPC should succeed.
if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
}
r.NewAddress([]resolver.Address{})
for i := 0; i < 1000; i++ {
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
if _, err := testc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); status.Code(err) == codes.DeadlineExceeded {
return
}
time.Sleep(time.Millisecond)
}
t.Fatalf("No RPC failed after removing all addresses, want RPC to fail with DeadlineExceeded")
}
func TestCloseWithPendingRPC(t *testing.T) {
defer leakcheck.Check(t)
r, cleanup := manual.GenerateAndRegisterManualResolver()
defer cleanup()
test, err := startTestServers(1)
if err != nil {
t.Fatalf("failed to start servers: %v", err)
}
defer test.cleanup()
cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerName(roundrobin.Name))
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
testc := testpb.NewTestServiceClient(cc)
var wg sync.WaitGroup
for i := 0; i < 3; i++ {
wg.Add(1)
go func() {
defer wg.Done()
// This RPC blocks until cc is closed.
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) == codes.DeadlineExceeded {
t.Errorf("RPC failed because of deadline after cc is closed; want error the client connection is closing")
}
cancel()
}()
}
cc.Close()
wg.Wait()
}
func TestNewAddressWhileBlocking(t *testing.T) {
defer leakcheck.Check(t)
r, cleanup := manual.GenerateAndRegisterManualResolver()
defer cleanup()
test, err := startTestServers(1)
if err != nil {
t.Fatalf("failed to start servers: %v", err)
}
defer test.cleanup()
cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerName(roundrobin.Name))
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
defer cc.Close()
testc := testpb.NewTestServiceClient(cc)
// The first RPC should fail because there's no address.
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
defer cancel()
if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || status.Code(err) != codes.DeadlineExceeded {
t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
}
r.NewAddress([]resolver.Address{{Addr: test.addresses[0]}})
// The second RPC should succeed.
ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err != nil {
t.Fatalf("EmptyCall() = _, %v, want _, nil", err)
}
r.NewAddress([]resolver.Address{})
var wg sync.WaitGroup
for i := 0; i < 3; i++ {
wg.Add(1)
go func() {
defer wg.Done()
// This RPC blocks until NewAddress is called.
testc.EmptyCall(context.Background(), &testpb.Empty{})
}()
}
time.Sleep(50 * time.Millisecond)
r.NewAddress([]resolver.Address{{Addr: test.addresses[0]}})
wg.Wait()
}
func TestOneServerDown(t *testing.T) {
defer leakcheck.Check(t)
r, cleanup := manual.GenerateAndRegisterManualResolver()
defer cleanup()
backendCount := 3
test, err := startTestServers(backendCount)
if err != nil {
t.Fatalf("failed to start servers: %v", err)
}
defer test.cleanup()
cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerName(roundrobin.Name), grpc.WithWaitForHandshake())
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
defer cc.Close()
testc := testpb.NewTestServiceClient(cc)
// The first RPC should fail because there's no address.
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
defer cancel()
if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || status.Code(err) != codes.DeadlineExceeded {
t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
}
var resolvedAddrs []resolver.Address
for i := 0; i < backendCount; i++ {
resolvedAddrs = append(resolvedAddrs, resolver.Address{Addr: test.addresses[i]})
}
r.NewAddress(resolvedAddrs)
var p peer.Peer
// Make sure connections to all servers are up.
for si := 0; si < backendCount; si++ {
var connected bool
for i := 0; i < 1000; i++ {
if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
}
if p.Addr.String() == test.addresses[si] {
connected = true
break
}
time.Sleep(time.Millisecond)
}
if !connected {
t.Fatalf("Connection to %v was not up after more than 1 second", test.addresses[si])
}
}
for i := 0; i < 3*backendCount; i++ {
if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
}
if p.Addr.String() != test.addresses[i%backendCount] {
t.Fatalf("Index %d: want peer %v, got peer %v", i, test.addresses[i%backendCount], p.Addr.String())
}
}
// Stop one server, RPCs should roundrobin among the remaining servers.
backendCount--
test.servers[backendCount].Stop()
// Loop until see server[backendCount-1] twice without seeing server[backendCount].
var targetSeen int
for i := 0; i < 1000; i++ {
if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
targetSeen = 0
t.Logf("EmptyCall() = _, %v, want _, <nil>", err)
// Due to a race, this RPC could possibly get the connection that
// was closing, and this RPC may fail. Keep trying when this
// happens.
continue
}
switch p.Addr.String() {
case test.addresses[backendCount-1]:
targetSeen++
case test.addresses[backendCount]:
// Reset targetSeen if peer is server[backendCount].
targetSeen = 0
}
// Break to make sure the last picked address is server[-1], so the following for loop won't be flaky.
if targetSeen >= 2 {
break
}
}
if targetSeen != 2 {
t.Fatal("Failed to see server[backendCount-1] twice without seeing server[backendCount]")
}
for i := 0; i < 3*backendCount; i++ {
if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
}
if p.Addr.String() != test.addresses[i%backendCount] {
t.Errorf("Index %d: want peer %v, got peer %v", i, test.addresses[i%backendCount], p.Addr.String())
}
}
}
func TestAllServersDown(t *testing.T) {
defer leakcheck.Check(t)
r, cleanup := manual.GenerateAndRegisterManualResolver()
defer cleanup()
backendCount := 3
test, err := startTestServers(backendCount)
if err != nil {
t.Fatalf("failed to start servers: %v", err)
}
defer test.cleanup()
cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerName(roundrobin.Name), grpc.WithWaitForHandshake())
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
defer cc.Close()
testc := testpb.NewTestServiceClient(cc)
// The first RPC should fail because there's no address.
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
defer cancel()
if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || status.Code(err) != codes.DeadlineExceeded {
t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
}
var resolvedAddrs []resolver.Address
for i := 0; i < backendCount; i++ {
resolvedAddrs = append(resolvedAddrs, resolver.Address{Addr: test.addresses[i]})
}
r.NewAddress(resolvedAddrs)
var p peer.Peer
// Make sure connections to all servers are up.
for si := 0; si < backendCount; si++ {
var connected bool
for i := 0; i < 1000; i++ {
if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
}
if p.Addr.String() == test.addresses[si] {
connected = true
break
}
time.Sleep(time.Millisecond)
}
if !connected {
t.Fatalf("Connection to %v was not up after more than 1 second", test.addresses[si])
}
}
for i := 0; i < 3*backendCount; i++ {
if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
}
if p.Addr.String() != test.addresses[i%backendCount] {
t.Fatalf("Index %d: want peer %v, got peer %v", i, test.addresses[i%backendCount], p.Addr.String())
}
}
// All servers are stopped, failfast RPC should fail with unavailable.
for i := 0; i < backendCount; i++ {
test.servers[i].Stop()
}
time.Sleep(100 * time.Millisecond)
for i := 0; i < 1000; i++ {
if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}); status.Code(err) == codes.Unavailable {
return
}
time.Sleep(time.Millisecond)
}
t.Fatalf("Failfast RPCs didn't fail with Unavailable after all servers are stopped")
}

View file

@ -115,7 +115,7 @@ func newCCBalancerWrapper(cc *ClientConn, b balancer.Builder, bopts balancer.Bui
return ccb
}
// watcher balancer functions sequencially, so the balancer can be implemeneted
// watcher balancer functions sequentially, so the balancer can be implemented
// lock-free.
func (ccb *ccBalancerWrapper) watcher() {
for {

View file

@ -25,13 +25,39 @@ import (
"time"
"golang.org/x/net/context"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/roundrobin"
"google.golang.org/grpc/connectivity"
_ "google.golang.org/grpc/grpclog/glogger"
"google.golang.org/grpc/internal/leakcheck"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/resolver/manual"
"google.golang.org/grpc/test/leakcheck"
)
var _ balancer.Builder = &magicalLB{}
var _ balancer.Balancer = &magicalLB{}
// magicalLB is a ringer for grpclb. It is used to avoid circular dependencies on the grpclb package
type magicalLB struct{}
func (b *magicalLB) Name() string {
return "grpclb"
}
func (b *magicalLB) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer {
return b
}
func (b *magicalLB) HandleSubConnStateChange(balancer.SubConn, connectivity.State) {}
func (b *magicalLB) HandleResolvedAddrs([]resolver.Address, error) {}
func (b *magicalLB) Close() {}
func init() {
balancer.Register(&magicalLB{})
}
func checkPickFirst(cc *ClientConn, servers []*server) error {
var (
req = "port"
@ -40,7 +66,7 @@ func checkPickFirst(cc *ClientConn, servers []*server) error {
)
connected := false
for i := 0; i < 5000; i++ {
if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); errorDesc(err) == servers[0].port {
if err = cc.Invoke(context.Background(), "/foo/bar", &req, &reply); errorDesc(err) == servers[0].port {
if connected {
// connected is set to false if peer is not server[0]. So if
// connected is true here, this is the second time we saw
@ -58,7 +84,7 @@ func checkPickFirst(cc *ClientConn, servers []*server) error {
}
// The following RPCs should all succeed with the first server.
for i := 0; i < 3; i++ {
err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc)
err = cc.Invoke(context.Background(), "/foo/bar", &req, &reply)
if errorDesc(err) != servers[0].port {
return fmt.Errorf("Index %d: want peer %v, got peer %v", i, servers[0].port, err)
}
@ -80,7 +106,7 @@ func checkRoundRobin(cc *ClientConn, servers []*server) error {
for _, s := range servers {
var up bool
for i := 0; i < 5000; i++ {
if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); errorDesc(err) == s.port {
if err = cc.Invoke(context.Background(), "/foo/bar", &req, &reply); errorDesc(err) == s.port {
up = true
break
}
@ -94,7 +120,7 @@ func checkRoundRobin(cc *ClientConn, servers []*server) error {
serverCount := len(servers)
for i := 0; i < 3*serverCount; i++ {
err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc)
err = cc.Invoke(context.Background(), "/foo/bar", &req, &reply)
if errorDesc(err) != servers[i%serverCount].port {
return fmt.Errorf("Index %d: want peer %v, got peer %v", i, servers[i%serverCount].port, err)
}

View file

@ -29,9 +29,9 @@ import (
"golang.org/x/net/context"
"google.golang.org/grpc/codes"
_ "google.golang.org/grpc/grpclog/glogger"
"google.golang.org/grpc/internal/leakcheck"
"google.golang.org/grpc/naming"
"google.golang.org/grpc/status"
"google.golang.org/grpc/test/leakcheck"
// V1 balancer tests use passthrough resolver instead of dns.
// TODO(bar) remove this when removing v1 balaner entirely.
@ -39,12 +39,16 @@ import (
_ "google.golang.org/grpc/resolver/passthrough"
)
func pickFirstBalancerV1(r naming.Resolver) Balancer {
return &pickFirst{&roundRobin{r: r}}
}
type testWatcher struct {
// the channel to receives name resolution updates
update chan *naming.Update
// the side channel to get to know how many updates in a batch
side chan int
// the channel to notifiy update injector that the update reading is done
// the channel to notify update injector that the update reading is done
readDone chan int
}
@ -130,7 +134,7 @@ func TestNameDiscovery(t *testing.T) {
defer cc.Close()
req := "port"
var reply string
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || errorDesc(err) != servers[0].port {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err == nil || errorDesc(err) != servers[0].port {
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want %s", err, servers[0].port)
}
// Inject the name resolution change to remove servers[0] and add servers[1].
@ -146,7 +150,7 @@ func TestNameDiscovery(t *testing.T) {
r.w.inject(updates)
// Loop until the rpcs in flight talks to servers[1].
for {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && errorDesc(err) == servers[1].port {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err != nil && errorDesc(err) == servers[1].port {
break
}
time.Sleep(10 * time.Millisecond)
@ -163,7 +167,7 @@ func TestEmptyAddrs(t *testing.T) {
}
defer cc.Close()
var reply string
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err != nil || reply != expectedResponse {
if err := cc.Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse {
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, reply = %q, want %q, <nil>", err, reply, expectedResponse)
}
// Inject name resolution change to remove the server so that there is no address
@ -177,7 +181,7 @@ func TestEmptyAddrs(t *testing.T) {
for {
time.Sleep(10 * time.Millisecond)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc); err != nil {
if err := cc.Invoke(ctx, "/foo/bar", &expectedRequest, &reply); err != nil {
cancel()
break
}
@ -206,7 +210,7 @@ func TestRoundRobin(t *testing.T) {
var reply string
// Loop until servers[1] is up
for {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && errorDesc(err) == servers[1].port {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err != nil && errorDesc(err) == servers[1].port {
break
}
time.Sleep(10 * time.Millisecond)
@ -219,14 +223,14 @@ func TestRoundRobin(t *testing.T) {
r.w.inject([]*naming.Update{u})
// Loop until both servers[2] are up.
for {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && errorDesc(err) == servers[2].port {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err != nil && errorDesc(err) == servers[2].port {
break
}
time.Sleep(10 * time.Millisecond)
}
// Check the incoming RPCs served in a round-robin manner.
for i := 0; i < 10; i++ {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || errorDesc(err) != servers[i%numServers].port {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err == nil || errorDesc(err) != servers[i%numServers].port {
t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", i, err, servers[i%numServers].port)
}
}
@ -242,7 +246,7 @@ func TestCloseWithPendingRPC(t *testing.T) {
}
defer cc.Close()
var reply string
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err != nil {
if err := cc.Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, FailFast(false)); err != nil {
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want %s", err, servers[0].port)
}
// Remove the server.
@ -254,7 +258,7 @@ func TestCloseWithPendingRPC(t *testing.T) {
// Loop until the above update applies.
for {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); status.Code(err) == codes.DeadlineExceeded {
if err := cc.Invoke(ctx, "/foo/bar", &expectedRequest, &reply, FailFast(false)); status.Code(err) == codes.DeadlineExceeded {
cancel()
break
}
@ -267,7 +271,7 @@ func TestCloseWithPendingRPC(t *testing.T) {
go func() {
defer wg.Done()
var reply string
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err == nil {
if err := cc.Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, FailFast(false)); err == nil {
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err)
}
}()
@ -275,7 +279,7 @@ func TestCloseWithPendingRPC(t *testing.T) {
defer wg.Done()
var reply string
time.Sleep(5 * time.Millisecond)
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err == nil {
if err := cc.Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, FailFast(false)); err == nil {
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err)
}
}()
@ -302,7 +306,7 @@ func TestGetOnWaitChannel(t *testing.T) {
for {
var reply string
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); status.Code(err) == codes.DeadlineExceeded {
if err := cc.Invoke(ctx, "/foo/bar", &expectedRequest, &reply, FailFast(false)); status.Code(err) == codes.DeadlineExceeded {
cancel()
break
}
@ -314,7 +318,7 @@ func TestGetOnWaitChannel(t *testing.T) {
go func() {
defer wg.Done()
var reply string
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err != nil {
if err := cc.Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, FailFast(false)); err != nil {
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
}
}()
@ -350,7 +354,7 @@ func TestOneServerDown(t *testing.T) {
var reply string
// Loop until servers[1] is up
for {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && errorDesc(err) == servers[1].port {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err != nil && errorDesc(err) == servers[1].port {
break
}
time.Sleep(10 * time.Millisecond)
@ -374,7 +378,7 @@ func TestOneServerDown(t *testing.T) {
time.Sleep(sleepDuration)
// After sleepDuration, invoke RPC.
// server[0] is killed around the same time to make it racy between balancer and gRPC internals.
Invoke(context.Background(), "/foo/bar", &req, &reply, cc, FailFast(false))
cc.Invoke(context.Background(), "/foo/bar", &req, &reply, FailFast(false))
wg.Done()
}()
}
@ -403,7 +407,7 @@ func TestOneAddressRemoval(t *testing.T) {
var reply string
// Loop until servers[1] is up
for {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && errorDesc(err) == servers[1].port {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err != nil && errorDesc(err) == servers[1].port {
break
}
time.Sleep(10 * time.Millisecond)
@ -433,8 +437,8 @@ func TestOneAddressRemoval(t *testing.T) {
time.Sleep(sleepDuration)
// After sleepDuration, invoke RPC.
// server[0] is removed around the same time to make it racy between balancer and gRPC internals.
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err != nil {
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err)
if err := cc.Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, FailFast(false)); err != nil {
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want nil", err)
}
wg.Done()
}()
@ -452,7 +456,7 @@ func checkServerUp(t *testing.T, currentServer *server) {
defer cc.Close()
var reply string
for {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && errorDesc(err) == port {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err != nil && errorDesc(err) == port {
break
}
time.Sleep(10 * time.Millisecond)
@ -469,7 +473,7 @@ func TestPickFirstEmptyAddrs(t *testing.T) {
}
defer cc.Close()
var reply string
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err != nil || reply != expectedResponse {
if err := cc.Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse {
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, reply = %q, want %q, <nil>", err, reply, expectedResponse)
}
// Inject name resolution change to remove the server so that there is no address
@ -483,7 +487,7 @@ func TestPickFirstEmptyAddrs(t *testing.T) {
for {
time.Sleep(10 * time.Millisecond)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc); err != nil {
if err := cc.Invoke(ctx, "/foo/bar", &expectedRequest, &reply); err != nil {
cancel()
break
}
@ -501,7 +505,7 @@ func TestPickFirstCloseWithPendingRPC(t *testing.T) {
}
defer cc.Close()
var reply string
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err != nil {
if err := cc.Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, FailFast(false)); err != nil {
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want %s", err, servers[0].port)
}
// Remove the server.
@ -513,7 +517,7 @@ func TestPickFirstCloseWithPendingRPC(t *testing.T) {
// Loop until the above update applies.
for {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); status.Code(err) == codes.DeadlineExceeded {
if err := cc.Invoke(ctx, "/foo/bar", &expectedRequest, &reply, FailFast(false)); status.Code(err) == codes.DeadlineExceeded {
cancel()
break
}
@ -526,7 +530,7 @@ func TestPickFirstCloseWithPendingRPC(t *testing.T) {
go func() {
defer wg.Done()
var reply string
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err == nil {
if err := cc.Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, FailFast(false)); err == nil {
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err)
}
}()
@ -534,7 +538,7 @@ func TestPickFirstCloseWithPendingRPC(t *testing.T) {
defer wg.Done()
var reply string
time.Sleep(5 * time.Millisecond)
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err == nil {
if err := cc.Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, FailFast(false)); err == nil {
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err)
}
}()
@ -576,7 +580,7 @@ func TestPickFirstOrderAllServerUp(t *testing.T) {
req := "port"
var reply string
for i := 0; i < 20; i++ {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || errorDesc(err) != servers[0].port {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err == nil || errorDesc(err) != servers[0].port {
t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 0, err, servers[0].port)
}
time.Sleep(10 * time.Millisecond)
@ -591,13 +595,13 @@ func TestPickFirstOrderAllServerUp(t *testing.T) {
r.w.inject([]*naming.Update{u})
// Loop until it changes to server[1]
for {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && errorDesc(err) == servers[1].port {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err != nil && errorDesc(err) == servers[1].port {
break
}
time.Sleep(10 * time.Millisecond)
}
for i := 0; i < 20; i++ {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || errorDesc(err) != servers[1].port {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err == nil || errorDesc(err) != servers[1].port {
t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 1, err, servers[1].port)
}
time.Sleep(10 * time.Millisecond)
@ -611,7 +615,7 @@ func TestPickFirstOrderAllServerUp(t *testing.T) {
}
r.w.inject([]*naming.Update{u})
for i := 0; i < 20; i++ {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || errorDesc(err) != servers[1].port {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err == nil || errorDesc(err) != servers[1].port {
t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 1, err, servers[1].port)
}
time.Sleep(10 * time.Millisecond)
@ -624,13 +628,13 @@ func TestPickFirstOrderAllServerUp(t *testing.T) {
}
r.w.inject([]*naming.Update{u})
for {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && errorDesc(err) == servers[2].port {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err != nil && errorDesc(err) == servers[2].port {
break
}
time.Sleep(1 * time.Second)
}
for i := 0; i < 20; i++ {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || errorDesc(err) != servers[2].port {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err == nil || errorDesc(err) != servers[2].port {
t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 2, err, servers[2].port)
}
time.Sleep(10 * time.Millisecond)
@ -643,13 +647,13 @@ func TestPickFirstOrderAllServerUp(t *testing.T) {
}
r.w.inject([]*naming.Update{u})
for {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && errorDesc(err) == servers[0].port {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err != nil && errorDesc(err) == servers[0].port {
break
}
time.Sleep(1 * time.Second)
}
for i := 0; i < 20; i++ {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || errorDesc(err) != servers[0].port {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err == nil || errorDesc(err) != servers[0].port {
t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 0, err, servers[0].port)
}
time.Sleep(10 * time.Millisecond)
@ -689,7 +693,7 @@ func TestPickFirstOrderOneServerDown(t *testing.T) {
req := "port"
var reply string
for i := 0; i < 20; i++ {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || errorDesc(err) != servers[0].port {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err == nil || errorDesc(err) != servers[0].port {
t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 0, err, servers[0].port)
}
time.Sleep(10 * time.Millisecond)
@ -700,13 +704,13 @@ func TestPickFirstOrderOneServerDown(t *testing.T) {
servers[0].stop()
// Loop until it changes to server[1]
for {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && errorDesc(err) == servers[1].port {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err != nil && errorDesc(err) == servers[1].port {
break
}
time.Sleep(10 * time.Millisecond)
}
for i := 0; i < 20; i++ {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || errorDesc(err) != servers[1].port {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err == nil || errorDesc(err) != servers[1].port {
t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 1, err, servers[1].port)
}
time.Sleep(10 * time.Millisecond)
@ -721,7 +725,7 @@ func TestPickFirstOrderOneServerDown(t *testing.T) {
checkServerUp(t, servers[0])
for i := 0; i < 20; i++ {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || errorDesc(err) != servers[1].port {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err == nil || errorDesc(err) != servers[1].port {
t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 1, err, servers[1].port)
}
time.Sleep(10 * time.Millisecond)
@ -734,13 +738,13 @@ func TestPickFirstOrderOneServerDown(t *testing.T) {
}
r.w.inject([]*naming.Update{u})
for {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && errorDesc(err) == servers[0].port {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err != nil && errorDesc(err) == servers[0].port {
break
}
time.Sleep(1 * time.Second)
}
for i := 0; i < 20; i++ {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || errorDesc(err) != servers[0].port {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err == nil || errorDesc(err) != servers[0].port {
t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 0, err, servers[0].port)
}
time.Sleep(10 * time.Millisecond)
@ -794,8 +798,8 @@ func TestPickFirstOneAddressRemoval(t *testing.T) {
time.Sleep(sleepDuration)
// After sleepDuration, invoke RPC.
// server[0] is removed around the same time to make it racy between balancer and gRPC internals.
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err != nil {
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err)
if err := cc.Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, FailFast(false)); err != nil {
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want nil", err)
}
wg.Done()
}()

View file

@ -55,7 +55,7 @@ func (bwb *balancerWrapperBuilder) Build(cc balancer.ClientConn, opts balancer.B
startCh: make(chan struct{}),
conns: make(map[resolver.Address]balancer.SubConn),
connSt: make(map[balancer.SubConn]*scState),
csEvltr: &connectivityStateEvaluator{},
csEvltr: &balancer.ConnectivityStateEvaluator{},
state: connectivity.Idle,
}
cc.UpdateBalancerState(connectivity.Idle, bw)
@ -80,10 +80,6 @@ type balancerWrapper struct {
cc balancer.ClientConn
targetAddr string // Target without the scheme.
// To aggregate the connectivity state.
csEvltr *connectivityStateEvaluator
state connectivity.State
mu sync.Mutex
conns map[resolver.Address]balancer.SubConn
connSt map[balancer.SubConn]*scState
@ -92,6 +88,10 @@ type balancerWrapper struct {
// - NewSubConn is created, cc wants to notify balancer of state changes;
// - Build hasn't return, cc doesn't have access to balancer.
startCh chan struct{}
// To aggregate the connectivity state.
csEvltr *balancer.ConnectivityStateEvaluator
state connectivity.State
}
// lbWatcher watches the Notify channel of the balancer and manages
@ -248,7 +248,7 @@ func (bw *balancerWrapper) HandleSubConnStateChange(sc balancer.SubConn, s conne
scSt.down(errConnClosing)
}
}
sa := bw.csEvltr.recordTransition(oldS, s)
sa := bw.csEvltr.RecordTransition(oldS, s)
if bw.state != sa {
bw.state = sa
}
@ -257,7 +257,6 @@ func (bw *balancerWrapper) HandleSubConnStateChange(sc balancer.SubConn, s conne
// Remove state for this sc.
delete(bw.connSt, sc)
}
return
}
func (bw *balancerWrapper) HandleResolvedAddrs([]resolver.Address, error) {
@ -270,7 +269,6 @@ func (bw *balancerWrapper) HandleResolvedAddrs([]resolver.Address, error) {
}
// There should be a resolver inside the balancer.
// All updates here, if any, are ignored.
return
}
func (bw *balancerWrapper) Close() {
@ -282,7 +280,6 @@ func (bw *balancerWrapper) Close() {
close(bw.startCh)
}
bw.balancer.Close()
return
}
// The picker is the balancerWrapper itself.
@ -329,47 +326,3 @@ func (bw *balancerWrapper) Pick(ctx context.Context, opts balancer.PickOptions)
return sc, done, nil
}
// connectivityStateEvaluator gets updated by addrConns when their
// states transition, based on which it evaluates the state of
// ClientConn.
type connectivityStateEvaluator struct {
mu sync.Mutex
numReady uint64 // Number of addrConns in ready state.
numConnecting uint64 // Number of addrConns in connecting state.
numTransientFailure uint64 // Number of addrConns in transientFailure.
}
// recordTransition records state change happening in every subConn and based on
// that it evaluates what aggregated state should be.
// It can only transition between Ready, Connecting and TransientFailure. Other states,
// Idle and Shutdown are transitioned into by ClientConn; in the beginning of the connection
// before any subConn is created ClientConn is in idle state. In the end when ClientConn
// closes it is in Shutdown state.
// TODO Note that in later releases, a ClientConn with no activity will be put into an Idle state.
func (cse *connectivityStateEvaluator) recordTransition(oldState, newState connectivity.State) connectivity.State {
cse.mu.Lock()
defer cse.mu.Unlock()
// Update counters.
for idx, state := range []connectivity.State{oldState, newState} {
updateVal := 2*uint64(idx) - 1 // -1 for oldState and +1 for new.
switch state {
case connectivity.Ready:
cse.numReady += updateVal
case connectivity.Connecting:
cse.numConnecting += updateVal
case connectivity.TransientFailure:
cse.numTransientFailure += updateVal
}
}
// Evaluate.
if cse.numReady > 0 {
return connectivity.Ready
}
if cse.numConnecting > 0 {
return connectivity.Connecting
}
return connectivity.TransientFailure
}

View file

@ -0,0 +1,547 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
/*
Package main provides benchmark with setting flags.
An example to run some benchmarks with profiling enabled:
go run benchmark/benchmain/main.go -benchtime=10s -workloads=all \
-compression=on -maxConcurrentCalls=1 -trace=off \
-reqSizeBytes=1,1048576 -respSizeBytes=1,1048576 -networkMode=Local \
-cpuProfile=cpuProf -memProfile=memProf -memProfileRate=10000 -resultFile=result
As a suggestion, when creating a branch, you can run this benchmark and save the result
file "-resultFile=basePerf", and later when you at the middle of the work or finish the
work, you can get the benchmark result and compare it with the base anytime.
Assume there are two result files names as "basePerf" and "curPerf" created by adding
-resultFile=basePerf and -resultFile=curPerf.
To format the curPerf, run:
go run benchmark/benchresult/main.go curPerf
To observe how the performance changes based on a base result, run:
go run benchmark/benchresult/main.go basePerf curPerf
*/
package main
import (
"encoding/gob"
"errors"
"flag"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"os"
"reflect"
"runtime"
"runtime/pprof"
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"golang.org/x/net/context"
"google.golang.org/grpc"
bm "google.golang.org/grpc/benchmark"
testpb "google.golang.org/grpc/benchmark/grpc_testing"
"google.golang.org/grpc/benchmark/latency"
"google.golang.org/grpc/benchmark/stats"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/test/bufconn"
)
const (
modeOn = "on"
modeOff = "off"
modeBoth = "both"
)
var allCompressionModes = []string{modeOn, modeOff, modeBoth}
var allTraceModes = []string{modeOn, modeOff, modeBoth}
const (
workloadsUnary = "unary"
workloadsStreaming = "streaming"
workloadsAll = "all"
)
var allWorkloads = []string{workloadsUnary, workloadsStreaming, workloadsAll}
var (
runMode = []bool{true, true} // {runUnary, runStream}
// When set the latency to 0 (no delay), the result is slower than the real result with no delay
// because latency simulation section has extra operations
ltc = []time.Duration{0, 40 * time.Millisecond} // if non-positive, no delay.
kbps = []int{0, 10240} // if non-positive, infinite
mtu = []int{0} // if non-positive, infinite
maxConcurrentCalls = []int{1, 8, 64, 512}
reqSizeBytes = []int{1, 1024, 1024 * 1024}
respSizeBytes = []int{1, 1024, 1024 * 1024}
enableTrace []bool
benchtime time.Duration
memProfile, cpuProfile string
memProfileRate int
enableCompressor []bool
enableChannelz []bool
networkMode string
benchmarkResultFile string
networks = map[string]latency.Network{
"Local": latency.Local,
"LAN": latency.LAN,
"WAN": latency.WAN,
"Longhaul": latency.Longhaul,
}
)
func unaryBenchmark(startTimer func(), stopTimer func(int32), benchFeatures stats.Features, benchtime time.Duration, s *stats.Stats) {
caller, cleanup := makeFuncUnary(benchFeatures)
defer cleanup()
runBenchmark(caller, startTimer, stopTimer, benchFeatures, benchtime, s)
}
func streamBenchmark(startTimer func(), stopTimer func(int32), benchFeatures stats.Features, benchtime time.Duration, s *stats.Stats) {
caller, cleanup := makeFuncStream(benchFeatures)
defer cleanup()
runBenchmark(caller, startTimer, stopTimer, benchFeatures, benchtime, s)
}
func makeFuncUnary(benchFeatures stats.Features) (func(int), func()) {
nw := &latency.Network{Kbps: benchFeatures.Kbps, Latency: benchFeatures.Latency, MTU: benchFeatures.Mtu}
opts := []grpc.DialOption{}
sopts := []grpc.ServerOption{}
if benchFeatures.EnableCompressor {
sopts = append(sopts,
grpc.RPCCompressor(nopCompressor{}),
grpc.RPCDecompressor(nopDecompressor{}),
)
opts = append(opts,
grpc.WithCompressor(nopCompressor{}),
grpc.WithDecompressor(nopDecompressor{}),
)
}
sopts = append(sopts, grpc.MaxConcurrentStreams(uint32(benchFeatures.MaxConcurrentCalls+1)))
opts = append(opts, grpc.WithInsecure())
var lis net.Listener
if *useBufconn {
bcLis := bufconn.Listen(256 * 1024)
lis = bcLis
opts = append(opts, grpc.WithDialer(func(string, time.Duration) (net.Conn, error) {
return nw.TimeoutDialer(
func(string, string, time.Duration) (net.Conn, error) {
return bcLis.Dial()
})("", "", 0)
}))
} else {
var err error
lis, err = net.Listen("tcp", "localhost:0")
if err != nil {
grpclog.Fatalf("Failed to listen: %v", err)
}
opts = append(opts, grpc.WithDialer(func(_ string, timeout time.Duration) (net.Conn, error) {
return nw.TimeoutDialer(net.DialTimeout)("tcp", lis.Addr().String(), timeout)
}))
}
lis = nw.Listener(lis)
stopper := bm.StartServer(bm.ServerInfo{Type: "protobuf", Listener: lis}, sopts...)
conn := bm.NewClientConn("" /* target not used */, opts...)
tc := testpb.NewBenchmarkServiceClient(conn)
return func(int) {
unaryCaller(tc, benchFeatures.ReqSizeBytes, benchFeatures.RespSizeBytes)
}, func() {
conn.Close()
stopper()
}
}
func makeFuncStream(benchFeatures stats.Features) (func(int), func()) {
// TODO: Refactor to remove duplication with makeFuncUnary.
nw := &latency.Network{Kbps: benchFeatures.Kbps, Latency: benchFeatures.Latency, MTU: benchFeatures.Mtu}
opts := []grpc.DialOption{}
sopts := []grpc.ServerOption{}
if benchFeatures.EnableCompressor {
sopts = append(sopts,
grpc.RPCCompressor(grpc.NewGZIPCompressor()),
grpc.RPCDecompressor(grpc.NewGZIPDecompressor()),
)
opts = append(opts,
grpc.WithCompressor(grpc.NewGZIPCompressor()),
grpc.WithDecompressor(grpc.NewGZIPDecompressor()),
)
}
sopts = append(sopts, grpc.MaxConcurrentStreams(uint32(benchFeatures.MaxConcurrentCalls+1)))
opts = append(opts, grpc.WithInsecure())
var lis net.Listener
if *useBufconn {
bcLis := bufconn.Listen(256 * 1024)
lis = bcLis
opts = append(opts, grpc.WithDialer(func(string, time.Duration) (net.Conn, error) {
return nw.TimeoutDialer(
func(string, string, time.Duration) (net.Conn, error) {
return bcLis.Dial()
})("", "", 0)
}))
} else {
var err error
lis, err = net.Listen("tcp", "localhost:0")
if err != nil {
grpclog.Fatalf("Failed to listen: %v", err)
}
opts = append(opts, grpc.WithDialer(func(_ string, timeout time.Duration) (net.Conn, error) {
return nw.TimeoutDialer(net.DialTimeout)("tcp", lis.Addr().String(), timeout)
}))
}
lis = nw.Listener(lis)
stopper := bm.StartServer(bm.ServerInfo{Type: "protobuf", Listener: lis}, sopts...)
conn := bm.NewClientConn("" /* target not used */, opts...)
tc := testpb.NewBenchmarkServiceClient(conn)
streams := make([]testpb.BenchmarkService_StreamingCallClient, benchFeatures.MaxConcurrentCalls)
for i := 0; i < benchFeatures.MaxConcurrentCalls; i++ {
stream, err := tc.StreamingCall(context.Background())
if err != nil {
grpclog.Fatalf("%v.StreamingCall(_) = _, %v", tc, err)
}
streams[i] = stream
}
return func(pos int) {
streamCaller(streams[pos], benchFeatures.ReqSizeBytes, benchFeatures.RespSizeBytes)
}, func() {
conn.Close()
stopper()
}
}
func unaryCaller(client testpb.BenchmarkServiceClient, reqSize, respSize int) {
if err := bm.DoUnaryCall(client, reqSize, respSize); err != nil {
grpclog.Fatalf("DoUnaryCall failed: %v", err)
}
}
func streamCaller(stream testpb.BenchmarkService_StreamingCallClient, reqSize, respSize int) {
if err := bm.DoStreamingRoundTrip(stream, reqSize, respSize); err != nil {
grpclog.Fatalf("DoStreamingRoundTrip failed: %v", err)
}
}
func runBenchmark(caller func(int), startTimer func(), stopTimer func(int32), benchFeatures stats.Features, benchtime time.Duration, s *stats.Stats) {
// Warm up connection.
for i := 0; i < 10; i++ {
caller(0)
}
// Run benchmark.
startTimer()
var (
mu sync.Mutex
wg sync.WaitGroup
)
wg.Add(benchFeatures.MaxConcurrentCalls)
bmEnd := time.Now().Add(benchtime)
var count int32
for i := 0; i < benchFeatures.MaxConcurrentCalls; i++ {
go func(pos int) {
for {
t := time.Now()
if t.After(bmEnd) {
break
}
start := time.Now()
caller(pos)
elapse := time.Since(start)
atomic.AddInt32(&count, 1)
mu.Lock()
s.Add(elapse)
mu.Unlock()
}
wg.Done()
}(i)
}
wg.Wait()
stopTimer(count)
}
var useBufconn = flag.Bool("bufconn", false, "Use in-memory connection instead of system network I/O")
// Initiate main function to get settings of features.
func init() {
var (
workloads, traceMode, compressorMode, readLatency, channelzOn string
readKbps, readMtu, readMaxConcurrentCalls intSliceType
readReqSizeBytes, readRespSizeBytes intSliceType
)
flag.StringVar(&workloads, "workloads", workloadsAll,
fmt.Sprintf("Workloads to execute - One of: %v", strings.Join(allWorkloads, ", ")))
flag.StringVar(&traceMode, "trace", modeOff,
fmt.Sprintf("Trace mode - One of: %v", strings.Join(allTraceModes, ", ")))
flag.StringVar(&readLatency, "latency", "", "Simulated one-way network latency - may be a comma-separated list")
flag.StringVar(&channelzOn, "channelz", modeOff, "whether channelz should be turned on")
flag.DurationVar(&benchtime, "benchtime", time.Second, "Configures the amount of time to run each benchmark")
flag.Var(&readKbps, "kbps", "Simulated network throughput (in kbps) - may be a comma-separated list")
flag.Var(&readMtu, "mtu", "Simulated network MTU (Maximum Transmission Unit) - may be a comma-separated list")
flag.Var(&readMaxConcurrentCalls, "maxConcurrentCalls", "Number of concurrent RPCs during benchmarks")
flag.Var(&readReqSizeBytes, "reqSizeBytes", "Request size in bytes - may be a comma-separated list")
flag.Var(&readRespSizeBytes, "respSizeBytes", "Response size in bytes - may be a comma-separated list")
flag.StringVar(&memProfile, "memProfile", "", "Enables memory profiling output to the filename provided.")
flag.IntVar(&memProfileRate, "memProfileRate", 512*1024, "Configures the memory profiling rate. \n"+
"memProfile should be set before setting profile rate. To include every allocated block in the profile, "+
"set MemProfileRate to 1. To turn off profiling entirely, set MemProfileRate to 0. 512 * 1024 by default.")
flag.StringVar(&cpuProfile, "cpuProfile", "", "Enables CPU profiling output to the filename provided")
flag.StringVar(&compressorMode, "compression", modeOff,
fmt.Sprintf("Compression mode - One of: %v", strings.Join(allCompressionModes, ", ")))
flag.StringVar(&benchmarkResultFile, "resultFile", "", "Save the benchmark result into a binary file")
flag.StringVar(&networkMode, "networkMode", "", "Network mode includes LAN, WAN, Local and Longhaul")
flag.Parse()
if flag.NArg() != 0 {
log.Fatal("Error: unparsed arguments: ", flag.Args())
}
switch workloads {
case workloadsUnary:
runMode[0] = true
runMode[1] = false
case workloadsStreaming:
runMode[0] = false
runMode[1] = true
case workloadsAll:
runMode[0] = true
runMode[1] = true
default:
log.Fatalf("Unknown workloads setting: %v (want one of: %v)",
workloads, strings.Join(allWorkloads, ", "))
}
enableCompressor = setMode(compressorMode)
enableTrace = setMode(traceMode)
enableChannelz = setMode(channelzOn)
// Time input formats as (time + unit).
readTimeFromInput(&ltc, readLatency)
readIntFromIntSlice(&kbps, readKbps)
readIntFromIntSlice(&mtu, readMtu)
readIntFromIntSlice(&maxConcurrentCalls, readMaxConcurrentCalls)
readIntFromIntSlice(&reqSizeBytes, readReqSizeBytes)
readIntFromIntSlice(&respSizeBytes, readRespSizeBytes)
// Re-write latency, kpbs and mtu if network mode is set.
if network, ok := networks[networkMode]; ok {
ltc = []time.Duration{network.Latency}
kbps = []int{network.Kbps}
mtu = []int{network.MTU}
}
}
func setMode(name string) []bool {
switch name {
case modeOn:
return []bool{true}
case modeOff:
return []bool{false}
case modeBoth:
return []bool{false, true}
default:
log.Fatalf("Unknown %s setting: %v (want one of: %v)",
name, name, strings.Join(allCompressionModes, ", "))
return []bool{}
}
}
type intSliceType []int
func (intSlice *intSliceType) String() string {
return fmt.Sprintf("%v", *intSlice)
}
func (intSlice *intSliceType) Set(value string) error {
if len(*intSlice) > 0 {
return errors.New("interval flag already set")
}
for _, num := range strings.Split(value, ",") {
next, err := strconv.Atoi(num)
if err != nil {
return err
}
*intSlice = append(*intSlice, next)
}
return nil
}
func readIntFromIntSlice(values *[]int, replace intSliceType) {
// If not set replace in the flag, just return to run the default settings.
if len(replace) == 0 {
return
}
*values = replace
}
func readTimeFromInput(values *[]time.Duration, replace string) {
if strings.Compare(replace, "") != 0 {
*values = []time.Duration{}
for _, ltc := range strings.Split(replace, ",") {
duration, err := time.ParseDuration(ltc)
if err != nil {
log.Fatal(err.Error())
}
*values = append(*values, duration)
}
}
}
func main() {
before()
featuresPos := make([]int, 9)
// 0:enableTracing 1:ltc 2:kbps 3:mtu 4:maxC 5:reqSize 6:respSize
featuresNum := []int{len(enableTrace), len(ltc), len(kbps), len(mtu),
len(maxConcurrentCalls), len(reqSizeBytes), len(respSizeBytes), len(enableCompressor), len(enableChannelz)}
initalPos := make([]int, len(featuresPos))
s := stats.NewStats(10)
s.SortLatency()
var memStats runtime.MemStats
var results testing.BenchmarkResult
var startAllocs, startBytes uint64
var startTime time.Time
start := true
var startTimer = func() {
runtime.ReadMemStats(&memStats)
startAllocs = memStats.Mallocs
startBytes = memStats.TotalAlloc
startTime = time.Now()
}
var stopTimer = func(count int32) {
runtime.ReadMemStats(&memStats)
results = testing.BenchmarkResult{N: int(count), T: time.Since(startTime),
Bytes: 0, MemAllocs: memStats.Mallocs - startAllocs, MemBytes: memStats.TotalAlloc - startBytes}
}
sharedPos := make([]bool, len(featuresPos))
for i := 0; i < len(featuresPos); i++ {
if featuresNum[i] <= 1 {
sharedPos[i] = true
}
}
// Run benchmarks
resultSlice := []stats.BenchResults{}
for !reflect.DeepEqual(featuresPos, initalPos) || start {
start = false
benchFeature := stats.Features{
NetworkMode: networkMode,
EnableTrace: enableTrace[featuresPos[0]],
Latency: ltc[featuresPos[1]],
Kbps: kbps[featuresPos[2]],
Mtu: mtu[featuresPos[3]],
MaxConcurrentCalls: maxConcurrentCalls[featuresPos[4]],
ReqSizeBytes: reqSizeBytes[featuresPos[5]],
RespSizeBytes: respSizeBytes[featuresPos[6]],
EnableCompressor: enableCompressor[featuresPos[7]],
EnableChannelz: enableChannelz[featuresPos[8]],
}
grpc.EnableTracing = enableTrace[featuresPos[0]]
if enableChannelz[featuresPos[8]] {
channelz.TurnOn()
}
if runMode[0] {
unaryBenchmark(startTimer, stopTimer, benchFeature, benchtime, s)
s.SetBenchmarkResult("Unary", benchFeature, results.N,
results.AllocedBytesPerOp(), results.AllocsPerOp(), sharedPos)
fmt.Println(s.BenchString())
fmt.Println(s.String())
resultSlice = append(resultSlice, s.GetBenchmarkResults())
s.Clear()
}
if runMode[1] {
streamBenchmark(startTimer, stopTimer, benchFeature, benchtime, s)
s.SetBenchmarkResult("Stream", benchFeature, results.N,
results.AllocedBytesPerOp(), results.AllocsPerOp(), sharedPos)
fmt.Println(s.BenchString())
fmt.Println(s.String())
resultSlice = append(resultSlice, s.GetBenchmarkResults())
s.Clear()
}
bm.AddOne(featuresPos, featuresNum)
}
after(resultSlice)
}
func before() {
if memProfile != "" {
runtime.MemProfileRate = memProfileRate
}
if cpuProfile != "" {
f, err := os.Create(cpuProfile)
if err != nil {
fmt.Fprintf(os.Stderr, "testing: %s\n", err)
return
}
if err := pprof.StartCPUProfile(f); err != nil {
fmt.Fprintf(os.Stderr, "testing: can't start cpu profile: %s\n", err)
f.Close()
return
}
}
}
func after(data []stats.BenchResults) {
if cpuProfile != "" {
pprof.StopCPUProfile() // flushes profile to disk
}
if memProfile != "" {
f, err := os.Create(memProfile)
if err != nil {
fmt.Fprintf(os.Stderr, "testing: %s\n", err)
os.Exit(2)
}
runtime.GC() // materialize all statistics
if err = pprof.WriteHeapProfile(f); err != nil {
fmt.Fprintf(os.Stderr, "testing: can't write heap profile %s: %s\n", memProfile, err)
os.Exit(2)
}
f.Close()
}
if benchmarkResultFile != "" {
f, err := os.Create(benchmarkResultFile)
if err != nil {
log.Fatalf("testing: can't write benchmark result %s: %s\n", benchmarkResultFile, err)
}
dataEncoder := gob.NewEncoder(f)
dataEncoder.Encode(data)
f.Close()
}
}
// nopCompressor is a compressor that just copies data.
type nopCompressor struct{}
func (nopCompressor) Do(w io.Writer, p []byte) error {
n, err := w.Write(p)
if err != nil {
return err
}
if n != len(p) {
return fmt.Errorf("nopCompressor.Write: wrote %v bytes; want %v", n, len(p))
}
return nil
}
func (nopCompressor) Type() string { return "nop" }
// nopDecompressor is a decompressor that just copies data.
type nopDecompressor struct{}
func (nopDecompressor) Do(r io.Reader) ([]byte, error) { return ioutil.ReadAll(r) }
func (nopDecompressor) Type() string { return "nop" }

369
vendor/google.golang.org/grpc/benchmark/benchmark.go generated vendored Normal file
View file

@ -0,0 +1,369 @@
/*
*
* Copyright 2014 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
//go:generate protoc -I grpc_testing --go_out=plugins=grpc:grpc_testing grpc_testing/control.proto grpc_testing/messages.proto grpc_testing/payloads.proto grpc_testing/services.proto grpc_testing/stats.proto
/*
Package benchmark implements the building blocks to setup end-to-end gRPC benchmarks.
*/
package benchmark
import (
"fmt"
"io"
"net"
"sync"
"testing"
"time"
"golang.org/x/net/context"
"google.golang.org/grpc"
testpb "google.golang.org/grpc/benchmark/grpc_testing"
"google.golang.org/grpc/benchmark/latency"
"google.golang.org/grpc/benchmark/stats"
"google.golang.org/grpc/grpclog"
)
// AddOne add 1 to the features slice
func AddOne(features []int, featuresMaxPosition []int) {
for i := len(features) - 1; i >= 0; i-- {
features[i] = (features[i] + 1)
if features[i]/featuresMaxPosition[i] == 0 {
break
}
features[i] = features[i] % featuresMaxPosition[i]
}
}
// Allows reuse of the same testpb.Payload object.
func setPayload(p *testpb.Payload, t testpb.PayloadType, size int) {
if size < 0 {
grpclog.Fatalf("Requested a response with invalid length %d", size)
}
body := make([]byte, size)
switch t {
case testpb.PayloadType_COMPRESSABLE:
case testpb.PayloadType_UNCOMPRESSABLE:
grpclog.Fatalf("PayloadType UNCOMPRESSABLE is not supported")
default:
grpclog.Fatalf("Unsupported payload type: %d", t)
}
p.Type = t
p.Body = body
}
func newPayload(t testpb.PayloadType, size int) *testpb.Payload {
p := new(testpb.Payload)
setPayload(p, t, size)
return p
}
type testServer struct {
}
func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
return &testpb.SimpleResponse{
Payload: newPayload(in.ResponseType, int(in.ResponseSize)),
}, nil
}
func (s *testServer) StreamingCall(stream testpb.BenchmarkService_StreamingCallServer) error {
response := &testpb.SimpleResponse{
Payload: new(testpb.Payload),
}
in := new(testpb.SimpleRequest)
for {
// use ServerStream directly to reuse the same testpb.SimpleRequest object
err := stream.(grpc.ServerStream).RecvMsg(in)
if err == io.EOF {
// read done.
return nil
}
if err != nil {
return err
}
setPayload(response.Payload, in.ResponseType, int(in.ResponseSize))
if err := stream.Send(response); err != nil {
return err
}
}
}
// byteBufServer is a gRPC server that sends and receives byte buffer.
// The purpose is to benchmark the gRPC performance without protobuf serialization/deserialization overhead.
type byteBufServer struct {
respSize int32
}
// UnaryCall is an empty function and is not used for benchmark.
// If bytebuf UnaryCall benchmark is needed later, the function body needs to be updated.
func (s *byteBufServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
return &testpb.SimpleResponse{}, nil
}
func (s *byteBufServer) StreamingCall(stream testpb.BenchmarkService_StreamingCallServer) error {
for {
var in []byte
err := stream.(grpc.ServerStream).RecvMsg(&in)
if err == io.EOF {
return nil
}
if err != nil {
return err
}
out := make([]byte, s.respSize)
if err := stream.(grpc.ServerStream).SendMsg(&out); err != nil {
return err
}
}
}
// ServerInfo contains the information to create a gRPC benchmark server.
type ServerInfo struct {
// Type is the type of the server.
// It should be "protobuf" or "bytebuf".
Type string
// Metadata is an optional configuration.
// For "protobuf", it's ignored.
// For "bytebuf", it should be an int representing response size.
Metadata interface{}
// Listener is the network listener for the server to use
Listener net.Listener
}
// StartServer starts a gRPC server serving a benchmark service according to info.
// It returns a function to stop the server.
func StartServer(info ServerInfo, opts ...grpc.ServerOption) func() {
opts = append(opts, grpc.WriteBufferSize(128*1024))
opts = append(opts, grpc.ReadBufferSize(128*1024))
s := grpc.NewServer(opts...)
switch info.Type {
case "protobuf":
testpb.RegisterBenchmarkServiceServer(s, &testServer{})
case "bytebuf":
respSize, ok := info.Metadata.(int32)
if !ok {
grpclog.Fatalf("failed to StartServer, invalid metadata: %v, for Type: %v", info.Metadata, info.Type)
}
testpb.RegisterBenchmarkServiceServer(s, &byteBufServer{respSize: respSize})
default:
grpclog.Fatalf("failed to StartServer, unknown Type: %v", info.Type)
}
go s.Serve(info.Listener)
return func() {
s.Stop()
}
}
// DoUnaryCall performs an unary RPC with given stub and request and response sizes.
func DoUnaryCall(tc testpb.BenchmarkServiceClient, reqSize, respSize int) error {
pl := newPayload(testpb.PayloadType_COMPRESSABLE, reqSize)
req := &testpb.SimpleRequest{
ResponseType: pl.Type,
ResponseSize: int32(respSize),
Payload: pl,
}
if _, err := tc.UnaryCall(context.Background(), req); err != nil {
return fmt.Errorf("/BenchmarkService/UnaryCall(_, _) = _, %v, want _, <nil>", err)
}
return nil
}
// DoStreamingRoundTrip performs a round trip for a single streaming rpc.
func DoStreamingRoundTrip(stream testpb.BenchmarkService_StreamingCallClient, reqSize, respSize int) error {
pl := newPayload(testpb.PayloadType_COMPRESSABLE, reqSize)
req := &testpb.SimpleRequest{
ResponseType: pl.Type,
ResponseSize: int32(respSize),
Payload: pl,
}
if err := stream.Send(req); err != nil {
return fmt.Errorf("/BenchmarkService/StreamingCall.Send(_) = %v, want <nil>", err)
}
if _, err := stream.Recv(); err != nil {
// EOF is a valid error here.
if err == io.EOF {
return nil
}
return fmt.Errorf("/BenchmarkService/StreamingCall.Recv(_) = %v, want <nil>", err)
}
return nil
}
// DoByteBufStreamingRoundTrip performs a round trip for a single streaming rpc, using a custom codec for byte buffer.
func DoByteBufStreamingRoundTrip(stream testpb.BenchmarkService_StreamingCallClient, reqSize, respSize int) error {
out := make([]byte, reqSize)
if err := stream.(grpc.ClientStream).SendMsg(&out); err != nil {
return fmt.Errorf("/BenchmarkService/StreamingCall.(ClientStream).SendMsg(_) = %v, want <nil>", err)
}
var in []byte
if err := stream.(grpc.ClientStream).RecvMsg(&in); err != nil {
// EOF is a valid error here.
if err == io.EOF {
return nil
}
return fmt.Errorf("/BenchmarkService/StreamingCall.(ClientStream).RecvMsg(_) = %v, want <nil>", err)
}
return nil
}
// NewClientConn creates a gRPC client connection to addr.
func NewClientConn(addr string, opts ...grpc.DialOption) *grpc.ClientConn {
return NewClientConnWithContext(context.Background(), addr, opts...)
}
// NewClientConnWithContext creates a gRPC client connection to addr using ctx.
func NewClientConnWithContext(ctx context.Context, addr string, opts ...grpc.DialOption) *grpc.ClientConn {
opts = append(opts, grpc.WithWriteBufferSize(128*1024))
opts = append(opts, grpc.WithReadBufferSize(128*1024))
conn, err := grpc.DialContext(ctx, addr, opts...)
if err != nil {
grpclog.Fatalf("NewClientConn(%q) failed to create a ClientConn %v", addr, err)
}
return conn
}
func runUnary(b *testing.B, benchFeatures stats.Features) {
s := stats.AddStats(b, 38)
nw := &latency.Network{Kbps: benchFeatures.Kbps, Latency: benchFeatures.Latency, MTU: benchFeatures.Mtu}
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
grpclog.Fatalf("Failed to listen: %v", err)
}
target := lis.Addr().String()
lis = nw.Listener(lis)
stopper := StartServer(ServerInfo{Type: "protobuf", Listener: lis}, grpc.MaxConcurrentStreams(uint32(benchFeatures.MaxConcurrentCalls+1)))
defer stopper()
conn := NewClientConn(
target, grpc.WithInsecure(),
grpc.WithDialer(func(address string, timeout time.Duration) (net.Conn, error) {
return nw.TimeoutDialer(net.DialTimeout)("tcp", address, timeout)
}),
)
tc := testpb.NewBenchmarkServiceClient(conn)
// Warm up connection.
for i := 0; i < 10; i++ {
unaryCaller(tc, benchFeatures.ReqSizeBytes, benchFeatures.RespSizeBytes)
}
ch := make(chan int, benchFeatures.MaxConcurrentCalls*4)
var (
mu sync.Mutex
wg sync.WaitGroup
)
wg.Add(benchFeatures.MaxConcurrentCalls)
// Distribute the b.N calls over maxConcurrentCalls workers.
for i := 0; i < benchFeatures.MaxConcurrentCalls; i++ {
go func() {
for range ch {
start := time.Now()
unaryCaller(tc, benchFeatures.ReqSizeBytes, benchFeatures.RespSizeBytes)
elapse := time.Since(start)
mu.Lock()
s.Add(elapse)
mu.Unlock()
}
wg.Done()
}()
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
ch <- i
}
close(ch)
wg.Wait()
b.StopTimer()
conn.Close()
}
func runStream(b *testing.B, benchFeatures stats.Features) {
s := stats.AddStats(b, 38)
nw := &latency.Network{Kbps: benchFeatures.Kbps, Latency: benchFeatures.Latency, MTU: benchFeatures.Mtu}
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
grpclog.Fatalf("Failed to listen: %v", err)
}
target := lis.Addr().String()
lis = nw.Listener(lis)
stopper := StartServer(ServerInfo{Type: "protobuf", Listener: lis}, grpc.MaxConcurrentStreams(uint32(benchFeatures.MaxConcurrentCalls+1)))
defer stopper()
conn := NewClientConn(
target, grpc.WithInsecure(),
grpc.WithDialer(func(address string, timeout time.Duration) (net.Conn, error) {
return nw.TimeoutDialer(net.DialTimeout)("tcp", address, timeout)
}),
)
tc := testpb.NewBenchmarkServiceClient(conn)
// Warm up connection.
stream, err := tc.StreamingCall(context.Background())
if err != nil {
b.Fatalf("%v.StreamingCall(_) = _, %v", tc, err)
}
for i := 0; i < 10; i++ {
streamCaller(stream, benchFeatures.ReqSizeBytes, benchFeatures.RespSizeBytes)
}
ch := make(chan struct{}, benchFeatures.MaxConcurrentCalls*4)
var (
mu sync.Mutex
wg sync.WaitGroup
)
wg.Add(benchFeatures.MaxConcurrentCalls)
// Distribute the b.N calls over maxConcurrentCalls workers.
for i := 0; i < benchFeatures.MaxConcurrentCalls; i++ {
stream, err := tc.StreamingCall(context.Background())
if err != nil {
b.Fatalf("%v.StreamingCall(_) = _, %v", tc, err)
}
go func() {
for range ch {
start := time.Now()
streamCaller(stream, benchFeatures.ReqSizeBytes, benchFeatures.RespSizeBytes)
elapse := time.Since(start)
mu.Lock()
s.Add(elapse)
mu.Unlock()
}
wg.Done()
}()
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
ch <- struct{}{}
}
close(ch)
wg.Wait()
b.StopTimer()
conn.Close()
}
func unaryCaller(client testpb.BenchmarkServiceClient, reqSize, respSize int) {
if err := DoUnaryCall(client, reqSize, respSize); err != nil {
grpclog.Fatalf("DoUnaryCall failed: %v", err)
}
}
func streamCaller(stream testpb.BenchmarkService_StreamingCallClient, reqSize, respSize int) {
if err := DoStreamingRoundTrip(stream, reqSize, respSize); err != nil {
grpclog.Fatalf("DoStreamingRoundTrip failed: %v", err)
}
}

View file

@ -0,0 +1,112 @@
// +build go1.6,!go1.7
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package benchmark
import (
"os"
"testing"
"google.golang.org/grpc"
"google.golang.org/grpc/benchmark/stats"
)
func BenchmarkClientStreamc1(b *testing.B) {
grpc.EnableTracing = true
runStream(b, stats.Features{"", true, 0, 0, 0, 1, 1, 1, false, false})
}
func BenchmarkClientStreamc8(b *testing.B) {
grpc.EnableTracing = true
runStream(b, stats.Features{"", true, 0, 0, 0, 8, 1, 1, false, false})
}
func BenchmarkClientStreamc64(b *testing.B) {
grpc.EnableTracing = true
runStream(b, stats.Features{"", true, 0, 0, 0, 64, 1, 1, false, false})
}
func BenchmarkClientStreamc512(b *testing.B) {
grpc.EnableTracing = true
runStream(b, stats.Features{"", true, 0, 0, 0, 512, 1, 1, false, false})
}
func BenchmarkClientUnaryc1(b *testing.B) {
grpc.EnableTracing = true
runStream(b, stats.Features{"", true, 0, 0, 0, 1, 1, 1, false, false})
}
func BenchmarkClientUnaryc8(b *testing.B) {
grpc.EnableTracing = true
runStream(b, stats.Features{"", true, 0, 0, 0, 8, 1, 1, false, false})
}
func BenchmarkClientUnaryc64(b *testing.B) {
grpc.EnableTracing = true
runStream(b, stats.Features{"", true, 0, 0, 0, 64, 1, 1, false, false})
}
func BenchmarkClientUnaryc512(b *testing.B) {
grpc.EnableTracing = true
runStream(b, stats.Features{"", true, 0, 0, 0, 512, 1, 1, false, false})
}
func BenchmarkClientStreamNoTracec1(b *testing.B) {
grpc.EnableTracing = false
runStream(b, stats.Features{"", false, 0, 0, 0, 1, 1, 1, false, false})
}
func BenchmarkClientStreamNoTracec8(b *testing.B) {
grpc.EnableTracing = false
runStream(b, stats.Features{"", false, 0, 0, 0, 8, 1, 1, false, false})
}
func BenchmarkClientStreamNoTracec64(b *testing.B) {
grpc.EnableTracing = false
runStream(b, stats.Features{"", false, 0, 0, 0, 64, 1, 1, false, false})
}
func BenchmarkClientStreamNoTracec512(b *testing.B) {
grpc.EnableTracing = false
runStream(b, stats.Features{"", false, 0, 0, 0, 512, 1, 1, false, false})
}
func BenchmarkClientUnaryNoTracec1(b *testing.B) {
grpc.EnableTracing = false
runStream(b, stats.Features{"", false, 0, 0, 0, 1, 1, 1, false, false})
}
func BenchmarkClientUnaryNoTracec8(b *testing.B) {
grpc.EnableTracing = false
runStream(b, stats.Features{"", false, 0, 0, 0, 8, 1, 1, false, false})
}
func BenchmarkClientUnaryNoTracec64(b *testing.B) {
grpc.EnableTracing = false
runStream(b, stats.Features{"", false, 0, 0, 0, 64, 1, 1, false, false})
}
func BenchmarkClientUnaryNoTracec512(b *testing.B) {
grpc.EnableTracing = false
runStream(b, stats.Features{"", false, 0, 0, 0, 512, 1, 1, false, false})
runStream(b, stats.Features{"", false, 0, 0, 0, 512, 1, 1, false, false})
}
func TestMain(m *testing.M) {
os.Exit(stats.RunTestMain(m))
}

View file

@ -0,0 +1,85 @@
// +build go1.7
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package benchmark
import (
"fmt"
"os"
"reflect"
"testing"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/benchmark/stats"
)
func BenchmarkClient(b *testing.B) {
enableTrace := []bool{true, false} // run both enable and disable by default
// When set the latency to 0 (no delay), the result is slower than the real result with no delay
// because latency simulation section has extra operations
latency := []time.Duration{0, 40 * time.Millisecond} // if non-positive, no delay.
kbps := []int{0, 10240} // if non-positive, infinite
mtu := []int{0} // if non-positive, infinite
maxConcurrentCalls := []int{1, 8, 64, 512}
reqSizeBytes := []int{1, 1024 * 1024}
respSizeBytes := []int{1, 1024 * 1024}
featuresCurPos := make([]int, 7)
// 0:enableTracing 1:md 2:ltc 3:kbps 4:mtu 5:maxC 6:connCount 7:reqSize 8:respSize
featuresMaxPosition := []int{len(enableTrace), len(latency), len(kbps), len(mtu), len(maxConcurrentCalls), len(reqSizeBytes), len(respSizeBytes)}
initalPos := make([]int, len(featuresCurPos))
// run benchmarks
start := true
for !reflect.DeepEqual(featuresCurPos, initalPos) || start {
start = false
tracing := "Trace"
if !enableTrace[featuresCurPos[0]] {
tracing = "noTrace"
}
benchFeature := stats.Features{
EnableTrace: enableTrace[featuresCurPos[0]],
Latency: latency[featuresCurPos[1]],
Kbps: kbps[featuresCurPos[2]],
Mtu: mtu[featuresCurPos[3]],
MaxConcurrentCalls: maxConcurrentCalls[featuresCurPos[4]],
ReqSizeBytes: reqSizeBytes[featuresCurPos[5]],
RespSizeBytes: respSizeBytes[featuresCurPos[6]],
}
grpc.EnableTracing = enableTrace[featuresCurPos[0]]
b.Run(fmt.Sprintf("Unary-%s-%s",
tracing, benchFeature.String()), func(b *testing.B) {
runUnary(b, benchFeature)
})
b.Run(fmt.Sprintf("Stream-%s-%s",
tracing, benchFeature.String()), func(b *testing.B) {
runStream(b, benchFeature)
})
AddOne(featuresCurPos, featuresMaxPosition)
}
}
func TestMain(m *testing.M) {
os.Exit(stats.RunTestMain(m))
}

View file

@ -0,0 +1,133 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
/*
To format the benchmark result:
go run benchmark/benchresult/main.go resultfile
To see the performance change based on a old result:
go run benchmark/benchresult/main.go resultfile_old resultfile
It will print the comparison result of intersection benchmarks between two files.
*/
package main
import (
"encoding/gob"
"fmt"
"log"
"os"
"strconv"
"strings"
"time"
"google.golang.org/grpc/benchmark/stats"
)
func createMap(fileName string, m map[string]stats.BenchResults) {
f, err := os.Open(fileName)
if err != nil {
log.Fatalf("Read file %s error: %s\n", fileName, err)
}
defer f.Close()
var data []stats.BenchResults
decoder := gob.NewDecoder(f)
if err = decoder.Decode(&data); err != nil {
log.Fatalf("Decode file %s error: %s\n", fileName, err)
}
for _, d := range data {
m[d.RunMode+"-"+d.Features.String()] = d
}
}
func intChange(title string, val1, val2 int64) string {
return fmt.Sprintf("%10s %12s %12s %8.2f%%\n", title, strconv.FormatInt(val1, 10),
strconv.FormatInt(val2, 10), float64(val2-val1)*100/float64(val1))
}
func timeChange(title int, val1, val2 time.Duration) string {
return fmt.Sprintf("%10s %12s %12s %8.2f%%\n", strconv.Itoa(title)+" latency", val1.String(),
val2.String(), float64(val2-val1)*100/float64(val1))
}
func compareTwoMap(m1, m2 map[string]stats.BenchResults) {
for k2, v2 := range m2 {
if v1, ok := m1[k2]; ok {
changes := k2 + "\n"
changes += fmt.Sprintf("%10s %12s %12s %8s\n", "Title", "Before", "After", "Percentage")
changes += intChange("Bytes/op", v1.AllocedBytesPerOp, v2.AllocedBytesPerOp)
changes += intChange("Allocs/op", v1.AllocsPerOp, v2.AllocsPerOp)
changes += timeChange(v1.Latency[1].Percent, v1.Latency[1].Value, v2.Latency[1].Value)
changes += timeChange(v1.Latency[2].Percent, v1.Latency[2].Value, v2.Latency[2].Value)
fmt.Printf("%s\n", changes)
}
}
}
func compareBenchmark(file1, file2 string) {
var BenchValueFile1 map[string]stats.BenchResults
var BenchValueFile2 map[string]stats.BenchResults
BenchValueFile1 = make(map[string]stats.BenchResults)
BenchValueFile2 = make(map[string]stats.BenchResults)
createMap(file1, BenchValueFile1)
createMap(file2, BenchValueFile2)
compareTwoMap(BenchValueFile1, BenchValueFile2)
}
func printline(benchName, ltc50, ltc90, allocByte, allocsOp interface{}) {
fmt.Printf("%-80v%12v%12v%12v%12v\n", benchName, ltc50, ltc90, allocByte, allocsOp)
}
func formatBenchmark(fileName string) {
f, err := os.Open(fileName)
if err != nil {
log.Fatalf("Read file %s error: %s\n", fileName, err)
}
defer f.Close()
var data []stats.BenchResults
decoder := gob.NewDecoder(f)
if err = decoder.Decode(&data); err != nil {
log.Fatalf("Decode file %s error: %s\n", fileName, err)
}
if len(data) == 0 {
log.Fatalf("No data in file %s\n", fileName)
}
printPos := data[0].SharedPosion
fmt.Println("\nShared features:\n" + strings.Repeat("-", 20))
fmt.Print(stats.PartialPrintString(printPos, data[0].Features, true))
fmt.Println(strings.Repeat("-", 35))
for i := 0; i < len(data[0].SharedPosion); i++ {
printPos[i] = !printPos[i]
}
printline("Name", "latency-50", "latency-90", "Alloc (B)", "Alloc (#)")
for _, d := range data {
name := d.RunMode + stats.PartialPrintString(printPos, d.Features, false)
printline(name, d.Latency[1].Value.String(), d.Latency[2].Value.String(),
d.AllocedBytesPerOp, d.AllocsPerOp)
}
}
func main() {
if len(os.Args) == 2 {
formatBenchmark(os.Args[1])
} else {
compareBenchmark(os.Args[1], os.Args[2])
}
}

187
vendor/google.golang.org/grpc/benchmark/client/main.go generated vendored Normal file
View file

@ -0,0 +1,187 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package main
import (
"flag"
"fmt"
"os"
"runtime"
"runtime/pprof"
"sync"
"time"
"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/benchmark"
testpb "google.golang.org/grpc/benchmark/grpc_testing"
"google.golang.org/grpc/benchmark/stats"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/syscall"
)
var (
port = flag.String("port", "50051", "Localhost port to connect to.")
numRPC = flag.Int("r", 1, "The number of concurrent RPCs on each connection.")
numConn = flag.Int("c", 1, "The number of parallel connections.")
warmupDur = flag.Int("w", 10, "Warm-up duration in seconds")
duration = flag.Int("d", 60, "Benchmark duration in seconds")
rqSize = flag.Int("req", 1, "Request message size in bytes.")
rspSize = flag.Int("resp", 1, "Response message size in bytes.")
rpcType = flag.String("rpc_type", "unary",
`Configure different client rpc type. Valid options are:
unary;
streaming.`)
testName = flag.String("test_name", "", "Name of the test used for creating profiles.")
wg sync.WaitGroup
hopts = stats.HistogramOptions{
NumBuckets: 2495,
GrowthFactor: .01,
}
mu sync.Mutex
hists []*stats.Histogram
)
func main() {
flag.Parse()
if *testName == "" {
grpclog.Fatalf("test_name not set")
}
req := &testpb.SimpleRequest{
ResponseType: testpb.PayloadType_COMPRESSABLE,
ResponseSize: int32(*rspSize),
Payload: &testpb.Payload{
Type: testpb.PayloadType_COMPRESSABLE,
Body: make([]byte, *rqSize),
},
}
connectCtx, connectCancel := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second))
defer connectCancel()
ccs := buildConnections(connectCtx)
warmDeadline := time.Now().Add(time.Duration(*warmupDur) * time.Second)
endDeadline := warmDeadline.Add(time.Duration(*duration) * time.Second)
cf, err := os.Create("/tmp/" + *testName + ".cpu")
if err != nil {
grpclog.Fatalf("Error creating file: %v", err)
}
defer cf.Close()
pprof.StartCPUProfile(cf)
cpuBeg := syscall.GetCPUTime()
for _, cc := range ccs {
runWithConn(cc, req, warmDeadline, endDeadline)
}
wg.Wait()
cpu := time.Duration(syscall.GetCPUTime() - cpuBeg)
pprof.StopCPUProfile()
mf, err := os.Create("/tmp/" + *testName + ".mem")
if err != nil {
grpclog.Fatalf("Error creating file: %v", err)
}
defer mf.Close()
runtime.GC() // materialize all statistics
if err := pprof.WriteHeapProfile(mf); err != nil {
grpclog.Fatalf("Error writing memory profile: %v", err)
}
hist := stats.NewHistogram(hopts)
for _, h := range hists {
hist.Merge(h)
}
parseHist(hist)
fmt.Println("Client CPU utilization:", cpu)
fmt.Println("Client CPU profile:", cf.Name())
fmt.Println("Client Mem Profile:", mf.Name())
}
func buildConnections(ctx context.Context) []*grpc.ClientConn {
ccs := make([]*grpc.ClientConn, *numConn)
for i := range ccs {
ccs[i] = benchmark.NewClientConnWithContext(ctx, "localhost:"+*port, grpc.WithInsecure(), grpc.WithBlock())
}
return ccs
}
func runWithConn(cc *grpc.ClientConn, req *testpb.SimpleRequest, warmDeadline, endDeadline time.Time) {
for i := 0; i < *numRPC; i++ {
wg.Add(1)
go func() {
defer wg.Done()
caller := makeCaller(cc, req)
hist := stats.NewHistogram(hopts)
for {
start := time.Now()
if start.After(endDeadline) {
mu.Lock()
hists = append(hists, hist)
mu.Unlock()
return
}
caller()
elapsed := time.Since(start)
if start.After(warmDeadline) {
hist.Add(elapsed.Nanoseconds())
}
}
}()
}
}
func makeCaller(cc *grpc.ClientConn, req *testpb.SimpleRequest) func() {
client := testpb.NewBenchmarkServiceClient(cc)
if *rpcType == "unary" {
return func() {
if _, err := client.UnaryCall(context.Background(), req); err != nil {
grpclog.Fatalf("RPC failed: %v", err)
}
}
}
stream, err := client.StreamingCall(context.Background())
if err != nil {
grpclog.Fatalf("RPC failed: %v", err)
}
return func() {
if err := stream.Send(req); err != nil {
grpclog.Fatalf("Streaming RPC failed to send: %v", err)
}
if _, err := stream.Recv(); err != nil {
grpclog.Fatalf("Streaming RPC failed to read: %v", err)
}
}
}
func parseHist(hist *stats.Histogram) {
fmt.Println("qps:", float64(hist.Count)/float64(*duration))
fmt.Printf("Latency: (50/90/99 %%ile): %v/%v/%v\n",
time.Duration(median(.5, hist)),
time.Duration(median(.9, hist)),
time.Duration(median(.99, hist)))
}
func median(percentile float64, h *stats.Histogram) int64 {
need := int64(float64(h.Count) * percentile)
have := int64(0)
for _, bucket := range h.Buckets {
count := bucket.Count
if have+count >= need {
percent := float64(need-have) / float64(count)
return int64((1.0-percent)*bucket.LowBound + percent*bucket.LowBound*(1.0+hopts.GrowthFactor))
}
have += bucket.Count
}
panic("should have found a bound")
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,186 @@
// Copyright 2016 gRPC authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
import "payloads.proto";
import "stats.proto";
package grpc.testing;
enum ClientType {
SYNC_CLIENT = 0;
ASYNC_CLIENT = 1;
}
enum ServerType {
SYNC_SERVER = 0;
ASYNC_SERVER = 1;
ASYNC_GENERIC_SERVER = 2;
}
enum RpcType {
UNARY = 0;
STREAMING = 1;
}
// Parameters of poisson process distribution, which is a good representation
// of activity coming in from independent identical stationary sources.
message PoissonParams {
// The rate of arrivals (a.k.a. lambda parameter of the exp distribution).
double offered_load = 1;
}
message UniformParams {
double interarrival_lo = 1;
double interarrival_hi = 2;
}
message DeterministicParams {
double offered_load = 1;
}
message ParetoParams {
double interarrival_base = 1;
double alpha = 2;
}
// Once an RPC finishes, immediately start a new one.
// No configuration parameters needed.
message ClosedLoopParams {
}
message LoadParams {
oneof load {
ClosedLoopParams closed_loop = 1;
PoissonParams poisson = 2;
UniformParams uniform = 3;
DeterministicParams determ = 4;
ParetoParams pareto = 5;
};
}
// presence of SecurityParams implies use of TLS
message SecurityParams {
bool use_test_ca = 1;
string server_host_override = 2;
}
message ClientConfig {
// List of targets to connect to. At least one target needs to be specified.
repeated string server_targets = 1;
ClientType client_type = 2;
SecurityParams security_params = 3;
// How many concurrent RPCs to start for each channel.
// For synchronous client, use a separate thread for each outstanding RPC.
int32 outstanding_rpcs_per_channel = 4;
// Number of independent client channels to create.
// i-th channel will connect to server_target[i % server_targets.size()]
int32 client_channels = 5;
// Only for async client. Number of threads to use to start/manage RPCs.
int32 async_client_threads = 7;
RpcType rpc_type = 8;
// The requested load for the entire client (aggregated over all the threads).
LoadParams load_params = 10;
PayloadConfig payload_config = 11;
HistogramParams histogram_params = 12;
// Specify the cores we should run the client on, if desired
repeated int32 core_list = 13;
int32 core_limit = 14;
}
message ClientStatus {
ClientStats stats = 1;
}
// Request current stats
message Mark {
// if true, the stats will be reset after taking their snapshot.
bool reset = 1;
}
message ClientArgs {
oneof argtype {
ClientConfig setup = 1;
Mark mark = 2;
}
}
message ServerConfig {
ServerType server_type = 1;
SecurityParams security_params = 2;
// Port on which to listen. Zero means pick unused port.
int32 port = 4;
// Only for async server. Number of threads used to serve the requests.
int32 async_server_threads = 7;
// Specify the number of cores to limit server to, if desired
int32 core_limit = 8;
// payload config, used in generic server
PayloadConfig payload_config = 9;
// Specify the cores we should run the server on, if desired
repeated int32 core_list = 10;
}
message ServerArgs {
oneof argtype {
ServerConfig setup = 1;
Mark mark = 2;
}
}
message ServerStatus {
ServerStats stats = 1;
// the port bound by the server
int32 port = 2;
// Number of cores available to the server
int32 cores = 3;
}
message CoreRequest {
}
message CoreResponse {
// Number of cores available on the server
int32 cores = 1;
}
message Void {
}
// A single performance scenario: input to qps_json_driver
message Scenario {
// Human readable name for this scenario
string name = 1;
// Client configuration
ClientConfig client_config = 2;
// Number of clients to start for the test
int32 num_clients = 3;
// Server configuration
ServerConfig server_config = 4;
// Number of servers to start for the test
int32 num_servers = 5;
// Warmup period, in seconds
int32 warmup_seconds = 6;
// Benchmark time, in seconds
int32 benchmark_seconds = 7;
// Number of workers to spawn locally (usually zero)
int32 spawn_local_worker_count = 8;
}
// A set of scenarios to be run with qps_json_driver
message Scenarios {
repeated Scenario scenarios = 1;
}

View file

@ -0,0 +1,731 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// source: messages.proto
package grpc_testing
import proto "github.com/golang/protobuf/proto"
import fmt "fmt"
import math "math"
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
// The type of payload that should be returned.
type PayloadType int32
const (
// Compressable text format.
PayloadType_COMPRESSABLE PayloadType = 0
// Uncompressable binary format.
PayloadType_UNCOMPRESSABLE PayloadType = 1
// Randomly chosen from all other formats defined in this enum.
PayloadType_RANDOM PayloadType = 2
)
var PayloadType_name = map[int32]string{
0: "COMPRESSABLE",
1: "UNCOMPRESSABLE",
2: "RANDOM",
}
var PayloadType_value = map[string]int32{
"COMPRESSABLE": 0,
"UNCOMPRESSABLE": 1,
"RANDOM": 2,
}
func (x PayloadType) String() string {
return proto.EnumName(PayloadType_name, int32(x))
}
func (PayloadType) EnumDescriptor() ([]byte, []int) {
return fileDescriptor_messages_5c70222ad96bf232, []int{0}
}
// Compression algorithms
type CompressionType int32
const (
// No compression
CompressionType_NONE CompressionType = 0
CompressionType_GZIP CompressionType = 1
CompressionType_DEFLATE CompressionType = 2
)
var CompressionType_name = map[int32]string{
0: "NONE",
1: "GZIP",
2: "DEFLATE",
}
var CompressionType_value = map[string]int32{
"NONE": 0,
"GZIP": 1,
"DEFLATE": 2,
}
func (x CompressionType) String() string {
return proto.EnumName(CompressionType_name, int32(x))
}
func (CompressionType) EnumDescriptor() ([]byte, []int) {
return fileDescriptor_messages_5c70222ad96bf232, []int{1}
}
// A block of data, to simply increase gRPC message size.
type Payload struct {
// The type of data in body.
Type PayloadType `protobuf:"varint,1,opt,name=type,proto3,enum=grpc.testing.PayloadType" json:"type,omitempty"`
// Primary contents of payload.
Body []byte `protobuf:"bytes,2,opt,name=body,proto3" json:"body,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *Payload) Reset() { *m = Payload{} }
func (m *Payload) String() string { return proto.CompactTextString(m) }
func (*Payload) ProtoMessage() {}
func (*Payload) Descriptor() ([]byte, []int) {
return fileDescriptor_messages_5c70222ad96bf232, []int{0}
}
func (m *Payload) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_Payload.Unmarshal(m, b)
}
func (m *Payload) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_Payload.Marshal(b, m, deterministic)
}
func (dst *Payload) XXX_Merge(src proto.Message) {
xxx_messageInfo_Payload.Merge(dst, src)
}
func (m *Payload) XXX_Size() int {
return xxx_messageInfo_Payload.Size(m)
}
func (m *Payload) XXX_DiscardUnknown() {
xxx_messageInfo_Payload.DiscardUnknown(m)
}
var xxx_messageInfo_Payload proto.InternalMessageInfo
func (m *Payload) GetType() PayloadType {
if m != nil {
return m.Type
}
return PayloadType_COMPRESSABLE
}
func (m *Payload) GetBody() []byte {
if m != nil {
return m.Body
}
return nil
}
// A protobuf representation for grpc status. This is used by test
// clients to specify a status that the server should attempt to return.
type EchoStatus struct {
Code int32 `protobuf:"varint,1,opt,name=code,proto3" json:"code,omitempty"`
Message string `protobuf:"bytes,2,opt,name=message,proto3" json:"message,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *EchoStatus) Reset() { *m = EchoStatus{} }
func (m *EchoStatus) String() string { return proto.CompactTextString(m) }
func (*EchoStatus) ProtoMessage() {}
func (*EchoStatus) Descriptor() ([]byte, []int) {
return fileDescriptor_messages_5c70222ad96bf232, []int{1}
}
func (m *EchoStatus) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_EchoStatus.Unmarshal(m, b)
}
func (m *EchoStatus) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_EchoStatus.Marshal(b, m, deterministic)
}
func (dst *EchoStatus) XXX_Merge(src proto.Message) {
xxx_messageInfo_EchoStatus.Merge(dst, src)
}
func (m *EchoStatus) XXX_Size() int {
return xxx_messageInfo_EchoStatus.Size(m)
}
func (m *EchoStatus) XXX_DiscardUnknown() {
xxx_messageInfo_EchoStatus.DiscardUnknown(m)
}
var xxx_messageInfo_EchoStatus proto.InternalMessageInfo
func (m *EchoStatus) GetCode() int32 {
if m != nil {
return m.Code
}
return 0
}
func (m *EchoStatus) GetMessage() string {
if m != nil {
return m.Message
}
return ""
}
// Unary request.
type SimpleRequest struct {
// Desired payload type in the response from the server.
// If response_type is RANDOM, server randomly chooses one from other formats.
ResponseType PayloadType `protobuf:"varint,1,opt,name=response_type,json=responseType,proto3,enum=grpc.testing.PayloadType" json:"response_type,omitempty"`
// Desired payload size in the response from the server.
// If response_type is COMPRESSABLE, this denotes the size before compression.
ResponseSize int32 `protobuf:"varint,2,opt,name=response_size,json=responseSize,proto3" json:"response_size,omitempty"`
// Optional input payload sent along with the request.
Payload *Payload `protobuf:"bytes,3,opt,name=payload,proto3" json:"payload,omitempty"`
// Whether SimpleResponse should include username.
FillUsername bool `protobuf:"varint,4,opt,name=fill_username,json=fillUsername,proto3" json:"fill_username,omitempty"`
// Whether SimpleResponse should include OAuth scope.
FillOauthScope bool `protobuf:"varint,5,opt,name=fill_oauth_scope,json=fillOauthScope,proto3" json:"fill_oauth_scope,omitempty"`
// Compression algorithm to be used by the server for the response (stream)
ResponseCompression CompressionType `protobuf:"varint,6,opt,name=response_compression,json=responseCompression,proto3,enum=grpc.testing.CompressionType" json:"response_compression,omitempty"`
// Whether server should return a given status
ResponseStatus *EchoStatus `protobuf:"bytes,7,opt,name=response_status,json=responseStatus,proto3" json:"response_status,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *SimpleRequest) Reset() { *m = SimpleRequest{} }
func (m *SimpleRequest) String() string { return proto.CompactTextString(m) }
func (*SimpleRequest) ProtoMessage() {}
func (*SimpleRequest) Descriptor() ([]byte, []int) {
return fileDescriptor_messages_5c70222ad96bf232, []int{2}
}
func (m *SimpleRequest) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_SimpleRequest.Unmarshal(m, b)
}
func (m *SimpleRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_SimpleRequest.Marshal(b, m, deterministic)
}
func (dst *SimpleRequest) XXX_Merge(src proto.Message) {
xxx_messageInfo_SimpleRequest.Merge(dst, src)
}
func (m *SimpleRequest) XXX_Size() int {
return xxx_messageInfo_SimpleRequest.Size(m)
}
func (m *SimpleRequest) XXX_DiscardUnknown() {
xxx_messageInfo_SimpleRequest.DiscardUnknown(m)
}
var xxx_messageInfo_SimpleRequest proto.InternalMessageInfo
func (m *SimpleRequest) GetResponseType() PayloadType {
if m != nil {
return m.ResponseType
}
return PayloadType_COMPRESSABLE
}
func (m *SimpleRequest) GetResponseSize() int32 {
if m != nil {
return m.ResponseSize
}
return 0
}
func (m *SimpleRequest) GetPayload() *Payload {
if m != nil {
return m.Payload
}
return nil
}
func (m *SimpleRequest) GetFillUsername() bool {
if m != nil {
return m.FillUsername
}
return false
}
func (m *SimpleRequest) GetFillOauthScope() bool {
if m != nil {
return m.FillOauthScope
}
return false
}
func (m *SimpleRequest) GetResponseCompression() CompressionType {
if m != nil {
return m.ResponseCompression
}
return CompressionType_NONE
}
func (m *SimpleRequest) GetResponseStatus() *EchoStatus {
if m != nil {
return m.ResponseStatus
}
return nil
}
// Unary response, as configured by the request.
type SimpleResponse struct {
// Payload to increase message size.
Payload *Payload `protobuf:"bytes,1,opt,name=payload,proto3" json:"payload,omitempty"`
// The user the request came from, for verifying authentication was
// successful when the client expected it.
Username string `protobuf:"bytes,2,opt,name=username,proto3" json:"username,omitempty"`
// OAuth scope.
OauthScope string `protobuf:"bytes,3,opt,name=oauth_scope,json=oauthScope,proto3" json:"oauth_scope,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *SimpleResponse) Reset() { *m = SimpleResponse{} }
func (m *SimpleResponse) String() string { return proto.CompactTextString(m) }
func (*SimpleResponse) ProtoMessage() {}
func (*SimpleResponse) Descriptor() ([]byte, []int) {
return fileDescriptor_messages_5c70222ad96bf232, []int{3}
}
func (m *SimpleResponse) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_SimpleResponse.Unmarshal(m, b)
}
func (m *SimpleResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_SimpleResponse.Marshal(b, m, deterministic)
}
func (dst *SimpleResponse) XXX_Merge(src proto.Message) {
xxx_messageInfo_SimpleResponse.Merge(dst, src)
}
func (m *SimpleResponse) XXX_Size() int {
return xxx_messageInfo_SimpleResponse.Size(m)
}
func (m *SimpleResponse) XXX_DiscardUnknown() {
xxx_messageInfo_SimpleResponse.DiscardUnknown(m)
}
var xxx_messageInfo_SimpleResponse proto.InternalMessageInfo
func (m *SimpleResponse) GetPayload() *Payload {
if m != nil {
return m.Payload
}
return nil
}
func (m *SimpleResponse) GetUsername() string {
if m != nil {
return m.Username
}
return ""
}
func (m *SimpleResponse) GetOauthScope() string {
if m != nil {
return m.OauthScope
}
return ""
}
// Client-streaming request.
type StreamingInputCallRequest struct {
// Optional input payload sent along with the request.
Payload *Payload `protobuf:"bytes,1,opt,name=payload,proto3" json:"payload,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *StreamingInputCallRequest) Reset() { *m = StreamingInputCallRequest{} }
func (m *StreamingInputCallRequest) String() string { return proto.CompactTextString(m) }
func (*StreamingInputCallRequest) ProtoMessage() {}
func (*StreamingInputCallRequest) Descriptor() ([]byte, []int) {
return fileDescriptor_messages_5c70222ad96bf232, []int{4}
}
func (m *StreamingInputCallRequest) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_StreamingInputCallRequest.Unmarshal(m, b)
}
func (m *StreamingInputCallRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_StreamingInputCallRequest.Marshal(b, m, deterministic)
}
func (dst *StreamingInputCallRequest) XXX_Merge(src proto.Message) {
xxx_messageInfo_StreamingInputCallRequest.Merge(dst, src)
}
func (m *StreamingInputCallRequest) XXX_Size() int {
return xxx_messageInfo_StreamingInputCallRequest.Size(m)
}
func (m *StreamingInputCallRequest) XXX_DiscardUnknown() {
xxx_messageInfo_StreamingInputCallRequest.DiscardUnknown(m)
}
var xxx_messageInfo_StreamingInputCallRequest proto.InternalMessageInfo
func (m *StreamingInputCallRequest) GetPayload() *Payload {
if m != nil {
return m.Payload
}
return nil
}
// Client-streaming response.
type StreamingInputCallResponse struct {
// Aggregated size of payloads received from the client.
AggregatedPayloadSize int32 `protobuf:"varint,1,opt,name=aggregated_payload_size,json=aggregatedPayloadSize,proto3" json:"aggregated_payload_size,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *StreamingInputCallResponse) Reset() { *m = StreamingInputCallResponse{} }
func (m *StreamingInputCallResponse) String() string { return proto.CompactTextString(m) }
func (*StreamingInputCallResponse) ProtoMessage() {}
func (*StreamingInputCallResponse) Descriptor() ([]byte, []int) {
return fileDescriptor_messages_5c70222ad96bf232, []int{5}
}
func (m *StreamingInputCallResponse) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_StreamingInputCallResponse.Unmarshal(m, b)
}
func (m *StreamingInputCallResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_StreamingInputCallResponse.Marshal(b, m, deterministic)
}
func (dst *StreamingInputCallResponse) XXX_Merge(src proto.Message) {
xxx_messageInfo_StreamingInputCallResponse.Merge(dst, src)
}
func (m *StreamingInputCallResponse) XXX_Size() int {
return xxx_messageInfo_StreamingInputCallResponse.Size(m)
}
func (m *StreamingInputCallResponse) XXX_DiscardUnknown() {
xxx_messageInfo_StreamingInputCallResponse.DiscardUnknown(m)
}
var xxx_messageInfo_StreamingInputCallResponse proto.InternalMessageInfo
func (m *StreamingInputCallResponse) GetAggregatedPayloadSize() int32 {
if m != nil {
return m.AggregatedPayloadSize
}
return 0
}
// Configuration for a particular response.
type ResponseParameters struct {
// Desired payload sizes in responses from the server.
// If response_type is COMPRESSABLE, this denotes the size before compression.
Size int32 `protobuf:"varint,1,opt,name=size,proto3" json:"size,omitempty"`
// Desired interval between consecutive responses in the response stream in
// microseconds.
IntervalUs int32 `protobuf:"varint,2,opt,name=interval_us,json=intervalUs,proto3" json:"interval_us,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *ResponseParameters) Reset() { *m = ResponseParameters{} }
func (m *ResponseParameters) String() string { return proto.CompactTextString(m) }
func (*ResponseParameters) ProtoMessage() {}
func (*ResponseParameters) Descriptor() ([]byte, []int) {
return fileDescriptor_messages_5c70222ad96bf232, []int{6}
}
func (m *ResponseParameters) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ResponseParameters.Unmarshal(m, b)
}
func (m *ResponseParameters) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_ResponseParameters.Marshal(b, m, deterministic)
}
func (dst *ResponseParameters) XXX_Merge(src proto.Message) {
xxx_messageInfo_ResponseParameters.Merge(dst, src)
}
func (m *ResponseParameters) XXX_Size() int {
return xxx_messageInfo_ResponseParameters.Size(m)
}
func (m *ResponseParameters) XXX_DiscardUnknown() {
xxx_messageInfo_ResponseParameters.DiscardUnknown(m)
}
var xxx_messageInfo_ResponseParameters proto.InternalMessageInfo
func (m *ResponseParameters) GetSize() int32 {
if m != nil {
return m.Size
}
return 0
}
func (m *ResponseParameters) GetIntervalUs() int32 {
if m != nil {
return m.IntervalUs
}
return 0
}
// Server-streaming request.
type StreamingOutputCallRequest struct {
// Desired payload type in the response from the server.
// If response_type is RANDOM, the payload from each response in the stream
// might be of different types. This is to simulate a mixed type of payload
// stream.
ResponseType PayloadType `protobuf:"varint,1,opt,name=response_type,json=responseType,proto3,enum=grpc.testing.PayloadType" json:"response_type,omitempty"`
// Configuration for each expected response message.
ResponseParameters []*ResponseParameters `protobuf:"bytes,2,rep,name=response_parameters,json=responseParameters,proto3" json:"response_parameters,omitempty"`
// Optional input payload sent along with the request.
Payload *Payload `protobuf:"bytes,3,opt,name=payload,proto3" json:"payload,omitempty"`
// Compression algorithm to be used by the server for the response (stream)
ResponseCompression CompressionType `protobuf:"varint,6,opt,name=response_compression,json=responseCompression,proto3,enum=grpc.testing.CompressionType" json:"response_compression,omitempty"`
// Whether server should return a given status
ResponseStatus *EchoStatus `protobuf:"bytes,7,opt,name=response_status,json=responseStatus,proto3" json:"response_status,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *StreamingOutputCallRequest) Reset() { *m = StreamingOutputCallRequest{} }
func (m *StreamingOutputCallRequest) String() string { return proto.CompactTextString(m) }
func (*StreamingOutputCallRequest) ProtoMessage() {}
func (*StreamingOutputCallRequest) Descriptor() ([]byte, []int) {
return fileDescriptor_messages_5c70222ad96bf232, []int{7}
}
func (m *StreamingOutputCallRequest) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_StreamingOutputCallRequest.Unmarshal(m, b)
}
func (m *StreamingOutputCallRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_StreamingOutputCallRequest.Marshal(b, m, deterministic)
}
func (dst *StreamingOutputCallRequest) XXX_Merge(src proto.Message) {
xxx_messageInfo_StreamingOutputCallRequest.Merge(dst, src)
}
func (m *StreamingOutputCallRequest) XXX_Size() int {
return xxx_messageInfo_StreamingOutputCallRequest.Size(m)
}
func (m *StreamingOutputCallRequest) XXX_DiscardUnknown() {
xxx_messageInfo_StreamingOutputCallRequest.DiscardUnknown(m)
}
var xxx_messageInfo_StreamingOutputCallRequest proto.InternalMessageInfo
func (m *StreamingOutputCallRequest) GetResponseType() PayloadType {
if m != nil {
return m.ResponseType
}
return PayloadType_COMPRESSABLE
}
func (m *StreamingOutputCallRequest) GetResponseParameters() []*ResponseParameters {
if m != nil {
return m.ResponseParameters
}
return nil
}
func (m *StreamingOutputCallRequest) GetPayload() *Payload {
if m != nil {
return m.Payload
}
return nil
}
func (m *StreamingOutputCallRequest) GetResponseCompression() CompressionType {
if m != nil {
return m.ResponseCompression
}
return CompressionType_NONE
}
func (m *StreamingOutputCallRequest) GetResponseStatus() *EchoStatus {
if m != nil {
return m.ResponseStatus
}
return nil
}
// Server-streaming response, as configured by the request and parameters.
type StreamingOutputCallResponse struct {
// Payload to increase response size.
Payload *Payload `protobuf:"bytes,1,opt,name=payload,proto3" json:"payload,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *StreamingOutputCallResponse) Reset() { *m = StreamingOutputCallResponse{} }
func (m *StreamingOutputCallResponse) String() string { return proto.CompactTextString(m) }
func (*StreamingOutputCallResponse) ProtoMessage() {}
func (*StreamingOutputCallResponse) Descriptor() ([]byte, []int) {
return fileDescriptor_messages_5c70222ad96bf232, []int{8}
}
func (m *StreamingOutputCallResponse) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_StreamingOutputCallResponse.Unmarshal(m, b)
}
func (m *StreamingOutputCallResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_StreamingOutputCallResponse.Marshal(b, m, deterministic)
}
func (dst *StreamingOutputCallResponse) XXX_Merge(src proto.Message) {
xxx_messageInfo_StreamingOutputCallResponse.Merge(dst, src)
}
func (m *StreamingOutputCallResponse) XXX_Size() int {
return xxx_messageInfo_StreamingOutputCallResponse.Size(m)
}
func (m *StreamingOutputCallResponse) XXX_DiscardUnknown() {
xxx_messageInfo_StreamingOutputCallResponse.DiscardUnknown(m)
}
var xxx_messageInfo_StreamingOutputCallResponse proto.InternalMessageInfo
func (m *StreamingOutputCallResponse) GetPayload() *Payload {
if m != nil {
return m.Payload
}
return nil
}
// For reconnect interop test only.
// Client tells server what reconnection parameters it used.
type ReconnectParams struct {
MaxReconnectBackoffMs int32 `protobuf:"varint,1,opt,name=max_reconnect_backoff_ms,json=maxReconnectBackoffMs,proto3" json:"max_reconnect_backoff_ms,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *ReconnectParams) Reset() { *m = ReconnectParams{} }
func (m *ReconnectParams) String() string { return proto.CompactTextString(m) }
func (*ReconnectParams) ProtoMessage() {}
func (*ReconnectParams) Descriptor() ([]byte, []int) {
return fileDescriptor_messages_5c70222ad96bf232, []int{9}
}
func (m *ReconnectParams) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ReconnectParams.Unmarshal(m, b)
}
func (m *ReconnectParams) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_ReconnectParams.Marshal(b, m, deterministic)
}
func (dst *ReconnectParams) XXX_Merge(src proto.Message) {
xxx_messageInfo_ReconnectParams.Merge(dst, src)
}
func (m *ReconnectParams) XXX_Size() int {
return xxx_messageInfo_ReconnectParams.Size(m)
}
func (m *ReconnectParams) XXX_DiscardUnknown() {
xxx_messageInfo_ReconnectParams.DiscardUnknown(m)
}
var xxx_messageInfo_ReconnectParams proto.InternalMessageInfo
func (m *ReconnectParams) GetMaxReconnectBackoffMs() int32 {
if m != nil {
return m.MaxReconnectBackoffMs
}
return 0
}
// For reconnect interop test only.
// Server tells client whether its reconnects are following the spec and the
// reconnect backoffs it saw.
type ReconnectInfo struct {
Passed bool `protobuf:"varint,1,opt,name=passed,proto3" json:"passed,omitempty"`
BackoffMs []int32 `protobuf:"varint,2,rep,packed,name=backoff_ms,json=backoffMs,proto3" json:"backoff_ms,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *ReconnectInfo) Reset() { *m = ReconnectInfo{} }
func (m *ReconnectInfo) String() string { return proto.CompactTextString(m) }
func (*ReconnectInfo) ProtoMessage() {}
func (*ReconnectInfo) Descriptor() ([]byte, []int) {
return fileDescriptor_messages_5c70222ad96bf232, []int{10}
}
func (m *ReconnectInfo) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ReconnectInfo.Unmarshal(m, b)
}
func (m *ReconnectInfo) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_ReconnectInfo.Marshal(b, m, deterministic)
}
func (dst *ReconnectInfo) XXX_Merge(src proto.Message) {
xxx_messageInfo_ReconnectInfo.Merge(dst, src)
}
func (m *ReconnectInfo) XXX_Size() int {
return xxx_messageInfo_ReconnectInfo.Size(m)
}
func (m *ReconnectInfo) XXX_DiscardUnknown() {
xxx_messageInfo_ReconnectInfo.DiscardUnknown(m)
}
var xxx_messageInfo_ReconnectInfo proto.InternalMessageInfo
func (m *ReconnectInfo) GetPassed() bool {
if m != nil {
return m.Passed
}
return false
}
func (m *ReconnectInfo) GetBackoffMs() []int32 {
if m != nil {
return m.BackoffMs
}
return nil
}
func init() {
proto.RegisterType((*Payload)(nil), "grpc.testing.Payload")
proto.RegisterType((*EchoStatus)(nil), "grpc.testing.EchoStatus")
proto.RegisterType((*SimpleRequest)(nil), "grpc.testing.SimpleRequest")
proto.RegisterType((*SimpleResponse)(nil), "grpc.testing.SimpleResponse")
proto.RegisterType((*StreamingInputCallRequest)(nil), "grpc.testing.StreamingInputCallRequest")
proto.RegisterType((*StreamingInputCallResponse)(nil), "grpc.testing.StreamingInputCallResponse")
proto.RegisterType((*ResponseParameters)(nil), "grpc.testing.ResponseParameters")
proto.RegisterType((*StreamingOutputCallRequest)(nil), "grpc.testing.StreamingOutputCallRequest")
proto.RegisterType((*StreamingOutputCallResponse)(nil), "grpc.testing.StreamingOutputCallResponse")
proto.RegisterType((*ReconnectParams)(nil), "grpc.testing.ReconnectParams")
proto.RegisterType((*ReconnectInfo)(nil), "grpc.testing.ReconnectInfo")
proto.RegisterEnum("grpc.testing.PayloadType", PayloadType_name, PayloadType_value)
proto.RegisterEnum("grpc.testing.CompressionType", CompressionType_name, CompressionType_value)
}
func init() { proto.RegisterFile("messages.proto", fileDescriptor_messages_5c70222ad96bf232) }
var fileDescriptor_messages_5c70222ad96bf232 = []byte{
// 652 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xcc, 0x55, 0x4d, 0x6f, 0xd3, 0x40,
0x10, 0xc5, 0xf9, 0xee, 0x24, 0x4d, 0xa3, 0x85, 0x82, 0x5b, 0x54, 0x11, 0x99, 0x4b, 0x54, 0x89,
0x20, 0x05, 0x09, 0x24, 0x0e, 0xa0, 0xb4, 0x4d, 0x51, 0x50, 0x9a, 0x84, 0x75, 0x7b, 0xe1, 0x62,
0x6d, 0x9c, 0x8d, 0x6b, 0x11, 0x7b, 0x8d, 0x77, 0x8d, 0x9a, 0x1e, 0xb8, 0xf3, 0x83, 0xb9, 0xa3,
0x5d, 0x7f, 0xc4, 0x69, 0x7b, 0x68, 0xe1, 0xc2, 0x6d, 0xf7, 0xed, 0x9b, 0x97, 0x79, 0x33, 0xcf,
0x0a, 0x34, 0x3d, 0xca, 0x39, 0x71, 0x28, 0xef, 0x06, 0x21, 0x13, 0x0c, 0x35, 0x9c, 0x30, 0xb0,
0xbb, 0x82, 0x72, 0xe1, 0xfa, 0x8e, 0x31, 0x82, 0xea, 0x94, 0xac, 0x96, 0x8c, 0xcc, 0xd1, 0x2b,
0x28, 0x89, 0x55, 0x40, 0x75, 0xad, 0xad, 0x75, 0x9a, 0xbd, 0xbd, 0x6e, 0x9e, 0xd7, 0x4d, 0x48,
0xe7, 0xab, 0x80, 0x62, 0x45, 0x43, 0x08, 0x4a, 0x33, 0x36, 0x5f, 0xe9, 0x85, 0xb6, 0xd6, 0x69,
0x60, 0x75, 0x36, 0xde, 0x03, 0x0c, 0xec, 0x4b, 0x66, 0x0a, 0x22, 0x22, 0x2e, 0x19, 0x36, 0x9b,
0xc7, 0x82, 0x65, 0xac, 0xce, 0x48, 0x87, 0x6a, 0xd2, 0x8f, 0x2a, 0xdc, 0xc2, 0xe9, 0xd5, 0xf8,
0x55, 0x84, 0x6d, 0xd3, 0xf5, 0x82, 0x25, 0xc5, 0xf4, 0x7b, 0x44, 0xb9, 0x40, 0x1f, 0x60, 0x3b,
0xa4, 0x3c, 0x60, 0x3e, 0xa7, 0xd6, 0xfd, 0x3a, 0x6b, 0xa4, 0x7c, 0x79, 0x43, 0x2f, 0x73, 0xf5,
0xdc, 0xbd, 0x8e, 0x7f, 0xb1, 0xbc, 0x26, 0x99, 0xee, 0x35, 0x45, 0xaf, 0xa1, 0x1a, 0xc4, 0x0a,
0x7a, 0xb1, 0xad, 0x75, 0xea, 0xbd, 0xdd, 0x3b, 0xe5, 0x71, 0xca, 0x92, 0xaa, 0x0b, 0x77, 0xb9,
0xb4, 0x22, 0x4e, 0x43, 0x9f, 0x78, 0x54, 0x2f, 0xb5, 0xb5, 0x4e, 0x0d, 0x37, 0x24, 0x78, 0x91,
0x60, 0xa8, 0x03, 0x2d, 0x45, 0x62, 0x24, 0x12, 0x97, 0x16, 0xb7, 0x59, 0x40, 0xf5, 0xb2, 0xe2,
0x35, 0x25, 0x3e, 0x91, 0xb0, 0x29, 0x51, 0x34, 0x85, 0x27, 0x59, 0x93, 0x36, 0xf3, 0x82, 0x90,
0x72, 0xee, 0x32, 0x5f, 0xaf, 0x28, 0xaf, 0x07, 0x9b, 0xcd, 0x1c, 0xaf, 0x09, 0xca, 0xef, 0xe3,
0xb4, 0x34, 0xf7, 0x80, 0xfa, 0xb0, 0xb3, 0xb6, 0xad, 0x36, 0xa1, 0x57, 0x95, 0x33, 0x7d, 0x53,
0x6c, 0xbd, 0x29, 0xdc, 0xcc, 0x46, 0xa2, 0xee, 0xc6, 0x4f, 0x68, 0xa6, 0xab, 0x88, 0xf1, 0xfc,
0x98, 0xb4, 0x7b, 0x8d, 0x69, 0x1f, 0x6a, 0xd9, 0x84, 0xe2, 0x4d, 0x67, 0x77, 0xf4, 0x02, 0xea,
0xf9, 0xc1, 0x14, 0xd5, 0x33, 0xb0, 0x6c, 0x28, 0xc6, 0x08, 0xf6, 0x4c, 0x11, 0x52, 0xe2, 0xb9,
0xbe, 0x33, 0xf4, 0x83, 0x48, 0x1c, 0x93, 0xe5, 0x32, 0x8d, 0xc5, 0x43, 0x5b, 0x31, 0xce, 0x61,
0xff, 0x2e, 0xb5, 0xc4, 0xd9, 0x5b, 0x78, 0x46, 0x1c, 0x27, 0xa4, 0x0e, 0x11, 0x74, 0x6e, 0x25,
0x35, 0x71, 0x5e, 0xe2, 0xe0, 0xee, 0xae, 0x9f, 0x13, 0x69, 0x19, 0x1c, 0x63, 0x08, 0x28, 0xd5,
0x98, 0x92, 0x90, 0x78, 0x54, 0xd0, 0x50, 0x65, 0x3e, 0x57, 0xaa, 0xce, 0xd2, 0xae, 0xeb, 0x0b,
0x1a, 0xfe, 0x20, 0x32, 0x35, 0x49, 0x0a, 0x21, 0x85, 0x2e, 0xb8, 0xf1, 0xbb, 0x90, 0xeb, 0x70,
0x12, 0x89, 0x1b, 0x86, 0xff, 0xf5, 0x3b, 0xf8, 0x02, 0x59, 0x4e, 0xac, 0x20, 0x6b, 0x55, 0x2f,
0xb4, 0x8b, 0x9d, 0x7a, 0xaf, 0xbd, 0xa9, 0x72, 0xdb, 0x12, 0x46, 0xe1, 0x6d, 0x9b, 0x0f, 0xfe,
0x6a, 0xfe, 0xcb, 0x98, 0x8f, 0xe1, 0xf9, 0x9d, 0x63, 0xff, 0xcb, 0xcc, 0x1b, 0x9f, 0x61, 0x07,
0x53, 0x9b, 0xf9, 0x3e, 0xb5, 0x85, 0x1a, 0x16, 0x47, 0xef, 0x40, 0xf7, 0xc8, 0x95, 0x15, 0xa6,
0xb0, 0x35, 0x23, 0xf6, 0x37, 0xb6, 0x58, 0x58, 0x1e, 0x4f, 0xe3, 0xe5, 0x91, 0xab, 0xac, 0xea,
0x28, 0x7e, 0x3d, 0xe3, 0xc6, 0x29, 0x6c, 0x67, 0xe8, 0xd0, 0x5f, 0x30, 0xf4, 0x14, 0x2a, 0x01,
0xe1, 0x9c, 0xc6, 0xcd, 0xd4, 0x70, 0x72, 0x43, 0x07, 0x00, 0x39, 0x4d, 0xb9, 0xd4, 0x32, 0xde,
0x9a, 0xa5, 0x3a, 0x87, 0x1f, 0xa1, 0x9e, 0x4b, 0x06, 0x6a, 0x41, 0xe3, 0x78, 0x72, 0x36, 0xc5,
0x03, 0xd3, 0xec, 0x1f, 0x8d, 0x06, 0xad, 0x47, 0x08, 0x41, 0xf3, 0x62, 0xbc, 0x81, 0x69, 0x08,
0xa0, 0x82, 0xfb, 0xe3, 0x93, 0xc9, 0x59, 0xab, 0x70, 0xd8, 0x83, 0x9d, 0x1b, 0xfb, 0x40, 0x35,
0x28, 0x8d, 0x27, 0x63, 0x59, 0x5c, 0x83, 0xd2, 0xa7, 0xaf, 0xc3, 0x69, 0x4b, 0x43, 0x75, 0xa8,
0x9e, 0x0c, 0x4e, 0x47, 0xfd, 0xf3, 0x41, 0xab, 0x30, 0xab, 0xa8, 0xbf, 0x9a, 0x37, 0x7f, 0x02,
0x00, 0x00, 0xff, 0xff, 0xc2, 0x6a, 0xce, 0x1e, 0x7c, 0x06, 0x00, 0x00,
}

View file

@ -0,0 +1,157 @@
// Copyright 2016 gRPC authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Message definitions to be used by integration test service definitions.
syntax = "proto3";
package grpc.testing;
// The type of payload that should be returned.
enum PayloadType {
// Compressable text format.
COMPRESSABLE = 0;
// Uncompressable binary format.
UNCOMPRESSABLE = 1;
// Randomly chosen from all other formats defined in this enum.
RANDOM = 2;
}
// Compression algorithms
enum CompressionType {
// No compression
NONE = 0;
GZIP = 1;
DEFLATE = 2;
}
// A block of data, to simply increase gRPC message size.
message Payload {
// The type of data in body.
PayloadType type = 1;
// Primary contents of payload.
bytes body = 2;
}
// A protobuf representation for grpc status. This is used by test
// clients to specify a status that the server should attempt to return.
message EchoStatus {
int32 code = 1;
string message = 2;
}
// Unary request.
message SimpleRequest {
// Desired payload type in the response from the server.
// If response_type is RANDOM, server randomly chooses one from other formats.
PayloadType response_type = 1;
// Desired payload size in the response from the server.
// If response_type is COMPRESSABLE, this denotes the size before compression.
int32 response_size = 2;
// Optional input payload sent along with the request.
Payload payload = 3;
// Whether SimpleResponse should include username.
bool fill_username = 4;
// Whether SimpleResponse should include OAuth scope.
bool fill_oauth_scope = 5;
// Compression algorithm to be used by the server for the response (stream)
CompressionType response_compression = 6;
// Whether server should return a given status
EchoStatus response_status = 7;
}
// Unary response, as configured by the request.
message SimpleResponse {
// Payload to increase message size.
Payload payload = 1;
// The user the request came from, for verifying authentication was
// successful when the client expected it.
string username = 2;
// OAuth scope.
string oauth_scope = 3;
}
// Client-streaming request.
message StreamingInputCallRequest {
// Optional input payload sent along with the request.
Payload payload = 1;
// Not expecting any payload from the response.
}
// Client-streaming response.
message StreamingInputCallResponse {
// Aggregated size of payloads received from the client.
int32 aggregated_payload_size = 1;
}
// Configuration for a particular response.
message ResponseParameters {
// Desired payload sizes in responses from the server.
// If response_type is COMPRESSABLE, this denotes the size before compression.
int32 size = 1;
// Desired interval between consecutive responses in the response stream in
// microseconds.
int32 interval_us = 2;
}
// Server-streaming request.
message StreamingOutputCallRequest {
// Desired payload type in the response from the server.
// If response_type is RANDOM, the payload from each response in the stream
// might be of different types. This is to simulate a mixed type of payload
// stream.
PayloadType response_type = 1;
// Configuration for each expected response message.
repeated ResponseParameters response_parameters = 2;
// Optional input payload sent along with the request.
Payload payload = 3;
// Compression algorithm to be used by the server for the response (stream)
CompressionType response_compression = 6;
// Whether server should return a given status
EchoStatus response_status = 7;
}
// Server-streaming response, as configured by the request and parameters.
message StreamingOutputCallResponse {
// Payload to increase response size.
Payload payload = 1;
}
// For reconnect interop test only.
// Client tells server what reconnection parameters it used.
message ReconnectParams {
int32 max_reconnect_backoff_ms = 1;
}
// For reconnect interop test only.
// Server tells client whether its reconnects are following the spec and the
// reconnect backoffs it saw.
message ReconnectInfo {
bool passed = 1;
repeated int32 backoff_ms = 2;
}

View file

@ -0,0 +1,348 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// source: payloads.proto
package grpc_testing
import proto "github.com/golang/protobuf/proto"
import fmt "fmt"
import math "math"
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
type ByteBufferParams struct {
ReqSize int32 `protobuf:"varint,1,opt,name=req_size,json=reqSize,proto3" json:"req_size,omitempty"`
RespSize int32 `protobuf:"varint,2,opt,name=resp_size,json=respSize,proto3" json:"resp_size,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *ByteBufferParams) Reset() { *m = ByteBufferParams{} }
func (m *ByteBufferParams) String() string { return proto.CompactTextString(m) }
func (*ByteBufferParams) ProtoMessage() {}
func (*ByteBufferParams) Descriptor() ([]byte, []int) {
return fileDescriptor_payloads_3abc71de35f06c83, []int{0}
}
func (m *ByteBufferParams) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ByteBufferParams.Unmarshal(m, b)
}
func (m *ByteBufferParams) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_ByteBufferParams.Marshal(b, m, deterministic)
}
func (dst *ByteBufferParams) XXX_Merge(src proto.Message) {
xxx_messageInfo_ByteBufferParams.Merge(dst, src)
}
func (m *ByteBufferParams) XXX_Size() int {
return xxx_messageInfo_ByteBufferParams.Size(m)
}
func (m *ByteBufferParams) XXX_DiscardUnknown() {
xxx_messageInfo_ByteBufferParams.DiscardUnknown(m)
}
var xxx_messageInfo_ByteBufferParams proto.InternalMessageInfo
func (m *ByteBufferParams) GetReqSize() int32 {
if m != nil {
return m.ReqSize
}
return 0
}
func (m *ByteBufferParams) GetRespSize() int32 {
if m != nil {
return m.RespSize
}
return 0
}
type SimpleProtoParams struct {
ReqSize int32 `protobuf:"varint,1,opt,name=req_size,json=reqSize,proto3" json:"req_size,omitempty"`
RespSize int32 `protobuf:"varint,2,opt,name=resp_size,json=respSize,proto3" json:"resp_size,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *SimpleProtoParams) Reset() { *m = SimpleProtoParams{} }
func (m *SimpleProtoParams) String() string { return proto.CompactTextString(m) }
func (*SimpleProtoParams) ProtoMessage() {}
func (*SimpleProtoParams) Descriptor() ([]byte, []int) {
return fileDescriptor_payloads_3abc71de35f06c83, []int{1}
}
func (m *SimpleProtoParams) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_SimpleProtoParams.Unmarshal(m, b)
}
func (m *SimpleProtoParams) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_SimpleProtoParams.Marshal(b, m, deterministic)
}
func (dst *SimpleProtoParams) XXX_Merge(src proto.Message) {
xxx_messageInfo_SimpleProtoParams.Merge(dst, src)
}
func (m *SimpleProtoParams) XXX_Size() int {
return xxx_messageInfo_SimpleProtoParams.Size(m)
}
func (m *SimpleProtoParams) XXX_DiscardUnknown() {
xxx_messageInfo_SimpleProtoParams.DiscardUnknown(m)
}
var xxx_messageInfo_SimpleProtoParams proto.InternalMessageInfo
func (m *SimpleProtoParams) GetReqSize() int32 {
if m != nil {
return m.ReqSize
}
return 0
}
func (m *SimpleProtoParams) GetRespSize() int32 {
if m != nil {
return m.RespSize
}
return 0
}
type ComplexProtoParams struct {
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *ComplexProtoParams) Reset() { *m = ComplexProtoParams{} }
func (m *ComplexProtoParams) String() string { return proto.CompactTextString(m) }
func (*ComplexProtoParams) ProtoMessage() {}
func (*ComplexProtoParams) Descriptor() ([]byte, []int) {
return fileDescriptor_payloads_3abc71de35f06c83, []int{2}
}
func (m *ComplexProtoParams) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ComplexProtoParams.Unmarshal(m, b)
}
func (m *ComplexProtoParams) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_ComplexProtoParams.Marshal(b, m, deterministic)
}
func (dst *ComplexProtoParams) XXX_Merge(src proto.Message) {
xxx_messageInfo_ComplexProtoParams.Merge(dst, src)
}
func (m *ComplexProtoParams) XXX_Size() int {
return xxx_messageInfo_ComplexProtoParams.Size(m)
}
func (m *ComplexProtoParams) XXX_DiscardUnknown() {
xxx_messageInfo_ComplexProtoParams.DiscardUnknown(m)
}
var xxx_messageInfo_ComplexProtoParams proto.InternalMessageInfo
type PayloadConfig struct {
// Types that are valid to be assigned to Payload:
// *PayloadConfig_BytebufParams
// *PayloadConfig_SimpleParams
// *PayloadConfig_ComplexParams
Payload isPayloadConfig_Payload `protobuf_oneof:"payload"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *PayloadConfig) Reset() { *m = PayloadConfig{} }
func (m *PayloadConfig) String() string { return proto.CompactTextString(m) }
func (*PayloadConfig) ProtoMessage() {}
func (*PayloadConfig) Descriptor() ([]byte, []int) {
return fileDescriptor_payloads_3abc71de35f06c83, []int{3}
}
func (m *PayloadConfig) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_PayloadConfig.Unmarshal(m, b)
}
func (m *PayloadConfig) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_PayloadConfig.Marshal(b, m, deterministic)
}
func (dst *PayloadConfig) XXX_Merge(src proto.Message) {
xxx_messageInfo_PayloadConfig.Merge(dst, src)
}
func (m *PayloadConfig) XXX_Size() int {
return xxx_messageInfo_PayloadConfig.Size(m)
}
func (m *PayloadConfig) XXX_DiscardUnknown() {
xxx_messageInfo_PayloadConfig.DiscardUnknown(m)
}
var xxx_messageInfo_PayloadConfig proto.InternalMessageInfo
type isPayloadConfig_Payload interface {
isPayloadConfig_Payload()
}
type PayloadConfig_BytebufParams struct {
BytebufParams *ByteBufferParams `protobuf:"bytes,1,opt,name=bytebuf_params,json=bytebufParams,proto3,oneof"`
}
type PayloadConfig_SimpleParams struct {
SimpleParams *SimpleProtoParams `protobuf:"bytes,2,opt,name=simple_params,json=simpleParams,proto3,oneof"`
}
type PayloadConfig_ComplexParams struct {
ComplexParams *ComplexProtoParams `protobuf:"bytes,3,opt,name=complex_params,json=complexParams,proto3,oneof"`
}
func (*PayloadConfig_BytebufParams) isPayloadConfig_Payload() {}
func (*PayloadConfig_SimpleParams) isPayloadConfig_Payload() {}
func (*PayloadConfig_ComplexParams) isPayloadConfig_Payload() {}
func (m *PayloadConfig) GetPayload() isPayloadConfig_Payload {
if m != nil {
return m.Payload
}
return nil
}
func (m *PayloadConfig) GetBytebufParams() *ByteBufferParams {
if x, ok := m.GetPayload().(*PayloadConfig_BytebufParams); ok {
return x.BytebufParams
}
return nil
}
func (m *PayloadConfig) GetSimpleParams() *SimpleProtoParams {
if x, ok := m.GetPayload().(*PayloadConfig_SimpleParams); ok {
return x.SimpleParams
}
return nil
}
func (m *PayloadConfig) GetComplexParams() *ComplexProtoParams {
if x, ok := m.GetPayload().(*PayloadConfig_ComplexParams); ok {
return x.ComplexParams
}
return nil
}
// XXX_OneofFuncs is for the internal use of the proto package.
func (*PayloadConfig) XXX_OneofFuncs() (func(msg proto.Message, b *proto.Buffer) error, func(msg proto.Message, tag, wire int, b *proto.Buffer) (bool, error), func(msg proto.Message) (n int), []interface{}) {
return _PayloadConfig_OneofMarshaler, _PayloadConfig_OneofUnmarshaler, _PayloadConfig_OneofSizer, []interface{}{
(*PayloadConfig_BytebufParams)(nil),
(*PayloadConfig_SimpleParams)(nil),
(*PayloadConfig_ComplexParams)(nil),
}
}
func _PayloadConfig_OneofMarshaler(msg proto.Message, b *proto.Buffer) error {
m := msg.(*PayloadConfig)
// payload
switch x := m.Payload.(type) {
case *PayloadConfig_BytebufParams:
b.EncodeVarint(1<<3 | proto.WireBytes)
if err := b.EncodeMessage(x.BytebufParams); err != nil {
return err
}
case *PayloadConfig_SimpleParams:
b.EncodeVarint(2<<3 | proto.WireBytes)
if err := b.EncodeMessage(x.SimpleParams); err != nil {
return err
}
case *PayloadConfig_ComplexParams:
b.EncodeVarint(3<<3 | proto.WireBytes)
if err := b.EncodeMessage(x.ComplexParams); err != nil {
return err
}
case nil:
default:
return fmt.Errorf("PayloadConfig.Payload has unexpected type %T", x)
}
return nil
}
func _PayloadConfig_OneofUnmarshaler(msg proto.Message, tag, wire int, b *proto.Buffer) (bool, error) {
m := msg.(*PayloadConfig)
switch tag {
case 1: // payload.bytebuf_params
if wire != proto.WireBytes {
return true, proto.ErrInternalBadWireType
}
msg := new(ByteBufferParams)
err := b.DecodeMessage(msg)
m.Payload = &PayloadConfig_BytebufParams{msg}
return true, err
case 2: // payload.simple_params
if wire != proto.WireBytes {
return true, proto.ErrInternalBadWireType
}
msg := new(SimpleProtoParams)
err := b.DecodeMessage(msg)
m.Payload = &PayloadConfig_SimpleParams{msg}
return true, err
case 3: // payload.complex_params
if wire != proto.WireBytes {
return true, proto.ErrInternalBadWireType
}
msg := new(ComplexProtoParams)
err := b.DecodeMessage(msg)
m.Payload = &PayloadConfig_ComplexParams{msg}
return true, err
default:
return false, nil
}
}
func _PayloadConfig_OneofSizer(msg proto.Message) (n int) {
m := msg.(*PayloadConfig)
// payload
switch x := m.Payload.(type) {
case *PayloadConfig_BytebufParams:
s := proto.Size(x.BytebufParams)
n += 1 // tag and wire
n += proto.SizeVarint(uint64(s))
n += s
case *PayloadConfig_SimpleParams:
s := proto.Size(x.SimpleParams)
n += 1 // tag and wire
n += proto.SizeVarint(uint64(s))
n += s
case *PayloadConfig_ComplexParams:
s := proto.Size(x.ComplexParams)
n += 1 // tag and wire
n += proto.SizeVarint(uint64(s))
n += s
case nil:
default:
panic(fmt.Sprintf("proto: unexpected type %T in oneof", x))
}
return n
}
func init() {
proto.RegisterType((*ByteBufferParams)(nil), "grpc.testing.ByteBufferParams")
proto.RegisterType((*SimpleProtoParams)(nil), "grpc.testing.SimpleProtoParams")
proto.RegisterType((*ComplexProtoParams)(nil), "grpc.testing.ComplexProtoParams")
proto.RegisterType((*PayloadConfig)(nil), "grpc.testing.PayloadConfig")
}
func init() { proto.RegisterFile("payloads.proto", fileDescriptor_payloads_3abc71de35f06c83) }
var fileDescriptor_payloads_3abc71de35f06c83 = []byte{
// 254 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x2b, 0x48, 0xac, 0xcc,
0xc9, 0x4f, 0x4c, 0x29, 0xd6, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x49, 0x2f, 0x2a, 0x48,
0xd6, 0x2b, 0x49, 0x2d, 0x2e, 0xc9, 0xcc, 0x4b, 0x57, 0xf2, 0xe2, 0x12, 0x70, 0xaa, 0x2c, 0x49,
0x75, 0x2a, 0x4d, 0x4b, 0x4b, 0x2d, 0x0a, 0x48, 0x2c, 0x4a, 0xcc, 0x2d, 0x16, 0x92, 0xe4, 0xe2,
0x28, 0x4a, 0x2d, 0x8c, 0x2f, 0xce, 0xac, 0x4a, 0x95, 0x60, 0x54, 0x60, 0xd4, 0x60, 0x0d, 0x62,
0x2f, 0x4a, 0x2d, 0x0c, 0xce, 0xac, 0x4a, 0x15, 0x92, 0xe6, 0xe2, 0x2c, 0x4a, 0x2d, 0x2e, 0x80,
0xc8, 0x31, 0x81, 0xe5, 0x38, 0x40, 0x02, 0x20, 0x49, 0x25, 0x6f, 0x2e, 0xc1, 0xe0, 0xcc, 0xdc,
0x82, 0x9c, 0xd4, 0x00, 0x90, 0x45, 0x14, 0x1a, 0x26, 0xc2, 0x25, 0xe4, 0x9c, 0x0f, 0x32, 0xac,
0x02, 0xc9, 0x34, 0xa5, 0x6f, 0x8c, 0x5c, 0xbc, 0x01, 0x10, 0xff, 0x38, 0xe7, 0xe7, 0xa5, 0x65,
0xa6, 0x0b, 0xb9, 0x73, 0xf1, 0x25, 0x55, 0x96, 0xa4, 0x26, 0x95, 0xa6, 0xc5, 0x17, 0x80, 0xd5,
0x80, 0x6d, 0xe1, 0x36, 0x92, 0xd3, 0x43, 0xf6, 0xa7, 0x1e, 0xba, 0x27, 0x3d, 0x18, 0x82, 0x78,
0xa1, 0xfa, 0xa0, 0x0e, 0x75, 0xe3, 0xe2, 0x2d, 0x06, 0xbb, 0x1e, 0x66, 0x0e, 0x13, 0xd8, 0x1c,
0x79, 0x54, 0x73, 0x30, 0x3c, 0xe8, 0xc1, 0x10, 0xc4, 0x03, 0xd1, 0x07, 0x35, 0xc7, 0x93, 0x8b,
0x2f, 0x19, 0xe2, 0x70, 0x98, 0x41, 0xcc, 0x60, 0x83, 0x14, 0x50, 0x0d, 0xc2, 0xf4, 0x1c, 0xc8,
0x49, 0x50, 0x9d, 0x10, 0x01, 0x27, 0x4e, 0x2e, 0x76, 0x68, 0xe4, 0x25, 0xb1, 0x81, 0x23, 0xcf,
0x18, 0x10, 0x00, 0x00, 0xff, 0xff, 0xb0, 0x8c, 0x18, 0x4e, 0xce, 0x01, 0x00, 0x00,
}

View file

@ -0,0 +1,40 @@
// Copyright 2016 gRPC authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package grpc.testing;
message ByteBufferParams {
int32 req_size = 1;
int32 resp_size = 2;
}
message SimpleProtoParams {
int32 req_size = 1;
int32 resp_size = 2;
}
message ComplexProtoParams {
// TODO (vpai): Fill this in once the details of complex, representative
// protos are decided
}
message PayloadConfig {
oneof payload {
ByteBufferParams bytebuf_params = 1;
SimpleProtoParams simple_params = 2;
ComplexProtoParams complex_params = 3;
}
}

View file

@ -0,0 +1,448 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// source: services.proto
package grpc_testing
import proto "github.com/golang/protobuf/proto"
import fmt "fmt"
import math "math"
import (
context "golang.org/x/net/context"
grpc "google.golang.org/grpc"
)
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
// Reference imports to suppress errors if they are not otherwise used.
var _ context.Context
var _ grpc.ClientConn
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
const _ = grpc.SupportPackageIsVersion4
// BenchmarkServiceClient is the client API for BenchmarkService service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream.
type BenchmarkServiceClient interface {
// One request followed by one response.
// The server returns the client payload as-is.
UnaryCall(ctx context.Context, in *SimpleRequest, opts ...grpc.CallOption) (*SimpleResponse, error)
// One request followed by one response.
// The server returns the client payload as-is.
StreamingCall(ctx context.Context, opts ...grpc.CallOption) (BenchmarkService_StreamingCallClient, error)
}
type benchmarkServiceClient struct {
cc *grpc.ClientConn
}
func NewBenchmarkServiceClient(cc *grpc.ClientConn) BenchmarkServiceClient {
return &benchmarkServiceClient{cc}
}
func (c *benchmarkServiceClient) UnaryCall(ctx context.Context, in *SimpleRequest, opts ...grpc.CallOption) (*SimpleResponse, error) {
out := new(SimpleResponse)
err := c.cc.Invoke(ctx, "/grpc.testing.BenchmarkService/UnaryCall", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *benchmarkServiceClient) StreamingCall(ctx context.Context, opts ...grpc.CallOption) (BenchmarkService_StreamingCallClient, error) {
stream, err := c.cc.NewStream(ctx, &_BenchmarkService_serviceDesc.Streams[0], "/grpc.testing.BenchmarkService/StreamingCall", opts...)
if err != nil {
return nil, err
}
x := &benchmarkServiceStreamingCallClient{stream}
return x, nil
}
type BenchmarkService_StreamingCallClient interface {
Send(*SimpleRequest) error
Recv() (*SimpleResponse, error)
grpc.ClientStream
}
type benchmarkServiceStreamingCallClient struct {
grpc.ClientStream
}
func (x *benchmarkServiceStreamingCallClient) Send(m *SimpleRequest) error {
return x.ClientStream.SendMsg(m)
}
func (x *benchmarkServiceStreamingCallClient) Recv() (*SimpleResponse, error) {
m := new(SimpleResponse)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// BenchmarkServiceServer is the server API for BenchmarkService service.
type BenchmarkServiceServer interface {
// One request followed by one response.
// The server returns the client payload as-is.
UnaryCall(context.Context, *SimpleRequest) (*SimpleResponse, error)
// One request followed by one response.
// The server returns the client payload as-is.
StreamingCall(BenchmarkService_StreamingCallServer) error
}
func RegisterBenchmarkServiceServer(s *grpc.Server, srv BenchmarkServiceServer) {
s.RegisterService(&_BenchmarkService_serviceDesc, srv)
}
func _BenchmarkService_UnaryCall_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(SimpleRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(BenchmarkServiceServer).UnaryCall(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/grpc.testing.BenchmarkService/UnaryCall",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BenchmarkServiceServer).UnaryCall(ctx, req.(*SimpleRequest))
}
return interceptor(ctx, in, info, handler)
}
func _BenchmarkService_StreamingCall_Handler(srv interface{}, stream grpc.ServerStream) error {
return srv.(BenchmarkServiceServer).StreamingCall(&benchmarkServiceStreamingCallServer{stream})
}
type BenchmarkService_StreamingCallServer interface {
Send(*SimpleResponse) error
Recv() (*SimpleRequest, error)
grpc.ServerStream
}
type benchmarkServiceStreamingCallServer struct {
grpc.ServerStream
}
func (x *benchmarkServiceStreamingCallServer) Send(m *SimpleResponse) error {
return x.ServerStream.SendMsg(m)
}
func (x *benchmarkServiceStreamingCallServer) Recv() (*SimpleRequest, error) {
m := new(SimpleRequest)
if err := x.ServerStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
var _BenchmarkService_serviceDesc = grpc.ServiceDesc{
ServiceName: "grpc.testing.BenchmarkService",
HandlerType: (*BenchmarkServiceServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "UnaryCall",
Handler: _BenchmarkService_UnaryCall_Handler,
},
},
Streams: []grpc.StreamDesc{
{
StreamName: "StreamingCall",
Handler: _BenchmarkService_StreamingCall_Handler,
ServerStreams: true,
ClientStreams: true,
},
},
Metadata: "services.proto",
}
// WorkerServiceClient is the client API for WorkerService service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream.
type WorkerServiceClient interface {
// Start server with specified workload.
// First request sent specifies the ServerConfig followed by ServerStatus
// response. After that, a "Mark" can be sent anytime to request the latest
// stats. Closing the stream will initiate shutdown of the test server
// and once the shutdown has finished, the OK status is sent to terminate
// this RPC.
RunServer(ctx context.Context, opts ...grpc.CallOption) (WorkerService_RunServerClient, error)
// Start client with specified workload.
// First request sent specifies the ClientConfig followed by ClientStatus
// response. After that, a "Mark" can be sent anytime to request the latest
// stats. Closing the stream will initiate shutdown of the test client
// and once the shutdown has finished, the OK status is sent to terminate
// this RPC.
RunClient(ctx context.Context, opts ...grpc.CallOption) (WorkerService_RunClientClient, error)
// Just return the core count - unary call
CoreCount(ctx context.Context, in *CoreRequest, opts ...grpc.CallOption) (*CoreResponse, error)
// Quit this worker
QuitWorker(ctx context.Context, in *Void, opts ...grpc.CallOption) (*Void, error)
}
type workerServiceClient struct {
cc *grpc.ClientConn
}
func NewWorkerServiceClient(cc *grpc.ClientConn) WorkerServiceClient {
return &workerServiceClient{cc}
}
func (c *workerServiceClient) RunServer(ctx context.Context, opts ...grpc.CallOption) (WorkerService_RunServerClient, error) {
stream, err := c.cc.NewStream(ctx, &_WorkerService_serviceDesc.Streams[0], "/grpc.testing.WorkerService/RunServer", opts...)
if err != nil {
return nil, err
}
x := &workerServiceRunServerClient{stream}
return x, nil
}
type WorkerService_RunServerClient interface {
Send(*ServerArgs) error
Recv() (*ServerStatus, error)
grpc.ClientStream
}
type workerServiceRunServerClient struct {
grpc.ClientStream
}
func (x *workerServiceRunServerClient) Send(m *ServerArgs) error {
return x.ClientStream.SendMsg(m)
}
func (x *workerServiceRunServerClient) Recv() (*ServerStatus, error) {
m := new(ServerStatus)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
func (c *workerServiceClient) RunClient(ctx context.Context, opts ...grpc.CallOption) (WorkerService_RunClientClient, error) {
stream, err := c.cc.NewStream(ctx, &_WorkerService_serviceDesc.Streams[1], "/grpc.testing.WorkerService/RunClient", opts...)
if err != nil {
return nil, err
}
x := &workerServiceRunClientClient{stream}
return x, nil
}
type WorkerService_RunClientClient interface {
Send(*ClientArgs) error
Recv() (*ClientStatus, error)
grpc.ClientStream
}
type workerServiceRunClientClient struct {
grpc.ClientStream
}
func (x *workerServiceRunClientClient) Send(m *ClientArgs) error {
return x.ClientStream.SendMsg(m)
}
func (x *workerServiceRunClientClient) Recv() (*ClientStatus, error) {
m := new(ClientStatus)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
func (c *workerServiceClient) CoreCount(ctx context.Context, in *CoreRequest, opts ...grpc.CallOption) (*CoreResponse, error) {
out := new(CoreResponse)
err := c.cc.Invoke(ctx, "/grpc.testing.WorkerService/CoreCount", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *workerServiceClient) QuitWorker(ctx context.Context, in *Void, opts ...grpc.CallOption) (*Void, error) {
out := new(Void)
err := c.cc.Invoke(ctx, "/grpc.testing.WorkerService/QuitWorker", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// WorkerServiceServer is the server API for WorkerService service.
type WorkerServiceServer interface {
// Start server with specified workload.
// First request sent specifies the ServerConfig followed by ServerStatus
// response. After that, a "Mark" can be sent anytime to request the latest
// stats. Closing the stream will initiate shutdown of the test server
// and once the shutdown has finished, the OK status is sent to terminate
// this RPC.
RunServer(WorkerService_RunServerServer) error
// Start client with specified workload.
// First request sent specifies the ClientConfig followed by ClientStatus
// response. After that, a "Mark" can be sent anytime to request the latest
// stats. Closing the stream will initiate shutdown of the test client
// and once the shutdown has finished, the OK status is sent to terminate
// this RPC.
RunClient(WorkerService_RunClientServer) error
// Just return the core count - unary call
CoreCount(context.Context, *CoreRequest) (*CoreResponse, error)
// Quit this worker
QuitWorker(context.Context, *Void) (*Void, error)
}
func RegisterWorkerServiceServer(s *grpc.Server, srv WorkerServiceServer) {
s.RegisterService(&_WorkerService_serviceDesc, srv)
}
func _WorkerService_RunServer_Handler(srv interface{}, stream grpc.ServerStream) error {
return srv.(WorkerServiceServer).RunServer(&workerServiceRunServerServer{stream})
}
type WorkerService_RunServerServer interface {
Send(*ServerStatus) error
Recv() (*ServerArgs, error)
grpc.ServerStream
}
type workerServiceRunServerServer struct {
grpc.ServerStream
}
func (x *workerServiceRunServerServer) Send(m *ServerStatus) error {
return x.ServerStream.SendMsg(m)
}
func (x *workerServiceRunServerServer) Recv() (*ServerArgs, error) {
m := new(ServerArgs)
if err := x.ServerStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
func _WorkerService_RunClient_Handler(srv interface{}, stream grpc.ServerStream) error {
return srv.(WorkerServiceServer).RunClient(&workerServiceRunClientServer{stream})
}
type WorkerService_RunClientServer interface {
Send(*ClientStatus) error
Recv() (*ClientArgs, error)
grpc.ServerStream
}
type workerServiceRunClientServer struct {
grpc.ServerStream
}
func (x *workerServiceRunClientServer) Send(m *ClientStatus) error {
return x.ServerStream.SendMsg(m)
}
func (x *workerServiceRunClientServer) Recv() (*ClientArgs, error) {
m := new(ClientArgs)
if err := x.ServerStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
func _WorkerService_CoreCount_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(CoreRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(WorkerServiceServer).CoreCount(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/grpc.testing.WorkerService/CoreCount",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(WorkerServiceServer).CoreCount(ctx, req.(*CoreRequest))
}
return interceptor(ctx, in, info, handler)
}
func _WorkerService_QuitWorker_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(Void)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(WorkerServiceServer).QuitWorker(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/grpc.testing.WorkerService/QuitWorker",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(WorkerServiceServer).QuitWorker(ctx, req.(*Void))
}
return interceptor(ctx, in, info, handler)
}
var _WorkerService_serviceDesc = grpc.ServiceDesc{
ServiceName: "grpc.testing.WorkerService",
HandlerType: (*WorkerServiceServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "CoreCount",
Handler: _WorkerService_CoreCount_Handler,
},
{
MethodName: "QuitWorker",
Handler: _WorkerService_QuitWorker_Handler,
},
},
Streams: []grpc.StreamDesc{
{
StreamName: "RunServer",
Handler: _WorkerService_RunServer_Handler,
ServerStreams: true,
ClientStreams: true,
},
{
StreamName: "RunClient",
Handler: _WorkerService_RunClient_Handler,
ServerStreams: true,
ClientStreams: true,
},
},
Metadata: "services.proto",
}
func init() { proto.RegisterFile("services.proto", fileDescriptor_services_bf68f4d7cbd0e0a1) }
var fileDescriptor_services_bf68f4d7cbd0e0a1 = []byte{
// 255 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xa4, 0x91, 0xc1, 0x4a, 0xc4, 0x30,
0x10, 0x86, 0xa9, 0x07, 0xa1, 0xc1, 0x2e, 0x92, 0x93, 0x46, 0x1f, 0xc0, 0x53, 0x91, 0xd5, 0x17,
0x70, 0x8b, 0x1e, 0x05, 0xb7, 0xa8, 0xe7, 0x58, 0x87, 0x1a, 0x36, 0xcd, 0xd4, 0x99, 0x89, 0xe0,
0x93, 0xf8, 0x0e, 0x3e, 0xa5, 0xec, 0x66, 0x57, 0xd6, 0x92, 0x9b, 0xc7, 0xf9, 0xbf, 0xe1, 0x23,
0x7f, 0x46, 0xcd, 0x18, 0xe8, 0xc3, 0x75, 0xc0, 0xf5, 0x48, 0x28, 0xa8, 0x8f, 0x7a, 0x1a, 0xbb,
0x5a, 0x80, 0xc5, 0x85, 0xde, 0xcc, 0x06, 0x60, 0xb6, 0xfd, 0x8e, 0x9a, 0xaa, 0xc3, 0x20, 0x84,
0x3e, 0x8d, 0xf3, 0xef, 0x42, 0x1d, 0x2f, 0x20, 0x74, 0x6f, 0x83, 0xa5, 0x55, 0x9b, 0x44, 0xfa,
0x4e, 0x95, 0x8f, 0xc1, 0xd2, 0x67, 0x63, 0xbd, 0xd7, 0x67, 0xf5, 0xbe, 0xaf, 0x6e, 0xdd, 0x30,
0x7a, 0x58, 0xc2, 0x7b, 0x04, 0x16, 0x73, 0x9e, 0x87, 0x3c, 0x62, 0x60, 0xd0, 0xf7, 0xaa, 0x6a,
0x85, 0xc0, 0x0e, 0x2e, 0xf4, 0xff, 0x74, 0x5d, 0x14, 0x97, 0xc5, 0xfc, 0xeb, 0x40, 0x55, 0xcf,
0x48, 0x2b, 0xa0, 0xdd, 0x4b, 0x6f, 0x55, 0xb9, 0x8c, 0x61, 0x3d, 0x01, 0xe9, 0x93, 0x89, 0x60,
0x93, 0xde, 0x50, 0xcf, 0xc6, 0xe4, 0x48, 0x2b, 0x56, 0x22, 0xaf, 0xc5, 0x5b, 0x4d, 0xe3, 0x1d,
0x04, 0x99, 0x6a, 0x52, 0x9a, 0xd3, 0x24, 0xb2, 0xa7, 0x59, 0xa8, 0xb2, 0x41, 0x82, 0x06, 0x63,
0x10, 0x7d, 0x3a, 0x59, 0x46, 0xfa, 0x6d, 0x6a, 0x72, 0x68, 0xfb, 0x67, 0xd7, 0x4a, 0x3d, 0x44,
0x27, 0xa9, 0xa6, 0xd6, 0x7f, 0x37, 0x9f, 0xd0, 0xbd, 0x9a, 0x4c, 0xf6, 0x72, 0xb8, 0xb9, 0xe6,
0xd5, 0x4f, 0x00, 0x00, 0x00, 0xff, 0xff, 0x3b, 0x84, 0x02, 0xe3, 0x0c, 0x02, 0x00, 0x00,
}

View file

@ -0,0 +1,56 @@
// Copyright 2016 gRPC authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// An integration test service that covers all the method signature permutations
// of unary/streaming requests/responses.
syntax = "proto3";
import "messages.proto";
import "control.proto";
package grpc.testing;
service BenchmarkService {
// One request followed by one response.
// The server returns the client payload as-is.
rpc UnaryCall(SimpleRequest) returns (SimpleResponse);
// One request followed by one response.
// The server returns the client payload as-is.
rpc StreamingCall(stream SimpleRequest) returns (stream SimpleResponse);
}
service WorkerService {
// Start server with specified workload.
// First request sent specifies the ServerConfig followed by ServerStatus
// response. After that, a "Mark" can be sent anytime to request the latest
// stats. Closing the stream will initiate shutdown of the test server
// and once the shutdown has finished, the OK status is sent to terminate
// this RPC.
rpc RunServer(stream ServerArgs) returns (stream ServerStatus);
// Start client with specified workload.
// First request sent specifies the ClientConfig followed by ClientStatus
// response. After that, a "Mark" can be sent anytime to request the latest
// stats. Closing the stream will initiate shutdown of the test client
// and once the shutdown has finished, the OK status is sent to terminate
// this RPC.
rpc RunClient(stream ClientArgs) returns (stream ClientStatus);
// Just return the core count - unary call
rpc CoreCount(CoreRequest) returns (CoreResponse);
// Quit this worker
rpc QuitWorker(Void) returns (Void);
}

View file

@ -0,0 +1,302 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// source: stats.proto
package grpc_testing
import proto "github.com/golang/protobuf/proto"
import fmt "fmt"
import math "math"
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
type ServerStats struct {
// wall clock time change in seconds since last reset
TimeElapsed float64 `protobuf:"fixed64,1,opt,name=time_elapsed,json=timeElapsed,proto3" json:"time_elapsed,omitempty"`
// change in user time (in seconds) used by the server since last reset
TimeUser float64 `protobuf:"fixed64,2,opt,name=time_user,json=timeUser,proto3" json:"time_user,omitempty"`
// change in server time (in seconds) used by the server process and all
// threads since last reset
TimeSystem float64 `protobuf:"fixed64,3,opt,name=time_system,json=timeSystem,proto3" json:"time_system,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *ServerStats) Reset() { *m = ServerStats{} }
func (m *ServerStats) String() string { return proto.CompactTextString(m) }
func (*ServerStats) ProtoMessage() {}
func (*ServerStats) Descriptor() ([]byte, []int) {
return fileDescriptor_stats_8ba831c0cb3c3440, []int{0}
}
func (m *ServerStats) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ServerStats.Unmarshal(m, b)
}
func (m *ServerStats) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_ServerStats.Marshal(b, m, deterministic)
}
func (dst *ServerStats) XXX_Merge(src proto.Message) {
xxx_messageInfo_ServerStats.Merge(dst, src)
}
func (m *ServerStats) XXX_Size() int {
return xxx_messageInfo_ServerStats.Size(m)
}
func (m *ServerStats) XXX_DiscardUnknown() {
xxx_messageInfo_ServerStats.DiscardUnknown(m)
}
var xxx_messageInfo_ServerStats proto.InternalMessageInfo
func (m *ServerStats) GetTimeElapsed() float64 {
if m != nil {
return m.TimeElapsed
}
return 0
}
func (m *ServerStats) GetTimeUser() float64 {
if m != nil {
return m.TimeUser
}
return 0
}
func (m *ServerStats) GetTimeSystem() float64 {
if m != nil {
return m.TimeSystem
}
return 0
}
// Histogram params based on grpc/support/histogram.c
type HistogramParams struct {
Resolution float64 `protobuf:"fixed64,1,opt,name=resolution,proto3" json:"resolution,omitempty"`
MaxPossible float64 `protobuf:"fixed64,2,opt,name=max_possible,json=maxPossible,proto3" json:"max_possible,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *HistogramParams) Reset() { *m = HistogramParams{} }
func (m *HistogramParams) String() string { return proto.CompactTextString(m) }
func (*HistogramParams) ProtoMessage() {}
func (*HistogramParams) Descriptor() ([]byte, []int) {
return fileDescriptor_stats_8ba831c0cb3c3440, []int{1}
}
func (m *HistogramParams) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_HistogramParams.Unmarshal(m, b)
}
func (m *HistogramParams) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_HistogramParams.Marshal(b, m, deterministic)
}
func (dst *HistogramParams) XXX_Merge(src proto.Message) {
xxx_messageInfo_HistogramParams.Merge(dst, src)
}
func (m *HistogramParams) XXX_Size() int {
return xxx_messageInfo_HistogramParams.Size(m)
}
func (m *HistogramParams) XXX_DiscardUnknown() {
xxx_messageInfo_HistogramParams.DiscardUnknown(m)
}
var xxx_messageInfo_HistogramParams proto.InternalMessageInfo
func (m *HistogramParams) GetResolution() float64 {
if m != nil {
return m.Resolution
}
return 0
}
func (m *HistogramParams) GetMaxPossible() float64 {
if m != nil {
return m.MaxPossible
}
return 0
}
// Histogram data based on grpc/support/histogram.c
type HistogramData struct {
Bucket []uint32 `protobuf:"varint,1,rep,packed,name=bucket,proto3" json:"bucket,omitempty"`
MinSeen float64 `protobuf:"fixed64,2,opt,name=min_seen,json=minSeen,proto3" json:"min_seen,omitempty"`
MaxSeen float64 `protobuf:"fixed64,3,opt,name=max_seen,json=maxSeen,proto3" json:"max_seen,omitempty"`
Sum float64 `protobuf:"fixed64,4,opt,name=sum,proto3" json:"sum,omitempty"`
SumOfSquares float64 `protobuf:"fixed64,5,opt,name=sum_of_squares,json=sumOfSquares,proto3" json:"sum_of_squares,omitempty"`
Count float64 `protobuf:"fixed64,6,opt,name=count,proto3" json:"count,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *HistogramData) Reset() { *m = HistogramData{} }
func (m *HistogramData) String() string { return proto.CompactTextString(m) }
func (*HistogramData) ProtoMessage() {}
func (*HistogramData) Descriptor() ([]byte, []int) {
return fileDescriptor_stats_8ba831c0cb3c3440, []int{2}
}
func (m *HistogramData) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_HistogramData.Unmarshal(m, b)
}
func (m *HistogramData) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_HistogramData.Marshal(b, m, deterministic)
}
func (dst *HistogramData) XXX_Merge(src proto.Message) {
xxx_messageInfo_HistogramData.Merge(dst, src)
}
func (m *HistogramData) XXX_Size() int {
return xxx_messageInfo_HistogramData.Size(m)
}
func (m *HistogramData) XXX_DiscardUnknown() {
xxx_messageInfo_HistogramData.DiscardUnknown(m)
}
var xxx_messageInfo_HistogramData proto.InternalMessageInfo
func (m *HistogramData) GetBucket() []uint32 {
if m != nil {
return m.Bucket
}
return nil
}
func (m *HistogramData) GetMinSeen() float64 {
if m != nil {
return m.MinSeen
}
return 0
}
func (m *HistogramData) GetMaxSeen() float64 {
if m != nil {
return m.MaxSeen
}
return 0
}
func (m *HistogramData) GetSum() float64 {
if m != nil {
return m.Sum
}
return 0
}
func (m *HistogramData) GetSumOfSquares() float64 {
if m != nil {
return m.SumOfSquares
}
return 0
}
func (m *HistogramData) GetCount() float64 {
if m != nil {
return m.Count
}
return 0
}
type ClientStats struct {
// Latency histogram. Data points are in nanoseconds.
Latencies *HistogramData `protobuf:"bytes,1,opt,name=latencies,proto3" json:"latencies,omitempty"`
// See ServerStats for details.
TimeElapsed float64 `protobuf:"fixed64,2,opt,name=time_elapsed,json=timeElapsed,proto3" json:"time_elapsed,omitempty"`
TimeUser float64 `protobuf:"fixed64,3,opt,name=time_user,json=timeUser,proto3" json:"time_user,omitempty"`
TimeSystem float64 `protobuf:"fixed64,4,opt,name=time_system,json=timeSystem,proto3" json:"time_system,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *ClientStats) Reset() { *m = ClientStats{} }
func (m *ClientStats) String() string { return proto.CompactTextString(m) }
func (*ClientStats) ProtoMessage() {}
func (*ClientStats) Descriptor() ([]byte, []int) {
return fileDescriptor_stats_8ba831c0cb3c3440, []int{3}
}
func (m *ClientStats) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ClientStats.Unmarshal(m, b)
}
func (m *ClientStats) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_ClientStats.Marshal(b, m, deterministic)
}
func (dst *ClientStats) XXX_Merge(src proto.Message) {
xxx_messageInfo_ClientStats.Merge(dst, src)
}
func (m *ClientStats) XXX_Size() int {
return xxx_messageInfo_ClientStats.Size(m)
}
func (m *ClientStats) XXX_DiscardUnknown() {
xxx_messageInfo_ClientStats.DiscardUnknown(m)
}
var xxx_messageInfo_ClientStats proto.InternalMessageInfo
func (m *ClientStats) GetLatencies() *HistogramData {
if m != nil {
return m.Latencies
}
return nil
}
func (m *ClientStats) GetTimeElapsed() float64 {
if m != nil {
return m.TimeElapsed
}
return 0
}
func (m *ClientStats) GetTimeUser() float64 {
if m != nil {
return m.TimeUser
}
return 0
}
func (m *ClientStats) GetTimeSystem() float64 {
if m != nil {
return m.TimeSystem
}
return 0
}
func init() {
proto.RegisterType((*ServerStats)(nil), "grpc.testing.ServerStats")
proto.RegisterType((*HistogramParams)(nil), "grpc.testing.HistogramParams")
proto.RegisterType((*HistogramData)(nil), "grpc.testing.HistogramData")
proto.RegisterType((*ClientStats)(nil), "grpc.testing.ClientStats")
}
func init() { proto.RegisterFile("stats.proto", fileDescriptor_stats_8ba831c0cb3c3440) }
var fileDescriptor_stats_8ba831c0cb3c3440 = []byte{
// 341 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x84, 0x92, 0xc1, 0x4a, 0xeb, 0x40,
0x14, 0x86, 0x49, 0xd3, 0xf6, 0xb6, 0x27, 0xed, 0xbd, 0x97, 0x41, 0x24, 0x52, 0xd0, 0x1a, 0x5c,
0x74, 0x95, 0x85, 0xae, 0x5c, 0xab, 0xe0, 0xce, 0xd2, 0xe8, 0x3a, 0x4c, 0xe3, 0x69, 0x19, 0xcc,
0xcc, 0xc4, 0x39, 0x33, 0x12, 0x1f, 0x49, 0x7c, 0x49, 0xc9, 0x24, 0x68, 0x55, 0xd0, 0x5d, 0xe6,
0xfb, 0x7e, 0xe6, 0xe4, 0xe4, 0x0f, 0x44, 0x64, 0xb9, 0xa5, 0xb4, 0x32, 0xda, 0x6a, 0x36, 0xd9,
0x9a, 0xaa, 0x48, 0x2d, 0x92, 0x15, 0x6a, 0x9b, 0x28, 0x88, 0x32, 0x34, 0x4f, 0x68, 0xb2, 0x26,
0xc2, 0x8e, 0x61, 0x62, 0x85, 0xc4, 0x1c, 0x4b, 0x5e, 0x11, 0xde, 0xc7, 0xc1, 0x3c, 0x58, 0x04,
0xab, 0xa8, 0x61, 0x57, 0x2d, 0x62, 0x33, 0x18, 0xfb, 0x88, 0x23, 0x34, 0x71, 0xcf, 0xfb, 0x51,
0x03, 0xee, 0x08, 0x0d, 0x3b, 0x02, 0x9f, 0xcd, 0xe9, 0x99, 0x2c, 0xca, 0x38, 0xf4, 0x1a, 0x1a,
0x94, 0x79, 0x92, 0xdc, 0xc2, 0xbf, 0x6b, 0x41, 0x56, 0x6f, 0x0d, 0x97, 0x4b, 0x6e, 0xb8, 0x24,
0x76, 0x08, 0x60, 0x90, 0x74, 0xe9, 0xac, 0xd0, 0xaa, 0x9b, 0xb8, 0x43, 0x9a, 0x77, 0x92, 0xbc,
0xce, 0x2b, 0x4d, 0x24, 0xd6, 0x25, 0x76, 0x33, 0x23, 0xc9, 0xeb, 0x65, 0x87, 0x92, 0xd7, 0x00,
0xa6, 0xef, 0xd7, 0x5e, 0x72, 0xcb, 0xd9, 0x3e, 0x0c, 0xd7, 0xae, 0x78, 0x40, 0x1b, 0x07, 0xf3,
0x70, 0x31, 0x5d, 0x75, 0x27, 0x76, 0x00, 0x23, 0x29, 0x54, 0x4e, 0x88, 0xaa, 0xbb, 0xe8, 0x8f,
0x14, 0x2a, 0x43, 0x54, 0x5e, 0xf1, 0xba, 0x55, 0x61, 0xa7, 0x78, 0xed, 0xd5, 0x7f, 0x08, 0xc9,
0xc9, 0xb8, 0xef, 0x69, 0xf3, 0xc8, 0x4e, 0xe0, 0x2f, 0x39, 0x99, 0xeb, 0x4d, 0x4e, 0x8f, 0x8e,
0x1b, 0xa4, 0x78, 0xe0, 0xe5, 0x84, 0x9c, 0xbc, 0xd9, 0x64, 0x2d, 0x63, 0x7b, 0x30, 0x28, 0xb4,
0x53, 0x36, 0x1e, 0x7a, 0xd9, 0x1e, 0x92, 0x97, 0x00, 0xa2, 0x8b, 0x52, 0xa0, 0xb2, 0xed, 0x47,
0x3f, 0x87, 0x71, 0xc9, 0x2d, 0xaa, 0x42, 0x20, 0xf9, 0xfd, 0xa3, 0xd3, 0x59, 0xba, 0xdb, 0x52,
0xfa, 0x69, 0xb7, 0xd5, 0x47, 0xfa, 0x5b, 0x5f, 0xbd, 0x5f, 0xfa, 0x0a, 0x7f, 0xee, 0xab, 0xff,
0xb5, 0xaf, 0xf5, 0xd0, 0xff, 0x34, 0x67, 0x6f, 0x01, 0x00, 0x00, 0xff, 0xff, 0xea, 0x75, 0x34,
0x90, 0x43, 0x02, 0x00, 0x00,
}

View file

@ -0,0 +1,55 @@
// Copyright 2016 gRPC authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package grpc.testing;
message ServerStats {
// wall clock time change in seconds since last reset
double time_elapsed = 1;
// change in user time (in seconds) used by the server since last reset
double time_user = 2;
// change in server time (in seconds) used by the server process and all
// threads since last reset
double time_system = 3;
}
// Histogram params based on grpc/support/histogram.c
message HistogramParams {
double resolution = 1; // first bucket is [0, 1 + resolution)
double max_possible = 2; // use enough buckets to allow this value
}
// Histogram data based on grpc/support/histogram.c
message HistogramData {
repeated uint32 bucket = 1;
double min_seen = 2;
double max_seen = 3;
double sum = 4;
double sum_of_squares = 5;
double count = 6;
}
message ClientStats {
// Latency histogram. Data points are in nanoseconds.
HistogramData latencies = 1;
// See ServerStats for details.
double time_elapsed = 2;
double time_user = 3;
double time_system = 4;
}

View file

@ -0,0 +1,316 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package latency provides wrappers for net.Conn, net.Listener, and
// net.Dialers, designed to interoperate to inject real-world latency into
// network connections.
package latency
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"net"
"time"
"golang.org/x/net/context"
)
// Dialer is a function matching the signature of net.Dial.
type Dialer func(network, address string) (net.Conn, error)
// TimeoutDialer is a function matching the signature of net.DialTimeout.
type TimeoutDialer func(network, address string, timeout time.Duration) (net.Conn, error)
// ContextDialer is a function matching the signature of
// net.Dialer.DialContext.
type ContextDialer func(ctx context.Context, network, address string) (net.Conn, error)
// Network represents a network with the given bandwidth, latency, and MTU
// (Maximum Transmission Unit) configuration, and can produce wrappers of
// net.Listeners, net.Conn, and various forms of dialing functions. The
// Listeners and Dialers/Conns on both sides of connections must come from this
// package, but need not be created from the same Network. Latency is computed
// when sending (in Write), and is injected when receiving (in Read). This
// allows senders' Write calls to be non-blocking, as in real-world
// applications.
//
// Note: Latency is injected by the sender specifying the absolute time data
// should be available, and the reader delaying until that time arrives to
// provide the data. This package attempts to counter-act the effects of clock
// drift and existing network latency by measuring the delay between the
// sender's transmission time and the receiver's reception time during startup.
// No attempt is made to measure the existing bandwidth of the connection.
type Network struct {
Kbps int // Kilobits per second; if non-positive, infinite
Latency time.Duration // One-way latency (sending); if non-positive, no delay
MTU int // Bytes per packet; if non-positive, infinite
}
var (
//Local simulates local network.
Local = Network{0, 0, 0}
//LAN simulates local area network network.
LAN = Network{100 * 1024, 2 * time.Millisecond, 1500}
//WAN simulates wide area network.
WAN = Network{20 * 1024, 30 * time.Millisecond, 1500}
//Longhaul simulates bad network.
Longhaul = Network{1000 * 1024, 200 * time.Millisecond, 9000}
)
// Conn returns a net.Conn that wraps c and injects n's latency into that
// connection. This function also imposes latency for connection creation.
// If n's Latency is lower than the measured latency in c, an error is
// returned.
func (n *Network) Conn(c net.Conn) (net.Conn, error) {
start := now()
nc := &conn{Conn: c, network: n, readBuf: new(bytes.Buffer)}
if err := nc.sync(); err != nil {
return nil, err
}
sleep(start.Add(nc.delay).Sub(now()))
return nc, nil
}
type conn struct {
net.Conn
network *Network
readBuf *bytes.Buffer // one packet worth of data received
lastSendEnd time.Time // time the previous Write should be fully on the wire
delay time.Duration // desired latency - measured latency
}
// header is sent before all data transmitted by the application.
type header struct {
ReadTime int64 // Time the reader is allowed to read this packet (UnixNano)
Sz int32 // Size of the data in the packet
}
func (c *conn) Write(p []byte) (n int, err error) {
tNow := now()
if c.lastSendEnd.Before(tNow) {
c.lastSendEnd = tNow
}
for len(p) > 0 {
pkt := p
if c.network.MTU > 0 && len(pkt) > c.network.MTU {
pkt = pkt[:c.network.MTU]
p = p[c.network.MTU:]
} else {
p = nil
}
if c.network.Kbps > 0 {
if congestion := c.lastSendEnd.Sub(tNow) - c.delay; congestion > 0 {
// The network is full; sleep until this packet can be sent.
sleep(congestion)
tNow = tNow.Add(congestion)
}
}
c.lastSendEnd = c.lastSendEnd.Add(c.network.pktTime(len(pkt)))
hdr := header{ReadTime: c.lastSendEnd.Add(c.delay).UnixNano(), Sz: int32(len(pkt))}
if err := binary.Write(c.Conn, binary.BigEndian, hdr); err != nil {
return n, err
}
x, err := c.Conn.Write(pkt)
n += x
if err != nil {
return n, err
}
}
return n, nil
}
func (c *conn) Read(p []byte) (n int, err error) {
if c.readBuf.Len() == 0 {
var hdr header
if err := binary.Read(c.Conn, binary.BigEndian, &hdr); err != nil {
return 0, err
}
defer func() { sleep(time.Unix(0, hdr.ReadTime).Sub(now())) }()
if _, err := io.CopyN(c.readBuf, c.Conn, int64(hdr.Sz)); err != nil {
return 0, err
}
}
// Read from readBuf.
return c.readBuf.Read(p)
}
// sync does a handshake and then measures the latency on the network in
// coordination with the other side.
func (c *conn) sync() error {
const (
pingMsg = "syncPing"
warmup = 10 // minimum number of iterations to measure latency
giveUp = 50 // maximum number of iterations to measure latency
accuracy = time.Millisecond // req'd accuracy to stop early
goodRun = 3 // stop early if latency within accuracy this many times
)
type syncMsg struct {
SendT int64 // Time sent. If zero, stop.
RecvT int64 // Time received. If zero, fill in and respond.
}
// A trivial handshake
if err := binary.Write(c.Conn, binary.BigEndian, []byte(pingMsg)); err != nil {
return err
}
var ping [8]byte
if err := binary.Read(c.Conn, binary.BigEndian, &ping); err != nil {
return err
} else if string(ping[:]) != pingMsg {
return fmt.Errorf("malformed handshake message: %v (want %q)", ping, pingMsg)
}
// Both sides are alive and syncing. Calculate network delay / clock skew.
att := 0
good := 0
var latency time.Duration
localDone, remoteDone := false, false
send := true
for !localDone || !remoteDone {
if send {
if err := binary.Write(c.Conn, binary.BigEndian, syncMsg{SendT: now().UnixNano()}); err != nil {
return err
}
att++
send = false
}
// Block until we get a syncMsg
m := syncMsg{}
if err := binary.Read(c.Conn, binary.BigEndian, &m); err != nil {
return err
}
if m.RecvT == 0 {
// Message initiated from other side.
if m.SendT == 0 {
remoteDone = true
continue
}
// Send response.
m.RecvT = now().UnixNano()
if err := binary.Write(c.Conn, binary.BigEndian, m); err != nil {
return err
}
continue
}
lag := time.Duration(m.RecvT - m.SendT)
latency += lag
avgLatency := latency / time.Duration(att)
if e := lag - avgLatency; e > -accuracy && e < accuracy {
good++
} else {
good = 0
}
if att < giveUp && (att < warmup || good < goodRun) {
send = true
continue
}
localDone = true
latency = avgLatency
// Tell the other side we're done.
if err := binary.Write(c.Conn, binary.BigEndian, syncMsg{}); err != nil {
return err
}
}
if c.network.Latency <= 0 {
return nil
}
c.delay = c.network.Latency - latency
if c.delay < 0 {
return fmt.Errorf("measured network latency (%v) higher than desired latency (%v)", latency, c.network.Latency)
}
return nil
}
// Listener returns a net.Listener that wraps l and injects n's latency in its
// connections.
func (n *Network) Listener(l net.Listener) net.Listener {
return &listener{Listener: l, network: n}
}
type listener struct {
net.Listener
network *Network
}
func (l *listener) Accept() (net.Conn, error) {
c, err := l.Listener.Accept()
if err != nil {
return nil, err
}
return l.network.Conn(c)
}
// Dialer returns a Dialer that wraps d and injects n's latency in its
// connections. n's Latency is also injected to the connection's creation.
func (n *Network) Dialer(d Dialer) Dialer {
return func(network, address string) (net.Conn, error) {
conn, err := d(network, address)
if err != nil {
return nil, err
}
return n.Conn(conn)
}
}
// TimeoutDialer returns a TimeoutDialer that wraps d and injects n's latency
// in its connections. n's Latency is also injected to the connection's
// creation.
func (n *Network) TimeoutDialer(d TimeoutDialer) TimeoutDialer {
return func(network, address string, timeout time.Duration) (net.Conn, error) {
conn, err := d(network, address, timeout)
if err != nil {
return nil, err
}
return n.Conn(conn)
}
}
// ContextDialer returns a ContextDialer that wraps d and injects n's latency
// in its connections. n's Latency is also injected to the connection's
// creation.
func (n *Network) ContextDialer(d ContextDialer) ContextDialer {
return func(ctx context.Context, network, address string) (net.Conn, error) {
conn, err := d(ctx, network, address)
if err != nil {
return nil, err
}
return n.Conn(conn)
}
}
// pktTime returns the time it takes to transmit one packet of data of size b
// in bytes.
func (n *Network) pktTime(b int) time.Duration {
if n.Kbps <= 0 {
return time.Duration(0)
}
return time.Duration(b) * time.Second / time.Duration(n.Kbps*(1024/8))
}
// Wrappers for testing
var now = time.Now
var sleep = time.Sleep

View file

@ -0,0 +1,353 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package latency
import (
"bytes"
"fmt"
"net"
"reflect"
"sync"
"testing"
"time"
)
// bufConn is a net.Conn implemented by a bytes.Buffer (which is a ReadWriter).
type bufConn struct {
*bytes.Buffer
}
func (bufConn) Close() error { panic("unimplemented") }
func (bufConn) LocalAddr() net.Addr { panic("unimplemented") }
func (bufConn) RemoteAddr() net.Addr { panic("unimplemented") }
func (bufConn) SetDeadline(t time.Time) error { panic("unimplemneted") }
func (bufConn) SetReadDeadline(t time.Time) error { panic("unimplemneted") }
func (bufConn) SetWriteDeadline(t time.Time) error { panic("unimplemneted") }
func restoreHooks() func() {
s := sleep
n := now
return func() {
sleep = s
now = n
}
}
func TestConn(t *testing.T) {
defer restoreHooks()()
// Constant time.
now = func() time.Time { return time.Unix(123, 456) }
// Capture sleep times for checking later.
var sleepTimes []time.Duration
sleep = func(t time.Duration) { sleepTimes = append(sleepTimes, t) }
wantSleeps := func(want ...time.Duration) {
if !reflect.DeepEqual(want, sleepTimes) {
t.Fatalf("sleepTimes = %v; want %v", sleepTimes, want)
}
sleepTimes = nil
}
// Use a fairly high latency to cause a large BDP and avoid sleeps while
// writing due to simulation of full buffers.
latency := 1 * time.Second
c, err := (&Network{Kbps: 1, Latency: latency, MTU: 5}).Conn(bufConn{&bytes.Buffer{}})
if err != nil {
t.Fatalf("Unexpected error creating connection: %v", err)
}
wantSleeps(latency) // Connection creation delay.
// 1 kbps = 128 Bps. Divides evenly by 1 second using nanos.
byteLatency := time.Duration(time.Second / 128)
write := func(b []byte) {
n, err := c.Write(b)
if n != len(b) || err != nil {
t.Fatalf("c.Write(%v) = %v, %v; want %v, nil", b, n, err, len(b))
}
}
write([]byte{1, 2, 3, 4, 5}) // One full packet
pkt1Time := latency + byteLatency*5
write([]byte{6}) // One partial packet
pkt2Time := pkt1Time + byteLatency
write([]byte{7, 8, 9, 10, 11, 12, 13}) // Two packets
pkt3Time := pkt2Time + byteLatency*5
pkt4Time := pkt3Time + byteLatency*2
// No reads, so no sleeps yet.
wantSleeps()
read := func(n int, want []byte) {
b := make([]byte, n)
if rd, err := c.Read(b); err != nil || rd != len(want) {
t.Fatalf("c.Read(<%v bytes>) = %v, %v; want %v, nil", n, rd, err, len(want))
}
if !reflect.DeepEqual(b[:len(want)], want) {
t.Fatalf("read %v; want %v", b, want)
}
}
read(1, []byte{1})
wantSleeps(pkt1Time)
read(1, []byte{2})
wantSleeps()
read(3, []byte{3, 4, 5})
wantSleeps()
read(2, []byte{6})
wantSleeps(pkt2Time)
read(2, []byte{7, 8})
wantSleeps(pkt3Time)
read(10, []byte{9, 10, 11})
wantSleeps()
read(10, []byte{12, 13})
wantSleeps(pkt4Time)
}
func TestSync(t *testing.T) {
defer restoreHooks()()
// Infinitely fast CPU: time doesn't pass unless sleep is called.
tn := time.Unix(123, 0)
now = func() time.Time { return tn }
sleep = func(d time.Duration) { tn = tn.Add(d) }
// Simulate a 20ms latency network, then run sync across that and expect to
// measure 20ms latency, or 10ms additional delay for a 30ms network.
slowConn, err := (&Network{Kbps: 0, Latency: 20 * time.Millisecond, MTU: 5}).Conn(bufConn{&bytes.Buffer{}})
if err != nil {
t.Fatalf("Unexpected error creating connection: %v", err)
}
c, err := (&Network{Latency: 30 * time.Millisecond}).Conn(slowConn)
if err != nil {
t.Fatalf("Unexpected error creating connection: %v", err)
}
if c.(*conn).delay != 10*time.Millisecond {
t.Fatalf("c.delay = %v; want 10ms", c.(*conn).delay)
}
}
func TestSyncTooSlow(t *testing.T) {
defer restoreHooks()()
// Infinitely fast CPU: time doesn't pass unless sleep is called.
tn := time.Unix(123, 0)
now = func() time.Time { return tn }
sleep = func(d time.Duration) { tn = tn.Add(d) }
// Simulate a 10ms latency network, then attempt to simulate a 5ms latency
// network and expect an error.
slowConn, err := (&Network{Kbps: 0, Latency: 10 * time.Millisecond, MTU: 5}).Conn(bufConn{&bytes.Buffer{}})
if err != nil {
t.Fatalf("Unexpected error creating connection: %v", err)
}
errWant := "measured network latency (10ms) higher than desired latency (5ms)"
if _, err := (&Network{Latency: 5 * time.Millisecond}).Conn(slowConn); err == nil || err.Error() != errWant {
t.Fatalf("Conn() = _, %q; want _, %q", err, errWant)
}
}
func TestListenerAndDialer(t *testing.T) {
defer restoreHooks()()
tn := time.Unix(123, 0)
startTime := tn
mu := &sync.Mutex{}
now = func() time.Time {
mu.Lock()
defer mu.Unlock()
return tn
}
// Use a fairly high latency to cause a large BDP and avoid sleeps while
// writing due to simulation of full buffers.
n := &Network{Kbps: 2, Latency: 1 * time.Second, MTU: 10}
// 2 kbps = .25 kBps = 256 Bps
byteLatency := func(n int) time.Duration {
return time.Duration(n) * time.Second / 256
}
// Create a real listener and wrap it.
l, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Unexpected error creating listener: %v", err)
}
defer l.Close()
l = n.Listener(l)
var serverConn net.Conn
var scErr error
scDone := make(chan struct{})
go func() {
serverConn, scErr = l.Accept()
close(scDone)
}()
// Create a dialer and use it.
clientConn, err := n.TimeoutDialer(net.DialTimeout)("tcp", l.Addr().String(), 2*time.Second)
if err != nil {
t.Fatalf("Unexpected error dialing: %v", err)
}
defer clientConn.Close()
// Block until server's Conn is available.
<-scDone
if scErr != nil {
t.Fatalf("Unexpected error listening: %v", scErr)
}
defer serverConn.Close()
// sleep (only) advances tn. Done after connections established so sync detects zero delay.
sleep = func(d time.Duration) {
mu.Lock()
defer mu.Unlock()
if d > 0 {
tn = tn.Add(d)
}
}
seq := func(a, b int) []byte {
buf := make([]byte, b-a)
for i := 0; i < b-a; i++ {
buf[i] = byte(i + a)
}
return buf
}
pkt1 := seq(0, 10)
pkt2 := seq(10, 30)
pkt3 := seq(30, 35)
write := func(c net.Conn, b []byte) {
n, err := c.Write(b)
if n != len(b) || err != nil {
t.Fatalf("c.Write(%v) = %v, %v; want %v, nil", b, n, err, len(b))
}
}
write(serverConn, pkt1)
write(serverConn, pkt2)
write(serverConn, pkt3)
write(clientConn, pkt3)
write(clientConn, pkt1)
write(clientConn, pkt2)
if tn != startTime {
t.Fatalf("unexpected sleep in write; tn = %v; want %v", tn, startTime)
}
read := func(c net.Conn, n int, want []byte, timeWant time.Time) {
b := make([]byte, n)
if rd, err := c.Read(b); err != nil || rd != len(want) {
t.Fatalf("c.Read(<%v bytes>) = %v, %v; want %v, nil (read: %v)", n, rd, err, len(want), b[:rd])
}
if !reflect.DeepEqual(b[:len(want)], want) {
t.Fatalf("read %v; want %v", b, want)
}
if !tn.Equal(timeWant) {
t.Errorf("tn after read(%v) = %v; want %v", want, tn, timeWant)
}
}
read(clientConn, len(pkt1)+1, pkt1, startTime.Add(n.Latency+byteLatency(len(pkt1))))
read(serverConn, len(pkt3)+1, pkt3, tn) // tn was advanced by the above read; pkt3 is shorter than pkt1
read(clientConn, len(pkt2), pkt2[:10], startTime.Add(n.Latency+byteLatency(len(pkt1)+10)))
read(clientConn, len(pkt2), pkt2[10:], startTime.Add(n.Latency+byteLatency(len(pkt1)+len(pkt2))))
read(clientConn, len(pkt3), pkt3, startTime.Add(n.Latency+byteLatency(len(pkt1)+len(pkt2)+len(pkt3))))
read(serverConn, len(pkt1), pkt1, tn) // tn already past the arrival time due to prior reads
read(serverConn, len(pkt2), pkt2[:10], tn)
read(serverConn, len(pkt2), pkt2[10:], tn)
// Sleep awhile and make sure the read happens disregarding previous writes
// (lastSendEnd handling).
sleep(10 * time.Second)
write(clientConn, pkt1)
read(serverConn, len(pkt1), pkt1, tn.Add(n.Latency+byteLatency(len(pkt1))))
// Send, sleep longer than the network delay, then make sure the read happens
// instantly.
write(serverConn, pkt1)
sleep(10 * time.Second)
read(clientConn, len(pkt1), pkt1, tn)
}
func TestBufferBloat(t *testing.T) {
defer restoreHooks()()
// Infinitely fast CPU: time doesn't pass unless sleep is called.
tn := time.Unix(123, 0)
now = func() time.Time { return tn }
// Capture sleep times for checking later.
var sleepTimes []time.Duration
sleep = func(d time.Duration) {
sleepTimes = append(sleepTimes, d)
tn = tn.Add(d)
}
wantSleeps := func(want ...time.Duration) error {
if !reflect.DeepEqual(want, sleepTimes) {
return fmt.Errorf("sleepTimes = %v; want %v", sleepTimes, want)
}
sleepTimes = nil
return nil
}
n := &Network{Kbps: 8 /* 1KBps */, Latency: time.Second, MTU: 8}
bdpBytes := (n.Kbps * 1024 / 8) * int(n.Latency/time.Second) // 1024
c, err := n.Conn(bufConn{&bytes.Buffer{}})
if err != nil {
t.Fatalf("Unexpected error creating connection: %v", err)
}
wantSleeps(n.Latency) // Connection creation delay.
write := func(n int, sleeps ...time.Duration) {
if wt, err := c.Write(make([]byte, n)); err != nil || wt != n {
t.Fatalf("c.Write(<%v bytes>) = %v, %v; want %v, nil", n, wt, err, n)
}
if err := wantSleeps(sleeps...); err != nil {
t.Fatalf("After writing %v bytes: %v", n, err)
}
}
read := func(n int, sleeps ...time.Duration) {
if rd, err := c.Read(make([]byte, n)); err != nil || rd != n {
t.Fatalf("c.Read(_) = %v, %v; want %v, nil", rd, err, n)
}
if err := wantSleeps(sleeps...); err != nil {
t.Fatalf("After reading %v bytes: %v", n, err)
}
}
write(8) // No reads and buffer not full, so no sleeps yet.
read(8, time.Second+n.pktTime(8))
write(bdpBytes) // Fill the buffer.
write(1) // We can send one extra packet even when the buffer is full.
write(n.MTU, n.pktTime(1)) // Make sure we sleep to clear the previous write.
write(1, n.pktTime(n.MTU))
write(n.MTU+1, n.pktTime(1), n.pktTime(n.MTU))
tn = tn.Add(10 * time.Second) // Wait long enough for the buffer to clear.
write(bdpBytes) // No sleeps required.
}

View file

@ -0,0 +1,135 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package primitives_test
import (
"strconv"
"testing"
"google.golang.org/grpc/codes"
)
type codeBench uint32
const (
OK codeBench = iota
Canceled
Unknown
InvalidArgument
DeadlineExceeded
NotFound
AlreadyExists
PermissionDenied
ResourceExhausted
FailedPrecondition
Aborted
OutOfRange
Unimplemented
Internal
Unavailable
DataLoss
Unauthenticated
)
// The following String() function was generated by stringer.
const _Code_name = "OKCanceledUnknownInvalidArgumentDeadlineExceededNotFoundAlreadyExistsPermissionDeniedResourceExhaustedFailedPreconditionAbortedOutOfRangeUnimplementedInternalUnavailableDataLossUnauthenticated"
var _Code_index = [...]uint8{0, 2, 10, 17, 32, 48, 56, 69, 85, 102, 120, 127, 137, 150, 158, 169, 177, 192}
func (i codeBench) String() string {
if i >= codeBench(len(_Code_index)-1) {
return "Code(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _Code_name[_Code_index[i]:_Code_index[i+1]]
}
var nameMap = map[codeBench]string{
OK: "OK",
Canceled: "Canceled",
Unknown: "Unknown",
InvalidArgument: "InvalidArgument",
DeadlineExceeded: "DeadlineExceeded",
NotFound: "NotFound",
AlreadyExists: "AlreadyExists",
PermissionDenied: "PermissionDenied",
ResourceExhausted: "ResourceExhausted",
FailedPrecondition: "FailedPrecondition",
Aborted: "Aborted",
OutOfRange: "OutOfRange",
Unimplemented: "Unimplemented",
Internal: "Internal",
Unavailable: "Unavailable",
DataLoss: "DataLoss",
Unauthenticated: "Unauthenticated",
}
func (i codeBench) StringUsingMap() string {
if s, ok := nameMap[i]; ok {
return s
}
return "Code(" + strconv.FormatInt(int64(i), 10) + ")"
}
func BenchmarkCodeStringStringer(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
c := codeBench(uint32(i % 17))
_ = c.String()
}
b.StopTimer()
}
func BenchmarkCodeStringMap(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
c := codeBench(uint32(i % 17))
_ = c.StringUsingMap()
}
b.StopTimer()
}
// codes.Code.String() does a switch.
func BenchmarkCodeStringSwitch(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
c := codes.Code(uint32(i % 17))
_ = c.String()
}
b.StopTimer()
}
// Testing all codes (0<=c<=16) and also one overflow (17).
func BenchmarkCodeStringStringerWithOverflow(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
c := codeBench(uint32(i % 18))
_ = c.String()
}
b.StopTimer()
}
// Testing all codes (0<=c<=16) and also one overflow (17).
func BenchmarkCodeStringSwitchWithOverflow(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
c := codes.Code(uint32(i % 18))
_ = c.String()
}
b.StopTimer()
}

View file

@ -0,0 +1,120 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package primitives_test
import (
"testing"
"time"
"golang.org/x/net/context"
)
func BenchmarkCancelContextErrNoErr(b *testing.B) {
ctx, cancel := context.WithCancel(context.Background())
for i := 0; i < b.N; i++ {
if err := ctx.Err(); err != nil {
b.Fatal("error")
}
}
cancel()
}
func BenchmarkCancelContextErrGotErr(b *testing.B) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
for i := 0; i < b.N; i++ {
if err := ctx.Err(); err == nil {
b.Fatal("error")
}
}
}
func BenchmarkCancelContextChannelNoErr(b *testing.B) {
ctx, cancel := context.WithCancel(context.Background())
for i := 0; i < b.N; i++ {
select {
case <-ctx.Done():
b.Fatal("error: ctx.Done():", ctx.Err())
default:
}
}
cancel()
}
func BenchmarkCancelContextChannelGotErr(b *testing.B) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
for i := 0; i < b.N; i++ {
select {
case <-ctx.Done():
if err := ctx.Err(); err == nil {
b.Fatal("error")
}
default:
b.Fatal("error: !ctx.Done()")
}
}
}
func BenchmarkTimerContextErrNoErr(b *testing.B) {
ctx, cancel := context.WithTimeout(context.Background(), 24*time.Hour)
for i := 0; i < b.N; i++ {
if err := ctx.Err(); err != nil {
b.Fatal("error")
}
}
cancel()
}
func BenchmarkTimerContextErrGotErr(b *testing.B) {
ctx, cancel := context.WithTimeout(context.Background(), time.Microsecond)
cancel()
for i := 0; i < b.N; i++ {
if err := ctx.Err(); err == nil {
b.Fatal("error")
}
}
}
func BenchmarkTimerContextChannelNoErr(b *testing.B) {
ctx, cancel := context.WithTimeout(context.Background(), 24*time.Hour)
for i := 0; i < b.N; i++ {
select {
case <-ctx.Done():
b.Fatal("error: ctx.Done():", ctx.Err())
default:
}
}
cancel()
}
func BenchmarkTimerContextChannelGotErr(b *testing.B) {
ctx, cancel := context.WithTimeout(context.Background(), time.Microsecond)
cancel()
for i := 0; i < b.N; i++ {
select {
case <-ctx.Done():
if err := ctx.Err(); err == nil {
b.Fatal("error")
}
default:
b.Fatal("error: !ctx.Done()")
}
}
}

View file

@ -0,0 +1,403 @@
// +build go1.7
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package primitives_test contains benchmarks for various synchronization primitives
// available in Go.
package primitives_test
import (
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
"unsafe"
)
func BenchmarkSelectClosed(b *testing.B) {
c := make(chan struct{})
close(c)
x := 0
b.ResetTimer()
for i := 0; i < b.N; i++ {
select {
case <-c:
x++
default:
}
}
b.StopTimer()
if x != b.N {
b.Fatal("error")
}
}
func BenchmarkSelectOpen(b *testing.B) {
c := make(chan struct{})
x := 0
b.ResetTimer()
for i := 0; i < b.N; i++ {
select {
case <-c:
default:
x++
}
}
b.StopTimer()
if x != b.N {
b.Fatal("error")
}
}
func BenchmarkAtomicBool(b *testing.B) {
c := int32(0)
x := 0
b.ResetTimer()
for i := 0; i < b.N; i++ {
if atomic.LoadInt32(&c) == 0 {
x++
}
}
b.StopTimer()
if x != b.N {
b.Fatal("error")
}
}
func BenchmarkAtomicValueLoad(b *testing.B) {
c := atomic.Value{}
c.Store(0)
x := 0
b.ResetTimer()
for i := 0; i < b.N; i++ {
if c.Load().(int) == 0 {
x++
}
}
b.StopTimer()
if x != b.N {
b.Fatal("error")
}
}
func BenchmarkAtomicValueStore(b *testing.B) {
c := atomic.Value{}
v := 123
b.ResetTimer()
for i := 0; i < b.N; i++ {
c.Store(v)
}
b.StopTimer()
}
func BenchmarkMutex(b *testing.B) {
c := sync.Mutex{}
x := 0
b.ResetTimer()
for i := 0; i < b.N; i++ {
c.Lock()
x++
c.Unlock()
}
b.StopTimer()
if x != b.N {
b.Fatal("error")
}
}
func BenchmarkRWMutex(b *testing.B) {
c := sync.RWMutex{}
x := 0
b.ResetTimer()
for i := 0; i < b.N; i++ {
c.RLock()
x++
c.RUnlock()
}
b.StopTimer()
if x != b.N {
b.Fatal("error")
}
}
func BenchmarkRWMutexW(b *testing.B) {
c := sync.RWMutex{}
x := 0
b.ResetTimer()
for i := 0; i < b.N; i++ {
c.Lock()
x++
c.Unlock()
}
b.StopTimer()
if x != b.N {
b.Fatal("error")
}
}
func BenchmarkMutexWithDefer(b *testing.B) {
c := sync.Mutex{}
x := 0
b.ResetTimer()
for i := 0; i < b.N; i++ {
func() {
c.Lock()
defer c.Unlock()
x++
}()
}
b.StopTimer()
if x != b.N {
b.Fatal("error")
}
}
func BenchmarkMutexWithClosureDefer(b *testing.B) {
c := sync.Mutex{}
x := 0
b.ResetTimer()
for i := 0; i < b.N; i++ {
func() {
c.Lock()
defer func() { c.Unlock() }()
x++
}()
}
b.StopTimer()
if x != b.N {
b.Fatal("error")
}
}
func BenchmarkMutexWithoutDefer(b *testing.B) {
c := sync.Mutex{}
x := 0
b.ResetTimer()
for i := 0; i < b.N; i++ {
func() {
c.Lock()
x++
c.Unlock()
}()
}
b.StopTimer()
if x != b.N {
b.Fatal("error")
}
}
func BenchmarkAtomicAddInt64(b *testing.B) {
var c int64
b.ResetTimer()
for i := 0; i < b.N; i++ {
atomic.AddInt64(&c, 1)
}
b.StopTimer()
if c != int64(b.N) {
b.Fatal("error")
}
}
func BenchmarkAtomicTimeValueStore(b *testing.B) {
var c atomic.Value
t := time.Now()
b.ResetTimer()
for i := 0; i < b.N; i++ {
c.Store(t)
}
b.StopTimer()
}
func BenchmarkAtomic16BValueStore(b *testing.B) {
var c atomic.Value
t := struct {
a int64
b int64
}{
123, 123,
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
c.Store(t)
}
b.StopTimer()
}
func BenchmarkAtomic32BValueStore(b *testing.B) {
var c atomic.Value
t := struct {
a int64
b int64
c int64
d int64
}{
123, 123, 123, 123,
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
c.Store(t)
}
b.StopTimer()
}
func BenchmarkAtomicPointerStore(b *testing.B) {
t := 123
var up unsafe.Pointer
b.ResetTimer()
for i := 0; i < b.N; i++ {
atomic.StorePointer(&up, unsafe.Pointer(&t))
}
b.StopTimer()
}
func BenchmarkAtomicTimePointerStore(b *testing.B) {
t := time.Now()
var up unsafe.Pointer
b.ResetTimer()
for i := 0; i < b.N; i++ {
atomic.StorePointer(&up, unsafe.Pointer(&t))
}
b.StopTimer()
}
func BenchmarkStoreContentionWithAtomic(b *testing.B) {
t := 123
var c unsafe.Pointer
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
atomic.StorePointer(&c, unsafe.Pointer(&t))
}
})
}
func BenchmarkStoreContentionWithMutex(b *testing.B) {
t := 123
var mu sync.Mutex
var c int
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
mu.Lock()
c = t
mu.Unlock()
}
})
_ = c
}
type dummyStruct struct {
a int64
b time.Time
}
func BenchmarkStructStoreContention(b *testing.B) {
d := dummyStruct{}
dp := unsafe.Pointer(&d)
t := time.Now()
for _, j := range []int{100000000, 10000, 0} {
for _, i := range []int{100000, 10} {
b.Run(fmt.Sprintf("CAS/%v/%v", j, i), func(b *testing.B) {
b.SetParallelism(i)
b.RunParallel(func(pb *testing.PB) {
n := &dummyStruct{
b: t,
}
for pb.Next() {
for y := 0; y < j; y++ {
}
for {
v := (*dummyStruct)(atomic.LoadPointer(&dp))
n.a = v.a + 1
if atomic.CompareAndSwapPointer(&dp, unsafe.Pointer(v), unsafe.Pointer(n)) {
n = v
break
}
}
}
})
})
}
}
var mu sync.Mutex
for _, j := range []int{100000000, 10000, 0} {
for _, i := range []int{100000, 10} {
b.Run(fmt.Sprintf("Mutex/%v/%v", j, i), func(b *testing.B) {
b.SetParallelism(i)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
for y := 0; y < j; y++ {
}
mu.Lock()
d.a++
d.b = t
mu.Unlock()
}
})
})
}
}
}
type myFooer struct{}
func (myFooer) Foo() {}
type fooer interface {
Foo()
}
func BenchmarkInterfaceTypeAssertion(b *testing.B) {
// Call a separate function to avoid compiler optimizations.
runInterfaceTypeAssertion(b, myFooer{})
}
func runInterfaceTypeAssertion(b *testing.B, fer interface{}) {
x := 0
b.ResetTimer()
for i := 0; i < b.N; i++ {
if _, ok := fer.(fooer); ok {
x++
}
}
b.StopTimer()
if x != b.N {
b.Fatal("error")
}
}
func BenchmarkStructTypeAssertion(b *testing.B) {
// Call a separate function to avoid compiler optimizations.
runStructTypeAssertion(b, myFooer{})
}
func runStructTypeAssertion(b *testing.B, fer interface{}) {
x := 0
b.ResetTimer()
for i := 0; i < b.N; i++ {
if _, ok := fer.(myFooer); ok {
x++
}
}
b.StopTimer()
if x != b.N {
b.Fatal("error")
}
}

187
vendor/google.golang.org/grpc/benchmark/run_bench.sh generated vendored Executable file
View file

@ -0,0 +1,187 @@
#!/bin/bash
rpcs=(1)
conns=(1)
warmup=10
dur=10
reqs=(1)
resps=(1)
rpc_types=(unary)
# idx[0] = idx value for rpcs
# idx[1] = idx value for conns
# idx[2] = idx value for reqs
# idx[3] = idx value for resps
# idx[4] = idx value for rpc_types
idx=(0 0 0 0 0)
idx_max=(1 1 1 1 1)
inc()
{
for i in $(seq $((${#idx[@]}-1)) -1 0); do
idx[${i}]=$((${idx[${i}]}+1))
if [ ${idx[${i}]} == ${idx_max[${i}]} ]; then
idx[${i}]=0
else
break
fi
done
local fin
fin=1
# Check to see if we have looped back to the beginning.
for v in ${idx[@]}; do
if [ ${v} != 0 ]; then
fin=0
break
fi
done
if [ ${fin} == 1 ]; then
rm -Rf ${out_dir}
clean_and_die 0
fi
}
clean_and_die() {
rm -Rf ${out_dir}
exit $1
}
run(){
local nr
nr=${rpcs[${idx[0]}]}
local nc
nc=${conns[${idx[1]}]}
req_sz=${reqs[${idx[2]}]}
resp_sz=${resps[${idx[3]}]}
r_type=${rpc_types[${idx[4]}]}
# Following runs one benchmark
base_port=50051
delta=0
test_name="r_"${nr}"_c_"${nc}"_req_"${req_sz}"_resp_"${resp_sz}"_"${r_type}"_"$(date +%s)
echo "================================================================================"
echo ${test_name}
while :
do
port=$((${base_port}+${delta}))
# Launch the server in background
${out_dir}/server --port=${port} --test_name="Server_"${test_name}&
server_pid=$(echo $!)
# Launch the client
${out_dir}/client --port=${port} --d=${dur} --w=${warmup} --r=${nr} --c=${nc} --req=${req_sz} --resp=${resp_sz} --rpc_type=${r_type} --test_name="client_"${test_name}
client_status=$(echo $?)
kill -INT ${server_pid}
wait ${server_pid}
if [ ${client_status} == 0 ]; then
break
fi
delta=$((${delta}+1))
if [ ${delta} == 10 ]; then
echo "Continuous 10 failed runs. Exiting now."
rm -Rf ${out_dir}
clean_and_die 1
fi
done
}
set_param(){
local argname=$1
shift
local idx=$1
shift
if [ $# -eq 0 ]; then
echo "${argname} not specified"
exit 1
fi
PARAM=($(echo $1 | sed 's/,/ /g'))
if [ ${idx} -lt 0 ]; then
return
fi
idx_max[${idx}]=${#PARAM[@]}
}
while [ $# -gt 0 ]; do
case "$1" in
-r)
shift
set_param "number of rpcs" 0 $1
rpcs=(${PARAM[@]})
shift
;;
-c)
shift
set_param "number of connections" 1 $1
conns=(${PARAM[@]})
shift
;;
-w)
shift
set_param "warm-up period" -1 $1
warmup=${PARAM}
shift
;;
-d)
shift
set_param "duration" -1 $1
dur=${PARAM}
shift
;;
-req)
shift
set_param "request size" 2 $1
reqs=(${PARAM[@]})
shift
;;
-resp)
shift
set_param "response size" 3 $1
resps=(${PARAM[@]})
shift
;;
-rpc_type)
shift
set_param "rpc type" 4 $1
rpc_types=(${PARAM[@]})
shift
;;
-h|--help)
echo "Following are valid options:"
echo
echo "-h, --help show brief help"
echo "-w warm-up duration in seconds, default value is 10"
echo "-d benchmark duration in seconds, default value is 60"
echo ""
echo "Each of the following can have multiple comma separated values."
echo ""
echo "-r number of RPCs, default value is 1"
echo "-c number of Connections, default value is 1"
echo "-req req size in bytes, default value is 1"
echo "-resp resp size in bytes, default value is 1"
echo "-rpc_type valid values are unary|streaming, default is unary"
;;
*)
echo "Incorrect option $1"
exit 1
;;
esac
done
# Build server and client
out_dir=$(mktemp -d oss_benchXXX)
go build -o ${out_dir}/server $GOPATH/src/google.golang.org/grpc/benchmark/server/main.go && go build -o ${out_dir}/client $GOPATH/src/google.golang.org/grpc/benchmark/client/main.go
if [ $? != 0 ]; then
clean_and_die 1
fi
while :
do
run
inc
done

81
vendor/google.golang.org/grpc/benchmark/server/main.go generated vendored Normal file
View file

@ -0,0 +1,81 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package main
import (
"flag"
"fmt"
"net"
_ "net/http/pprof"
"os"
"os/signal"
"runtime"
"runtime/pprof"
"time"
"google.golang.org/grpc/benchmark"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/syscall"
)
var (
port = flag.String("port", "50051", "Localhost port to listen on.")
testName = flag.String("test_name", "", "Name of the test used for creating profiles.")
)
func main() {
flag.Parse()
if *testName == "" {
grpclog.Fatalf("test name not set")
}
lis, err := net.Listen("tcp", ":"+*port)
if err != nil {
grpclog.Fatalf("Failed to listen: %v", err)
}
defer lis.Close()
cf, err := os.Create("/tmp/" + *testName + ".cpu")
if err != nil {
grpclog.Fatalf("Failed to create file: %v", err)
}
defer cf.Close()
pprof.StartCPUProfile(cf)
cpuBeg := syscall.GetCPUTime()
// Launch server in a separate goroutine.
stop := benchmark.StartServer(benchmark.ServerInfo{Type: "protobuf", Listener: lis})
// Wait on OS terminate signal.
ch := make(chan os.Signal, 1)
signal.Notify(ch, os.Interrupt)
<-ch
cpu := time.Duration(syscall.GetCPUTime() - cpuBeg)
stop()
pprof.StopCPUProfile()
mf, err := os.Create("/tmp/" + *testName + ".mem")
if err != nil {
grpclog.Fatalf("Failed to create file: %v", err)
}
defer mf.Close()
runtime.GC() // materialize all statistics
if err := pprof.WriteHeapProfile(mf); err != nil {
grpclog.Fatalf("Failed to write memory profile: %v", err)
}
fmt.Println("Server CPU utilization:", cpu)
fmt.Println("Server CPU profile:", cf.Name())
fmt.Println("Server Mem Profile:", mf.Name())
}

View file

@ -0,0 +1,222 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package stats
import (
"bytes"
"fmt"
"io"
"log"
"math"
"strconv"
"strings"
)
// Histogram accumulates values in the form of a histogram with
// exponentially increased bucket sizes.
type Histogram struct {
// Count is the total number of values added to the histogram.
Count int64
// Sum is the sum of all the values added to the histogram.
Sum int64
// SumOfSquares is the sum of squares of all values.
SumOfSquares int64
// Min is the minimum of all the values added to the histogram.
Min int64
// Max is the maximum of all the values added to the histogram.
Max int64
// Buckets contains all the buckets of the histogram.
Buckets []HistogramBucket
opts HistogramOptions
logBaseBucketSize float64
oneOverLogOnePlusGrowthFactor float64
}
// HistogramOptions contains the parameters that define the histogram's buckets.
// The first bucket of the created histogram (with index 0) contains [min, min+n)
// where n = BaseBucketSize, min = MinValue.
// Bucket i (i>=1) contains [min + n * m^(i-1), min + n * m^i), where m = 1+GrowthFactor.
// The type of the values is int64.
type HistogramOptions struct {
// NumBuckets is the number of buckets.
NumBuckets int
// GrowthFactor is the growth factor of the buckets. A value of 0.1
// indicates that bucket N+1 will be 10% larger than bucket N.
GrowthFactor float64
// BaseBucketSize is the size of the first bucket.
BaseBucketSize float64
// MinValue is the lower bound of the first bucket.
MinValue int64
}
// HistogramBucket represents one histogram bucket.
type HistogramBucket struct {
// LowBound is the lower bound of the bucket.
LowBound float64
// Count is the number of values in the bucket.
Count int64
}
// NewHistogram returns a pointer to a new Histogram object that was created
// with the provided options.
func NewHistogram(opts HistogramOptions) *Histogram {
if opts.NumBuckets == 0 {
opts.NumBuckets = 32
}
if opts.BaseBucketSize == 0.0 {
opts.BaseBucketSize = 1.0
}
h := Histogram{
Buckets: make([]HistogramBucket, opts.NumBuckets),
Min: math.MaxInt64,
Max: math.MinInt64,
opts: opts,
logBaseBucketSize: math.Log(opts.BaseBucketSize),
oneOverLogOnePlusGrowthFactor: 1 / math.Log(1+opts.GrowthFactor),
}
m := 1.0 + opts.GrowthFactor
delta := opts.BaseBucketSize
h.Buckets[0].LowBound = float64(opts.MinValue)
for i := 1; i < opts.NumBuckets; i++ {
h.Buckets[i].LowBound = float64(opts.MinValue) + delta
delta = delta * m
}
return &h
}
// Print writes textual output of the histogram values.
func (h *Histogram) Print(w io.Writer) {
h.PrintWithUnit(w, 1)
}
// PrintWithUnit writes textual output of the histogram values .
// Data in histogram is divided by a Unit before print.
func (h *Histogram) PrintWithUnit(w io.Writer, unit float64) {
avg := float64(h.Sum) / float64(h.Count)
fmt.Fprintf(w, "Count: %d Min: %5.1f Max: %5.1f Avg: %.2f\n", h.Count, float64(h.Min)/unit, float64(h.Max)/unit, avg/unit)
fmt.Fprintf(w, "%s\n", strings.Repeat("-", 60))
if h.Count <= 0 {
return
}
maxBucketDigitLen := len(strconv.FormatFloat(h.Buckets[len(h.Buckets)-1].LowBound, 'f', 6, 64))
if maxBucketDigitLen < 3 {
// For "inf".
maxBucketDigitLen = 3
}
maxCountDigitLen := len(strconv.FormatInt(h.Count, 10))
percentMulti := 100 / float64(h.Count)
accCount := int64(0)
for i, b := range h.Buckets {
fmt.Fprintf(w, "[%*f, ", maxBucketDigitLen, b.LowBound/unit)
if i+1 < len(h.Buckets) {
fmt.Fprintf(w, "%*f)", maxBucketDigitLen, h.Buckets[i+1].LowBound/unit)
} else {
fmt.Fprintf(w, "%*s)", maxBucketDigitLen, "inf")
}
accCount += b.Count
fmt.Fprintf(w, " %*d %5.1f%% %5.1f%%", maxCountDigitLen, b.Count, float64(b.Count)*percentMulti, float64(accCount)*percentMulti)
const barScale = 0.1
barLength := int(float64(b.Count)*percentMulti*barScale + 0.5)
fmt.Fprintf(w, " %s\n", strings.Repeat("#", barLength))
}
}
// String returns the textual output of the histogram values as string.
func (h *Histogram) String() string {
var b bytes.Buffer
h.Print(&b)
return b.String()
}
// Clear resets all the content of histogram.
func (h *Histogram) Clear() {
h.Count = 0
h.Sum = 0
h.SumOfSquares = 0
h.Min = math.MaxInt64
h.Max = math.MinInt64
for i := range h.Buckets {
h.Buckets[i].Count = 0
}
}
// Opts returns a copy of the options used to create the Histogram.
func (h *Histogram) Opts() HistogramOptions {
return h.opts
}
// Add adds a value to the histogram.
func (h *Histogram) Add(value int64) error {
bucket, err := h.findBucket(value)
if err != nil {
return err
}
h.Buckets[bucket].Count++
h.Count++
h.Sum += value
h.SumOfSquares += value * value
if value < h.Min {
h.Min = value
}
if value > h.Max {
h.Max = value
}
return nil
}
func (h *Histogram) findBucket(value int64) (int, error) {
delta := float64(value - h.opts.MinValue)
var b int
if delta >= h.opts.BaseBucketSize {
// b = log_{1+growthFactor} (delta / baseBucketSize) + 1
// = log(delta / baseBucketSize) / log(1+growthFactor) + 1
// = (log(delta) - log(baseBucketSize)) * (1 / log(1+growthFactor)) + 1
b = int((math.Log(delta)-h.logBaseBucketSize)*h.oneOverLogOnePlusGrowthFactor + 1)
}
if b >= len(h.Buckets) {
return 0, fmt.Errorf("no bucket for value: %d", value)
}
return b, nil
}
// Merge takes another histogram h2, and merges its content into h.
// The two histograms must be created by equivalent HistogramOptions.
func (h *Histogram) Merge(h2 *Histogram) {
if h.opts != h2.opts {
log.Fatalf("failed to merge histograms, created by inequivalent options")
}
h.Count += h2.Count
h.Sum += h2.Sum
h.SumOfSquares += h2.SumOfSquares
if h2.Min < h.Min {
h.Min = h2.Min
}
if h2.Max > h.Max {
h.Max = h2.Max
}
for i, b := range h2.Buckets {
h.Buckets[i].Count += b.Count
}
}

302
vendor/google.golang.org/grpc/benchmark/stats/stats.go generated vendored Normal file
View file

@ -0,0 +1,302 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package stats
import (
"bytes"
"fmt"
"io"
"math"
"sort"
"strconv"
"time"
)
// Features contains most fields for a benchmark
type Features struct {
NetworkMode string
EnableTrace bool
Latency time.Duration
Kbps int
Mtu int
MaxConcurrentCalls int
ReqSizeBytes int
RespSizeBytes int
EnableCompressor bool
EnableChannelz bool
}
// String returns the textual output of the Features as string.
func (f Features) String() string {
return fmt.Sprintf("traceMode_%t-latency_%s-kbps_%#v-MTU_%#v-maxConcurrentCalls_"+
"%#v-reqSize_%#vB-respSize_%#vB-Compressor_%t", f.EnableTrace,
f.Latency.String(), f.Kbps, f.Mtu, f.MaxConcurrentCalls, f.ReqSizeBytes, f.RespSizeBytes, f.EnableCompressor)
}
// ConciseString returns the concise textual output of the Features as string, skipping
// setting with default value.
func (f Features) ConciseString() string {
noneEmptyPos := []bool{f.EnableTrace, f.Latency != 0, f.Kbps != 0, f.Mtu != 0, true, true, true, f.EnableCompressor, f.EnableChannelz}
return PartialPrintString(noneEmptyPos, f, false)
}
// PartialPrintString can print certain features with different format.
func PartialPrintString(noneEmptyPos []bool, f Features, shared bool) string {
s := ""
var (
prefix, suffix, linker string
isNetwork bool
)
if shared {
suffix = "\n"
linker = ": "
} else {
prefix = "-"
linker = "_"
}
if noneEmptyPos[0] {
s += fmt.Sprintf("%sTrace%s%t%s", prefix, linker, f.EnableTrace, suffix)
}
if shared && f.NetworkMode != "" {
s += fmt.Sprintf("Network: %s \n", f.NetworkMode)
isNetwork = true
}
if !isNetwork {
if noneEmptyPos[1] {
s += fmt.Sprintf("%slatency%s%s%s", prefix, linker, f.Latency.String(), suffix)
}
if noneEmptyPos[2] {
s += fmt.Sprintf("%skbps%s%#v%s", prefix, linker, f.Kbps, suffix)
}
if noneEmptyPos[3] {
s += fmt.Sprintf("%sMTU%s%#v%s", prefix, linker, f.Mtu, suffix)
}
}
if noneEmptyPos[4] {
s += fmt.Sprintf("%sCallers%s%#v%s", prefix, linker, f.MaxConcurrentCalls, suffix)
}
if noneEmptyPos[5] {
s += fmt.Sprintf("%sreqSize%s%#vB%s", prefix, linker, f.ReqSizeBytes, suffix)
}
if noneEmptyPos[6] {
s += fmt.Sprintf("%srespSize%s%#vB%s", prefix, linker, f.RespSizeBytes, suffix)
}
if noneEmptyPos[7] {
s += fmt.Sprintf("%sCompressor%s%t%s", prefix, linker, f.EnableCompressor, suffix)
}
if noneEmptyPos[8] {
s += fmt.Sprintf("%sChannelz%s%t%s", prefix, linker, f.EnableChannelz, suffix)
}
return s
}
type percentLatency struct {
Percent int
Value time.Duration
}
// BenchResults records features and result of a benchmark.
type BenchResults struct {
RunMode string
Features Features
Latency []percentLatency
Operations int
NsPerOp int64
AllocedBytesPerOp int64
AllocsPerOp int64
SharedPosion []bool
}
// SetBenchmarkResult sets features of benchmark and basic results.
func (stats *Stats) SetBenchmarkResult(mode string, features Features, o int, allocdBytes, allocs int64, sharedPos []bool) {
stats.result.RunMode = mode
stats.result.Features = features
stats.result.Operations = o
stats.result.AllocedBytesPerOp = allocdBytes
stats.result.AllocsPerOp = allocs
stats.result.SharedPosion = sharedPos
}
// GetBenchmarkResults returns the result of the benchmark including features and result.
func (stats *Stats) GetBenchmarkResults() BenchResults {
return stats.result
}
// BenchString output latency stats as the format as time + unit.
func (stats *Stats) BenchString() string {
stats.maybeUpdate()
s := stats.result
res := s.RunMode + "-" + s.Features.String() + ": \n"
if len(s.Latency) != 0 {
var statsUnit = s.Latency[0].Value
var timeUnit = fmt.Sprintf("%v", statsUnit)[1:]
for i := 1; i < len(s.Latency)-1; i++ {
res += fmt.Sprintf("%d_Latency: %s %s \t", s.Latency[i].Percent,
strconv.FormatFloat(float64(s.Latency[i].Value)/float64(statsUnit), 'f', 4, 64), timeUnit)
}
res += fmt.Sprintf("Avg latency: %s %s \t",
strconv.FormatFloat(float64(s.Latency[len(s.Latency)-1].Value)/float64(statsUnit), 'f', 4, 64), timeUnit)
}
res += fmt.Sprintf("Count: %v \t", s.Operations)
res += fmt.Sprintf("%v Bytes/op\t", s.AllocedBytesPerOp)
res += fmt.Sprintf("%v Allocs/op\t", s.AllocsPerOp)
return res
}
// Stats is a simple helper for gathering additional statistics like histogram
// during benchmarks. This is not thread safe.
type Stats struct {
numBuckets int
unit time.Duration
min, max int64
histogram *Histogram
durations durationSlice
dirty bool
sortLatency bool
result BenchResults
}
type durationSlice []time.Duration
// NewStats creates a new Stats instance. If numBuckets is not positive,
// the default value (16) will be used.
func NewStats(numBuckets int) *Stats {
if numBuckets <= 0 {
numBuckets = 16
}
return &Stats{
// Use one more bucket for the last unbounded bucket.
numBuckets: numBuckets + 1,
durations: make(durationSlice, 0, 100000),
}
}
// Add adds an elapsed time per operation to the stats.
func (stats *Stats) Add(d time.Duration) {
stats.durations = append(stats.durations, d)
stats.dirty = true
}
// Clear resets the stats, removing all values.
func (stats *Stats) Clear() {
stats.durations = stats.durations[:0]
stats.histogram = nil
stats.dirty = false
stats.result = BenchResults{}
}
//Sort method for durations
func (a durationSlice) Len() int { return len(a) }
func (a durationSlice) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a durationSlice) Less(i, j int) bool { return a[i] < a[j] }
func max(a, b int64) int64 {
if a > b {
return a
}
return b
}
// maybeUpdate updates internal stat data if there was any newly added
// stats since this was updated.
func (stats *Stats) maybeUpdate() {
if !stats.dirty {
return
}
if stats.sortLatency {
sort.Sort(stats.durations)
stats.min = int64(stats.durations[0])
stats.max = int64(stats.durations[len(stats.durations)-1])
}
stats.min = math.MaxInt64
stats.max = 0
for _, d := range stats.durations {
if stats.min > int64(d) {
stats.min = int64(d)
}
if stats.max < int64(d) {
stats.max = int64(d)
}
}
// Use the largest unit that can represent the minimum time duration.
stats.unit = time.Nanosecond
for _, u := range []time.Duration{time.Microsecond, time.Millisecond, time.Second} {
if stats.min <= int64(u) {
break
}
stats.unit = u
}
numBuckets := stats.numBuckets
if n := int(stats.max - stats.min + 1); n < numBuckets {
numBuckets = n
}
stats.histogram = NewHistogram(HistogramOptions{
NumBuckets: numBuckets,
// max-min(lower bound of last bucket) = (1 + growthFactor)^(numBuckets-2) * baseBucketSize.
GrowthFactor: math.Pow(float64(stats.max-stats.min), 1/float64(numBuckets-2)) - 1,
BaseBucketSize: 1.0,
MinValue: stats.min})
for _, d := range stats.durations {
stats.histogram.Add(int64(d))
}
stats.dirty = false
if stats.durations.Len() != 0 {
var percentToObserve = []int{50, 90, 99}
// First data record min unit from the latency result.
stats.result.Latency = append(stats.result.Latency, percentLatency{Percent: -1, Value: stats.unit})
for _, position := range percentToObserve {
stats.result.Latency = append(stats.result.Latency, percentLatency{Percent: position, Value: stats.durations[max(stats.histogram.Count*int64(position)/100-1, 0)]})
}
// Last data record the average latency.
avg := float64(stats.histogram.Sum) / float64(stats.histogram.Count)
stats.result.Latency = append(stats.result.Latency, percentLatency{Percent: -1, Value: time.Duration(avg)})
}
}
// SortLatency blocks the output
func (stats *Stats) SortLatency() {
stats.sortLatency = true
}
// Print writes textual output of the Stats.
func (stats *Stats) Print(w io.Writer) {
stats.maybeUpdate()
if stats.histogram == nil {
fmt.Fprint(w, "Histogram (empty)\n")
} else {
fmt.Fprintf(w, "Histogram (unit: %s)\n", fmt.Sprintf("%v", stats.unit)[1:])
stats.histogram.PrintWithUnit(w, float64(stats.unit))
}
}
// String returns the textual output of the Stats as string.
func (stats *Stats) String() string {
var b bytes.Buffer
stats.Print(&b)
return b.String()
}

208
vendor/google.golang.org/grpc/benchmark/stats/util.go generated vendored Normal file
View file

@ -0,0 +1,208 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package stats
import (
"bufio"
"bytes"
"fmt"
"os"
"runtime"
"sort"
"strings"
"sync"
"testing"
)
var (
curB *testing.B
curBenchName string
curStats map[string]*Stats
orgStdout *os.File
nextOutPos int
injectCond *sync.Cond
injectDone chan struct{}
)
// AddStats adds a new unnamed Stats instance to the current benchmark. You need
// to run benchmarks by calling RunTestMain() to inject the stats to the
// benchmark results. If numBuckets is not positive, the default value (16) will
// be used. Please note that this calls b.ResetTimer() since it may be blocked
// until the previous benchmark stats is printed out. So AddStats() should
// typically be called at the very beginning of each benchmark function.
func AddStats(b *testing.B, numBuckets int) *Stats {
return AddStatsWithName(b, "", numBuckets)
}
// AddStatsWithName adds a new named Stats instance to the current benchmark.
// With this, you can add multiple stats in a single benchmark. You need
// to run benchmarks by calling RunTestMain() to inject the stats to the
// benchmark results. If numBuckets is not positive, the default value (16) will
// be used. Please note that this calls b.ResetTimer() since it may be blocked
// until the previous benchmark stats is printed out. So AddStatsWithName()
// should typically be called at the very beginning of each benchmark function.
func AddStatsWithName(b *testing.B, name string, numBuckets int) *Stats {
var benchName string
for i := 1; ; i++ {
pc, _, _, ok := runtime.Caller(i)
if !ok {
panic("benchmark function not found")
}
p := strings.Split(runtime.FuncForPC(pc).Name(), ".")
benchName = p[len(p)-1]
if strings.HasPrefix(benchName, "run") {
break
}
}
procs := runtime.GOMAXPROCS(-1)
if procs != 1 {
benchName = fmt.Sprintf("%s-%d", benchName, procs)
}
stats := NewStats(numBuckets)
if injectCond != nil {
// We need to wait until the previous benchmark stats is printed out.
injectCond.L.Lock()
for curB != nil && curBenchName != benchName {
injectCond.Wait()
}
curB = b
curBenchName = benchName
curStats[name] = stats
injectCond.L.Unlock()
}
b.ResetTimer()
return stats
}
// RunTestMain runs the tests with enabling injection of benchmark stats. It
// returns an exit code to pass to os.Exit.
func RunTestMain(m *testing.M) int {
startStatsInjector()
defer stopStatsInjector()
return m.Run()
}
// startStatsInjector starts stats injection to benchmark results.
func startStatsInjector() {
orgStdout = os.Stdout
r, w, _ := os.Pipe()
os.Stdout = w
nextOutPos = 0
resetCurBenchStats()
injectCond = sync.NewCond(&sync.Mutex{})
injectDone = make(chan struct{})
go func() {
defer close(injectDone)
scanner := bufio.NewScanner(r)
scanner.Split(splitLines)
for scanner.Scan() {
injectStatsIfFinished(scanner.Text())
}
if err := scanner.Err(); err != nil {
panic(err)
}
}()
}
// stopStatsInjector stops stats injection and restores os.Stdout.
func stopStatsInjector() {
os.Stdout.Close()
<-injectDone
injectCond = nil
os.Stdout = orgStdout
}
// splitLines is a split function for a bufio.Scanner that returns each line
// of text, teeing texts to the original stdout even before each line ends.
func splitLines(data []byte, eof bool) (advance int, token []byte, err error) {
if eof && len(data) == 0 {
return 0, nil, nil
}
if i := bytes.IndexByte(data, '\n'); i >= 0 {
orgStdout.Write(data[nextOutPos : i+1])
nextOutPos = 0
return i + 1, data[0:i], nil
}
orgStdout.Write(data[nextOutPos:])
nextOutPos = len(data)
if eof {
// This is a final, non-terminated line. Return it.
return len(data), data, nil
}
return 0, nil, nil
}
// injectStatsIfFinished prints out the stats if the current benchmark finishes.
func injectStatsIfFinished(line string) {
injectCond.L.Lock()
defer injectCond.L.Unlock()
// We assume that the benchmark results start with "Benchmark".
if curB == nil || !strings.HasPrefix(line, "Benchmark") {
return
}
if !curB.Failed() {
// Output all stats in alphabetical order.
names := make([]string, 0, len(curStats))
for name := range curStats {
names = append(names, name)
}
sort.Strings(names)
for _, name := range names {
stats := curStats[name]
// The output of stats starts with a header like "Histogram (unit: ms)"
// followed by statistical properties and the buckets. Add the stats name
// if it is a named stats and indent them as Go testing outputs.
lines := strings.Split(stats.String(), "\n")
if n := len(lines); n > 0 {
if name != "" {
name = ": " + name
}
fmt.Fprintf(orgStdout, "--- %s%s\n", lines[0], name)
for _, line := range lines[1 : n-1] {
fmt.Fprintf(orgStdout, "\t%s\n", line)
}
}
}
}
resetCurBenchStats()
injectCond.Signal()
}
// resetCurBenchStats resets the current benchmark stats.
func resetCurBenchStats() {
curB = nil
curBenchName = ""
curStats = make(map[string]*Stats)
}

View file

@ -0,0 +1,386 @@
/*
*
* Copyright 2016 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package main
import (
"flag"
"math"
"runtime"
"sync"
"time"
"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/benchmark"
testpb "google.golang.org/grpc/benchmark/grpc_testing"
"google.golang.org/grpc/benchmark/stats"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/syscall"
"google.golang.org/grpc/status"
"google.golang.org/grpc/testdata"
)
var caFile = flag.String("ca_file", "", "The file containing the CA root cert file")
type lockingHistogram struct {
mu sync.Mutex
histogram *stats.Histogram
}
func (h *lockingHistogram) add(value int64) {
h.mu.Lock()
defer h.mu.Unlock()
h.histogram.Add(value)
}
// swap sets h.histogram to o and returns its old value.
func (h *lockingHistogram) swap(o *stats.Histogram) *stats.Histogram {
h.mu.Lock()
defer h.mu.Unlock()
old := h.histogram
h.histogram = o
return old
}
func (h *lockingHistogram) mergeInto(merged *stats.Histogram) {
h.mu.Lock()
defer h.mu.Unlock()
merged.Merge(h.histogram)
}
type benchmarkClient struct {
closeConns func()
stop chan bool
lastResetTime time.Time
histogramOptions stats.HistogramOptions
lockingHistograms []lockingHistogram
rusageLastReset *syscall.Rusage
}
func printClientConfig(config *testpb.ClientConfig) {
// Some config options are ignored:
// - client type:
// will always create sync client
// - async client threads.
// - core list
grpclog.Infof(" * client type: %v (ignored, always creates sync client)", config.ClientType)
grpclog.Infof(" * async client threads: %v (ignored)", config.AsyncClientThreads)
// TODO: use cores specified by CoreList when setting list of cores is supported in go.
grpclog.Infof(" * core list: %v (ignored)", config.CoreList)
grpclog.Infof(" - security params: %v", config.SecurityParams)
grpclog.Infof(" - core limit: %v", config.CoreLimit)
grpclog.Infof(" - payload config: %v", config.PayloadConfig)
grpclog.Infof(" - rpcs per chann: %v", config.OutstandingRpcsPerChannel)
grpclog.Infof(" - channel number: %v", config.ClientChannels)
grpclog.Infof(" - load params: %v", config.LoadParams)
grpclog.Infof(" - rpc type: %v", config.RpcType)
grpclog.Infof(" - histogram params: %v", config.HistogramParams)
grpclog.Infof(" - server targets: %v", config.ServerTargets)
}
func setupClientEnv(config *testpb.ClientConfig) {
// Use all cpu cores available on machine by default.
// TODO: Revisit this for the optimal default setup.
if config.CoreLimit > 0 {
runtime.GOMAXPROCS(int(config.CoreLimit))
} else {
runtime.GOMAXPROCS(runtime.NumCPU())
}
}
// createConns creates connections according to given config.
// It returns the connections and corresponding function to close them.
// It returns non-nil error if there is anything wrong.
func createConns(config *testpb.ClientConfig) ([]*grpc.ClientConn, func(), error) {
var opts []grpc.DialOption
// Sanity check for client type.
switch config.ClientType {
case testpb.ClientType_SYNC_CLIENT:
case testpb.ClientType_ASYNC_CLIENT:
default:
return nil, nil, status.Errorf(codes.InvalidArgument, "unknown client type: %v", config.ClientType)
}
// Check and set security options.
if config.SecurityParams != nil {
if *caFile == "" {
*caFile = testdata.Path("ca.pem")
}
creds, err := credentials.NewClientTLSFromFile(*caFile, config.SecurityParams.ServerHostOverride)
if err != nil {
return nil, nil, status.Errorf(codes.InvalidArgument, "failed to create TLS credentials %v", err)
}
opts = append(opts, grpc.WithTransportCredentials(creds))
} else {
opts = append(opts, grpc.WithInsecure())
}
// Use byteBufCodec if it is required.
if config.PayloadConfig != nil {
switch config.PayloadConfig.Payload.(type) {
case *testpb.PayloadConfig_BytebufParams:
opts = append(opts, grpc.WithDefaultCallOptions(grpc.CallCustomCodec(byteBufCodec{})))
case *testpb.PayloadConfig_SimpleParams:
default:
return nil, nil, status.Errorf(codes.InvalidArgument, "unknown payload config: %v", config.PayloadConfig)
}
}
// Create connections.
connCount := int(config.ClientChannels)
conns := make([]*grpc.ClientConn, connCount)
for connIndex := 0; connIndex < connCount; connIndex++ {
conns[connIndex] = benchmark.NewClientConn(config.ServerTargets[connIndex%len(config.ServerTargets)], opts...)
}
return conns, func() {
for _, conn := range conns {
conn.Close()
}
}, nil
}
func performRPCs(config *testpb.ClientConfig, conns []*grpc.ClientConn, bc *benchmarkClient) error {
// Read payload size and type from config.
var (
payloadReqSize, payloadRespSize int
payloadType string
)
if config.PayloadConfig != nil {
switch c := config.PayloadConfig.Payload.(type) {
case *testpb.PayloadConfig_BytebufParams:
payloadReqSize = int(c.BytebufParams.ReqSize)
payloadRespSize = int(c.BytebufParams.RespSize)
payloadType = "bytebuf"
case *testpb.PayloadConfig_SimpleParams:
payloadReqSize = int(c.SimpleParams.ReqSize)
payloadRespSize = int(c.SimpleParams.RespSize)
payloadType = "protobuf"
default:
return status.Errorf(codes.InvalidArgument, "unknown payload config: %v", config.PayloadConfig)
}
}
// TODO add open loop distribution.
switch config.LoadParams.Load.(type) {
case *testpb.LoadParams_ClosedLoop:
case *testpb.LoadParams_Poisson:
return status.Errorf(codes.Unimplemented, "unsupported load params: %v", config.LoadParams)
default:
return status.Errorf(codes.InvalidArgument, "unknown load params: %v", config.LoadParams)
}
rpcCountPerConn := int(config.OutstandingRpcsPerChannel)
switch config.RpcType {
case testpb.RpcType_UNARY:
bc.doCloseLoopUnary(conns, rpcCountPerConn, payloadReqSize, payloadRespSize)
// TODO open loop.
case testpb.RpcType_STREAMING:
bc.doCloseLoopStreaming(conns, rpcCountPerConn, payloadReqSize, payloadRespSize, payloadType)
// TODO open loop.
default:
return status.Errorf(codes.InvalidArgument, "unknown rpc type: %v", config.RpcType)
}
return nil
}
func startBenchmarkClient(config *testpb.ClientConfig) (*benchmarkClient, error) {
printClientConfig(config)
// Set running environment like how many cores to use.
setupClientEnv(config)
conns, closeConns, err := createConns(config)
if err != nil {
return nil, err
}
rpcCountPerConn := int(config.OutstandingRpcsPerChannel)
bc := &benchmarkClient{
histogramOptions: stats.HistogramOptions{
NumBuckets: int(math.Log(config.HistogramParams.MaxPossible)/math.Log(1+config.HistogramParams.Resolution)) + 1,
GrowthFactor: config.HistogramParams.Resolution,
BaseBucketSize: (1 + config.HistogramParams.Resolution),
MinValue: 0,
},
lockingHistograms: make([]lockingHistogram, rpcCountPerConn*len(conns)),
stop: make(chan bool),
lastResetTime: time.Now(),
closeConns: closeConns,
rusageLastReset: syscall.GetRusage(),
}
if err = performRPCs(config, conns, bc); err != nil {
// Close all connections if performRPCs failed.
closeConns()
return nil, err
}
return bc, nil
}
func (bc *benchmarkClient) doCloseLoopUnary(conns []*grpc.ClientConn, rpcCountPerConn int, reqSize int, respSize int) {
for ic, conn := range conns {
client := testpb.NewBenchmarkServiceClient(conn)
// For each connection, create rpcCountPerConn goroutines to do rpc.
for j := 0; j < rpcCountPerConn; j++ {
// Create histogram for each goroutine.
idx := ic*rpcCountPerConn + j
bc.lockingHistograms[idx].histogram = stats.NewHistogram(bc.histogramOptions)
// Start goroutine on the created mutex and histogram.
go func(idx int) {
// TODO: do warm up if necessary.
// Now relying on worker client to reserve time to do warm up.
// The worker client needs to wait for some time after client is created,
// before starting benchmark.
done := make(chan bool)
for {
go func() {
start := time.Now()
if err := benchmark.DoUnaryCall(client, reqSize, respSize); err != nil {
select {
case <-bc.stop:
case done <- false:
}
return
}
elapse := time.Since(start)
bc.lockingHistograms[idx].add(int64(elapse))
select {
case <-bc.stop:
case done <- true:
}
}()
select {
case <-bc.stop:
return
case <-done:
}
}
}(idx)
}
}
}
func (bc *benchmarkClient) doCloseLoopStreaming(conns []*grpc.ClientConn, rpcCountPerConn int, reqSize int, respSize int, payloadType string) {
var doRPC func(testpb.BenchmarkService_StreamingCallClient, int, int) error
if payloadType == "bytebuf" {
doRPC = benchmark.DoByteBufStreamingRoundTrip
} else {
doRPC = benchmark.DoStreamingRoundTrip
}
for ic, conn := range conns {
// For each connection, create rpcCountPerConn goroutines to do rpc.
for j := 0; j < rpcCountPerConn; j++ {
c := testpb.NewBenchmarkServiceClient(conn)
stream, err := c.StreamingCall(context.Background())
if err != nil {
grpclog.Fatalf("%v.StreamingCall(_) = _, %v", c, err)
}
// Create histogram for each goroutine.
idx := ic*rpcCountPerConn + j
bc.lockingHistograms[idx].histogram = stats.NewHistogram(bc.histogramOptions)
// Start goroutine on the created mutex and histogram.
go func(idx int) {
// TODO: do warm up if necessary.
// Now relying on worker client to reserve time to do warm up.
// The worker client needs to wait for some time after client is created,
// before starting benchmark.
for {
start := time.Now()
if err := doRPC(stream, reqSize, respSize); err != nil {
return
}
elapse := time.Since(start)
bc.lockingHistograms[idx].add(int64(elapse))
select {
case <-bc.stop:
return
default:
}
}
}(idx)
}
}
}
// getStats returns the stats for benchmark client.
// It resets lastResetTime and all histograms if argument reset is true.
func (bc *benchmarkClient) getStats(reset bool) *testpb.ClientStats {
var wallTimeElapsed, uTimeElapsed, sTimeElapsed float64
mergedHistogram := stats.NewHistogram(bc.histogramOptions)
if reset {
// Merging histogram may take some time.
// Put all histograms aside and merge later.
toMerge := make([]*stats.Histogram, len(bc.lockingHistograms))
for i := range bc.lockingHistograms {
toMerge[i] = bc.lockingHistograms[i].swap(stats.NewHistogram(bc.histogramOptions))
}
for i := 0; i < len(toMerge); i++ {
mergedHistogram.Merge(toMerge[i])
}
wallTimeElapsed = time.Since(bc.lastResetTime).Seconds()
latestRusage := syscall.GetRusage()
uTimeElapsed, sTimeElapsed = syscall.CPUTimeDiff(bc.rusageLastReset, latestRusage)
bc.rusageLastReset = latestRusage
bc.lastResetTime = time.Now()
} else {
// Merge only, not reset.
for i := range bc.lockingHistograms {
bc.lockingHistograms[i].mergeInto(mergedHistogram)
}
wallTimeElapsed = time.Since(bc.lastResetTime).Seconds()
uTimeElapsed, sTimeElapsed = syscall.CPUTimeDiff(bc.rusageLastReset, syscall.GetRusage())
}
b := make([]uint32, len(mergedHistogram.Buckets))
for i, v := range mergedHistogram.Buckets {
b[i] = uint32(v.Count)
}
return &testpb.ClientStats{
Latencies: &testpb.HistogramData{
Bucket: b,
MinSeen: float64(mergedHistogram.Min),
MaxSeen: float64(mergedHistogram.Max),
Sum: float64(mergedHistogram.Sum),
SumOfSquares: float64(mergedHistogram.SumOfSquares),
Count: float64(mergedHistogram.Count),
},
TimeElapsed: wallTimeElapsed,
TimeUser: uTimeElapsed,
TimeSystem: sTimeElapsed,
}
}
func (bc *benchmarkClient) shutdown() {
close(bc.stop)
bc.closeConns()
}

View file

@ -0,0 +1,184 @@
/*
*
* Copyright 2016 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package main
import (
"flag"
"fmt"
"net"
"runtime"
"strconv"
"strings"
"sync"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/benchmark"
testpb "google.golang.org/grpc/benchmark/grpc_testing"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/syscall"
"google.golang.org/grpc/status"
"google.golang.org/grpc/testdata"
)
var (
certFile = flag.String("tls_cert_file", "", "The TLS cert file")
keyFile = flag.String("tls_key_file", "", "The TLS key file")
)
type benchmarkServer struct {
port int
cores int
closeFunc func()
mu sync.RWMutex
lastResetTime time.Time
rusageLastReset *syscall.Rusage
}
func printServerConfig(config *testpb.ServerConfig) {
// Some config options are ignored:
// - server type:
// will always start sync server
// - async server threads
// - core list
grpclog.Infof(" * server type: %v (ignored, always starts sync server)", config.ServerType)
grpclog.Infof(" * async server threads: %v (ignored)", config.AsyncServerThreads)
// TODO: use cores specified by CoreList when setting list of cores is supported in go.
grpclog.Infof(" * core list: %v (ignored)", config.CoreList)
grpclog.Infof(" - security params: %v", config.SecurityParams)
grpclog.Infof(" - core limit: %v", config.CoreLimit)
grpclog.Infof(" - port: %v", config.Port)
grpclog.Infof(" - payload config: %v", config.PayloadConfig)
}
func startBenchmarkServer(config *testpb.ServerConfig, serverPort int) (*benchmarkServer, error) {
printServerConfig(config)
// Use all cpu cores available on machine by default.
// TODO: Revisit this for the optimal default setup.
numOfCores := runtime.NumCPU()
if config.CoreLimit > 0 {
numOfCores = int(config.CoreLimit)
}
runtime.GOMAXPROCS(numOfCores)
var opts []grpc.ServerOption
// Sanity check for server type.
switch config.ServerType {
case testpb.ServerType_SYNC_SERVER:
case testpb.ServerType_ASYNC_SERVER:
case testpb.ServerType_ASYNC_GENERIC_SERVER:
default:
return nil, status.Errorf(codes.InvalidArgument, "unknown server type: %v", config.ServerType)
}
// Set security options.
if config.SecurityParams != nil {
if *certFile == "" {
*certFile = testdata.Path("server1.pem")
}
if *keyFile == "" {
*keyFile = testdata.Path("server1.key")
}
creds, err := credentials.NewServerTLSFromFile(*certFile, *keyFile)
if err != nil {
grpclog.Fatalf("failed to generate credentials %v", err)
}
opts = append(opts, grpc.Creds(creds))
}
// Priority: config.Port > serverPort > default (0).
port := int(config.Port)
if port == 0 {
port = serverPort
}
lis, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil {
grpclog.Fatalf("Failed to listen: %v", err)
}
addr := lis.Addr().String()
// Create different benchmark server according to config.
var closeFunc func()
if config.PayloadConfig != nil {
switch payload := config.PayloadConfig.Payload.(type) {
case *testpb.PayloadConfig_BytebufParams:
opts = append(opts, grpc.CustomCodec(byteBufCodec{}))
closeFunc = benchmark.StartServer(benchmark.ServerInfo{
Type: "bytebuf",
Metadata: payload.BytebufParams.RespSize,
Listener: lis,
}, opts...)
case *testpb.PayloadConfig_SimpleParams:
closeFunc = benchmark.StartServer(benchmark.ServerInfo{
Type: "protobuf",
Listener: lis,
}, opts...)
case *testpb.PayloadConfig_ComplexParams:
return nil, status.Errorf(codes.Unimplemented, "unsupported payload config: %v", config.PayloadConfig)
default:
return nil, status.Errorf(codes.InvalidArgument, "unknown payload config: %v", config.PayloadConfig)
}
} else {
// Start protobuf server if payload config is nil.
closeFunc = benchmark.StartServer(benchmark.ServerInfo{
Type: "protobuf",
Listener: lis,
}, opts...)
}
grpclog.Infof("benchmark server listening at %v", addr)
addrSplitted := strings.Split(addr, ":")
p, err := strconv.Atoi(addrSplitted[len(addrSplitted)-1])
if err != nil {
grpclog.Fatalf("failed to get port number from server address: %v", err)
}
return &benchmarkServer{
port: p,
cores: numOfCores,
closeFunc: closeFunc,
lastResetTime: time.Now(),
rusageLastReset: syscall.GetRusage(),
}, nil
}
// getStats returns the stats for benchmark server.
// It resets lastResetTime if argument reset is true.
func (bs *benchmarkServer) getStats(reset bool) *testpb.ServerStats {
bs.mu.RLock()
defer bs.mu.RUnlock()
wallTimeElapsed := time.Since(bs.lastResetTime).Seconds()
rusageLatest := syscall.GetRusage()
uTimeElapsed, sTimeElapsed := syscall.CPUTimeDiff(bs.rusageLastReset, rusageLatest)
if reset {
bs.lastResetTime = time.Now()
bs.rusageLastReset = rusageLatest
}
return &testpb.ServerStats{
TimeElapsed: wallTimeElapsed,
TimeUser: uTimeElapsed,
TimeSystem: sTimeElapsed,
}
}

230
vendor/google.golang.org/grpc/benchmark/worker/main.go generated vendored Normal file
View file

@ -0,0 +1,230 @@
/*
*
* Copyright 2016 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package main
import (
"flag"
"fmt"
"io"
"net"
"net/http"
_ "net/http/pprof"
"runtime"
"strconv"
"time"
"golang.org/x/net/context"
"google.golang.org/grpc"
testpb "google.golang.org/grpc/benchmark/grpc_testing"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/status"
)
var (
driverPort = flag.Int("driver_port", 10000, "port for communication with driver")
serverPort = flag.Int("server_port", 0, "port for benchmark server if not specified by server config message")
pprofPort = flag.Int("pprof_port", -1, "Port for pprof debug server to listen on. Pprof server doesn't start if unset")
blockProfRate = flag.Int("block_prof_rate", 0, "fraction of goroutine blocking events to report in blocking profile")
)
type byteBufCodec struct {
}
func (byteBufCodec) Marshal(v interface{}) ([]byte, error) {
b, ok := v.(*[]byte)
if !ok {
return nil, fmt.Errorf("failed to marshal: %v is not type of *[]byte", v)
}
return *b, nil
}
func (byteBufCodec) Unmarshal(data []byte, v interface{}) error {
b, ok := v.(*[]byte)
if !ok {
return fmt.Errorf("failed to marshal: %v is not type of *[]byte", v)
}
*b = data
return nil
}
func (byteBufCodec) String() string {
return "bytebuffer"
}
// workerServer implements WorkerService rpc handlers.
// It can create benchmarkServer or benchmarkClient on demand.
type workerServer struct {
stop chan<- bool
serverPort int
}
func (s *workerServer) RunServer(stream testpb.WorkerService_RunServerServer) error {
var bs *benchmarkServer
defer func() {
// Close benchmark server when stream ends.
grpclog.Infof("closing benchmark server")
if bs != nil {
bs.closeFunc()
}
}()
for {
in, err := stream.Recv()
if err == io.EOF {
return nil
}
if err != nil {
return err
}
var out *testpb.ServerStatus
switch argtype := in.Argtype.(type) {
case *testpb.ServerArgs_Setup:
grpclog.Infof("server setup received:")
if bs != nil {
grpclog.Infof("server setup received when server already exists, closing the existing server")
bs.closeFunc()
}
bs, err = startBenchmarkServer(argtype.Setup, s.serverPort)
if err != nil {
return err
}
out = &testpb.ServerStatus{
Stats: bs.getStats(false),
Port: int32(bs.port),
Cores: int32(bs.cores),
}
case *testpb.ServerArgs_Mark:
grpclog.Infof("server mark received:")
grpclog.Infof(" - %v", argtype)
if bs == nil {
return status.Error(codes.InvalidArgument, "server does not exist when mark received")
}
out = &testpb.ServerStatus{
Stats: bs.getStats(argtype.Mark.Reset_),
Port: int32(bs.port),
Cores: int32(bs.cores),
}
}
if err := stream.Send(out); err != nil {
return err
}
}
}
func (s *workerServer) RunClient(stream testpb.WorkerService_RunClientServer) error {
var bc *benchmarkClient
defer func() {
// Shut down benchmark client when stream ends.
grpclog.Infof("shuting down benchmark client")
if bc != nil {
bc.shutdown()
}
}()
for {
in, err := stream.Recv()
if err == io.EOF {
return nil
}
if err != nil {
return err
}
var out *testpb.ClientStatus
switch t := in.Argtype.(type) {
case *testpb.ClientArgs_Setup:
grpclog.Infof("client setup received:")
if bc != nil {
grpclog.Infof("client setup received when client already exists, shuting down the existing client")
bc.shutdown()
}
bc, err = startBenchmarkClient(t.Setup)
if err != nil {
return err
}
out = &testpb.ClientStatus{
Stats: bc.getStats(false),
}
case *testpb.ClientArgs_Mark:
grpclog.Infof("client mark received:")
grpclog.Infof(" - %v", t)
if bc == nil {
return status.Error(codes.InvalidArgument, "client does not exist when mark received")
}
out = &testpb.ClientStatus{
Stats: bc.getStats(t.Mark.Reset_),
}
}
if err := stream.Send(out); err != nil {
return err
}
}
}
func (s *workerServer) CoreCount(ctx context.Context, in *testpb.CoreRequest) (*testpb.CoreResponse, error) {
grpclog.Infof("core count: %v", runtime.NumCPU())
return &testpb.CoreResponse{Cores: int32(runtime.NumCPU())}, nil
}
func (s *workerServer) QuitWorker(ctx context.Context, in *testpb.Void) (*testpb.Void, error) {
grpclog.Infof("quitting worker")
s.stop <- true
return &testpb.Void{}, nil
}
func main() {
grpc.EnableTracing = false
flag.Parse()
lis, err := net.Listen("tcp", ":"+strconv.Itoa(*driverPort))
if err != nil {
grpclog.Fatalf("failed to listen: %v", err)
}
grpclog.Infof("worker listening at port %v", *driverPort)
s := grpc.NewServer()
stop := make(chan bool)
testpb.RegisterWorkerServiceServer(s, &workerServer{
stop: stop,
serverPort: *serverPort,
})
go func() {
<-stop
// Wait for 1 second before stopping the server to make sure the return value of QuitWorker is sent to client.
// TODO revise this once server graceful stop is supported in gRPC.
time.Sleep(time.Second)
s.Stop()
}()
runtime.SetBlockProfileRate(*blockProfRate)
if *pprofPort >= 0 {
go func() {
grpclog.Infoln("Starting pprof server on port " + strconv.Itoa(*pprofPort))
grpclog.Infoln(http.ListenAndServe("localhost:"+strconv.Itoa(*pprofPort), nil))
}()
}
s.Serve(lis)
}

View file

@ -27,12 +27,31 @@ import (
//
// All errors returned by Invoke are compatible with the status package.
func (cc *ClientConn) Invoke(ctx context.Context, method string, args, reply interface{}, opts ...CallOption) error {
// allow interceptor to see all applicable call options, which means those
// configured as defaults from dial option as well as per-call options
opts = combine(cc.dopts.callOptions, opts)
if cc.dopts.unaryInt != nil {
return cc.dopts.unaryInt(ctx, method, args, reply, cc, invoke, opts...)
}
return invoke(ctx, method, args, reply, cc, opts...)
}
func combine(o1 []CallOption, o2 []CallOption) []CallOption {
// we don't use append because o1 could have extra capacity whose
// elements would be overwritten, which could cause inadvertent
// sharing (and race connditions) between concurrent calls
if len(o1) == 0 {
return o2
} else if len(o2) == 0 {
return o1
}
ret := make([]CallOption, len(o1)+len(o2))
copy(ret, o1)
copy(ret[len(o1):], o2)
return ret
}
// Invoke sends the RPC request on the wire and returns after response is
// received. This is typically called by generated code.
//
@ -44,31 +63,12 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
var unaryStreamDesc = &StreamDesc{ServerStreams: false, ClientStreams: false}
func invoke(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, opts ...CallOption) error {
// TODO: implement retries in clientStream and make this simply
// newClientStream, SendMsg, RecvMsg.
firstAttempt := true
for {
csInt, err := newClientStream(ctx, unaryStreamDesc, cc, method, opts...)
if err != nil {
return err
}
cs := csInt.(*clientStream)
if err := cs.SendMsg(req); err != nil {
if !cs.c.failFast && cs.s.Unprocessed() && firstAttempt {
// TODO: Add a field to header for grpc-transparent-retry-attempts
firstAttempt = false
continue
}
return err
}
if err := cs.RecvMsg(reply); err != nil {
if !cs.c.failFast && cs.s.Unprocessed() && firstAttempt {
// TODO: Add a field to header for grpc-transparent-retry-attempts
firstAttempt = false
continue
}
return err
}
return nil
cs, err := newClientStream(ctx, unaryStreamDesc, cc, method, opts...)
if err != nil {
return err
}
if err := cs.SendMsg(req); err != nil {
return err
}
return cs.RecvMsg(reply)
}

View file

@ -31,9 +31,9 @@ import (
"golang.org/x/net/context"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/internal/leakcheck"
"google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/status"
"google.golang.org/grpc/test/leakcheck"
"google.golang.org/grpc/transport"
)
var (
@ -105,12 +105,13 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) {
}
}
// send a response back to end the stream.
hdr, data, err := encode(testCodec{}, &expectedResponse, nil, nil, nil)
data, err := encode(testCodec{}, &expectedResponse)
if err != nil {
t.Errorf("Failed to encode the response: %v", err)
return
}
h.t.Write(s, hdr, data, &transport.Options{})
hdr, payload := msgHeader(data, nil)
h.t.Write(s, hdr, payload, &transport.Options{})
h.t.WriteStatus(s, status.New(codes.OK, ""))
}
@ -217,7 +218,7 @@ func TestInvoke(t *testing.T) {
defer leakcheck.Check(t)
server, cc := setUp(t, 0, math.MaxUint32)
var reply string
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err != nil || reply != expectedResponse {
if err := cc.Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse {
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
}
cc.Close()
@ -229,7 +230,7 @@ func TestInvokeLargeErr(t *testing.T) {
server, cc := setUp(t, 0, math.MaxUint32)
var reply string
req := "hello"
err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc)
err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply)
if _, ok := status.FromError(err); !ok {
t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.")
}
@ -246,7 +247,7 @@ func TestInvokeErrorSpecialChars(t *testing.T) {
server, cc := setUp(t, 0, math.MaxUint32)
var reply string
req := "weird error"
err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc)
err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply)
if _, ok := status.FromError(err); !ok {
t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.")
}
@ -266,7 +267,7 @@ func TestInvokeCancel(t *testing.T) {
for i := 0; i < 100; i++ {
ctx, cancel := context.WithCancel(context.Background())
cancel()
Invoke(ctx, "/foo/bar", &req, &reply, cc)
cc.Invoke(ctx, "/foo/bar", &req, &reply)
}
if canceled != 0 {
t.Fatalf("received %d of 100 canceled requests", canceled)
@ -285,7 +286,7 @@ func TestInvokeCancelClosedNonFailFast(t *testing.T) {
req := "hello"
ctx, cancel := context.WithCancel(context.Background())
cancel()
if err := Invoke(ctx, "/foo/bar", &req, &reply, cc, FailFast(false)); err == nil {
if err := cc.Invoke(ctx, "/foo/bar", &req, &reply, FailFast(false)); err == nil {
t.Fatalf("canceled invoke on closed connection should fail")
}
server.stop()

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,105 @@
// +build !appengine,go1.7
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package service
import (
"github.com/golang/protobuf/ptypes"
channelzpb "google.golang.org/grpc/channelz/grpc_channelz_v1"
"google.golang.org/grpc/internal/channelz"
)
func sockoptToProto(skopts *channelz.SocketOptionData) []*channelzpb.SocketOption {
var opts []*channelzpb.SocketOption
if skopts.Linger != nil {
additional, err := ptypes.MarshalAny(&channelzpb.SocketOptionLinger{
Active: skopts.Linger.Onoff != 0,
Duration: convertToPtypesDuration(int64(skopts.Linger.Linger), 0),
})
if err == nil {
opts = append(opts, &channelzpb.SocketOption{
Name: "SO_LINGER",
Additional: additional,
})
}
}
if skopts.RecvTimeout != nil {
additional, err := ptypes.MarshalAny(&channelzpb.SocketOptionTimeout{
Duration: convertToPtypesDuration(int64(skopts.RecvTimeout.Sec), int64(skopts.RecvTimeout.Usec)),
})
if err == nil {
opts = append(opts, &channelzpb.SocketOption{
Name: "SO_RCVTIMEO",
Additional: additional,
})
}
}
if skopts.SendTimeout != nil {
additional, err := ptypes.MarshalAny(&channelzpb.SocketOptionTimeout{
Duration: convertToPtypesDuration(int64(skopts.SendTimeout.Sec), int64(skopts.SendTimeout.Usec)),
})
if err == nil {
opts = append(opts, &channelzpb.SocketOption{
Name: "SO_SNDTIMEO",
Additional: additional,
})
}
}
if skopts.TCPInfo != nil {
additional, err := ptypes.MarshalAny(&channelzpb.SocketOptionTcpInfo{
TcpiState: uint32(skopts.TCPInfo.State),
TcpiCaState: uint32(skopts.TCPInfo.Ca_state),
TcpiRetransmits: uint32(skopts.TCPInfo.Retransmits),
TcpiProbes: uint32(skopts.TCPInfo.Probes),
TcpiBackoff: uint32(skopts.TCPInfo.Backoff),
TcpiOptions: uint32(skopts.TCPInfo.Options),
// https://golang.org/pkg/syscall/#TCPInfo
// TCPInfo struct does not contain info about TcpiSndWscale and TcpiRcvWscale.
TcpiRto: skopts.TCPInfo.Rto,
TcpiAto: skopts.TCPInfo.Ato,
TcpiSndMss: skopts.TCPInfo.Snd_mss,
TcpiRcvMss: skopts.TCPInfo.Rcv_mss,
TcpiUnacked: skopts.TCPInfo.Unacked,
TcpiSacked: skopts.TCPInfo.Sacked,
TcpiLost: skopts.TCPInfo.Lost,
TcpiRetrans: skopts.TCPInfo.Retrans,
TcpiFackets: skopts.TCPInfo.Fackets,
TcpiLastDataSent: skopts.TCPInfo.Last_data_sent,
TcpiLastAckSent: skopts.TCPInfo.Last_ack_sent,
TcpiLastDataRecv: skopts.TCPInfo.Last_data_recv,
TcpiLastAckRecv: skopts.TCPInfo.Last_ack_recv,
TcpiPmtu: skopts.TCPInfo.Pmtu,
TcpiRcvSsthresh: skopts.TCPInfo.Rcv_ssthresh,
TcpiRtt: skopts.TCPInfo.Rtt,
TcpiRttvar: skopts.TCPInfo.Rttvar,
TcpiSndSsthresh: skopts.TCPInfo.Snd_ssthresh,
TcpiSndCwnd: skopts.TCPInfo.Snd_cwnd,
TcpiAdvmss: skopts.TCPInfo.Advmss,
TcpiReordering: skopts.TCPInfo.Reordering,
})
if err == nil {
opts = append(opts, &channelzpb.SocketOption{
Name: "TCP_INFO",
Additional: additional,
})
}
}
return opts
}

View file

@ -0,0 +1,30 @@
// +build !linux appengine !go1.7
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package service
import (
channelzpb "google.golang.org/grpc/channelz/grpc_channelz_v1"
"google.golang.org/grpc/internal/channelz"
)
func sockoptToProto(skopts *channelz.SocketOptionData) []*channelzpb.SocketOption {
return nil
}

33
vendor/google.golang.org/grpc/channelz/service/regenerate.sh generated vendored Executable file
View file

@ -0,0 +1,33 @@
#!/bin/bash
# Copyright 2018 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -eux -o pipefail
TMP=$(mktemp -d)
function finish {
rm -rf "$TMP"
}
trap finish EXIT
pushd "$TMP"
mkdir -p grpc/channelz/v1
curl https://raw.githubusercontent.com/grpc/grpc-proto/master/grpc/channelz/v1/channelz.proto > grpc/channelz/v1/channelz.proto
protoc --go_out=plugins=grpc,paths=source_relative:. -I. grpc/channelz/v1/*.proto
popd
rm -f ../grpc_channelz_v1/*.pb.go
cp "$TMP"/grpc/channelz/v1/*.pb.go ../grpc_channelz_v1/

View file

@ -0,0 +1,304 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
//go:generate ./regenerate.sh
// Package service provides an implementation for channelz service server.
package service
import (
"net"
"time"
"github.com/golang/protobuf/ptypes"
durpb "github.com/golang/protobuf/ptypes/duration"
wrpb "github.com/golang/protobuf/ptypes/wrappers"
"golang.org/x/net/context"
"google.golang.org/grpc"
channelzgrpc "google.golang.org/grpc/channelz/grpc_channelz_v1"
channelzpb "google.golang.org/grpc/channelz/grpc_channelz_v1"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/channelz"
)
func init() {
channelz.TurnOn()
}
func convertToPtypesDuration(sec int64, usec int64) *durpb.Duration {
return ptypes.DurationProto(time.Duration(sec*1e9 + usec*1e3))
}
// RegisterChannelzServiceToServer registers the channelz service to the given server.
func RegisterChannelzServiceToServer(s *grpc.Server) {
channelzgrpc.RegisterChannelzServer(s, newCZServer())
}
func newCZServer() channelzgrpc.ChannelzServer {
return &serverImpl{}
}
type serverImpl struct{}
func connectivityStateToProto(s connectivity.State) *channelzpb.ChannelConnectivityState {
switch s {
case connectivity.Idle:
return &channelzpb.ChannelConnectivityState{State: channelzpb.ChannelConnectivityState_IDLE}
case connectivity.Connecting:
return &channelzpb.ChannelConnectivityState{State: channelzpb.ChannelConnectivityState_CONNECTING}
case connectivity.Ready:
return &channelzpb.ChannelConnectivityState{State: channelzpb.ChannelConnectivityState_READY}
case connectivity.TransientFailure:
return &channelzpb.ChannelConnectivityState{State: channelzpb.ChannelConnectivityState_TRANSIENT_FAILURE}
case connectivity.Shutdown:
return &channelzpb.ChannelConnectivityState{State: channelzpb.ChannelConnectivityState_SHUTDOWN}
default:
return &channelzpb.ChannelConnectivityState{State: channelzpb.ChannelConnectivityState_UNKNOWN}
}
}
func channelMetricToProto(cm *channelz.ChannelMetric) *channelzpb.Channel {
c := &channelzpb.Channel{}
c.Ref = &channelzpb.ChannelRef{ChannelId: cm.ID, Name: cm.RefName}
c.Data = &channelzpb.ChannelData{
State: connectivityStateToProto(cm.ChannelData.State),
Target: cm.ChannelData.Target,
CallsStarted: cm.ChannelData.CallsStarted,
CallsSucceeded: cm.ChannelData.CallsSucceeded,
CallsFailed: cm.ChannelData.CallsFailed,
}
if ts, err := ptypes.TimestampProto(cm.ChannelData.LastCallStartedTimestamp); err == nil {
c.Data.LastCallStartedTimestamp = ts
}
nestedChans := make([]*channelzpb.ChannelRef, 0, len(cm.NestedChans))
for id, ref := range cm.NestedChans {
nestedChans = append(nestedChans, &channelzpb.ChannelRef{ChannelId: id, Name: ref})
}
c.ChannelRef = nestedChans
subChans := make([]*channelzpb.SubchannelRef, 0, len(cm.SubChans))
for id, ref := range cm.SubChans {
subChans = append(subChans, &channelzpb.SubchannelRef{SubchannelId: id, Name: ref})
}
c.SubchannelRef = subChans
sockets := make([]*channelzpb.SocketRef, 0, len(cm.Sockets))
for id, ref := range cm.Sockets {
sockets = append(sockets, &channelzpb.SocketRef{SocketId: id, Name: ref})
}
c.SocketRef = sockets
return c
}
func subChannelMetricToProto(cm *channelz.SubChannelMetric) *channelzpb.Subchannel {
sc := &channelzpb.Subchannel{}
sc.Ref = &channelzpb.SubchannelRef{SubchannelId: cm.ID, Name: cm.RefName}
sc.Data = &channelzpb.ChannelData{
State: connectivityStateToProto(cm.ChannelData.State),
Target: cm.ChannelData.Target,
CallsStarted: cm.ChannelData.CallsStarted,
CallsSucceeded: cm.ChannelData.CallsSucceeded,
CallsFailed: cm.ChannelData.CallsFailed,
}
if ts, err := ptypes.TimestampProto(cm.ChannelData.LastCallStartedTimestamp); err == nil {
sc.Data.LastCallStartedTimestamp = ts
}
nestedChans := make([]*channelzpb.ChannelRef, 0, len(cm.NestedChans))
for id, ref := range cm.NestedChans {
nestedChans = append(nestedChans, &channelzpb.ChannelRef{ChannelId: id, Name: ref})
}
sc.ChannelRef = nestedChans
subChans := make([]*channelzpb.SubchannelRef, 0, len(cm.SubChans))
for id, ref := range cm.SubChans {
subChans = append(subChans, &channelzpb.SubchannelRef{SubchannelId: id, Name: ref})
}
sc.SubchannelRef = subChans
sockets := make([]*channelzpb.SocketRef, 0, len(cm.Sockets))
for id, ref := range cm.Sockets {
sockets = append(sockets, &channelzpb.SocketRef{SocketId: id, Name: ref})
}
sc.SocketRef = sockets
return sc
}
func securityToProto(se credentials.ChannelzSecurityValue) *channelzpb.Security {
switch v := se.(type) {
case *credentials.TLSChannelzSecurityValue:
return &channelzpb.Security{Model: &channelzpb.Security_Tls_{Tls: &channelzpb.Security_Tls{
CipherSuite: &channelzpb.Security_Tls_StandardName{StandardName: v.StandardName},
LocalCertificate: v.LocalCertificate,
RemoteCertificate: v.RemoteCertificate,
}}}
case *credentials.OtherChannelzSecurityValue:
otherSecurity := &channelzpb.Security_OtherSecurity{
Name: v.Name,
}
if anyval, err := ptypes.MarshalAny(v.Value); err == nil {
otherSecurity.Value = anyval
}
return &channelzpb.Security{Model: &channelzpb.Security_Other{Other: otherSecurity}}
}
return nil
}
func addrToProto(a net.Addr) *channelzpb.Address {
switch a.Network() {
case "udp":
// TODO: Address_OtherAddress{}. Need proto def for Value.
case "ip":
// Note zone info is discarded through the conversion.
return &channelzpb.Address{Address: &channelzpb.Address_TcpipAddress{TcpipAddress: &channelzpb.Address_TcpIpAddress{IpAddress: a.(*net.IPAddr).IP}}}
case "ip+net":
// Note mask info is discarded through the conversion.
return &channelzpb.Address{Address: &channelzpb.Address_TcpipAddress{TcpipAddress: &channelzpb.Address_TcpIpAddress{IpAddress: a.(*net.IPNet).IP}}}
case "tcp":
// Note zone info is discarded through the conversion.
return &channelzpb.Address{Address: &channelzpb.Address_TcpipAddress{TcpipAddress: &channelzpb.Address_TcpIpAddress{IpAddress: a.(*net.TCPAddr).IP, Port: int32(a.(*net.TCPAddr).Port)}}}
case "unix", "unixgram", "unixpacket":
return &channelzpb.Address{Address: &channelzpb.Address_UdsAddress_{UdsAddress: &channelzpb.Address_UdsAddress{Filename: a.String()}}}
default:
}
return &channelzpb.Address{}
}
func socketMetricToProto(sm *channelz.SocketMetric) *channelzpb.Socket {
s := &channelzpb.Socket{}
s.Ref = &channelzpb.SocketRef{SocketId: sm.ID, Name: sm.RefName}
s.Data = &channelzpb.SocketData{
StreamsStarted: sm.SocketData.StreamsStarted,
StreamsSucceeded: sm.SocketData.StreamsSucceeded,
StreamsFailed: sm.SocketData.StreamsFailed,
MessagesSent: sm.SocketData.MessagesSent,
MessagesReceived: sm.SocketData.MessagesReceived,
KeepAlivesSent: sm.SocketData.KeepAlivesSent,
}
if ts, err := ptypes.TimestampProto(sm.SocketData.LastLocalStreamCreatedTimestamp); err == nil {
s.Data.LastLocalStreamCreatedTimestamp = ts
}
if ts, err := ptypes.TimestampProto(sm.SocketData.LastRemoteStreamCreatedTimestamp); err == nil {
s.Data.LastRemoteStreamCreatedTimestamp = ts
}
if ts, err := ptypes.TimestampProto(sm.SocketData.LastMessageSentTimestamp); err == nil {
s.Data.LastMessageSentTimestamp = ts
}
if ts, err := ptypes.TimestampProto(sm.SocketData.LastMessageReceivedTimestamp); err == nil {
s.Data.LastMessageReceivedTimestamp = ts
}
s.Data.LocalFlowControlWindow = &wrpb.Int64Value{Value: sm.SocketData.LocalFlowControlWindow}
s.Data.RemoteFlowControlWindow = &wrpb.Int64Value{Value: sm.SocketData.RemoteFlowControlWindow}
if sm.SocketData.SocketOptions != nil {
s.Data.Option = sockoptToProto(sm.SocketData.SocketOptions)
}
if sm.SocketData.Security != nil {
s.Security = securityToProto(sm.SocketData.Security)
}
if sm.SocketData.LocalAddr != nil {
s.Local = addrToProto(sm.SocketData.LocalAddr)
}
if sm.SocketData.RemoteAddr != nil {
s.Remote = addrToProto(sm.SocketData.RemoteAddr)
}
s.RemoteName = sm.SocketData.RemoteName
return s
}
func (s *serverImpl) GetTopChannels(ctx context.Context, req *channelzpb.GetTopChannelsRequest) (*channelzpb.GetTopChannelsResponse, error) {
metrics, end := channelz.GetTopChannels(req.GetStartChannelId())
resp := &channelzpb.GetTopChannelsResponse{}
for _, m := range metrics {
resp.Channel = append(resp.Channel, channelMetricToProto(m))
}
resp.End = end
return resp, nil
}
func serverMetricToProto(sm *channelz.ServerMetric) *channelzpb.Server {
s := &channelzpb.Server{}
s.Ref = &channelzpb.ServerRef{ServerId: sm.ID, Name: sm.RefName}
s.Data = &channelzpb.ServerData{
CallsStarted: sm.ServerData.CallsStarted,
CallsSucceeded: sm.ServerData.CallsSucceeded,
CallsFailed: sm.ServerData.CallsFailed,
}
if ts, err := ptypes.TimestampProto(sm.ServerData.LastCallStartedTimestamp); err == nil {
s.Data.LastCallStartedTimestamp = ts
}
sockets := make([]*channelzpb.SocketRef, 0, len(sm.ListenSockets))
for id, ref := range sm.ListenSockets {
sockets = append(sockets, &channelzpb.SocketRef{SocketId: id, Name: ref})
}
s.ListenSocket = sockets
return s
}
func (s *serverImpl) GetServers(ctx context.Context, req *channelzpb.GetServersRequest) (*channelzpb.GetServersResponse, error) {
metrics, end := channelz.GetServers(req.GetStartServerId())
resp := &channelzpb.GetServersResponse{}
for _, m := range metrics {
resp.Server = append(resp.Server, serverMetricToProto(m))
}
resp.End = end
return resp, nil
}
func (s *serverImpl) GetServerSockets(ctx context.Context, req *channelzpb.GetServerSocketsRequest) (*channelzpb.GetServerSocketsResponse, error) {
metrics, end := channelz.GetServerSockets(req.GetServerId(), req.GetStartSocketId())
resp := &channelzpb.GetServerSocketsResponse{}
for _, m := range metrics {
resp.SocketRef = append(resp.SocketRef, &channelzpb.SocketRef{SocketId: m.ID, Name: m.RefName})
}
resp.End = end
return resp, nil
}
func (s *serverImpl) GetChannel(ctx context.Context, req *channelzpb.GetChannelRequest) (*channelzpb.GetChannelResponse, error) {
var metric *channelz.ChannelMetric
if metric = channelz.GetChannel(req.GetChannelId()); metric == nil {
return &channelzpb.GetChannelResponse{}, nil
}
resp := &channelzpb.GetChannelResponse{Channel: channelMetricToProto(metric)}
return resp, nil
}
func (s *serverImpl) GetSubchannel(ctx context.Context, req *channelzpb.GetSubchannelRequest) (*channelzpb.GetSubchannelResponse, error) {
var metric *channelz.SubChannelMetric
if metric = channelz.GetSubChannel(req.GetSubchannelId()); metric == nil {
return &channelzpb.GetSubchannelResponse{}, nil
}
resp := &channelzpb.GetSubchannelResponse{Subchannel: subChannelMetricToProto(metric)}
return resp, nil
}
func (s *serverImpl) GetSocket(ctx context.Context, req *channelzpb.GetSocketRequest) (*channelzpb.GetSocketResponse, error) {
var metric *channelz.SocketMetric
if metric = channelz.GetSocket(req.GetSocketId()); metric == nil {
return &channelzpb.GetSocketResponse{}, nil
}
resp := &channelzpb.GetSocketResponse{Socket: socketMetricToProto(metric)}
return resp, nil
}

View file

@ -0,0 +1,152 @@
// +build linux,!appengine,go1.7
// +build 386 amd64
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// SocketOptions is only supported on linux system. The functions defined in
// this file are to parse the socket option field and the test is specifically
// to verify the behavior of socket option parsing.
package service
import (
"reflect"
"strconv"
"testing"
"github.com/golang/protobuf/ptypes"
durpb "github.com/golang/protobuf/ptypes/duration"
"golang.org/x/net/context"
"golang.org/x/sys/unix"
channelzpb "google.golang.org/grpc/channelz/grpc_channelz_v1"
"google.golang.org/grpc/internal/channelz"
)
func init() {
// Assign protoToSocketOption to protoToSocketOpt in order to enable socket option
// data conversion from proto message to channelz defined struct.
protoToSocketOpt = protoToSocketOption
}
func convertToDuration(d *durpb.Duration) (sec int64, usec int64) {
if d != nil {
if dur, err := ptypes.Duration(d); err == nil {
sec = int64(int64(dur) / 1e9)
usec = (int64(dur) - sec*1e9) / 1e3
}
}
return
}
func protoToLinger(protoLinger *channelzpb.SocketOptionLinger) *unix.Linger {
linger := &unix.Linger{}
if protoLinger.GetActive() {
linger.Onoff = 1
}
lv, _ := convertToDuration(protoLinger.GetDuration())
linger.Linger = int32(lv)
return linger
}
func protoToSocketOption(skopts []*channelzpb.SocketOption) *channelz.SocketOptionData {
skdata := &channelz.SocketOptionData{}
for _, opt := range skopts {
switch opt.GetName() {
case "SO_LINGER":
protoLinger := &channelzpb.SocketOptionLinger{}
err := ptypes.UnmarshalAny(opt.GetAdditional(), protoLinger)
if err == nil {
skdata.Linger = protoToLinger(protoLinger)
}
case "SO_RCVTIMEO":
protoTimeout := &channelzpb.SocketOptionTimeout{}
err := ptypes.UnmarshalAny(opt.GetAdditional(), protoTimeout)
if err == nil {
skdata.RecvTimeout = protoToTime(protoTimeout)
}
case "SO_SNDTIMEO":
protoTimeout := &channelzpb.SocketOptionTimeout{}
err := ptypes.UnmarshalAny(opt.GetAdditional(), protoTimeout)
if err == nil {
skdata.SendTimeout = protoToTime(protoTimeout)
}
case "TCP_INFO":
tcpi := &channelzpb.SocketOptionTcpInfo{}
err := ptypes.UnmarshalAny(opt.GetAdditional(), tcpi)
if err == nil {
skdata.TCPInfo = &unix.TCPInfo{
State: uint8(tcpi.TcpiState),
Ca_state: uint8(tcpi.TcpiCaState),
Retransmits: uint8(tcpi.TcpiRetransmits),
Probes: uint8(tcpi.TcpiProbes),
Backoff: uint8(tcpi.TcpiBackoff),
Options: uint8(tcpi.TcpiOptions),
Rto: tcpi.TcpiRto,
Ato: tcpi.TcpiAto,
Snd_mss: tcpi.TcpiSndMss,
Rcv_mss: tcpi.TcpiRcvMss,
Unacked: tcpi.TcpiUnacked,
Sacked: tcpi.TcpiSacked,
Lost: tcpi.TcpiLost,
Retrans: tcpi.TcpiRetrans,
Fackets: tcpi.TcpiFackets,
Last_data_sent: tcpi.TcpiLastDataSent,
Last_ack_sent: tcpi.TcpiLastAckSent,
Last_data_recv: tcpi.TcpiLastDataRecv,
Last_ack_recv: tcpi.TcpiLastAckRecv,
Pmtu: tcpi.TcpiPmtu,
Rcv_ssthresh: tcpi.TcpiRcvSsthresh,
Rtt: tcpi.TcpiRtt,
Rttvar: tcpi.TcpiRttvar,
Snd_ssthresh: tcpi.TcpiSndSsthresh,
Snd_cwnd: tcpi.TcpiSndCwnd,
Advmss: tcpi.TcpiAdvmss,
Reordering: tcpi.TcpiReordering}
}
}
}
return skdata
}
func TestGetSocketOptions(t *testing.T) {
channelz.NewChannelzStorage()
ss := []*dummySocket{
{
socketOptions: &channelz.SocketOptionData{
Linger: &unix.Linger{Onoff: 1, Linger: 2},
RecvTimeout: &unix.Timeval{Sec: 10, Usec: 1},
SendTimeout: &unix.Timeval{},
TCPInfo: &unix.TCPInfo{State: 1},
},
},
}
svr := newCZServer()
ids := make([]int64, len(ss))
svrID := channelz.RegisterServer(&dummyServer{}, "")
for i, s := range ss {
ids[i] = channelz.RegisterNormalSocket(s, svrID, strconv.Itoa(i))
}
for i, s := range ss {
resp, _ := svr.GetSocket(context.Background(), &channelzpb.GetSocketRequest{SocketId: ids[i]})
metrics := resp.GetSocket()
if !reflect.DeepEqual(metrics.GetRef(), &channelzpb.SocketRef{SocketId: ids[i], Name: strconv.Itoa(i)}) || !reflect.DeepEqual(socketProtoToStruct(metrics), s) {
t.Fatalf("resp.GetSocket() want: metrics.GetRef() = %#v and %#v, got: metrics.GetRef() = %#v and %#v", &channelzpb.SocketRef{SocketId: ids[i], Name: strconv.Itoa(i)}, s, metrics.GetRef(), socketProtoToStruct(metrics))
}
}
}

View file

@ -0,0 +1,538 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package service
import (
"net"
"reflect"
"strconv"
"testing"
"time"
"github.com/golang/protobuf/proto"
"github.com/golang/protobuf/ptypes"
"golang.org/x/net/context"
channelzpb "google.golang.org/grpc/channelz/grpc_channelz_v1"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/channelz"
)
func init() {
channelz.TurnOn()
}
type protoToSocketOptFunc func([]*channelzpb.SocketOption) *channelz.SocketOptionData
// protoToSocketOpt is used in function socketProtoToStruct to extract socket option
// data from unmarshaled proto message.
// It is only defined under linux, non-appengine environment on x86 architecture.
var protoToSocketOpt protoToSocketOptFunc
// emptyTime is used for detecting unset value of time.Time type.
// For go1.7 and earlier, ptypes.Timestamp will fill in the loc field of time.Time
// with &utcLoc. However zero value of a time.Time type value loc field is nil.
// This behavior will make reflect.DeepEqual fail upon unset time.Time field,
// and cause false positive fatal error.
var emptyTime time.Time
type dummyChannel struct {
state connectivity.State
target string
callsStarted int64
callsSucceeded int64
callsFailed int64
lastCallStartedTimestamp time.Time
}
func (d *dummyChannel) ChannelzMetric() *channelz.ChannelInternalMetric {
return &channelz.ChannelInternalMetric{
State: d.state,
Target: d.target,
CallsStarted: d.callsStarted,
CallsSucceeded: d.callsSucceeded,
CallsFailed: d.callsFailed,
LastCallStartedTimestamp: d.lastCallStartedTimestamp,
}
}
type dummyServer struct {
callsStarted int64
callsSucceeded int64
callsFailed int64
lastCallStartedTimestamp time.Time
}
func (d *dummyServer) ChannelzMetric() *channelz.ServerInternalMetric {
return &channelz.ServerInternalMetric{
CallsStarted: d.callsStarted,
CallsSucceeded: d.callsSucceeded,
CallsFailed: d.callsFailed,
LastCallStartedTimestamp: d.lastCallStartedTimestamp,
}
}
type dummySocket struct {
streamsStarted int64
streamsSucceeded int64
streamsFailed int64
messagesSent int64
messagesReceived int64
keepAlivesSent int64
lastLocalStreamCreatedTimestamp time.Time
lastRemoteStreamCreatedTimestamp time.Time
lastMessageSentTimestamp time.Time
lastMessageReceivedTimestamp time.Time
localFlowControlWindow int64
remoteFlowControlWindow int64
socketOptions *channelz.SocketOptionData
localAddr net.Addr
remoteAddr net.Addr
security credentials.ChannelzSecurityValue
remoteName string
}
func (d *dummySocket) ChannelzMetric() *channelz.SocketInternalMetric {
return &channelz.SocketInternalMetric{
StreamsStarted: d.streamsStarted,
StreamsSucceeded: d.streamsSucceeded,
StreamsFailed: d.streamsFailed,
MessagesSent: d.messagesSent,
MessagesReceived: d.messagesReceived,
KeepAlivesSent: d.keepAlivesSent,
LastLocalStreamCreatedTimestamp: d.lastLocalStreamCreatedTimestamp,
LastRemoteStreamCreatedTimestamp: d.lastRemoteStreamCreatedTimestamp,
LastMessageSentTimestamp: d.lastMessageSentTimestamp,
LastMessageReceivedTimestamp: d.lastMessageReceivedTimestamp,
LocalFlowControlWindow: d.localFlowControlWindow,
RemoteFlowControlWindow: d.remoteFlowControlWindow,
SocketOptions: d.socketOptions,
LocalAddr: d.localAddr,
RemoteAddr: d.remoteAddr,
Security: d.security,
RemoteName: d.remoteName,
}
}
func channelProtoToStruct(c *channelzpb.Channel) *dummyChannel {
dc := &dummyChannel{}
pdata := c.GetData()
switch pdata.GetState().GetState() {
case channelzpb.ChannelConnectivityState_UNKNOWN:
// TODO: what should we set here?
case channelzpb.ChannelConnectivityState_IDLE:
dc.state = connectivity.Idle
case channelzpb.ChannelConnectivityState_CONNECTING:
dc.state = connectivity.Connecting
case channelzpb.ChannelConnectivityState_READY:
dc.state = connectivity.Ready
case channelzpb.ChannelConnectivityState_TRANSIENT_FAILURE:
dc.state = connectivity.TransientFailure
case channelzpb.ChannelConnectivityState_SHUTDOWN:
dc.state = connectivity.Shutdown
}
dc.target = pdata.GetTarget()
dc.callsStarted = pdata.CallsStarted
dc.callsSucceeded = pdata.CallsSucceeded
dc.callsFailed = pdata.CallsFailed
if t, err := ptypes.Timestamp(pdata.GetLastCallStartedTimestamp()); err == nil {
if !t.Equal(emptyTime) {
dc.lastCallStartedTimestamp = t
}
}
return dc
}
func serverProtoToStruct(s *channelzpb.Server) *dummyServer {
ds := &dummyServer{}
pdata := s.GetData()
ds.callsStarted = pdata.CallsStarted
ds.callsSucceeded = pdata.CallsSucceeded
ds.callsFailed = pdata.CallsFailed
if t, err := ptypes.Timestamp(pdata.GetLastCallStartedTimestamp()); err == nil {
if !t.Equal(emptyTime) {
ds.lastCallStartedTimestamp = t
}
}
return ds
}
func socketProtoToStruct(s *channelzpb.Socket) *dummySocket {
ds := &dummySocket{}
pdata := s.GetData()
ds.streamsStarted = pdata.GetStreamsStarted()
ds.streamsSucceeded = pdata.GetStreamsSucceeded()
ds.streamsFailed = pdata.GetStreamsFailed()
ds.messagesSent = pdata.GetMessagesSent()
ds.messagesReceived = pdata.GetMessagesReceived()
ds.keepAlivesSent = pdata.GetKeepAlivesSent()
if t, err := ptypes.Timestamp(pdata.GetLastLocalStreamCreatedTimestamp()); err == nil {
if !t.Equal(emptyTime) {
ds.lastLocalStreamCreatedTimestamp = t
}
}
if t, err := ptypes.Timestamp(pdata.GetLastRemoteStreamCreatedTimestamp()); err == nil {
if !t.Equal(emptyTime) {
ds.lastRemoteStreamCreatedTimestamp = t
}
}
if t, err := ptypes.Timestamp(pdata.GetLastMessageSentTimestamp()); err == nil {
if !t.Equal(emptyTime) {
ds.lastMessageSentTimestamp = t
}
}
if t, err := ptypes.Timestamp(pdata.GetLastMessageReceivedTimestamp()); err == nil {
if !t.Equal(emptyTime) {
ds.lastMessageReceivedTimestamp = t
}
}
if v := pdata.GetLocalFlowControlWindow(); v != nil {
ds.localFlowControlWindow = v.Value
}
if v := pdata.GetRemoteFlowControlWindow(); v != nil {
ds.remoteFlowControlWindow = v.Value
}
if v := pdata.GetOption(); v != nil && protoToSocketOpt != nil {
ds.socketOptions = protoToSocketOpt(v)
}
if v := s.GetSecurity(); v != nil {
ds.security = protoToSecurity(v)
}
if local := s.GetLocal(); local != nil {
ds.localAddr = protoToAddr(local)
}
if remote := s.GetRemote(); remote != nil {
ds.remoteAddr = protoToAddr(remote)
}
ds.remoteName = s.GetRemoteName()
return ds
}
func protoToSecurity(protoSecurity *channelzpb.Security) credentials.ChannelzSecurityValue {
switch v := protoSecurity.Model.(type) {
case *channelzpb.Security_Tls_:
return &credentials.TLSChannelzSecurityValue{StandardName: v.Tls.GetStandardName(), LocalCertificate: v.Tls.GetLocalCertificate(), RemoteCertificate: v.Tls.GetRemoteCertificate()}
case *channelzpb.Security_Other:
sv := &credentials.OtherChannelzSecurityValue{Name: v.Other.GetName()}
var x ptypes.DynamicAny
if err := ptypes.UnmarshalAny(v.Other.GetValue(), &x); err == nil {
sv.Value = x.Message
}
return sv
}
return nil
}
func protoToAddr(a *channelzpb.Address) net.Addr {
switch v := a.Address.(type) {
case *channelzpb.Address_TcpipAddress:
if port := v.TcpipAddress.GetPort(); port != 0 {
return &net.TCPAddr{IP: v.TcpipAddress.GetIpAddress(), Port: int(port)}
}
return &net.IPAddr{IP: v.TcpipAddress.GetIpAddress()}
case *channelzpb.Address_UdsAddress_:
return &net.UnixAddr{Name: v.UdsAddress.GetFilename(), Net: "unix"}
case *channelzpb.Address_OtherAddress_:
// TODO:
}
return nil
}
func convertSocketRefSliceToMap(sktRefs []*channelzpb.SocketRef) map[int64]string {
m := make(map[int64]string)
for _, sr := range sktRefs {
m[sr.SocketId] = sr.Name
}
return m
}
type OtherSecurityValue struct {
LocalCertificate []byte `protobuf:"bytes,1,opt,name=local_certificate,json=localCertificate,proto3" json:"local_certificate,omitempty"`
RemoteCertificate []byte `protobuf:"bytes,2,opt,name=remote_certificate,json=remoteCertificate,proto3" json:"remote_certificate,omitempty"`
}
func (m *OtherSecurityValue) Reset() { *m = OtherSecurityValue{} }
func (m *OtherSecurityValue) String() string { return proto.CompactTextString(m) }
func (*OtherSecurityValue) ProtoMessage() {}
func init() {
// Ad-hoc registering the proto type here to facilitate UnmarshalAny of OtherSecurityValue.
proto.RegisterType((*OtherSecurityValue)(nil), "grpc.credentials.OtherChannelzSecurityValue")
}
func TestGetTopChannels(t *testing.T) {
tcs := []*dummyChannel{
{
state: connectivity.Connecting,
target: "test.channelz:1234",
callsStarted: 6,
callsSucceeded: 2,
callsFailed: 3,
lastCallStartedTimestamp: time.Now().UTC(),
},
{
state: connectivity.Connecting,
target: "test.channelz:1234",
callsStarted: 1,
callsSucceeded: 2,
callsFailed: 3,
lastCallStartedTimestamp: time.Now().UTC(),
},
{
state: connectivity.Shutdown,
target: "test.channelz:8888",
callsStarted: 0,
callsSucceeded: 0,
callsFailed: 0,
},
{},
}
channelz.NewChannelzStorage()
for _, c := range tcs {
channelz.RegisterChannel(c, 0, "")
}
s := newCZServer()
resp, _ := s.GetTopChannels(context.Background(), &channelzpb.GetTopChannelsRequest{StartChannelId: 0})
if !resp.GetEnd() {
t.Fatalf("resp.GetEnd() want true, got %v", resp.GetEnd())
}
for i, c := range resp.GetChannel() {
if !reflect.DeepEqual(channelProtoToStruct(c), tcs[i]) {
t.Fatalf("dummyChannel: %d, want: %#v, got: %#v", i, tcs[i], channelProtoToStruct(c))
}
}
for i := 0; i < 50; i++ {
channelz.RegisterChannel(tcs[0], 0, "")
}
resp, _ = s.GetTopChannels(context.Background(), &channelzpb.GetTopChannelsRequest{StartChannelId: 0})
if resp.GetEnd() {
t.Fatalf("resp.GetEnd() want false, got %v", resp.GetEnd())
}
}
func TestGetServers(t *testing.T) {
ss := []*dummyServer{
{
callsStarted: 6,
callsSucceeded: 2,
callsFailed: 3,
lastCallStartedTimestamp: time.Now().UTC(),
},
{
callsStarted: 1,
callsSucceeded: 2,
callsFailed: 3,
lastCallStartedTimestamp: time.Now().UTC(),
},
{
callsStarted: 1,
callsSucceeded: 0,
callsFailed: 0,
lastCallStartedTimestamp: time.Now().UTC(),
},
}
channelz.NewChannelzStorage()
for _, s := range ss {
channelz.RegisterServer(s, "")
}
svr := newCZServer()
resp, _ := svr.GetServers(context.Background(), &channelzpb.GetServersRequest{StartServerId: 0})
if !resp.GetEnd() {
t.Fatalf("resp.GetEnd() want true, got %v", resp.GetEnd())
}
for i, s := range resp.GetServer() {
if !reflect.DeepEqual(serverProtoToStruct(s), ss[i]) {
t.Fatalf("dummyServer: %d, want: %#v, got: %#v", i, ss[i], serverProtoToStruct(s))
}
}
for i := 0; i < 50; i++ {
channelz.RegisterServer(ss[0], "")
}
resp, _ = svr.GetServers(context.Background(), &channelzpb.GetServersRequest{StartServerId: 0})
if resp.GetEnd() {
t.Fatalf("resp.GetEnd() want false, got %v", resp.GetEnd())
}
}
func TestGetServerSockets(t *testing.T) {
channelz.NewChannelzStorage()
svrID := channelz.RegisterServer(&dummyServer{}, "")
refNames := []string{"listen socket 1", "normal socket 1", "normal socket 2"}
ids := make([]int64, 3)
ids[0] = channelz.RegisterListenSocket(&dummySocket{}, svrID, refNames[0])
ids[1] = channelz.RegisterNormalSocket(&dummySocket{}, svrID, refNames[1])
ids[2] = channelz.RegisterNormalSocket(&dummySocket{}, svrID, refNames[2])
svr := newCZServer()
resp, _ := svr.GetServerSockets(context.Background(), &channelzpb.GetServerSocketsRequest{ServerId: svrID, StartSocketId: 0})
if !resp.GetEnd() {
t.Fatalf("resp.GetEnd() want: true, got: %v", resp.GetEnd())
}
// GetServerSockets only return normal sockets.
want := map[int64]string{
ids[1]: refNames[1],
ids[2]: refNames[2],
}
if !reflect.DeepEqual(convertSocketRefSliceToMap(resp.GetSocketRef()), want) {
t.Fatalf("GetServerSockets want: %#v, got: %#v", want, resp.GetSocketRef())
}
for i := 0; i < 50; i++ {
channelz.RegisterNormalSocket(&dummySocket{}, svrID, "")
}
resp, _ = svr.GetServerSockets(context.Background(), &channelzpb.GetServerSocketsRequest{ServerId: svrID, StartSocketId: 0})
if resp.GetEnd() {
t.Fatalf("resp.GetEnd() want false, got %v", resp.GetEnd())
}
}
func TestGetChannel(t *testing.T) {
channelz.NewChannelzStorage()
refNames := []string{"top channel 1", "nested channel 1", "nested channel 2", "nested channel 3"}
ids := make([]int64, 4)
ids[0] = channelz.RegisterChannel(&dummyChannel{}, 0, refNames[0])
ids[1] = channelz.RegisterChannel(&dummyChannel{}, ids[0], refNames[1])
ids[2] = channelz.RegisterSubChannel(&dummyChannel{}, ids[0], refNames[2])
ids[3] = channelz.RegisterChannel(&dummyChannel{}, ids[1], refNames[3])
svr := newCZServer()
resp, _ := svr.GetChannel(context.Background(), &channelzpb.GetChannelRequest{ChannelId: ids[0]})
metrics := resp.GetChannel()
subChans := metrics.GetSubchannelRef()
if len(subChans) != 1 || subChans[0].GetName() != refNames[2] || subChans[0].GetSubchannelId() != ids[2] {
t.Fatalf("GetSubChannelRef() want %#v, got %#v", []*channelzpb.SubchannelRef{{SubchannelId: ids[2], Name: refNames[2]}}, subChans)
}
nestedChans := metrics.GetChannelRef()
if len(nestedChans) != 1 || nestedChans[0].GetName() != refNames[1] || nestedChans[0].GetChannelId() != ids[1] {
t.Fatalf("GetChannelRef() want %#v, got %#v", []*channelzpb.ChannelRef{{ChannelId: ids[1], Name: refNames[1]}}, nestedChans)
}
resp, _ = svr.GetChannel(context.Background(), &channelzpb.GetChannelRequest{ChannelId: ids[1]})
metrics = resp.GetChannel()
nestedChans = metrics.GetChannelRef()
if len(nestedChans) != 1 || nestedChans[0].GetName() != refNames[3] || nestedChans[0].GetChannelId() != ids[3] {
t.Fatalf("GetChannelRef() want %#v, got %#v", []*channelzpb.ChannelRef{{ChannelId: ids[3], Name: refNames[3]}}, nestedChans)
}
}
func TestGetSubChannel(t *testing.T) {
channelz.NewChannelzStorage()
refNames := []string{"top channel 1", "sub channel 1", "socket 1", "socket 2"}
ids := make([]int64, 4)
ids[0] = channelz.RegisterChannel(&dummyChannel{}, 0, refNames[0])
ids[1] = channelz.RegisterSubChannel(&dummyChannel{}, ids[0], refNames[1])
ids[2] = channelz.RegisterNormalSocket(&dummySocket{}, ids[1], refNames[2])
ids[3] = channelz.RegisterNormalSocket(&dummySocket{}, ids[1], refNames[3])
svr := newCZServer()
resp, _ := svr.GetSubchannel(context.Background(), &channelzpb.GetSubchannelRequest{SubchannelId: ids[1]})
metrics := resp.GetSubchannel()
want := map[int64]string{
ids[2]: refNames[2],
ids[3]: refNames[3],
}
if !reflect.DeepEqual(convertSocketRefSliceToMap(metrics.GetSocketRef()), want) {
t.Fatalf("GetSocketRef() want %#v: got: %#v", want, metrics.GetSocketRef())
}
}
func TestGetSocket(t *testing.T) {
channelz.NewChannelzStorage()
ss := []*dummySocket{
{
streamsStarted: 10,
streamsSucceeded: 2,
streamsFailed: 3,
messagesSent: 20,
messagesReceived: 10,
keepAlivesSent: 2,
lastLocalStreamCreatedTimestamp: time.Now().UTC(),
lastRemoteStreamCreatedTimestamp: time.Now().UTC(),
lastMessageSentTimestamp: time.Now().UTC(),
lastMessageReceivedTimestamp: time.Now().UTC(),
localFlowControlWindow: 65536,
remoteFlowControlWindow: 1024,
localAddr: &net.TCPAddr{IP: net.ParseIP("1.0.0.1"), Port: 10001},
remoteAddr: &net.TCPAddr{IP: net.ParseIP("12.0.0.1"), Port: 10002},
remoteName: "remote.remote",
},
{
streamsStarted: 10,
streamsSucceeded: 2,
streamsFailed: 3,
messagesSent: 20,
messagesReceived: 10,
keepAlivesSent: 2,
lastRemoteStreamCreatedTimestamp: time.Now().UTC(),
lastMessageSentTimestamp: time.Now().UTC(),
lastMessageReceivedTimestamp: time.Now().UTC(),
localFlowControlWindow: 65536,
remoteFlowControlWindow: 1024,
localAddr: &net.UnixAddr{Name: "file.path", Net: "unix"},
remoteAddr: &net.UnixAddr{Name: "another.path", Net: "unix"},
remoteName: "remote.remote",
},
{
streamsStarted: 5,
streamsSucceeded: 2,
streamsFailed: 3,
messagesSent: 20,
messagesReceived: 10,
keepAlivesSent: 2,
lastLocalStreamCreatedTimestamp: time.Now().UTC(),
lastMessageSentTimestamp: time.Now().UTC(),
lastMessageReceivedTimestamp: time.Now().UTC(),
localFlowControlWindow: 65536,
remoteFlowControlWindow: 10240,
localAddr: &net.IPAddr{IP: net.ParseIP("1.0.0.1")},
remoteAddr: &net.IPAddr{IP: net.ParseIP("9.0.0.1")},
remoteName: "",
},
{
localAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 10001},
},
{
security: &credentials.TLSChannelzSecurityValue{
StandardName: "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
RemoteCertificate: []byte{48, 130, 2, 156, 48, 130, 2, 5, 160},
},
},
{
security: &credentials.OtherChannelzSecurityValue{
Name: "XXXX",
},
},
{
security: &credentials.OtherChannelzSecurityValue{
Name: "YYYY",
Value: &OtherSecurityValue{LocalCertificate: []byte{1, 2, 3}, RemoteCertificate: []byte{4, 5, 6}},
},
},
}
svr := newCZServer()
ids := make([]int64, len(ss))
svrID := channelz.RegisterServer(&dummyServer{}, "")
for i, s := range ss {
ids[i] = channelz.RegisterNormalSocket(s, svrID, strconv.Itoa(i))
}
for i, s := range ss {
resp, _ := svr.GetSocket(context.Background(), &channelzpb.GetSocketRequest{SocketId: ids[i]})
metrics := resp.GetSocket()
if !reflect.DeepEqual(metrics.GetRef(), &channelzpb.SocketRef{SocketId: ids[i], Name: strconv.Itoa(i)}) || !reflect.DeepEqual(socketProtoToStruct(metrics), s) {
t.Fatalf("resp.GetSocket() want: metrics.GetRef() = %#v and %#v, got: metrics.GetRef() = %#v and %#v", &channelzpb.SocketRef{SocketId: ids[i], Name: strconv.Itoa(i)}, s, metrics.GetRef(), socketProtoToStruct(metrics))
}
}
}

View file

@ -0,0 +1,33 @@
// +build 386,linux,!appengine,go1.7
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package service
import (
"golang.org/x/sys/unix"
channelzpb "google.golang.org/grpc/channelz/grpc_channelz_v1"
)
func protoToTime(protoTime *channelzpb.SocketOptionTimeout) *unix.Timeval {
timeout := &unix.Timeval{}
sec, usec := convertToDuration(protoTime.GetDuration())
timeout.Sec, timeout.Usec = int32(sec), int32(usec)
return timeout
}

View file

@ -0,0 +1,32 @@
// +build amd64,linux,!appengine,go1.7
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package service
import (
"golang.org/x/sys/unix"
channelzpb "google.golang.org/grpc/channelz/grpc_channelz_v1"
)
func protoToTime(protoTime *channelzpb.SocketOptionTimeout) *unix.Timeval {
timeout := &unix.Timeval{}
timeout.Sec, timeout.Usec = convertToDuration(protoTime.GetDuration())
return timeout
}

View file

@ -26,6 +26,7 @@ import (
"reflect"
"strings"
"sync"
"sync/atomic"
"time"
"golang.org/x/net/context"
@ -36,13 +37,21 @@ import (
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/backoff"
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/resolver"
_ "google.golang.org/grpc/resolver/dns" // To register dns resolver.
_ "google.golang.org/grpc/resolver/passthrough" // To register passthrough resolver.
"google.golang.org/grpc/stats"
"google.golang.org/grpc/status"
"google.golang.org/grpc/transport"
)
const (
// minimum time to give a connection to complete
minConnectTimeout = 20 * time.Second
// must match grpclbName in grpclb/grpclb.go
grpclbName = "grpclb"
)
var (
@ -56,12 +65,13 @@ var (
errConnDrain = errors.New("grpc: the connection is drained")
// errConnClosing indicates that the connection is closing.
errConnClosing = errors.New("grpc: the connection is closing")
// errConnUnavailable indicates that the connection is unavailable.
errConnUnavailable = errors.New("grpc: the connection is unavailable")
// errBalancerClosed indicates that the balancer is closed.
errBalancerClosed = errors.New("grpc: balancer is closed")
// minimum time to give a connection to complete
minConnectTimeout = 20 * time.Second
// We use an accessor so that minConnectTimeout can be
// atomically read and updated while testing.
getMinConnectTimeout = func() time.Duration {
return minConnectTimeout
}
)
// The following errors are returned from Dial and DialContext
@ -77,342 +87,59 @@ var (
// errCredentialsConflict indicates that grpc.WithTransportCredentials()
// and grpc.WithInsecure() are both called for a connection.
errCredentialsConflict = errors.New("grpc: transport credentials are set for an insecure connection (grpc.WithTransportCredentials() and grpc.WithInsecure() are both called)")
// errNetworkIO indicates that the connection is down due to some network I/O error.
errNetworkIO = errors.New("grpc: failed with network I/O error")
)
// dialOptions configure a Dial call. dialOptions are set by the DialOption
// values passed to Dial.
type dialOptions struct {
unaryInt UnaryClientInterceptor
streamInt StreamClientInterceptor
cp Compressor
dc Decompressor
bs backoffStrategy
block bool
insecure bool
timeout time.Duration
scChan <-chan ServiceConfig
copts transport.ConnectOptions
callOptions []CallOption
// This is used by v1 balancer dial option WithBalancer to support v1
// balancer, and also by WithBalancerName dial option.
balancerBuilder balancer.Builder
// This is to support grpclb.
resolverBuilder resolver.Builder
waitForHandshake bool
}
const (
defaultClientMaxReceiveMessageSize = 1024 * 1024 * 4
defaultClientMaxSendMessageSize = math.MaxInt32
// http2IOBufSize specifies the buffer size for sending frames.
defaultWriteBufSize = 32 * 1024
defaultReadBufSize = 32 * 1024
)
// DialOption configures how we set up the connection.
type DialOption func(*dialOptions)
// WithWaitForHandshake blocks until the initial settings frame is received from the
// server before assigning RPCs to the connection.
// Experimental API.
func WithWaitForHandshake() DialOption {
return func(o *dialOptions) {
o.waitForHandshake = true
}
}
// WithWriteBufferSize lets you set the size of write buffer, this determines how much data can be batched
// before doing a write on the wire.
func WithWriteBufferSize(s int) DialOption {
return func(o *dialOptions) {
o.copts.WriteBufferSize = s
}
}
// WithReadBufferSize lets you set the size of read buffer, this determines how much data can be read at most
// for each read syscall.
func WithReadBufferSize(s int) DialOption {
return func(o *dialOptions) {
o.copts.ReadBufferSize = s
}
}
// WithInitialWindowSize returns a DialOption which sets the value for initial window size on a stream.
// The lower bound for window size is 64K and any value smaller than that will be ignored.
func WithInitialWindowSize(s int32) DialOption {
return func(o *dialOptions) {
o.copts.InitialWindowSize = s
}
}
// WithInitialConnWindowSize returns a DialOption which sets the value for initial window size on a connection.
// The lower bound for window size is 64K and any value smaller than that will be ignored.
func WithInitialConnWindowSize(s int32) DialOption {
return func(o *dialOptions) {
o.copts.InitialConnWindowSize = s
}
}
// WithMaxMsgSize returns a DialOption which sets the maximum message size the client can receive. Deprecated: use WithDefaultCallOptions(MaxCallRecvMsgSize(s)) instead.
func WithMaxMsgSize(s int) DialOption {
return WithDefaultCallOptions(MaxCallRecvMsgSize(s))
}
// WithDefaultCallOptions returns a DialOption which sets the default CallOptions for calls over the connection.
func WithDefaultCallOptions(cos ...CallOption) DialOption {
return func(o *dialOptions) {
o.callOptions = append(o.callOptions, cos...)
}
}
// WithCodec returns a DialOption which sets a codec for message marshaling and unmarshaling.
//
// Deprecated: use WithDefaultCallOptions(CallCustomCodec(c)) instead.
func WithCodec(c Codec) DialOption {
return WithDefaultCallOptions(CallCustomCodec(c))
}
// WithCompressor returns a DialOption which sets a Compressor to use for
// message compression. It has lower priority than the compressor set by
// the UseCompressor CallOption.
//
// Deprecated: use UseCompressor instead.
func WithCompressor(cp Compressor) DialOption {
return func(o *dialOptions) {
o.cp = cp
}
}
// WithDecompressor returns a DialOption which sets a Decompressor to use for
// incoming message decompression. If incoming response messages are encoded
// using the decompressor's Type(), it will be used. Otherwise, the message
// encoding will be used to look up the compressor registered via
// encoding.RegisterCompressor, which will then be used to decompress the
// message. If no compressor is registered for the encoding, an Unimplemented
// status error will be returned.
//
// Deprecated: use encoding.RegisterCompressor instead.
func WithDecompressor(dc Decompressor) DialOption {
return func(o *dialOptions) {
o.dc = dc
}
}
// WithBalancer returns a DialOption which sets a load balancer with the v1 API.
// Name resolver will be ignored if this DialOption is specified.
//
// Deprecated: use the new balancer APIs in balancer package and WithBalancerName.
func WithBalancer(b Balancer) DialOption {
return func(o *dialOptions) {
o.balancerBuilder = &balancerWrapperBuilder{
b: b,
}
}
}
// WithBalancerName sets the balancer that the ClientConn will be initialized
// with. Balancer registered with balancerName will be used. This function
// panics if no balancer was registered by balancerName.
//
// The balancer cannot be overridden by balancer option specified by service
// config.
//
// This is an EXPERIMENTAL API.
func WithBalancerName(balancerName string) DialOption {
builder := balancer.Get(balancerName)
if builder == nil {
panic(fmt.Sprintf("grpc.WithBalancerName: no balancer is registered for name %v", balancerName))
}
return func(o *dialOptions) {
o.balancerBuilder = builder
}
}
// withResolverBuilder is only for grpclb.
func withResolverBuilder(b resolver.Builder) DialOption {
return func(o *dialOptions) {
o.resolverBuilder = b
}
}
// WithServiceConfig returns a DialOption which has a channel to read the service configuration.
// DEPRECATED: service config should be received through name resolver, as specified here.
// https://github.com/grpc/grpc/blob/master/doc/service_config.md
func WithServiceConfig(c <-chan ServiceConfig) DialOption {
return func(o *dialOptions) {
o.scChan = c
}
}
// WithBackoffMaxDelay configures the dialer to use the provided maximum delay
// when backing off after failed connection attempts.
func WithBackoffMaxDelay(md time.Duration) DialOption {
return WithBackoffConfig(BackoffConfig{MaxDelay: md})
}
// WithBackoffConfig configures the dialer to use the provided backoff
// parameters after connection failures.
//
// Use WithBackoffMaxDelay until more parameters on BackoffConfig are opened up
// for use.
func WithBackoffConfig(b BackoffConfig) DialOption {
// Set defaults to ensure that provided BackoffConfig is valid and
// unexported fields get default values.
setDefaults(&b)
return withBackoff(b)
}
// withBackoff sets the backoff strategy used for connectRetryNum after a
// failed connection attempt.
//
// This can be exported if arbitrary backoff strategies are allowed by gRPC.
func withBackoff(bs backoffStrategy) DialOption {
return func(o *dialOptions) {
o.bs = bs
}
}
// WithBlock returns a DialOption which makes caller of Dial blocks until the underlying
// connection is up. Without this, Dial returns immediately and connecting the server
// happens in background.
func WithBlock() DialOption {
return func(o *dialOptions) {
o.block = true
}
}
// WithInsecure returns a DialOption which disables transport security for this ClientConn.
// Note that transport security is required unless WithInsecure is set.
func WithInsecure() DialOption {
return func(o *dialOptions) {
o.insecure = true
}
}
// WithTransportCredentials returns a DialOption which configures a
// connection level security credentials (e.g., TLS/SSL).
func WithTransportCredentials(creds credentials.TransportCredentials) DialOption {
return func(o *dialOptions) {
o.copts.TransportCredentials = creds
}
}
// WithPerRPCCredentials returns a DialOption which sets
// credentials and places auth state on each outbound RPC.
func WithPerRPCCredentials(creds credentials.PerRPCCredentials) DialOption {
return func(o *dialOptions) {
o.copts.PerRPCCredentials = append(o.copts.PerRPCCredentials, creds)
}
}
// WithTimeout returns a DialOption that configures a timeout for dialing a ClientConn
// initially. This is valid if and only if WithBlock() is present.
// Deprecated: use DialContext and context.WithTimeout instead.
func WithTimeout(d time.Duration) DialOption {
return func(o *dialOptions) {
o.timeout = d
}
}
func withContextDialer(f func(context.Context, string) (net.Conn, error)) DialOption {
return func(o *dialOptions) {
o.copts.Dialer = f
}
}
// WithDialer returns a DialOption that specifies a function to use for dialing network addresses.
// If FailOnNonTempDialError() is set to true, and an error is returned by f, gRPC checks the error's
// Temporary() method to decide if it should try to reconnect to the network address.
func WithDialer(f func(string, time.Duration) (net.Conn, error)) DialOption {
return withContextDialer(
func(ctx context.Context, addr string) (net.Conn, error) {
if deadline, ok := ctx.Deadline(); ok {
return f(addr, deadline.Sub(time.Now()))
}
return f(addr, 0)
})
}
// WithStatsHandler returns a DialOption that specifies the stats handler
// for all the RPCs and underlying network connections in this ClientConn.
func WithStatsHandler(h stats.Handler) DialOption {
return func(o *dialOptions) {
o.copts.StatsHandler = h
}
}
// FailOnNonTempDialError returns a DialOption that specifies if gRPC fails on non-temporary dial errors.
// If f is true, and dialer returns a non-temporary error, gRPC will fail the connection to the network
// address and won't try to reconnect.
// The default value of FailOnNonTempDialError is false.
// This is an EXPERIMENTAL API.
func FailOnNonTempDialError(f bool) DialOption {
return func(o *dialOptions) {
o.copts.FailOnNonTempDialError = f
}
}
// WithUserAgent returns a DialOption that specifies a user agent string for all the RPCs.
func WithUserAgent(s string) DialOption {
return func(o *dialOptions) {
o.copts.UserAgent = s
}
}
// WithKeepaliveParams returns a DialOption that specifies keepalive parameters for the client transport.
func WithKeepaliveParams(kp keepalive.ClientParameters) DialOption {
return func(o *dialOptions) {
o.copts.KeepaliveParams = kp
}
}
// WithUnaryInterceptor returns a DialOption that specifies the interceptor for unary RPCs.
func WithUnaryInterceptor(f UnaryClientInterceptor) DialOption {
return func(o *dialOptions) {
o.unaryInt = f
}
}
// WithStreamInterceptor returns a DialOption that specifies the interceptor for streaming RPCs.
func WithStreamInterceptor(f StreamClientInterceptor) DialOption {
return func(o *dialOptions) {
o.streamInt = f
}
}
// WithAuthority returns a DialOption that specifies the value to be used as
// the :authority pseudo-header. This value only works with WithInsecure and
// has no effect if TransportCredentials are present.
func WithAuthority(a string) DialOption {
return func(o *dialOptions) {
o.copts.Authority = a
}
}
// Dial creates a client connection to the given target.
func Dial(target string, opts ...DialOption) (*ClientConn, error) {
return DialContext(context.Background(), target, opts...)
}
// DialContext creates a client connection to the given target. ctx can be used to
// cancel or expire the pending connection. Once this function returns, the
// cancellation and expiration of ctx will be noop. Users should call ClientConn.Close
// to terminate all the pending operations after this function returns.
// DialContext creates a client connection to the given target. By default, it's
// a non-blocking dial (the function won't wait for connections to be
// established, and connecting happens in the background). To make it a blocking
// dial, use WithBlock() dial option.
//
// In the non-blocking case, the ctx does not act against the connection. It
// only controls the setup steps.
//
// In the blocking case, ctx can be used to cancel or expire the pending
// connection. Once this function returns, the cancellation and expiration of
// ctx will be noop. Users should call ClientConn.Close to terminate all the
// pending operations after this function returns.
//
// The target name syntax is defined in
// https://github.com/grpc/grpc/blob/master/doc/naming.md.
// e.g. to use dns resolver, a "dns:///" prefix should be applied to the target.
func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *ClientConn, err error) {
cc := &ClientConn{
target: target,
csMgr: &connectivityStateManager{},
conns: make(map[*addrConn]struct{}),
target: target,
csMgr: &connectivityStateManager{},
conns: make(map[*addrConn]struct{}),
dopts: defaultDialOptions(),
blockingpicker: newPickerWrapper(),
czData: new(channelzData),
}
cc.retryThrottler.Store((*retryThrottler)(nil))
cc.ctx, cc.cancel = context.WithCancel(context.Background())
for _, opt := range opts {
opt(&cc.dopts)
opt.apply(&cc.dopts)
}
if channelz.IsOn() {
if cc.dopts.channelzParentID != 0 {
cc.channelzID = channelz.RegisterChannel(&channelzChannel{cc}, cc.dopts.channelzParentID, target)
} else {
cc.channelzID = channelz.RegisterChannel(&channelzChannel{cc}, 0, target)
}
}
if !cc.dopts.insecure {
@ -435,7 +162,8 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
if cc.dopts.copts.Dialer == nil {
cc.dopts.copts.Dialer = newProxyDialer(
func(ctx context.Context, addr string) (net.Conn, error) {
return dialContext(ctx, "tcp", addr)
network, addr := parseDialTarget(addr)
return dialContext(ctx, network, addr)
},
)
}
@ -477,14 +205,34 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
}
}
if cc.dopts.bs == nil {
cc.dopts.bs = DefaultBackoffConfig
cc.dopts.bs = backoff.Exponential{
MaxDelay: DefaultBackoffConfig.MaxDelay,
}
}
if cc.dopts.resolverBuilder == nil {
// Only try to parse target when resolver builder is not already set.
cc.parsedTarget = parseTarget(cc.target)
grpclog.Infof("parsed scheme: %q", cc.parsedTarget.Scheme)
cc.dopts.resolverBuilder = resolver.Get(cc.parsedTarget.Scheme)
if cc.dopts.resolverBuilder == nil {
// If resolver builder is still nil, the parse target's scheme is
// not registered. Fallback to default resolver and set Endpoint to
// the original unparsed target.
grpclog.Infof("scheme %q not registered, fallback to default scheme", cc.parsedTarget.Scheme)
cc.parsedTarget = resolver.Target{
Scheme: resolver.GetDefaultScheme(),
Endpoint: target,
}
cc.dopts.resolverBuilder = resolver.Get(cc.parsedTarget.Scheme)
}
} else {
cc.parsedTarget = resolver.Target{Endpoint: target}
}
cc.parsedTarget = parseTarget(cc.target)
creds := cc.dopts.copts.TransportCredentials
if creds != nil && creds.Info().ServerName != "" {
cc.authority = creds.Info().ServerName
} else if cc.dopts.insecure && cc.dopts.copts.Authority != "" {
cc.authority = cc.dopts.copts.Authority
} else if cc.dopts.insecure && cc.dopts.authority != "" {
cc.authority = cc.dopts.authority
} else {
// Use endpoint from "scheme://authority/endpoint" as the default
// authority for ClientConn.
@ -511,8 +259,9 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
credsClone = creds.Clone()
}
cc.balancerBuildOpts = balancer.BuildOptions{
DialCreds: credsClone,
Dialer: cc.dopts.copts.Dialer,
DialCreds: credsClone,
Dialer: cc.dopts.copts.Dialer,
ChannelzParentID: cc.channelzID,
}
// Build the resolver.
@ -535,6 +284,13 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
s := cc.GetState()
if s == connectivity.Ready {
break
} else if cc.dopts.copts.FailOnNonTempDialError && s == connectivity.TransientFailure {
if err = cc.blockingpicker.connectionError(); err != nil {
terr, ok := err.(interface{ Temporary() bool })
if ok && !terr.Temporary() {
return nil, err
}
}
}
if !cc.WaitForStateChange(ctx, s) {
// ctx got timeout or canceled.
@ -614,6 +370,10 @@ type ClientConn struct {
preBalancerName string // previous balancer name.
curAddresses []resolver.Address
balancerWrapper *ccBalancerWrapper
retryThrottler atomic.Value
channelzID int64 // channelz unique identification number
czData *channelzData
}
// WaitForStateChange waits until the connectivity.State of ClientConn changes from sourceState or
@ -766,9 +526,11 @@ func (cc *ClientConn) handleSubConnStateChange(sc balancer.SubConn, s connectivi
// Caller needs to make sure len(addrs) > 0.
func (cc *ClientConn) newAddrConn(addrs []resolver.Address) (*addrConn, error) {
ac := &addrConn{
cc: cc,
addrs: addrs,
dopts: cc.dopts,
cc: cc,
addrs: addrs,
dopts: cc.dopts,
czData: new(channelzData),
resetBackoff: make(chan struct{}),
}
ac.ctx, ac.cancel = context.WithCancel(cc.ctx)
// Track ac in cc. This needs to be done before any getTransport(...) is called.
@ -777,6 +539,9 @@ func (cc *ClientConn) newAddrConn(addrs []resolver.Address) (*addrConn, error) {
cc.mu.Unlock()
return nil, ErrClientConnClosing
}
if channelz.IsOn() {
ac.channelzID = channelz.RegisterSubChannel(ac, cc.channelzID, "")
}
cc.conns[ac] = struct{}{}
cc.mu.Unlock()
return ac, nil
@ -795,6 +560,36 @@ func (cc *ClientConn) removeAddrConn(ac *addrConn, err error) {
ac.tearDown(err)
}
func (cc *ClientConn) channelzMetric() *channelz.ChannelInternalMetric {
return &channelz.ChannelInternalMetric{
State: cc.GetState(),
Target: cc.target,
CallsStarted: atomic.LoadInt64(&cc.czData.callsStarted),
CallsSucceeded: atomic.LoadInt64(&cc.czData.callsSucceeded),
CallsFailed: atomic.LoadInt64(&cc.czData.callsFailed),
LastCallStartedTimestamp: time.Unix(0, atomic.LoadInt64(&cc.czData.lastCallStartedTime)),
}
}
// Target returns the target string of the ClientConn.
// This is an EXPERIMENTAL API.
func (cc *ClientConn) Target() string {
return cc.target
}
func (cc *ClientConn) incrCallsStarted() {
atomic.AddInt64(&cc.czData.callsStarted, 1)
atomic.StoreInt64(&cc.czData.lastCallStartedTime, time.Now().UnixNano())
}
func (cc *ClientConn) incrCallsSucceeded() {
atomic.AddInt64(&cc.czData.callsSucceeded, 1)
}
func (cc *ClientConn) incrCallsFailed() {
atomic.AddInt64(&cc.czData.callsFailed, 1)
}
// connect starts to creating transport and also starts the transport monitor
// goroutine for this ac.
// It does nothing if the ac is not IDLE.
@ -865,7 +660,7 @@ func (ac *addrConn) tryUpdateAddrs(addrs []resolver.Address) bool {
// the corresponding MethodConfig.
// If there isn't an exact match for the input method, we look for the default config
// under the service (i.e /service/). If there is a default MethodConfig for
// the serivce, we return it.
// the service, we return it.
// Otherwise, we return an empty MethodConfig.
func (cc *ClientConn) GetMethodConfig(method string) MethodConfig {
// TODO: Avoid the locking here.
@ -874,13 +669,15 @@ func (cc *ClientConn) GetMethodConfig(method string) MethodConfig {
m, ok := cc.sc.Methods[method]
if !ok {
i := strings.LastIndex(method, "/")
m, _ = cc.sc.Methods[method[:i+1]]
m = cc.sc.Methods[method[:i+1]]
}
return m
}
func (cc *ClientConn) getTransport(ctx context.Context, failfast bool) (transport.ClientTransport, func(balancer.DoneInfo), error) {
t, done, err := cc.blockingpicker.pick(ctx, failfast, balancer.PickOptions{})
func (cc *ClientConn) getTransport(ctx context.Context, failfast bool, method string) (transport.ClientTransport, func(balancer.DoneInfo), error) {
t, done, err := cc.blockingpicker.pick(ctx, failfast, balancer.PickOptions{
FullMethodName: method,
})
if err != nil {
return nil, nil, toRPCErr(err)
}
@ -890,6 +687,9 @@ func (cc *ClientConn) getTransport(ctx context.Context, failfast bool) (transpor
// handleServiceConfig parses the service config string in JSON format to Go native
// struct ServiceConfig, and store both the struct and the JSON string in ClientConn.
func (cc *ClientConn) handleServiceConfig(js string) error {
if cc.dopts.disableServiceConfig {
return nil
}
sc, err := parseServiceConfig(js)
if err != nil {
return err
@ -897,6 +697,19 @@ func (cc *ClientConn) handleServiceConfig(js string) error {
cc.mu.Lock()
cc.scRaw = js
cc.sc = sc
if sc.retryThrottling != nil {
newThrottler := &retryThrottler{
tokens: sc.retryThrottling.MaxTokens,
max: sc.retryThrottling.MaxTokens,
thresh: sc.retryThrottling.MaxTokens / 2,
ratio: sc.retryThrottling.TokenRatio,
}
cc.retryThrottler.Store(newThrottler)
} else {
cc.retryThrottler.Store((*retryThrottler)(nil))
}
if sc.LB != nil && *sc.LB != grpclbName { // "grpclb" is not a valid balancer option in service config.
if cc.curBalancerName == grpclbName {
// If current balancer is grpclb, there's at least one grpclb
@ -910,23 +723,42 @@ func (cc *ClientConn) handleServiceConfig(js string) error {
cc.balancerWrapper.handleResolvedAddrs(cc.curAddresses, nil)
}
}
cc.mu.Unlock()
return nil
}
func (cc *ClientConn) resolveNow(o resolver.ResolveNowOption) {
cc.mu.Lock()
cc.mu.RLock()
r := cc.resolverWrapper
cc.mu.Unlock()
cc.mu.RUnlock()
if r == nil {
return
}
go r.resolveNow(o)
}
// ResetConnectBackoff wakes up all subchannels in transient failure and causes
// them to attempt another connection immediately. It also resets the backoff
// times used for subsequent attempts regardless of the current state.
//
// In general, this function should not be used. Typical service or network
// outages result in a reasonable client reconnection strategy by default.
// However, if a previously unavailable network becomes available, this may be
// used to trigger an immediate reconnect.
//
// This API is EXPERIMENTAL.
func (cc *ClientConn) ResetConnectBackoff() {
cc.mu.Lock()
defer cc.mu.Unlock()
for ac := range cc.conns {
ac.resetConnectBackoff()
}
}
// Close tears down the ClientConn and all underlying connections.
func (cc *ClientConn) Close() error {
cc.cancel()
defer cc.cancel()
cc.mu.Lock()
if cc.conns == nil {
@ -942,16 +774,22 @@ func (cc *ClientConn) Close() error {
bWrapper := cc.balancerWrapper
cc.balancerWrapper = nil
cc.mu.Unlock()
cc.blockingpicker.close()
if rWrapper != nil {
rWrapper.close()
}
if bWrapper != nil {
bWrapper.close()
}
for ac := range conns {
ac.tearDown(ErrClientConnClosing)
}
if channelz.IsOn() {
channelz.RemoveEntry(cc.channelzID)
}
return nil
}
@ -985,6 +823,11 @@ type addrConn struct {
// connectDeadline is the time by which all connection
// negotiations must complete.
connectDeadline time.Time
resetBackoff chan struct{}
channelzID int64 // channelz unique identification number
czData *channelzData
}
// adjustParams updates parameters used to create transports upon
@ -1009,18 +852,10 @@ func (ac *addrConn) printf(format string, a ...interface{}) {
}
}
// errorf records an error in ac's event log, unless ac has been closed.
// REQUIRES ac.mu is held.
func (ac *addrConn) errorf(format string, a ...interface{}) {
if ac.events != nil {
ac.events.Errorf(format, a...)
}
}
// resetTransport recreates a transport to the address for ac. The old
// transport will close itself on error or when the clientconn is closed.
// The created transport must receive initial settings frame from the server.
// In case that doesnt happen, transportMonitor will kill the newly created
// In case that doesn't happen, transportMonitor will kill the newly created
// transport after connectDeadline has expired.
// In case there was an error on the transport before the settings frame was
// received, resetTransport resumes connecting to backends after the one that
@ -1047,15 +882,17 @@ func (ac *addrConn) resetTransport() error {
ac.dopts.copts.KeepaliveParams = ac.cc.mkp
ac.cc.mu.RUnlock()
var backoffDeadline, connectDeadline time.Time
var resetBackoff chan struct{}
for connectRetryNum := 0; ; connectRetryNum++ {
ac.mu.Lock()
if ac.backoffDeadline.IsZero() {
// This means either a successful HTTP2 connection was established
// or this is the first time this addrConn is trying to establish a
// connection.
backoffFor := ac.dopts.bs.backoff(connectRetryNum) // time.Duration.
backoffFor := ac.dopts.bs.Backoff(connectRetryNum) // time.Duration.
resetBackoff = ac.resetBackoff
// This will be the duration that dial gets to finish.
dialDuration := minConnectTimeout
dialDuration := getMinConnectTimeout()
if backoffFor > dialDuration {
// Give dial more time as we keep failing to connect.
dialDuration = backoffFor
@ -1065,7 +902,7 @@ func (ac *addrConn) resetTransport() error {
connectDeadline = start.Add(dialDuration)
ridx = 0 // Start connecting from the beginning.
} else {
// Continue trying to conect with the same deadlines.
// Continue trying to connect with the same deadlines.
connectRetryNum = ac.connectRetryNum
backoffDeadline = ac.backoffDeadline
connectDeadline = ac.connectDeadline
@ -1087,7 +924,7 @@ func (ac *addrConn) resetTransport() error {
copy(addrsIter, ac.addrs)
copts := ac.dopts.copts
ac.mu.Unlock()
connected, err := ac.createTransport(connectRetryNum, ridx, backoffDeadline, connectDeadline, addrsIter, copts)
connected, err := ac.createTransport(connectRetryNum, ridx, backoffDeadline, connectDeadline, addrsIter, copts, resetBackoff)
if err != nil {
return err
}
@ -1099,7 +936,7 @@ func (ac *addrConn) resetTransport() error {
// createTransport creates a connection to one of the backends in addrs.
// It returns true if a connection was established.
func (ac *addrConn) createTransport(connectRetryNum, ridx int, backoffDeadline, connectDeadline time.Time, addrs []resolver.Address, copts transport.ConnectOptions) (bool, error) {
func (ac *addrConn) createTransport(connectRetryNum, ridx int, backoffDeadline, connectDeadline time.Time, addrs []resolver.Address, copts transport.ConnectOptions, resetBackoff chan struct{}) (bool, error) {
for i := ridx; i < len(addrs); i++ {
addr := addrs[i]
target := transport.TargetInfo{
@ -1126,18 +963,13 @@ func (ac *addrConn) createTransport(connectRetryNum, ridx int, backoffDeadline,
// Do not cancel in the success path because of
// this issue in Go1.6: https://github.com/golang/go/issues/15078.
connectCtx, cancel := context.WithDeadline(ac.ctx, connectDeadline)
if channelz.IsOn() {
copts.ChannelzParentID = ac.channelzID
}
newTr, err := transport.NewClientTransport(connectCtx, ac.cc.ctx, target, copts, onPrefaceReceipt)
if err != nil {
cancel()
if e, ok := err.(transport.ConnectionError); ok && !e.Temporary() {
ac.mu.Lock()
if ac.state != connectivity.Shutdown {
ac.state = connectivity.TransientFailure
ac.cc.handleSubConnStateChange(ac.acbw, ac.state)
}
ac.mu.Unlock()
return false, err
}
ac.cc.blockingpicker.updateConnectionError(err)
ac.mu.Lock()
if ac.state == connectivity.Shutdown {
// ac.tearDown(...) has been invoked.
@ -1155,7 +987,7 @@ func (ac *addrConn) createTransport(connectRetryNum, ridx int, backoffDeadline,
// Didn't receive server preface, must kill this new transport now.
grpclog.Warningf("grpc: addrConn.createTransport failed to receive server preface before deadline.")
newTr.Close()
break
continue
case <-ac.ctx.Done():
}
}
@ -1189,6 +1021,10 @@ func (ac *addrConn) createTransport(connectRetryNum, ridx int, backoffDeadline,
return true, nil
}
ac.mu.Lock()
if ac.state == connectivity.Shutdown {
ac.mu.Unlock()
return false, errConnClosing
}
ac.state = connectivity.TransientFailure
ac.cc.handleSubConnStateChange(ac.acbw, ac.state)
ac.cc.resolveNow(resolver.ResolveNowOption{})
@ -1200,6 +1036,8 @@ func (ac *addrConn) createTransport(connectRetryNum, ridx int, backoffDeadline,
timer := time.NewTimer(backoffDeadline.Sub(time.Now()))
select {
case <-timer.C:
case <-resetBackoff:
timer.Stop()
case <-ac.ctx.Done():
timer.Stop()
return false, ac.ctx.Err()
@ -1207,6 +1045,14 @@ func (ac *addrConn) createTransport(connectRetryNum, ridx int, backoffDeadline,
return false, nil
}
func (ac *addrConn) resetConnectBackoff() {
ac.mu.Lock()
close(ac.resetBackoff)
ac.resetBackoff = make(chan struct{})
ac.connectRetryNum = 0
ac.mu.Unlock()
}
// Run in a goroutine to track the error in transport and create the
// new transport if an error happens. It returns when the channel is closing.
func (ac *addrConn) transportMonitor() {
@ -1223,7 +1069,20 @@ func (ac *addrConn) transportMonitor() {
// Block until we receive a goaway or an error occurs.
select {
case <-t.GoAway():
done := t.Error()
cleanup := t.Close
// Since this transport will be orphaned (won't have a transportMonitor)
// we need to launch a goroutine to keep track of clientConn.Close()
// happening since it might not be noticed by any other goroutine for a while.
go func() {
<-done
cleanup()
}()
case <-t.Error():
// In case this is triggered because clientConn.Close()
// was called, we want to immeditately close the transport
// since no other goroutine might notice it for a while.
t.Close()
case <-cdeadline:
ac.mu.Lock()
// This implies that client received server preface.
@ -1274,46 +1133,6 @@ func (ac *addrConn) transportMonitor() {
}
}
// wait blocks until i) the new transport is up or ii) ctx is done or iii) ac is closed or
// iv) transport is in connectivity.TransientFailure and there is a balancer/failfast is true.
func (ac *addrConn) wait(ctx context.Context, hasBalancer, failfast bool) (transport.ClientTransport, error) {
for {
ac.mu.Lock()
switch {
case ac.state == connectivity.Shutdown:
if failfast || !hasBalancer {
// RPC is failfast or balancer is nil. This RPC should fail with ac.tearDownErr.
err := ac.tearDownErr
ac.mu.Unlock()
return nil, err
}
ac.mu.Unlock()
return nil, errConnClosing
case ac.state == connectivity.Ready:
ct := ac.transport
ac.mu.Unlock()
return ct, nil
case ac.state == connectivity.TransientFailure:
if failfast || hasBalancer {
ac.mu.Unlock()
return nil, errConnUnavailable
}
}
ready := ac.ready
if ready == nil {
ready = make(chan struct{})
ac.ready = ready
}
ac.mu.Unlock()
select {
case <-ctx.Done():
return nil, toRPCErr(ctx.Err())
// Wait until the new transport is ready or failed.
case <-ready:
}
}
}
// getReadyTransport returns the transport if ac's state is READY.
// Otherwise it returns nil, false.
// If ac's state is IDLE, it will trigger ac to connect.
@ -1367,7 +1186,9 @@ func (ac *addrConn) tearDown(err error) {
close(ac.ready)
ac.ready = nil
}
return
if channelz.IsOn() {
channelz.RemoveEntry(ac.channelzID)
}
}
func (ac *addrConn) getState() connectivity.State {
@ -1376,6 +1197,78 @@ func (ac *addrConn) getState() connectivity.State {
return ac.state
}
func (ac *addrConn) ChannelzMetric() *channelz.ChannelInternalMetric {
ac.mu.Lock()
addr := ac.curAddr.Addr
ac.mu.Unlock()
return &channelz.ChannelInternalMetric{
State: ac.getState(),
Target: addr,
CallsStarted: atomic.LoadInt64(&ac.czData.callsStarted),
CallsSucceeded: atomic.LoadInt64(&ac.czData.callsSucceeded),
CallsFailed: atomic.LoadInt64(&ac.czData.callsFailed),
LastCallStartedTimestamp: time.Unix(0, atomic.LoadInt64(&ac.czData.lastCallStartedTime)),
}
}
func (ac *addrConn) incrCallsStarted() {
atomic.AddInt64(&ac.czData.callsStarted, 1)
atomic.StoreInt64(&ac.czData.lastCallStartedTime, time.Now().UnixNano())
}
func (ac *addrConn) incrCallsSucceeded() {
atomic.AddInt64(&ac.czData.callsSucceeded, 1)
}
func (ac *addrConn) incrCallsFailed() {
atomic.AddInt64(&ac.czData.callsFailed, 1)
}
type retryThrottler struct {
max float64
thresh float64
ratio float64
mu sync.Mutex
tokens float64 // TODO(dfawley): replace with atomic and remove lock.
}
// throttle subtracts a retry token from the pool and returns whether a retry
// should be throttled (disallowed) based upon the retry throttling policy in
// the service config.
func (rt *retryThrottler) throttle() bool {
if rt == nil {
return false
}
rt.mu.Lock()
defer rt.mu.Unlock()
rt.tokens--
if rt.tokens < 0 {
rt.tokens = 0
}
return rt.tokens <= rt.thresh
}
func (rt *retryThrottler) successfulRPC() {
if rt == nil {
return
}
rt.mu.Lock()
defer rt.mu.Unlock()
rt.tokens += rt.ratio
if rt.tokens > rt.max {
rt.tokens = rt.max
}
}
type channelzChannel struct {
cc *ClientConn
}
func (c *channelzChannel) ChannelzMetric() *channelz.ChannelInternalMetric {
return c.cc.channelzMetric()
}
// ErrClientConnTimeout indicates that the ClientConn cannot establish the
// underlying connections within the specified timeout.
//

View file

@ -19,26 +19,38 @@
package grpc
import (
"io"
"errors"
"math"
"net"
"sync/atomic"
"testing"
"time"
"golang.org/x/net/context"
"golang.org/x/net/http2"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/backoff"
"google.golang.org/grpc/internal/leakcheck"
"google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/naming"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/resolver/manual"
_ "google.golang.org/grpc/resolver/passthrough"
"google.golang.org/grpc/test/leakcheck"
"google.golang.org/grpc/testdata"
)
var (
mutableMinConnectTimeout = time.Second * 20
)
func init() {
getMinConnectTimeout = func() time.Duration {
return time.Duration(atomic.LoadInt64((*int64)(&mutableMinConnectTimeout)))
}
}
func assertState(wantState connectivity.State, cc *ClientConn) (connectivity.State, bool) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
@ -139,9 +151,9 @@ func TestDialWaitsForServerSettings(t *testing.T) {
return
}
defer conn.Close()
// Sleep so that if the test were to fail it
// will fail more often than not.
time.Sleep(100 * time.Millisecond)
// Sleep for a little bit to make sure that Dial on client
// side blocks until settings are received.
time.Sleep(500 * time.Millisecond)
framer := http2.NewFramer(conn, conn)
close(sent)
if err := framer.WriteSettings(http2.Setting{}); err != nil {
@ -150,7 +162,7 @@ func TestDialWaitsForServerSettings(t *testing.T) {
}
<-dialDone // Close conn only after dial returns.
}()
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
client, err := DialContext(ctx, server.Addr().String(), WithInsecure(), WithWaitForHandshake(), WithBlock())
close(dialDone)
@ -169,102 +181,87 @@ func TestDialWaitsForServerSettings(t *testing.T) {
}
func TestCloseConnectionWhenServerPrefaceNotReceived(t *testing.T) {
mctBkp := minConnectTimeout
// 1. Client connects to a server that doesn't send preface.
// 2. After minConnectTimeout(500 ms here), client disconnects and retries.
// 3. The new server sends its preface.
// 4. Client doesn't kill the connection this time.
mctBkp := getMinConnectTimeout()
// Call this only after transportMonitor goroutine has ended.
defer func() {
minConnectTimeout = mctBkp
atomic.StoreInt64((*int64)(&mutableMinConnectTimeout), int64(mctBkp))
}()
defer leakcheck.Check(t)
minConnectTimeout = time.Millisecond * 500
server, err := net.Listen("tcp", "localhost:0")
atomic.StoreInt64((*int64)(&mutableMinConnectTimeout), int64(time.Millisecond)*500)
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Error while listening. Err: %v", err)
}
defer server.Close()
var (
conn2 net.Conn
over uint32
)
defer func() {
lis.Close()
// conn2 shouldn't be closed until the client has
// observed a successful test.
if conn2 != nil {
conn2.Close()
}
}()
done := make(chan struct{})
clientDone := make(chan struct{})
accepted := make(chan struct{})
go func() { // Launch the server.
defer func() {
if done != nil {
close(done)
}
}()
conn1, err := server.Accept()
defer close(done)
conn1, err := lis.Accept()
if err != nil {
t.Errorf("Error while accepting. Err: %v", err)
return
}
defer conn1.Close()
// Don't send server settings and make sure the connection is closed.
time.Sleep(time.Millisecond * 1500) // Since the first backoff is for a second.
conn1.SetDeadline(time.Now().Add(time.Second))
b := make([]byte, 24)
for {
// Make sure the connection was closed by client.
_, err = conn1.Read(b)
if err == nil {
continue
}
if err != io.EOF {
t.Errorf(" conn1.Read(_) = _, %v, want _, io.EOF", err)
return
}
break
}
conn2, err := server.Accept() // Accept a reconnection request from client.
// Don't send server settings and the client should close the connection and try again.
conn2, err = lis.Accept() // Accept a reconnection request from client.
if err != nil {
t.Errorf("Error while accepting. Err: %v", err)
return
}
defer conn2.Close()
close(accepted)
framer := http2.NewFramer(conn2, conn2)
if err := framer.WriteSettings(http2.Setting{}); err != nil {
if err = framer.WriteSettings(http2.Setting{}); err != nil {
t.Errorf("Error while writing settings. Err: %v", err)
return
}
time.Sleep(time.Millisecond * 1500) // Since the first backoff is for a second.
conn2.SetDeadline(time.Now().Add(time.Millisecond * 500))
b := make([]byte, 8)
for {
// Make sure the connection stays open and is closed
// only by connection timeout.
_, err = conn2.Read(b)
if err == nil {
continue
}
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
if atomic.LoadUint32(&over) == 1 {
// The connection stayed alive for the timer.
// Success.
return
}
t.Errorf("Unexpected error while reading. Err: %v, want timeout error", err)
break
}
close(done)
done = nil
<-clientDone
}()
client, err := Dial(server.Addr().String(), WithInsecure())
client, err := Dial(lis.Addr().String(), WithInsecure())
if err != nil {
t.Fatalf("Error while dialing. Err: %v", err)
}
<-done
// TODO: The code from BEGIN to END should be delete once issue
// https://github.com/grpc/grpc-go/issues/1750 is fixed.
// BEGIN
// Set underlying addrConns state to Shutdown so that no reconnect
// attempts take place and thereby resetting minConnectTimeout is
// race free.
client.mu.Lock()
addrConns := client.conns
client.mu.Unlock()
for ac := range addrConns {
ac.mu.Lock()
ac.state = connectivity.Shutdown
ac.mu.Unlock()
// wait for connection to be accepted on the server.
timer := time.NewTimer(time.Second * 10)
select {
case <-accepted:
case <-timer.C:
t.Fatalf("Client didn't make another connection request in time.")
}
// END
// Make sure the connection stays alive for sometime.
time.Sleep(time.Second * 2)
atomic.StoreUint32(&over, 1)
client.Close()
close(clientDone)
<-done
}
func TestBackoffWhenNoServerPrefaceReceived(t *testing.T) {
@ -351,7 +348,7 @@ func TestConnectivityStates(t *testing.T) {
}
func TestDialTimeout(t *testing.T) {
func TestWithTimeout(t *testing.T) {
defer leakcheck.Check(t)
conn, err := Dial("passthrough:///Non-Existent.Server:80", WithTimeout(time.Millisecond), WithBlock(), WithInsecure())
if err == nil {
@ -362,13 +359,15 @@ func TestDialTimeout(t *testing.T) {
}
}
func TestTLSDialTimeout(t *testing.T) {
func TestWithTransportCredentialsTLS(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
defer cancel()
defer leakcheck.Check(t)
creds, err := credentials.NewClientTLSFromFile(testdata.Path("ca.pem"), "x.test.youtube.com")
if err != nil {
t.Fatalf("Failed to create credentials %v", err)
}
conn, err := Dial("passthrough:///Non-Existent.Server:80", WithTransportCredentials(creds), WithTimeout(time.Millisecond), WithBlock())
conn, err := DialContext(ctx, "passthrough:///Non-Existent.Server:80", WithTransportCredentials(creds), WithBlock())
if err == nil {
conn.Close()
}
@ -446,6 +445,26 @@ func TestDialContextCancel(t *testing.T) {
}
}
type failFastError struct{}
func (failFastError) Error() string { return "failfast" }
func (failFastError) Temporary() bool { return false }
func TestDialContextFailFast(t *testing.T) {
defer leakcheck.Check(t)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
failErr := failFastError{}
dialer := func(string, time.Duration) (net.Conn, error) {
return nil, failErr
}
_, err := DialContext(ctx, "Non-Existent.Server:80", WithBlock(), WithInsecure(), WithDialer(dialer), FailOnNonTempDialError(true))
if terr, ok := err.(transport.ConnectionError); !ok || terr.Origin() != failErr {
t.Fatalf("DialContext() = _, %v, want _, %v", err, failErr)
}
}
// blockingBalancer mimics the behavior of balancers whose initialization takes a long time.
// In this test, reading from blockingBalancer.Notify() blocks forever.
type blockingBalancer struct {
@ -520,7 +539,6 @@ func TestWithBackoffConfig(t *testing.T) {
defer leakcheck.Check(t)
b := BackoffConfig{MaxDelay: DefaultBackoffConfig.MaxDelay / 2}
expected := b
setDefaults(&expected) // defaults should be set
testBackoffConfigSet(t, &expected, WithBackoffConfig(b))
}
@ -528,7 +546,6 @@ func TestWithBackoffMaxDelay(t *testing.T) {
defer leakcheck.Check(t)
md := DefaultBackoffConfig.MaxDelay / 2
expected := BackoffConfig{MaxDelay: md}
setDefaults(&expected)
testBackoffConfigSet(t, &expected, WithBackoffMaxDelay(md))
}
@ -544,12 +561,15 @@ func testBackoffConfigSet(t *testing.T, expected *BackoffConfig, opts ...DialOpt
t.Fatalf("backoff config not set")
}
actual, ok := conn.dopts.bs.(BackoffConfig)
actual, ok := conn.dopts.bs.(backoff.Exponential)
if !ok {
t.Fatalf("unexpected type of backoff config: %#v", conn.dopts.bs)
}
if actual != *expected {
expectedValue := backoff.Exponential{
MaxDelay: expected.MaxDelay,
}
if actual != expectedValue {
t.Fatalf("unexpected backoff config on connection: %v, want %v", actual, expected)
}
}
@ -662,3 +682,81 @@ func TestClientUpdatesParamsAfterGoAway(t *testing.T) {
t.Fatalf("cc.dopts.copts.Keepalive.Time = %v , want 100ms", v)
}
}
func TestDisableServiceConfigOption(t *testing.T) {
r, cleanup := manual.GenerateAndRegisterManualResolver()
defer cleanup()
addr := r.Scheme() + ":///non.existent"
cc, err := Dial(addr, WithInsecure(), WithDisableServiceConfig())
if err != nil {
t.Fatalf("Dial(%s, _) = _, %v, want _, <nil>", addr, err)
}
defer cc.Close()
r.NewServiceConfig(`{
"methodConfig": [
{
"name": [
{
"service": "foo",
"method": "Bar"
}
],
"waitForReady": true
}
]
}`)
time.Sleep(1 * time.Second)
m := cc.GetMethodConfig("/foo/Bar")
if m.WaitForReady != nil {
t.Fatalf("want: method (\"/foo/bar/\") config to be empty, got: %v", m)
}
}
func TestGetClientConnTarget(t *testing.T) {
addr := "nonexist:///non.existent"
cc, err := Dial(addr, WithInsecure())
if err != nil {
t.Fatalf("Dial(%s, _) = _, %v, want _, <nil>", addr, err)
}
defer cc.Close()
if cc.Target() != addr {
t.Fatalf("Target() = %s, want %s", cc.Target(), addr)
}
}
type backoffForever struct{}
func (b backoffForever) Backoff(int) time.Duration { return time.Duration(math.MaxInt64) }
func TestResetConnectBackoff(t *testing.T) {
defer leakcheck.Check(t)
dials := make(chan struct{})
dialer := func(string, time.Duration) (net.Conn, error) {
dials <- struct{}{}
return nil, errors.New("failed to fake dial")
}
cc, err := Dial("any", WithInsecure(), WithDialer(dialer), withBackoff(backoffForever{}))
if err != nil {
t.Fatalf("Dial() = _, %v; want _, nil", err)
}
defer cc.Close()
select {
case <-dials:
case <-time.NewTimer(10 * time.Second).C:
t.Fatal("Failed to call dial within 10s")
}
select {
case <-dials:
t.Fatal("Dial called unexpectedly before resetting backoff")
case <-time.NewTimer(100 * time.Millisecond).C:
}
cc.ResetConnectBackoff()
select {
case <-dials:
case <-time.NewTimer(10 * time.Second).C:
t.Fatal("Failed to call dial within 10s after resetting backoff")
}
}

View file

@ -22,6 +22,7 @@ package codes // import "google.golang.org/grpc/codes"
import (
"fmt"
"strconv"
)
// A Code is an unsigned 32-bit error code as defined in the gRPC spec.
@ -143,6 +144,8 @@ const (
// Unauthenticated indicates the request does not have valid
// authentication credentials for the operation.
Unauthenticated Code = 16
_maxCode = 17
)
var strToCode = map[string]Code{
@ -176,6 +179,16 @@ func (c *Code) UnmarshalJSON(b []byte) error {
if c == nil {
return fmt.Errorf("nil receiver passed to UnmarshalJSON")
}
if ci, err := strconv.ParseUint(string(b), 10, 32); err == nil {
if ci >= _maxCode {
return fmt.Errorf("invalid code: %q", ci)
}
*c = Code(ci)
return nil
}
if jc, ok := strToCode[string(b)]; ok {
*c = jc
return nil

View file

@ -62,3 +62,23 @@ func TestUnmarshalJSON_UnknownInput(t *testing.T) {
}
}
}
func TestUnmarshalJSON_MarshalUnmarshal(t *testing.T) {
for i := 0; i < _maxCode; i++ {
var cUnMarshaled Code
c := Code(i)
cJSON, err := json.Marshal(c)
if err != nil {
t.Errorf("marshalling %q failed: %v", c, err)
}
if err := json.Unmarshal(cJSON, &cUnMarshaled); err != nil {
t.Errorf("unmarshalling code failed: %s", err)
}
if c != cUnMarshaled {
t.Errorf("code is %q after marshalling/unmarshalling, expected %q", cUnMarshaled, c)
}
}
}

View file

@ -0,0 +1,72 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package connectivity defines connectivity semantics.
// For details, see https://github.com/grpc/grpc/blob/master/doc/connectivity-semantics-and-api.md.
// All APIs in this package are experimental.
package connectivity
import (
"golang.org/x/net/context"
"google.golang.org/grpc/grpclog"
)
// State indicates the state of connectivity.
// It can be the state of a ClientConn or SubConn.
type State int
func (s State) String() string {
switch s {
case Idle:
return "IDLE"
case Connecting:
return "CONNECTING"
case Ready:
return "READY"
case TransientFailure:
return "TRANSIENT_FAILURE"
case Shutdown:
return "SHUTDOWN"
default:
grpclog.Errorf("unknown connectivity state: %d", s)
return "Invalid-State"
}
}
const (
// Idle indicates the ClientConn is idle.
Idle State = iota
// Connecting indicates the ClienConn is connecting.
Connecting
// Ready indicates the ClientConn is ready for work.
Ready
// TransientFailure indicates the ClientConn has seen a failure but expects to recover.
TransientFailure
// Shutdown indicates the ClientConn has started shutting down.
Shutdown
)
// Reporter reports the connectivity states.
type Reporter interface {
// CurrentState returns the current state of the reporter.
CurrentState() State
// WaitForStateChange blocks until the reporter's state is different from the given state,
// and returns true.
// It returns false if <-ctx.Done() can proceed (ctx got timeout or got canceled).
WaitForStateChange(context.Context, State) bool
}

329
vendor/google.golang.org/grpc/credentials/alts/alts.go generated vendored Normal file
View file

@ -0,0 +1,329 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package alts implements the ALTS credential support by gRPC library, which
// encapsulates all the state needed by a client to authenticate with a server
// using ALTS and make various assertions, e.g., about the client's identity,
// role, or whether it is authorized to make a particular call.
// This package is experimental.
package alts
import (
"errors"
"fmt"
"net"
"sync"
"time"
"golang.org/x/net/context"
"google.golang.org/grpc/credentials"
core "google.golang.org/grpc/credentials/alts/internal"
"google.golang.org/grpc/credentials/alts/internal/handshaker"
"google.golang.org/grpc/credentials/alts/internal/handshaker/service"
altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
"google.golang.org/grpc/grpclog"
)
const (
// hypervisorHandshakerServiceAddress represents the default ALTS gRPC
// handshaker service address in the hypervisor.
hypervisorHandshakerServiceAddress = "metadata.google.internal:8080"
// defaultTimeout specifies the server handshake timeout.
defaultTimeout = 30.0 * time.Second
// The following constants specify the minimum and maximum acceptable
// protocol versions.
protocolVersionMaxMajor = 2
protocolVersionMaxMinor = 1
protocolVersionMinMajor = 2
protocolVersionMinMinor = 1
)
var (
once sync.Once
maxRPCVersion = &altspb.RpcProtocolVersions_Version{
Major: protocolVersionMaxMajor,
Minor: protocolVersionMaxMinor,
}
minRPCVersion = &altspb.RpcProtocolVersions_Version{
Major: protocolVersionMinMajor,
Minor: protocolVersionMinMinor,
}
// ErrUntrustedPlatform is returned from ClientHandshake and
// ServerHandshake is running on a platform where the trustworthiness of
// the handshaker service is not guaranteed.
ErrUntrustedPlatform = errors.New("untrusted platform")
)
// AuthInfo exposes security information from the ALTS handshake to the
// application. This interface is to be implemented by ALTS. Users should not
// need a brand new implementation of this interface. For situations like
// testing, any new implementation should embed this interface. This allows
// ALTS to add new methods to this interface.
type AuthInfo interface {
// ApplicationProtocol returns application protocol negotiated for the
// ALTS connection.
ApplicationProtocol() string
// RecordProtocol returns the record protocol negotiated for the ALTS
// connection.
RecordProtocol() string
// SecurityLevel returns the security level of the created ALTS secure
// channel.
SecurityLevel() altspb.SecurityLevel
// PeerServiceAccount returns the peer service account.
PeerServiceAccount() string
// LocalServiceAccount returns the local service account.
LocalServiceAccount() string
// PeerRPCVersions returns the RPC version supported by the peer.
PeerRPCVersions() *altspb.RpcProtocolVersions
}
// ClientOptions contains the client-side options of an ALTS channel. These
// options will be passed to the underlying ALTS handshaker.
type ClientOptions struct {
// TargetServiceAccounts contains a list of expected target service
// accounts.
TargetServiceAccounts []string
// HandshakerServiceAddress represents the ALTS handshaker gRPC service
// address to connect to.
HandshakerServiceAddress string
}
// DefaultClientOptions creates a new ClientOptions object with the default
// values.
func DefaultClientOptions() *ClientOptions {
return &ClientOptions{
HandshakerServiceAddress: hypervisorHandshakerServiceAddress,
}
}
// ServerOptions contains the server-side options of an ALTS channel. These
// options will be passed to the underlying ALTS handshaker.
type ServerOptions struct {
// HandshakerServiceAddress represents the ALTS handshaker gRPC service
// address to connect to.
HandshakerServiceAddress string
}
// DefaultServerOptions creates a new ServerOptions object with the default
// values.
func DefaultServerOptions() *ServerOptions {
return &ServerOptions{
HandshakerServiceAddress: hypervisorHandshakerServiceAddress,
}
}
// altsTC is the credentials required for authenticating a connection using ALTS.
// It implements credentials.TransportCredentials interface.
type altsTC struct {
info *credentials.ProtocolInfo
side core.Side
accounts []string
hsAddress string
}
// NewClientCreds constructs a client-side ALTS TransportCredentials object.
func NewClientCreds(opts *ClientOptions) credentials.TransportCredentials {
return newALTS(core.ClientSide, opts.TargetServiceAccounts, opts.HandshakerServiceAddress)
}
// NewServerCreds constructs a server-side ALTS TransportCredentials object.
func NewServerCreds(opts *ServerOptions) credentials.TransportCredentials {
return newALTS(core.ServerSide, nil, opts.HandshakerServiceAddress)
}
func newALTS(side core.Side, accounts []string, hsAddress string) credentials.TransportCredentials {
once.Do(func() {
vmOnGCP = isRunningOnGCP()
})
if hsAddress == "" {
hsAddress = hypervisorHandshakerServiceAddress
}
return &altsTC{
info: &credentials.ProtocolInfo{
SecurityProtocol: "alts",
SecurityVersion: "1.0",
},
side: side,
accounts: accounts,
hsAddress: hsAddress,
}
}
// ClientHandshake implements the client side handshake protocol.
func (g *altsTC) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (_ net.Conn, _ credentials.AuthInfo, err error) {
if !vmOnGCP {
return nil, nil, ErrUntrustedPlatform
}
// Connecting to ALTS handshaker service.
hsConn, err := service.Dial(g.hsAddress)
if err != nil {
return nil, nil, err
}
// Do not close hsConn since it is shared with other handshakes.
// Possible context leak:
// The cancel function for the child context we create will only be
// called a non-nil error is returned.
var cancel context.CancelFunc
ctx, cancel = context.WithCancel(ctx)
defer func() {
if err != nil {
cancel()
}
}()
opts := handshaker.DefaultClientHandshakerOptions()
opts.TargetServiceAccounts = g.accounts
opts.RPCVersions = &altspb.RpcProtocolVersions{
MaxRpcVersion: maxRPCVersion,
MinRpcVersion: minRPCVersion,
}
chs, err := handshaker.NewClientHandshaker(ctx, hsConn, rawConn, opts)
defer func() {
if err != nil {
chs.Close()
}
}()
if err != nil {
return nil, nil, err
}
secConn, authInfo, err := chs.ClientHandshake(ctx)
if err != nil {
return nil, nil, err
}
altsAuthInfo, ok := authInfo.(AuthInfo)
if !ok {
return nil, nil, errors.New("client-side auth info is not of type alts.AuthInfo")
}
match, _ := checkRPCVersions(opts.RPCVersions, altsAuthInfo.PeerRPCVersions())
if !match {
return nil, nil, fmt.Errorf("server-side RPC versions are not compatible with this client, local versions: %v, peer versions: %v", opts.RPCVersions, altsAuthInfo.PeerRPCVersions())
}
return secConn, authInfo, nil
}
// ServerHandshake implements the server side ALTS handshaker.
func (g *altsTC) ServerHandshake(rawConn net.Conn) (_ net.Conn, _ credentials.AuthInfo, err error) {
if !vmOnGCP {
return nil, nil, ErrUntrustedPlatform
}
// Connecting to ALTS handshaker service.
hsConn, err := service.Dial(g.hsAddress)
if err != nil {
return nil, nil, err
}
// Do not close hsConn since it's shared with other handshakes.
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
defer cancel()
opts := handshaker.DefaultServerHandshakerOptions()
opts.RPCVersions = &altspb.RpcProtocolVersions{
MaxRpcVersion: maxRPCVersion,
MinRpcVersion: minRPCVersion,
}
shs, err := handshaker.NewServerHandshaker(ctx, hsConn, rawConn, opts)
defer func() {
if err != nil {
shs.Close()
}
}()
if err != nil {
return nil, nil, err
}
secConn, authInfo, err := shs.ServerHandshake(ctx)
if err != nil {
return nil, nil, err
}
altsAuthInfo, ok := authInfo.(AuthInfo)
if !ok {
return nil, nil, errors.New("server-side auth info is not of type alts.AuthInfo")
}
match, _ := checkRPCVersions(opts.RPCVersions, altsAuthInfo.PeerRPCVersions())
if !match {
return nil, nil, fmt.Errorf("client-side RPC versions is not compatible with this server, local versions: %v, peer versions: %v", opts.RPCVersions, altsAuthInfo.PeerRPCVersions())
}
return secConn, authInfo, nil
}
func (g *altsTC) Info() credentials.ProtocolInfo {
return *g.info
}
func (g *altsTC) Clone() credentials.TransportCredentials {
info := *g.info
var accounts []string
if g.accounts != nil {
accounts = make([]string, len(g.accounts))
copy(accounts, g.accounts)
}
return &altsTC{
info: &info,
side: g.side,
hsAddress: g.hsAddress,
accounts: accounts,
}
}
func (g *altsTC) OverrideServerName(serverNameOverride string) error {
g.info.ServerName = serverNameOverride
return nil
}
// compareRPCVersion returns 0 if v1 == v2, 1 if v1 > v2 and -1 if v1 < v2.
func compareRPCVersions(v1, v2 *altspb.RpcProtocolVersions_Version) int {
switch {
case v1.GetMajor() > v2.GetMajor(),
v1.GetMajor() == v2.GetMajor() && v1.GetMinor() > v2.GetMinor():
return 1
case v1.GetMajor() < v2.GetMajor(),
v1.GetMajor() == v2.GetMajor() && v1.GetMinor() < v2.GetMinor():
return -1
}
return 0
}
// checkRPCVersions performs a version check between local and peer rpc protocol
// versions. This function returns true if the check passes which means both
// parties agreed on a common rpc protocol to use, and false otherwise. The
// function also returns the highest common RPC protocol version both parties
// agreed on.
func checkRPCVersions(local, peer *altspb.RpcProtocolVersions) (bool, *altspb.RpcProtocolVersions_Version) {
if local == nil || peer == nil {
grpclog.Error("invalid checkRPCVersions argument, either local or peer is nil.")
return false, nil
}
// maxCommonVersion is MIN(local.max, peer.max).
maxCommonVersion := local.GetMaxRpcVersion()
if compareRPCVersions(local.GetMaxRpcVersion(), peer.GetMaxRpcVersion()) > 0 {
maxCommonVersion = peer.GetMaxRpcVersion()
}
// minCommonVersion is MAX(local.min, peer.min).
minCommonVersion := peer.GetMinRpcVersion()
if compareRPCVersions(local.GetMinRpcVersion(), peer.GetMinRpcVersion()) > 0 {
minCommonVersion = local.GetMinRpcVersion()
}
if compareRPCVersions(maxCommonVersion, minCommonVersion) < 0 {
return false, nil
}
return true, maxCommonVersion
}

View file

@ -0,0 +1,290 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package alts
import (
"reflect"
"testing"
"github.com/golang/protobuf/proto"
altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
)
func TestInfoServerName(t *testing.T) {
// This is not testing any handshaker functionality, so it's fine to only
// use NewServerCreds and not NewClientCreds.
alts := NewServerCreds(DefaultServerOptions())
if got, want := alts.Info().ServerName, ""; got != want {
t.Fatalf("%v.Info().ServerName = %v, want %v", alts, got, want)
}
}
func TestOverrideServerName(t *testing.T) {
wantServerName := "server.name"
// This is not testing any handshaker functionality, so it's fine to only
// use NewServerCreds and not NewClientCreds.
c := NewServerCreds(DefaultServerOptions())
c.OverrideServerName(wantServerName)
if got, want := c.Info().ServerName, wantServerName; got != want {
t.Fatalf("c.Info().ServerName = %v, want %v", got, want)
}
}
func TestCloneClient(t *testing.T) {
wantServerName := "server.name"
opt := DefaultClientOptions()
opt.TargetServiceAccounts = []string{"not", "empty"}
c := NewClientCreds(opt)
c.OverrideServerName(wantServerName)
cc := c.Clone()
if got, want := cc.Info().ServerName, wantServerName; got != want {
t.Fatalf("cc.Info().ServerName = %v, want %v", got, want)
}
cc.OverrideServerName("")
if got, want := c.Info().ServerName, wantServerName; got != want {
t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", got, want)
}
if got, want := cc.Info().ServerName, ""; got != want {
t.Fatalf("cc.Info().ServerName = %v, want %v", got, want)
}
ct := c.(*altsTC)
cct := cc.(*altsTC)
if ct.side != cct.side {
t.Errorf("cc.side = %q, want %q", cct.side, ct.side)
}
if ct.hsAddress != cct.hsAddress {
t.Errorf("cc.hsAddress = %q, want %q", cct.hsAddress, ct.hsAddress)
}
if !reflect.DeepEqual(ct.accounts, cct.accounts) {
t.Errorf("cc.accounts = %q, want %q", cct.accounts, ct.accounts)
}
}
func TestCloneServer(t *testing.T) {
wantServerName := "server.name"
c := NewServerCreds(DefaultServerOptions())
c.OverrideServerName(wantServerName)
cc := c.Clone()
if got, want := cc.Info().ServerName, wantServerName; got != want {
t.Fatalf("cc.Info().ServerName = %v, want %v", got, want)
}
cc.OverrideServerName("")
if got, want := c.Info().ServerName, wantServerName; got != want {
t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", got, want)
}
if got, want := cc.Info().ServerName, ""; got != want {
t.Fatalf("cc.Info().ServerName = %v, want %v", got, want)
}
ct := c.(*altsTC)
cct := cc.(*altsTC)
if ct.side != cct.side {
t.Errorf("cc.side = %q, want %q", cct.side, ct.side)
}
if ct.hsAddress != cct.hsAddress {
t.Errorf("cc.hsAddress = %q, want %q", cct.hsAddress, ct.hsAddress)
}
if !reflect.DeepEqual(ct.accounts, cct.accounts) {
t.Errorf("cc.accounts = %q, want %q", cct.accounts, ct.accounts)
}
}
func TestInfo(t *testing.T) {
// This is not testing any handshaker functionality, so it's fine to only
// use NewServerCreds and not NewClientCreds.
c := NewServerCreds(DefaultServerOptions())
info := c.Info()
if got, want := info.ProtocolVersion, ""; got != want {
t.Errorf("info.ProtocolVersion=%v, want %v", got, want)
}
if got, want := info.SecurityProtocol, "alts"; got != want {
t.Errorf("info.SecurityProtocol=%v, want %v", got, want)
}
if got, want := info.SecurityVersion, "1.0"; got != want {
t.Errorf("info.SecurityVersion=%v, want %v", got, want)
}
if got, want := info.ServerName, ""; got != want {
t.Errorf("info.ServerName=%v, want %v", got, want)
}
}
func TestCompareRPCVersions(t *testing.T) {
for _, tc := range []struct {
v1 *altspb.RpcProtocolVersions_Version
v2 *altspb.RpcProtocolVersions_Version
output int
}{
{
version(3, 2),
version(2, 1),
1,
},
{
version(3, 2),
version(3, 1),
1,
},
{
version(2, 1),
version(3, 2),
-1,
},
{
version(3, 1),
version(3, 2),
-1,
},
{
version(3, 2),
version(3, 2),
0,
},
} {
if got, want := compareRPCVersions(tc.v1, tc.v2), tc.output; got != want {
t.Errorf("compareRPCVersions(%v, %v)=%v, want %v", tc.v1, tc.v2, got, want)
}
}
}
func TestCheckRPCVersions(t *testing.T) {
for _, tc := range []struct {
desc string
local *altspb.RpcProtocolVersions
peer *altspb.RpcProtocolVersions
output bool
maxCommonVersion *altspb.RpcProtocolVersions_Version
}{
{
"local.max > peer.max and local.min > peer.min",
versions(2, 1, 3, 2),
versions(1, 2, 2, 1),
true,
version(2, 1),
},
{
"local.max > peer.max and local.min < peer.min",
versions(1, 2, 3, 2),
versions(2, 1, 2, 1),
true,
version(2, 1),
},
{
"local.max > peer.max and local.min = peer.min",
versions(2, 1, 3, 2),
versions(2, 1, 2, 1),
true,
version(2, 1),
},
{
"local.max < peer.max and local.min > peer.min",
versions(2, 1, 2, 1),
versions(1, 2, 3, 2),
true,
version(2, 1),
},
{
"local.max = peer.max and local.min > peer.min",
versions(2, 1, 2, 1),
versions(1, 2, 2, 1),
true,
version(2, 1),
},
{
"local.max < peer.max and local.min < peer.min",
versions(1, 2, 2, 1),
versions(2, 1, 3, 2),
true,
version(2, 1),
},
{
"local.max < peer.max and local.min = peer.min",
versions(1, 2, 2, 1),
versions(1, 2, 3, 2),
true,
version(2, 1),
},
{
"local.max = peer.max and local.min < peer.min",
versions(1, 2, 2, 1),
versions(2, 1, 2, 1),
true,
version(2, 1),
},
{
"all equal",
versions(2, 1, 2, 1),
versions(2, 1, 2, 1),
true,
version(2, 1),
},
{
"max is smaller than min",
versions(2, 1, 1, 2),
versions(2, 1, 1, 2),
false,
nil,
},
{
"no overlap, local > peer",
versions(4, 3, 6, 5),
versions(1, 0, 2, 1),
false,
nil,
},
{
"no overlap, local < peer",
versions(1, 0, 2, 1),
versions(4, 3, 6, 5),
false,
nil,
},
{
"no overlap, max < min",
versions(6, 5, 4, 3),
versions(2, 1, 1, 0),
false,
nil,
},
} {
output, maxCommonVersion := checkRPCVersions(tc.local, tc.peer)
if got, want := output, tc.output; got != want {
t.Errorf("%v: checkRPCVersions(%v, %v)=(%v, _), want (%v, _)", tc.desc, tc.local, tc.peer, got, want)
}
if got, want := maxCommonVersion, tc.maxCommonVersion; !proto.Equal(got, want) {
t.Errorf("%v: checkRPCVersions(%v, %v)=(_, %v), want (_, %v)", tc.desc, tc.local, tc.peer, got, want)
}
}
}
func version(major, minor uint32) *altspb.RpcProtocolVersions_Version {
return &altspb.RpcProtocolVersions_Version{
Major: major,
Minor: minor,
}
}
func versions(minMajor, minMinor, maxMajor, maxMinor uint32) *altspb.RpcProtocolVersions {
return &altspb.RpcProtocolVersions{
MinRpcVersion: version(minMajor, minMinor),
MaxRpcVersion: version(maxMajor, maxMinor),
}
}

View file

@ -0,0 +1,87 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package authinfo provide authentication information returned by handshakers.
package authinfo
import (
"google.golang.org/grpc/credentials"
altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
)
var _ credentials.AuthInfo = (*altsAuthInfo)(nil)
// altsAuthInfo exposes security information from the ALTS handshake to the
// application. altsAuthInfo is immutable and implements credentials.AuthInfo.
type altsAuthInfo struct {
p *altspb.AltsContext
}
// New returns a new altsAuthInfo object given handshaker results.
func New(result *altspb.HandshakerResult) credentials.AuthInfo {
return newAuthInfo(result)
}
func newAuthInfo(result *altspb.HandshakerResult) *altsAuthInfo {
return &altsAuthInfo{
p: &altspb.AltsContext{
ApplicationProtocol: result.GetApplicationProtocol(),
RecordProtocol: result.GetRecordProtocol(),
// TODO: assign security level from result.
SecurityLevel: altspb.SecurityLevel_INTEGRITY_AND_PRIVACY,
PeerServiceAccount: result.GetPeerIdentity().GetServiceAccount(),
LocalServiceAccount: result.GetLocalIdentity().GetServiceAccount(),
PeerRpcVersions: result.GetPeerRpcVersions(),
},
}
}
// AuthType identifies the context as providing ALTS authentication information.
func (s *altsAuthInfo) AuthType() string {
return "alts"
}
// ApplicationProtocol returns the context's application protocol.
func (s *altsAuthInfo) ApplicationProtocol() string {
return s.p.GetApplicationProtocol()
}
// RecordProtocol returns the context's record protocol.
func (s *altsAuthInfo) RecordProtocol() string {
return s.p.GetRecordProtocol()
}
// SecurityLevel returns the context's security level.
func (s *altsAuthInfo) SecurityLevel() altspb.SecurityLevel {
return s.p.GetSecurityLevel()
}
// PeerServiceAccount returns the context's peer service account.
func (s *altsAuthInfo) PeerServiceAccount() string {
return s.p.GetPeerServiceAccount()
}
// LocalServiceAccount returns the context's local service account.
func (s *altsAuthInfo) LocalServiceAccount() string {
return s.p.GetLocalServiceAccount()
}
// PeerRPCVersions returns the context's peer RPC versions.
func (s *altsAuthInfo) PeerRPCVersions() *altspb.RpcProtocolVersions {
return s.p.GetPeerRpcVersions()
}

View file

@ -0,0 +1,134 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package authinfo
import (
"reflect"
"testing"
altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
)
const (
testAppProtocol = "my_app"
testRecordProtocol = "very_secure_protocol"
testPeerAccount = "peer_service_account"
testLocalAccount = "local_service_account"
testPeerHostname = "peer_hostname"
testLocalHostname = "local_hostname"
)
func TestALTSAuthInfo(t *testing.T) {
for _, tc := range []struct {
result *altspb.HandshakerResult
outAppProtocol string
outRecordProtocol string
outSecurityLevel altspb.SecurityLevel
outPeerAccount string
outLocalAccount string
outPeerRPCVersions *altspb.RpcProtocolVersions
}{
{
&altspb.HandshakerResult{
ApplicationProtocol: testAppProtocol,
RecordProtocol: testRecordProtocol,
PeerIdentity: &altspb.Identity{
IdentityOneof: &altspb.Identity_ServiceAccount{
ServiceAccount: testPeerAccount,
},
},
LocalIdentity: &altspb.Identity{
IdentityOneof: &altspb.Identity_ServiceAccount{
ServiceAccount: testLocalAccount,
},
},
},
testAppProtocol,
testRecordProtocol,
altspb.SecurityLevel_INTEGRITY_AND_PRIVACY,
testPeerAccount,
testLocalAccount,
nil,
},
{
&altspb.HandshakerResult{
ApplicationProtocol: testAppProtocol,
RecordProtocol: testRecordProtocol,
PeerIdentity: &altspb.Identity{
IdentityOneof: &altspb.Identity_Hostname{
Hostname: testPeerHostname,
},
},
LocalIdentity: &altspb.Identity{
IdentityOneof: &altspb.Identity_Hostname{
Hostname: testLocalHostname,
},
},
PeerRpcVersions: &altspb.RpcProtocolVersions{
MaxRpcVersion: &altspb.RpcProtocolVersions_Version{
Major: 20,
Minor: 21,
},
MinRpcVersion: &altspb.RpcProtocolVersions_Version{
Major: 10,
Minor: 11,
},
},
},
testAppProtocol,
testRecordProtocol,
altspb.SecurityLevel_INTEGRITY_AND_PRIVACY,
"",
"",
&altspb.RpcProtocolVersions{
MaxRpcVersion: &altspb.RpcProtocolVersions_Version{
Major: 20,
Minor: 21,
},
MinRpcVersion: &altspb.RpcProtocolVersions_Version{
Major: 10,
Minor: 11,
},
},
},
} {
authInfo := newAuthInfo(tc.result)
if got, want := authInfo.AuthType(), "alts"; got != want {
t.Errorf("authInfo.AuthType()=%v, want %v", got, want)
}
if got, want := authInfo.ApplicationProtocol(), tc.outAppProtocol; got != want {
t.Errorf("authInfo.ApplicationProtocol()=%v, want %v", got, want)
}
if got, want := authInfo.RecordProtocol(), tc.outRecordProtocol; got != want {
t.Errorf("authInfo.RecordProtocol()=%v, want %v", got, want)
}
if got, want := authInfo.SecurityLevel(), tc.outSecurityLevel; got != want {
t.Errorf("authInfo.SecurityLevel()=%v, want %v", got, want)
}
if got, want := authInfo.PeerServiceAccount(), tc.outPeerAccount; got != want {
t.Errorf("authInfo.PeerServiceAccount()=%v, want %v", got, want)
}
if got, want := authInfo.LocalServiceAccount(), tc.outLocalAccount; got != want {
t.Errorf("authInfo.LocalServiceAccount()=%v, want %v", got, want)
}
if got, want := authInfo.PeerRPCVersions(), tc.outPeerRPCVersions; !reflect.DeepEqual(got, want) {
t.Errorf("authinfo.PeerRpcVersions()=%v, want %v", got, want)
}
}
}

View file

@ -0,0 +1,69 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
//go:generate ./regenerate.sh
// Package internal contains common core functionality for ALTS.
package internal
import (
"net"
"golang.org/x/net/context"
"google.golang.org/grpc/credentials"
)
const (
// ClientSide identifies the client in this communication.
ClientSide Side = iota
// ServerSide identifies the server in this communication.
ServerSide
)
// PeerNotRespondingError is returned when a peer server is not responding
// after a channel has been established. It is treated as a temporary connection
// error and re-connection to the server should be attempted.
var PeerNotRespondingError = &peerNotRespondingError{}
// Side identifies the party's role: client or server.
type Side int
type peerNotRespondingError struct{}
// Return an error message for the purpose of logging.
func (e *peerNotRespondingError) Error() string {
return "peer server is not responding and re-connection should be attempted."
}
// Temporary indicates if this connection error is temporary or fatal.
func (e *peerNotRespondingError) Temporary() bool {
return true
}
// Handshaker defines a ALTS handshaker interface.
type Handshaker interface {
// ClientHandshake starts and completes a client-side handshaking and
// returns a secure connection and corresponding auth information.
ClientHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error)
// ServerHandshake starts and completes a server-side handshaking and
// returns a secure connection and corresponding auth information.
ServerHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error)
// Close terminates the Handshaker. It should be called when the caller
// obtains the secure connection.
Close()
}

View file

@ -0,0 +1,131 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package conn
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/sha256"
"encoding/binary"
"fmt"
"strconv"
)
// rekeyAEAD holds the necessary information for an AEAD based on
// AES-GCM that performs nonce-based key derivation and XORs the
// nonce with a random mask.
type rekeyAEAD struct {
kdfKey []byte
kdfCounter []byte
nonceMask []byte
nonceBuf []byte
gcmAEAD cipher.AEAD
}
// KeySizeError signals that the given key does not have the correct size.
type KeySizeError int
func (k KeySizeError) Error() string {
return "alts/conn: invalid key size " + strconv.Itoa(int(k))
}
// newRekeyAEAD creates a new instance of aes128gcm with rekeying.
// The key argument should be 44 bytes, the first 32 bytes are used as a key
// for HKDF-expand and the remainining 12 bytes are used as a random mask for
// the counter.
func newRekeyAEAD(key []byte) (*rekeyAEAD, error) {
k := len(key)
if k != kdfKeyLen+nonceLen {
return nil, KeySizeError(k)
}
return &rekeyAEAD{
kdfKey: key[:kdfKeyLen],
kdfCounter: make([]byte, kdfCounterLen),
nonceMask: key[kdfKeyLen:],
nonceBuf: make([]byte, nonceLen),
gcmAEAD: nil,
}, nil
}
// Seal rekeys if nonce[2:8] is different than in the last call, masks the nonce,
// and calls Seal for aes128gcm.
func (s *rekeyAEAD) Seal(dst, nonce, plaintext, additionalData []byte) []byte {
if err := s.rekeyIfRequired(nonce); err != nil {
panic(fmt.Sprintf("Rekeying failed with: %s", err.Error()))
}
maskNonce(s.nonceBuf, nonce, s.nonceMask)
return s.gcmAEAD.Seal(dst, s.nonceBuf, plaintext, additionalData)
}
// Open rekeys if nonce[2:8] is different than in the last call, masks the nonce,
// and calls Open for aes128gcm.
func (s *rekeyAEAD) Open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) {
if err := s.rekeyIfRequired(nonce); err != nil {
return nil, err
}
maskNonce(s.nonceBuf, nonce, s.nonceMask)
return s.gcmAEAD.Open(dst, s.nonceBuf, ciphertext, additionalData)
}
// rekeyIfRequired creates a new aes128gcm AEAD if the existing AEAD is nil
// or cannot be used with given nonce.
func (s *rekeyAEAD) rekeyIfRequired(nonce []byte) error {
newKdfCounter := nonce[kdfCounterOffset : kdfCounterOffset+kdfCounterLen]
if s.gcmAEAD != nil && bytes.Equal(newKdfCounter, s.kdfCounter) {
return nil
}
copy(s.kdfCounter, newKdfCounter)
a, err := aes.NewCipher(hkdfExpand(s.kdfKey, s.kdfCounter))
if err != nil {
return err
}
s.gcmAEAD, err = cipher.NewGCM(a)
return err
}
// maskNonce XORs the given nonce with the mask and stores the result in dst.
func maskNonce(dst, nonce, mask []byte) {
nonce1 := binary.LittleEndian.Uint64(nonce[:sizeUint64])
nonce2 := binary.LittleEndian.Uint32(nonce[sizeUint64:])
mask1 := binary.LittleEndian.Uint64(mask[:sizeUint64])
mask2 := binary.LittleEndian.Uint32(mask[sizeUint64:])
binary.LittleEndian.PutUint64(dst[:sizeUint64], nonce1^mask1)
binary.LittleEndian.PutUint32(dst[sizeUint64:], nonce2^mask2)
}
// NonceSize returns the required nonce size.
func (s *rekeyAEAD) NonceSize() int {
return s.gcmAEAD.NonceSize()
}
// Overhead returns the ciphertext overhead.
func (s *rekeyAEAD) Overhead() int {
return s.gcmAEAD.Overhead()
}
// hkdfExpand computes the first 16 bytes of the HKDF-expand function
// defined in RFC5869.
func hkdfExpand(key, info []byte) []byte {
mac := hmac.New(sha256.New, key)
mac.Write(info)
mac.Write([]byte{0x01}[:])
return mac.Sum(nil)[:aeadKeyLen]
}

View file

@ -0,0 +1,263 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package conn
import (
"bytes"
"encoding/hex"
"testing"
)
// cryptoTestVector is struct for a rekey test vector
type rekeyAEADTestVector struct {
desc string
key, nonce, plaintext, aad, ciphertext []byte
}
// Test encrypt and decrypt using (adapted) test vectors for AES-GCM.
func TestAES128GCMRekeyEncrypt(t *testing.T) {
for _, test := range []rekeyAEADTestVector{
// NIST vectors from:
// http://csrc.nist.gov/groups/ST/toolkit/BCM/documents/proposedmodes/gcm/gcm-revised-spec.pdf
//
// IEEE vectors from:
// http://www.ieee802.org/1/files/public/docs2011/bn-randall-test-vectors-0511-v1.pdf
//
// Key expanded by setting
// expandedKey = (key ||
// key ^ {0x01,..,0x01} ||
// key ^ {0x02,..,0x02})[0:44].
{
desc: "Derived from NIST test vector 1",
key: dehex("0000000000000000000000000000000001010101010101010101010101010101020202020202020202020202"),
nonce: dehex("000000000000000000000000"),
aad: dehex(""),
plaintext: dehex(""),
ciphertext: dehex("85e873e002f6ebdc4060954eb8675508"),
},
{
desc: "Derived from NIST test vector 2",
key: dehex("0000000000000000000000000000000001010101010101010101010101010101020202020202020202020202"),
nonce: dehex("000000000000000000000000"),
aad: dehex(""),
plaintext: dehex("00000000000000000000000000000000"),
ciphertext: dehex("51e9a8cb23ca2512c8256afff8e72d681aca19a1148ac115e83df4888cc00d11"),
},
{
desc: "Derived from NIST test vector 3",
key: dehex("feffe9928665731c6d6a8f9467308308fffee8938764721d6c6b8e9566318209fcfdeb908467711e6f688d96"),
nonce: dehex("cafebabefacedbaddecaf888"),
aad: dehex(""),
plaintext: dehex("d9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a721c3c0c95956809532fcf0e2449a6b525b16aedf5aa0de657ba637b391aafd255"),
ciphertext: dehex("1018ed5a1402a86516d6576d70b2ffccca261b94df88b58f53b64dfba435d18b2f6e3b7869f9353d4ac8cf09afb1663daa7b4017e6fc2c177c0c087c0df1162129952213cee1bc6e9c8495dd705e1f3d"),
},
{
desc: "Derived from NIST test vector 4",
key: dehex("feffe9928665731c6d6a8f9467308308fffee8938764721d6c6b8e9566318209fcfdeb908467711e6f688d96"),
nonce: dehex("cafebabefacedbaddecaf888"),
aad: dehex("feedfacedeadbeeffeedfacedeadbeefabaddad2"),
plaintext: dehex("d9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a721c3c0c95956809532fcf0e2449a6b525b16aedf5aa0de657ba637b39"),
ciphertext: dehex("1018ed5a1402a86516d6576d70b2ffccca261b94df88b58f53b64dfba435d18b2f6e3b7869f9353d4ac8cf09afb1663daa7b4017e6fc2c177c0c087c4764565d077e9124001ddb27fc0848c5"),
},
{
desc: "Derived from adapted NIST test vector 4 for KDF counter boundary (flip nonce bit 15)",
key: dehex("feffe9928665731c6d6a8f9467308308fffee8938764721d6c6b8e9566318209fcfdeb908467711e6f688d96"),
nonce: dehex("ca7ebabefacedbaddecaf888"),
aad: dehex("feedfacedeadbeeffeedfacedeadbeefabaddad2"),
plaintext: dehex("d9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a721c3c0c95956809532fcf0e2449a6b525b16aedf5aa0de657ba637b39"),
ciphertext: dehex("e650d3c0fb879327f2d03287fa93cd07342b136215adbca00c3bd5099ec41832b1d18e0423ed26bb12c6cd09debb29230a94c0cee15903656f85edb6fc509b1b28216382172ecbcc31e1e9b1"),
},
{
desc: "Derived from adapted NIST test vector 4 for KDF counter boundary (flip nonce bit 16)",
key: dehex("feffe9928665731c6d6a8f9467308308fffee8938764721d6c6b8e9566318209fcfdeb908467711e6f688d96"),
nonce: dehex("cafebbbefacedbaddecaf888"),
aad: dehex("feedfacedeadbeeffeedfacedeadbeefabaddad2"),
plaintext: dehex("d9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a721c3c0c95956809532fcf0e2449a6b525b16aedf5aa0de657ba637b39"),
ciphertext: dehex("c0121e6c954d0767f96630c33450999791b2da2ad05c4190169ccad9ac86ff1c721e3d82f2ad22ab463bab4a0754b7dd68ca4de7ea2531b625eda01f89312b2ab957d5c7f8568dd95fcdcd1f"),
},
{
desc: "Derived from adapted NIST test vector 4 for KDF counter boundary (flip nonce bit 63)",
key: dehex("feffe9928665731c6d6a8f9467308308fffee8938764721d6c6b8e9566318209fcfdeb908467711e6f688d96"),
nonce: dehex("cafebabefacedb2ddecaf888"),
aad: dehex("feedfacedeadbeeffeedfacedeadbeefabaddad2"),
plaintext: dehex("d9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a721c3c0c95956809532fcf0e2449a6b525b16aedf5aa0de657ba637b39"),
ciphertext: dehex("8af37ea5684a4d81d4fd817261fd9743099e7e6a025eaacf8e54b124fb5743149e05cb89f4a49467fe2e5e5965f29a19f99416b0016b54585d12553783ba59e9f782e82e097c336bf7989f08"),
},
{
desc: "Derived from adapted NIST test vector 4 for KDF counter boundary (flip nonce bit 64)",
key: dehex("feffe9928665731c6d6a8f9467308308fffee8938764721d6c6b8e9566318209fcfdeb908467711e6f688d96"),
nonce: dehex("cafebabefacedbaddfcaf888"),
aad: dehex("feedfacedeadbeeffeedfacedeadbeefabaddad2"),
plaintext: dehex("d9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a721c3c0c95956809532fcf0e2449a6b525b16aedf5aa0de657ba637b39"),
ciphertext: dehex("fbd528448d0346bfa878634864d407a35a039de9db2f1feb8e965b3ae9356ce6289441d77f8f0df294891f37ea438b223e3bf2bdc53d4c5a74fb680bb312a8dec6f7252cbcd7f5799750ad78"),
},
{
desc: "Derived from IEEE 2.1.1 54-byte auth",
key: dehex("ad7a2bd03eac835a6f620fdcb506b345ac7b2ad13fad825b6e630eddb407b244af7829d23cae81586d600dde"),
nonce: dehex("12153524c0895e81b2c28465"),
aad: dehex("d609b1f056637a0d46df998d88e5222ab2c2846512153524c0895e8108000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f30313233340001"),
plaintext: dehex(""),
ciphertext: dehex("3ea0b584f3c85e93f9320ea591699efb"),
},
{
desc: "Derived from IEEE 2.1.2 54-byte auth",
key: dehex("e3c08a8f06c6e3ad95a70557b23f75483ce33021a9c72b7025666204c69c0b72e1c2888d04c4e1af97a50755"),
nonce: dehex("12153524c0895e81b2c28465"),
aad: dehex("d609b1f056637a0d46df998d88e5222ab2c2846512153524c0895e8108000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f30313233340001"),
plaintext: dehex(""),
ciphertext: dehex("294e028bf1fe6f14c4e8f7305c933eb5"),
},
{
desc: "Derived from IEEE 2.2.1 60-byte crypt",
key: dehex("ad7a2bd03eac835a6f620fdcb506b345ac7b2ad13fad825b6e630eddb407b244af7829d23cae81586d600dde"),
nonce: dehex("12153524c0895e81b2c28465"),
aad: dehex("d609b1f056637a0d46df998d88e52e00b2c2846512153524c0895e81"),
plaintext: dehex("08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a0002"),
ciphertext: dehex("db3d25719c6b0a3ca6145c159d5c6ed9aff9c6e0b79f17019ea923b8665ddf52137ad611f0d1bf417a7ca85e45afe106ff9c7569d335d086ae6c03f00987ccd6"),
},
{
desc: "Derived from IEEE 2.2.2 60-byte crypt",
key: dehex("e3c08a8f06c6e3ad95a70557b23f75483ce33021a9c72b7025666204c69c0b72e1c2888d04c4e1af97a50755"),
nonce: dehex("12153524c0895e81b2c28465"),
aad: dehex("d609b1f056637a0d46df998d88e52e00b2c2846512153524c0895e81"),
plaintext: dehex("08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a0002"),
ciphertext: dehex("1641f28ec13afcc8f7903389787201051644914933e9202bb9d06aa020c2a67ef51dfe7bc00a856c55b8f8133e77f659132502bad63f5713d57d0c11e0f871ed"),
},
{
desc: "Derived from IEEE 2.3.1 60-byte auth",
key: dehex("071b113b0ca743fecccf3d051f737382061a103a0da642ffcdce3c041e727283051913390ea541fccecd3f07"),
nonce: dehex("f0761e8dcd3d000176d457ed"),
aad: dehex("e20106d7cd0df0761e8dcd3d88e5400076d457ed08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a0003"),
plaintext: dehex(""),
ciphertext: dehex("58837a10562b0f1f8edbe58ca55811d3"),
},
{
desc: "Derived from IEEE 2.3.2 60-byte auth",
key: dehex("691d3ee909d7f54167fd1ca0b5d769081f2bde1aee655fdbab80bd5295ae6be76b1f3ceb0bd5f74365ff1ea2"),
nonce: dehex("f0761e8dcd3d000176d457ed"),
aad: dehex("e20106d7cd0df0761e8dcd3d88e5400076d457ed08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a0003"),
plaintext: dehex(""),
ciphertext: dehex("c2722ff6ca29a257718a529d1f0c6a3b"),
},
{
desc: "Derived from IEEE 2.4.1 54-byte crypt",
key: dehex("071b113b0ca743fecccf3d051f737382061a103a0da642ffcdce3c041e727283051913390ea541fccecd3f07"),
nonce: dehex("f0761e8dcd3d000176d457ed"),
aad: dehex("e20106d7cd0df0761e8dcd3d88e54c2a76d457ed"),
plaintext: dehex("08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f30313233340004"),
ciphertext: dehex("fd96b715b93a13346af51e8acdf792cdc7b2686f8574c70e6b0cbf16291ded427ad73fec48cd298e0528a1f4c644a949fc31dc9279706ddba33f"),
},
{
desc: "Derived from IEEE 2.4.2 54-byte crypt",
key: dehex("691d3ee909d7f54167fd1ca0b5d769081f2bde1aee655fdbab80bd5295ae6be76b1f3ceb0bd5f74365ff1ea2"),
nonce: dehex("f0761e8dcd3d000176d457ed"),
aad: dehex("e20106d7cd0df0761e8dcd3d88e54c2a76d457ed"),
plaintext: dehex("08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f30313233340004"),
ciphertext: dehex("b68f6300c2e9ae833bdc070e24021a3477118e78ccf84e11a485d861476c300f175353d5cdf92008a4f878e6cc3577768085c50a0e98fda6cbb8"),
},
{
desc: "Derived from IEEE 2.5.1 65-byte auth",
key: dehex("013fe00b5f11be7f866d0cbbc55a7a90003ee10a5e10bf7e876c0dbac45b7b91033de2095d13bc7d846f0eb9"),
nonce: dehex("7cfde9f9e33724c68932d612"),
aad: dehex("84c5d513d2aaf6e5bbd2727788e523008932d6127cfde9f9e33724c608000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f0005"),
plaintext: dehex(""),
ciphertext: dehex("cca20eecda6283f09bb3543dd99edb9b"),
},
{
desc: "Derived from IEEE 2.5.2 65-byte auth",
key: dehex("83c093b58de7ffe1c0da926ac43fb3609ac1c80fee1b624497ef942e2f79a82381c291b78fe5fde3c2d89068"),
nonce: dehex("7cfde9f9e33724c68932d612"),
aad: dehex("84c5d513d2aaf6e5bbd2727788e523008932d6127cfde9f9e33724c608000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f0005"),
plaintext: dehex(""),
ciphertext: dehex("b232cc1da5117bf15003734fa599d271"),
},
{
desc: "Derived from IEEE 2.6.1 61-byte crypt",
key: dehex("013fe00b5f11be7f866d0cbbc55a7a90003ee10a5e10bf7e876c0dbac45b7b91033de2095d13bc7d846f0eb9"),
nonce: dehex("7cfde9f9e33724c68932d612"),
aad: dehex("84c5d513d2aaf6e5bbd2727788e52f008932d6127cfde9f9e33724c6"),
plaintext: dehex("08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b0006"),
ciphertext: dehex("ff1910d35ad7e5657890c7c560146fd038707f204b66edbc3d161f8ace244b985921023c436e3a1c3532ecd5d09a056d70be583f0d10829d9387d07d33d872e490"),
},
{
desc: "Derived from IEEE 2.6.2 61-byte crypt",
key: dehex("83c093b58de7ffe1c0da926ac43fb3609ac1c80fee1b624497ef942e2f79a82381c291b78fe5fde3c2d89068"),
nonce: dehex("7cfde9f9e33724c68932d612"),
aad: dehex("84c5d513d2aaf6e5bbd2727788e52f008932d6127cfde9f9e33724c6"),
plaintext: dehex("08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b0006"),
ciphertext: dehex("0db4cf956b5f97eca4eab82a6955307f9ae02a32dd7d93f83d66ad04e1cfdc5182ad12abdea5bbb619a1bd5fb9a573590fba908e9c7a46c1f7ba0905d1b55ffda4"),
},
{
desc: "Derived from IEEE 2.7.1 79-byte crypt",
key: dehex("88ee087fd95da9fbf6725aa9d757b0cd89ef097ed85ca8faf7735ba8d656b1cc8aec0a7ddb5fabf9f47058ab"),
nonce: dehex("7ae8e2ca4ec500012e58495c"),
aad: dehex("68f2e77696ce7ae8e2ca4ec588e541002e58495c08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d0007"),
plaintext: dehex(""),
ciphertext: dehex("813f0e630f96fb2d030f58d83f5cdfd0"),
},
{
desc: "Derived from IEEE 2.7.2 79-byte crypt",
key: dehex("4c973dbc7364621674f8b5b89e5c15511fced9216490fb1c1a2caa0ffe0407e54e953fbe7166601476fab7ba"),
nonce: dehex("7ae8e2ca4ec500012e58495c"),
aad: dehex("68f2e77696ce7ae8e2ca4ec588e541002e58495c08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d0007"),
plaintext: dehex(""),
ciphertext: dehex("77e5a44c21eb07188aacbd74d1980e97"),
},
{
desc: "Derived from IEEE 2.8.1 61-byte crypt",
key: dehex("88ee087fd95da9fbf6725aa9d757b0cd89ef097ed85ca8faf7735ba8d656b1cc8aec0a7ddb5fabf9f47058ab"),
nonce: dehex("7ae8e2ca4ec500012e58495c"),
aad: dehex("68f2e77696ce7ae8e2ca4ec588e54d002e58495c"),
plaintext: dehex("08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748490008"),
ciphertext: dehex("958ec3f6d60afeda99efd888f175e5fcd4c87b9bcc5c2f5426253a8b506296c8c43309ab2adb5939462541d95e80811e04e706b1498f2c407c7fb234f8cc01a647550ee6b557b35a7e3945381821f4"),
},
{
desc: "Derived from IEEE 2.8.2 61-byte crypt",
key: dehex("4c973dbc7364621674f8b5b89e5c15511fced9216490fb1c1a2caa0ffe0407e54e953fbe7166601476fab7ba"),
nonce: dehex("7ae8e2ca4ec500012e58495c"),
aad: dehex("68f2e77696ce7ae8e2ca4ec588e54d002e58495c"),
plaintext: dehex("08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748490008"),
ciphertext: dehex("b44d072011cd36d272a9b7a98db9aa90cbc5c67b93ddce67c854503214e2e896ec7e9db649ed4bcf6f850aac0223d0cf92c83db80795c3a17ecc1248bb00591712b1ae71e268164196252162810b00"),
}} {
aead, err := newRekeyAEAD(test.key)
if err != nil {
t.Fatal("unexpected failure in newRekeyAEAD: ", err.Error())
}
if got := aead.Seal(nil, test.nonce, test.plaintext, test.aad); !bytes.Equal(got, test.ciphertext) {
t.Errorf("Unexpected ciphertext for test vector '%s':\nciphertext=%s\nwant= %s",
test.desc, hex.EncodeToString(got), hex.EncodeToString(test.ciphertext))
}
if got, err := aead.Open(nil, test.nonce, test.ciphertext, test.aad); err != nil || !bytes.Equal(got, test.plaintext) {
t.Errorf("Unexpected plaintext for test vector '%s':\nplaintext=%s (err=%v)\nwant= %s",
test.desc, hex.EncodeToString(got), err, hex.EncodeToString(test.plaintext))
}
}
}
func dehex(s string) []byte {
if len(s) == 0 {
return make([]byte, 0)
}
b, err := hex.DecodeString(s)
if err != nil {
panic(err)
}
return b
}

View file

@ -0,0 +1,105 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package conn
import (
"crypto/aes"
"crypto/cipher"
core "google.golang.org/grpc/credentials/alts/internal"
)
const (
// Overflow length n in bytes, never encrypt more than 2^(n*8) frames (in
// each direction).
overflowLenAES128GCM = 5
)
// aes128gcm is the struct that holds necessary information for ALTS record.
// The counter value is NOT included in the payload during the encryption and
// decryption operations.
type aes128gcm struct {
// inCounter is used in ALTS record to check that incoming counters are
// as expected, since ALTS record guarantees that messages are unwrapped
// in the same order that the peer wrapped them.
inCounter Counter
outCounter Counter
aead cipher.AEAD
}
// NewAES128GCM creates an instance that uses aes128gcm for ALTS record.
func NewAES128GCM(side core.Side, key []byte) (ALTSRecordCrypto, error) {
c, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
a, err := cipher.NewGCM(c)
if err != nil {
return nil, err
}
return &aes128gcm{
inCounter: NewInCounter(side, overflowLenAES128GCM),
outCounter: NewOutCounter(side, overflowLenAES128GCM),
aead: a,
}, nil
}
// Encrypt is the encryption function. dst can contain bytes at the beginning of
// the ciphertext that will not be encrypted but will be authenticated. If dst
// has enough capacity to hold these bytes, the ciphertext and the tag, no
// allocation and copy operations will be performed. dst and plaintext do not
// overlap.
func (s *aes128gcm) Encrypt(dst, plaintext []byte) ([]byte, error) {
// If we need to allocate an output buffer, we want to include space for
// GCM tag to avoid forcing ALTS record to reallocate as well.
dlen := len(dst)
dst, out := SliceForAppend(dst, len(plaintext)+GcmTagSize)
seq, err := s.outCounter.Value()
if err != nil {
return nil, err
}
data := out[:len(plaintext)]
copy(data, plaintext) // data may alias plaintext
// Seal appends the ciphertext and the tag to its first argument and
// returns the updated slice. However, SliceForAppend above ensures that
// dst has enough capacity to avoid a reallocation and copy due to the
// append.
dst = s.aead.Seal(dst[:dlen], seq, data, nil)
s.outCounter.Inc()
return dst, nil
}
func (s *aes128gcm) EncryptionOverhead() int {
return GcmTagSize
}
func (s *aes128gcm) Decrypt(dst, ciphertext []byte) ([]byte, error) {
seq, err := s.inCounter.Value()
if err != nil {
return nil, err
}
// If dst is equal to ciphertext[:0], ciphertext storage is reused.
plaintext, err := s.aead.Open(dst, seq, ciphertext, nil)
if err != nil {
return nil, ErrAuth
}
s.inCounter.Inc()
return plaintext, nil
}

View file

@ -0,0 +1,223 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package conn
import (
"bytes"
"testing"
core "google.golang.org/grpc/credentials/alts/internal"
)
// cryptoTestVector is struct for a GCM test vector
type cryptoTestVector struct {
key, counter, plaintext, ciphertext, tag []byte
allocateDst bool
}
// getGCMCryptoPair outputs a client/server pair on aes128gcm.
func getGCMCryptoPair(key []byte, counter []byte, t *testing.T) (ALTSRecordCrypto, ALTSRecordCrypto) {
client, err := NewAES128GCM(core.ClientSide, key)
if err != nil {
t.Fatalf("NewAES128GCM(ClientSide, key) = %v", err)
}
server, err := NewAES128GCM(core.ServerSide, key)
if err != nil {
t.Fatalf("NewAES128GCM(ServerSide, key) = %v", err)
}
// set counter if provided.
if counter != nil {
if CounterSide(counter) == core.ClientSide {
client.(*aes128gcm).outCounter = CounterFromValue(counter, overflowLenAES128GCM)
server.(*aes128gcm).inCounter = CounterFromValue(counter, overflowLenAES128GCM)
} else {
server.(*aes128gcm).outCounter = CounterFromValue(counter, overflowLenAES128GCM)
client.(*aes128gcm).inCounter = CounterFromValue(counter, overflowLenAES128GCM)
}
}
return client, server
}
func testGCMEncryptionDecryption(sender ALTSRecordCrypto, receiver ALTSRecordCrypto, test *cryptoTestVector, withCounter bool, t *testing.T) {
// Ciphertext is: counter + encrypted text + tag.
ciphertext := []byte(nil)
if withCounter {
ciphertext = append(ciphertext, test.counter...)
}
ciphertext = append(ciphertext, test.ciphertext...)
ciphertext = append(ciphertext, test.tag...)
// Decrypt.
if got, err := receiver.Decrypt(nil, ciphertext); err != nil || !bytes.Equal(got, test.plaintext) {
t.Errorf("key=%v\ncounter=%v\ntag=%v\nciphertext=%v\nDecrypt = %v, %v\nwant: %v",
test.key, test.counter, test.tag, test.ciphertext, got, err, test.plaintext)
}
// Encrypt.
var dst []byte
if test.allocateDst {
dst = make([]byte, len(test.plaintext)+sender.EncryptionOverhead())
}
if got, err := sender.Encrypt(dst[:0], test.plaintext); err != nil || !bytes.Equal(got, ciphertext) {
t.Errorf("key=%v\ncounter=%v\nplaintext=%v\nEncrypt = %v, %v\nwant: %v",
test.key, test.counter, test.plaintext, got, err, ciphertext)
}
}
// Test encrypt and decrypt using test vectors for aes128gcm.
func TestAES128GCMEncrypt(t *testing.T) {
for _, test := range []cryptoTestVector{
{
key: dehex("11754cd72aec309bf52f7687212e8957"),
counter: dehex("3c819d9a9bed087615030b65"),
plaintext: nil,
ciphertext: nil,
tag: dehex("250327c674aaf477aef2675748cf6971"),
allocateDst: false,
},
{
key: dehex("ca47248ac0b6f8372a97ac43508308ed"),
counter: dehex("ffd2b598feabc9019262d2be"),
plaintext: nil,
ciphertext: nil,
tag: dehex("60d20404af527d248d893ae495707d1a"),
allocateDst: false,
},
{
key: dehex("7fddb57453c241d03efbed3ac44e371c"),
counter: dehex("ee283a3fc75575e33efd4887"),
plaintext: dehex("d5de42b461646c255c87bd2962d3b9a2"),
ciphertext: dehex("2ccda4a5415cb91e135c2a0f78c9b2fd"),
tag: dehex("b36d1df9b9d5e596f83e8b7f52971cb3"),
allocateDst: false,
},
{
key: dehex("ab72c77b97cb5fe9a382d9fe81ffdbed"),
counter: dehex("54cc7dc2c37ec006bcc6d1da"),
plaintext: dehex("007c5e5b3e59df24a7c355584fc1518d"),
ciphertext: dehex("0e1bde206a07a9c2c1b65300f8c64997"),
tag: dehex("2b4401346697138c7a4891ee59867d0c"),
allocateDst: false,
},
{
key: dehex("11754cd72aec309bf52f7687212e8957"),
counter: dehex("3c819d9a9bed087615030b65"),
plaintext: nil,
ciphertext: nil,
tag: dehex("250327c674aaf477aef2675748cf6971"),
allocateDst: true,
},
{
key: dehex("ca47248ac0b6f8372a97ac43508308ed"),
counter: dehex("ffd2b598feabc9019262d2be"),
plaintext: nil,
ciphertext: nil,
tag: dehex("60d20404af527d248d893ae495707d1a"),
allocateDst: true,
},
{
key: dehex("7fddb57453c241d03efbed3ac44e371c"),
counter: dehex("ee283a3fc75575e33efd4887"),
plaintext: dehex("d5de42b461646c255c87bd2962d3b9a2"),
ciphertext: dehex("2ccda4a5415cb91e135c2a0f78c9b2fd"),
tag: dehex("b36d1df9b9d5e596f83e8b7f52971cb3"),
allocateDst: true,
},
{
key: dehex("ab72c77b97cb5fe9a382d9fe81ffdbed"),
counter: dehex("54cc7dc2c37ec006bcc6d1da"),
plaintext: dehex("007c5e5b3e59df24a7c355584fc1518d"),
ciphertext: dehex("0e1bde206a07a9c2c1b65300f8c64997"),
tag: dehex("2b4401346697138c7a4891ee59867d0c"),
allocateDst: true,
},
} {
// Test encryption and decryption for aes128gcm.
client, server := getGCMCryptoPair(test.key, test.counter, t)
if CounterSide(test.counter) == core.ClientSide {
testGCMEncryptionDecryption(client, server, &test, false, t)
} else {
testGCMEncryptionDecryption(server, client, &test, false, t)
}
}
}
func testGCMEncryptRoundtrip(client ALTSRecordCrypto, server ALTSRecordCrypto, t *testing.T) {
// Encrypt.
const plaintext = "This is plaintext."
var err error
buf := []byte(plaintext)
buf, err = client.Encrypt(buf[:0], buf)
if err != nil {
t.Fatal("Encrypting with client-side context: unexpected error", err, "\n",
"Plaintext:", []byte(plaintext))
}
// Encrypt a second message.
const plaintext2 = "This is a second plaintext."
buf2 := []byte(plaintext2)
buf2, err = client.Encrypt(buf2[:0], buf2)
if err != nil {
t.Fatal("Encrypting with client-side context: unexpected error", err, "\n",
"Plaintext:", []byte(plaintext2))
}
// Decryption fails: cannot decrypt second message before first.
if got, err := server.Decrypt(nil, buf2); err == nil {
t.Error("Decrypting client-side ciphertext with a client-side context unexpectedly succeeded; want unexpected counter error:\n",
" Original plaintext:", []byte(plaintext2), "\n",
" Ciphertext:", buf2, "\n",
" Decrypted plaintext:", got)
}
// Decryption fails: wrong counter space.
if got, err := client.Decrypt(nil, buf); err == nil {
t.Error("Decrypting client-side ciphertext with a client-side context unexpectedly succeeded; want counter space error:\n",
" Original plaintext:", []byte(plaintext), "\n",
" Ciphertext:", buf, "\n",
" Decrypted plaintext:", got)
}
// Decrypt first message.
ciphertext := append([]byte(nil), buf...)
buf, err = server.Decrypt(buf[:0], buf)
if err != nil || string(buf) != plaintext {
t.Fatal("Decrypting client-side ciphertext with a server-side context did not produce original content:\n",
" Original plaintext:", []byte(plaintext), "\n",
" Ciphertext:", ciphertext, "\n",
" Decryption error:", err, "\n",
" Decrypted plaintext:", buf)
}
// Decryption fails: replay attack.
if got, err := server.Decrypt(nil, buf); err == nil {
t.Error("Decrypting client-side ciphertext with a client-side context unexpectedly succeeded; want unexpected counter error:\n",
" Original plaintext:", []byte(plaintext), "\n",
" Ciphertext:", buf, "\n",
" Decrypted plaintext:", got)
}
}
// Test encrypt and decrypt on roundtrip messages for aes128gcm.
func TestAES128GCMEncryptRoundtrip(t *testing.T) {
// Test for aes128gcm.
key := make([]byte, 16)
client, server := getGCMCryptoPair(key, nil, t)
testGCMEncryptRoundtrip(client, server, t)
}

View file

@ -0,0 +1,116 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package conn
import (
"crypto/cipher"
core "google.golang.org/grpc/credentials/alts/internal"
)
const (
// Overflow length n in bytes, never encrypt more than 2^(n*8) frames (in
// each direction).
overflowLenAES128GCMRekey = 8
nonceLen = 12
aeadKeyLen = 16
kdfKeyLen = 32
kdfCounterOffset = 2
kdfCounterLen = 6
sizeUint64 = 8
)
// aes128gcmRekey is the struct that holds necessary information for ALTS record.
// The counter value is NOT included in the payload during the encryption and
// decryption operations.
type aes128gcmRekey struct {
// inCounter is used in ALTS record to check that incoming counters are
// as expected, since ALTS record guarantees that messages are unwrapped
// in the same order that the peer wrapped them.
inCounter Counter
outCounter Counter
inAEAD cipher.AEAD
outAEAD cipher.AEAD
}
// NewAES128GCMRekey creates an instance that uses aes128gcm with rekeying
// for ALTS record. The key argument should be 44 bytes, the first 32 bytes
// are used as a key for HKDF-expand and the remainining 12 bytes are used
// as a random mask for the counter.
func NewAES128GCMRekey(side core.Side, key []byte) (ALTSRecordCrypto, error) {
inCounter := NewInCounter(side, overflowLenAES128GCMRekey)
outCounter := NewOutCounter(side, overflowLenAES128GCMRekey)
inAEAD, err := newRekeyAEAD(key)
if err != nil {
return nil, err
}
outAEAD, err := newRekeyAEAD(key)
if err != nil {
return nil, err
}
return &aes128gcmRekey{
inCounter,
outCounter,
inAEAD,
outAEAD,
}, nil
}
// Encrypt is the encryption function. dst can contain bytes at the beginning of
// the ciphertext that will not be encrypted but will be authenticated. If dst
// has enough capacity to hold these bytes, the ciphertext and the tag, no
// allocation and copy operations will be performed. dst and plaintext do not
// overlap.
func (s *aes128gcmRekey) Encrypt(dst, plaintext []byte) ([]byte, error) {
// If we need to allocate an output buffer, we want to include space for
// GCM tag to avoid forcing ALTS record to reallocate as well.
dlen := len(dst)
dst, out := SliceForAppend(dst, len(plaintext)+GcmTagSize)
seq, err := s.outCounter.Value()
if err != nil {
return nil, err
}
data := out[:len(plaintext)]
copy(data, plaintext) // data may alias plaintext
// Seal appends the ciphertext and the tag to its first argument and
// returns the updated slice. However, SliceForAppend above ensures that
// dst has enough capacity to avoid a reallocation and copy due to the
// append.
dst = s.outAEAD.Seal(dst[:dlen], seq, data, nil)
s.outCounter.Inc()
return dst, nil
}
func (s *aes128gcmRekey) EncryptionOverhead() int {
return GcmTagSize
}
func (s *aes128gcmRekey) Decrypt(dst, ciphertext []byte) ([]byte, error) {
seq, err := s.inCounter.Value()
if err != nil {
return nil, err
}
plaintext, err := s.inAEAD.Open(dst, seq, ciphertext, nil)
if err != nil {
return nil, ErrAuth
}
s.inCounter.Inc()
return plaintext, nil
}

View file

@ -0,0 +1,117 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package conn
import (
"testing"
core "google.golang.org/grpc/credentials/alts/internal"
)
// cryptoTestVector is struct for a rekey test vector
type rekeyTestVector struct {
key, nonce, plaintext, ciphertext []byte
}
// getGCMCryptoPair outputs a client/server pair on aes128gcmRekey.
func getRekeyCryptoPair(key []byte, counter []byte, t *testing.T) (ALTSRecordCrypto, ALTSRecordCrypto) {
client, err := NewAES128GCMRekey(core.ClientSide, key)
if err != nil {
t.Fatalf("NewAES128GCMRekey(ClientSide, key) = %v", err)
}
server, err := NewAES128GCMRekey(core.ServerSide, key)
if err != nil {
t.Fatalf("NewAES128GCMRekey(ServerSide, key) = %v", err)
}
// set counter if provided.
if counter != nil {
if CounterSide(counter) == core.ClientSide {
client.(*aes128gcmRekey).outCounter = CounterFromValue(counter, overflowLenAES128GCMRekey)
server.(*aes128gcmRekey).inCounter = CounterFromValue(counter, overflowLenAES128GCMRekey)
} else {
server.(*aes128gcmRekey).outCounter = CounterFromValue(counter, overflowLenAES128GCMRekey)
client.(*aes128gcmRekey).inCounter = CounterFromValue(counter, overflowLenAES128GCMRekey)
}
}
return client, server
}
func testRekeyEncryptRoundtrip(client ALTSRecordCrypto, server ALTSRecordCrypto, t *testing.T) {
// Encrypt.
const plaintext = "This is plaintext."
var err error
buf := []byte(plaintext)
buf, err = client.Encrypt(buf[:0], buf)
if err != nil {
t.Fatal("Encrypting with client-side context: unexpected error", err, "\n",
"Plaintext:", []byte(plaintext))
}
// Encrypt a second message.
const plaintext2 = "This is a second plaintext."
buf2 := []byte(plaintext2)
buf2, err = client.Encrypt(buf2[:0], buf2)
if err != nil {
t.Fatal("Encrypting with client-side context: unexpected error", err, "\n",
"Plaintext:", []byte(plaintext2))
}
// Decryption fails: cannot decrypt second message before first.
if got, err := server.Decrypt(nil, buf2); err == nil {
t.Error("Decrypting client-side ciphertext with a client-side context unexpectedly succeeded; want unexpected counter error:\n",
" Original plaintext:", []byte(plaintext2), "\n",
" Ciphertext:", buf2, "\n",
" Decrypted plaintext:", got)
}
// Decryption fails: wrong counter space.
if got, err := client.Decrypt(nil, buf); err == nil {
t.Error("Decrypting client-side ciphertext with a client-side context unexpectedly succeeded; want counter space error:\n",
" Original plaintext:", []byte(plaintext), "\n",
" Ciphertext:", buf, "\n",
" Decrypted plaintext:", got)
}
// Decrypt first message.
ciphertext := append([]byte(nil), buf...)
buf, err = server.Decrypt(buf[:0], buf)
if err != nil || string(buf) != plaintext {
t.Fatal("Decrypting client-side ciphertext with a server-side context did not produce original content:\n",
" Original plaintext:", []byte(plaintext), "\n",
" Ciphertext:", ciphertext, "\n",
" Decryption error:", err, "\n",
" Decrypted plaintext:", buf)
}
// Decryption fails: replay attack.
if got, err := server.Decrypt(nil, buf); err == nil {
t.Error("Decrypting client-side ciphertext with a client-side context unexpectedly succeeded; want unexpected counter error:\n",
" Original plaintext:", []byte(plaintext), "\n",
" Ciphertext:", buf, "\n",
" Decrypted plaintext:", got)
}
}
// Test encrypt and decrypt on roundtrip messages for aes128gcmRekey.
func TestAES128GCMRekeyEncryptRoundtrip(t *testing.T) {
// Test for aes128gcmRekey.
key := make([]byte, 44)
client, server := getRekeyCryptoPair(key, nil, t)
testRekeyEncryptRoundtrip(client, server, t)
}

View file

@ -0,0 +1,70 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package conn
import (
"encoding/binary"
"errors"
"fmt"
)
const (
// GcmTagSize is the GCM tag size is the difference in length between
// plaintext and ciphertext. From crypto/cipher/gcm.go in Go crypto
// library.
GcmTagSize = 16
)
// ErrAuth occurs on authentication failure.
var ErrAuth = errors.New("message authentication failed")
// SliceForAppend takes a slice and a requested number of bytes. It returns a
// slice with the contents of the given slice followed by that many bytes and a
// second slice that aliases into it and contains only the extra bytes. If the
// original slice has sufficient capacity then no allocation is performed.
func SliceForAppend(in []byte, n int) (head, tail []byte) {
if total := len(in) + n; cap(in) >= total {
head = in[:total]
} else {
head = make([]byte, total)
copy(head, in)
}
tail = head[len(in):]
return head, tail
}
// ParseFramedMsg parse the provided buffer and returns a frame of the format
// msgLength+msg and any remaining bytes in that buffer.
func ParseFramedMsg(b []byte, maxLen uint32) ([]byte, []byte, error) {
// If the size field is not complete, return the provided buffer as
// remaining buffer.
if len(b) < MsgLenFieldSize {
return nil, b, nil
}
msgLenField := b[:MsgLenFieldSize]
length := binary.LittleEndian.Uint32(msgLenField)
if length > maxLen {
return nil, nil, fmt.Errorf("received the frame length %d larger than the limit %d", length, maxLen)
}
if len(b) < int(length)+4 { // account for the first 4 msg length bytes.
// Frame is not complete yet.
return nil, b, nil
}
return b[:MsgLenFieldSize+length], b[MsgLenFieldSize+length:], nil
}

View file

@ -0,0 +1,62 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package conn
import (
"errors"
)
const counterLen = 12
var (
errInvalidCounter = errors.New("invalid counter")
)
// Counter is a 96-bit, little-endian counter.
type Counter struct {
value [counterLen]byte
invalid bool
overflowLen int
}
// Value returns the current value of the counter as a byte slice.
func (c *Counter) Value() ([]byte, error) {
if c.invalid {
return nil, errInvalidCounter
}
return c.value[:], nil
}
// Inc increments the counter and checks for overflow.
func (c *Counter) Inc() {
// If the counter is already invalid, there is not need to increase it.
if c.invalid {
return
}
i := 0
for ; i < c.overflowLen; i++ {
c.value[i]++
if c.value[i] != 0 {
break
}
}
if i == c.overflowLen {
c.invalid = true
}
}

View file

@ -0,0 +1,141 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package conn
import (
"bytes"
"testing"
core "google.golang.org/grpc/credentials/alts/internal"
)
const (
testOverflowLen = 5
)
func TestCounterSides(t *testing.T) {
for _, side := range []core.Side{core.ClientSide, core.ServerSide} {
outCounter := NewOutCounter(side, testOverflowLen)
inCounter := NewInCounter(side, testOverflowLen)
for i := 0; i < 1024; i++ {
value, _ := outCounter.Value()
if g, w := CounterSide(value), side; g != w {
t.Errorf("after %d iterations, CounterSide(outCounter.Value()) = %v, want %v", i, g, w)
break
}
value, _ = inCounter.Value()
if g, w := CounterSide(value), side; g == w {
t.Errorf("after %d iterations, CounterSide(inCounter.Value()) = %v, want %v", i, g, w)
break
}
outCounter.Inc()
inCounter.Inc()
}
}
}
func TestCounterInc(t *testing.T) {
for _, test := range []struct {
counter []byte
want []byte
}{
{
counter: []byte{0x00, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
want: []byte{0x01, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
},
{
counter: []byte{0x00, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x80},
want: []byte{0x01, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x80},
},
{
counter: []byte{0xff, 0x00, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
want: []byte{0x00, 0x01, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
},
{
counter: []byte{0x42, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
want: []byte{0x43, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
},
{
counter: []byte{0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
want: []byte{0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
},
{
counter: []byte{0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80},
want: []byte{0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80},
},
} {
c := CounterFromValue(test.counter, overflowLenAES128GCM)
c.Inc()
value, _ := c.Value()
if g, w := value, test.want; !bytes.Equal(g, w) || c.invalid {
t.Errorf("counter(%v).Inc() =\n%v, want\n%v", test.counter, g, w)
}
}
}
func TestRolloverCounter(t *testing.T) {
for _, test := range []struct {
desc string
value []byte
overflowLen int
}{
{
desc: "testing overflow without rekeying 1",
value: []byte{0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80},
overflowLen: 5,
},
{
desc: "testing overflow without rekeying 2",
value: []byte{0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
overflowLen: 5,
},
{
desc: "testing overflow for rekeying mode 1",
value: []byte{0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x80},
overflowLen: 8,
},
{
desc: "testing overflow for rekeying mode 2",
value: []byte{0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00},
overflowLen: 8,
},
} {
c := CounterFromValue(test.value, overflowLenAES128GCM)
// First Inc() + Value() should work.
c.Inc()
_, err := c.Value()
if err != nil {
t.Errorf("%v: first Inc() + Value() unexpectedly failed: %v, want <nil> error", test.desc, err)
}
// Second Inc() + Value() should fail.
c.Inc()
_, err = c.Value()
if err != errInvalidCounter {
t.Errorf("%v: second Inc() + Value() unexpectedly succeeded: want %v", test.desc, errInvalidCounter)
}
// Third Inc() + Value() should also fail because the counter is
// already in an invalid state.
c.Inc()
_, err = c.Value()
if err != errInvalidCounter {
t.Errorf("%v: Third Inc() + Value() unexpectedly succeeded: want %v", test.desc, errInvalidCounter)
}
}
}

View file

@ -0,0 +1,271 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package conn contains an implementation of a secure channel created by gRPC
// handshakers.
package conn
import (
"encoding/binary"
"fmt"
"math"
"net"
core "google.golang.org/grpc/credentials/alts/internal"
)
// ALTSRecordCrypto is the interface for gRPC ALTS record protocol.
type ALTSRecordCrypto interface {
// Encrypt encrypts the plaintext and computes the tag (if any) of dst
// and plaintext, dst and plaintext do not overlap.
Encrypt(dst, plaintext []byte) ([]byte, error)
// EncryptionOverhead returns the tag size (if any) in bytes.
EncryptionOverhead() int
// Decrypt decrypts ciphertext and verify the tag (if any). dst and
// ciphertext may alias exactly or not at all. To reuse ciphertext's
// storage for the decrypted output, use ciphertext[:0] as dst.
Decrypt(dst, ciphertext []byte) ([]byte, error)
}
// ALTSRecordFunc is a function type for factory functions that create
// ALTSRecordCrypto instances.
type ALTSRecordFunc func(s core.Side, keyData []byte) (ALTSRecordCrypto, error)
const (
// MsgLenFieldSize is the byte size of the frame length field of a
// framed message.
MsgLenFieldSize = 4
// The byte size of the message type field of a framed message.
msgTypeFieldSize = 4
// The bytes size limit for a ALTS record message.
altsRecordLengthLimit = 1024 * 1024 // 1 MiB
// The default bytes size of a ALTS record message.
altsRecordDefaultLength = 4 * 1024 // 4KiB
// Message type value included in ALTS record framing.
altsRecordMsgType = uint32(0x06)
// The initial write buffer size.
altsWriteBufferInitialSize = 32 * 1024 // 32KiB
// The maximum write buffer size. This *must* be multiple of
// altsRecordDefaultLength.
altsWriteBufferMaxSize = 512 * 1024 // 512KiB
)
var (
protocols = make(map[string]ALTSRecordFunc)
)
// RegisterProtocol register a ALTS record encryption protocol.
func RegisterProtocol(protocol string, f ALTSRecordFunc) error {
if _, ok := protocols[protocol]; ok {
return fmt.Errorf("protocol %v is already registered", protocol)
}
protocols[protocol] = f
return nil
}
// conn represents a secured connection. It implements the net.Conn interface.
type conn struct {
net.Conn
crypto ALTSRecordCrypto
// buf holds data that has been read from the connection and decrypted,
// but has not yet been returned by Read.
buf []byte
payloadLengthLimit int
// protected holds data read from the network but have not yet been
// decrypted. This data might not compose a complete frame.
protected []byte
// writeBuf is a buffer used to contain encrypted frames before being
// written to the network.
writeBuf []byte
// nextFrame stores the next frame (in protected buffer) info.
nextFrame []byte
// overhead is the calculated overhead of each frame.
overhead int
}
// NewConn creates a new secure channel instance given the other party role and
// handshaking result.
func NewConn(c net.Conn, side core.Side, recordProtocol string, key []byte, protected []byte) (net.Conn, error) {
newCrypto := protocols[recordProtocol]
if newCrypto == nil {
return nil, fmt.Errorf("negotiated unknown next_protocol %q", recordProtocol)
}
crypto, err := newCrypto(side, key)
if err != nil {
return nil, fmt.Errorf("protocol %q: %v", recordProtocol, err)
}
overhead := MsgLenFieldSize + msgTypeFieldSize + crypto.EncryptionOverhead()
payloadLengthLimit := altsRecordDefaultLength - overhead
if protected == nil {
// We pre-allocate protected to be of size
// 2*altsRecordDefaultLength-1 during initialization. We only
// read from the network into protected when protected does not
// contain a complete frame, which is at most
// altsRecordDefaultLength-1 (bytes). And we read at most
// altsRecordDefaultLength (bytes) data into protected at one
// time. Therefore, 2*altsRecordDefaultLength-1 is large enough
// to buffer data read from the network.
protected = make([]byte, 0, 2*altsRecordDefaultLength-1)
}
altsConn := &conn{
Conn: c,
crypto: crypto,
payloadLengthLimit: payloadLengthLimit,
protected: protected,
writeBuf: make([]byte, altsWriteBufferInitialSize),
nextFrame: protected,
overhead: overhead,
}
return altsConn, nil
}
// Read reads and decrypts a frame from the underlying connection, and copies the
// decrypted payload into b. If the size of the payload is greater than len(b),
// Read retains the remaining bytes in an internal buffer, and subsequent calls
// to Read will read from this buffer until it is exhausted.
func (p *conn) Read(b []byte) (n int, err error) {
if len(p.buf) == 0 {
var framedMsg []byte
framedMsg, p.nextFrame, err = ParseFramedMsg(p.nextFrame, altsRecordLengthLimit)
if err != nil {
return n, err
}
// Check whether the next frame to be decrypted has been
// completely received yet.
if len(framedMsg) == 0 {
copy(p.protected, p.nextFrame)
p.protected = p.protected[:len(p.nextFrame)]
// Always copy next incomplete frame to the beginning of
// the protected buffer and reset nextFrame to it.
p.nextFrame = p.protected
}
// Check whether a complete frame has been received yet.
for len(framedMsg) == 0 {
if len(p.protected) == cap(p.protected) {
tmp := make([]byte, len(p.protected), cap(p.protected)+altsRecordDefaultLength)
copy(tmp, p.protected)
p.protected = tmp
}
n, err = p.Conn.Read(p.protected[len(p.protected):min(cap(p.protected), len(p.protected)+altsRecordDefaultLength)])
if err != nil {
return 0, err
}
p.protected = p.protected[:len(p.protected)+n]
framedMsg, p.nextFrame, err = ParseFramedMsg(p.protected, altsRecordLengthLimit)
if err != nil {
return 0, err
}
}
// Now we have a complete frame, decrypted it.
msg := framedMsg[MsgLenFieldSize:]
msgType := binary.LittleEndian.Uint32(msg[:msgTypeFieldSize])
if msgType&0xff != altsRecordMsgType {
return 0, fmt.Errorf("received frame with incorrect message type %v, expected lower byte %v",
msgType, altsRecordMsgType)
}
ciphertext := msg[msgTypeFieldSize:]
// Decrypt requires that if the dst and ciphertext alias, they
// must alias exactly. Code here used to use msg[:0], but msg
// starts MsgLenFieldSize+msgTypeFieldSize bytes earlier than
// ciphertext, so they alias inexactly. Using ciphertext[:0]
// arranges the appropriate aliasing without needing to copy
// ciphertext or use a separate destination buffer. For more info
// check: https://golang.org/pkg/crypto/cipher/#AEAD.
p.buf, err = p.crypto.Decrypt(ciphertext[:0], ciphertext)
if err != nil {
return 0, err
}
}
n = copy(b, p.buf)
p.buf = p.buf[n:]
return n, nil
}
// Write encrypts, frames, and writes bytes from b to the underlying connection.
func (p *conn) Write(b []byte) (n int, err error) {
n = len(b)
// Calculate the output buffer size with framing and encryption overhead.
numOfFrames := int(math.Ceil(float64(len(b)) / float64(p.payloadLengthLimit)))
size := len(b) + numOfFrames*p.overhead
// If writeBuf is too small, increase its size up to the maximum size.
partialBSize := len(b)
if size > altsWriteBufferMaxSize {
size = altsWriteBufferMaxSize
const numOfFramesInMaxWriteBuf = altsWriteBufferMaxSize / altsRecordDefaultLength
partialBSize = numOfFramesInMaxWriteBuf * p.payloadLengthLimit
}
if len(p.writeBuf) < size {
p.writeBuf = make([]byte, size)
}
for partialBStart := 0; partialBStart < len(b); partialBStart += partialBSize {
partialBEnd := partialBStart + partialBSize
if partialBEnd > len(b) {
partialBEnd = len(b)
}
partialB := b[partialBStart:partialBEnd]
writeBufIndex := 0
for len(partialB) > 0 {
payloadLen := len(partialB)
if payloadLen > p.payloadLengthLimit {
payloadLen = p.payloadLengthLimit
}
buf := partialB[:payloadLen]
partialB = partialB[payloadLen:]
// Write buffer contains: length, type, payload, and tag
// if any.
// 1. Fill in type field.
msg := p.writeBuf[writeBufIndex+MsgLenFieldSize:]
binary.LittleEndian.PutUint32(msg, altsRecordMsgType)
// 2. Encrypt the payload and create a tag if any.
msg, err = p.crypto.Encrypt(msg[:msgTypeFieldSize], buf)
if err != nil {
return n, err
}
// 3. Fill in the size field.
binary.LittleEndian.PutUint32(p.writeBuf[writeBufIndex:], uint32(len(msg)))
// 4. Increase writeBufIndex.
writeBufIndex += len(buf) + p.overhead
}
nn, err := p.Conn.Write(p.writeBuf[:writeBufIndex])
if err != nil {
// We need to calculate the actual data size that was
// written. This means we need to remove header,
// encryption overheads, and any partially-written
// frame data.
numOfWrittenFrames := int(math.Floor(float64(nn) / float64(altsRecordDefaultLength)))
return partialBStart + numOfWrittenFrames*p.payloadLengthLimit, err
}
}
return n, nil
}
func min(a, b int) int {
if a < b {
return a
}
return b
}

View file

@ -0,0 +1,274 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package conn
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"math"
"net"
"reflect"
"testing"
core "google.golang.org/grpc/credentials/alts/internal"
)
var (
nextProtocols = []string{"ALTSRP_GCM_AES128"}
altsRecordFuncs = map[string]ALTSRecordFunc{
// ALTS handshaker protocols.
"ALTSRP_GCM_AES128": func(s core.Side, keyData []byte) (ALTSRecordCrypto, error) {
return NewAES128GCM(s, keyData)
},
}
)
func init() {
for protocol, f := range altsRecordFuncs {
if err := RegisterProtocol(protocol, f); err != nil {
panic(err)
}
}
}
// testConn mimics a net.Conn to the peer.
type testConn struct {
net.Conn
in *bytes.Buffer
out *bytes.Buffer
}
func (c *testConn) Read(b []byte) (n int, err error) {
return c.in.Read(b)
}
func (c *testConn) Write(b []byte) (n int, err error) {
return c.out.Write(b)
}
func (c *testConn) Close() error {
return nil
}
func newTestALTSRecordConn(in, out *bytes.Buffer, side core.Side, np string) *conn {
key := []byte{
// 16 arbitrary bytes.
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49}
tc := testConn{
in: in,
out: out,
}
c, err := NewConn(&tc, side, np, key, nil)
if err != nil {
panic(fmt.Sprintf("Unexpected error creating test ALTS record connection: %v", err))
}
return c.(*conn)
}
func newConnPair(np string) (client, server *conn) {
clientBuf := new(bytes.Buffer)
serverBuf := new(bytes.Buffer)
clientConn := newTestALTSRecordConn(clientBuf, serverBuf, core.ClientSide, np)
serverConn := newTestALTSRecordConn(serverBuf, clientBuf, core.ServerSide, np)
return clientConn, serverConn
}
func testPingPong(t *testing.T, np string) {
clientConn, serverConn := newConnPair(np)
clientMsg := []byte("Client Message")
if n, err := clientConn.Write(clientMsg); n != len(clientMsg) || err != nil {
t.Fatalf("Client Write() = %v, %v; want %v, <nil>", n, err, len(clientMsg))
}
rcvClientMsg := make([]byte, len(clientMsg))
if n, err := serverConn.Read(rcvClientMsg); n != len(rcvClientMsg) || err != nil {
t.Fatalf("Server Read() = %v, %v; want %v, <nil>", n, err, len(rcvClientMsg))
}
if !reflect.DeepEqual(clientMsg, rcvClientMsg) {
t.Fatalf("Client Write()/Server Read() = %v, want %v", rcvClientMsg, clientMsg)
}
serverMsg := []byte("Server Message")
if n, err := serverConn.Write(serverMsg); n != len(serverMsg) || err != nil {
t.Fatalf("Server Write() = %v, %v; want %v, <nil>", n, err, len(serverMsg))
}
rcvServerMsg := make([]byte, len(serverMsg))
if n, err := clientConn.Read(rcvServerMsg); n != len(rcvServerMsg) || err != nil {
t.Fatalf("Client Read() = %v, %v; want %v, <nil>", n, err, len(rcvServerMsg))
}
if !reflect.DeepEqual(serverMsg, rcvServerMsg) {
t.Fatalf("Server Write()/Client Read() = %v, want %v", rcvServerMsg, serverMsg)
}
}
func TestPingPong(t *testing.T) {
for _, np := range nextProtocols {
testPingPong(t, np)
}
}
func testSmallReadBuffer(t *testing.T, np string) {
clientConn, serverConn := newConnPair(np)
msg := []byte("Very Important Message")
if n, err := clientConn.Write(msg); err != nil {
t.Fatalf("Write() = %v, %v; want %v, <nil>", n, err, len(msg))
}
rcvMsg := make([]byte, len(msg))
n := 2 // Arbitrary index to break rcvMsg in two.
rcvMsg1 := rcvMsg[:n]
rcvMsg2 := rcvMsg[n:]
if n, err := serverConn.Read(rcvMsg1); n != len(rcvMsg1) || err != nil {
t.Fatalf("Read() = %v, %v; want %v, <nil>", n, err, len(rcvMsg1))
}
if n, err := serverConn.Read(rcvMsg2); n != len(rcvMsg2) || err != nil {
t.Fatalf("Read() = %v, %v; want %v, <nil>", n, err, len(rcvMsg2))
}
if !reflect.DeepEqual(msg, rcvMsg) {
t.Fatalf("Write()/Read() = %v, want %v", rcvMsg, msg)
}
}
func TestSmallReadBuffer(t *testing.T) {
for _, np := range nextProtocols {
testSmallReadBuffer(t, np)
}
}
func testLargeMsg(t *testing.T, np string) {
clientConn, serverConn := newConnPair(np)
// msgLen is such that the length in the framing is larger than the
// default size of one frame.
msgLen := altsRecordDefaultLength - msgTypeFieldSize - clientConn.crypto.EncryptionOverhead() + 1
msg := make([]byte, msgLen)
if n, err := clientConn.Write(msg); n != len(msg) || err != nil {
t.Fatalf("Write() = %v, %v; want %v, <nil>", n, err, len(msg))
}
rcvMsg := make([]byte, len(msg))
if n, err := io.ReadFull(serverConn, rcvMsg); n != len(rcvMsg) || err != nil {
t.Fatalf("Read() = %v, %v; want %v, <nil>", n, err, len(rcvMsg))
}
if !reflect.DeepEqual(msg, rcvMsg) {
t.Fatalf("Write()/Server Read() = %v, want %v", rcvMsg, msg)
}
}
func TestLargeMsg(t *testing.T) {
for _, np := range nextProtocols {
testLargeMsg(t, np)
}
}
func testIncorrectMsgType(t *testing.T, np string) {
// framedMsg is an empty ciphertext with correct framing but wrong
// message type.
framedMsg := make([]byte, MsgLenFieldSize+msgTypeFieldSize)
binary.LittleEndian.PutUint32(framedMsg[:MsgLenFieldSize], msgTypeFieldSize)
wrongMsgType := uint32(0x22)
binary.LittleEndian.PutUint32(framedMsg[MsgLenFieldSize:], wrongMsgType)
in := bytes.NewBuffer(framedMsg)
c := newTestALTSRecordConn(in, nil, core.ClientSide, np)
b := make([]byte, 1)
if n, err := c.Read(b); n != 0 || err == nil {
t.Fatalf("Read() = <nil>, want %v", fmt.Errorf("received frame with incorrect message type %v", wrongMsgType))
}
}
func TestIncorrectMsgType(t *testing.T) {
for _, np := range nextProtocols {
testIncorrectMsgType(t, np)
}
}
func testFrameTooLarge(t *testing.T, np string) {
buf := new(bytes.Buffer)
clientConn := newTestALTSRecordConn(nil, buf, core.ClientSide, np)
serverConn := newTestALTSRecordConn(buf, nil, core.ServerSide, np)
// payloadLen is such that the length in the framing is larger than
// allowed in one frame.
payloadLen := altsRecordLengthLimit - msgTypeFieldSize - clientConn.crypto.EncryptionOverhead() + 1
payload := make([]byte, payloadLen)
c, err := clientConn.crypto.Encrypt(nil, payload)
if err != nil {
t.Fatalf(fmt.Sprintf("Error encrypting message: %v", err))
}
msgLen := msgTypeFieldSize + len(c)
framedMsg := make([]byte, MsgLenFieldSize+msgLen)
binary.LittleEndian.PutUint32(framedMsg[:MsgLenFieldSize], uint32(msgTypeFieldSize+len(c)))
msg := framedMsg[MsgLenFieldSize:]
binary.LittleEndian.PutUint32(msg[:msgTypeFieldSize], altsRecordMsgType)
copy(msg[msgTypeFieldSize:], c)
if _, err = buf.Write(framedMsg); err != nil {
t.Fatal(fmt.Sprintf("Unexpected error writing to buffer: %v", err))
}
b := make([]byte, 1)
if n, err := serverConn.Read(b); n != 0 || err == nil {
t.Fatalf("Read() = <nil>, want %v", fmt.Errorf("received the frame length %d larger than the limit %d", altsRecordLengthLimit+1, altsRecordLengthLimit))
}
}
func TestFrameTooLarge(t *testing.T) {
for _, np := range nextProtocols {
testFrameTooLarge(t, np)
}
}
func testWriteLargeData(t *testing.T, np string) {
// Test sending and receiving messages larger than the maximum write
// buffer size.
clientConn, serverConn := newConnPair(np)
// Message size is intentionally chosen to not be multiple of
// payloadLengthLimtit.
msgSize := altsWriteBufferMaxSize + (100 * 1024)
clientMsg := make([]byte, msgSize)
for i := 0; i < msgSize; i++ {
clientMsg[i] = 0xAA
}
if n, err := clientConn.Write(clientMsg); n != len(clientMsg) || err != nil {
t.Fatalf("Client Write() = %v, %v; want %v, <nil>", n, err, len(clientMsg))
}
// We need to keep reading until the entire message is received. The
// reason we set all bytes of the message to a value other than zero is
// to avoid ambiguous zero-init value of rcvClientMsg buffer and the
// actual received data.
rcvClientMsg := make([]byte, 0, msgSize)
numberOfExpectedFrames := int(math.Ceil(float64(msgSize) / float64(serverConn.payloadLengthLimit)))
for i := 0; i < numberOfExpectedFrames; i++ {
expectedRcvSize := serverConn.payloadLengthLimit
if i == numberOfExpectedFrames-1 {
// Last frame might be smaller.
expectedRcvSize = msgSize % serverConn.payloadLengthLimit
}
tmpBuf := make([]byte, expectedRcvSize)
if n, err := serverConn.Read(tmpBuf); n != len(tmpBuf) || err != nil {
t.Fatalf("Server Read() = %v, %v; want %v, <nil>", n, err, len(tmpBuf))
}
rcvClientMsg = append(rcvClientMsg, tmpBuf...)
}
if !reflect.DeepEqual(clientMsg, rcvClientMsg) {
t.Fatalf("Client Write()/Server Read() = %v, want %v", rcvClientMsg, clientMsg)
}
}
func TestWriteLargeData(t *testing.T) {
for _, np := range nextProtocols {
testWriteLargeData(t, np)
}
}

View file

@ -0,0 +1,63 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package conn
import core "google.golang.org/grpc/credentials/alts/internal"
// NewOutCounter returns an outgoing counter initialized to the starting sequence
// number for the client/server side of a connection.
func NewOutCounter(s core.Side, overflowLen int) (c Counter) {
c.overflowLen = overflowLen
if s == core.ServerSide {
// Server counters in ALTS record have the little-endian high bit
// set.
c.value[counterLen-1] = 0x80
}
return
}
// NewInCounter returns an incoming counter initialized to the starting sequence
// number for the client/server side of a connection. This is used in ALTS record
// to check that incoming counters are as expected, since ALTS record guarantees
// that messages are unwrapped in the same order that the peer wrapped them.
func NewInCounter(s core.Side, overflowLen int) (c Counter) {
c.overflowLen = overflowLen
if s == core.ClientSide {
// Server counters in ALTS record have the little-endian high bit
// set.
c.value[counterLen-1] = 0x80
}
return
}
// CounterFromValue creates a new counter given an initial value.
func CounterFromValue(value []byte, overflowLen int) (c Counter) {
c.overflowLen = overflowLen
copy(c.value[:], value)
return
}
// CounterSide returns the connection side (client/server) a sequence counter is
// associated with.
func CounterSide(c []byte) core.Side {
if c[counterLen-1]&0x80 == 0x80 {
return core.ServerSide
}
return core.ClientSide
}

View file

@ -0,0 +1,365 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package handshaker provides ALTS handshaking functionality for GCP.
package handshaker
import (
"errors"
"fmt"
"io"
"net"
"sync"
"golang.org/x/net/context"
grpc "google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
core "google.golang.org/grpc/credentials/alts/internal"
"google.golang.org/grpc/credentials/alts/internal/authinfo"
"google.golang.org/grpc/credentials/alts/internal/conn"
altsgrpc "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
)
const (
// The maximum byte size of receive frames.
frameLimit = 64 * 1024 // 64 KB
rekeyRecordProtocolName = "ALTSRP_GCM_AES128_REKEY"
// maxPendingHandshakes represents the maximum number of concurrent
// handshakes.
maxPendingHandshakes = 100
)
var (
hsProtocol = altspb.HandshakeProtocol_ALTS
appProtocols = []string{"grpc"}
recordProtocols = []string{rekeyRecordProtocolName}
keyLength = map[string]int{
rekeyRecordProtocolName: 44,
}
altsRecordFuncs = map[string]conn.ALTSRecordFunc{
// ALTS handshaker protocols.
rekeyRecordProtocolName: func(s core.Side, keyData []byte) (conn.ALTSRecordCrypto, error) {
return conn.NewAES128GCMRekey(s, keyData)
},
}
// control number of concurrent created (but not closed) handshakers.
mu sync.Mutex
concurrentHandshakes = int64(0)
// errDropped occurs when maxPendingHandshakes is reached.
errDropped = errors.New("maximum number of concurrent ALTS handshakes is reached")
)
func init() {
for protocol, f := range altsRecordFuncs {
if err := conn.RegisterProtocol(protocol, f); err != nil {
panic(err)
}
}
}
func acquire(n int64) bool {
mu.Lock()
success := maxPendingHandshakes-concurrentHandshakes >= n
if success {
concurrentHandshakes += n
}
mu.Unlock()
return success
}
func release(n int64) {
mu.Lock()
concurrentHandshakes -= n
if concurrentHandshakes < 0 {
mu.Unlock()
panic("bad release")
}
mu.Unlock()
}
// ClientHandshakerOptions contains the client handshaker options that can
// provided by the caller.
type ClientHandshakerOptions struct {
// ClientIdentity is the handshaker client local identity.
ClientIdentity *altspb.Identity
// TargetName is the server service account name for secure name
// checking.
TargetName string
// TargetServiceAccounts contains a list of expected target service
// accounts. One of these accounts should match one of the accounts in
// the handshaker results. Otherwise, the handshake fails.
TargetServiceAccounts []string
// RPCVersions specifies the gRPC versions accepted by the client.
RPCVersions *altspb.RpcProtocolVersions
}
// ServerHandshakerOptions contains the server handshaker options that can
// provided by the caller.
type ServerHandshakerOptions struct {
// RPCVersions specifies the gRPC versions accepted by the server.
RPCVersions *altspb.RpcProtocolVersions
}
// DefaultClientHandshakerOptions returns the default client handshaker options.
func DefaultClientHandshakerOptions() *ClientHandshakerOptions {
return &ClientHandshakerOptions{}
}
// DefaultServerHandshakerOptions returns the default client handshaker options.
func DefaultServerHandshakerOptions() *ServerHandshakerOptions {
return &ServerHandshakerOptions{}
}
// TODO: add support for future local and remote endpoint in both client options
// and server options (server options struct does not exist now. When
// caller can provide endpoints, it should be created.
// altsHandshaker is used to complete a ALTS handshaking between client and
// server. This handshaker talks to the ALTS handshaker service in the metadata
// server.
type altsHandshaker struct {
// RPC stream used to access the ALTS Handshaker service.
stream altsgrpc.HandshakerService_DoHandshakeClient
// the connection to the peer.
conn net.Conn
// client handshake options.
clientOpts *ClientHandshakerOptions
// server handshake options.
serverOpts *ServerHandshakerOptions
// defines the side doing the handshake, client or server.
side core.Side
}
// NewClientHandshaker creates a ALTS handshaker for GCP which contains an RPC
// stub created using the passed conn and used to talk to the ALTS Handshaker
// service in the metadata server.
func NewClientHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, opts *ClientHandshakerOptions) (core.Handshaker, error) {
stream, err := altsgrpc.NewHandshakerServiceClient(conn).DoHandshake(ctx, grpc.FailFast(false))
if err != nil {
return nil, err
}
return &altsHandshaker{
stream: stream,
conn: c,
clientOpts: opts,
side: core.ClientSide,
}, nil
}
// NewServerHandshaker creates a ALTS handshaker for GCP which contains an RPC
// stub created using the passed conn and used to talk to the ALTS Handshaker
// service in the metadata server.
func NewServerHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, opts *ServerHandshakerOptions) (core.Handshaker, error) {
stream, err := altsgrpc.NewHandshakerServiceClient(conn).DoHandshake(ctx, grpc.FailFast(false))
if err != nil {
return nil, err
}
return &altsHandshaker{
stream: stream,
conn: c,
serverOpts: opts,
side: core.ServerSide,
}, nil
}
// ClientHandshake starts and completes a client ALTS handshaking for GCP. Once
// done, ClientHandshake returns a secure connection.
func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) {
if !acquire(1) {
return nil, nil, errDropped
}
defer release(1)
if h.side != core.ClientSide {
return nil, nil, errors.New("only handshakers created using NewClientHandshaker can perform a client handshaker")
}
// Create target identities from service account list.
targetIdentities := make([]*altspb.Identity, 0, len(h.clientOpts.TargetServiceAccounts))
for _, account := range h.clientOpts.TargetServiceAccounts {
targetIdentities = append(targetIdentities, &altspb.Identity{
IdentityOneof: &altspb.Identity_ServiceAccount{
ServiceAccount: account,
},
})
}
req := &altspb.HandshakerReq{
ReqOneof: &altspb.HandshakerReq_ClientStart{
ClientStart: &altspb.StartClientHandshakeReq{
HandshakeSecurityProtocol: hsProtocol,
ApplicationProtocols: appProtocols,
RecordProtocols: recordProtocols,
TargetIdentities: targetIdentities,
LocalIdentity: h.clientOpts.ClientIdentity,
TargetName: h.clientOpts.TargetName,
RpcVersions: h.clientOpts.RPCVersions,
},
},
}
conn, result, err := h.doHandshake(req)
if err != nil {
return nil, nil, err
}
authInfo := authinfo.New(result)
return conn, authInfo, nil
}
// ServerHandshake starts and completes a server ALTS handshaking for GCP. Once
// done, ServerHandshake returns a secure connection.
func (h *altsHandshaker) ServerHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) {
if !acquire(1) {
return nil, nil, errDropped
}
defer release(1)
if h.side != core.ServerSide {
return nil, nil, errors.New("only handshakers created using NewServerHandshaker can perform a server handshaker")
}
p := make([]byte, frameLimit)
n, err := h.conn.Read(p)
if err != nil {
return nil, nil, err
}
// Prepare server parameters.
// TODO: currently only ALTS parameters are provided. Might need to use
// more options in the future.
params := make(map[int32]*altspb.ServerHandshakeParameters)
params[int32(altspb.HandshakeProtocol_ALTS)] = &altspb.ServerHandshakeParameters{
RecordProtocols: recordProtocols,
}
req := &altspb.HandshakerReq{
ReqOneof: &altspb.HandshakerReq_ServerStart{
ServerStart: &altspb.StartServerHandshakeReq{
ApplicationProtocols: appProtocols,
HandshakeParameters: params,
InBytes: p[:n],
RpcVersions: h.serverOpts.RPCVersions,
},
},
}
conn, result, err := h.doHandshake(req)
if err != nil {
return nil, nil, err
}
authInfo := authinfo.New(result)
return conn, authInfo, nil
}
func (h *altsHandshaker) doHandshake(req *altspb.HandshakerReq) (net.Conn, *altspb.HandshakerResult, error) {
resp, err := h.accessHandshakerService(req)
if err != nil {
return nil, nil, err
}
// Check of the returned status is an error.
if resp.GetStatus() != nil {
if got, want := resp.GetStatus().Code, uint32(codes.OK); got != want {
return nil, nil, fmt.Errorf("%v", resp.GetStatus().Details)
}
}
var extra []byte
if req.GetServerStart() != nil {
extra = req.GetServerStart().GetInBytes()[resp.GetBytesConsumed():]
}
result, extra, err := h.processUntilDone(resp, extra)
if err != nil {
return nil, nil, err
}
// The handshaker returns a 128 bytes key. It should be truncated based
// on the returned record protocol.
keyLen, ok := keyLength[result.RecordProtocol]
if !ok {
return nil, nil, fmt.Errorf("unknown resulted record protocol %v", result.RecordProtocol)
}
sc, err := conn.NewConn(h.conn, h.side, result.GetRecordProtocol(), result.KeyData[:keyLen], extra)
if err != nil {
return nil, nil, err
}
return sc, result, nil
}
func (h *altsHandshaker) accessHandshakerService(req *altspb.HandshakerReq) (*altspb.HandshakerResp, error) {
if err := h.stream.Send(req); err != nil {
return nil, err
}
resp, err := h.stream.Recv()
if err != nil {
return nil, err
}
return resp, nil
}
// processUntilDone processes the handshake until the handshaker service returns
// the results. Handshaker service takes care of frame parsing, so we read
// whatever received from the network and send it to the handshaker service.
func (h *altsHandshaker) processUntilDone(resp *altspb.HandshakerResp, extra []byte) (*altspb.HandshakerResult, []byte, error) {
for {
if len(resp.OutFrames) > 0 {
if _, err := h.conn.Write(resp.OutFrames); err != nil {
return nil, nil, err
}
}
if resp.Result != nil {
return resp.Result, extra, nil
}
buf := make([]byte, frameLimit)
n, err := h.conn.Read(buf)
if err != nil && err != io.EOF {
return nil, nil, err
}
// If there is nothing to send to the handshaker service, and
// nothing is received from the peer, then we are stuck.
// This covers the case when the peer is not responding. Note
// that handshaker service connection issues are caught in
// accessHandshakerService before we even get here.
if len(resp.OutFrames) == 0 && n == 0 {
return nil, nil, core.PeerNotRespondingError
}
// Append extra bytes from the previous interaction with the
// handshaker service with the current buffer read from conn.
p := append(extra, buf[:n]...)
resp, err = h.accessHandshakerService(&altspb.HandshakerReq{
ReqOneof: &altspb.HandshakerReq_Next{
Next: &altspb.NextHandshakeMessageReq{
InBytes: p,
},
},
})
if err != nil {
return nil, nil, err
}
// Set extra based on handshaker service response.
if n == 0 {
extra = nil
} else {
extra = buf[resp.GetBytesConsumed():n]
}
}
}
// Close terminates the Handshaker. It should be called when the caller obtains
// the secure connection.
func (h *altsHandshaker) Close() {
h.stream.CloseSend()
}

View file

@ -0,0 +1,261 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package handshaker
import (
"bytes"
"testing"
"time"
"golang.org/x/net/context"
grpc "google.golang.org/grpc"
core "google.golang.org/grpc/credentials/alts/internal"
altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
"google.golang.org/grpc/credentials/alts/internal/testutil"
)
var (
testAppProtocols = []string{"grpc"}
testRecordProtocol = rekeyRecordProtocolName
testKey = []byte{
// 44 arbitrary bytes.
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49,
0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49, 0x1f, 0x8b,
0xd2, 0x4c, 0xce, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2,
}
testServiceAccount = "test_service_account"
testTargetServiceAccounts = []string{testServiceAccount}
testClientIdentity = &altspb.Identity{
IdentityOneof: &altspb.Identity_Hostname{
Hostname: "i_am_a_client",
},
}
)
// testRPCStream mimics a altspb.HandshakerService_DoHandshakeClient object.
type testRPCStream struct {
grpc.ClientStream
t *testing.T
isClient bool
// The resp expected to be returned by Recv(). Make sure this is set to
// the content the test requires before Recv() is invoked.
recvBuf *altspb.HandshakerResp
// false if it is the first access to Handshaker service on Envelope.
first bool
// useful for testing concurrent calls.
delay time.Duration
}
func (t *testRPCStream) Recv() (*altspb.HandshakerResp, error) {
resp := t.recvBuf
t.recvBuf = nil
return resp, nil
}
func (t *testRPCStream) Send(req *altspb.HandshakerReq) error {
var resp *altspb.HandshakerResp
if !t.first {
// Generate the bytes to be returned by Recv() for the initial
// handshaking.
t.first = true
if t.isClient {
resp = &altspb.HandshakerResp{
OutFrames: testutil.MakeFrame("ClientInit"),
// Simulate consuming ServerInit.
BytesConsumed: 14,
}
} else {
resp = &altspb.HandshakerResp{
OutFrames: testutil.MakeFrame("ServerInit"),
// Simulate consuming ClientInit.
BytesConsumed: 14,
}
}
} else {
// Add delay to test concurrent calls.
cleanup := stat.Update()
defer cleanup()
time.Sleep(t.delay)
// Generate the response to be returned by Recv() for the
// follow-up handshaking.
result := &altspb.HandshakerResult{
RecordProtocol: testRecordProtocol,
KeyData: testKey,
}
resp = &altspb.HandshakerResp{
Result: result,
// Simulate consuming ClientFinished or ServerFinished.
BytesConsumed: 18,
}
}
t.recvBuf = resp
return nil
}
func (t *testRPCStream) CloseSend() error {
return nil
}
var stat testutil.Stats
func TestClientHandshake(t *testing.T) {
for _, testCase := range []struct {
delay time.Duration
numberOfHandshakes int
}{
{0 * time.Millisecond, 1},
{100 * time.Millisecond, 10 * maxPendingHandshakes},
} {
errc := make(chan error)
stat.Reset()
for i := 0; i < testCase.numberOfHandshakes; i++ {
stream := &testRPCStream{
t: t,
isClient: true,
}
// Preload the inbound frames.
f1 := testutil.MakeFrame("ServerInit")
f2 := testutil.MakeFrame("ServerFinished")
in := bytes.NewBuffer(f1)
in.Write(f2)
out := new(bytes.Buffer)
tc := testutil.NewTestConn(in, out)
chs := &altsHandshaker{
stream: stream,
conn: tc,
clientOpts: &ClientHandshakerOptions{
TargetServiceAccounts: testTargetServiceAccounts,
ClientIdentity: testClientIdentity,
},
side: core.ClientSide,
}
go func() {
_, context, err := chs.ClientHandshake(context.Background())
if err == nil && context == nil {
panic("expected non-nil ALTS context")
}
errc <- err
chs.Close()
}()
}
// Ensure all errors are expected.
for i := 0; i < testCase.numberOfHandshakes; i++ {
if err := <-errc; err != nil && err != errDropped {
t.Errorf("ClientHandshake() = _, %v, want _, <nil> or %v", err, errDropped)
}
}
// Ensure that there are no concurrent calls more than the limit.
if stat.MaxConcurrentCalls > maxPendingHandshakes {
t.Errorf("Observed %d concurrent handshakes; want <= %d", stat.MaxConcurrentCalls, maxPendingHandshakes)
}
}
}
func TestServerHandshake(t *testing.T) {
for _, testCase := range []struct {
delay time.Duration
numberOfHandshakes int
}{
{0 * time.Millisecond, 1},
{100 * time.Millisecond, 10 * maxPendingHandshakes},
} {
errc := make(chan error)
stat.Reset()
for i := 0; i < testCase.numberOfHandshakes; i++ {
stream := &testRPCStream{
t: t,
isClient: false,
}
// Preload the inbound frames.
f1 := testutil.MakeFrame("ClientInit")
f2 := testutil.MakeFrame("ClientFinished")
in := bytes.NewBuffer(f1)
in.Write(f2)
out := new(bytes.Buffer)
tc := testutil.NewTestConn(in, out)
shs := &altsHandshaker{
stream: stream,
conn: tc,
serverOpts: DefaultServerHandshakerOptions(),
side: core.ServerSide,
}
go func() {
_, context, err := shs.ServerHandshake(context.Background())
if err == nil && context == nil {
panic("expected non-nil ALTS context")
}
errc <- err
shs.Close()
}()
}
// Ensure all errors are expected.
for i := 0; i < testCase.numberOfHandshakes; i++ {
if err := <-errc; err != nil && err != errDropped {
t.Errorf("ServerHandshake() = _, %v, want _, <nil> or %v", err, errDropped)
}
}
// Ensure that there are no concurrent calls more than the limit.
if stat.MaxConcurrentCalls > maxPendingHandshakes {
t.Errorf("Observed %d concurrent handshakes; want <= %d", stat.MaxConcurrentCalls, maxPendingHandshakes)
}
}
}
// testUnresponsiveRPCStream is used for testing the PeerNotResponding case.
type testUnresponsiveRPCStream struct {
grpc.ClientStream
}
func (t *testUnresponsiveRPCStream) Recv() (*altspb.HandshakerResp, error) {
return &altspb.HandshakerResp{}, nil
}
func (t *testUnresponsiveRPCStream) Send(req *altspb.HandshakerReq) error {
return nil
}
func (t *testUnresponsiveRPCStream) CloseSend() error {
return nil
}
func TestPeerNotResponding(t *testing.T) {
stream := &testUnresponsiveRPCStream{}
chs := &altsHandshaker{
stream: stream,
conn: testutil.NewUnresponsiveTestConn(),
clientOpts: &ClientHandshakerOptions{
TargetServiceAccounts: testTargetServiceAccounts,
ClientIdentity: testClientIdentity,
},
side: core.ClientSide,
}
_, context, err := chs.ClientHandshake(context.Background())
chs.Close()
if context != nil {
t.Error("expected non-nil ALTS context")
}
if got, want := err, core.PeerNotRespondingError; got != want {
t.Errorf("ClientHandshake() = %v, want %v", got, want)
}
}

View file

@ -0,0 +1,56 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package service manages connections between the VM application and the ALTS
// handshaker service.
package service
import (
"sync"
grpc "google.golang.org/grpc"
)
var (
// hsConn represents a connection to hypervisor handshaker service.
hsConn *grpc.ClientConn
mu sync.Mutex
// hsDialer will be reassigned in tests.
hsDialer = grpc.Dial
)
type dialer func(target string, opts ...grpc.DialOption) (*grpc.ClientConn, error)
// Dial dials the handshake service in the hypervisor. If a connection has
// already been established, this function returns it. Otherwise, a new
// connection is created.
func Dial(hsAddress string) (*grpc.ClientConn, error) {
mu.Lock()
defer mu.Unlock()
if hsConn == nil {
// Create a new connection to the handshaker service. Note that
// this connection stays open until the application is closed.
var err error
hsConn, err = hsDialer(hsAddress, grpc.WithInsecure())
if err != nil {
return nil, err
}
}
return hsConn, nil
}

View file

@ -0,0 +1,69 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package service
import (
"testing"
grpc "google.golang.org/grpc"
)
const (
// The address is irrelevant in this test.
testAddress = "some_address"
)
func TestDial(t *testing.T) {
defer func() func() {
temp := hsDialer
hsDialer = func(target string, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
return &grpc.ClientConn{}, nil
}
return func() {
hsDialer = temp
}
}()
// Ensure that hsConn is nil at first.
hsConn = nil
// First call to Dial, it should create set hsConn.
conn1, err := Dial(testAddress)
if err != nil {
t.Fatalf("first call to Dial failed: %v", err)
}
if conn1 == nil {
t.Fatal("first call to Dial(_)=(nil, _), want not nil")
}
if got, want := hsConn, conn1; got != want {
t.Fatalf("hsConn=%v, want %v", got, want)
}
// Second call to Dial should return conn1 above.
conn2, err := Dial(testAddress)
if err != nil {
t.Fatalf("second call to Dial(_) failed: %v", err)
}
if got, want := conn2, conn1; got != want {
t.Fatalf("second call to Dial(_)=(%v, _), want (%v,. _)", got, want)
}
if got, want := hsConn, conn1; got != want {
t.Fatalf("hsConn=%v, want %v", got, want)
}
}

Some files were not shown because too many files have changed in this diff Show more