Merge pull request #159 from mlaventure/use-fix-hashes-for-vendoring

Update vendor.sh to use fix hashes
This commit is contained in:
Michael Crosby 2016-03-25 10:32:18 -07:00
commit ef821ea93f
79 changed files with 3291 additions and 991 deletions

View file

@ -21,9 +21,4 @@ RUN go get github.com/golang/lint/golint \
COPY . /go/src/github.com/docker/containerd COPY . /go/src/github.com/docker/containerd
# get deps, until they are in vendor
# TODO: remomve this when there is a dep tool
RUN go get -d -v github.com/docker/containerd/ctr \
&& go get -d -v github.com/docker/containerd/containerd
WORKDIR /go/src/github.com/docker/containerd WORKDIR /go/src/github.com/docker/containerd

View file

@ -4,25 +4,25 @@ set -e
rm -rf vendor/ rm -rf vendor/
source 'hack/.vendor-helpers.sh' source 'hack/.vendor-helpers.sh'
clone git github.com/Sirupsen/logrus master clone git github.com/Sirupsen/logrus 4b6ea7319e214d98c938f12692336f7ca9348d6b
clone git github.com/cloudfoundry/gosigar master clone git github.com/cloudfoundry/gosigar 3ed7c74352dae6dc00bdc8c74045375352e3ec05
clone git github.com/codegangsta/cli master clone git github.com/codegangsta/cli 9fec0fad02befc9209347cc6d620e68e1b45f74d
clone git github.com/coreos/go-systemd master clone git github.com/coreos/go-systemd 7b2428fec40033549c68f54e26e89e7ca9a9ce31
clone git github.com/cyberdelia/go-metrics-graphite master clone git github.com/cyberdelia/go-metrics-graphite 7e54b5c2aa6eaff4286c44129c3def899dff528c
clone git github.com/docker/docker master clone git github.com/docker/docker 9ff767bcc06c924fd669d881a34847aa4fbaab5e
clone git github.com/docker/go-units master clone git github.com/docker/go-units 5d2041e26a699eaca682e2ea41c8f891e1060444
clone git github.com/godbus/dbus master clone git github.com/godbus/dbus e2cf28118e66a6a63db46cf6088a35d2054d3bb0
clone git github.com/golang/glog master clone git github.com/golang/glog 23def4e6c14b4da8ac2ed8007337bc5eb5007998
clone git github.com/golang/protobuf master clone git github.com/golang/protobuf 8d92cf5fc15a4382f8964b08e1f42a75c0591aa3
clone git github.com/opencontainers/runc 5f182ce7380f41b8c60a2ecaec14996d7e9cfd4a clone git github.com/opencontainers/runc 5f182ce7380f41b8c60a2ecaec14996d7e9cfd4a
clone git github.com/opencontainers/specs/specs-go 3ce138b1934bf227a418e241ead496c383eaba1c clone git github.com/opencontainers/specs 3ce138b1934bf227a418e241ead496c383eaba1c
clone git github.com/rcrowley/go-metrics master clone git github.com/rcrowley/go-metrics eeba7bd0dd01ace6e690fa833b3f22aaec29af43
clone git github.com/satori/go.uuid master clone git github.com/satori/go.uuid f9ab0dce87d815821e221626b772e3475a0d2749
clone git github.com/syndtr/gocapability master clone git github.com/syndtr/gocapability 2c00daeb6c3b45114c80ac44119e7b8801fdd852
clone git github.com/vishvananda/netlink master clone git github.com/vishvananda/netlink adb0f53af689dd38f1443eba79489feaacf0b22e
clone git github.com/Azure/go-ansiterm master clone git github.com/Azure/go-ansiterm 70b2c90b260171e829f1ebd7c17f600c11858dbe
clone git golang.org/x/net master https://github.com/golang/net.git clone git golang.org/x/net 991d3e32f76f19ee6d9caadb3a22eae8d23315f7 https://github.com/golang/net.git
clone git google.golang.org/grpc master https://github.com/grpc/grpc-go.git clone git google.golang.org/grpc a22b6611561e9f0a3e0919690dd2caf48f14c517 https://github.com/grpc/grpc-go.git
clone git github.com/seccomp/libseccomp-golang master clone git github.com/seccomp/libseccomp-golang 1b506fc7c24eec5a3693cdcbed40d9c226cfc6a1
clean clean

View file

@ -2,6 +2,8 @@ language: go
go: go:
- 1.3 - 1.3
- 1.4 - 1.4
- 1.5
- tip - tip
install: install:
- go get -t ./... - go get -t ./...
script: GOMAXPROCS=4 GORACE="halt_on_error=1" go test -race -v ./...

View file

@ -1,10 +1,21 @@
# 0.9.0 (Unreleased) # 0.10.0
* feature: Add a test hook (#180)
* feature: `ParseLevel` is now case-insensitive (#326)
* feature: `FieldLogger` interface that generalizes `Logger` and `Entry` (#308)
* performance: avoid re-allocations on `WithFields` (#335)
# 0.9.0
* logrus/text_formatter: don't emit empty msg * logrus/text_formatter: don't emit empty msg
* logrus/hooks/airbrake: move out of main repository * logrus/hooks/airbrake: move out of main repository
* logrus/hooks/sentry: move out of main repository * logrus/hooks/sentry: move out of main repository
* logrus/hooks/papertrail: move out of main repository * logrus/hooks/papertrail: move out of main repository
* logrus/hooks/bugsnag: move out of main repository * logrus/hooks/bugsnag: move out of main repository
* logrus/core: run tests with `-race`
* logrus/core: detect TTY based on `stderr`
* logrus/core: support `WithError` on logger
* logrus/core: Solaris support
# 0.8.7 # 0.8.7

View file

@ -1,4 +1,4 @@
# Logrus <img src="http://i.imgur.com/hTeVwmJ.png" width="40" height="40" alt=":walrus:" class="emoji" title=":walrus:"/>&nbsp;[![Build Status](https://travis-ci.org/Sirupsen/logrus.svg?branch=master)](https://travis-ci.org/Sirupsen/logrus)&nbsp;[![godoc reference](https://godoc.org/github.com/Sirupsen/logrus?status.png)][godoc] # Logrus <img src="http://i.imgur.com/hTeVwmJ.png" width="40" height="40" alt=":walrus:" class="emoji" title=":walrus:"/>&nbsp;[![Build Status](https://travis-ci.org/Sirupsen/logrus.svg?branch=master)](https://travis-ci.org/Sirupsen/logrus)&nbsp;[![GoDoc](https://godoc.org/github.com/Sirupsen/logrus?status.svg)](https://godoc.org/github.com/Sirupsen/logrus)
Logrus is a structured logger for Go (golang), completely API compatible with Logrus is a structured logger for Go (golang), completely API compatible with
the standard library logger. [Godoc][godoc]. **Please note the Logrus API is not the standard library logger. [Godoc][godoc]. **Please note the Logrus API is not
@ -12,7 +12,7 @@ plain text):
![Colored](http://i.imgur.com/PY7qMwd.png) ![Colored](http://i.imgur.com/PY7qMwd.png)
With `log.Formatter = new(logrus.JSONFormatter)`, for easy parsing by logstash With `log.SetFormatter(&log.JSONFormatter{})`, for easy parsing by logstash
or Splunk: or Splunk:
```json ```json
@ -32,7 +32,7 @@ ocean","size":10,"time":"2014-03-10 19:57:38.562264131 -0400 EDT"}
"time":"2014-03-10 19:57:38.562543128 -0400 EDT"} "time":"2014-03-10 19:57:38.562543128 -0400 EDT"}
``` ```
With the default `log.Formatter = new(&log.TextFormatter{})` when a TTY is not With the default `log.SetFormatter(&log.TextFormatter{})` when a TTY is not
attached, the output is compatible with the attached, the output is compatible with the
[logfmt](http://godoc.org/github.com/kr/logfmt) format: [logfmt](http://godoc.org/github.com/kr/logfmt) format:
@ -221,6 +221,12 @@ Note: Syslog hook also support connecting to local syslog (Ex. "/dev/log" or "/v
| [InfluxDB](https://github.com/Abramovic/logrus_influxdb) | Hook for logging to influxdb | | [InfluxDB](https://github.com/Abramovic/logrus_influxdb) | Hook for logging to influxdb |
| [Octokit](https://github.com/dorajistyle/logrus-octokit-hook) | Hook for logging to github via octokit | | [Octokit](https://github.com/dorajistyle/logrus-octokit-hook) | Hook for logging to github via octokit |
| [DeferPanic](https://github.com/deferpanic/dp-logrus) | Hook for logging to DeferPanic | | [DeferPanic](https://github.com/deferpanic/dp-logrus) | Hook for logging to DeferPanic |
| [Redis-Hook](https://github.com/rogierlommers/logrus-redis-hook) | Hook for logging to a ELK stack (through Redis) |
| [Amqp-Hook](https://github.com/vladoatanasov/logrus_amqp) | Hook for logging to Amqp broker (Like RabbitMQ) |
| [KafkaLogrus](https://github.com/goibibo/KafkaLogrus) | Hook for logging to kafka |
| [Typetalk](https://github.com/dragon3/logrus-typetalk-hook) | Hook for logging to [Typetalk](https://www.typetalk.in/) |
| [ElasticSearch](https://github.com/sohlich/elogrus) | Hook for logging to ElasticSearch|
#### Level logging #### Level logging
@ -362,4 +368,21 @@ entries. It should not be a feature of the application-level logger.
| ---- | ----------- | | ---- | ----------- |
|[Logrus Mate](https://github.com/gogap/logrus_mate)|Logrus mate is a tool for Logrus to manage loggers, you can initial logger's level, hook and formatter by config file, the logger will generated with different config at different environment.| |[Logrus Mate](https://github.com/gogap/logrus_mate)|Logrus mate is a tool for Logrus to manage loggers, you can initial logger's level, hook and formatter by config file, the logger will generated with different config at different environment.|
[godoc]: https://godoc.org/github.com/Sirupsen/logrus #### Testing
Logrus has a built in facility for asserting the presence of log messages. This is implemented through the `test` hook and provides:
* decorators for existing logger (`test.NewLocal` and `test.NewGlobal`) which basically just add the `test` hook
* a test logger (`test.NewNullLogger`) that just records log messages (and does not output any):
```go
logger, hook := NewNullLogger()
logger.Error("Hello error")
assert.Equal(1, len(hook.Entries))
assert.Equal(logrus.ErrorLevel, hook.LastEntry().Level)
assert.Equal("Hello error", hook.LastEntry().Message)
hook.Reset()
assert.Nil(hook.LastEntry())
```

View file

@ -68,7 +68,7 @@ func (entry *Entry) WithField(key string, value interface{}) *Entry {
// Add a map of fields to the Entry. // Add a map of fields to the Entry.
func (entry *Entry) WithFields(fields Fields) *Entry { func (entry *Entry) WithFields(fields Fields) *Entry {
data := Fields{} data := make(Fields, len(entry.Data)+len(fields))
for k, v := range entry.Data { for k, v := range entry.Data {
data[k] = v data[k] = v
} }

View file

@ -3,6 +3,7 @@ package logrus
import ( import (
"fmt" "fmt"
"log" "log"
"strings"
) )
// Fields type, used to pass to `WithFields`. // Fields type, used to pass to `WithFields`.
@ -33,7 +34,7 @@ func (level Level) String() string {
// ParseLevel takes a string level and returns the Logrus log level constant. // ParseLevel takes a string level and returns the Logrus log level constant.
func ParseLevel(lvl string) (Level, error) { func ParseLevel(lvl string) (Level, error) {
switch lvl { switch strings.ToLower(lvl) {
case "panic": case "panic":
return PanicLevel, nil return PanicLevel, nil
case "fatal": case "fatal":
@ -52,6 +53,16 @@ func ParseLevel(lvl string) (Level, error) {
return l, fmt.Errorf("not a valid logrus Level: %q", lvl) return l, fmt.Errorf("not a valid logrus Level: %q", lvl)
} }
// A constant exposing all logging levels
var AllLevels = []Level{
PanicLevel,
FatalLevel,
ErrorLevel,
WarnLevel,
InfoLevel,
DebugLevel,
}
// These are the different logging levels. You can set the logging level to log // These are the different logging levels. You can set the logging level to log
// on your instance of logger, obtained with `logrus.New()`. // on your instance of logger, obtained with `logrus.New()`.
const ( const (
@ -96,3 +107,37 @@ type StdLogger interface {
Panicf(string, ...interface{}) Panicf(string, ...interface{})
Panicln(...interface{}) Panicln(...interface{})
} }
// The FieldLogger interface generalizes the Entry and Logger types
type FieldLogger interface {
WithField(key string, value interface{}) *Entry
WithFields(fields Fields) *Entry
WithError(err error) *Entry
Debugf(format string, args ...interface{})
Infof(format string, args ...interface{})
Printf(format string, args ...interface{})
Warnf(format string, args ...interface{})
Warningf(format string, args ...interface{})
Errorf(format string, args ...interface{})
Fatalf(format string, args ...interface{})
Panicf(format string, args ...interface{})
Debug(args ...interface{})
Info(args ...interface{})
Print(args ...interface{})
Warn(args ...interface{})
Warning(args ...interface{})
Error(args ...interface{})
Fatal(args ...interface{})
Panic(args ...interface{})
Debugln(args ...interface{})
Infoln(args ...interface{})
Println(args ...interface{})
Warnln(args ...interface{})
Warningln(args ...interface{})
Errorln(args ...interface{})
Fatalln(args ...interface{})
Panicln(args ...interface{})
}

View file

@ -2,7 +2,6 @@ language: go
sudo: false sudo: false
go: go:
- 1.0.3
- 1.1.2 - 1.1.2
- 1.2.2 - 1.2.2
- 1.3.3 - 1.3.3

View file

@ -1,16 +1,20 @@
[![Coverage](http://gocover.io/_badge/github.com/codegangsta/cli?0)](http://gocover.io/github.com/codegangsta/cli) [![Coverage](http://gocover.io/_badge/github.com/codegangsta/cli?0)](http://gocover.io/github.com/codegangsta/cli)
[![Build Status](https://travis-ci.org/codegangsta/cli.png?branch=master)](https://travis-ci.org/codegangsta/cli) [![Build Status](https://travis-ci.org/codegangsta/cli.svg?branch=master)](https://travis-ci.org/codegangsta/cli)
[![GoDoc](https://godoc.org/github.com/codegangsta/cli?status.svg)](https://godoc.org/github.com/codegangsta/cli) [![GoDoc](https://godoc.org/github.com/codegangsta/cli?status.svg)](https://godoc.org/github.com/codegangsta/cli)
[![codebeat](https://codebeat.co/badges/0a8f30aa-f975-404b-b878-5fab3ae1cc5f)](https://codebeat.co/projects/github-com-codegangsta-cli)
# cli.go # cli.go
`cli.go` is simple, fast, and fun package for building command line apps in Go. The goal is to enable developers to write fast and distributable command line applications in an expressive way. `cli.go` is simple, fast, and fun package for building command line apps in Go. The goal is to enable developers to write fast and distributable command line applications in an expressive way.
## Overview ## Overview
Command line apps are usually so tiny that there is absolutely no reason why your code should *not* be self-documenting. Things like generating help text and parsing command flags/options should not hinder productivity when writing a command line app. Command line apps are usually so tiny that there is absolutely no reason why your code should *not* be self-documenting. Things like generating help text and parsing command flags/options should not hinder productivity when writing a command line app.
**This is where `cli.go` comes into play.** `cli.go` makes command line programming fun, organized, and expressive! **This is where `cli.go` comes into play.** `cli.go` makes command line programming fun, organized, and expressive!
## Installation ## Installation
Make sure you have a working Go environment (go 1.1+ is *required*). [See the install instructions](http://golang.org/doc/install.html). Make sure you have a working Go environment (go 1.1+ is *required*). [See the install instructions](http://golang.org/doc/install.html).
To install `cli.go`, simply run: To install `cli.go`, simply run:
@ -24,7 +28,8 @@ export PATH=$PATH:$GOPATH/bin
``` ```
## Getting Started ## Getting Started
One of the philosophies behind `cli.go` is that an API should be playful and full of discovery. So a `cli.go` app can be as little as one line of code in `main()`.
One of the philosophies behind `cli.go` is that an API should be playful and full of discovery. So a `cli.go` app can be as little as one line of code in `main()`.
``` go ``` go
package main package main
@ -56,7 +61,7 @@ func main() {
app.Action = func(c *cli.Context) { app.Action = func(c *cli.Context) {
println("boom! I say!") println("boom! I say!")
} }
app.Run(os.Args) app.Run(os.Args)
} }
``` ```
@ -123,6 +128,7 @@ GLOBAL OPTIONS
``` ```
### Arguments ### Arguments
You can lookup arguments by calling the `Args` function on `cli.Context`. You can lookup arguments by calling the `Args` function on `cli.Context`.
``` go ``` go
@ -134,7 +140,9 @@ app.Action = func(c *cli.Context) {
``` ```
### Flags ### Flags
Setting and querying flags is simple. Setting and querying flags is simple.
``` go ``` go
... ...
app.Flags = []cli.Flag { app.Flags = []cli.Flag {
@ -146,7 +154,7 @@ app.Flags = []cli.Flag {
} }
app.Action = func(c *cli.Context) { app.Action = func(c *cli.Context) {
name := "someone" name := "someone"
if len(c.Args()) > 0 { if c.NArg() > 0 {
name = c.Args()[0] name = c.Args()[0]
} }
if c.String("lang") == "spanish" { if c.String("lang") == "spanish" {
@ -159,6 +167,7 @@ app.Action = func(c *cli.Context) {
``` ```
You can also set a destination variable for a flag, to which the content will be scanned. You can also set a destination variable for a flag, to which the content will be scanned.
``` go ``` go
... ...
var language string var language string
@ -172,7 +181,7 @@ app.Flags = []cli.Flag {
} }
app.Action = func(c *cli.Context) { app.Action = func(c *cli.Context) {
name := "someone" name := "someone"
if len(c.Args()) > 0 { if c.NArg() > 0 {
name = c.Args()[0] name = c.Args()[0]
} }
if language == "spanish" { if language == "spanish" {
@ -230,9 +239,52 @@ app.Flags = []cli.Flag {
} }
``` ```
#### Values from alternate input sources (YAML and others)
There is a separate package altsrc that adds support for getting flag values from other input sources like YAML.
In order to get values for a flag from an alternate input source the following code would be added to wrap an existing cli.Flag like below:
``` go
altsrc.NewIntFlag(cli.IntFlag{Name: "test"})
```
Initialization must also occur for these flags. Below is an example initializing getting data from a yaml file below.
``` go
command.Before = altsrc.InitInputSourceWithContext(command.Flags, NewYamlSourceFromFlagFunc("load"))
```
The code above will use the "load" string as a flag name to get the file name of a yaml file from the cli.Context.
It will then use that file name to initialize the yaml input source for any flags that are defined on that command.
As a note the "load" flag used would also have to be defined on the command flags in order for this code snipped to work.
Currently only YAML files are supported but developers can add support for other input sources by implementing the
altsrc.InputSourceContext for their given sources.
Here is a more complete sample of a command using YAML support:
``` go
command := &cli.Command{
Name: "test-cmd",
Aliases: []string{"tc"},
Usage: "this is for testing",
Description: "testing",
Action: func(c *cli.Context) {
// Action to run
},
Flags: []cli.Flag{
NewIntFlag(cli.IntFlag{Name: "test"}),
cli.StringFlag{Name: "load"}},
}
command.Before = InitInputSourceWithContext(command.Flags, NewYamlSourceFromFlagFunc("load"))
err := command.Run(c)
```
### Subcommands ### Subcommands
Subcommands can be defined for a more git-like command line app. Subcommands can be defined for a more git-like command line app.
```go ```go
... ...
app.Commands = []cli.Command{ app.Commands = []cli.Command{
@ -283,6 +335,7 @@ You can enable completion commands by setting the `EnableBashCompletion`
flag on the `App` object. By default, this setting will only auto-complete to flag on the `App` object. By default, this setting will only auto-complete to
show an app's subcommands, but you can write your own completion methods for show an app's subcommands, but you can write your own completion methods for
the App or its subcommands. the App or its subcommands.
```go ```go
... ...
var tasks = []string{"cook", "clean", "laundry", "eat", "sleep", "code"} var tasks = []string{"cook", "clean", "laundry", "eat", "sleep", "code"}
@ -298,7 +351,7 @@ app.Commands = []cli.Command{
}, },
BashComplete: func(c *cli.Context) { BashComplete: func(c *cli.Context) {
// This will complete if no args are passed // This will complete if no args are passed
if len(c.Args()) > 0 { if c.NArg() > 0 {
return return
} }
for _, t := range tasks { for _, t := range tasks {
@ -325,8 +378,8 @@ automatically install it there if you are distributing a package). Don't forget
to source the file to make it active in the current shell. to source the file to make it active in the current shell.
``` ```
sudo cp src/bash_autocomplete /etc/bash_completion.d/<myprogram> sudo cp src/bash_autocomplete /etc/bash_completion.d/<myprogram>
source /etc/bash_completion.d/<myprogram> source /etc/bash_completion.d/<myprogram>
``` ```
Alternatively, you can just document that users should source the generic Alternatively, you can just document that users should source the generic
@ -334,6 +387,7 @@ Alternatively, you can just document that users should source the generic
to the name of their program (as above). to the name of their program (as above).
## Contribution Guidelines ## Contribution Guidelines
Feel free to put up a pull request to fix a bug or maybe add a feature. I will give it a code review and make sure that it does not break backwards compatibility. If I or any other collaborators agree that it is in line with the vision of the project, we will work with you to get the code into a mergeable state and merge it into the master branch. Feel free to put up a pull request to fix a bug or maybe add a feature. I will give it a code review and make sure that it does not break backwards compatibility. If I or any other collaborators agree that it is in line with the vision of the project, we will work with you to get the code into a mergeable state and merge it into the master branch.
If you have contributed something significant to the project, I will most likely add you as a collaborator. As a collaborator you are given the ability to merge others pull requests. It is very important that new code does not break existing code, so be careful about what code you do choose to merge. If you have any questions feel free to link @codegangsta to the issue in question and we can review it together. If you have contributed something significant to the project, I will most likely add you as a collaborator. As a collaborator you are given the ability to merge others pull requests. It is very important that new code does not break existing code, so be careful about what code you do choose to merge. If you have any questions feel free to link @codegangsta to the issue in question and we can review it together.

View file

@ -9,7 +9,7 @@ import (
"time" "time"
) )
// App is the main structure of a cli application. It is recomended that // App is the main structure of a cli application. It is recommended that
// an app be created with the cli.NewApp() function // an app be created with the cli.NewApp() function
type App struct { type App struct {
// The name of the program. Defaults to path.Base(os.Args[0]) // The name of the program. Defaults to path.Base(os.Args[0])
@ -18,6 +18,8 @@ type App struct {
HelpName string HelpName string
// Description of the program. // Description of the program.
Usage string Usage string
// Text to override the USAGE section of help
UsageText string
// Description of the program argument format. // Description of the program argument format.
ArgsUsage string ArgsUsage string
// Version of the program // Version of the program
@ -30,7 +32,7 @@ type App struct {
EnableBashCompletion bool EnableBashCompletion bool
// Boolean to hide built-in help command // Boolean to hide built-in help command
HideHelp bool HideHelp bool
// Boolean to hide built-in version flag // Boolean to hide built-in version flag and the VERSION section of help
HideVersion bool HideVersion bool
// An action to execute when the bash-completion flag is set // An action to execute when the bash-completion flag is set
BashComplete func(context *Context) BashComplete func(context *Context)
@ -44,6 +46,10 @@ type App struct {
Action func(context *Context) Action func(context *Context)
// Execute this function if the proper command cannot be found // Execute this function if the proper command cannot be found
CommandNotFound func(context *Context, command string) CommandNotFound func(context *Context, command string)
// Execute this function, if an usage error occurs. This is useful for displaying customized usage error messages.
// This function is able to replace the original error messages.
// If this function is not set, the "Incorrect usage" is displayed and the execution is interrupted.
OnUsageError func(context *Context, err error, isSubcommand bool) error
// Compilation date // Compilation date
Compiled time.Time Compiled time.Time
// List of all authors who contributed // List of all authors who contributed
@ -74,6 +80,7 @@ func NewApp() *App {
Name: path.Base(os.Args[0]), Name: path.Base(os.Args[0]),
HelpName: path.Base(os.Args[0]), HelpName: path.Base(os.Args[0]),
Usage: "A new cli application", Usage: "A new cli application",
UsageText: "",
Version: "0.0.0", Version: "0.0.0",
BashComplete: DefaultAppComplete, BashComplete: DefaultAppComplete,
Action: helpCommand.Action, Action: helpCommand.Action,
@ -119,23 +126,26 @@ func (a *App) Run(arguments []string) (err error) {
set.SetOutput(ioutil.Discard) set.SetOutput(ioutil.Discard)
err = set.Parse(arguments[1:]) err = set.Parse(arguments[1:])
nerr := normalizeFlags(a.Flags, set) nerr := normalizeFlags(a.Flags, set)
context := NewContext(a, set, nil)
if nerr != nil { if nerr != nil {
fmt.Fprintln(a.Writer, nerr) fmt.Fprintln(a.Writer, nerr)
context := NewContext(a, set, nil)
ShowAppHelp(context) ShowAppHelp(context)
return nerr return nerr
} }
context := NewContext(a, set, nil)
if checkCompletions(context) { if checkCompletions(context) {
return nil return nil
} }
if err != nil { if err != nil {
fmt.Fprintln(a.Writer, "Incorrect Usage.") if a.OnUsageError != nil {
fmt.Fprintln(a.Writer) err := a.OnUsageError(context, err, false)
ShowAppHelp(context) return err
return err } else {
fmt.Fprintf(a.Writer, "%s\n\n", "Incorrect Usage.")
ShowAppHelp(context)
return err
}
} }
if !a.HideHelp && checkHelp(context) { if !a.HideHelp && checkHelp(context) {
@ -150,8 +160,7 @@ func (a *App) Run(arguments []string) (err error) {
if a.After != nil { if a.After != nil {
defer func() { defer func() {
afterErr := a.After(context) if afterErr := a.After(context); afterErr != nil {
if afterErr != nil {
if err != nil { if err != nil {
err = NewMultiError(err, afterErr) err = NewMultiError(err, afterErr)
} else { } else {
@ -162,8 +171,10 @@ func (a *App) Run(arguments []string) (err error) {
} }
if a.Before != nil { if a.Before != nil {
err := a.Before(context) err = a.Before(context)
if err != nil { if err != nil {
fmt.Fprintf(a.Writer, "%v\n\n", err)
ShowAppHelp(context)
return err return err
} }
} }
@ -239,10 +250,14 @@ func (a *App) RunAsSubcommand(ctx *Context) (err error) {
} }
if err != nil { if err != nil {
fmt.Fprintln(a.Writer, "Incorrect Usage.") if a.OnUsageError != nil {
fmt.Fprintln(a.Writer) err = a.OnUsageError(context, err, true)
ShowSubcommandHelp(context) return err
return err } else {
fmt.Fprintf(a.Writer, "%s\n\n", "Incorrect Usage.")
ShowSubcommandHelp(context)
return err
}
} }
if len(a.Commands) > 0 { if len(a.Commands) > 0 {

View file

@ -0,0 +1,16 @@
version: "{build}"
os: Windows Server 2012 R2
install:
- go version
- go env
build_script:
- cd %APPVEYOR_BUILD_FOLDER%
- go vet ./...
- go test -v ./...
test: off
deploy: off

View file

@ -16,6 +16,8 @@ type Command struct {
Aliases []string Aliases []string
// A short description of the usage of this command // A short description of the usage of this command
Usage string Usage string
// Custom text to show on USAGE section of help
UsageText string
// A longer explanation of how the command works // A longer explanation of how the command works
Description string Description string
// A short description of the arguments of this command // A short description of the arguments of this command
@ -25,11 +27,15 @@ type Command struct {
// An action to execute before any sub-subcommands are run, but after the context is ready // An action to execute before any sub-subcommands are run, but after the context is ready
// If a non-nil error is returned, no sub-subcommands are run // If a non-nil error is returned, no sub-subcommands are run
Before func(context *Context) error Before func(context *Context) error
// An action to execute after any subcommands are run, but after the subcommand has finished // An action to execute after any subcommands are run, but before the subcommand has finished
// It is run even if Action() panics // It is run even if Action() panics
After func(context *Context) error After func(context *Context) error
// The function to call when this command is invoked // The function to call when this command is invoked
Action func(context *Context) Action func(context *Context)
// Execute this function, if an usage error occurs. This is useful for displaying customized usage error messages.
// This function is able to replace the original error messages.
// If this function is not set, the "Incorrect usage" is displayed and the execution is interrupted.
OnUsageError func(context *Context, err error) error
// List of child commands // List of child commands
Subcommands []Command Subcommands []Command
// List of flags to parse // List of flags to parse
@ -54,8 +60,8 @@ func (c Command) FullName() string {
} }
// Invokes the command given the context, parses ctx.Args() to generate command-specific flags // Invokes the command given the context, parses ctx.Args() to generate command-specific flags
func (c Command) Run(ctx *Context) error { func (c Command) Run(ctx *Context) (err error) {
if len(c.Subcommands) > 0 || c.Before != nil || c.After != nil { if len(c.Subcommands) > 0 {
return c.startApp(ctx) return c.startApp(ctx)
} }
@ -74,7 +80,6 @@ func (c Command) Run(ctx *Context) error {
set := flagSet(c.Name, c.Flags) set := flagSet(c.Name, c.Flags)
set.SetOutput(ioutil.Discard) set.SetOutput(ioutil.Discard)
var err error
if !c.SkipFlagParsing { if !c.SkipFlagParsing {
firstFlagIndex := -1 firstFlagIndex := -1
terminatorIndex := -1 terminatorIndex := -1
@ -82,6 +87,9 @@ func (c Command) Run(ctx *Context) error {
if arg == "--" { if arg == "--" {
terminatorIndex = index terminatorIndex = index
break break
} else if arg == "-" {
// Do nothing. A dash alone is not really a flag.
continue
} else if strings.HasPrefix(arg, "-") && firstFlagIndex == -1 { } else if strings.HasPrefix(arg, "-") && firstFlagIndex == -1 {
firstFlagIndex = index firstFlagIndex = index
} }
@ -111,10 +119,15 @@ func (c Command) Run(ctx *Context) error {
} }
if err != nil { if err != nil {
fmt.Fprintln(ctx.App.Writer, "Incorrect Usage.") if c.OnUsageError != nil {
fmt.Fprintln(ctx.App.Writer) err := c.OnUsageError(ctx, err)
ShowCommandHelp(ctx, c.Name) return err
return err } else {
fmt.Fprintln(ctx.App.Writer, "Incorrect Usage.")
fmt.Fprintln(ctx.App.Writer)
ShowCommandHelp(ctx, c.Name)
return err
}
} }
nerr := normalizeFlags(c.Flags, set) nerr := normalizeFlags(c.Flags, set)
@ -133,6 +146,30 @@ func (c Command) Run(ctx *Context) error {
if checkCommandHelp(context, c.Name) { if checkCommandHelp(context, c.Name) {
return nil return nil
} }
if c.After != nil {
defer func() {
afterErr := c.After(context)
if afterErr != nil {
if err != nil {
err = NewMultiError(err, afterErr)
} else {
err = afterErr
}
}
}()
}
if c.Before != nil {
err := c.Before(context)
if err != nil {
fmt.Fprintln(ctx.App.Writer, err)
fmt.Fprintln(ctx.App.Writer)
ShowCommandHelp(ctx, c.Name)
return err
}
}
context.Command = c context.Command = c
c.Action(context) c.Action(context)
return nil return nil
@ -166,7 +203,7 @@ func (c Command) startApp(ctx *Context) error {
if c.HelpName == "" { if c.HelpName == "" {
app.HelpName = c.HelpName app.HelpName = c.HelpName
} else { } else {
app.HelpName = fmt.Sprintf("%s %s", ctx.App.Name, c.Name) app.HelpName = app.Name
} }
if c.Description != "" { if c.Description != "" {
@ -205,12 +242,9 @@ func (c Command) startApp(ctx *Context) error {
app.Action = helpSubcommand.Action app.Action = helpSubcommand.Action
} }
var newCmds []Command for index, cc := range app.Commands {
for _, cc := range app.Commands { app.Commands[index].commandNamePath = []string{c.Name, cc.Name}
cc.commandNamePath = []string{c.Name, cc.Name}
newCmds = append(newCmds, cc)
} }
app.Commands = newCmds
return app.RunAsSubcommand(ctx) return app.RunAsSubcommand(ctx)
} }

View file

@ -197,6 +197,11 @@ func (c *Context) Args() Args {
return args return args
} }
// Returns the number of the command line arguments.
func (c *Context) NArg() int {
return len(c.Args())
}
// Returns the nth argument, or else a blank string // Returns the nth argument, or else a blank string
func (a Args) Get(n int) string { func (a Args) Get(n int) string {
if len(a) > n { if len(a) > n {

View file

@ -4,6 +4,7 @@ import (
"flag" "flag"
"fmt" "fmt"
"os" "os"
"runtime"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -29,7 +30,7 @@ var HelpFlag = BoolFlag{
} }
// Flag is a common interface related to parsing flags in cli. // Flag is a common interface related to parsing flags in cli.
// For more advanced flag parsing techniques, it is recomended that // For more advanced flag parsing techniques, it is recommended that
// this interface be implemented. // this interface be implemented.
type Flag interface { type Flag interface {
fmt.Stringer fmt.Stringer
@ -73,7 +74,18 @@ type GenericFlag struct {
// help text to the user (uses the String() method of the generic flag to show // help text to the user (uses the String() method of the generic flag to show
// the value) // the value)
func (f GenericFlag) String() string { func (f GenericFlag) String() string {
return withEnvHint(f.EnvVar, fmt.Sprintf("%s%s \"%v\"\t%v", prefixFor(f.Name), f.Name, f.Value, f.Usage)) return withEnvHint(f.EnvVar, fmt.Sprintf("%s %v\t%v", prefixedNames(f.Name), f.FormatValueHelp(), f.Usage))
}
func (f GenericFlag) FormatValueHelp() string {
if f.Value == nil {
return ""
}
s := f.Value.String()
if len(s) == 0 {
return ""
}
return fmt.Sprintf("\"%s\"", s)
} }
// Apply takes the flagset and calls Set on the generic flag with the value // Apply takes the flagset and calls Set on the generic flag with the value
@ -331,16 +343,15 @@ type StringFlag struct {
// String returns the usage // String returns the usage
func (f StringFlag) String() string { func (f StringFlag) String() string {
var fmtString string return withEnvHint(f.EnvVar, fmt.Sprintf("%s %v\t%v", prefixedNames(f.Name), f.FormatValueHelp(), f.Usage))
fmtString = "%s %v\t%v" }
if len(f.Value) > 0 { func (f StringFlag) FormatValueHelp() string {
fmtString = "%s \"%v\"\t%v" s := f.Value
} else { if len(s) == 0 {
fmtString = "%s %v\t%v" return ""
} }
return fmt.Sprintf("\"%s\"", s)
return withEnvHint(f.EnvVar, fmt.Sprintf(fmtString, prefixedNames(f.Name), f.Value, f.Usage))
} }
// Apply populates the flag given the flag set and environment // Apply populates the flag given the flag set and environment
@ -521,7 +532,15 @@ func prefixedNames(fullName string) (prefixed string) {
func withEnvHint(envVar, str string) string { func withEnvHint(envVar, str string) string {
envText := "" envText := ""
if envVar != "" { if envVar != "" {
envText = fmt.Sprintf(" [$%s]", strings.Join(strings.Split(envVar, ","), ", $")) prefix := "$"
suffix := ""
sep := ", $"
if runtime.GOOS == "windows" {
prefix = "%"
suffix = "%"
sep = "%, %"
}
envText = fmt.Sprintf(" [%s%s%s]", prefix, strings.Join(strings.Split(envVar, ","), sep), suffix)
} }
return str + envText return str + envText
} }

View file

@ -15,11 +15,11 @@ var AppHelpTemplate = `NAME:
{{.Name}} - {{.Usage}} {{.Name}} - {{.Usage}}
USAGE: USAGE:
{{.HelpName}} {{if .Flags}}[global options]{{end}}{{if .Commands}} command [command options]{{end}} {{if .ArgsUsage}}{{.ArgsUsage}}{{else}}[arguments...]{{end}} {{if .UsageText}}{{.UsageText}}{{else}}{{.HelpName}} {{if .Flags}}[global options]{{end}}{{if .Commands}} command [command options]{{end}} {{if .ArgsUsage}}{{.ArgsUsage}}{{else}}[arguments...]{{end}}{{end}}
{{if .Version}} {{if .Version}}{{if not .HideVersion}}
VERSION: VERSION:
{{.Version}} {{.Version}}
{{end}}{{if len .Authors}} {{end}}{{end}}{{if len .Authors}}
AUTHOR(S): AUTHOR(S):
{{range .Authors}}{{ . }}{{end}} {{range .Authors}}{{ . }}{{end}}
{{end}}{{if .Commands}} {{end}}{{if .Commands}}
@ -180,7 +180,9 @@ func printHelp(out io.Writer, templ string, data interface{}) {
t := template.Must(template.New("help").Funcs(funcMap).Parse(templ)) t := template.Must(template.New("help").Funcs(funcMap).Parse(templ))
err := t.Execute(w, data) err := t.Execute(w, data)
if err != nil { if err != nil {
panic(err) // If the writer is closed, t.Execute will fail, and there's nothing
// we can do to recover. We could send this to os.Stderr if we need.
return
} }
w.Flush() w.Flush()
} }

View file

@ -1,4 +1,4 @@
// +build !linux,!freebsd freebsd,!cgo // +build !windows,!linux,!freebsd freebsd,!cgo
package mount package mount

View file

@ -0,0 +1,6 @@
package mount
func parseMountTable() ([]*Info, error) {
// Do NOT return an error!
return nil, nil
}

View file

@ -61,8 +61,7 @@ func ensureMountedAs(mountPoint, options string) error {
return err return err
} }
} }
mounted, err = Mounted(mountPoint) if _, err = Mounted(mountPoint); err != nil {
if err != nil {
return err return err
} }

View file

@ -23,8 +23,7 @@ func toShort(path string) (string, error) {
} }
if n > uint32(len(b)) { if n > uint32(len(b)) {
b = make([]uint16, n) b = make([]uint16, n)
n, err = syscall.GetShortPathName(&p[0], &b[0], uint32(len(b))) if _, err = syscall.GetShortPathName(&p[0], &b[0], uint32(len(b))); err != nil {
if err != nil {
return "", err return "", err
} }
} }

View file

@ -43,5 +43,10 @@ func Chtimes(name string, atime time.Time, mtime time.Time) error {
return err return err
} }
// Take platform specific action for setting create time.
if err := setCTime(name, mtime); err != nil {
return err
}
return nil return nil
} }

View file

@ -0,0 +1,14 @@
// +build !windows
package system
import (
"time"
)
//setCTime will set the create time on a file. On Unix, the create
//time is updated as a side effect of setting the modified time, so
//no action is required.
func setCTime(path string, ctime time.Time) error {
return nil
}

View file

@ -0,0 +1,27 @@
// +build windows
package system
import (
"syscall"
"time"
)
//setCTime will set the create time on a file. On Windows, this requires
//calling SetFileTime and explicitly including the create time.
func setCTime(path string, ctime time.Time) error {
ctimespec := syscall.NsecToTimespec(ctime.UnixNano())
pathp, e := syscall.UTF16PtrFromString(path)
if e != nil {
return e
}
h, e := syscall.CreateFile(pathp,
syscall.FILE_WRITE_ATTRIBUTES, syscall.FILE_SHARE_WRITE, nil,
syscall.OPEN_EXISTING, syscall.FILE_FLAG_BACKUP_SEMANTICS, 0)
if e != nil {
return e
}
defer syscall.Close(h)
c := syscall.NsecToFiletime(syscall.TimespecToNsec(ctimespec))
return syscall.SetFileTime(h, &c, nil, nil)
}

View file

@ -11,7 +11,7 @@ import (
) )
// ReadMemInfo retrieves memory statistics of the host system and returns a // ReadMemInfo retrieves memory statistics of the host system and returns a
// MemInfo type. // MemInfo type.
func ReadMemInfo() (*MemInfo, error) { func ReadMemInfo() (*MemInfo, error) {
file, err := os.Open("/proc/meminfo") file, err := os.Open("/proc/meminfo")
if err != nil { if err != nil {
@ -22,8 +22,7 @@ func ReadMemInfo() (*MemInfo, error) {
} }
// parseMemInfo parses the /proc/meminfo file into // parseMemInfo parses the /proc/meminfo file into
// a MemInfo object given a io.Reader to the file. // a MemInfo object given an io.Reader to the file.
//
// Throws error if there are problems reading from the file // Throws error if there are problems reading from the file
func parseMemInfo(reader io.Reader) (*MemInfo, error) { func parseMemInfo(reader io.Reader) (*MemInfo, error) {
meminfo := &MemInfo{} meminfo := &MemInfo{}

View file

@ -0,0 +1,15 @@
package system
import (
"syscall"
)
// fromStatT creates a system.StatT type from a syscall.Stat_t type
func fromStatT(s *syscall.Stat_t) (*StatT, error) {
return &StatT{size: s.Size,
mode: uint32(s.Mode),
uid: s.Uid,
gid: s.Gid,
rdev: uint64(s.Rdev),
mtim: s.Mtim}, nil
}

View file

@ -1,4 +1,4 @@
// +build !linux,!windows,!freebsd,!solaris // +build !linux,!windows,!freebsd,!solaris,!openbsd
package system package system

View file

@ -9,3 +9,9 @@ import "syscall"
func Unmount(dest string) error { func Unmount(dest string) error {
return syscall.Unmount(dest, 0) return syscall.Unmount(dest, 0)
} }
// CommandLineToArgv should not be used on Unix.
// It simply returns commandLine in the only element in the returned array.
func CommandLineToArgv(commandLine string) ([]string, error) {
return []string{commandLine}, nil
}

View file

@ -3,6 +3,7 @@ package system
import ( import (
"fmt" "fmt"
"syscall" "syscall"
"unsafe"
) )
// OSVersion is a wrapper for Windows version information // OSVersion is a wrapper for Windows version information
@ -34,3 +35,26 @@ func GetOSVersion() (OSVersion, error) {
func Unmount(dest string) error { func Unmount(dest string) error {
return nil return nil
} }
// CommandLineToArgv wraps the Windows syscall to turn a commandline into an argument array.
func CommandLineToArgv(commandLine string) ([]string, error) {
var argc int32
argsPtr, err := syscall.UTF16PtrFromString(commandLine)
if err != nil {
return nil, err
}
argv, err := syscall.CommandLineToArgv(argsPtr, &argc)
if err != nil {
return nil, err
}
defer syscall.LocalFree(syscall.Handle(uintptr(unsafe.Pointer(argv))))
newArgs := make([]string, argc)
for i, v := range (*argv)[:argc] {
newArgs[i] = string(syscall.UTF16ToString((*v)[:]))
}
return newArgs, nil
}

View file

@ -27,7 +27,6 @@ func MakeRaw(fd uintptr) (*State, error) {
newState := oldState.termios newState := oldState.termios
C.cfmakeraw((*C.struct_termios)(unsafe.Pointer(&newState))) C.cfmakeraw((*C.struct_termios)(unsafe.Pointer(&newState)))
newState.Oflag = newState.Oflag | C.OPOST
if err := tcset(fd, &newState); err != 0 { if err := tcset(fd, &newState); err != 0 {
return nil, err return nil, err
} }

View file

@ -127,6 +127,5 @@ func handleInterrupt(fd uintptr, state *State) {
go func() { go func() {
_ = <-sigchan _ = <-sigchan
RestoreTerminal(fd, state) RestoreTerminal(fd, state)
os.Exit(0)
}() }()
} }

View file

@ -3,21 +3,20 @@
package term package term
import ( import (
"fmt"
"io" "io"
"os" "os"
"os/signal" "os/signal"
"syscall" "syscall"
"github.com/Azure/go-ansiterm/winterm" "github.com/Azure/go-ansiterm/winterm"
"github.com/Sirupsen/logrus"
"github.com/docker/docker/pkg/system" "github.com/docker/docker/pkg/system"
"github.com/docker/docker/pkg/term/windows" "github.com/docker/docker/pkg/term/windows"
) )
// State holds the console mode for the terminal. // State holds the console mode for the terminal.
type State struct { type State struct {
mode uint32 inMode, outMode uint32
inHandle, outHandle syscall.Handle
} }
// Winsize is used for window size. // Winsize is used for window size.
@ -28,17 +27,27 @@ type Winsize struct {
y uint16 y uint16
} }
const (
// https://msdn.microsoft.com/en-us/library/windows/desktop/ms683167(v=vs.85).aspx
enableVirtualTerminalInput = 0x0200
enableVirtualTerminalProcessing = 0x0004
)
// usingNativeConsole is true if we are using the Windows native console
var usingNativeConsole bool
// StdStreams returns the standard streams (stdin, stdout, stedrr). // StdStreams returns the standard streams (stdin, stdout, stedrr).
func StdStreams() (stdIn io.ReadCloser, stdOut, stdErr io.Writer) { func StdStreams() (stdIn io.ReadCloser, stdOut, stdErr io.Writer) {
switch { switch {
case os.Getenv("ConEmuANSI") == "ON": case os.Getenv("ConEmuANSI") == "ON":
// The ConEmu shell emulates ANSI well by default. // The ConEmu terminal emulates ANSI on output streams well.
return os.Stdin, os.Stdout, os.Stderr return windows.ConEmuStreams()
case os.Getenv("MSYSTEM") != "": case os.Getenv("MSYSTEM") != "":
// MSYS (mingw) does not emulate ANSI well. // MSYS (mingw) does not emulate ANSI well.
return windows.ConsoleStreams() return windows.ConsoleStreams()
default: default:
if useNativeConsole() { if useNativeConsole() {
usingNativeConsole = true
return os.Stdin, os.Stdout, os.Stderr return os.Stdin, os.Stdout, os.Stderr
} }
return windows.ConsoleStreams() return windows.ConsoleStreams()
@ -54,7 +63,7 @@ func useNativeConsole() bool {
return false return false
} }
// Native console is not available major version 10 // Native console is not available before major version 10
if osv.MajorVersion < 10 { if osv.MajorVersion < 10 {
return false return false
} }
@ -64,6 +73,17 @@ func useNativeConsole() bool {
return false return false
} }
// Get the console modes. If this fails, we can't use the native console
state, err := getNativeConsole()
if err != nil {
return false
}
// Probe the console to see if it can be enabled.
if nil != probeNativeConsole(state) {
return false
}
// Environment variable override // Environment variable override
if e := os.Getenv("USE_NATIVE_CONSOLE"); e != "" { if e := os.Getenv("USE_NATIVE_CONSOLE"); e != "" {
if e == "1" { if e == "1" {
@ -72,32 +92,86 @@ func useNativeConsole() bool {
return false return false
} }
// Get the handle to stdout // TODO Windows. The native emulator still has issues which
stdOutHandle, err := syscall.GetStdHandle(syscall.STD_OUTPUT_HANDLE)
if err != nil {
return false
}
// Get the console mode from the consoles stdout handle
var mode uint32
if err := syscall.GetConsoleMode(stdOutHandle, &mode); err != nil {
return false
}
// Legacy mode does not have native ANSI emulation.
// https://msdn.microsoft.com/en-us/library/windows/desktop/ms683167(v=vs.85).aspx
const enableVirtualTerminalProcessing = 0x0004
if mode&enableVirtualTerminalProcessing == 0 {
return false
}
// TODO Windows (Post TP4). The native emulator still has issues which
// mean it shouldn't be enabled for everyone. Change this next line to true // mean it shouldn't be enabled for everyone. Change this next line to true
// to change the default to "enable if available". In the meantime, users // to change the default to "enable if available". In the meantime, users
// can still try it out by using USE_NATIVE_CONSOLE env variable. // can still try it out by using USE_NATIVE_CONSOLE env variable.
return false return false
} }
// getNativeConsole returns the console modes ('state') for the native Windows console
func getNativeConsole() (State, error) {
var (
err error
state State
)
// Get the handle to stdout
if state.outHandle, err = syscall.GetStdHandle(syscall.STD_OUTPUT_HANDLE); err != nil {
return state, err
}
// Get the console mode from the consoles stdout handle
if err = syscall.GetConsoleMode(state.outHandle, &state.outMode); err != nil {
return state, err
}
// Get the handle to stdin
if state.inHandle, err = syscall.GetStdHandle(syscall.STD_INPUT_HANDLE); err != nil {
return state, err
}
// Get the console mode from the consoles stdin handle
if err = syscall.GetConsoleMode(state.inHandle, &state.inMode); err != nil {
return state, err
}
return state, nil
}
// probeNativeConsole probes the console to determine if native can be supported,
func probeNativeConsole(state State) error {
if err := winterm.SetConsoleMode(uintptr(state.outHandle), state.outMode|enableVirtualTerminalProcessing); err != nil {
return err
}
defer winterm.SetConsoleMode(uintptr(state.outHandle), state.outMode)
if err := winterm.SetConsoleMode(uintptr(state.inHandle), state.inMode|enableVirtualTerminalInput); err != nil {
return err
}
defer winterm.SetConsoleMode(uintptr(state.inHandle), state.inMode)
return nil
}
// enableNativeConsole turns on native console mode
func enableNativeConsole(state State) error {
if err := winterm.SetConsoleMode(uintptr(state.outHandle), state.outMode|enableVirtualTerminalProcessing); err != nil {
return err
}
if err := winterm.SetConsoleMode(uintptr(state.inHandle), state.inMode|enableVirtualTerminalInput); err != nil {
winterm.SetConsoleMode(uintptr(state.outHandle), state.outMode) // restore out if we can
return err
}
return nil
}
// disableNativeConsole turns off native console mode
func disableNativeConsole(state *State) error {
// Try and restore both in an out before error checking.
errout := winterm.SetConsoleMode(uintptr(state.outHandle), state.outMode)
errin := winterm.SetConsoleMode(uintptr(state.inHandle), state.inMode)
if errout != nil {
return errout
}
if errin != nil {
return errin
}
return nil
}
// GetFdInfo returns the file descriptor for an os.File and indicates whether the file represents a terminal. // GetFdInfo returns the file descriptor for an os.File and indicates whether the file represents a terminal.
func GetFdInfo(in interface{}) (uintptr, bool) { func GetFdInfo(in interface{}) (uintptr, bool) {
return windows.GetHandleInfo(in) return windows.GetHandleInfo(in)
@ -105,7 +179,6 @@ func GetFdInfo(in interface{}) (uintptr, bool) {
// GetWinsize returns the window size based on the specified file descriptor. // GetWinsize returns the window size based on the specified file descriptor.
func GetWinsize(fd uintptr) (*Winsize, error) { func GetWinsize(fd uintptr) (*Winsize, error) {
info, err := winterm.GetConsoleScreenBufferInfo(fd) info, err := winterm.GetConsoleScreenBufferInfo(fd)
if err != nil { if err != nil {
return nil, err return nil, err
@ -117,58 +190,9 @@ func GetWinsize(fd uintptr) (*Winsize, error) {
x: 0, x: 0,
y: 0} y: 0}
// Note: GetWinsize is called frequently -- uncomment only for excessive details
// logrus.Debugf("[windows] GetWinsize: Console(%v)", info.String())
// logrus.Debugf("[windows] GetWinsize: Width(%v), Height(%v), x(%v), y(%v)", winsize.Width, winsize.Height, winsize.x, winsize.y)
return winsize, nil return winsize, nil
} }
// SetWinsize tries to set the specified window size for the specified file descriptor.
func SetWinsize(fd uintptr, ws *Winsize) error {
// Ensure the requested dimensions are no larger than the maximum window size
info, err := winterm.GetConsoleScreenBufferInfo(fd)
if err != nil {
return err
}
if ws.Width == 0 || ws.Height == 0 || ws.Width > uint16(info.MaximumWindowSize.X) || ws.Height > uint16(info.MaximumWindowSize.Y) {
return fmt.Errorf("Illegal window size: (%v,%v) -- Maximum allow: (%v,%v)",
ws.Width, ws.Height, info.MaximumWindowSize.X, info.MaximumWindowSize.Y)
}
// Narrow the sizes to that used by Windows
width := winterm.SHORT(ws.Width)
height := winterm.SHORT(ws.Height)
// Set the dimensions while ensuring they remain within the bounds of the backing console buffer
// -- Shrinking will always succeed. Growing may push the edges past the buffer boundary. When that occurs,
// shift the upper left just enough to keep the new window within the buffer.
rect := info.Window
if width < rect.Right-rect.Left+1 {
rect.Right = rect.Left + width - 1
} else if width > rect.Right-rect.Left+1 {
rect.Right = rect.Left + width - 1
if rect.Right >= info.Size.X {
rect.Left = info.Size.X - width
rect.Right = info.Size.X - 1
}
}
if height < rect.Bottom-rect.Top+1 {
rect.Bottom = rect.Top + height - 1
} else if height > rect.Bottom-rect.Top+1 {
rect.Bottom = rect.Top + height - 1
if rect.Bottom >= info.Size.Y {
rect.Top = info.Size.Y - height
rect.Bottom = info.Size.Y - 1
}
}
logrus.Debugf("[windows] SetWinsize: Requested((%v,%v)) Actual(%v)", ws.Width, ws.Height, rect)
return winterm.SetConsoleWindowInfo(fd, true, rect)
}
// IsTerminal returns true if the given file descriptor is a terminal. // IsTerminal returns true if the given file descriptor is a terminal.
func IsTerminal(fd uintptr) bool { func IsTerminal(fd uintptr) bool {
return windows.IsConsole(fd) return windows.IsConsole(fd)
@ -177,25 +201,36 @@ func IsTerminal(fd uintptr) bool {
// RestoreTerminal restores the terminal connected to the given file descriptor // RestoreTerminal restores the terminal connected to the given file descriptor
// to a previous state. // to a previous state.
func RestoreTerminal(fd uintptr, state *State) error { func RestoreTerminal(fd uintptr, state *State) error {
return winterm.SetConsoleMode(fd, state.mode) if usingNativeConsole {
return disableNativeConsole(state)
}
return winterm.SetConsoleMode(fd, state.outMode)
} }
// SaveState saves the state of the terminal connected to the given file descriptor. // SaveState saves the state of the terminal connected to the given file descriptor.
func SaveState(fd uintptr) (*State, error) { func SaveState(fd uintptr) (*State, error) {
if usingNativeConsole {
state, err := getNativeConsole()
if err != nil {
return nil, err
}
return &state, nil
}
mode, e := winterm.GetConsoleMode(fd) mode, e := winterm.GetConsoleMode(fd)
if e != nil { if e != nil {
return nil, e return nil, e
} }
return &State{mode}, nil
return &State{outMode: mode}, nil
} }
// DisableEcho disables echo for the terminal connected to the given file descriptor. // DisableEcho disables echo for the terminal connected to the given file descriptor.
// -- See https://msdn.microsoft.com/en-us/library/windows/desktop/ms683462(v=vs.85).aspx // -- See https://msdn.microsoft.com/en-us/library/windows/desktop/ms683462(v=vs.85).aspx
func DisableEcho(fd uintptr, state *State) error { func DisableEcho(fd uintptr, state *State) error {
mode := state.mode mode := state.inMode
mode &^= winterm.ENABLE_ECHO_INPUT mode &^= winterm.ENABLE_ECHO_INPUT
mode |= winterm.ENABLE_PROCESSED_INPUT | winterm.ENABLE_LINE_INPUT mode |= winterm.ENABLE_PROCESSED_INPUT | winterm.ENABLE_LINE_INPUT
err := winterm.SetConsoleMode(fd, mode) err := winterm.SetConsoleMode(fd, mode)
if err != nil { if err != nil {
return err return err
@ -227,10 +262,17 @@ func MakeRaw(fd uintptr) (*State, error) {
return nil, err return nil, err
} }
mode := state.inMode
if usingNativeConsole {
if err := enableNativeConsole(*state); err != nil {
return nil, err
}
mode |= enableVirtualTerminalInput
}
// See // See
// -- https://msdn.microsoft.com/en-us/library/windows/desktop/ms686033(v=vs.85).aspx // -- https://msdn.microsoft.com/en-us/library/windows/desktop/ms686033(v=vs.85).aspx
// -- https://msdn.microsoft.com/en-us/library/windows/desktop/ms683462(v=vs.85).aspx // -- https://msdn.microsoft.com/en-us/library/windows/desktop/ms683462(v=vs.85).aspx
mode := state.mode
// Disable these modes // Disable these modes
mode &^= winterm.ENABLE_ECHO_INPUT mode &^= winterm.ENABLE_ECHO_INPUT

View file

@ -0,0 +1,69 @@
package term
import (
"syscall"
"unsafe"
)
const (
getTermios = syscall.TIOCGETA
setTermios = syscall.TIOCSETA
)
// Termios magic numbers, passthrough to the ones defined in syscall.
const (
IGNBRK = syscall.IGNBRK
PARMRK = syscall.PARMRK
INLCR = syscall.INLCR
IGNCR = syscall.IGNCR
ECHONL = syscall.ECHONL
CSIZE = syscall.CSIZE
ICRNL = syscall.ICRNL
ISTRIP = syscall.ISTRIP
PARENB = syscall.PARENB
ECHO = syscall.ECHO
ICANON = syscall.ICANON
ISIG = syscall.ISIG
IXON = syscall.IXON
BRKINT = syscall.BRKINT
INPCK = syscall.INPCK
OPOST = syscall.OPOST
CS8 = syscall.CS8
IEXTEN = syscall.IEXTEN
)
// Termios is the Unix API for terminal I/O.
type Termios struct {
Iflag uint32
Oflag uint32
Cflag uint32
Lflag uint32
Cc [20]byte
Ispeed uint32
Ospeed uint32
}
// MakeRaw put the terminal connected to the given file descriptor into raw
// mode and returns the previous state of the terminal so that it can be
// restored.
func MakeRaw(fd uintptr) (*State, error) {
var oldState State
if _, _, err := syscall.Syscall(syscall.SYS_IOCTL, fd, uintptr(getTermios), uintptr(unsafe.Pointer(&oldState.termios))); err != 0 {
return nil, err
}
newState := oldState.termios
newState.Iflag &^= (IGNBRK | BRKINT | PARMRK | ISTRIP | INLCR | IGNCR | ICRNL | IXON)
newState.Oflag &^= OPOST
newState.Lflag &^= (ECHO | ECHONL | ICANON | ISIG | IEXTEN)
newState.Cflag &^= (CSIZE | PARENB)
newState.Cflag |= CS8
newState.Cc[syscall.VMIN] = 1
newState.Cc[syscall.VTIME] = 0
if _, _, err := syscall.Syscall(syscall.SYS_IOCTL, fd, uintptr(setTermios), uintptr(unsafe.Pointer(&newState))); err != 0 {
return nil, err
}
return &oldState, nil
}

View file

@ -8,8 +8,44 @@ import (
"syscall" "syscall"
"github.com/Azure/go-ansiterm/winterm" "github.com/Azure/go-ansiterm/winterm"
ansiterm "github.com/Azure/go-ansiterm"
"github.com/Sirupsen/logrus"
"io/ioutil"
) )
// ConEmuStreams returns prepared versions of console streams,
// for proper use in ConEmu terminal.
// The ConEmu terminal emulates ANSI on output streams well by default.
func ConEmuStreams() (stdIn io.ReadCloser, stdOut, stdErr io.Writer) {
if IsConsole(os.Stdin.Fd()) {
stdIn = newAnsiReader(syscall.STD_INPUT_HANDLE)
} else {
stdIn = os.Stdin
}
stdOut = os.Stdout
stdErr = os.Stderr
// WARNING (BEGIN): sourced from newAnsiWriter
logFile := ioutil.Discard
if isDebugEnv := os.Getenv(ansiterm.LogEnv); isDebugEnv == "1" {
logFile, _ = os.Create("ansiReaderWriter.log")
}
logger = &logrus.Logger{
Out: logFile,
Formatter: new(logrus.TextFormatter),
Level: logrus.DebugLevel,
}
// WARNING (END): sourced from newAnsiWriter
return stdIn, stdOut, stdErr
}
// ConsoleStreams returns a wrapped version for each standard stream referencing a console, // ConsoleStreams returns a wrapped version for each standard stream referencing a console,
// that handles ANSI character sequences. // that handles ANSI character sequences.
func ConsoleStreams() (stdIn io.ReadCloser, stdOut, stdErr io.Writer) { func ConsoleStreams() (stdIn io.ReadCloser, stdOut, stdErr io.Writer) {

View file

@ -31,7 +31,7 @@ type unitMap map[string]int64
var ( var (
decimalMap = unitMap{"k": KB, "m": MB, "g": GB, "t": TB, "p": PB} decimalMap = unitMap{"k": KB, "m": MB, "g": GB, "t": TB, "p": PB}
binaryMap = unitMap{"k": KiB, "m": MiB, "g": GiB, "t": TiB, "p": PiB} binaryMap = unitMap{"k": KiB, "m": MiB, "g": GiB, "t": TiB, "p": PiB}
sizeRegex = regexp.MustCompile(`^(\d+)([kKmMgGtTpP])?[bB]?$`) sizeRegex = regexp.MustCompile(`^(\d+(\.\d+)*) ?([kKmMgGtTpP])?[bB]?$`)
) )
var decimapAbbrs = []string{"B", "kB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"} var decimapAbbrs = []string{"B", "kB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"}
@ -77,19 +77,19 @@ func RAMInBytes(size string) (int64, error) {
// Parses the human-readable size string into the amount it represents. // Parses the human-readable size string into the amount it represents.
func parseSize(sizeStr string, uMap unitMap) (int64, error) { func parseSize(sizeStr string, uMap unitMap) (int64, error) {
matches := sizeRegex.FindStringSubmatch(sizeStr) matches := sizeRegex.FindStringSubmatch(sizeStr)
if len(matches) != 3 { if len(matches) != 4 {
return -1, fmt.Errorf("invalid size: '%s'", sizeStr) return -1, fmt.Errorf("invalid size: '%s'", sizeStr)
} }
size, err := strconv.ParseInt(matches[1], 10, 0) size, err := strconv.ParseFloat(matches[1], 64)
if err != nil { if err != nil {
return -1, err return -1, err
} }
unitPrefix := strings.ToLower(matches[2]) unitPrefix := strings.ToLower(matches[3])
if mul, ok := uMap[unitPrefix]; ok { if mul, ok := uMap[unitPrefix]; ok {
size *= mul size *= float64(mul)
} }
return size, nil return int64(size), nil
} }

View file

@ -46,7 +46,7 @@ type Conn struct {
calls map[uint32]*Call calls map[uint32]*Call
callsLck sync.RWMutex callsLck sync.RWMutex
handlers map[ObjectPath]map[string]exportWithMapping handlers map[ObjectPath]map[string]exportedObj
handlersLck sync.RWMutex handlersLck sync.RWMutex
out chan *Message out chan *Message
@ -157,7 +157,7 @@ func newConn(tr transport) (*Conn, error) {
conn.transport = tr conn.transport = tr
conn.calls = make(map[uint32]*Call) conn.calls = make(map[uint32]*Call)
conn.out = make(chan *Message, 10) conn.out = make(chan *Message, 10)
conn.handlers = make(map[ObjectPath]map[string]exportWithMapping) conn.handlers = make(map[ObjectPath]map[string]exportedObj)
conn.nextSerial = 1 conn.nextSerial = 1
conn.serialUsed = map[uint32]bool{0: true} conn.serialUsed = map[uint32]bool{0: true}
conn.busObj = conn.Object("org.freedesktop.DBus", "/org/freedesktop/DBus") conn.busObj = conn.Object("org.freedesktop.DBus", "/org/freedesktop/DBus")
@ -499,9 +499,7 @@ func (conn *Conn) sendReply(dest string, serial uint32, values ...interface{}) {
// The caller has to make sure that ch is sufficiently buffered; if a message // The caller has to make sure that ch is sufficiently buffered; if a message
// arrives when a write to c is not possible, it is discarded. // arrives when a write to c is not possible, it is discarded.
// //
// Multiple of these channels can be registered at the same time. Passing a // Multiple of these channels can be registered at the same time.
// channel that already is registered will remove it from the list of the
// registered channels.
// //
// These channels are "overwritten" by Eavesdrop; i.e., if there currently is a // These channels are "overwritten" by Eavesdrop; i.e., if there currently is a
// channel for eavesdropped messages, this channel receives all signals, and // channel for eavesdropped messages, this channel receives all signals, and
@ -512,6 +510,19 @@ func (conn *Conn) Signal(ch chan<- *Signal) {
conn.signalsLck.Unlock() conn.signalsLck.Unlock()
} }
// RemoveSignal removes the given channel from the list of the registered channels.
func (conn *Conn) RemoveSignal(ch chan<- *Signal) {
conn.signalsLck.Lock()
for i := len(conn.signals) - 1; i >= 0; i-- {
if ch == conn.signals[i] {
copy(conn.signals[i:], conn.signals[i+1:])
conn.signals[len(conn.signals)-1] = nil
conn.signals = conn.signals[:len(conn.signals)-1]
}
}
conn.signalsLck.Unlock()
}
// SupportsUnixFDs returns whether the underlying transport supports passing of // SupportsUnixFDs returns whether the underlying transport supports passing of
// unix file descriptors. If this is false, method calls containing unix file // unix file descriptors. If this is false, method calls containing unix file
// descriptors will return an error and emitted signals containing them will // descriptors will return an error and emitted signals containing them will

View file

@ -1,6 +1,7 @@
package dbus package dbus
import ( import (
"bytes"
"errors" "errors"
"fmt" "fmt"
"reflect" "reflect"
@ -22,67 +23,60 @@ var (
} }
) )
// exportWithMapping represents an exported struct along with a method name // exportedObj represents an exported object. It stores a precomputed
// mapping to allow for exporting lower-case methods, etc. // method table that represents the methods exported on the bus.
type exportWithMapping struct { type exportedObj struct {
export interface{} methods map[string]reflect.Value
// Method name mapping; key -> struct method, value -> dbus method.
mapping map[string]string
// Whether or not this export is for the entire subtree // Whether or not this export is for the entire subtree
includeSubtree bool includeSubtree bool
} }
func (obj exportedObj) Method(name string) (reflect.Value, bool) {
out, exists := obj.methods[name]
return out, exists
}
// Sender is a type which can be used in exported methods to receive the message // Sender is a type which can be used in exported methods to receive the message
// sender. // sender.
type Sender string type Sender string
func exportedMethod(export exportWithMapping, name string) reflect.Value { func computeMethodName(name string, mapping map[string]string) string {
if export.export == nil { newname, ok := mapping[name]
return reflect.Value{} if ok {
name = newname
} }
return name
}
// If a mapping was included in the export, check the map to see if we func getMethods(in interface{}, mapping map[string]string) map[string]reflect.Value {
// should be looking for a different method in the export. if in == nil {
if export.mapping != nil { return nil
for key, value := range export.mapping { }
if value == name { methods := make(map[string]reflect.Value)
name = key val := reflect.ValueOf(in)
break typ := val.Type()
} for i := 0; i < typ.NumMethod(); i++ {
methtype := typ.Method(i)
// Catch the case where a method is aliased but the client is calling method := val.Method(i)
// the original, e.g. the "Foo" method was exported mapped to t := method.Type()
// "foo," and dbus client called the original "Foo." // only track valid methods must return *Error as last arg
if key == name { // and must be exported
return reflect.Value{} if t.NumOut() == 0 ||
} t.Out(t.NumOut()-1) != reflect.TypeOf(&errmsgInvalidArg) ||
methtype.PkgPath != "" {
continue
} }
// map names while building table
methods[computeMethodName(methtype.Name, mapping)] = method
} }
return methods
value := reflect.ValueOf(export.export)
m := value.MethodByName(name)
// Catch the case of attempting to call an unexported method
method, ok := value.Type().MethodByName(name)
if !m.IsValid() || !ok || method.PkgPath != "" {
return reflect.Value{}
}
t := m.Type()
if t.NumOut() == 0 ||
t.Out(t.NumOut()-1) != reflect.TypeOf(&errmsgInvalidArg) {
return reflect.Value{}
}
return m
} }
// searchHandlers will look through all registered handlers looking for one // searchHandlers will look through all registered handlers looking for one
// to handle the given path. If a verbatim one isn't found, it will check for // to handle the given path. If a verbatim one isn't found, it will check for
// a subtree registration for the path as well. // a subtree registration for the path as well.
func (conn *Conn) searchHandlers(path ObjectPath) (map[string]exportWithMapping, bool) { func (conn *Conn) searchHandlers(path ObjectPath) (map[string]exportedObj, bool) {
conn.handlersLck.RLock() conn.handlersLck.RLock()
defer conn.handlersLck.RUnlock() defer conn.handlersLck.RUnlock()
@ -93,10 +87,10 @@ func (conn *Conn) searchHandlers(path ObjectPath) (map[string]exportWithMapping,
// If handlers weren't found for this exact path, look for a matching subtree // If handlers weren't found for this exact path, look for a matching subtree
// registration // registration
handlers = make(map[string]exportWithMapping) handlers = make(map[string]exportedObj)
path = path[:strings.LastIndex(string(path), "/")] path = path[:strings.LastIndex(string(path), "/")]
for len(path) > 0 { for len(path) > 0 {
var subtreeHandlers map[string]exportWithMapping var subtreeHandlers map[string]exportedObj
subtreeHandlers, ok = conn.handlers[path] subtreeHandlers, ok = conn.handlers[path]
if ok { if ok {
for iface, handler := range subtreeHandlers { for iface, handler := range subtreeHandlers {
@ -133,6 +127,28 @@ func (conn *Conn) handleCall(msg *Message) {
conn.sendError(errmsgUnknownMethod, sender, serial) conn.sendError(errmsgUnknownMethod, sender, serial)
} }
return return
} else if ifaceName == "org.freedesktop.DBus.Introspectable" && name == "Introspect" {
if _, ok := conn.handlers[path]; !ok {
subpath := make(map[string]struct{})
var xml bytes.Buffer
xml.WriteString("<node>")
for h, _ := range conn.handlers {
p := string(path)
if p != "/" {
p += "/"
}
if strings.HasPrefix(string(h), p) {
node_name := strings.Split(string(h[len(p):]), "/")[0]
subpath[node_name] = struct{}{}
}
}
for s, _ := range subpath {
xml.WriteString("\n\t<node name=\"" + s + "\"/>")
}
xml.WriteString("\n</node>")
conn.sendReply(sender, serial, xml.String())
return
}
} }
if len(name) == 0 { if len(name) == 0 {
conn.sendError(errmsgUnknownMethod, sender, serial) conn.sendError(errmsgUnknownMethod, sender, serial)
@ -146,19 +162,20 @@ func (conn *Conn) handleCall(msg *Message) {
} }
var m reflect.Value var m reflect.Value
var exists bool
if hasIface { if hasIface {
iface := handlers[ifaceName] iface := handlers[ifaceName]
m = exportedMethod(iface, name) m, exists = iface.Method(name)
} else { } else {
for _, v := range handlers { for _, v := range handlers {
m = exportedMethod(v, name) m, exists = v.Method(name)
if m.IsValid() { if exists {
break break
} }
} }
} }
if !m.IsValid() { if !exists {
conn.sendError(errmsgUnknownMethod, sender, serial) conn.sendError(errmsgUnknownMethod, sender, serial)
return return
} }
@ -303,7 +320,7 @@ func (conn *Conn) Export(v interface{}, path ObjectPath, iface string) error {
// The keys in the map are the real method names (exported on the struct), and // The keys in the map are the real method names (exported on the struct), and
// the values are the method names to be exported on DBus. // the values are the method names to be exported on DBus.
func (conn *Conn) ExportWithMap(v interface{}, mapping map[string]string, path ObjectPath, iface string) error { func (conn *Conn) ExportWithMap(v interface{}, mapping map[string]string, path ObjectPath, iface string) error {
return conn.exportWithMap(v, mapping, path, iface, false) return conn.export(getMethods(v, mapping), path, iface, false)
} }
// ExportSubtree works exactly like Export but registers the given value for // ExportSubtree works exactly like Export but registers the given value for
@ -326,11 +343,48 @@ func (conn *Conn) ExportSubtree(v interface{}, path ObjectPath, iface string) er
// The keys in the map are the real method names (exported on the struct), and // The keys in the map are the real method names (exported on the struct), and
// the values are the method names to be exported on DBus. // the values are the method names to be exported on DBus.
func (conn *Conn) ExportSubtreeWithMap(v interface{}, mapping map[string]string, path ObjectPath, iface string) error { func (conn *Conn) ExportSubtreeWithMap(v interface{}, mapping map[string]string, path ObjectPath, iface string) error {
return conn.exportWithMap(v, mapping, path, iface, true) return conn.export(getMethods(v, mapping), path, iface, true)
}
// ExportMethodTable like Export registers the given methods as an object
// on the message bus. Unlike Export the it uses a method table to define
// the object instead of a native go object.
//
// The method table is a map from method name to function closure
// representing the method. This allows an object exported on the bus to not
// necessarily be a native go object. It can be useful for generating exposed
// methods on the fly.
//
// Any non-function objects in the method table are ignored.
func (conn *Conn) ExportMethodTable(methods map[string]interface{}, path ObjectPath, iface string) error {
return conn.exportMethodTable(methods, path, iface, false)
}
// Like ExportSubtree, but with the same caveats as ExportMethodTable.
func (conn *Conn) ExportSubtreeMethodTable(methods map[string]interface{}, path ObjectPath, iface string) error {
return conn.exportMethodTable(methods, path, iface, true)
}
func (conn *Conn) exportMethodTable(methods map[string]interface{}, path ObjectPath, iface string, includeSubtree bool) error {
out := make(map[string]reflect.Value)
for name, method := range methods {
rval := reflect.ValueOf(method)
if rval.Kind() != reflect.Func {
continue
}
t := rval.Type()
// only track valid methods must return *Error as last arg
if t.NumOut() == 0 ||
t.Out(t.NumOut()-1) != reflect.TypeOf(&errmsgInvalidArg) {
continue
}
out[name] = rval
}
return conn.export(out, path, iface, includeSubtree)
} }
// exportWithMap is the worker function for all exports/registrations. // exportWithMap is the worker function for all exports/registrations.
func (conn *Conn) exportWithMap(v interface{}, mapping map[string]string, path ObjectPath, iface string, includeSubtree bool) error { func (conn *Conn) export(methods map[string]reflect.Value, path ObjectPath, iface string, includeSubtree bool) error {
if !path.IsValid() { if !path.IsValid() {
return fmt.Errorf(`dbus: Invalid path name: "%s"`, path) return fmt.Errorf(`dbus: Invalid path name: "%s"`, path)
} }
@ -339,7 +393,7 @@ func (conn *Conn) exportWithMap(v interface{}, mapping map[string]string, path O
defer conn.handlersLck.Unlock() defer conn.handlersLck.Unlock()
// Remove a previous export if the interface is nil // Remove a previous export if the interface is nil
if v == nil { if methods == nil {
if _, ok := conn.handlers[path]; ok { if _, ok := conn.handlers[path]; ok {
delete(conn.handlers[path], iface) delete(conn.handlers[path], iface)
if len(conn.handlers[path]) == 0 { if len(conn.handlers[path]) == 0 {
@ -353,11 +407,14 @@ func (conn *Conn) exportWithMap(v interface{}, mapping map[string]string, path O
// If this is the first handler for this path, make a new map to hold all // If this is the first handler for this path, make a new map to hold all
// handlers for this path. // handlers for this path.
if _, ok := conn.handlers[path]; !ok { if _, ok := conn.handlers[path]; !ok {
conn.handlers[path] = make(map[string]exportWithMapping) conn.handlers[path] = make(map[string]exportedObj)
} }
// Finally, save this handler // Finally, save this handler
conn.handlers[path][iface] = exportWithMapping{export: v, mapping: mapping, includeSubtree: includeSubtree} conn.handlers[path][iface] = exportedObj{
methods: methods,
includeSubtree: includeSubtree,
}
return nil return nil
} }

View file

@ -1,4 +1,4 @@
//+build !windows //+build !windows,!solaris
package dbus package dbus

View file

@ -39,5 +39,5 @@ test: install generate-test-pbs
generate-test-pbs: generate-test-pbs:
make install make install
make -C testdata make -C testdata
protoc --go_out=Mtestdata/test.proto=github.com/golang/protobuf/proto/testdata:. proto3_proto/proto3.proto protoc --go_out=Mtestdata/test.proto=github.com/golang/protobuf/proto/testdata,Mgoogle/protobuf/any.proto=github.com/golang/protobuf/ptypes/any:. proto3_proto/proto3.proto
make make

View file

@ -768,10 +768,11 @@ func (o *Buffer) dec_new_map(p *Properties, base structPointer) error {
} }
} }
keyelem, valelem := keyptr.Elem(), valptr.Elem() keyelem, valelem := keyptr.Elem(), valptr.Elem()
if !keyelem.IsValid() || !valelem.IsValid() { if !keyelem.IsValid() {
// We did not decode the key or the value in the map entry. keyelem = reflect.Zero(p.mtype.Key())
// Either way, it's an invalid map entry. }
return fmt.Errorf("proto: bad map data: missing key/val") if !valelem.IsValid() {
valelem = reflect.Zero(p.mtype.Elem())
} }
v.SetMapIndex(keyelem, valelem) v.SetMapIndex(keyelem, valelem)

View file

@ -64,6 +64,10 @@ var (
// a struct with a repeated field containing a nil element. // a struct with a repeated field containing a nil element.
errRepeatedHasNil = errors.New("proto: repeated field has nil element") errRepeatedHasNil = errors.New("proto: repeated field has nil element")
// errOneofHasNil is the error returned if Marshal is called with
// a struct with a oneof field containing a nil element.
errOneofHasNil = errors.New("proto: oneof field has nil value")
// ErrNil is the error returned if Marshal is called with nil. // ErrNil is the error returned if Marshal is called with nil.
ErrNil = errors.New("proto: Marshal called with nil") ErrNil = errors.New("proto: Marshal called with nil")
) )
@ -1222,7 +1226,9 @@ func (o *Buffer) enc_struct(prop *StructProperties, base structPointer) error {
// Do oneof fields. // Do oneof fields.
if prop.oneofMarshaler != nil { if prop.oneofMarshaler != nil {
m := structPointer_Interface(base, prop.stype).(Message) m := structPointer_Interface(base, prop.stype).(Message)
if err := prop.oneofMarshaler(m, o); err != nil { if err := prop.oneofMarshaler(m, o); err == ErrNil {
return errOneofHasNil
} else if err != nil {
return err return err
} }
} }

View file

@ -235,6 +235,7 @@ To create and play with a Test object:
test := &pb.Test{ test := &pb.Test{
Label: proto.String("hello"), Label: proto.String("hello"),
Type: proto.Int32(17), Type: proto.Int32(17),
Reps: []int64{1, 2, 3},
Optionalgroup: &pb.Test_OptionalGroup{ Optionalgroup: &pb.Test_OptionalGroup{
RequiredField: proto.String("good bye"), RequiredField: proto.String("good bye"),
}, },
@ -887,3 +888,7 @@ func isProto3Zero(v reflect.Value) bool {
} }
return false return false
} }
// ProtoPackageIsVersion1 is referenced from generated protocol buffer files
// to assert that that code is compatible with this version of the proto package.
const ProtoPackageIsVersion1 = true

View file

@ -173,6 +173,7 @@ func (sp *StructProperties) Swap(i, j int) { sp.order[i], sp.order[j] = sp.order
type Properties struct { type Properties struct {
Name string // name of the field, for error messages Name string // name of the field, for error messages
OrigName string // original name before protocol compiler (always set) OrigName string // original name before protocol compiler (always set)
JSONName string // name to use for JSON; determined by protoc
Wire string Wire string
WireType int WireType int
Tag int Tag int
@ -229,8 +230,9 @@ func (p *Properties) String() string {
if p.Packed { if p.Packed {
s += ",packed" s += ",packed"
} }
if p.OrigName != p.Name { s += ",name=" + p.OrigName
s += ",name=" + p.OrigName if p.JSONName != p.OrigName {
s += ",json=" + p.JSONName
} }
if p.proto3 { if p.proto3 {
s += ",proto3" s += ",proto3"
@ -310,6 +312,8 @@ func (p *Properties) Parse(s string) {
p.Packed = true p.Packed = true
case strings.HasPrefix(f, "name="): case strings.HasPrefix(f, "name="):
p.OrigName = f[5:] p.OrigName = f[5:]
case strings.HasPrefix(f, "json="):
p.JSONName = f[5:]
case strings.HasPrefix(f, "enum="): case strings.HasPrefix(f, "enum="):
p.Enum = f[5:] p.Enum = f[5:]
case f == "proto3": case f == "proto3":

View file

@ -175,7 +175,93 @@ type raw interface {
Bytes() []byte Bytes() []byte
} }
func writeStruct(w *textWriter, sv reflect.Value) error { func requiresQuotes(u string) bool {
// When type URL contains any characters except [0-9A-Za-z./\-]*, it must be quoted.
for _, ch := range u {
switch {
case ch == '.' || ch == '/' || ch == '_':
continue
case '0' <= ch && ch <= '9':
continue
case 'A' <= ch && ch <= 'Z':
continue
case 'a' <= ch && ch <= 'z':
continue
default:
return true
}
}
return false
}
// isAny reports whether sv is a google.protobuf.Any message
func isAny(sv reflect.Value) bool {
type wkt interface {
XXX_WellKnownType() string
}
t, ok := sv.Addr().Interface().(wkt)
return ok && t.XXX_WellKnownType() == "Any"
}
// writeProto3Any writes an expanded google.protobuf.Any message.
//
// It returns (false, nil) if sv value can't be unmarshaled (e.g. because
// required messages are not linked in).
//
// It returns (true, error) when sv was written in expanded format or an error
// was encountered.
func (tm *TextMarshaler) writeProto3Any(w *textWriter, sv reflect.Value) (bool, error) {
turl := sv.FieldByName("TypeUrl")
val := sv.FieldByName("Value")
if !turl.IsValid() || !val.IsValid() {
return true, errors.New("proto: invalid google.protobuf.Any message")
}
b, ok := val.Interface().([]byte)
if !ok {
return true, errors.New("proto: invalid google.protobuf.Any message")
}
parts := strings.Split(turl.String(), "/")
mt := MessageType(parts[len(parts)-1])
if mt == nil {
return false, nil
}
m := reflect.New(mt.Elem())
if err := Unmarshal(b, m.Interface().(Message)); err != nil {
return false, nil
}
w.Write([]byte("["))
u := turl.String()
if requiresQuotes(u) {
writeString(w, u)
} else {
w.Write([]byte(u))
}
if w.compact {
w.Write([]byte("]:<"))
} else {
w.Write([]byte("]: <\n"))
w.ind++
}
if err := tm.writeStruct(w, m.Elem()); err != nil {
return true, err
}
if w.compact {
w.Write([]byte("> "))
} else {
w.ind--
w.Write([]byte(">\n"))
}
return true, nil
}
func (tm *TextMarshaler) writeStruct(w *textWriter, sv reflect.Value) error {
if tm.ExpandAny && isAny(sv) {
if canExpand, err := tm.writeProto3Any(w, sv); canExpand {
return err
}
}
st := sv.Type() st := sv.Type()
sprops := GetProperties(st) sprops := GetProperties(st)
for i := 0; i < sv.NumField(); i++ { for i := 0; i < sv.NumField(); i++ {
@ -227,7 +313,7 @@ func writeStruct(w *textWriter, sv reflect.Value) error {
} }
continue continue
} }
if err := writeAny(w, v, props); err != nil { if err := tm.writeAny(w, v, props); err != nil {
return err return err
} }
if err := w.WriteByte('\n'); err != nil { if err := w.WriteByte('\n'); err != nil {
@ -269,7 +355,7 @@ func writeStruct(w *textWriter, sv reflect.Value) error {
return err return err
} }
} }
if err := writeAny(w, key, props.mkeyprop); err != nil { if err := tm.writeAny(w, key, props.mkeyprop); err != nil {
return err return err
} }
if err := w.WriteByte('\n'); err != nil { if err := w.WriteByte('\n'); err != nil {
@ -286,7 +372,7 @@ func writeStruct(w *textWriter, sv reflect.Value) error {
return err return err
} }
} }
if err := writeAny(w, val, props.mvalprop); err != nil { if err := tm.writeAny(w, val, props.mvalprop); err != nil {
return err return err
} }
if err := w.WriteByte('\n'); err != nil { if err := w.WriteByte('\n'); err != nil {
@ -358,7 +444,7 @@ func writeStruct(w *textWriter, sv reflect.Value) error {
} }
// Enums have a String method, so writeAny will work fine. // Enums have a String method, so writeAny will work fine.
if err := writeAny(w, fv, props); err != nil { if err := tm.writeAny(w, fv, props); err != nil {
return err return err
} }
@ -370,7 +456,7 @@ func writeStruct(w *textWriter, sv reflect.Value) error {
// Extensions (the XXX_extensions field). // Extensions (the XXX_extensions field).
pv := sv.Addr() pv := sv.Addr()
if pv.Type().Implements(extendableProtoType) { if pv.Type().Implements(extendableProtoType) {
if err := writeExtensions(w, pv); err != nil { if err := tm.writeExtensions(w, pv); err != nil {
return err return err
} }
} }
@ -400,7 +486,7 @@ func writeRaw(w *textWriter, b []byte) error {
} }
// writeAny writes an arbitrary field. // writeAny writes an arbitrary field.
func writeAny(w *textWriter, v reflect.Value, props *Properties) error { func (tm *TextMarshaler) writeAny(w *textWriter, v reflect.Value, props *Properties) error {
v = reflect.Indirect(v) v = reflect.Indirect(v)
// Floats have special cases. // Floats have special cases.
@ -449,15 +535,15 @@ func writeAny(w *textWriter, v reflect.Value, props *Properties) error {
} }
} }
w.indent() w.indent()
if tm, ok := v.Interface().(encoding.TextMarshaler); ok { if etm, ok := v.Interface().(encoding.TextMarshaler); ok {
text, err := tm.MarshalText() text, err := etm.MarshalText()
if err != nil { if err != nil {
return err return err
} }
if _, err = w.Write(text); err != nil { if _, err = w.Write(text); err != nil {
return err return err
} }
} else if err := writeStruct(w, v); err != nil { } else if err := tm.writeStruct(w, v); err != nil {
return err return err
} }
w.unindent() w.unindent()
@ -601,7 +687,7 @@ func (s int32Slice) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
// writeExtensions writes all the extensions in pv. // writeExtensions writes all the extensions in pv.
// pv is assumed to be a pointer to a protocol message struct that is extendable. // pv is assumed to be a pointer to a protocol message struct that is extendable.
func writeExtensions(w *textWriter, pv reflect.Value) error { func (tm *TextMarshaler) writeExtensions(w *textWriter, pv reflect.Value) error {
emap := extensionMaps[pv.Type().Elem()] emap := extensionMaps[pv.Type().Elem()]
ep := pv.Interface().(extendableProto) ep := pv.Interface().(extendableProto)
@ -636,13 +722,13 @@ func writeExtensions(w *textWriter, pv reflect.Value) error {
// Repeated extensions will appear as a slice. // Repeated extensions will appear as a slice.
if !desc.repeated() { if !desc.repeated() {
if err := writeExtension(w, desc.Name, pb); err != nil { if err := tm.writeExtension(w, desc.Name, pb); err != nil {
return err return err
} }
} else { } else {
v := reflect.ValueOf(pb) v := reflect.ValueOf(pb)
for i := 0; i < v.Len(); i++ { for i := 0; i < v.Len(); i++ {
if err := writeExtension(w, desc.Name, v.Index(i).Interface()); err != nil { if err := tm.writeExtension(w, desc.Name, v.Index(i).Interface()); err != nil {
return err return err
} }
} }
@ -651,7 +737,7 @@ func writeExtensions(w *textWriter, pv reflect.Value) error {
return nil return nil
} }
func writeExtension(w *textWriter, name string, pb interface{}) error { func (tm *TextMarshaler) writeExtension(w *textWriter, name string, pb interface{}) error {
if _, err := fmt.Fprintf(w, "[%s]:", name); err != nil { if _, err := fmt.Fprintf(w, "[%s]:", name); err != nil {
return err return err
} }
@ -660,7 +746,7 @@ func writeExtension(w *textWriter, name string, pb interface{}) error {
return err return err
} }
} }
if err := writeAny(w, reflect.ValueOf(pb), nil); err != nil { if err := tm.writeAny(w, reflect.ValueOf(pb), nil); err != nil {
return err return err
} }
if err := w.WriteByte('\n'); err != nil { if err := w.WriteByte('\n'); err != nil {
@ -685,7 +771,15 @@ func (w *textWriter) writeIndent() {
w.complete = false w.complete = false
} }
func marshalText(w io.Writer, pb Message, compact bool) error { // TextMarshaler is a configurable text format marshaler.
type TextMarshaler struct {
Compact bool // use compact text format (one line).
ExpandAny bool // expand google.protobuf.Any messages of known types
}
// Marshal writes a given protocol buffer in text format.
// The only errors returned are from w.
func (tm *TextMarshaler) Marshal(w io.Writer, pb Message) error {
val := reflect.ValueOf(pb) val := reflect.ValueOf(pb)
if pb == nil || val.IsNil() { if pb == nil || val.IsNil() {
w.Write([]byte("<nil>")) w.Write([]byte("<nil>"))
@ -700,11 +794,11 @@ func marshalText(w io.Writer, pb Message, compact bool) error {
aw := &textWriter{ aw := &textWriter{
w: ww, w: ww,
complete: true, complete: true,
compact: compact, compact: tm.Compact,
} }
if tm, ok := pb.(encoding.TextMarshaler); ok { if etm, ok := pb.(encoding.TextMarshaler); ok {
text, err := tm.MarshalText() text, err := etm.MarshalText()
if err != nil { if err != nil {
return err return err
} }
@ -718,7 +812,7 @@ func marshalText(w io.Writer, pb Message, compact bool) error {
} }
// Dereference the received pointer so we don't have outer < and >. // Dereference the received pointer so we don't have outer < and >.
v := reflect.Indirect(val) v := reflect.Indirect(val)
if err := writeStruct(aw, v); err != nil { if err := tm.writeStruct(aw, v); err != nil {
return err return err
} }
if bw != nil { if bw != nil {
@ -727,25 +821,29 @@ func marshalText(w io.Writer, pb Message, compact bool) error {
return nil return nil
} }
// Text is the same as Marshal, but returns the string directly.
func (tm *TextMarshaler) Text(pb Message) string {
var buf bytes.Buffer
tm.Marshal(&buf, pb)
return buf.String()
}
var (
defaultTextMarshaler = TextMarshaler{}
compactTextMarshaler = TextMarshaler{Compact: true}
)
// TODO: consider removing some of the Marshal functions below.
// MarshalText writes a given protocol buffer in text format. // MarshalText writes a given protocol buffer in text format.
// The only errors returned are from w. // The only errors returned are from w.
func MarshalText(w io.Writer, pb Message) error { func MarshalText(w io.Writer, pb Message) error { return defaultTextMarshaler.Marshal(w, pb) }
return marshalText(w, pb, false)
}
// MarshalTextString is the same as MarshalText, but returns the string directly. // MarshalTextString is the same as MarshalText, but returns the string directly.
func MarshalTextString(pb Message) string { func MarshalTextString(pb Message) string { return defaultTextMarshaler.Text(pb) }
var buf bytes.Buffer
marshalText(&buf, pb, false)
return buf.String()
}
// CompactText writes a given protocol buffer in compact text format (one line). // CompactText writes a given protocol buffer in compact text format (one line).
func CompactText(w io.Writer, pb Message) error { return marshalText(w, pb, true) } func CompactText(w io.Writer, pb Message) error { return compactTextMarshaler.Marshal(w, pb) }
// CompactTextString is the same as CompactText, but returns the string directly. // CompactTextString is the same as CompactText, but returns the string directly.
func CompactTextString(pb Message) string { func CompactTextString(pb Message) string { return compactTextMarshaler.Text(pb) }
var buf bytes.Buffer
marshalText(&buf, pb, true)
return buf.String()
}

View file

@ -119,6 +119,14 @@ func isWhitespace(c byte) bool {
return false return false
} }
func isQuote(c byte) bool {
switch c {
case '"', '\'':
return true
}
return false
}
func (p *textParser) skipWhitespace() { func (p *textParser) skipWhitespace() {
i := 0 i := 0
for i < len(p.s) && (isWhitespace(p.s[i]) || p.s[i] == '#') { for i < len(p.s) && (isWhitespace(p.s[i]) || p.s[i] == '#') {
@ -155,7 +163,7 @@ func (p *textParser) advance() {
p.cur.offset, p.cur.line = p.offset, p.line p.cur.offset, p.cur.line = p.offset, p.line
p.cur.unquoted = "" p.cur.unquoted = ""
switch p.s[0] { switch p.s[0] {
case '<', '>', '{', '}', ':', '[', ']', ';', ',': case '<', '>', '{', '}', ':', '[', ']', ';', ',', '/':
// Single symbol // Single symbol
p.cur.value, p.s = p.s[0:1], p.s[1:len(p.s)] p.cur.value, p.s = p.s[0:1], p.s[1:len(p.s)]
case '"', '\'': case '"', '\'':
@ -333,13 +341,13 @@ func (p *textParser) next() *token {
p.advance() p.advance()
if p.done { if p.done {
p.cur.value = "" p.cur.value = ""
} else if len(p.cur.value) > 0 && p.cur.value[0] == '"' { } else if len(p.cur.value) > 0 && isQuote(p.cur.value[0]) {
// Look for multiple quoted strings separated by whitespace, // Look for multiple quoted strings separated by whitespace,
// and concatenate them. // and concatenate them.
cat := p.cur cat := p.cur
for { for {
p.skipWhitespace() p.skipWhitespace()
if p.done || p.s[0] != '"' { if p.done || !isQuote(p.s[0]) {
break break
} }
p.advance() p.advance()
@ -443,7 +451,10 @@ func (p *textParser) readStruct(sv reflect.Value, terminator string) error {
fieldSet := make(map[string]bool) fieldSet := make(map[string]bool)
// A struct is a sequence of "name: value", terminated by one of // A struct is a sequence of "name: value", terminated by one of
// '>' or '}', or the end of the input. A name may also be // '>' or '}', or the end of the input. A name may also be
// "[extension]". // "[extension]" or "[type/url]".
//
// The whole struct can also be an expanded Any message, like:
// [type/url] < ... struct contents ... >
for { for {
tok := p.next() tok := p.next()
if tok.err != nil { if tok.err != nil {
@ -453,33 +464,66 @@ func (p *textParser) readStruct(sv reflect.Value, terminator string) error {
break break
} }
if tok.value == "[" { if tok.value == "[" {
// Looks like an extension. // Looks like an extension or an Any.
// //
// TODO: Check whether we need to handle // TODO: Check whether we need to handle
// namespace rooted names (e.g. ".something.Foo"). // namespace rooted names (e.g. ".something.Foo").
tok = p.next() extName, err := p.consumeExtName()
if tok.err != nil { if err != nil {
return tok.err return err
} }
if s := strings.LastIndex(extName, "/"); s >= 0 {
// If it contains a slash, it's an Any type URL.
messageName := extName[s+1:]
mt := MessageType(messageName)
if mt == nil {
return p.errorf("unrecognized message %q in google.protobuf.Any", messageName)
}
tok = p.next()
if tok.err != nil {
return tok.err
}
// consume an optional colon
if tok.value == ":" {
tok = p.next()
if tok.err != nil {
return tok.err
}
}
var terminator string
switch tok.value {
case "<":
terminator = ">"
case "{":
terminator = "}"
default:
return p.errorf("expected '{' or '<', found %q", tok.value)
}
v := reflect.New(mt.Elem())
if pe := p.readStruct(v.Elem(), terminator); pe != nil {
return pe
}
b, err := Marshal(v.Interface().(Message))
if err != nil {
return p.errorf("failed to marshal message of type %q: %v", messageName, err)
}
sv.FieldByName("TypeUrl").SetString(extName)
sv.FieldByName("Value").SetBytes(b)
continue
}
var desc *ExtensionDesc var desc *ExtensionDesc
// This could be faster, but it's functional. // This could be faster, but it's functional.
// TODO: Do something smarter than a linear scan. // TODO: Do something smarter than a linear scan.
for _, d := range RegisteredExtensions(reflect.New(st).Interface().(Message)) { for _, d := range RegisteredExtensions(reflect.New(st).Interface().(Message)) {
if d.Name == tok.value { if d.Name == extName {
desc = d desc = d
break break
} }
} }
if desc == nil { if desc == nil {
return p.errorf("unrecognized extension %q", tok.value) return p.errorf("unrecognized extension %q", extName)
}
// Check the extension terminator.
tok = p.next()
if tok.err != nil {
return tok.err
}
if tok.value != "]" {
return p.errorf("unrecognized extension terminator %q", tok.value)
} }
props := &Properties{} props := &Properties{}
@ -635,6 +679,35 @@ func (p *textParser) readStruct(sv reflect.Value, terminator string) error {
return reqFieldErr return reqFieldErr
} }
// consumeExtName consumes extension name or expanded Any type URL and the
// following ']'. It returns the name or URL consumed.
func (p *textParser) consumeExtName() (string, error) {
tok := p.next()
if tok.err != nil {
return "", tok.err
}
// If extension name or type url is quoted, it's a single token.
if len(tok.value) > 2 && isQuote(tok.value[0]) && tok.value[len(tok.value)-1] == tok.value[0] {
name, err := unquoteC(tok.value[1:len(tok.value)-1], rune(tok.value[0]))
if err != nil {
return "", err
}
return name, p.consumeToken("]")
}
// Consume everything up to "]"
var parts []string
for tok.value != "]" {
parts = append(parts, tok.value)
tok = p.next()
if tok.err != nil {
return "", p.errorf("unrecognized type_url or extension name: %s", tok.err)
}
}
return strings.Join(parts, ""), nil
}
// consumeOptionalSeparator consumes an optional semicolon or comma. // consumeOptionalSeparator consumes an optional semicolon or comma.
// It is used in readStruct to provide backward compatibility. // It is used in readStruct to provide backward compatibility.
func (p *textParser) consumeOptionalSeparator() error { func (p *textParser) consumeOptionalSeparator() error {

View file

@ -103,6 +103,19 @@ import "github.com/rcrowley/go-metrics/stathat"
go stathat.Stathat(metrics.DefaultRegistry, 10e9, "example@example.com") go stathat.Stathat(metrics.DefaultRegistry, 10e9, "example@example.com")
``` ```
Maintain all metrics along with expvars at `/debug/metrics`:
This uses the same mechanism as [the official expvar](http://golang.org/pkg/expvar/)
but exposed under `/debug/metrics`, which shows a json representation of all your usual expvars
as well as all your go-metrics.
```go
import "github.com/rcrowley/go-metrics/exp"
exp.Exp(metrics.DefaultRegistry)
```
Installation Installation
------------ ------------

View file

@ -2,6 +2,7 @@ package metrics
import ( import (
"runtime" "runtime"
"runtime/pprof"
"time" "time"
) )
@ -39,6 +40,7 @@ var (
} }
NumCgoCall Gauge NumCgoCall Gauge
NumGoroutine Gauge NumGoroutine Gauge
NumThread Gauge
ReadMemStats Timer ReadMemStats Timer
} }
frees uint64 frees uint64
@ -46,6 +48,8 @@ var (
mallocs uint64 mallocs uint64
numGC uint32 numGC uint32
numCgoCalls int64 numCgoCalls int64
threadCreateProfile = pprof.Lookup("threadcreate")
) )
// Capture new values for the Go runtime statistics exported in // Capture new values for the Go runtime statistics exported in
@ -134,6 +138,8 @@ func CaptureRuntimeMemStatsOnce(r Registry) {
numCgoCalls = currentNumCgoCalls numCgoCalls = currentNumCgoCalls
runtimeMetrics.NumGoroutine.Update(int64(runtime.NumGoroutine())) runtimeMetrics.NumGoroutine.Update(int64(runtime.NumGoroutine()))
runtimeMetrics.NumThread.Update(int64(threadCreateProfile.Count()))
} }
// Register runtimeMetrics for the Go runtime statistics exported in runtime and // Register runtimeMetrics for the Go runtime statistics exported in runtime and
@ -169,6 +175,7 @@ func RegisterRuntimeMemStats(r Registry) {
runtimeMetrics.MemStats.TotalAlloc = NewGauge() runtimeMetrics.MemStats.TotalAlloc = NewGauge()
runtimeMetrics.NumCgoCall = NewGauge() runtimeMetrics.NumCgoCall = NewGauge()
runtimeMetrics.NumGoroutine = NewGauge() runtimeMetrics.NumGoroutine = NewGauge()
runtimeMetrics.NumThread = NewGauge()
runtimeMetrics.ReadMemStats = NewTimer() runtimeMetrics.ReadMemStats = NewTimer()
r.Register("runtime.MemStats.Alloc", runtimeMetrics.MemStats.Alloc) r.Register("runtime.MemStats.Alloc", runtimeMetrics.MemStats.Alloc)
@ -200,5 +207,6 @@ func RegisterRuntimeMemStats(r Registry) {
r.Register("runtime.MemStats.TotalAlloc", runtimeMetrics.MemStats.TotalAlloc) r.Register("runtime.MemStats.TotalAlloc", runtimeMetrics.MemStats.TotalAlloc)
r.Register("runtime.NumCgoCall", runtimeMetrics.NumCgoCall) r.Register("runtime.NumCgoCall", runtimeMetrics.NumCgoCall)
r.Register("runtime.NumGoroutine", runtimeMetrics.NumGoroutine) r.Register("runtime.NumGoroutine", runtimeMetrics.NumGoroutine)
r.Register("runtime.NumThread", runtimeMetrics.NumThread)
r.Register("runtime.ReadMemStats", runtimeMetrics.ReadMemStats) r.Register("runtime.ReadMemStats", runtimeMetrics.ReadMemStats)
} }

View file

@ -1,4 +1,4 @@
Copyright (C) 2013-2015 by Maxim Bublis <b@codemonkey.ru> Copyright (C) 2013-2016 by Maxim Bublis <b@codemonkey.ru>
Permission is hereby granted, free of charge, to any person obtaining Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the a copy of this software and associated documentation files (the

View file

@ -11,7 +11,7 @@ goroot = $(addprefix ../../../,$(1))
unroot = $(subst ../../../,,$(1)) unroot = $(subst ../../../,,$(1))
fmt = $(addprefix fmt-,$(1)) fmt = $(addprefix fmt-,$(1))
all: fmt test all: test
$(call goroot,$(DEPS)): $(call goroot,$(DEPS)):
go get $(call unroot,$@) go get $(call unroot,$@)

View file

@ -2,6 +2,7 @@ package netlink
import ( import (
"fmt" "fmt"
"log"
"net" "net"
"strings" "strings"
"syscall" "syscall"
@ -85,58 +86,124 @@ func AddrList(link Link, family int) ([]Addr, error) {
return nil, err return nil, err
} }
index := 0 indexFilter := 0
if link != nil { if link != nil {
base := link.Attrs() base := link.Attrs()
ensureIndex(base) ensureIndex(base)
index = base.Index indexFilter = base.Index
} }
var res []Addr var res []Addr
for _, m := range msgs { for _, m := range msgs {
msg := nl.DeserializeIfAddrmsg(m) addr, family, ifindex, err := parseAddr(m)
if err != nil {
return res, err
}
if link != nil && msg.Index != uint32(index) { if link != nil && ifindex != indexFilter {
// Ignore messages from other interfaces // Ignore messages from other interfaces
continue continue
} }
attrs, err := nl.ParseRouteAttr(m[msg.Len():]) if family != FAMILY_ALL && msg.Family != uint8(family) {
if err != nil { continue
return nil, err
} }
var local, dst *net.IPNet
var addr Addr
for _, attr := range attrs {
switch attr.Attr.Type {
case syscall.IFA_ADDRESS:
dst = &net.IPNet{
IP: attr.Value,
Mask: net.CIDRMask(int(msg.Prefixlen), 8*len(attr.Value)),
}
case syscall.IFA_LOCAL:
local = &net.IPNet{
IP: attr.Value,
Mask: net.CIDRMask(int(msg.Prefixlen), 8*len(attr.Value)),
}
case syscall.IFA_LABEL:
addr.Label = string(attr.Value[:len(attr.Value)-1])
case IFA_FLAGS:
addr.Flags = int(native.Uint32(attr.Value[0:4]))
}
}
// IFA_LOCAL should be there but if not, fall back to IFA_ADDRESS
if local != nil {
addr.IPNet = local
} else {
addr.IPNet = dst
}
addr.Scope = int(msg.Scope)
res = append(res, addr) res = append(res, addr)
} }
return res, nil return res, nil
} }
func parseAddr(m []byte) (addr Addr, family, index int, err error) {
msg := nl.DeserializeIfAddrmsg(m)
family = -1
index = -1
attrs, err1 := nl.ParseRouteAttr(m[msg.Len():])
if err1 != nil {
err = err1
return
}
index = int(msg.Index)
var local, dst *net.IPNet
for _, attr := range attrs {
switch attr.Attr.Type {
case syscall.IFA_ADDRESS:
dst = &net.IPNet{
IP: attr.Value,
Mask: net.CIDRMask(int(msg.Prefixlen), 8*len(attr.Value)),
}
case syscall.IFA_LOCAL:
local = &net.IPNet{
IP: attr.Value,
Mask: net.CIDRMask(int(msg.Prefixlen), 8*len(attr.Value)),
}
case syscall.IFA_LABEL:
addr.Label = string(attr.Value[:len(attr.Value)-1])
case IFA_FLAGS:
addr.Flags = int(native.Uint32(attr.Value[0:4]))
}
}
// IFA_LOCAL should be there but if not, fall back to IFA_ADDRESS
if local != nil {
addr.IPNet = local
} else {
addr.IPNet = dst
}
addr.Scope = int(msg.Scope)
return
}
type AddrUpdate struct {
LinkAddress net.IPNet
LinkIndex int
NewAddr bool // true=added false=deleted
}
// AddrSubscribe takes a chan down which notifications will be sent
// when addresses change. Close the 'done' chan to stop subscription.
func AddrSubscribe(ch chan<- AddrUpdate, done <-chan struct{}) error {
s, err := nl.Subscribe(syscall.NETLINK_ROUTE, syscall.RTNLGRP_IPV4_IFADDR, syscall.RTNLGRP_IPV6_IFADDR)
if err != nil {
return err
}
if done != nil {
go func() {
<-done
s.Close()
}()
}
go func() {
defer close(ch)
for {
msgs, err := s.Receive()
if err != nil {
log.Printf("netlink.AddrSubscribe: Receive() error: %v", err)
return
}
for _, m := range msgs {
msgType := m.Header.Type
if msgType != syscall.RTM_NEWADDR && msgType != syscall.RTM_DELADDR {
log.Printf("netlink.AddrSubscribe: bad message type: %d", msgType)
continue
}
addr, _, ifindex, err := parseAddr(m.Data)
if err != nil {
log.Printf("netlink.AddrSubscribe: could not parse address: %v", err)
continue
}
ch <- AddrUpdate{LinkAddress: *addr.IPNet, LinkIndex: ifindex, NewAddr: msgType == syscall.RTM_NEWADDR}
}
}
}()
return nil
}

View file

@ -0,0 +1,60 @@
package netlink
/*
#include <asm/types.h>
#include <asm/unistd.h>
#include <errno.h>
#include <stdio.h>
#include <stdint.h>
#include <unistd.h>
static int load_simple_bpf(int prog_type) {
#ifdef __NR_bpf
// { return 1; }
__u64 __attribute__((aligned(8))) insns[] = {
0x00000001000000b7ull,
0x0000000000000095ull,
};
__u8 __attribute__((aligned(8))) license[] = "ASL2";
// Copied from a header file since libc is notoriously slow to update.
// The call will succeed or fail and that will be our indication on
// whether or not it is supported.
struct {
__u32 prog_type;
__u32 insn_cnt;
__u64 insns;
__u64 license;
__u32 log_level;
__u32 log_size;
__u64 log_buf;
__u32 kern_version;
} __attribute__((aligned(8))) attr = {
.prog_type = prog_type,
.insn_cnt = 2,
.insns = (uintptr_t)&insns,
.license = (uintptr_t)&license,
};
return syscall(__NR_bpf, 5, &attr, sizeof(attr));
#else
errno = EINVAL;
return -1;
#endif
}
*/
import "C"
type BpfProgType C.int
const (
BPF_PROG_TYPE_UNSPEC BpfProgType = iota
BPF_PROG_TYPE_SOCKET_FILTER
BPF_PROG_TYPE_KPROBE
BPF_PROG_TYPE_SCHED_CLS
BPF_PROG_TYPE_SCHED_ACT
)
// loadSimpleBpf loads a trivial bpf program for testing purposes
func loadSimpleBpf(progType BpfProgType) (int, error) {
fd, err := C.load_simple_bpf(C.int(progType))
return int(fd), err
}

View file

@ -9,7 +9,7 @@ type Class interface {
Type() string Type() string
} }
// Class represents a netlink class. A filter is associated with a link, // ClassAttrs represents a netlink class. A filter is associated with a link,
// has a handle and a parent. The root filter of a device should have a // has a handle and a parent. The root filter of a device should have a
// parent == HANDLE_ROOT. // parent == HANDLE_ROOT.
type ClassAttrs struct { type ClassAttrs struct {
@ -20,7 +20,7 @@ type ClassAttrs struct {
} }
func (q ClassAttrs) String() string { func (q ClassAttrs) String() string {
return fmt.Sprintf("{LinkIndex: %d, Handle: %s, Parent: %s, Leaf: %s}", q.LinkIndex, HandleStr(q.Handle), HandleStr(q.Parent), q.Leaf) return fmt.Sprintf("{LinkIndex: %d, Handle: %s, Parent: %s, Leaf: %d}", q.LinkIndex, HandleStr(q.Handle), HandleStr(q.Parent), q.Leaf)
} }
type HtbClassAttrs struct { type HtbClassAttrs struct {
@ -38,7 +38,7 @@ func (q HtbClassAttrs) String() string {
return fmt.Sprintf("{Rate: %d, Ceil: %d, Buffer: %d, Cbuffer: %d}", q.Rate, q.Ceil, q.Buffer, q.Cbuffer) return fmt.Sprintf("{Rate: %d, Ceil: %d, Buffer: %d, Cbuffer: %d}", q.Rate, q.Ceil, q.Buffer, q.Cbuffer)
} }
// Htb class // HtbClass represents an Htb class
type HtbClass struct { type HtbClass struct {
ClassAttrs ClassAttrs
Rate uint64 Rate uint64
@ -56,6 +56,7 @@ func NewHtbClass(attrs ClassAttrs, cattrs HtbClassAttrs) *HtbClass {
ceil := cattrs.Ceil / 8 ceil := cattrs.Ceil / 8
buffer := cattrs.Buffer buffer := cattrs.Buffer
cbuffer := cattrs.Cbuffer cbuffer := cattrs.Cbuffer
if ceil == 0 { if ceil == 0 {
ceil = rate ceil = rate
} }
@ -86,11 +87,11 @@ func (q HtbClass) String() string {
return fmt.Sprintf("{Rate: %d, Ceil: %d, Buffer: %d, Cbuffer: %d}", q.Rate, q.Ceil, q.Buffer, q.Cbuffer) return fmt.Sprintf("{Rate: %d, Ceil: %d, Buffer: %d, Cbuffer: %d}", q.Rate, q.Ceil, q.Buffer, q.Cbuffer)
} }
func (class *HtbClass) Attrs() *ClassAttrs { func (q *HtbClass) Attrs() *ClassAttrs {
return &class.ClassAttrs return &q.ClassAttrs
} }
func (class *HtbClass) Type() string { func (q *HtbClass) Type() string {
return "htb" return "htb"
} }

View file

@ -1,6 +1,7 @@
package netlink package netlink
import ( import (
"errors"
"syscall" "syscall"
"github.com/vishvananda/netlink/nl" "github.com/vishvananda/netlink/nl"
@ -65,15 +66,32 @@ func classPayload(req *nl.NetlinkRequest, class Class) error {
options := nl.NewRtAttr(nl.TCA_OPTIONS, nil) options := nl.NewRtAttr(nl.TCA_OPTIONS, nil)
if htb, ok := class.(*HtbClass); ok { if htb, ok := class.(*HtbClass); ok {
opt := nl.TcHtbCopt{} opt := nl.TcHtbCopt{}
opt.Rate.Rate = uint32(htb.Rate)
opt.Ceil.Rate = uint32(htb.Ceil)
opt.Buffer = htb.Buffer opt.Buffer = htb.Buffer
opt.Cbuffer = htb.Cbuffer opt.Cbuffer = htb.Cbuffer
opt.Quantum = htb.Quantum opt.Quantum = htb.Quantum
opt.Level = htb.Level opt.Level = htb.Level
opt.Prio = htb.Prio opt.Prio = htb.Prio
// TODO: Handle Debug properly. For now default to 0 // TODO: Handle Debug properly. For now default to 0
/* Calculate {R,C}Tab and set Rate and Ceil */
cellLog := -1
ccellLog := -1
linklayer := nl.LINKLAYER_ETHERNET
mtu := 1600
var rtab [256]uint32
var ctab [256]uint32
tcrate := nl.TcRateSpec{Rate: uint32(htb.Rate)}
if CalcRtable(&tcrate, rtab, cellLog, uint32(mtu), linklayer) < 0 {
return errors.New("HTB: failed to calculate rate table")
}
opt.Rate = tcrate
tcceil := nl.TcRateSpec{Rate: uint32(htb.Ceil)}
if CalcRtable(&tcceil, ctab, ccellLog, uint32(mtu), linklayer) < 0 {
return errors.New("HTB: failed to calculate ceil rate table")
}
opt.Ceil = tcceil
nl.NewRtAttrChild(options, nl.TCA_HTB_PARMS, opt.Serialize()) nl.NewRtAttrChild(options, nl.TCA_HTB_PARMS, opt.Serialize())
nl.NewRtAttrChild(options, nl.TCA_HTB_RTAB, SerializeRtab(rtab))
nl.NewRtAttrChild(options, nl.TCA_HTB_CTAB, SerializeRtab(ctab))
} }
req.AddData(options) req.AddData(options)
return nil return nil

View file

@ -11,7 +11,7 @@ type Filter interface {
Type() string Type() string
} }
// Filter represents a netlink filter. A filter is associated with a link, // FilterAttrs represents a netlink filter. A filter is associated with a link,
// has a handle and a parent. The root filter of a device should have a // has a handle and a parent. The root filter of a device should have a
// parent == HANDLE_ROOT. // parent == HANDLE_ROOT.
type FilterAttrs struct { type FilterAttrs struct {
@ -26,11 +26,45 @@ func (q FilterAttrs) String() string {
return fmt.Sprintf("{LinkIndex: %d, Handle: %s, Parent: %s, Priority: %d, Protocol: %d}", q.LinkIndex, HandleStr(q.Handle), HandleStr(q.Parent), q.Priority, q.Protocol) return fmt.Sprintf("{LinkIndex: %d, Handle: %s, Parent: %s, Priority: %d, Protocol: %d}", q.LinkIndex, HandleStr(q.Handle), HandleStr(q.Parent), q.Priority, q.Protocol)
} }
// Action represents an action in any supported filter.
type Action interface {
Type() string
}
type BpfAction struct {
nl.TcActBpf
Fd int
Name string
}
func (action *BpfAction) Type() string {
return "bpf"
}
type MirredAction struct {
nl.TcMirred
}
func (action *MirredAction) Type() string {
return "mirred"
}
func NewMirredAction(redirIndex int) *MirredAction {
return &MirredAction{
TcMirred: nl.TcMirred{
TcGen: nl.TcGen{Action: nl.TC_ACT_STOLEN},
Eaction: nl.TCA_EGRESS_REDIR,
Ifindex: uint32(redirIndex),
},
}
}
// U32 filters on many packet related properties // U32 filters on many packet related properties
type U32 struct { type U32 struct {
FilterAttrs FilterAttrs
// Currently only supports redirecting to another interface ClassId uint32
RedirIndex int RedirIndex int
Actions []Action
} }
func (filter *U32) Attrs() *FilterAttrs { func (filter *U32) Attrs() *FilterAttrs {
@ -57,7 +91,7 @@ type FilterFwAttrs struct {
LinkLayer int LinkLayer int
} }
// FwFilter filters on firewall marks // Fw filter filters on firewall marks
type Fw struct { type Fw struct {
FilterAttrs FilterAttrs
ClassId uint32 ClassId uint32
@ -73,8 +107,8 @@ type Fw struct {
func NewFw(attrs FilterAttrs, fattrs FilterFwAttrs) (*Fw, error) { func NewFw(attrs FilterAttrs, fattrs FilterFwAttrs) (*Fw, error) {
var rtab [256]uint32 var rtab [256]uint32
var ptab [256]uint32 var ptab [256]uint32
rcell_log := -1 rcellLog := -1
pcell_log := -1 pcellLog := -1
avrate := fattrs.AvRate / 8 avrate := fattrs.AvRate / 8
police := nl.TcPolice{} police := nl.TcPolice{}
police.Rate.Rate = fattrs.Rate / 8 police.Rate.Rate = fattrs.Rate / 8
@ -90,8 +124,8 @@ func NewFw(attrs FilterAttrs, fattrs FilterFwAttrs) (*Fw, error) {
if police.Rate.Rate != 0 { if police.Rate.Rate != 0 {
police.Rate.Mpu = fattrs.Mpu police.Rate.Mpu = fattrs.Mpu
police.Rate.Overhead = fattrs.Overhead police.Rate.Overhead = fattrs.Overhead
if CalcRtable(&police.Rate, rtab, rcell_log, fattrs.Mtu, linklayer) < 0 { if CalcRtable(&police.Rate, rtab, rcellLog, fattrs.Mtu, linklayer) < 0 {
return nil, errors.New("TBF: failed to calculate rate table.") return nil, errors.New("TBF: failed to calculate rate table")
} }
police.Burst = uint32(Xmittime(uint64(police.Rate.Rate), uint32(buffer))) police.Burst = uint32(Xmittime(uint64(police.Rate.Rate), uint32(buffer)))
} }
@ -99,8 +133,8 @@ func NewFw(attrs FilterAttrs, fattrs FilterFwAttrs) (*Fw, error) {
if police.PeakRate.Rate != 0 { if police.PeakRate.Rate != 0 {
police.PeakRate.Mpu = fattrs.Mpu police.PeakRate.Mpu = fattrs.Mpu
police.PeakRate.Overhead = fattrs.Overhead police.PeakRate.Overhead = fattrs.Overhead
if CalcRtable(&police.PeakRate, ptab, pcell_log, fattrs.Mtu, linklayer) < 0 { if CalcRtable(&police.PeakRate, ptab, pcellLog, fattrs.Mtu, linklayer) < 0 {
return nil, errors.New("POLICE: failed to calculate peak rate table.") return nil, errors.New("POLICE: failed to calculate peak rate table")
} }
} }
@ -124,6 +158,22 @@ func (filter *Fw) Type() string {
return "fw" return "fw"
} }
type BpfFilter struct {
FilterAttrs
ClassId uint32
Fd int
Name string
DirectAction bool
}
func (filter *BpfFilter) Type() string {
return "bpf"
}
func (filter *BpfFilter) Attrs() *FilterAttrs {
return &filter.FilterAttrs
}
// GenericFilter filters represent types that are not currently understood // GenericFilter filters represent types that are not currently understood
// by this netlink library. // by this netlink library.
type GenericFilter struct { type GenericFilter struct {

View file

@ -52,17 +52,17 @@ func FilterAdd(filter Filter) error {
} }
sel.Keys = append(sel.Keys, nl.TcU32Key{}) sel.Keys = append(sel.Keys, nl.TcU32Key{})
nl.NewRtAttrChild(options, nl.TCA_U32_SEL, sel.Serialize()) nl.NewRtAttrChild(options, nl.TCA_U32_SEL, sel.Serialize())
actions := nl.NewRtAttrChild(options, nl.TCA_U32_ACT, nil) if u32.ClassId != 0 {
table := nl.NewRtAttrChild(actions, nl.TCA_ACT_TAB, nil) nl.NewRtAttrChild(options, nl.TCA_U32_CLASSID, nl.Uint32Attr(u32.ClassId))
nl.NewRtAttrChild(table, nl.TCA_KIND, nl.ZeroTerminated("mirred")) }
// redirect to other interface actionsAttr := nl.NewRtAttrChild(options, nl.TCA_U32_ACT, nil)
mir := nl.TcMirred{ // backwards compatibility
Action: nl.TC_ACT_STOLEN, if u32.RedirIndex != 0 {
Eaction: nl.TCA_EGRESS_REDIR, u32.Actions = append([]Action{NewMirredAction(u32.RedirIndex)}, u32.Actions...)
Ifindex: uint32(u32.RedirIndex), }
if err := encodeActions(actionsAttr, u32.Actions); err != nil {
return err
} }
aopts := nl.NewRtAttrChild(table, nl.TCA_OPTIONS, nil)
nl.NewRtAttrChild(aopts, nl.TCA_MIRRED_PARMS, mir.Serialize())
} else if fw, ok := filter.(*Fw); ok { } else if fw, ok := filter.(*Fw); ok {
if fw.Mask != 0 { if fw.Mask != 0 {
b := make([]byte, 4) b := make([]byte, 4)
@ -90,6 +90,21 @@ func FilterAdd(filter Filter) error {
native.PutUint32(b, fw.ClassId) native.PutUint32(b, fw.ClassId)
nl.NewRtAttrChild(options, nl.TCA_FW_CLASSID, b) nl.NewRtAttrChild(options, nl.TCA_FW_CLASSID, b)
} }
} else if bpf, ok := filter.(*BpfFilter); ok {
var bpfFlags uint32
if bpf.ClassId != 0 {
nl.NewRtAttrChild(options, nl.TCA_BPF_CLASSID, nl.Uint32Attr(bpf.ClassId))
}
if bpf.Fd >= 0 {
nl.NewRtAttrChild(options, nl.TCA_BPF_FD, nl.Uint32Attr((uint32(bpf.Fd))))
}
if bpf.Name != "" {
nl.NewRtAttrChild(options, nl.TCA_BPF_NAME, nl.ZeroTerminated(bpf.Name))
}
if bpf.DirectAction {
bpfFlags |= nl.TCA_BPF_FLAG_ACT_DIRECT
}
nl.NewRtAttrChild(options, nl.TCA_BPF_FLAGS, nl.Uint32Attr(bpfFlags))
} }
req.AddData(options) req.AddData(options)
@ -147,26 +162,29 @@ func FilterList(link Link, parent uint32) ([]Filter, error) {
filter = &U32{} filter = &U32{}
case "fw": case "fw":
filter = &Fw{} filter = &Fw{}
case "bpf":
filter = &BpfFilter{}
default: default:
filter = &GenericFilter{FilterType: filterType} filter = &GenericFilter{FilterType: filterType}
} }
case nl.TCA_OPTIONS: case nl.TCA_OPTIONS:
data, err := nl.ParseRouteAttr(attr.Value)
if err != nil {
return nil, err
}
switch filterType { switch filterType {
case "u32": case "u32":
data, err := nl.ParseRouteAttr(attr.Value)
if err != nil {
return nil, err
}
detailed, err = parseU32Data(filter, data) detailed, err = parseU32Data(filter, data)
if err != nil { if err != nil {
return nil, err return nil, err
} }
case "fw": case "fw":
data, err := nl.ParseRouteAttr(attr.Value) detailed, err = parseFwData(filter, data)
if err != nil { if err != nil {
return nil, err return nil, err
} }
detailed, err = parseFwData(filter, data) case "bpf":
detailed, err = parseBpfData(filter, data)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -183,6 +201,85 @@ func FilterList(link Link, parent uint32) ([]Filter, error) {
return res, nil return res, nil
} }
func encodeActions(attr *nl.RtAttr, actions []Action) error {
tabIndex := int(nl.TCA_ACT_TAB)
for _, action := range actions {
switch action := action.(type) {
default:
return fmt.Errorf("unknown action type %s", action.Type())
case *MirredAction:
table := nl.NewRtAttrChild(attr, tabIndex, nil)
tabIndex++
nl.NewRtAttrChild(table, nl.TCA_ACT_KIND, nl.ZeroTerminated("mirred"))
aopts := nl.NewRtAttrChild(table, nl.TCA_ACT_OPTIONS, nil)
nl.NewRtAttrChild(aopts, nl.TCA_MIRRED_PARMS, action.Serialize())
case *BpfAction:
table := nl.NewRtAttrChild(attr, tabIndex, nil)
tabIndex++
nl.NewRtAttrChild(table, nl.TCA_ACT_KIND, nl.ZeroTerminated("bpf"))
aopts := nl.NewRtAttrChild(table, nl.TCA_ACT_OPTIONS, nil)
nl.NewRtAttrChild(aopts, nl.TCA_ACT_BPF_PARMS, action.Serialize())
nl.NewRtAttrChild(aopts, nl.TCA_ACT_BPF_FD, nl.Uint32Attr(uint32(action.Fd)))
nl.NewRtAttrChild(aopts, nl.TCA_ACT_BPF_NAME, nl.ZeroTerminated(action.Name))
}
}
return nil
}
func parseActions(tables []syscall.NetlinkRouteAttr) ([]Action, error) {
var actions []Action
for _, table := range tables {
var action Action
var actionType string
aattrs, err := nl.ParseRouteAttr(table.Value)
if err != nil {
return nil, err
}
nextattr:
for _, aattr := range aattrs {
switch aattr.Attr.Type {
case nl.TCA_KIND:
actionType = string(aattr.Value[:len(aattr.Value)-1])
// only parse if the action is mirred or bpf
switch actionType {
case "mirred":
action = &MirredAction{}
case "bpf":
action = &BpfAction{}
default:
break nextattr
}
case nl.TCA_OPTIONS:
adata, err := nl.ParseRouteAttr(aattr.Value)
if err != nil {
return nil, err
}
for _, adatum := range adata {
switch actionType {
case "mirred":
switch adatum.Attr.Type {
case nl.TCA_MIRRED_PARMS:
action.(*MirredAction).TcMirred = *nl.DeserializeTcMirred(adatum.Value)
}
case "bpf":
switch adatum.Attr.Type {
case nl.TCA_ACT_BPF_PARMS:
action.(*BpfAction).TcActBpf = *nl.DeserializeTcActBpf(adatum.Value)
case nl.TCA_ACT_BPF_FD:
action.(*BpfAction).Fd = int(native.Uint32(adatum.Value[0:4]))
case nl.TCA_ACT_BPF_NAME:
action.(*BpfAction).Name = string(adatum.Value[:len(adatum.Value)-1])
}
}
}
}
}
actions = append(actions, action)
}
return actions, nil
}
func parseU32Data(filter Filter, data []syscall.NetlinkRouteAttr) (bool, error) { func parseU32Data(filter Filter, data []syscall.NetlinkRouteAttr) (bool, error) {
native = nl.NativeEndian() native = nl.NativeEndian()
u32 := filter.(*U32) u32 := filter.(*U32)
@ -197,34 +294,17 @@ func parseU32Data(filter Filter, data []syscall.NetlinkRouteAttr) (bool, error)
return detailed, nil return detailed, nil
} }
case nl.TCA_U32_ACT: case nl.TCA_U32_ACT:
table, err := nl.ParseRouteAttr(datum.Value) tables, err := nl.ParseRouteAttr(datum.Value)
if err != nil { if err != nil {
return detailed, err return detailed, err
} }
if len(table) != 1 || table[0].Attr.Type != nl.TCA_ACT_TAB { u32.Actions, err = parseActions(tables)
return detailed, fmt.Errorf("Action table not formed properly") if err != nil {
return detailed, err
} }
aattrs, err := nl.ParseRouteAttr(table[0].Value) for _, action := range u32.Actions {
for _, aattr := range aattrs { if action, ok := action.(*MirredAction); ok {
switch aattr.Attr.Type { u32.RedirIndex = int(action.Ifindex)
case nl.TCA_KIND:
actionType := string(aattr.Value[:len(aattr.Value)-1])
// only parse if the action is mirred
if actionType != "mirred" {
return detailed, nil
}
case nl.TCA_OPTIONS:
adata, err := nl.ParseRouteAttr(aattr.Value)
if err != nil {
return detailed, err
}
for _, adatum := range adata {
switch adatum.Attr.Type {
case nl.TCA_MIRRED_PARMS:
mir := nl.DeserializeTcMirred(adatum.Value)
u32.RedirIndex = int(mir.Ifindex)
}
}
} }
} }
} }
@ -261,6 +341,28 @@ func parseFwData(filter Filter, data []syscall.NetlinkRouteAttr) (bool, error) {
return detailed, nil return detailed, nil
} }
func parseBpfData(filter Filter, data []syscall.NetlinkRouteAttr) (bool, error) {
native = nl.NativeEndian()
bpf := filter.(*BpfFilter)
detailed := true
for _, datum := range data {
switch datum.Attr.Type {
case nl.TCA_BPF_FD:
bpf.Fd = int(native.Uint32(datum.Value[0:4]))
case nl.TCA_BPF_NAME:
bpf.Name = string(datum.Value[:len(datum.Value)-1])
case nl.TCA_BPF_CLASSID:
bpf.ClassId = native.Uint32(datum.Value[0:4])
case nl.TCA_BPF_FLAGS:
flags := native.Uint32(datum.Value[0:4])
if (flags & nl.TCA_BPF_FLAG_ACT_DIRECT) != 0 {
bpf.DirectAction = true
}
}
}
return detailed, nil
}
func AlignToAtm(size uint) uint { func AlignToAtm(size uint) uint {
var linksize, cells int var linksize, cells int
cells = int(size / nl.ATM_CELL_PAYLOAD) cells = int(size / nl.ATM_CELL_PAYLOAD)
@ -283,27 +385,27 @@ func AdjustSize(sz uint, mpu uint, linklayer int) uint {
} }
} }
func CalcRtable(rate *nl.TcRateSpec, rtab [256]uint32, cell_log int, mtu uint32, linklayer int) int { func CalcRtable(rate *nl.TcRateSpec, rtab [256]uint32, cellLog int, mtu uint32, linklayer int) int {
bps := rate.Rate bps := rate.Rate
mpu := rate.Mpu mpu := rate.Mpu
var sz uint var sz uint
if mtu == 0 { if mtu == 0 {
mtu = 2047 mtu = 2047
} }
if cell_log < 0 { if cellLog < 0 {
cell_log = 0 cellLog = 0
for (mtu >> uint(cell_log)) > 255 { for (mtu >> uint(cellLog)) > 255 {
cell_log++ cellLog++
} }
} }
for i := 0; i < 256; i++ { for i := 0; i < 256; i++ {
sz = AdjustSize(uint((i+1)<<uint32(cell_log)), uint(mpu), linklayer) sz = AdjustSize(uint((i+1)<<uint32(cellLog)), uint(mpu), linklayer)
rtab[i] = uint32(Xmittime(uint64(bps), uint32(sz))) rtab[i] = uint32(Xmittime(uint64(bps), uint32(sz)))
} }
rate.CellAlign = -1 rate.CellAlign = -1
rate.CellLog = uint8(cell_log) rate.CellLog = uint8(cellLog)
rate.Linklayer = uint8(linklayer & nl.TC_LINKLAYER_MASK) rate.Linklayer = uint8(linklayer & nl.TC_LINKLAYER_MASK)
return cell_log return cellLog
} }
func DeserializeRtab(b []byte) [256]uint32 { func DeserializeRtab(b []byte) [256]uint32 {

View file

@ -204,6 +204,7 @@ type Vxlan struct {
RSC bool RSC bool
L2miss bool L2miss bool
L3miss bool L3miss bool
UDPCSum bool
NoAge bool NoAge bool
GBP bool GBP bool
Age int Age int
@ -424,7 +425,7 @@ const (
BOND_AD_SELECT_COUNT BOND_AD_SELECT_COUNT
) )
// BondAdInfo // BondAdInfo represents ad info for bond
type BondAdInfo struct { type BondAdInfo struct {
AggregatorId int AggregatorId int
NumPorts int NumPorts int
@ -525,7 +526,7 @@ func (bond *Bond) Type() string {
return "bond" return "bond"
} }
// GreTap devices must specify LocalIP and RemoteIP on create // Gretap devices must specify LocalIP and RemoteIP on create
type Gretap struct { type Gretap struct {
LinkAttrs LinkAttrs
IKey uint32 IKey uint32

View file

@ -142,6 +142,54 @@ func LinkSetHardwareAddr(link Link, hwaddr net.HardwareAddr) error {
return err return err
} }
// LinkSetVfHardwareAddr sets the hardware address of a vf for the link.
// Equivalent to: `ip link set $link vf $vf mac $hwaddr`
func LinkSetVfHardwareAddr(link Link, vf int, hwaddr net.HardwareAddr) error {
base := link.Attrs()
ensureIndex(base)
req := nl.NewNetlinkRequest(syscall.RTM_SETLINK, syscall.NLM_F_ACK)
msg := nl.NewIfInfomsg(syscall.AF_UNSPEC)
msg.Index = int32(base.Index)
req.AddData(msg)
data := nl.NewRtAttr(nl.IFLA_VFINFO_LIST, nil)
info := nl.NewRtAttrChild(data, nl.IFLA_VF_INFO, nil)
vfmsg := nl.VfMac{
Vf: uint32(vf),
}
copy(vfmsg.Mac[:], []byte(hwaddr))
nl.NewRtAttrChild(info, nl.IFLA_VF_MAC, vfmsg.Serialize())
req.AddData(data)
_, err := req.Execute(syscall.NETLINK_ROUTE, 0)
return err
}
// LinkSetVfVlan sets the vlan of a vf for the link.
// Equivalent to: `ip link set $link vf $vf vlan $vlan`
func LinkSetVfVlan(link Link, vf, vlan int) error {
base := link.Attrs()
ensureIndex(base)
req := nl.NewNetlinkRequest(syscall.RTM_SETLINK, syscall.NLM_F_ACK)
msg := nl.NewIfInfomsg(syscall.AF_UNSPEC)
msg.Index = int32(base.Index)
req.AddData(msg)
data := nl.NewRtAttr(nl.IFLA_VFINFO_LIST, nil)
info := nl.NewRtAttrChild(data, nl.IFLA_VF_INFO, nil)
vfmsg := nl.VfVlan{
Vf: uint32(vf),
Vlan: uint32(vlan),
}
nl.NewRtAttrChild(info, nl.IFLA_VF_VLAN, vfmsg.Serialize())
req.AddData(data)
_, err := req.Execute(syscall.NETLINK_ROUTE, 0)
return err
}
// LinkSetMaster sets the master of the link device. // LinkSetMaster sets the master of the link device.
// Equivalent to: `ip link set $link master $master` // Equivalent to: `ip link set $link master $master`
func LinkSetMaster(link Link, master *Bridge) error { func LinkSetMaster(link Link, master *Bridge) error {
@ -277,10 +325,12 @@ func addVxlanAttrs(vxlan *Vxlan, linkInfo *nl.RtAttr) {
nl.NewRtAttrChild(data, nl.IFLA_VXLAN_L2MISS, boolAttr(vxlan.L2miss)) nl.NewRtAttrChild(data, nl.IFLA_VXLAN_L2MISS, boolAttr(vxlan.L2miss))
nl.NewRtAttrChild(data, nl.IFLA_VXLAN_L3MISS, boolAttr(vxlan.L3miss)) nl.NewRtAttrChild(data, nl.IFLA_VXLAN_L3MISS, boolAttr(vxlan.L3miss))
if vxlan.UDPCSum {
nl.NewRtAttrChild(data, nl.IFLA_VXLAN_UDP_CSUM, boolAttr(vxlan.UDPCSum))
}
if vxlan.GBP { if vxlan.GBP {
nl.NewRtAttrChild(data, nl.IFLA_VXLAN_GBP, boolAttr(vxlan.GBP)) nl.NewRtAttrChild(data, nl.IFLA_VXLAN_GBP, boolAttr(vxlan.GBP))
} }
if vxlan.NoAge { if vxlan.NoAge {
nl.NewRtAttrChild(data, nl.IFLA_VXLAN_AGEING, nl.Uint32Attr(0)) nl.NewRtAttrChild(data, nl.IFLA_VXLAN_AGEING, nl.Uint32Attr(0))
} else if vxlan.Age > 0 { } else if vxlan.Age > 0 {
@ -815,6 +865,7 @@ func LinkList() ([]Link, error) {
// LinkUpdate is used to pass information back from LinkSubscribe() // LinkUpdate is used to pass information back from LinkSubscribe()
type LinkUpdate struct { type LinkUpdate struct {
nl.IfInfomsg nl.IfInfomsg
Header syscall.NlMsghdr
Link Link
} }
@ -844,7 +895,7 @@ func LinkSubscribe(ch chan<- LinkUpdate, done <-chan struct{}) error {
if err != nil { if err != nil {
return return
} }
ch <- LinkUpdate{IfInfomsg: *ifmsg, Link: link} ch <- LinkUpdate{IfInfomsg: *ifmsg, Header: m.Header, Link: link}
} }
} }
}() }()
@ -935,6 +986,8 @@ func parseVxlanData(link Link, data []syscall.NetlinkRouteAttr) {
vxlan.L2miss = int8(datum.Value[0]) != 0 vxlan.L2miss = int8(datum.Value[0]) != 0
case nl.IFLA_VXLAN_L3MISS: case nl.IFLA_VXLAN_L3MISS:
vxlan.L3miss = int8(datum.Value[0]) != 0 vxlan.L3miss = int8(datum.Value[0]) != 0
case nl.IFLA_VXLAN_UDP_CSUM:
vxlan.UDPCSum = int8(datum.Value[0]) != 0
case nl.IFLA_VXLAN_GBP: case nl.IFLA_VXLAN_GBP:
vxlan.GBP = int8(datum.Value[0]) != 0 vxlan.GBP = int8(datum.Value[0]) != 0
case nl.IFLA_VXLAN_AGEING: case nl.IFLA_VXLAN_AGEING:

View file

@ -14,8 +14,8 @@ import (
"github.com/vishvananda/netlink/nl" "github.com/vishvananda/netlink/nl"
) )
// Family type definitions
const ( const (
// Family type definitions
FAMILY_ALL = nl.FAMILY_ALL FAMILY_ALL = nl.FAMILY_ALL
FAMILY_V4 = nl.FAMILY_V4 FAMILY_V4 = nl.FAMILY_V4
FAMILY_V6 = nl.FAMILY_V6 FAMILY_V6 = nl.FAMILY_V6

View file

@ -1,7 +1,13 @@
package nl package nl
import (
"unsafe"
)
const ( const (
DEFAULT_CHANGE = 0xFFFFFFFF DEFAULT_CHANGE = 0xFFFFFFFF
// doesn't exist in syscall
IFLA_VFINFO_LIST = 0x16
) )
const ( const (
@ -182,3 +188,209 @@ const (
GRE_FLAGS = 0x00F8 GRE_FLAGS = 0x00F8
GRE_VERSION = 0x0007 GRE_VERSION = 0x0007
) )
const (
IFLA_VF_INFO_UNSPEC = iota
IFLA_VF_INFO
IFLA_VF_INFO_MAX = IFLA_VF_INFO
)
const (
IFLA_VF_UNSPEC = iota
IFLA_VF_MAC /* Hardware queue specific attributes */
IFLA_VF_VLAN
IFLA_VF_TX_RATE /* Max TX Bandwidth Allocation */
IFLA_VF_SPOOFCHK /* Spoof Checking on/off switch */
IFLA_VF_LINK_STATE /* link state enable/disable/auto switch */
IFLA_VF_RATE /* Min and Max TX Bandwidth Allocation */
IFLA_VF_RSS_QUERY_EN /* RSS Redirection Table and Hash Key query
* on/off switch
*/
IFLA_VF_STATS /* network device statistics */
IFLA_VF_MAX = IFLA_VF_STATS
)
const (
IFLA_VF_LINK_STATE_AUTO = iota /* link state of the uplink */
IFLA_VF_LINK_STATE_ENABLE /* link always up */
IFLA_VF_LINK_STATE_DISABLE /* link always down */
IFLA_VF_LINK_STATE_MAX = IFLA_VF_LINK_STATE_DISABLE
)
const (
IFLA_VF_STATS_RX_PACKETS = iota
IFLA_VF_STATS_TX_PACKETS
IFLA_VF_STATS_RX_BYTES
IFLA_VF_STATS_TX_BYTES
IFLA_VF_STATS_BROADCAST
IFLA_VF_STATS_MULTICAST
IFLA_VF_STATS_MAX = IFLA_VF_STATS_MULTICAST
)
const (
SizeofVfMac = 0x24
SizeofVfVlan = 0x0c
SizeofVfTxRate = 0x08
SizeofVfRate = 0x0c
SizeofVfSpoofchk = 0x08
SizeofVfLinkState = 0x08
SizeofVfRssQueryEn = 0x08
)
// struct ifla_vf_mac {
// __u32 vf;
// __u8 mac[32]; /* MAX_ADDR_LEN */
// };
type VfMac struct {
Vf uint32
Mac [32]byte
}
func (msg *VfMac) Len() int {
return SizeofVfMac
}
func DeserializeVfMac(b []byte) *VfMac {
return (*VfMac)(unsafe.Pointer(&b[0:SizeofVfMac][0]))
}
func (msg *VfMac) Serialize() []byte {
return (*(*[SizeofVfMac]byte)(unsafe.Pointer(msg)))[:]
}
// struct ifla_vf_vlan {
// __u32 vf;
// __u32 vlan; /* 0 - 4095, 0 disables VLAN filter */
// __u32 qos;
// };
type VfVlan struct {
Vf uint32
Vlan uint32
Qos uint32
}
func (msg *VfVlan) Len() int {
return SizeofVfVlan
}
func DeserializeVfVlan(b []byte) *VfVlan {
return (*VfVlan)(unsafe.Pointer(&b[0:SizeofVfVlan][0]))
}
func (msg *VfVlan) Serialize() []byte {
return (*(*[SizeofVfVlan]byte)(unsafe.Pointer(msg)))[:]
}
// struct ifla_vf_tx_rate {
// __u32 vf;
// __u32 rate; /* Max TX bandwidth in Mbps, 0 disables throttling */
// };
type VfTxRate struct {
Vf uint32
Rate uint32
}
func (msg *VfTxRate) Len() int {
return SizeofVfTxRate
}
func DeserializeVfTxRate(b []byte) *VfTxRate {
return (*VfTxRate)(unsafe.Pointer(&b[0:SizeofVfTxRate][0]))
}
func (msg *VfTxRate) Serialize() []byte {
return (*(*[SizeofVfTxRate]byte)(unsafe.Pointer(msg)))[:]
}
// struct ifla_vf_rate {
// __u32 vf;
// __u32 min_tx_rate; /* Min Bandwidth in Mbps */
// __u32 max_tx_rate; /* Max Bandwidth in Mbps */
// };
type VfRate struct {
Vf uint32
MinTxRate uint32
MaxTxRate uint32
}
func (msg *VfRate) Len() int {
return SizeofVfRate
}
func DeserializeVfRate(b []byte) *VfRate {
return (*VfRate)(unsafe.Pointer(&b[0:SizeofVfRate][0]))
}
func (msg *VfRate) Serialize() []byte {
return (*(*[SizeofVfRate]byte)(unsafe.Pointer(msg)))[:]
}
// struct ifla_vf_spoofchk {
// __u32 vf;
// __u32 setting;
// };
type VfSpoofchk struct {
Vf uint32
Setting uint32
}
func (msg *VfSpoofchk) Len() int {
return SizeofVfSpoofchk
}
func DeserializeVfSpoofchk(b []byte) *VfSpoofchk {
return (*VfSpoofchk)(unsafe.Pointer(&b[0:SizeofVfSpoofchk][0]))
}
func (msg *VfSpoofchk) Serialize() []byte {
return (*(*[SizeofVfSpoofchk]byte)(unsafe.Pointer(msg)))[:]
}
// struct ifla_vf_link_state {
// __u32 vf;
// __u32 link_state;
// };
type VfLinkState struct {
Vf uint32
LinkState uint32
}
func (msg *VfLinkState) Len() int {
return SizeofVfLinkState
}
func DeserializeVfLinkState(b []byte) *VfLinkState {
return (*VfLinkState)(unsafe.Pointer(&b[0:SizeofVfLinkState][0]))
}
func (msg *VfLinkState) Serialize() []byte {
return (*(*[SizeofVfLinkState]byte)(unsafe.Pointer(msg)))[:]
}
// struct ifla_vf_rss_query_en {
// __u32 vf;
// __u32 setting;
// };
type VfRssQueryEn struct {
Vf uint32
Setting uint32
}
func (msg *VfRssQueryEn) Len() int {
return SizeofVfRssQueryEn
}
func DeserializeVfRssQueryEn(b []byte) *VfRssQueryEn {
return (*VfRssQueryEn)(unsafe.Pointer(&b[0:SizeofVfRssQueryEn][0]))
}
func (msg *VfRssQueryEn) Serialize() []byte {
return (*(*[SizeofVfRssQueryEn]byte)(unsafe.Pointer(msg)))[:]
}

View file

@ -49,6 +49,15 @@ const (
TCAA_MAX = 1 TCAA_MAX = 1
) )
const (
TCA_ACT_UNSPEC = iota
TCA_ACT_KIND
TCA_ACT_OPTIONS
TCA_ACT_INDEX
TCA_ACT_STATS
TCA_ACT_MAX
)
const ( const (
TCA_PRIO_UNSPEC = iota TCA_PRIO_UNSPEC = iota
TCA_PRIO_MQ TCA_PRIO_MQ
@ -69,6 +78,7 @@ const (
SizeofTcHtbGlob = 0x14 SizeofTcHtbGlob = 0x14
SizeofTcU32Key = 0x10 SizeofTcU32Key = 0x10
SizeofTcU32Sel = 0x10 // without keys SizeofTcU32Sel = 0x10 // without keys
SizeofTcActBpf = 0x14
SizeofTcMirred = 0x1c SizeofTcMirred = 0x1c
SizeofTcPolice = 2*SizeofTcRateSpec + 0x20 SizeofTcPolice = 2*SizeofTcRateSpec + 0x20
) )
@ -533,9 +543,34 @@ const (
TC_ACT_STOLEN = 4 TC_ACT_STOLEN = 4
TC_ACT_QUEUED = 5 TC_ACT_QUEUED = 5
TC_ACT_REPEAT = 6 TC_ACT_REPEAT = 6
TC_ACT_REDIRECT = 7
TC_ACT_JUMP = 0x10000000 TC_ACT_JUMP = 0x10000000
) )
type TcGen struct {
Index uint32
Capab uint32
Action int32
Refcnt int32
Bindcnt int32
}
type TcActBpf struct {
TcGen
}
func (msg *TcActBpf) Len() int {
return SizeofTcActBpf
}
func DeserializeTcActBpf(b []byte) *TcActBpf {
return (*TcActBpf)(unsafe.Pointer(&b[0:SizeofTcActBpf][0]))
}
func (x *TcActBpf) Serialize() []byte {
return (*(*[SizeofTcActBpf]byte)(unsafe.Pointer(x)))[:]
}
// #define tc_gen \ // #define tc_gen \
// __u32 index; \ // __u32 index; \
// __u32 capab; \ // __u32 capab; \
@ -549,11 +584,7 @@ const (
// }; // };
type TcMirred struct { type TcMirred struct {
Index uint32 TcGen
Capab uint32
Action int32
Refcnt int32
Bindcnt int32
Eaction int32 Eaction int32
Ifindex uint32 Ifindex uint32
} }
@ -625,3 +656,31 @@ const (
TCA_FW_MASK TCA_FW_MASK
TCA_FW_MAX = TCA_FW_MASK TCA_FW_MAX = TCA_FW_MASK
) )
const (
TCA_BPF_FLAG_ACT_DIRECT uint32 = 1 << iota
)
const (
TCA_BPF_UNSPEC = iota
TCA_BPF_ACT
TCA_BPF_POLICE
TCA_BPF_CLASSID
TCA_BPF_OPS_LEN
TCA_BPF_OPS
TCA_BPF_FD
TCA_BPF_NAME
TCA_BPF_FLAGS
TCA_BPF_MAX = TCA_BPF_FLAGS
)
const (
TCA_ACT_BPF_UNSPEC = iota
TCA_ACT_BPF_TM
TCA_ACT_BPF_PARMS
TCA_ACT_BPF_OPS_LEN
TCA_ACT_BPF_OPS
TCA_ACT_BPF_FD
TCA_ACT_BPF_NAME
TCA_ACT_BPF_MAX = TCA_ACT_BPF_NAME
)

View file

@ -8,16 +8,21 @@ import (
const ( const (
HANDLE_NONE = 0 HANDLE_NONE = 0
HANDLE_INGRESS = 0xFFFFFFF1 HANDLE_INGRESS = 0xFFFFFFF1
HANDLE_CLSACT = HANDLE_INGRESS
HANDLE_ROOT = 0xFFFFFFFF HANDLE_ROOT = 0xFFFFFFFF
PRIORITY_MAP_LEN = 16 PRIORITY_MAP_LEN = 16
) )
const (
HANDLE_MIN_INGRESS = 0xFFFFFFF2
HANDLE_MIN_EGRESS = 0xFFFFFFF3
)
type Qdisc interface { type Qdisc interface {
Attrs() *QdiscAttrs Attrs() *QdiscAttrs
Type() string Type() string
} }
// Qdisc represents a netlink qdisc. A qdisc is associated with a link, // QdiscAttrs represents a netlink qdisc. A qdisc is associated with a link,
// has a handle, a parent and a refcnt. The root qdisc of a device should // has a handle, a parent and a refcnt. The root qdisc of a device should
// have parent == HANDLE_ROOT. // have parent == HANDLE_ROOT.
type QdiscAttrs struct { type QdiscAttrs struct {
@ -28,7 +33,7 @@ type QdiscAttrs struct {
} }
func (q QdiscAttrs) String() string { func (q QdiscAttrs) String() string {
return fmt.Sprintf("{LinkIndex: %d, Handle: %s, Parent: %s, Refcnt: %s}", q.LinkIndex, HandleStr(q.Handle), HandleStr(q.Parent), q.Refcnt) return fmt.Sprintf("{LinkIndex: %d, Handle: %s, Parent: %s, Refcnt: %d}", q.LinkIndex, HandleStr(q.Handle), HandleStr(q.Parent), q.Refcnt)
} }
func MakeHandle(major, minor uint16) uint32 { func MakeHandle(major, minor uint16) uint32 {
@ -149,7 +154,7 @@ type NetemQdiscAttrs struct {
func (q NetemQdiscAttrs) String() string { func (q NetemQdiscAttrs) String() string {
return fmt.Sprintf( return fmt.Sprintf(
"{Latency: %d, Limit: %d, Loss: %d, Gap: %d, Duplicate: %d, Jitter: %d}", "{Latency: %d, Limit: %d, Loss: %f, Gap: %d, Duplicate: %f, Jitter: %d}",
q.Latency, q.Limit, q.Loss, q.Gap, q.Duplicate, q.Jitter, q.Latency, q.Limit, q.Loss, q.Gap, q.Duplicate, q.Jitter,
) )
} }
@ -173,9 +178,9 @@ type Netem struct {
func NewNetem(attrs QdiscAttrs, nattrs NetemQdiscAttrs) *Netem { func NewNetem(attrs QdiscAttrs, nattrs NetemQdiscAttrs) *Netem {
var limit uint32 = 1000 var limit uint32 = 1000
var loss_corr, delay_corr, duplicate_corr uint32 var lossCorr, delayCorr, duplicateCorr uint32
var reorder_prob, reorder_corr uint32 var reorderProb, reorderCorr uint32
var corrupt_prob, corrupt_corr uint32 var corruptProb, corruptCorr uint32
latency := nattrs.Latency latency := nattrs.Latency
loss := Percentage2u32(nattrs.Loss) loss := Percentage2u32(nattrs.Loss)
@ -185,13 +190,13 @@ func NewNetem(attrs QdiscAttrs, nattrs NetemQdiscAttrs) *Netem {
// Correlation // Correlation
if latency > 0 && jitter > 0 { if latency > 0 && jitter > 0 {
delay_corr = Percentage2u32(nattrs.DelayCorr) delayCorr = Percentage2u32(nattrs.DelayCorr)
} }
if loss > 0 { if loss > 0 {
loss_corr = Percentage2u32(nattrs.LossCorr) lossCorr = Percentage2u32(nattrs.LossCorr)
} }
if duplicate > 0 { if duplicate > 0 {
duplicate_corr = Percentage2u32(nattrs.DuplicateCorr) duplicateCorr = Percentage2u32(nattrs.DuplicateCorr)
} }
// FIXME should validate values(like loss/duplicate are percentages...) // FIXME should validate values(like loss/duplicate are percentages...)
latency = time2Tick(latency) latency = time2Tick(latency)
@ -204,34 +209,34 @@ func NewNetem(attrs QdiscAttrs, nattrs NetemQdiscAttrs) *Netem {
jitter = time2Tick(jitter) jitter = time2Tick(jitter)
} }
reorder_prob = Percentage2u32(nattrs.ReorderProb) reorderProb = Percentage2u32(nattrs.ReorderProb)
reorder_corr = Percentage2u32(nattrs.ReorderCorr) reorderCorr = Percentage2u32(nattrs.ReorderCorr)
if reorder_prob > 0 { if reorderProb > 0 {
// ERROR if lantency == 0 // ERROR if lantency == 0
if gap == 0 { if gap == 0 {
gap = 1 gap = 1
} }
} }
corrupt_prob = Percentage2u32(nattrs.CorruptProb) corruptProb = Percentage2u32(nattrs.CorruptProb)
corrupt_corr = Percentage2u32(nattrs.CorruptCorr) corruptCorr = Percentage2u32(nattrs.CorruptCorr)
return &Netem{ return &Netem{
QdiscAttrs: attrs, QdiscAttrs: attrs,
Latency: latency, Latency: latency,
DelayCorr: delay_corr, DelayCorr: delayCorr,
Limit: limit, Limit: limit,
Loss: loss, Loss: loss,
LossCorr: loss_corr, LossCorr: lossCorr,
Gap: gap, Gap: gap,
Duplicate: duplicate, Duplicate: duplicate,
DuplicateCorr: duplicate_corr, DuplicateCorr: duplicateCorr,
Jitter: jitter, Jitter: jitter,
ReorderProb: reorder_prob, ReorderProb: reorderProb,
ReorderCorr: reorder_corr, ReorderCorr: reorderCorr,
CorruptProb: corrupt_prob, CorruptProb: corruptProb,
CorruptCorr: corrupt_corr, CorruptCorr: corruptCorr,
} }
} }

View file

@ -334,9 +334,9 @@ const (
) )
var ( var (
tickInUsec float64 = 0.0 tickInUsec float64
clockFactor float64 = 0.0 clockFactor float64
hz float64 = 0.0 hz float64
) )
func initClock() { func initClock() {

View file

@ -59,8 +59,8 @@ type flagString struct {
} }
var testFlags = []flagString{ var testFlags = []flagString{
flagString{f: FLAG_ONLINK, s: "onlink"}, {f: FLAG_ONLINK, s: "onlink"},
flagString{f: FLAG_PERVASIVE, s: "pervasive"}, {f: FLAG_PERVASIVE, s: "pervasive"},
} }
func (r *Route) ListFlags() []string { func (r *Route) ListFlags() []string {

View file

@ -116,6 +116,7 @@ func routeHandle(route *Route, req *nl.NetlinkRequest, msg *nl.RtMsg) error {
msg.Type = uint8(route.Type) msg.Type = uint8(route.Type)
} }
msg.Flags = uint32(route.Flags)
msg.Scope = uint8(route.Scope) msg.Scope = uint8(route.Scope)
msg.Family = uint8(family) msg.Family = uint8(family)
req.AddData(msg) req.AddData(msg)

View file

@ -29,7 +29,7 @@ func (e EncapType) String() string {
return "unknown" return "unknown"
} }
// XfrmEncap represents the encapsulation to use for the ipsec encryption. // XfrmStateEncap represents the encapsulation to use for the ipsec encryption.
type XfrmStateEncap struct { type XfrmStateEncap struct {
Type EncapType Type EncapType
SrcPort int SrcPort int

View file

@ -110,9 +110,6 @@ func XfrmStateDel(state *XfrmState) error {
func XfrmStateList(family int) ([]XfrmState, error) { func XfrmStateList(family int) ([]XfrmState, error) {
req := nl.NewNetlinkRequest(nl.XFRM_MSG_GETSA, syscall.NLM_F_DUMP) req := nl.NewNetlinkRequest(nl.XFRM_MSG_GETSA, syscall.NLM_F_DUMP)
msg := nl.NewIfInfomsg(family)
req.AddData(msg)
msgs, err := req.Execute(syscall.NETLINK_XFRM, nl.XFRM_MSG_NEWSA) msgs, err := req.Execute(syscall.NETLINK_XFRM, nl.XFRM_MSG_NEWSA)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -210,13 +210,13 @@ type CancelFunc func()
// call cancel as soon as the operations running in this Context complete. // call cancel as soon as the operations running in this Context complete.
func WithCancel(parent Context) (ctx Context, cancel CancelFunc) { func WithCancel(parent Context) (ctx Context, cancel CancelFunc) {
c := newCancelCtx(parent) c := newCancelCtx(parent)
propagateCancel(parent, &c) propagateCancel(parent, c)
return &c, func() { c.cancel(true, Canceled) } return c, func() { c.cancel(true, Canceled) }
} }
// newCancelCtx returns an initialized cancelCtx. // newCancelCtx returns an initialized cancelCtx.
func newCancelCtx(parent Context) cancelCtx { func newCancelCtx(parent Context) *cancelCtx {
return cancelCtx{ return &cancelCtx{
Context: parent, Context: parent,
done: make(chan struct{}), done: make(chan struct{}),
} }
@ -259,7 +259,7 @@ func parentCancelCtx(parent Context) (*cancelCtx, bool) {
case *cancelCtx: case *cancelCtx:
return c, true return c, true
case *timerCtx: case *timerCtx:
return &c.cancelCtx, true return c.cancelCtx, true
case *valueCtx: case *valueCtx:
parent = c.Context parent = c.Context
default: default:
@ -377,7 +377,7 @@ func WithDeadline(parent Context, deadline time.Time) (Context, CancelFunc) {
// implement Done and Err. It implements cancel by stopping its timer then // implement Done and Err. It implements cancel by stopping its timer then
// delegating to cancelCtx.cancel. // delegating to cancelCtx.cancel.
type timerCtx struct { type timerCtx struct {
cancelCtx *cancelCtx
timer *time.Timer // Under cancelCtx.mu. timer *time.Timer // Under cancelCtx.mu.
deadline time.Time deadline time.Time

View file

@ -7,6 +7,7 @@
package http2 package http2
import ( import (
"crypto/tls"
"net/http" "net/http"
"sync" "sync"
) )
@ -17,21 +18,29 @@ type ClientConnPool interface {
MarkDead(*ClientConn) MarkDead(*ClientConn)
} }
// TODO: use singleflight for dialing and addConnCalls?
type clientConnPool struct { type clientConnPool struct {
t *Transport t *Transport
mu sync.Mutex // TODO: maybe switch to RWMutex mu sync.Mutex // TODO: maybe switch to RWMutex
// TODO: add support for sharing conns based on cert names // TODO: add support for sharing conns based on cert names
// (e.g. share conn for googleapis.com and appspot.com) // (e.g. share conn for googleapis.com and appspot.com)
conns map[string][]*ClientConn // key is host:port conns map[string][]*ClientConn // key is host:port
dialing map[string]*dialCall // currently in-flight dials dialing map[string]*dialCall // currently in-flight dials
keys map[*ClientConn][]string keys map[*ClientConn][]string
addConnCalls map[string]*addConnCall // in-flight addConnIfNeede calls
} }
func (p *clientConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) { func (p *clientConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) {
return p.getClientConn(req, addr, true) return p.getClientConn(req, addr, dialOnMiss)
} }
func (p *clientConnPool) getClientConn(req *http.Request, addr string, dialOnMiss bool) (*ClientConn, error) { const (
dialOnMiss = true
noDialOnMiss = false
)
func (p *clientConnPool) getClientConn(_ *http.Request, addr string, dialOnMiss bool) (*ClientConn, error) {
p.mu.Lock() p.mu.Lock()
for _, cc := range p.conns[addr] { for _, cc := range p.conns[addr] {
if cc.CanTakeNewRequest() { if cc.CanTakeNewRequest() {
@ -85,6 +94,64 @@ func (c *dialCall) dial(addr string) {
c.p.mu.Unlock() c.p.mu.Unlock()
} }
// addConnIfNeeded makes a NewClientConn out of c if a connection for key doesn't
// already exist. It coalesces concurrent calls with the same key.
// This is used by the http1 Transport code when it creates a new connection. Because
// the http1 Transport doesn't de-dup TCP dials to outbound hosts (because it doesn't know
// the protocol), it can get into a situation where it has multiple TLS connections.
// This code decides which ones live or die.
// The return value used is whether c was used.
// c is never closed.
func (p *clientConnPool) addConnIfNeeded(key string, t *Transport, c *tls.Conn) (used bool, err error) {
p.mu.Lock()
for _, cc := range p.conns[key] {
if cc.CanTakeNewRequest() {
p.mu.Unlock()
return false, nil
}
}
call, dup := p.addConnCalls[key]
if !dup {
if p.addConnCalls == nil {
p.addConnCalls = make(map[string]*addConnCall)
}
call = &addConnCall{
p: p,
done: make(chan struct{}),
}
p.addConnCalls[key] = call
go call.run(t, key, c)
}
p.mu.Unlock()
<-call.done
if call.err != nil {
return false, call.err
}
return !dup, nil
}
type addConnCall struct {
p *clientConnPool
done chan struct{} // closed when done
err error
}
func (c *addConnCall) run(t *Transport, key string, tc *tls.Conn) {
cc, err := t.NewClientConn(tc)
p := c.p
p.mu.Lock()
if err != nil {
c.err = err
} else {
p.addConnLocked(key, cc)
}
delete(p.addConnCalls, key)
p.mu.Unlock()
close(c.done)
}
func (p *clientConnPool) addConn(key string, cc *ClientConn) { func (p *clientConnPool) addConn(key string, cc *ClientConn) {
p.mu.Lock() p.mu.Lock()
p.addConnLocked(key, cc) p.addConnLocked(key, cc)

View file

@ -12,11 +12,15 @@ import (
"net/http" "net/http"
) )
func configureTransport(t1 *http.Transport) error { func configureTransport(t1 *http.Transport) (*Transport, error) {
connPool := new(clientConnPool) connPool := new(clientConnPool)
t2 := &Transport{ConnPool: noDialClientConnPool{connPool}} t2 := &Transport{
ConnPool: noDialClientConnPool{connPool},
t1: t1,
}
connPool.t = t2
if err := registerHTTPSProtocol(t1, noDialH2RoundTripper{t2}); err != nil { if err := registerHTTPSProtocol(t1, noDialH2RoundTripper{t2}); err != nil {
return err return nil, err
} }
if t1.TLSClientConfig == nil { if t1.TLSClientConfig == nil {
t1.TLSClientConfig = new(tls.Config) t1.TLSClientConfig = new(tls.Config)
@ -28,12 +32,17 @@ func configureTransport(t1 *http.Transport) error {
t1.TLSClientConfig.NextProtos = append(t1.TLSClientConfig.NextProtos, "http/1.1") t1.TLSClientConfig.NextProtos = append(t1.TLSClientConfig.NextProtos, "http/1.1")
} }
upgradeFn := func(authority string, c *tls.Conn) http.RoundTripper { upgradeFn := func(authority string, c *tls.Conn) http.RoundTripper {
cc, err := t2.NewClientConn(c) addr := authorityAddr(authority)
if err != nil { if used, err := connPool.addConnIfNeeded(addr, t2, c); err != nil {
c.Close() go c.Close()
return erringRoundTripper{err} return erringRoundTripper{err}
} else if !used {
// Turns out we don't need this c.
// For example, two goroutines made requests to the same host
// at the same time, both kicking off TCP dials. (since protocol
// was unknown)
go c.Close()
} }
connPool.addConn(authorityAddr(authority), cc)
return t2 return t2
} }
if m := t1.TLSNextProto; len(m) == 0 { if m := t1.TLSNextProto; len(m) == 0 {
@ -43,7 +52,7 @@ func configureTransport(t1 *http.Transport) error {
} else { } else {
m["h2"] = upgradeFn m["h2"] = upgradeFn
} }
return nil return t2, nil
} }
// registerHTTPSProtocol calls Transport.RegisterProtocol but // registerHTTPSProtocol calls Transport.RegisterProtocol but
@ -64,8 +73,7 @@ func registerHTTPSProtocol(t *http.Transport, rt http.RoundTripper) (err error)
type noDialClientConnPool struct{ *clientConnPool } type noDialClientConnPool struct{ *clientConnPool }
func (p noDialClientConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) { func (p noDialClientConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) {
const doDial = false return p.getClientConn(req, addr, noDialOnMiss)
return p.getClientConn(req, addr, doDial)
} }
// noDialH2RoundTripper is a RoundTripper which only tries to complete the request // noDialH2RoundTripper is a RoundTripper which only tries to complete the request

View file

@ -4,7 +4,10 @@
package http2 package http2
import "fmt" import (
"errors"
"fmt"
)
// An ErrCode is an unsigned 32-bit error code as defined in the HTTP/2 spec. // An ErrCode is an unsigned 32-bit error code as defined in the HTTP/2 spec.
type ErrCode uint32 type ErrCode uint32
@ -88,3 +91,32 @@ type connError struct {
func (e connError) Error() string { func (e connError) Error() string {
return fmt.Sprintf("http2: connection error: %v: %v", e.Code, e.Reason) return fmt.Sprintf("http2: connection error: %v: %v", e.Code, e.Reason)
} }
type pseudoHeaderError string
func (e pseudoHeaderError) Error() string {
return fmt.Sprintf("invalid pseudo-header %q", string(e))
}
type duplicatePseudoHeaderError string
func (e duplicatePseudoHeaderError) Error() string {
return fmt.Sprintf("duplicate pseudo-header %q", string(e))
}
type headerFieldNameError string
func (e headerFieldNameError) Error() string {
return fmt.Sprintf("invalid header field name %q", string(e))
}
type headerFieldValueError string
func (e headerFieldValueError) Error() string {
return fmt.Sprintf("invalid header field value %q", string(e))
}
var (
errMixPseudoHeaderTypes = errors.New("mix of request and response pseudo headers")
errPseudoAfterRegular = errors.New("pseudo header field after regular")
)

View file

@ -10,7 +10,11 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log"
"strings"
"sync" "sync"
"golang.org/x/net/http2/hpack"
) )
const frameHeaderLen = 9 const frameHeaderLen = 9
@ -171,6 +175,12 @@ func (h FrameHeader) Header() FrameHeader { return h }
func (h FrameHeader) String() string { func (h FrameHeader) String() string {
var buf bytes.Buffer var buf bytes.Buffer
buf.WriteString("[FrameHeader ") buf.WriteString("[FrameHeader ")
h.writeDebug(&buf)
buf.WriteByte(']')
return buf.String()
}
func (h FrameHeader) writeDebug(buf *bytes.Buffer) {
buf.WriteString(h.Type.String()) buf.WriteString(h.Type.String())
if h.Flags != 0 { if h.Flags != 0 {
buf.WriteString(" flags=") buf.WriteString(" flags=")
@ -187,15 +197,14 @@ func (h FrameHeader) String() string {
if name != "" { if name != "" {
buf.WriteString(name) buf.WriteString(name)
} else { } else {
fmt.Fprintf(&buf, "0x%x", 1<<i) fmt.Fprintf(buf, "0x%x", 1<<i)
} }
} }
} }
if h.StreamID != 0 { if h.StreamID != 0 {
fmt.Fprintf(&buf, " stream=%d", h.StreamID) fmt.Fprintf(buf, " stream=%d", h.StreamID)
} }
fmt.Fprintf(&buf, " len=%d]", h.Length) fmt.Fprintf(buf, " len=%d", h.Length)
return buf.String()
} }
func (h *FrameHeader) checkValid() { func (h *FrameHeader) checkValid() {
@ -255,7 +264,7 @@ type Frame interface {
type Framer struct { type Framer struct {
r io.Reader r io.Reader
lastFrame Frame lastFrame Frame
errReason string errDetail error
// lastHeaderStream is non-zero if the last frame was an // lastHeaderStream is non-zero if the last frame was an
// unfinished HEADERS/CONTINUATION. // unfinished HEADERS/CONTINUATION.
@ -287,13 +296,37 @@ type Framer struct {
// to return non-compliant frames or frame orders. // to return non-compliant frames or frame orders.
// This is for testing and permits using the Framer to test // This is for testing and permits using the Framer to test
// other HTTP/2 implementations' conformance to the spec. // other HTTP/2 implementations' conformance to the spec.
// It is not compatible with ReadMetaHeaders.
AllowIllegalReads bool AllowIllegalReads bool
// ReadMetaHeaders if non-nil causes ReadFrame to merge
// HEADERS and CONTINUATION frames together and return
// MetaHeadersFrame instead.
ReadMetaHeaders *hpack.Decoder
// MaxHeaderListSize is the http2 MAX_HEADER_LIST_SIZE.
// It's used only if ReadMetaHeaders is set; 0 means a sane default
// (currently 16MB)
// If the limit is hit, MetaHeadersFrame.Truncated is set true.
MaxHeaderListSize uint32
// TODO: track which type of frame & with which flags was sent // TODO: track which type of frame & with which flags was sent
// last. Then return an error (unless AllowIllegalWrites) if // last. Then return an error (unless AllowIllegalWrites) if
// we're in the middle of a header block and a // we're in the middle of a header block and a
// non-Continuation or Continuation on a different stream is // non-Continuation or Continuation on a different stream is
// attempted to be written. // attempted to be written.
logReads bool
debugFramer *Framer // only use for logging written writes
debugFramerBuf *bytes.Buffer
}
func (fr *Framer) maxHeaderListSize() uint32 {
if fr.MaxHeaderListSize == 0 {
return 16 << 20 // sane default, per docs
}
return fr.MaxHeaderListSize
} }
func (f *Framer) startWrite(ftype FrameType, flags Flags, streamID uint32) { func (f *Framer) startWrite(ftype FrameType, flags Flags, streamID uint32) {
@ -321,6 +354,10 @@ func (f *Framer) endWrite() error {
byte(length>>16), byte(length>>16),
byte(length>>8), byte(length>>8),
byte(length)) byte(length))
if logFrameWrites {
f.logWrite()
}
n, err := f.w.Write(f.wbuf) n, err := f.w.Write(f.wbuf)
if err == nil && n != len(f.wbuf) { if err == nil && n != len(f.wbuf) {
err = io.ErrShortWrite err = io.ErrShortWrite
@ -328,6 +365,24 @@ func (f *Framer) endWrite() error {
return err return err
} }
func (f *Framer) logWrite() {
if f.debugFramer == nil {
f.debugFramerBuf = new(bytes.Buffer)
f.debugFramer = NewFramer(nil, f.debugFramerBuf)
f.debugFramer.logReads = false // we log it ourselves, saying "wrote" below
// Let us read anything, even if we accidentally wrote it
// in the wrong order:
f.debugFramer.AllowIllegalReads = true
}
f.debugFramerBuf.Write(f.wbuf)
fr, err := f.debugFramer.ReadFrame()
if err != nil {
log.Printf("http2: Framer %p: failed to decode just-written frame", f)
return
}
log.Printf("http2: Framer %p: wrote %v", f, summarizeFrame(fr))
}
func (f *Framer) writeByte(v byte) { f.wbuf = append(f.wbuf, v) } func (f *Framer) writeByte(v byte) { f.wbuf = append(f.wbuf, v) }
func (f *Framer) writeBytes(v []byte) { f.wbuf = append(f.wbuf, v...) } func (f *Framer) writeBytes(v []byte) { f.wbuf = append(f.wbuf, v...) }
func (f *Framer) writeUint16(v uint16) { f.wbuf = append(f.wbuf, byte(v>>8), byte(v)) } func (f *Framer) writeUint16(v uint16) { f.wbuf = append(f.wbuf, byte(v>>8), byte(v)) }
@ -343,8 +398,9 @@ const (
// NewFramer returns a Framer that writes frames to w and reads them from r. // NewFramer returns a Framer that writes frames to w and reads them from r.
func NewFramer(w io.Writer, r io.Reader) *Framer { func NewFramer(w io.Writer, r io.Reader) *Framer {
fr := &Framer{ fr := &Framer{
w: w, w: w,
r: r, r: r,
logReads: logFrameReads,
} }
fr.getReadBuf = func(size uint32) []byte { fr.getReadBuf = func(size uint32) []byte {
if cap(fr.readBuf) >= int(size) { if cap(fr.readBuf) >= int(size) {
@ -368,6 +424,17 @@ func (fr *Framer) SetMaxReadFrameSize(v uint32) {
fr.maxReadSize = v fr.maxReadSize = v
} }
// ErrorDetail returns a more detailed error of the last error
// returned by Framer.ReadFrame. For instance, if ReadFrame
// returns a StreamError with code PROTOCOL_ERROR, ErrorDetail
// will say exactly what was invalid. ErrorDetail is not guaranteed
// to return a non-nil value and like the rest of the http2 package,
// its return value is not protected by an API compatibility promise.
// ErrorDetail is reset after the next call to ReadFrame.
func (fr *Framer) ErrorDetail() error {
return fr.errDetail
}
// ErrFrameTooLarge is returned from Framer.ReadFrame when the peer // ErrFrameTooLarge is returned from Framer.ReadFrame when the peer
// sends a frame that is larger than declared with SetMaxReadFrameSize. // sends a frame that is larger than declared with SetMaxReadFrameSize.
var ErrFrameTooLarge = errors.New("http2: frame too large") var ErrFrameTooLarge = errors.New("http2: frame too large")
@ -389,6 +456,7 @@ func terminalReadFrameError(err error) bool {
// ConnectionError, StreamError, or anything else from from the underlying // ConnectionError, StreamError, or anything else from from the underlying
// reader. // reader.
func (fr *Framer) ReadFrame() (Frame, error) { func (fr *Framer) ReadFrame() (Frame, error) {
fr.errDetail = nil
if fr.lastFrame != nil { if fr.lastFrame != nil {
fr.lastFrame.invalidate() fr.lastFrame.invalidate()
} }
@ -413,6 +481,12 @@ func (fr *Framer) ReadFrame() (Frame, error) {
if err := fr.checkFrameOrder(f); err != nil { if err := fr.checkFrameOrder(f); err != nil {
return nil, err return nil, err
} }
if fr.logReads {
log.Printf("http2: Framer %p: read %v", fr, summarizeFrame(f))
}
if fh.Type == FrameHeaders && fr.ReadMetaHeaders != nil {
return fr.readMetaFrame(f.(*HeadersFrame))
}
return f, nil return f, nil
} }
@ -421,7 +495,7 @@ func (fr *Framer) ReadFrame() (Frame, error) {
// to the peer before hanging up on them. This might help others debug // to the peer before hanging up on them. This might help others debug
// their implementations. // their implementations.
func (fr *Framer) connError(code ErrCode, reason string) error { func (fr *Framer) connError(code ErrCode, reason string) error {
fr.errReason = reason fr.errDetail = errors.New(reason)
return ConnectionError(code) return ConnectionError(code)
} }
@ -1026,10 +1100,6 @@ func parseContinuationFrame(fh FrameHeader, p []byte) (Frame, error) {
return &ContinuationFrame{fh, p}, nil return &ContinuationFrame{fh, p}, nil
} }
func (f *ContinuationFrame) StreamEnded() bool {
return f.FrameHeader.Flags.Has(FlagDataEndStream)
}
func (f *ContinuationFrame) HeaderBlockFragment() []byte { func (f *ContinuationFrame) HeaderBlockFragment() []byte {
f.checkValid() f.checkValid()
return f.headerFragBuf return f.headerFragBuf
@ -1191,3 +1261,236 @@ type streamEnder interface {
type headersEnder interface { type headersEnder interface {
HeadersEnded() bool HeadersEnded() bool
} }
type headersOrContinuation interface {
headersEnder
HeaderBlockFragment() []byte
}
// A MetaHeadersFrame is the representation of one HEADERS frame and
// zero or more contiguous CONTINUATION frames and the decoding of
// their HPACK-encoded contents.
//
// This type of frame does not appear on the wire and is only returned
// by the Framer when Framer.ReadMetaHeaders is set.
type MetaHeadersFrame struct {
*HeadersFrame
// Fields are the fields contained in the HEADERS and
// CONTINUATION frames. The underlying slice is owned by the
// Framer and must not be retained after the next call to
// ReadFrame.
//
// Fields are guaranteed to be in the correct http2 order and
// not have unknown pseudo header fields or invalid header
// field names or values. Required pseudo header fields may be
// missing, however. Use the MetaHeadersFrame.Pseudo accessor
// method access pseudo headers.
Fields []hpack.HeaderField
// Truncated is whether the max header list size limit was hit
// and Fields is incomplete. The hpack decoder state is still
// valid, however.
Truncated bool
}
// PseudoValue returns the given pseudo header field's value.
// The provided pseudo field should not contain the leading colon.
func (mh *MetaHeadersFrame) PseudoValue(pseudo string) string {
for _, hf := range mh.Fields {
if !hf.IsPseudo() {
return ""
}
if hf.Name[1:] == pseudo {
return hf.Value
}
}
return ""
}
// RegularFields returns the regular (non-pseudo) header fields of mh.
// The caller does not own the returned slice.
func (mh *MetaHeadersFrame) RegularFields() []hpack.HeaderField {
for i, hf := range mh.Fields {
if !hf.IsPseudo() {
return mh.Fields[i:]
}
}
return nil
}
// PseudoFields returns the pseudo header fields of mh.
// The caller does not own the returned slice.
func (mh *MetaHeadersFrame) PseudoFields() []hpack.HeaderField {
for i, hf := range mh.Fields {
if !hf.IsPseudo() {
return mh.Fields[:i]
}
}
return mh.Fields
}
func (mh *MetaHeadersFrame) checkPseudos() error {
var isRequest, isResponse bool
pf := mh.PseudoFields()
for i, hf := range pf {
switch hf.Name {
case ":method", ":path", ":scheme", ":authority":
isRequest = true
case ":status":
isResponse = true
default:
return pseudoHeaderError(hf.Name)
}
// Check for duplicates.
// This would be a bad algorithm, but N is 4.
// And this doesn't allocate.
for _, hf2 := range pf[:i] {
if hf.Name == hf2.Name {
return duplicatePseudoHeaderError(hf.Name)
}
}
}
if isRequest && isResponse {
return errMixPseudoHeaderTypes
}
return nil
}
func (fr *Framer) maxHeaderStringLen() int {
v := fr.maxHeaderListSize()
if uint32(int(v)) == v {
return int(v)
}
// They had a crazy big number for MaxHeaderBytes anyway,
// so give them unlimited header lengths:
return 0
}
// readMetaFrame returns 0 or more CONTINUATION frames from fr and
// merge them into into the provided hf and returns a MetaHeadersFrame
// with the decoded hpack values.
func (fr *Framer) readMetaFrame(hf *HeadersFrame) (*MetaHeadersFrame, error) {
if fr.AllowIllegalReads {
return nil, errors.New("illegal use of AllowIllegalReads with ReadMetaHeaders")
}
mh := &MetaHeadersFrame{
HeadersFrame: hf,
}
var remainSize = fr.maxHeaderListSize()
var sawRegular bool
var invalid error // pseudo header field errors
hdec := fr.ReadMetaHeaders
hdec.SetEmitEnabled(true)
hdec.SetMaxStringLength(fr.maxHeaderStringLen())
hdec.SetEmitFunc(func(hf hpack.HeaderField) {
if !validHeaderFieldValue(hf.Value) {
invalid = headerFieldValueError(hf.Value)
}
isPseudo := strings.HasPrefix(hf.Name, ":")
if isPseudo {
if sawRegular {
invalid = errPseudoAfterRegular
}
} else {
sawRegular = true
if !validHeaderFieldName(hf.Name) {
invalid = headerFieldNameError(hf.Name)
}
}
if invalid != nil {
hdec.SetEmitEnabled(false)
return
}
size := hf.Size()
if size > remainSize {
hdec.SetEmitEnabled(false)
mh.Truncated = true
return
}
remainSize -= size
mh.Fields = append(mh.Fields, hf)
})
// Lose reference to MetaHeadersFrame:
defer hdec.SetEmitFunc(func(hf hpack.HeaderField) {})
var hc headersOrContinuation = hf
for {
frag := hc.HeaderBlockFragment()
if _, err := hdec.Write(frag); err != nil {
return nil, ConnectionError(ErrCodeCompression)
}
if hc.HeadersEnded() {
break
}
if f, err := fr.ReadFrame(); err != nil {
return nil, err
} else {
hc = f.(*ContinuationFrame) // guaranteed by checkFrameOrder
}
}
mh.HeadersFrame.headerFragBuf = nil
mh.HeadersFrame.invalidate()
if err := hdec.Close(); err != nil {
return nil, ConnectionError(ErrCodeCompression)
}
if invalid != nil {
fr.errDetail = invalid
return nil, StreamError{mh.StreamID, ErrCodeProtocol}
}
if err := mh.checkPseudos(); err != nil {
fr.errDetail = err
return nil, StreamError{mh.StreamID, ErrCodeProtocol}
}
return mh, nil
}
func summarizeFrame(f Frame) string {
var buf bytes.Buffer
f.Header().writeDebug(&buf)
switch f := f.(type) {
case *SettingsFrame:
n := 0
f.ForeachSetting(func(s Setting) error {
n++
if n == 1 {
buf.WriteString(", settings:")
}
fmt.Fprintf(&buf, " %v=%v,", s.ID, s.Val)
return nil
})
if n > 0 {
buf.Truncate(buf.Len() - 1) // remove trailing comma
}
case *DataFrame:
data := f.Data()
const max = 256
if len(data) > max {
data = data[:max]
}
fmt.Fprintf(&buf, " data=%q", data)
if len(f.Data()) > max {
fmt.Fprintf(&buf, " (%d bytes omitted)", len(f.Data())-max)
}
case *WindowUpdateFrame:
if f.StreamID == 0 {
buf.WriteString(" (conn)")
}
fmt.Fprintf(&buf, " incr=%v", f.Increment)
case *PingFrame:
fmt.Fprintf(&buf, " ping=%q", f.Data[:])
case *GoAwayFrame:
fmt.Fprintf(&buf, " LastStreamID=%v ErrCode=%v Debug=%q",
f.LastStreamID, f.ErrCode, f.debugData)
case *RSTStreamFrame:
fmt.Fprintf(&buf, " ErrCode=%v", f.ErrCode)
}
return buf.String()
}

View file

@ -144,7 +144,7 @@ func (e *Encoder) SetMaxDynamicTableSizeLimit(v uint32) {
// shouldIndex reports whether f should be indexed. // shouldIndex reports whether f should be indexed.
func (e *Encoder) shouldIndex(f HeaderField) bool { func (e *Encoder) shouldIndex(f HeaderField) bool {
return !f.Sensitive && f.size() <= e.dynTab.maxSize return !f.Sensitive && f.Size() <= e.dynTab.maxSize
} }
// appendIndexed appends index i, as encoded in "Indexed Header Field" // appendIndexed appends index i, as encoded in "Indexed Header Field"

View file

@ -41,7 +41,24 @@ type HeaderField struct {
Sensitive bool Sensitive bool
} }
func (hf *HeaderField) size() uint32 { // IsPseudo reports whether the header field is an http2 pseudo header.
// That is, it reports whether it starts with a colon.
// It is not otherwise guaranteed to be a valid psuedo header field,
// though.
func (hf HeaderField) IsPseudo() bool {
return len(hf.Name) != 0 && hf.Name[0] == ':'
}
func (hf HeaderField) String() string {
var suffix string
if hf.Sensitive {
suffix = " (sensitive)"
}
return fmt.Sprintf("header field %q = %q%s", hf.Name, hf.Value, suffix)
}
// Size returns the size of an entry per RFC 7540 section 5.2.
func (hf HeaderField) Size() uint32 {
// http://http2.github.io/http2-spec/compression.html#rfc.section.4.1 // http://http2.github.io/http2-spec/compression.html#rfc.section.4.1
// "The size of the dynamic table is the sum of the size of // "The size of the dynamic table is the sum of the size of
// its entries. The size of an entry is the sum of its name's // its entries. The size of an entry is the sum of its name's
@ -163,7 +180,7 @@ func (dt *dynamicTable) setMaxSize(v uint32) {
func (dt *dynamicTable) add(f HeaderField) { func (dt *dynamicTable) add(f HeaderField) {
dt.ents = append(dt.ents, f) dt.ents = append(dt.ents, f)
dt.size += f.size() dt.size += f.Size()
dt.evict() dt.evict()
} }
@ -171,7 +188,7 @@ func (dt *dynamicTable) add(f HeaderField) {
func (dt *dynamicTable) evict() { func (dt *dynamicTable) evict() {
base := dt.ents // keep base pointer of slice base := dt.ents // keep base pointer of slice
for dt.size > dt.maxSize { for dt.size > dt.maxSize {
dt.size -= dt.ents[0].size() dt.size -= dt.ents[0].Size()
dt.ents = dt.ents[1:] dt.ents = dt.ents[1:]
} }

View file

@ -17,16 +17,35 @@ package http2
import ( import (
"bufio" "bufio"
"crypto/tls"
"errors"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"os" "os"
"sort"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
) )
var VerboseLogs = strings.Contains(os.Getenv("GODEBUG"), "h2debug=1") var (
VerboseLogs bool
logFrameWrites bool
logFrameReads bool
)
func init() {
e := os.Getenv("GODEBUG")
if strings.Contains(e, "http2debug=1") {
VerboseLogs = true
}
if strings.Contains(e, "http2debug=2") {
VerboseLogs = true
logFrameWrites = true
logFrameReads = true
}
}
const ( const (
// ClientPreface is the string that must be sent by new // ClientPreface is the string that must be sent by new
@ -142,17 +161,62 @@ func (s SettingID) String() string {
return fmt.Sprintf("UNKNOWN_SETTING_%d", uint16(s)) return fmt.Sprintf("UNKNOWN_SETTING_%d", uint16(s))
} }
func validHeader(v string) bool { var (
errInvalidHeaderFieldName = errors.New("http2: invalid header field name")
errInvalidHeaderFieldValue = errors.New("http2: invalid header field value")
)
// validHeaderFieldName reports whether v is a valid header field name (key).
// RFC 7230 says:
// header-field = field-name ":" OWS field-value OWS
// field-name = token
// tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." /
// "^" / "_" / "
// Further, http2 says:
// "Just as in HTTP/1.x, header field names are strings of ASCII
// characters that are compared in a case-insensitive
// fashion. However, header field names MUST be converted to
// lowercase prior to their encoding in HTTP/2. "
func validHeaderFieldName(v string) bool {
if len(v) == 0 { if len(v) == 0 {
return false return false
} }
for _, r := range v { for _, r := range v {
// "Just as in HTTP/1.x, header field names are if int(r) >= len(isTokenTable) || ('A' <= r && r <= 'Z') {
// strings of ASCII characters that are compared in a return false
// case-insensitive fashion. However, header field }
// names MUST be converted to lowercase prior to their if !isTokenTable[byte(r)] {
// encoding in HTTP/2. " return false
if r >= 127 || ('A' <= r && r <= 'Z') { }
}
return true
}
// validHeaderFieldValue reports whether v is a valid header field value.
//
// RFC 7230 says:
// field-value = *( field-content / obs-fold )
// obj-fold = N/A to http2, and deprecated
// field-content = field-vchar [ 1*( SP / HTAB ) field-vchar ]
// field-vchar = VCHAR / obs-text
// obs-text = %x80-FF
// VCHAR = "any visible [USASCII] character"
//
// http2 further says: "Similarly, HTTP/2 allows header field values
// that are not valid. While most of the values that can be encoded
// will not alter header field parsing, carriage return (CR, ASCII
// 0xd), line feed (LF, ASCII 0xa), and the zero character (NUL, ASCII
// 0x0) might be exploited by an attacker if they are translated
// verbatim. Any request or response that contains a character not
// permitted in a header field value MUST be treated as malformed
// (Section 8.1.2.6). Valid characters are defined by the
// field-content ABNF rule in Section 3.2 of [RFC7230]."
//
// This function does not (yet?) properly handle the rejection of
// strings that begin or end with SP or HTAB.
func validHeaderFieldValue(v string) bool {
for i := 0; i < len(v); i++ {
if b := v[i]; b < ' ' && b != '\t' || b == 0x7f {
return false return false
} }
} }
@ -269,3 +333,131 @@ func bodyAllowedForStatus(status int) bool {
} }
return true return true
} }
type httpError struct {
msg string
timeout bool
}
func (e *httpError) Error() string { return e.msg }
func (e *httpError) Timeout() bool { return e.timeout }
func (e *httpError) Temporary() bool { return true }
var errTimeout error = &httpError{msg: "http2: timeout awaiting response headers", timeout: true}
var isTokenTable = [127]bool{
'!': true,
'#': true,
'$': true,
'%': true,
'&': true,
'\'': true,
'*': true,
'+': true,
'-': true,
'.': true,
'0': true,
'1': true,
'2': true,
'3': true,
'4': true,
'5': true,
'6': true,
'7': true,
'8': true,
'9': true,
'A': true,
'B': true,
'C': true,
'D': true,
'E': true,
'F': true,
'G': true,
'H': true,
'I': true,
'J': true,
'K': true,
'L': true,
'M': true,
'N': true,
'O': true,
'P': true,
'Q': true,
'R': true,
'S': true,
'T': true,
'U': true,
'W': true,
'V': true,
'X': true,
'Y': true,
'Z': true,
'^': true,
'_': true,
'`': true,
'a': true,
'b': true,
'c': true,
'd': true,
'e': true,
'f': true,
'g': true,
'h': true,
'i': true,
'j': true,
'k': true,
'l': true,
'm': true,
'n': true,
'o': true,
'p': true,
'q': true,
'r': true,
's': true,
't': true,
'u': true,
'v': true,
'w': true,
'x': true,
'y': true,
'z': true,
'|': true,
'~': true,
}
type connectionStater interface {
ConnectionState() tls.ConnectionState
}
var sorterPool = sync.Pool{New: func() interface{} { return new(sorter) }}
type sorter struct {
v []string // owned by sorter
}
func (s *sorter) Len() int { return len(s.v) }
func (s *sorter) Swap(i, j int) { s.v[i], s.v[j] = s.v[j], s.v[i] }
func (s *sorter) Less(i, j int) bool { return s.v[i] < s.v[j] }
// Keys returns the sorted keys of h.
//
// The returned slice is only valid until s used again or returned to
// its pool.
func (s *sorter) Keys(h http.Header) []string {
keys := s.v[:0]
for k := range h {
keys = append(keys, k)
}
s.v = keys
sort.Sort(s)
return keys
}
func (s *sorter) SortStrings(ss []string) {
// Our sorter works on s.v, which sorter owners, so
// stash it away while we sort the user's buffer.
save := s.v
s.v = ss
sort.Sort(s)
s.v = save
}

View file

@ -8,6 +8,6 @@ package http2
import "net/http" import "net/http"
func configureTransport(t1 *http.Transport) error { func configureTransport(t1 *http.Transport) (*Transport, error) {
return errTransportVersion return nil, errTransportVersion
} }

View file

@ -6,8 +6,8 @@
// instead, and make sure that on close we close all open // instead, and make sure that on close we close all open
// streams. then remove doneServing? // streams. then remove doneServing?
// TODO: finish GOAWAY support. Consider each incoming frame type and // TODO: re-audit GOAWAY support. Consider each incoming frame type and
// whether it should be ignored during a shutdown race. // whether it should be ignored during graceful shutdown.
// TODO: disconnect idle clients. GFE seems to do 4 minutes. make // TODO: disconnect idle clients. GFE seems to do 4 minutes. make
// configurable? or maximum number of idle clients and remove the // configurable? or maximum number of idle clients and remove the
@ -48,6 +48,8 @@ import (
"net/http" "net/http"
"net/textproto" "net/textproto"
"net/url" "net/url"
"os"
"reflect"
"runtime" "runtime"
"strconv" "strconv"
"strings" "strings"
@ -192,28 +194,76 @@ func ConfigureServer(s *http.Server, conf *Server) error {
if testHookOnConn != nil { if testHookOnConn != nil {
testHookOnConn() testHookOnConn()
} }
conf.handleConn(hs, c, h) conf.ServeConn(c, &ServeConnOpts{
Handler: h,
BaseConfig: hs,
})
} }
s.TLSNextProto[NextProtoTLS] = protoHandler s.TLSNextProto[NextProtoTLS] = protoHandler
s.TLSNextProto["h2-14"] = protoHandler // temporary; see above. s.TLSNextProto["h2-14"] = protoHandler // temporary; see above.
return nil return nil
} }
func (srv *Server) handleConn(hs *http.Server, c net.Conn, h http.Handler) { // ServeConnOpts are options for the Server.ServeConn method.
type ServeConnOpts struct {
// BaseConfig optionally sets the base configuration
// for values. If nil, defaults are used.
BaseConfig *http.Server
// Handler specifies which handler to use for processing
// requests. If nil, BaseConfig.Handler is used. If BaseConfig
// or BaseConfig.Handler is nil, http.DefaultServeMux is used.
Handler http.Handler
}
func (o *ServeConnOpts) baseConfig() *http.Server {
if o != nil && o.BaseConfig != nil {
return o.BaseConfig
}
return new(http.Server)
}
func (o *ServeConnOpts) handler() http.Handler {
if o != nil {
if o.Handler != nil {
return o.Handler
}
if o.BaseConfig != nil && o.BaseConfig.Handler != nil {
return o.BaseConfig.Handler
}
}
return http.DefaultServeMux
}
// ServeConn serves HTTP/2 requests on the provided connection and
// blocks until the connection is no longer readable.
//
// ServeConn starts speaking HTTP/2 assuming that c has not had any
// reads or writes. It writes its initial settings frame and expects
// to be able to read the preface and settings frame from the
// client. If c has a ConnectionState method like a *tls.Conn, the
// ConnectionState is used to verify the TLS ciphersuite and to set
// the Request.TLS field in Handlers.
//
// ServeConn does not support h2c by itself. Any h2c support must be
// implemented in terms of providing a suitably-behaving net.Conn.
//
// The opts parameter is optional. If nil, default values are used.
func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
sc := &serverConn{ sc := &serverConn{
srv: srv, srv: s,
hs: hs, hs: opts.baseConfig(),
conn: c, conn: c,
remoteAddrStr: c.RemoteAddr().String(), remoteAddrStr: c.RemoteAddr().String(),
bw: newBufferedWriter(c), bw: newBufferedWriter(c),
handler: h, handler: opts.handler(),
streams: make(map[uint32]*stream), streams: make(map[uint32]*stream),
readFrameCh: make(chan readFrameResult), readFrameCh: make(chan readFrameResult),
wantWriteFrameCh: make(chan frameWriteMsg, 8), wantWriteFrameCh: make(chan frameWriteMsg, 8),
wroteFrameCh: make(chan frameWriteResult, 1), // buffered; one send in writeFrameAsync wroteFrameCh: make(chan frameWriteResult, 1), // buffered; one send in writeFrameAsync
bodyReadCh: make(chan bodyReadMsg), // buffering doesn't matter either way bodyReadCh: make(chan bodyReadMsg), // buffering doesn't matter either way
doneServing: make(chan struct{}), doneServing: make(chan struct{}),
advMaxStreams: srv.maxConcurrentStreams(), advMaxStreams: s.maxConcurrentStreams(),
writeSched: writeScheduler{ writeSched: writeScheduler{
maxFrameSize: initialMaxFrameSize, maxFrameSize: initialMaxFrameSize,
}, },
@ -225,14 +275,14 @@ func (srv *Server) handleConn(hs *http.Server, c net.Conn, h http.Handler) {
sc.flow.add(initialWindowSize) sc.flow.add(initialWindowSize)
sc.inflow.add(initialWindowSize) sc.inflow.add(initialWindowSize)
sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf) sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf)
sc.hpackDecoder = hpack.NewDecoder(initialHeaderTableSize, nil)
sc.hpackDecoder.SetMaxStringLength(sc.maxHeaderStringLen())
fr := NewFramer(sc.bw, c) fr := NewFramer(sc.bw, c)
fr.SetMaxReadFrameSize(srv.maxReadFrameSize()) fr.ReadMetaHeaders = hpack.NewDecoder(initialHeaderTableSize, nil)
fr.MaxHeaderListSize = sc.maxHeaderListSize()
fr.SetMaxReadFrameSize(s.maxReadFrameSize())
sc.framer = fr sc.framer = fr
if tc, ok := c.(*tls.Conn); ok { if tc, ok := c.(connectionStater); ok {
sc.tlsState = new(tls.ConnectionState) sc.tlsState = new(tls.ConnectionState)
*sc.tlsState = tc.ConnectionState() *sc.tlsState = tc.ConnectionState()
// 9.2 Use of TLS Features // 9.2 Use of TLS Features
@ -262,7 +312,7 @@ func (srv *Server) handleConn(hs *http.Server, c net.Conn, h http.Handler) {
// So for now, do nothing here again. // So for now, do nothing here again.
} }
if !srv.PermitProhibitedCipherSuites && isBadCipher(sc.tlsState.CipherSuite) { if !s.PermitProhibitedCipherSuites && isBadCipher(sc.tlsState.CipherSuite) {
// "Endpoints MAY choose to generate a connection error // "Endpoints MAY choose to generate a connection error
// (Section 5.4.1) of type INADEQUATE_SECURITY if one of // (Section 5.4.1) of type INADEQUATE_SECURITY if one of
// the prohibited cipher suites are negotiated." // the prohibited cipher suites are negotiated."
@ -309,7 +359,7 @@ func isBadCipher(cipher uint16) bool {
} }
func (sc *serverConn) rejectConn(err ErrCode, debug string) { func (sc *serverConn) rejectConn(err ErrCode, debug string) {
sc.vlogf("REJECTING conn: %v, %s", err, debug) sc.vlogf("http2: server rejecting conn: %v, %s", err, debug)
// ignoring errors. hanging up anyway. // ignoring errors. hanging up anyway.
sc.framer.WriteGoAway(0, err, []byte(debug)) sc.framer.WriteGoAway(0, err, []byte(debug))
sc.bw.Flush() sc.bw.Flush()
@ -324,7 +374,6 @@ type serverConn struct {
bw *bufferedWriter // writing to conn bw *bufferedWriter // writing to conn
handler http.Handler handler http.Handler
framer *Framer framer *Framer
hpackDecoder *hpack.Decoder
doneServing chan struct{} // closed when serverConn.serve ends doneServing chan struct{} // closed when serverConn.serve ends
readFrameCh chan readFrameResult // written by serverConn.readFrames readFrameCh chan readFrameResult // written by serverConn.readFrames
wantWriteFrameCh chan frameWriteMsg // from handlers -> serve wantWriteFrameCh chan frameWriteMsg // from handlers -> serve
@ -351,7 +400,6 @@ type serverConn struct {
headerTableSize uint32 headerTableSize uint32
peerMaxHeaderListSize uint32 // zero means unknown (default) peerMaxHeaderListSize uint32 // zero means unknown (default)
canonHeader map[string]string // http2-lower-case -> Go-Canonical-Case canonHeader map[string]string // http2-lower-case -> Go-Canonical-Case
req requestParam // non-zero while reading request headers
writingFrame bool // started write goroutine but haven't heard back on wroteFrameCh writingFrame bool // started write goroutine but haven't heard back on wroteFrameCh
needsFrameFlush bool // last frame write wasn't a flush needsFrameFlush bool // last frame write wasn't a flush
writeSched writeScheduler writeSched writeScheduler
@ -360,22 +408,13 @@ type serverConn struct {
goAwayCode ErrCode goAwayCode ErrCode
shutdownTimerCh <-chan time.Time // nil until used shutdownTimerCh <-chan time.Time // nil until used
shutdownTimer *time.Timer // nil until used shutdownTimer *time.Timer // nil until used
freeRequestBodyBuf []byte // if non-nil, a free initialWindowSize buffer for getRequestBodyBuf
// Owned by the writeFrameAsync goroutine: // Owned by the writeFrameAsync goroutine:
headerWriteBuf bytes.Buffer headerWriteBuf bytes.Buffer
hpackEncoder *hpack.Encoder hpackEncoder *hpack.Encoder
} }
func (sc *serverConn) maxHeaderStringLen() int {
v := sc.maxHeaderListSize()
if uint32(int(v)) == v {
return int(v)
}
// They had a crazy big number for MaxHeaderBytes anyway,
// so give them unlimited header lengths:
return 0
}
func (sc *serverConn) maxHeaderListSize() uint32 { func (sc *serverConn) maxHeaderListSize() uint32 {
n := sc.hs.MaxHeaderBytes n := sc.hs.MaxHeaderBytes
if n <= 0 { if n <= 0 {
@ -388,21 +427,6 @@ func (sc *serverConn) maxHeaderListSize() uint32 {
return uint32(n + typicalHeaders*perFieldOverhead) return uint32(n + typicalHeaders*perFieldOverhead)
} }
// requestParam is the state of the next request, initialized over
// potentially several frames HEADERS + zero or more CONTINUATION
// frames.
type requestParam struct {
// stream is non-nil if we're reading (HEADER or CONTINUATION)
// frames for a request (but not DATA).
stream *stream
header http.Header
method, path string
scheme, authority string
sawRegularHeader bool // saw a non-pseudo header already
invalidHeader bool // an invalid header was seen
headerListSize int64 // actually uint32, but easier math this way
}
// stream represents a stream. This is the minimal metadata needed by // stream represents a stream. This is the minimal metadata needed by
// the serve goroutine. Most of the actual stream state is owned by // the serve goroutine. Most of the actual stream state is owned by
// the http.Handler's goroutine in the responseWriter. Because the // the http.Handler's goroutine in the responseWriter. Because the
@ -429,6 +453,7 @@ type stream struct {
sentReset bool // only true once detached from streams map sentReset bool // only true once detached from streams map
gotReset bool // only true once detacted from streams map gotReset bool // only true once detacted from streams map
gotTrailerHeader bool // HEADER frame for trailers was seen gotTrailerHeader bool // HEADER frame for trailers was seen
reqBuf []byte
trailer http.Header // accumulated trailers trailer http.Header // accumulated trailers
reqTrailer http.Header // handler's Request.Trailer reqTrailer http.Header // handler's Request.Trailer
@ -482,12 +507,55 @@ func (sc *serverConn) logf(format string, args ...interface{}) {
} }
} }
// errno returns v's underlying uintptr, else 0.
//
// TODO: remove this helper function once http2 can use build
// tags. See comment in isClosedConnError.
func errno(v error) uintptr {
if rv := reflect.ValueOf(v); rv.Kind() == reflect.Uintptr {
return uintptr(rv.Uint())
}
return 0
}
// isClosedConnError reports whether err is an error from use of a closed
// network connection.
func isClosedConnError(err error) bool {
if err == nil {
return false
}
// TODO: remove this string search and be more like the Windows
// case below. That might involve modifying the standard library
// to return better error types.
str := err.Error()
if strings.Contains(str, "use of closed network connection") {
return true
}
// TODO(bradfitz): x/tools/cmd/bundle doesn't really support
// build tags, so I can't make an http2_windows.go file with
// Windows-specific stuff. Fix that and move this, once we
// have a way to bundle this into std's net/http somehow.
if runtime.GOOS == "windows" {
if oe, ok := err.(*net.OpError); ok && oe.Op == "read" {
if se, ok := oe.Err.(*os.SyscallError); ok && se.Syscall == "wsarecv" {
const WSAECONNABORTED = 10053
const WSAECONNRESET = 10054
if n := errno(se.Err); n == WSAECONNRESET || n == WSAECONNABORTED {
return true
}
}
}
}
return false
}
func (sc *serverConn) condlogf(err error, format string, args ...interface{}) { func (sc *serverConn) condlogf(err error, format string, args ...interface{}) {
if err == nil { if err == nil {
return return
} }
str := err.Error() if err == io.EOF || err == io.ErrUnexpectedEOF || isClosedConnError(err) {
if err == io.EOF || strings.Contains(str, "use of closed network connection") {
// Boring, expected errors. // Boring, expected errors.
sc.vlogf(format, args...) sc.vlogf(format, args...)
} else { } else {
@ -495,86 +563,6 @@ func (sc *serverConn) condlogf(err error, format string, args ...interface{}) {
} }
} }
func (sc *serverConn) onNewHeaderField(f hpack.HeaderField) {
sc.serveG.check()
sc.vlogf("got header field %+v", f)
switch {
case !validHeader(f.Name):
sc.req.invalidHeader = true
case strings.HasPrefix(f.Name, ":"):
if sc.req.sawRegularHeader {
sc.logf("pseudo-header after regular header")
sc.req.invalidHeader = true
return
}
var dst *string
switch f.Name {
case ":method":
dst = &sc.req.method
case ":path":
dst = &sc.req.path
case ":scheme":
dst = &sc.req.scheme
case ":authority":
dst = &sc.req.authority
default:
// 8.1.2.1 Pseudo-Header Fields
// "Endpoints MUST treat a request or response
// that contains undefined or invalid
// pseudo-header fields as malformed (Section
// 8.1.2.6)."
sc.logf("invalid pseudo-header %q", f.Name)
sc.req.invalidHeader = true
return
}
if *dst != "" {
sc.logf("duplicate pseudo-header %q sent", f.Name)
sc.req.invalidHeader = true
return
}
*dst = f.Value
default:
sc.req.sawRegularHeader = true
sc.req.header.Add(sc.canonicalHeader(f.Name), f.Value)
const headerFieldOverhead = 32 // per spec
sc.req.headerListSize += int64(len(f.Name)) + int64(len(f.Value)) + headerFieldOverhead
if sc.req.headerListSize > int64(sc.maxHeaderListSize()) {
sc.hpackDecoder.SetEmitEnabled(false)
}
}
}
func (st *stream) onNewTrailerField(f hpack.HeaderField) {
sc := st.sc
sc.serveG.check()
sc.vlogf("got trailer field %+v", f)
switch {
case !validHeader(f.Name):
// TODO: change hpack signature so this can return
// errors? Or stash an error somewhere on st or sc
// for processHeaderBlockFragment etc to pick up and
// return after the hpack Write/Close. For now just
// ignore.
return
case strings.HasPrefix(f.Name, ":"):
// TODO: same TODO as above.
return
default:
key := sc.canonicalHeader(f.Name)
if st.trailer != nil {
vv := append(st.trailer[key], f.Value)
st.trailer[key] = vv
// arbitrary; TODO: read spec about header list size limits wrt trailers
const tooBig = 1000
if len(vv) >= tooBig {
sc.hpackDecoder.SetEmitEnabled(false)
}
}
}
}
func (sc *serverConn) canonicalHeader(v string) string { func (sc *serverConn) canonicalHeader(v string) string {
sc.serveG.check() sc.serveG.check()
cv, ok := commonCanonHeader[v] cv, ok := commonCanonHeader[v]
@ -609,10 +597,11 @@ type readFrameResult struct {
// It's run on its own goroutine. // It's run on its own goroutine.
func (sc *serverConn) readFrames() { func (sc *serverConn) readFrames() {
gate := make(gate) gate := make(gate)
gateDone := gate.Done
for { for {
f, err := sc.framer.ReadFrame() f, err := sc.framer.ReadFrame()
select { select {
case sc.readFrameCh <- readFrameResult{f, err, gate.Done}: case sc.readFrameCh <- readFrameResult{f, err, gateDone}:
case <-sc.doneServing: case <-sc.doneServing:
return return
} }
@ -679,7 +668,9 @@ func (sc *serverConn) serve() {
defer sc.stopShutdownTimer() defer sc.stopShutdownTimer()
defer close(sc.doneServing) // unblocks handlers trying to send defer close(sc.doneServing) // unblocks handlers trying to send
sc.vlogf("HTTP/2 connection from %v on %p", sc.conn.RemoteAddr(), sc.hs) if VerboseLogs {
sc.vlogf("http2: server connection from %v on %p", sc.conn.RemoteAddr(), sc.hs)
}
sc.writeFrame(frameWriteMsg{ sc.writeFrame(frameWriteMsg{
write: writeSettings{ write: writeSettings{
@ -696,7 +687,7 @@ func (sc *serverConn) serve() {
sc.unackedSettings++ sc.unackedSettings++
if err := sc.readPreface(); err != nil { if err := sc.readPreface(); err != nil {
sc.condlogf(err, "error reading preface from client %v: %v", sc.conn.RemoteAddr(), err) sc.condlogf(err, "http2: server: error reading preface from client %v: %v", sc.conn.RemoteAddr(), err)
return return
} }
// Now that we've got the preface, get us out of the // Now that we've got the preface, get us out of the
@ -762,7 +753,9 @@ func (sc *serverConn) readPreface() error {
return errors.New("timeout waiting for client preface") return errors.New("timeout waiting for client preface")
case err := <-errc: case err := <-errc:
if err == nil { if err == nil {
sc.vlogf("client %v said hello", sc.conn.RemoteAddr()) if VerboseLogs {
sc.vlogf("http2: server: client %v said hello", sc.conn.RemoteAddr())
}
} }
return err return err
} }
@ -1026,7 +1019,7 @@ func (sc *serverConn) processFrameFromReader(res readFrameResult) bool {
sc.goAway(ErrCodeFrameSize) sc.goAway(ErrCodeFrameSize)
return true // goAway will close the loop return true // goAway will close the loop
} }
clientGone := err == io.EOF || strings.Contains(err.Error(), "use of closed network connection") clientGone := err == io.EOF || err == io.ErrUnexpectedEOF || isClosedConnError(err)
if clientGone { if clientGone {
// TODO: could we also get into this state if // TODO: could we also get into this state if
// the peer does a half close // the peer does a half close
@ -1040,7 +1033,9 @@ func (sc *serverConn) processFrameFromReader(res readFrameResult) bool {
} }
} else { } else {
f := res.f f := res.f
sc.vlogf("got %v: %#v", f.Header(), f) if VerboseLogs {
sc.vlogf("http2: server read frame %v", summarizeFrame(f))
}
err = sc.processFrame(f) err = sc.processFrame(f)
if err == nil { if err == nil {
return true return true
@ -1055,14 +1050,14 @@ func (sc *serverConn) processFrameFromReader(res readFrameResult) bool {
sc.goAway(ErrCodeFlowControl) sc.goAway(ErrCodeFlowControl)
return true return true
case ConnectionError: case ConnectionError:
sc.logf("%v: %v", sc.conn.RemoteAddr(), ev) sc.logf("http2: server connection error from %v: %v", sc.conn.RemoteAddr(), ev)
sc.goAway(ErrCode(ev)) sc.goAway(ErrCode(ev))
return true // goAway will handle shutdown return true // goAway will handle shutdown
default: default:
if res.err != nil { if res.err != nil {
sc.logf("disconnecting; error reading frame from client %s: %v", sc.conn.RemoteAddr(), err) sc.vlogf("http2: server closing client connection; error reading frame from client %s: %v", sc.conn.RemoteAddr(), err)
} else { } else {
sc.logf("disconnection due to other error: %v", err) sc.logf("http2: server closing client connection: %v", err)
} }
return false return false
} }
@ -1082,10 +1077,8 @@ func (sc *serverConn) processFrame(f Frame) error {
switch f := f.(type) { switch f := f.(type) {
case *SettingsFrame: case *SettingsFrame:
return sc.processSettings(f) return sc.processSettings(f)
case *HeadersFrame: case *MetaHeadersFrame:
return sc.processHeaders(f) return sc.processHeaders(f)
case *ContinuationFrame:
return sc.processContinuation(f)
case *WindowUpdateFrame: case *WindowUpdateFrame:
return sc.processWindowUpdate(f) return sc.processWindowUpdate(f)
case *PingFrame: case *PingFrame:
@ -1101,7 +1094,7 @@ func (sc *serverConn) processFrame(f Frame) error {
// frame as a connection error (Section 5.4.1) of type PROTOCOL_ERROR. // frame as a connection error (Section 5.4.1) of type PROTOCOL_ERROR.
return ConnectionError(ErrCodeProtocol) return ConnectionError(ErrCodeProtocol)
default: default:
sc.vlogf("Ignoring frame: %v", f.Header()) sc.vlogf("http2: server ignoring frame: %v", f.Header())
return nil return nil
} }
} }
@ -1185,6 +1178,18 @@ func (sc *serverConn) closeStream(st *stream, err error) {
} }
st.cw.Close() // signals Handler's CloseNotifier, unblocks writes, etc st.cw.Close() // signals Handler's CloseNotifier, unblocks writes, etc
sc.writeSched.forgetStream(st.id) sc.writeSched.forgetStream(st.id)
if st.reqBuf != nil {
// Stash this request body buffer (64k) away for reuse
// by a future POST/PUT/etc.
//
// TODO(bradfitz): share on the server? sync.Pool?
// Server requires locks and might hurt contention.
// sync.Pool might work, or might be worse, depending
// on goroutine CPU migrations. (get and put on
// separate CPUs). Maybe a mix of strategies. But
// this is an easy win for now.
sc.freeRequestBodyBuf = st.reqBuf
}
} }
func (sc *serverConn) processSettings(f *SettingsFrame) error { func (sc *serverConn) processSettings(f *SettingsFrame) error {
@ -1212,7 +1217,9 @@ func (sc *serverConn) processSetting(s Setting) error {
if err := s.Valid(); err != nil { if err := s.Valid(); err != nil {
return err return err
} }
sc.vlogf("processing setting %v", s) if VerboseLogs {
sc.vlogf("http2: server processing setting %v", s)
}
switch s.ID { switch s.ID {
case SettingHeaderTableSize: case SettingHeaderTableSize:
sc.headerTableSize = s.Val sc.headerTableSize = s.Val
@ -1231,6 +1238,9 @@ func (sc *serverConn) processSetting(s Setting) error {
// Unknown setting: "An endpoint that receives a SETTINGS // Unknown setting: "An endpoint that receives a SETTINGS
// frame with any unknown or unsupported identifier MUST // frame with any unknown or unsupported identifier MUST
// ignore that setting." // ignore that setting."
if VerboseLogs {
sc.vlogf("http2: server ignoring unknown setting %v", s)
}
} }
return nil return nil
} }
@ -1336,7 +1346,7 @@ func (st *stream) copyTrailersToHandlerRequest() {
} }
} }
func (sc *serverConn) processHeaders(f *HeadersFrame) error { func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error {
sc.serveG.check() sc.serveG.check()
id := f.Header().StreamID id := f.Header().StreamID
if sc.inGoAway { if sc.inGoAway {
@ -1365,13 +1375,11 @@ func (sc *serverConn) processHeaders(f *HeadersFrame) error {
// endpoint has opened or reserved. [...] An endpoint that // endpoint has opened or reserved. [...] An endpoint that
// receives an unexpected stream identifier MUST respond with // receives an unexpected stream identifier MUST respond with
// a connection error (Section 5.4.1) of type PROTOCOL_ERROR. // a connection error (Section 5.4.1) of type PROTOCOL_ERROR.
if id <= sc.maxStreamID || sc.req.stream != nil { if id <= sc.maxStreamID {
return ConnectionError(ErrCodeProtocol) return ConnectionError(ErrCodeProtocol)
} }
sc.maxStreamID = id
if id > sc.maxStreamID {
sc.maxStreamID = id
}
st = &stream{ st = &stream{
sc: sc, sc: sc,
id: id, id: id,
@ -1395,46 +1403,6 @@ func (sc *serverConn) processHeaders(f *HeadersFrame) error {
if sc.curOpenStreams == 1 { if sc.curOpenStreams == 1 {
sc.setConnState(http.StateActive) sc.setConnState(http.StateActive)
} }
sc.req = requestParam{
stream: st,
header: make(http.Header),
}
sc.hpackDecoder.SetEmitFunc(sc.onNewHeaderField)
sc.hpackDecoder.SetEmitEnabled(true)
return sc.processHeaderBlockFragment(st, f.HeaderBlockFragment(), f.HeadersEnded())
}
func (st *stream) processTrailerHeaders(f *HeadersFrame) error {
sc := st.sc
sc.serveG.check()
if st.gotTrailerHeader {
return ConnectionError(ErrCodeProtocol)
}
st.gotTrailerHeader = true
return st.processTrailerHeaderBlockFragment(f.HeaderBlockFragment(), f.HeadersEnded())
}
func (sc *serverConn) processContinuation(f *ContinuationFrame) error {
sc.serveG.check()
st := sc.streams[f.Header().StreamID]
if st.gotTrailerHeader {
return st.processTrailerHeaderBlockFragment(f.HeaderBlockFragment(), f.HeadersEnded())
}
return sc.processHeaderBlockFragment(st, f.HeaderBlockFragment(), f.HeadersEnded())
}
func (sc *serverConn) processHeaderBlockFragment(st *stream, frag []byte, end bool) error {
sc.serveG.check()
if _, err := sc.hpackDecoder.Write(frag); err != nil {
return ConnectionError(ErrCodeCompression)
}
if !end {
return nil
}
if err := sc.hpackDecoder.Close(); err != nil {
return ConnectionError(ErrCodeCompression)
}
defer sc.resetPendingRequest()
if sc.curOpenStreams > sc.advMaxStreams { if sc.curOpenStreams > sc.advMaxStreams {
// "Endpoints MUST NOT exceed the limit set by their // "Endpoints MUST NOT exceed the limit set by their
// peer. An endpoint that receives a HEADERS frame // peer. An endpoint that receives a HEADERS frame
@ -1454,7 +1422,7 @@ func (sc *serverConn) processHeaderBlockFragment(st *stream, frag []byte, end bo
return StreamError{st.id, ErrCodeRefusedStream} return StreamError{st.id, ErrCodeRefusedStream}
} }
rw, req, err := sc.newWriterAndRequest() rw, req, err := sc.newWriterAndRequest(st, f)
if err != nil { if err != nil {
return err return err
} }
@ -1466,7 +1434,7 @@ func (sc *serverConn) processHeaderBlockFragment(st *stream, frag []byte, end bo
st.declBodyBytes = req.ContentLength st.declBodyBytes = req.ContentLength
handler := sc.handler.ServeHTTP handler := sc.handler.ServeHTTP
if !sc.hpackDecoder.EmitEnabled() { if f.Truncated {
// Their header list was too long. Send a 431 error. // Their header list was too long. Send a 431 error.
handler = handleHeaderListTooLong handler = handleHeaderListTooLong
} }
@ -1475,21 +1443,27 @@ func (sc *serverConn) processHeaderBlockFragment(st *stream, frag []byte, end bo
return nil return nil
} }
func (st *stream) processTrailerHeaderBlockFragment(frag []byte, end bool) error { func (st *stream) processTrailerHeaders(f *MetaHeadersFrame) error {
sc := st.sc sc := st.sc
sc.serveG.check() sc.serveG.check()
sc.hpackDecoder.SetEmitFunc(st.onNewTrailerField) if st.gotTrailerHeader {
if _, err := sc.hpackDecoder.Write(frag); err != nil { return ConnectionError(ErrCodeProtocol)
return ConnectionError(ErrCodeCompression)
} }
if !end { st.gotTrailerHeader = true
return nil if !f.StreamEnded() {
return StreamError{st.id, ErrCodeProtocol}
}
if len(f.PseudoFields()) > 0 {
return StreamError{st.id, ErrCodeProtocol}
}
if st.trailer != nil {
for _, hf := range f.RegularFields() {
key := sc.canonicalHeader(hf.Name)
st.trailer[key] = append(st.trailer[key], hf.Value)
}
} }
err := sc.hpackDecoder.Close()
st.endStream() st.endStream()
if err != nil {
return ConnectionError(ErrCodeCompression)
}
return nil return nil
} }
@ -1534,19 +1508,21 @@ func adjustStreamPriority(streams map[uint32]*stream, streamID uint32, priority
} }
} }
// resetPendingRequest zeros out all state related to a HEADERS frame func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*responseWriter, *http.Request, error) {
// and its zero or more CONTINUATION frames sent to start a new
// request.
func (sc *serverConn) resetPendingRequest() {
sc.serveG.check() sc.serveG.check()
sc.req = requestParam{}
}
func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, error) { method := f.PseudoValue("method")
sc.serveG.check() path := f.PseudoValue("path")
rp := &sc.req scheme := f.PseudoValue("scheme")
if rp.invalidHeader || rp.method == "" || rp.path == "" || authority := f.PseudoValue("authority")
(rp.scheme != "https" && rp.scheme != "http") {
isConnect := method == "CONNECT"
if isConnect {
if path != "" || scheme != "" || authority == "" {
return nil, nil, StreamError{f.StreamID, ErrCodeProtocol}
}
} else if method == "" || path == "" ||
(scheme != "https" && scheme != "http") {
// See 8.1.2.6 Malformed Requests and Responses: // See 8.1.2.6 Malformed Requests and Responses:
// //
// Malformed requests or responses that are detected // Malformed requests or responses that are detected
@ -1557,33 +1533,40 @@ func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, err
// "All HTTP/2 requests MUST include exactly one valid // "All HTTP/2 requests MUST include exactly one valid
// value for the :method, :scheme, and :path // value for the :method, :scheme, and :path
// pseudo-header fields" // pseudo-header fields"
return nil, nil, StreamError{rp.stream.id, ErrCodeProtocol} return nil, nil, StreamError{f.StreamID, ErrCodeProtocol}
} }
bodyOpen := rp.stream.state == stateOpen
if rp.method == "HEAD" && bodyOpen { bodyOpen := !f.StreamEnded()
if method == "HEAD" && bodyOpen {
// HEAD requests can't have bodies // HEAD requests can't have bodies
return nil, nil, StreamError{rp.stream.id, ErrCodeProtocol} return nil, nil, StreamError{f.StreamID, ErrCodeProtocol}
} }
var tlsState *tls.ConnectionState // nil if not scheme https var tlsState *tls.ConnectionState // nil if not scheme https
if rp.scheme == "https" {
if scheme == "https" {
tlsState = sc.tlsState tlsState = sc.tlsState
} }
authority := rp.authority
if authority == "" { header := make(http.Header)
authority = rp.header.Get("Host") for _, hf := range f.RegularFields() {
header.Add(sc.canonicalHeader(hf.Name), hf.Value)
} }
needsContinue := rp.header.Get("Expect") == "100-continue"
if authority == "" {
authority = header.Get("Host")
}
needsContinue := header.Get("Expect") == "100-continue"
if needsContinue { if needsContinue {
rp.header.Del("Expect") header.Del("Expect")
} }
// Merge Cookie headers into one "; "-delimited value. // Merge Cookie headers into one "; "-delimited value.
if cookies := rp.header["Cookie"]; len(cookies) > 1 { if cookies := header["Cookie"]; len(cookies) > 1 {
rp.header.Set("Cookie", strings.Join(cookies, "; ")) header.Set("Cookie", strings.Join(cookies, "; "))
} }
// Setup Trailers // Setup Trailers
var trailer http.Header var trailer http.Header
for _, v := range rp.header["Trailer"] { for _, v := range header["Trailer"] {
for _, key := range strings.Split(v, ",") { for _, key := range strings.Split(v, ",") {
key = http.CanonicalHeaderKey(strings.TrimSpace(key)) key = http.CanonicalHeaderKey(strings.TrimSpace(key))
switch key { switch key {
@ -1598,25 +1581,32 @@ func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, err
} }
} }
} }
delete(rp.header, "Trailer") delete(header, "Trailer")
body := &requestBody{ body := &requestBody{
conn: sc, conn: sc,
stream: rp.stream, stream: st,
needsContinue: needsContinue, needsContinue: needsContinue,
} }
// TODO: handle asterisk '*' requests + test var url_ *url.URL
url, err := url.ParseRequestURI(rp.path) var requestURI string
if err != nil { if isConnect {
// TODO: find the right error code? url_ = &url.URL{Host: authority}
return nil, nil, StreamError{rp.stream.id, ErrCodeProtocol} requestURI = authority // mimic HTTP/1 server behavior
} else {
var err error
url_, err = url.ParseRequestURI(path)
if err != nil {
return nil, nil, StreamError{f.StreamID, ErrCodeProtocol}
}
requestURI = path
} }
req := &http.Request{ req := &http.Request{
Method: rp.method, Method: method,
URL: url, URL: url_,
RemoteAddr: sc.remoteAddrStr, RemoteAddr: sc.remoteAddrStr,
Header: rp.header, Header: header,
RequestURI: rp.path, RequestURI: requestURI,
Proto: "HTTP/2.0", Proto: "HTTP/2.0",
ProtoMajor: 2, ProtoMajor: 2,
ProtoMinor: 0, ProtoMinor: 0,
@ -1626,11 +1616,12 @@ func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, err
Trailer: trailer, Trailer: trailer,
} }
if bodyOpen { if bodyOpen {
st.reqBuf = sc.getRequestBodyBuf()
body.pipe = &pipe{ body.pipe = &pipe{
b: &fixedBuffer{buf: make([]byte, initialWindowSize)}, // TODO: garbage b: &fixedBuffer{buf: st.reqBuf},
} }
if vv, ok := rp.header["Content-Length"]; ok { if vv, ok := header["Content-Length"]; ok {
req.ContentLength, _ = strconv.ParseInt(vv[0], 10, 64) req.ContentLength, _ = strconv.ParseInt(vv[0], 10, 64)
} else { } else {
req.ContentLength = -1 req.ContentLength = -1
@ -1643,7 +1634,7 @@ func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, err
rws.conn = sc rws.conn = sc
rws.bw = bwSave rws.bw = bwSave
rws.bw.Reset(chunkWriter{rws}) rws.bw.Reset(chunkWriter{rws})
rws.stream = rp.stream rws.stream = st
rws.req = req rws.req = req
rws.body = body rws.body = body
@ -1651,6 +1642,15 @@ func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, err
return rw, req, nil return rw, req, nil
} }
func (sc *serverConn) getRequestBodyBuf() []byte {
sc.serveG.check()
if buf := sc.freeRequestBodyBuf; buf != nil {
sc.freeRequestBodyBuf = nil
return buf
}
return make([]byte, initialWindowSize)
}
// Run on its own goroutine. // Run on its own goroutine.
func (sc *serverConn) runHandler(rw *responseWriter, req *http.Request, handler func(http.ResponseWriter, *http.Request)) { func (sc *serverConn) runHandler(rw *responseWriter, req *http.Request, handler func(http.ResponseWriter, *http.Request)) {
didPanic := true didPanic := true
@ -1887,7 +1887,9 @@ func (rws *responseWriterState) declareTrailer(k string) {
// Forbidden by RFC 2616 14.40. // Forbidden by RFC 2616 14.40.
return return
} }
rws.trailers = append(rws.trailers, k) if !strSliceContains(rws.trailers, k) {
rws.trailers = append(rws.trailers, k)
}
} }
// writeChunk writes chunks from the bufio.Writer. But because // writeChunk writes chunks from the bufio.Writer. But because
@ -1955,6 +1957,10 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
return 0, nil return 0, nil
} }
if rws.handlerDone {
rws.promoteUndeclaredTrailers()
}
endStream := rws.handlerDone && !rws.hasTrailers() endStream := rws.handlerDone && !rws.hasTrailers()
if len(p) > 0 || endStream { if len(p) > 0 || endStream {
// only send a 0 byte DATA frame if we're ending the stream. // only send a 0 byte DATA frame if we're ending the stream.
@ -1975,6 +1981,58 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
return len(p), nil return len(p), nil
} }
// TrailerPrefix is a magic prefix for ResponseWriter.Header map keys
// that, if present, signals that the map entry is actually for
// the response trailers, and not the response headers. The prefix
// is stripped after the ServeHTTP call finishes and the values are
// sent in the trailers.
//
// This mechanism is intended only for trailers that are not known
// prior to the headers being written. If the set of trailers is fixed
// or known before the header is written, the normal Go trailers mechanism
// is preferred:
// https://golang.org/pkg/net/http/#ResponseWriter
// https://golang.org/pkg/net/http/#example_ResponseWriter_trailers
const TrailerPrefix = "Trailer:"
// promoteUndeclaredTrailers permits http.Handlers to set trailers
// after the header has already been flushed. Because the Go
// ResponseWriter interface has no way to set Trailers (only the
// Header), and because we didn't want to expand the ResponseWriter
// interface, and because nobody used trailers, and because RFC 2616
// says you SHOULD (but not must) predeclare any trailers in the
// header, the official ResponseWriter rules said trailers in Go must
// be predeclared, and then we reuse the same ResponseWriter.Header()
// map to mean both Headers and Trailers. When it's time to write the
// Trailers, we pick out the fields of Headers that were declared as
// trailers. That worked for a while, until we found the first major
// user of Trailers in the wild: gRPC (using them only over http2),
// and gRPC libraries permit setting trailers mid-stream without
// predeclarnig them. So: change of plans. We still permit the old
// way, but we also permit this hack: if a Header() key begins with
// "Trailer:", the suffix of that key is a Trailer. Because ':' is an
// invalid token byte anyway, there is no ambiguity. (And it's already
// filtered out) It's mildly hacky, but not terrible.
//
// This method runs after the Handler is done and promotes any Header
// fields to be trailers.
func (rws *responseWriterState) promoteUndeclaredTrailers() {
for k, vv := range rws.handlerHeader {
if !strings.HasPrefix(k, TrailerPrefix) {
continue
}
trailerKey := strings.TrimPrefix(k, TrailerPrefix)
rws.declareTrailer(trailerKey)
rws.handlerHeader[http.CanonicalHeaderKey(trailerKey)] = vv
}
if len(rws.trailers) > 1 {
sorter := sorterPool.Get().(*sorter)
sorter.SortStrings(rws.trailers)
sorterPool.Put(sorter)
}
}
func (w *responseWriter) Flush() { func (w *responseWriter) Flush() {
rws := w.rws rws := w.rws
if rws == nil { if rws == nil {

File diff suppressed because it is too large Load diff

View file

@ -7,8 +7,8 @@ package http2
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"log"
"net/http" "net/http"
"sort"
"time" "time"
"golang.org/x/net/http2/hpack" "golang.org/x/net/http2/hpack"
@ -136,27 +136,31 @@ type writeResHeaders struct {
contentLength string contentLength string
} }
func encKV(enc *hpack.Encoder, k, v string) {
if VerboseLogs {
log.Printf("http2: server encoding header %q = %q", k, v)
}
enc.WriteField(hpack.HeaderField{Name: k, Value: v})
}
func (w *writeResHeaders) writeFrame(ctx writeContext) error { func (w *writeResHeaders) writeFrame(ctx writeContext) error {
enc, buf := ctx.HeaderEncoder() enc, buf := ctx.HeaderEncoder()
buf.Reset() buf.Reset()
if w.httpResCode != 0 { if w.httpResCode != 0 {
enc.WriteField(hpack.HeaderField{ encKV(enc, ":status", httpCodeString(w.httpResCode))
Name: ":status",
Value: httpCodeString(w.httpResCode),
})
} }
encodeHeaders(enc, w.h, w.trailers) encodeHeaders(enc, w.h, w.trailers)
if w.contentType != "" { if w.contentType != "" {
enc.WriteField(hpack.HeaderField{Name: "content-type", Value: w.contentType}) encKV(enc, "content-type", w.contentType)
} }
if w.contentLength != "" { if w.contentLength != "" {
enc.WriteField(hpack.HeaderField{Name: "content-length", Value: w.contentLength}) encKV(enc, "content-length", w.contentLength)
} }
if w.date != "" { if w.date != "" {
enc.WriteField(hpack.HeaderField{Name: "date", Value: w.date}) encKV(enc, "date", w.date)
} }
headerBlock := buf.Bytes() headerBlock := buf.Bytes()
@ -206,7 +210,7 @@ type write100ContinueHeadersFrame struct {
func (w write100ContinueHeadersFrame) writeFrame(ctx writeContext) error { func (w write100ContinueHeadersFrame) writeFrame(ctx writeContext) error {
enc, buf := ctx.HeaderEncoder() enc, buf := ctx.HeaderEncoder()
buf.Reset() buf.Reset()
enc.WriteField(hpack.HeaderField{Name: ":status", Value: "100"}) encKV(enc, ":status", "100")
return ctx.Framer().WriteHeaders(HeadersFrameParam{ return ctx.Framer().WriteHeaders(HeadersFrameParam{
StreamID: w.streamID, StreamID: w.streamID,
BlockFragment: buf.Bytes(), BlockFragment: buf.Bytes(),
@ -225,24 +229,34 @@ func (wu writeWindowUpdate) writeFrame(ctx writeContext) error {
} }
func encodeHeaders(enc *hpack.Encoder, h http.Header, keys []string) { func encodeHeaders(enc *hpack.Encoder, h http.Header, keys []string) {
// TODO: garbage. pool sorters like http1? hot path for 1 key?
if keys == nil { if keys == nil {
keys = make([]string, 0, len(h)) sorter := sorterPool.Get().(*sorter)
for k := range h { // Using defer here, since the returned keys from the
keys = append(keys, k) // sorter.Keys method is only valid until the sorter
} // is returned:
sort.Strings(keys) defer sorterPool.Put(sorter)
keys = sorter.Keys(h)
} }
for _, k := range keys { for _, k := range keys {
vv := h[k] vv := h[k]
k = lowerHeader(k) k = lowerHeader(k)
if !validHeaderFieldName(k) {
// TODO: return an error? golang.org/issue/14048
// For now just omit it.
continue
}
isTE := k == "transfer-encoding" isTE := k == "transfer-encoding"
for _, v := range vv { for _, v := range vv {
if !validHeaderFieldValue(v) {
// TODO: return an error? golang.org/issue/14048
// For now just omit it.
continue
}
// TODO: more of "8.1.2.2 Connection-Specific Header Fields" // TODO: more of "8.1.2.2 Connection-Specific Header Fields"
if isTE && v != "trailers" { if isTE && v != "trailers" {
continue continue
} }
enc.WriteField(hpack.HeaderField{Name: k, Value: v}) encKV(enc, k, v)
} }
} }
} }

View file

@ -95,11 +95,14 @@ var DebugUseAfterFinish = false
// //
// The default AuthRequest function returns (true, true) iff the request comes from localhost/127.0.0.1/[::1]. // The default AuthRequest function returns (true, true) iff the request comes from localhost/127.0.0.1/[::1].
var AuthRequest = func(req *http.Request) (any, sensitive bool) { var AuthRequest = func(req *http.Request) (any, sensitive bool) {
// RemoteAddr is commonly in the form "IP" or "IP:port".
// If it is in the form "IP:port", split off the port.
host, _, err := net.SplitHostPort(req.RemoteAddr) host, _, err := net.SplitHostPort(req.RemoteAddr)
switch { if err != nil {
case err != nil: // Badly formed address; fail closed. host = req.RemoteAddr
return false, false }
case host == "localhost" || host == "127.0.0.1" || host == "::1": switch host {
case "localhost", "127.0.0.1", "::1":
return true, true return true, true
default: default:
return false, false return false, false

View file

@ -4,20 +4,15 @@ set -e
workdir=.cover workdir=.cover
profile="$workdir/cover.out" profile="$workdir/cover.out"
mode=set mode=count
end2endtest="google.golang.org/grpc/test"
generate_cover_data() { generate_cover_data() {
rm -rf "$workdir" rm -rf "$workdir"
mkdir "$workdir" mkdir "$workdir"
for pkg in "$@"; do for pkg in "$@"; do
if [ $pkg == "google.golang.org/grpc" -o $pkg == "google.golang.org/grpc/transport" -o $pkg == "google.golang.org/grpc/metadata" -o $pkg == "google.golang.org/grpc/credentials" ] f="$workdir/$(echo $pkg | tr / -).cover"
then go test -covermode="$mode" -coverprofile="$f" "$pkg"
f="$workdir/$(echo $pkg | tr / -)"
go test -covermode="$mode" -coverprofile="$f.cover" "$pkg"
go test -covermode="$mode" -coverpkg "$pkg" -coverprofile="$f.e2e.cover" "$end2endtest"
fi
done done
echo "mode: $mode" >"$profile" echo "mode: $mode" >"$profile"
@ -37,8 +32,6 @@ show_cover_report func
case "$1" in case "$1" in
"") "")
;; ;;
--html)
show_cover_report html ;;
--coveralls) --coveralls)
push_to_coveralls ;; push_to_coveralls ;;
*) *)

View file

@ -172,7 +172,7 @@ func (p *unicastNamingPicker) processUpdates() error {
} }
p.mu.Unlock() p.mu.Unlock()
default: default:
grpclog.Println("Unknown update.Op ", update.Op) grpclog.Println("Unknown update.Op %d", update.Op)
} }
} }
return nil return nil