Signed-off-by: Olivier Gambier <olivier@docker.com>
This commit is contained in:
Olivier Gambier 2016-03-22 10:42:33 -07:00
parent 59401e277b
commit d1444b56e9
141 changed files with 19483 additions and 4205 deletions

View file

@ -1,2 +0,0 @@
*~
h2i/h2i

View file

@ -1,20 +0,0 @@
# This file is like Go's AUTHORS file: it lists Copyright holders.
# The list of humans who have contributd is in the CONTRIBUTORS file.
#
# To contribute to this project, because it will eventually be folded
# back in to Go itself, you need to submit a CLA:
#
# http://golang.org/doc/contribute.html#copyright
#
# Then you get added to CONTRIBUTORS and you or your company get added
# to the AUTHORS file.
Blake Mizerany <blake.mizerany@gmail.com> github=bmizerany
Daniel Morsing <daniel.morsing@gmail.com> github=DanielMorsing
Gabriel Aszalos <gabriel.aszalos@gmail.com> github=gbbr
Google, Inc.
Keith Rarick <kr@xph.us> github=kr
Matthew Keenan <tank.en.mate@gmail.com> <github@mattkeenan.net> github=mattkeenan
Matt Layher <mdlayher@gmail.com> github=mdlayher
Perry Abbott <perry.j.abbott@gmail.com> github=pabbott0
Tatsuhiro Tsujikawa <tatsuhiro.t@gmail.com> github=tatsuhiro-t

View file

@ -1,20 +0,0 @@
# This file is like Go's CONTRIBUTORS file: it lists humans.
# The list of copyright holders (which may be companies) are in the AUTHORS file.
#
# To contribute to this project, because it will eventually be folded
# back in to Go itself, you need to submit a CLA:
#
# http://golang.org/doc/contribute.html#copyright
#
# Then you get added to CONTRIBUTORS and you or your company get added
# to the AUTHORS file.
Blake Mizerany <blake.mizerany@gmail.com> github=bmizerany
Brad Fitzpatrick <bradfitz@golang.org> github=bradfitz
Daniel Morsing <daniel.morsing@gmail.com> github=DanielMorsing
Gabriel Aszalos <gabriel.aszalos@gmail.com> github=gbbr
Keith Rarick <kr@xph.us> github=kr
Matthew Keenan <tank.en.mate@gmail.com> <github@mattkeenan.net> github=mattkeenan
Matt Layher <mdlayher@gmail.com> github=mdlayher
Perry Abbott <perry.j.abbott@gmail.com> github=pabbott0
Tatsuhiro Tsujikawa <tatsuhiro.t@gmail.com> github=tatsuhiro-t

View file

@ -1,44 +0,0 @@
#
# This Dockerfile builds a recent curl with HTTP/2 client support, using
# a recent nghttp2 build.
#
# See the Makefile for how to tag it. If Docker and that image is found, the
# Go tests use this curl binary for integration tests.
#
FROM ubuntu:trusty
RUN apt-get update && \
apt-get upgrade -y && \
apt-get install -y git-core build-essential wget
RUN apt-get install -y --no-install-recommends \
autotools-dev libtool pkg-config zlib1g-dev \
libcunit1-dev libssl-dev libxml2-dev libevent-dev \
automake autoconf
# Note: setting NGHTTP2_VER before the git clone, so an old git clone isn't cached:
ENV NGHTTP2_VER af24f8394e43f4
RUN cd /root && git clone https://github.com/tatsuhiro-t/nghttp2.git
WORKDIR /root/nghttp2
RUN git reset --hard $NGHTTP2_VER
RUN autoreconf -i
RUN automake
RUN autoconf
RUN ./configure
RUN make
RUN make install
WORKDIR /root
RUN wget http://curl.haxx.se/download/curl-7.40.0.tar.gz
RUN tar -zxvf curl-7.40.0.tar.gz
WORKDIR /root/curl-7.40.0
RUN ./configure --with-ssl --with-nghttp2=/usr/local
RUN make
RUN make install
RUN ldconfig
CMD ["-h"]
ENTRYPOINT ["/usr/local/bin/curl"]

View file

@ -1,5 +0,0 @@
We only accept contributions from users who have gone through Go's
contribution process (signed a CLA).
Please acknowledge whether you have (and use the same email) if
sending a pull request.

View file

@ -1,7 +0,0 @@
Copyright 2014 Google & the Go AUTHORS
Go AUTHORS are:
See https://code.google.com/p/go/source/browse/AUTHORS
Licensed under the terms of Go itself:
https://code.google.com/p/go/source/browse/LICENSE

View file

@ -1,3 +0,0 @@
curlimage:
docker build -t gohttp2/curl .

View file

@ -1,17 +0,0 @@
This is a work-in-progress HTTP/2 implementation for Go.
It will eventually live in the Go standard library and won't require
any changes to your code to use. It will just be automatic.
Status:
* The server support is pretty good. A few things are missing
but are being worked on.
* The client work has just started but shares a lot of code
is coming along much quicker.
Docs are at https://godoc.org/github.com/bradfitz/http2
Demo test server at https://http2.golang.org/
Help & bug reports welcome.

View file

@ -1,76 +0,0 @@
// Copyright 2014 The Go Authors.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
package http2
import (
"errors"
)
// buffer is an io.ReadWriteCloser backed by a fixed size buffer.
// It never allocates, but moves old data as new data is written.
type buffer struct {
buf []byte
r, w int
closed bool
err error // err to return to reader
}
var (
errReadEmpty = errors.New("read from empty buffer")
errWriteClosed = errors.New("write on closed buffer")
errWriteFull = errors.New("write on full buffer")
)
// Read copies bytes from the buffer into p.
// It is an error to read when no data is available.
func (b *buffer) Read(p []byte) (n int, err error) {
n = copy(p, b.buf[b.r:b.w])
b.r += n
if b.closed && b.r == b.w {
err = b.err
} else if b.r == b.w && n == 0 {
err = errReadEmpty
}
return n, err
}
// Len returns the number of bytes of the unread portion of the buffer.
func (b *buffer) Len() int {
return b.w - b.r
}
// Write copies bytes from p into the buffer.
// It is an error to write more data than the buffer can hold.
func (b *buffer) Write(p []byte) (n int, err error) {
if b.closed {
return 0, errWriteClosed
}
// Slide existing data to beginning.
if b.r > 0 && len(p) > len(b.buf)-b.w {
copy(b.buf, b.buf[b.r:b.w])
b.w -= b.r
b.r = 0
}
// Write new data.
n = copy(b.buf[b.w:], p)
b.w += n
if n < len(p) {
err = errWriteFull
}
return n, err
}
// Close marks the buffer as closed. Future calls to Write will
// return an error. Future calls to Read, once the buffer is
// empty, will return err.
func (b *buffer) Close(err error) {
if !b.closed {
b.closed = true
b.err = err
}
}

View file

@ -1,78 +0,0 @@
// Copyright 2014 The Go Authors.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
package http2
import "fmt"
// An ErrCode is an unsigned 32-bit error code as defined in the HTTP/2 spec.
type ErrCode uint32
const (
ErrCodeNo ErrCode = 0x0
ErrCodeProtocol ErrCode = 0x1
ErrCodeInternal ErrCode = 0x2
ErrCodeFlowControl ErrCode = 0x3
ErrCodeSettingsTimeout ErrCode = 0x4
ErrCodeStreamClosed ErrCode = 0x5
ErrCodeFrameSize ErrCode = 0x6
ErrCodeRefusedStream ErrCode = 0x7
ErrCodeCancel ErrCode = 0x8
ErrCodeCompression ErrCode = 0x9
ErrCodeConnect ErrCode = 0xa
ErrCodeEnhanceYourCalm ErrCode = 0xb
ErrCodeInadequateSecurity ErrCode = 0xc
ErrCodeHTTP11Required ErrCode = 0xd
)
var errCodeName = map[ErrCode]string{
ErrCodeNo: "NO_ERROR",
ErrCodeProtocol: "PROTOCOL_ERROR",
ErrCodeInternal: "INTERNAL_ERROR",
ErrCodeFlowControl: "FLOW_CONTROL_ERROR",
ErrCodeSettingsTimeout: "SETTINGS_TIMEOUT",
ErrCodeStreamClosed: "STREAM_CLOSED",
ErrCodeFrameSize: "FRAME_SIZE_ERROR",
ErrCodeRefusedStream: "REFUSED_STREAM",
ErrCodeCancel: "CANCEL",
ErrCodeCompression: "COMPRESSION_ERROR",
ErrCodeConnect: "CONNECT_ERROR",
ErrCodeEnhanceYourCalm: "ENHANCE_YOUR_CALM",
ErrCodeInadequateSecurity: "INADEQUATE_SECURITY",
ErrCodeHTTP11Required: "HTTP_1_1_REQUIRED",
}
func (e ErrCode) String() string {
if s, ok := errCodeName[e]; ok {
return s
}
return fmt.Sprintf("unknown error code 0x%x", uint32(e))
}
// ConnectionError is an error that results in the termination of the
// entire connection.
type ConnectionError ErrCode
func (e ConnectionError) Error() string { return fmt.Sprintf("connection error: %s", ErrCode(e)) }
// StreamError is an error that only affects one stream within an
// HTTP/2 connection.
type StreamError struct {
StreamID uint32
Code ErrCode
}
func (e StreamError) Error() string {
return fmt.Sprintf("stream error: stream ID %d; %v", e.StreamID, e.Code)
}
// 6.9.1 The Flow Control Window
// "If a sender receives a WINDOW_UPDATE that causes a flow control
// window to exceed this maximum it MUST terminate either the stream
// or the connection, as appropriate. For streams, [...]; for the
// connection, a GOAWAY frame with a FLOW_CONTROL_ERROR code."
type goAwayFlowError struct{}
func (goAwayFlowError) Error() string { return "connection exceeded flow control window size" }

View file

@ -1,51 +0,0 @@
// Copyright 2014 The Go Authors.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
// Flow control
package http2
// flow is the flow control window's size.
type flow struct {
// n is the number of DATA bytes we're allowed to send.
// A flow is kept both on a conn and a per-stream.
n int32
// conn points to the shared connection-level flow that is
// shared by all streams on that conn. It is nil for the flow
// that's on the conn directly.
conn *flow
}
func (f *flow) setConnFlow(cf *flow) { f.conn = cf }
func (f *flow) available() int32 {
n := f.n
if f.conn != nil && f.conn.n < n {
n = f.conn.n
}
return n
}
func (f *flow) take(n int32) {
if n > f.available() {
panic("internal error: took too much")
}
f.n -= n
if f.conn != nil {
f.conn.n -= n
}
}
// add adds n bytes (positive or negative) to the flow control window.
// It returns false if the sum would exceed 2^31-1.
func (f *flow) add(n int32) bool {
remain := (1<<31 - 1) - f.n
if n > remain {
return false
}
f.n += n
return true
}

File diff suppressed because it is too large Load diff

View file

@ -1,173 +0,0 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
// Defensive debug-only utility to track that functions run on the
// goroutine that they're supposed to.
package http2
import (
"bytes"
"errors"
"fmt"
"os"
"runtime"
"strconv"
"sync"
)
var DebugGoroutines = os.Getenv("DEBUG_HTTP2_GOROUTINES") == "1"
type goroutineLock uint64
func newGoroutineLock() goroutineLock {
if !DebugGoroutines {
return 0
}
return goroutineLock(curGoroutineID())
}
func (g goroutineLock) check() {
if !DebugGoroutines {
return
}
if curGoroutineID() != uint64(g) {
panic("running on the wrong goroutine")
}
}
func (g goroutineLock) checkNotOn() {
if !DebugGoroutines {
return
}
if curGoroutineID() == uint64(g) {
panic("running on the wrong goroutine")
}
}
var goroutineSpace = []byte("goroutine ")
func curGoroutineID() uint64 {
bp := littleBuf.Get().(*[]byte)
defer littleBuf.Put(bp)
b := *bp
b = b[:runtime.Stack(b, false)]
// Parse the 4707 out of "goroutine 4707 ["
b = bytes.TrimPrefix(b, goroutineSpace)
i := bytes.IndexByte(b, ' ')
if i < 0 {
panic(fmt.Sprintf("No space found in %q", b))
}
b = b[:i]
n, err := parseUintBytes(b, 10, 64)
if err != nil {
panic(fmt.Sprintf("Failed to parse goroutine ID out of %q: %v", b, err))
}
return n
}
var littleBuf = sync.Pool{
New: func() interface{} {
buf := make([]byte, 64)
return &buf
},
}
// parseUintBytes is like strconv.ParseUint, but using a []byte.
func parseUintBytes(s []byte, base int, bitSize int) (n uint64, err error) {
var cutoff, maxVal uint64
if bitSize == 0 {
bitSize = int(strconv.IntSize)
}
s0 := s
switch {
case len(s) < 1:
err = strconv.ErrSyntax
goto Error
case 2 <= base && base <= 36:
// valid base; nothing to do
case base == 0:
// Look for octal, hex prefix.
switch {
case s[0] == '0' && len(s) > 1 && (s[1] == 'x' || s[1] == 'X'):
base = 16
s = s[2:]
if len(s) < 1 {
err = strconv.ErrSyntax
goto Error
}
case s[0] == '0':
base = 8
default:
base = 10
}
default:
err = errors.New("invalid base " + strconv.Itoa(base))
goto Error
}
n = 0
cutoff = cutoff64(base)
maxVal = 1<<uint(bitSize) - 1
for i := 0; i < len(s); i++ {
var v byte
d := s[i]
switch {
case '0' <= d && d <= '9':
v = d - '0'
case 'a' <= d && d <= 'z':
v = d - 'a' + 10
case 'A' <= d && d <= 'Z':
v = d - 'A' + 10
default:
n = 0
err = strconv.ErrSyntax
goto Error
}
if int(v) >= base {
n = 0
err = strconv.ErrSyntax
goto Error
}
if n >= cutoff {
// n*base overflows
n = 1<<64 - 1
err = strconv.ErrRange
goto Error
}
n *= uint64(base)
n1 := n + uint64(v)
if n1 < n || n1 > maxVal {
// n+v overflows
n = 1<<64 - 1
err = strconv.ErrRange
goto Error
}
n = n1
}
return n, nil
Error:
return n, &strconv.NumError{Func: "ParseUint", Num: string(s0), Err: err}
}
// Return the first number n such that n*base >= 1<<64.
func cutoff64(base int) uint64 {
if base < 2 {
return 0
}
return (1<<64-1)/uint64(base) + 1
}

View file

@ -1,80 +0,0 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
package http2
import (
"net/http"
"strings"
)
var (
commonLowerHeader = map[string]string{} // Go-Canonical-Case -> lower-case
commonCanonHeader = map[string]string{} // lower-case -> Go-Canonical-Case
)
func init() {
for _, v := range []string{
"accept",
"accept-charset",
"accept-encoding",
"accept-language",
"accept-ranges",
"age",
"access-control-allow-origin",
"allow",
"authorization",
"cache-control",
"content-disposition",
"content-encoding",
"content-language",
"content-length",
"content-location",
"content-range",
"content-type",
"cookie",
"date",
"etag",
"expect",
"expires",
"from",
"host",
"if-match",
"if-modified-since",
"if-none-match",
"if-unmodified-since",
"last-modified",
"link",
"location",
"max-forwards",
"proxy-authenticate",
"proxy-authorization",
"range",
"referer",
"refresh",
"retry-after",
"server",
"set-cookie",
"strict-transport-security",
"transfer-encoding",
"user-agent",
"vary",
"via",
"www-authenticate",
} {
chk := http.CanonicalHeaderKey(v)
commonLowerHeader[chk] = v
commonCanonHeader[v] = chk
}
}
func lowerHeader(v string) string {
if s, ok := commonLowerHeader[v]; ok {
return s
}
return strings.ToLower(v)
}

View file

@ -1,252 +0,0 @@
// Copyright 2014 The Go Authors.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
package hpack
import (
"io"
)
const (
uint32Max = ^uint32(0)
initialHeaderTableSize = 4096
)
type Encoder struct {
dynTab dynamicTable
// minSize is the minimum table size set by
// SetMaxDynamicTableSize after the previous Header Table Size
// Update.
minSize uint32
// maxSizeLimit is the maximum table size this encoder
// supports. This will protect the encoder from too large
// size.
maxSizeLimit uint32
// tableSizeUpdate indicates whether "Header Table Size
// Update" is required.
tableSizeUpdate bool
w io.Writer
buf []byte
}
// NewEncoder returns a new Encoder which performs HPACK encoding. An
// encoded data is written to w.
func NewEncoder(w io.Writer) *Encoder {
e := &Encoder{
minSize: uint32Max,
maxSizeLimit: initialHeaderTableSize,
tableSizeUpdate: false,
w: w,
}
e.dynTab.setMaxSize(initialHeaderTableSize)
return e
}
// WriteField encodes f into a single Write to e's underlying Writer.
// This function may also produce bytes for "Header Table Size Update"
// if necessary. If produced, it is done before encoding f.
func (e *Encoder) WriteField(f HeaderField) error {
e.buf = e.buf[:0]
if e.tableSizeUpdate {
e.tableSizeUpdate = false
if e.minSize < e.dynTab.maxSize {
e.buf = appendTableSize(e.buf, e.minSize)
}
e.minSize = uint32Max
e.buf = appendTableSize(e.buf, e.dynTab.maxSize)
}
idx, nameValueMatch := e.searchTable(f)
if nameValueMatch {
e.buf = appendIndexed(e.buf, idx)
} else {
indexing := e.shouldIndex(f)
if indexing {
e.dynTab.add(f)
}
if idx == 0 {
e.buf = appendNewName(e.buf, f, indexing)
} else {
e.buf = appendIndexedName(e.buf, f, idx, indexing)
}
}
n, err := e.w.Write(e.buf)
if err == nil && n != len(e.buf) {
err = io.ErrShortWrite
}
return err
}
// searchTable searches f in both stable and dynamic header tables.
// The static header table is searched first. Only when there is no
// exact match for both name and value, the dynamic header table is
// then searched. If there is no match, i is 0. If both name and value
// match, i is the matched index and nameValueMatch becomes true. If
// only name matches, i points to that index and nameValueMatch
// becomes false.
func (e *Encoder) searchTable(f HeaderField) (i uint64, nameValueMatch bool) {
for idx, hf := range staticTable {
if !constantTimeStringCompare(hf.Name, f.Name) {
continue
}
if i == 0 {
i = uint64(idx + 1)
}
if f.Sensitive {
continue
}
if !constantTimeStringCompare(hf.Value, f.Value) {
continue
}
i = uint64(idx + 1)
nameValueMatch = true
return
}
j, nameValueMatch := e.dynTab.search(f)
if nameValueMatch || (i == 0 && j != 0) {
i = j + uint64(len(staticTable))
}
return
}
// SetMaxDynamicTableSize changes the dynamic header table size to v.
// The actual size is bounded by the value passed to
// SetMaxDynamicTableSizeLimit.
func (e *Encoder) SetMaxDynamicTableSize(v uint32) {
if v > e.maxSizeLimit {
v = e.maxSizeLimit
}
if v < e.minSize {
e.minSize = v
}
e.tableSizeUpdate = true
e.dynTab.setMaxSize(v)
}
// SetMaxDynamicTableSizeLimit changes the maximum value that can be
// specified in SetMaxDynamicTableSize to v. By default, it is set to
// 4096, which is the same size of the default dynamic header table
// size described in HPACK specification. If the current maximum
// dynamic header table size is strictly greater than v, "Header Table
// Size Update" will be done in the next WriteField call and the
// maximum dynamic header table size is truncated to v.
func (e *Encoder) SetMaxDynamicTableSizeLimit(v uint32) {
e.maxSizeLimit = v
if e.dynTab.maxSize > v {
e.tableSizeUpdate = true
e.dynTab.setMaxSize(v)
}
}
// shouldIndex reports whether f should be indexed.
func (e *Encoder) shouldIndex(f HeaderField) bool {
return !f.Sensitive && f.size() <= e.dynTab.maxSize
}
// appendIndexed appends index i, as encoded in "Indexed Header Field"
// representation, to dst and returns the extended buffer.
func appendIndexed(dst []byte, i uint64) []byte {
first := len(dst)
dst = appendVarInt(dst, 7, i)
dst[first] |= 0x80
return dst
}
// appendNewName appends f, as encoded in one of "Literal Header field
// - New Name" representation variants, to dst and returns the
// extended buffer.
//
// If f.Sensitive is true, "Never Indexed" representation is used. If
// f.Sensitive is false and indexing is true, "Inremental Indexing"
// representation is used.
func appendNewName(dst []byte, f HeaderField, indexing bool) []byte {
dst = append(dst, encodeTypeByte(indexing, f.Sensitive))
dst = appendHpackString(dst, f.Name)
return appendHpackString(dst, f.Value)
}
// appendIndexedName appends f and index i referring indexed name
// entry, as encoded in one of "Literal Header field - Indexed Name"
// representation variants, to dst and returns the extended buffer.
//
// If f.Sensitive is true, "Never Indexed" representation is used. If
// f.Sensitive is false and indexing is true, "Incremental Indexing"
// representation is used.
func appendIndexedName(dst []byte, f HeaderField, i uint64, indexing bool) []byte {
first := len(dst)
var n byte
if indexing {
n = 6
} else {
n = 4
}
dst = appendVarInt(dst, n, i)
dst[first] |= encodeTypeByte(indexing, f.Sensitive)
return appendHpackString(dst, f.Value)
}
// appendTableSize appends v, as encoded in "Header Table Size Update"
// representation, to dst and returns the extended buffer.
func appendTableSize(dst []byte, v uint32) []byte {
first := len(dst)
dst = appendVarInt(dst, 5, uint64(v))
dst[first] |= 0x20
return dst
}
// appendVarInt appends i, as encoded in variable integer form using n
// bit prefix, to dst and returns the extended buffer.
//
// See
// http://http2.github.io/http2-spec/compression.html#integer.representation
func appendVarInt(dst []byte, n byte, i uint64) []byte {
k := uint64((1 << n) - 1)
if i < k {
return append(dst, byte(i))
}
dst = append(dst, byte(k))
i -= k
for ; i >= 128; i >>= 7 {
dst = append(dst, byte(0x80|(i&0x7f)))
}
return append(dst, byte(i))
}
// appendHpackString appends s, as encoded in "String Literal"
// representation, to dst and returns the the extended buffer.
//
// s will be encoded in Huffman codes only when it produces strictly
// shorter byte string.
func appendHpackString(dst []byte, s string) []byte {
huffmanLength := HuffmanEncodeLength(s)
if huffmanLength < uint64(len(s)) {
first := len(dst)
dst = appendVarInt(dst, 7, huffmanLength)
dst = AppendHuffmanString(dst, s)
dst[first] |= 0x80
} else {
dst = appendVarInt(dst, 7, uint64(len(s)))
dst = append(dst, s...)
}
return dst
}
// encodeTypeByte returns type byte. If sensitive is true, type byte
// for "Never Indexed" representation is returned. If sensitive is
// false and indexing is true, type byte for "Incremental Indexing"
// representation is returned. Otherwise, type byte for "Without
// Indexing" is returned.
func encodeTypeByte(indexing, sensitive bool) byte {
if sensitive {
return 0x10
}
if indexing {
return 0x40
}
return 0
}

View file

@ -1,445 +0,0 @@
// Copyright 2014 The Go Authors.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
// Package hpack implements HPACK, a compression format for
// efficiently representing HTTP header fields in the context of HTTP/2.
//
// See http://tools.ietf.org/html/draft-ietf-httpbis-header-compression-09
package hpack
import (
"bytes"
"errors"
"fmt"
)
// A DecodingError is something the spec defines as a decoding error.
type DecodingError struct {
Err error
}
func (de DecodingError) Error() string {
return fmt.Sprintf("decoding error: %v", de.Err)
}
// An InvalidIndexError is returned when an encoder references a table
// entry before the static table or after the end of the dynamic table.
type InvalidIndexError int
func (e InvalidIndexError) Error() string {
return fmt.Sprintf("invalid indexed representation index %d", int(e))
}
// A HeaderField is a name-value pair. Both the name and value are
// treated as opaque sequences of octets.
type HeaderField struct {
Name, Value string
// Sensitive means that this header field should never be
// indexed.
Sensitive bool
}
func (hf *HeaderField) size() uint32 {
// http://http2.github.io/http2-spec/compression.html#rfc.section.4.1
// "The size of the dynamic table is the sum of the size of
// its entries. The size of an entry is the sum of its name's
// length in octets (as defined in Section 5.2), its value's
// length in octets (see Section 5.2), plus 32. The size of
// an entry is calculated using the length of the name and
// value without any Huffman encoding applied."
// This can overflow if somebody makes a large HeaderField
// Name and/or Value by hand, but we don't care, because that
// won't happen on the wire because the encoding doesn't allow
// it.
return uint32(len(hf.Name) + len(hf.Value) + 32)
}
// A Decoder is the decoding context for incremental processing of
// header blocks.
type Decoder struct {
dynTab dynamicTable
emit func(f HeaderField)
// buf is the unparsed buffer. It's only written to
// saveBuf if it was truncated in the middle of a header
// block. Because it's usually not owned, we can only
// process it under Write.
buf []byte // usually not owned
saveBuf bytes.Buffer
}
func NewDecoder(maxSize uint32, emitFunc func(f HeaderField)) *Decoder {
d := &Decoder{
emit: emitFunc,
}
d.dynTab.allowedMaxSize = maxSize
d.dynTab.setMaxSize(maxSize)
return d
}
// TODO: add method *Decoder.Reset(maxSize, emitFunc) to let callers re-use Decoders and their
// underlying buffers for garbage reasons.
func (d *Decoder) SetMaxDynamicTableSize(v uint32) {
d.dynTab.setMaxSize(v)
}
// SetAllowedMaxDynamicTableSize sets the upper bound that the encoded
// stream (via dynamic table size updates) may set the maximum size
// to.
func (d *Decoder) SetAllowedMaxDynamicTableSize(v uint32) {
d.dynTab.allowedMaxSize = v
}
type dynamicTable struct {
// ents is the FIFO described at
// http://http2.github.io/http2-spec/compression.html#rfc.section.2.3.2
// The newest (low index) is append at the end, and items are
// evicted from the front.
ents []HeaderField
size uint32
maxSize uint32 // current maxSize
allowedMaxSize uint32 // maxSize may go up to this, inclusive
}
func (dt *dynamicTable) setMaxSize(v uint32) {
dt.maxSize = v
dt.evict()
}
// TODO: change dynamicTable to be a struct with a slice and a size int field,
// per http://http2.github.io/http2-spec/compression.html#rfc.section.4.1:
//
//
// Then make add increment the size. maybe the max size should move from Decoder to
// dynamicTable and add should return an ok bool if there was enough space.
//
// Later we'll need a remove operation on dynamicTable.
func (dt *dynamicTable) add(f HeaderField) {
dt.ents = append(dt.ents, f)
dt.size += f.size()
dt.evict()
}
// If we're too big, evict old stuff (front of the slice)
func (dt *dynamicTable) evict() {
base := dt.ents // keep base pointer of slice
for dt.size > dt.maxSize {
dt.size -= dt.ents[0].size()
dt.ents = dt.ents[1:]
}
// Shift slice contents down if we evicted things.
if len(dt.ents) != len(base) {
copy(base, dt.ents)
dt.ents = base[:len(dt.ents)]
}
}
// constantTimeStringCompare compares string a and b in a constant
// time manner.
func constantTimeStringCompare(a, b string) bool {
if len(a) != len(b) {
return false
}
c := byte(0)
for i := 0; i < len(a); i++ {
c |= a[i] ^ b[i]
}
return c == 0
}
// Search searches f in the table. The return value i is 0 if there is
// no name match. If there is name match or name/value match, i is the
// index of that entry (1-based). If both name and value match,
// nameValueMatch becomes true.
func (dt *dynamicTable) search(f HeaderField) (i uint64, nameValueMatch bool) {
l := len(dt.ents)
for j := l - 1; j >= 0; j-- {
ent := dt.ents[j]
if !constantTimeStringCompare(ent.Name, f.Name) {
continue
}
if i == 0 {
i = uint64(l - j)
}
if f.Sensitive {
continue
}
if !constantTimeStringCompare(ent.Value, f.Value) {
continue
}
i = uint64(l - j)
nameValueMatch = true
return
}
return
}
func (d *Decoder) maxTableIndex() int {
return len(d.dynTab.ents) + len(staticTable)
}
func (d *Decoder) at(i uint64) (hf HeaderField, ok bool) {
if i < 1 {
return
}
if i > uint64(d.maxTableIndex()) {
return
}
if i <= uint64(len(staticTable)) {
return staticTable[i-1], true
}
dents := d.dynTab.ents
return dents[len(dents)-(int(i)-len(staticTable))], true
}
// Decode decodes an entire block.
//
// TODO: remove this method and make it incremental later? This is
// easier for debugging now.
func (d *Decoder) DecodeFull(p []byte) ([]HeaderField, error) {
var hf []HeaderField
saveFunc := d.emit
defer func() { d.emit = saveFunc }()
d.emit = func(f HeaderField) { hf = append(hf, f) }
if _, err := d.Write(p); err != nil {
return nil, err
}
if err := d.Close(); err != nil {
return nil, err
}
return hf, nil
}
func (d *Decoder) Close() error {
if d.saveBuf.Len() > 0 {
d.saveBuf.Reset()
return DecodingError{errors.New("truncated headers")}
}
return nil
}
func (d *Decoder) Write(p []byte) (n int, err error) {
if len(p) == 0 {
// Prevent state machine CPU attacks (making us redo
// work up to the point of finding out we don't have
// enough data)
return
}
// Only copy the data if we have to. Optimistically assume
// that p will contain a complete header block.
if d.saveBuf.Len() == 0 {
d.buf = p
} else {
d.saveBuf.Write(p)
d.buf = d.saveBuf.Bytes()
d.saveBuf.Reset()
}
for len(d.buf) > 0 {
err = d.parseHeaderFieldRepr()
if err != nil {
if err == errNeedMore {
err = nil
d.saveBuf.Write(d.buf)
}
break
}
}
return len(p), err
}
// errNeedMore is an internal sentinel error value that means the
// buffer is truncated and we need to read more data before we can
// continue parsing.
var errNeedMore = errors.New("need more data")
type indexType int
const (
indexedTrue indexType = iota
indexedFalse
indexedNever
)
func (v indexType) indexed() bool { return v == indexedTrue }
func (v indexType) sensitive() bool { return v == indexedNever }
// returns errNeedMore if there isn't enough data available.
// any other error is fatal.
// consumes d.buf iff it returns nil.
// precondition: must be called with len(d.buf) > 0
func (d *Decoder) parseHeaderFieldRepr() error {
b := d.buf[0]
switch {
case b&128 != 0:
// Indexed representation.
// High bit set?
// http://http2.github.io/http2-spec/compression.html#rfc.section.6.1
return d.parseFieldIndexed()
case b&192 == 64:
// 6.2.1 Literal Header Field with Incremental Indexing
// 0b10xxxxxx: top two bits are 10
// http://http2.github.io/http2-spec/compression.html#rfc.section.6.2.1
return d.parseFieldLiteral(6, indexedTrue)
case b&240 == 0:
// 6.2.2 Literal Header Field without Indexing
// 0b0000xxxx: top four bits are 0000
// http://http2.github.io/http2-spec/compression.html#rfc.section.6.2.2
return d.parseFieldLiteral(4, indexedFalse)
case b&240 == 16:
// 6.2.3 Literal Header Field never Indexed
// 0b0001xxxx: top four bits are 0001
// http://http2.github.io/http2-spec/compression.html#rfc.section.6.2.3
return d.parseFieldLiteral(4, indexedNever)
case b&224 == 32:
// 6.3 Dynamic Table Size Update
// Top three bits are '001'.
// http://http2.github.io/http2-spec/compression.html#rfc.section.6.3
return d.parseDynamicTableSizeUpdate()
}
return DecodingError{errors.New("invalid encoding")}
}
// (same invariants and behavior as parseHeaderFieldRepr)
func (d *Decoder) parseFieldIndexed() error {
buf := d.buf
idx, buf, err := readVarInt(7, buf)
if err != nil {
return err
}
hf, ok := d.at(idx)
if !ok {
return DecodingError{InvalidIndexError(idx)}
}
d.emit(HeaderField{Name: hf.Name, Value: hf.Value})
d.buf = buf
return nil
}
// (same invariants and behavior as parseHeaderFieldRepr)
func (d *Decoder) parseFieldLiteral(n uint8, it indexType) error {
buf := d.buf
nameIdx, buf, err := readVarInt(n, buf)
if err != nil {
return err
}
var hf HeaderField
if nameIdx > 0 {
ihf, ok := d.at(nameIdx)
if !ok {
return DecodingError{InvalidIndexError(nameIdx)}
}
hf.Name = ihf.Name
} else {
hf.Name, buf, err = readString(buf)
if err != nil {
return err
}
}
hf.Value, buf, err = readString(buf)
if err != nil {
return err
}
d.buf = buf
if it.indexed() {
d.dynTab.add(hf)
}
hf.Sensitive = it.sensitive()
d.emit(hf)
return nil
}
// (same invariants and behavior as parseHeaderFieldRepr)
func (d *Decoder) parseDynamicTableSizeUpdate() error {
buf := d.buf
size, buf, err := readVarInt(5, buf)
if err != nil {
return err
}
if size > uint64(d.dynTab.allowedMaxSize) {
return DecodingError{errors.New("dynamic table size update too large")}
}
d.dynTab.setMaxSize(uint32(size))
d.buf = buf
return nil
}
var errVarintOverflow = DecodingError{errors.New("varint integer overflow")}
// readVarInt reads an unsigned variable length integer off the
// beginning of p. n is the parameter as described in
// http://http2.github.io/http2-spec/compression.html#rfc.section.5.1.
//
// n must always be between 1 and 8.
//
// The returned remain buffer is either a smaller suffix of p, or err != nil.
// The error is errNeedMore if p doesn't contain a complete integer.
func readVarInt(n byte, p []byte) (i uint64, remain []byte, err error) {
if n < 1 || n > 8 {
panic("bad n")
}
if len(p) == 0 {
return 0, p, errNeedMore
}
i = uint64(p[0])
if n < 8 {
i &= (1 << uint64(n)) - 1
}
if i < (1<<uint64(n))-1 {
return i, p[1:], nil
}
origP := p
p = p[1:]
var m uint64
for len(p) > 0 {
b := p[0]
p = p[1:]
i += uint64(b&127) << m
if b&128 == 0 {
return i, p, nil
}
m += 7
if m >= 63 { // TODO: proper overflow check. making this up.
return 0, origP, errVarintOverflow
}
}
return 0, origP, errNeedMore
}
func readString(p []byte) (s string, remain []byte, err error) {
if len(p) == 0 {
return "", p, errNeedMore
}
isHuff := p[0]&128 != 0
strLen, p, err := readVarInt(7, p)
if err != nil {
return "", p, err
}
if uint64(len(p)) < strLen {
return "", p, errNeedMore
}
if !isHuff {
return string(p[:strLen]), p[strLen:], nil
}
// TODO: optimize this garbage:
var buf bytes.Buffer
if _, err := HuffmanDecode(&buf, p[:strLen]); err != nil {
return "", nil, err
}
return buf.String(), p[strLen:], nil
}

View file

@ -1,159 +0,0 @@
// Copyright 2014 The Go Authors.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
package hpack
import (
"bytes"
"io"
"sync"
)
var bufPool = sync.Pool{
New: func() interface{} { return new(bytes.Buffer) },
}
// HuffmanDecode decodes the string in v and writes the expanded
// result to w, returning the number of bytes written to w and the
// Write call's return value. At most one Write call is made.
func HuffmanDecode(w io.Writer, v []byte) (int, error) {
buf := bufPool.Get().(*bytes.Buffer)
buf.Reset()
defer bufPool.Put(buf)
n := rootHuffmanNode
cur, nbits := uint(0), uint8(0)
for _, b := range v {
cur = cur<<8 | uint(b)
nbits += 8
for nbits >= 8 {
n = n.children[byte(cur>>(nbits-8))]
if n.children == nil {
buf.WriteByte(n.sym)
nbits -= n.codeLen
n = rootHuffmanNode
} else {
nbits -= 8
}
}
}
for nbits > 0 {
n = n.children[byte(cur<<(8-nbits))]
if n.children != nil || n.codeLen > nbits {
break
}
buf.WriteByte(n.sym)
nbits -= n.codeLen
n = rootHuffmanNode
}
return w.Write(buf.Bytes())
}
type node struct {
// children is non-nil for internal nodes
children []*node
// The following are only valid if children is nil:
codeLen uint8 // number of bits that led to the output of sym
sym byte // output symbol
}
func newInternalNode() *node {
return &node{children: make([]*node, 256)}
}
var rootHuffmanNode = newInternalNode()
func init() {
for i, code := range huffmanCodes {
if i > 255 {
panic("too many huffman codes")
}
addDecoderNode(byte(i), code, huffmanCodeLen[i])
}
}
func addDecoderNode(sym byte, code uint32, codeLen uint8) {
cur := rootHuffmanNode
for codeLen > 8 {
codeLen -= 8
i := uint8(code >> codeLen)
if cur.children[i] == nil {
cur.children[i] = newInternalNode()
}
cur = cur.children[i]
}
shift := 8 - codeLen
start, end := int(uint8(code<<shift)), int(1<<shift)
for i := start; i < start+end; i++ {
cur.children[i] = &node{sym: sym, codeLen: codeLen}
}
}
// AppendHuffmanString appends s, as encoded in Huffman codes, to dst
// and returns the extended buffer.
func AppendHuffmanString(dst []byte, s string) []byte {
rembits := uint8(8)
for i := 0; i < len(s); i++ {
if rembits == 8 {
dst = append(dst, 0)
}
dst, rembits = appendByteToHuffmanCode(dst, rembits, s[i])
}
if rembits < 8 {
// special EOS symbol
code := uint32(0x3fffffff)
nbits := uint8(30)
t := uint8(code >> (nbits - rembits))
dst[len(dst)-1] |= t
}
return dst
}
// HuffmanEncodeLength returns the number of bytes required to encode
// s in Huffman codes. The result is round up to byte boundary.
func HuffmanEncodeLength(s string) uint64 {
n := uint64(0)
for i := 0; i < len(s); i++ {
n += uint64(huffmanCodeLen[s[i]])
}
return (n + 7) / 8
}
// appendByteToHuffmanCode appends Huffman code for c to dst and
// returns the extended buffer and the remaining bits in the last
// element. The appending is not byte aligned and the remaining bits
// in the last element of dst is given in rembits.
func appendByteToHuffmanCode(dst []byte, rembits uint8, c byte) ([]byte, uint8) {
code := huffmanCodes[c]
nbits := huffmanCodeLen[c]
for {
if rembits > nbits {
t := uint8(code << (rembits - nbits))
dst[len(dst)-1] |= t
rembits -= nbits
break
}
t := uint8(code >> (nbits - rembits))
dst[len(dst)-1] |= t
nbits -= rembits
rembits = 8
if nbits == 0 {
break
}
dst = append(dst, 0)
}
return dst, rembits
}

View file

@ -1,353 +0,0 @@
// Copyright 2014 The Go Authors.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
package hpack
func pair(name, value string) HeaderField {
return HeaderField{Name: name, Value: value}
}
// http://tools.ietf.org/html/draft-ietf-httpbis-header-compression-07#appendix-B
var staticTable = []HeaderField{
pair(":authority", ""), // index 1 (1-based)
pair(":method", "GET"),
pair(":method", "POST"),
pair(":path", "/"),
pair(":path", "/index.html"),
pair(":scheme", "http"),
pair(":scheme", "https"),
pair(":status", "200"),
pair(":status", "204"),
pair(":status", "206"),
pair(":status", "304"),
pair(":status", "400"),
pair(":status", "404"),
pair(":status", "500"),
pair("accept-charset", ""),
pair("accept-encoding", "gzip, deflate"),
pair("accept-language", ""),
pair("accept-ranges", ""),
pair("accept", ""),
pair("access-control-allow-origin", ""),
pair("age", ""),
pair("allow", ""),
pair("authorization", ""),
pair("cache-control", ""),
pair("content-disposition", ""),
pair("content-encoding", ""),
pair("content-language", ""),
pair("content-length", ""),
pair("content-location", ""),
pair("content-range", ""),
pair("content-type", ""),
pair("cookie", ""),
pair("date", ""),
pair("etag", ""),
pair("expect", ""),
pair("expires", ""),
pair("from", ""),
pair("host", ""),
pair("if-match", ""),
pair("if-modified-since", ""),
pair("if-none-match", ""),
pair("if-range", ""),
pair("if-unmodified-since", ""),
pair("last-modified", ""),
pair("link", ""),
pair("location", ""),
pair("max-forwards", ""),
pair("proxy-authenticate", ""),
pair("proxy-authorization", ""),
pair("range", ""),
pair("referer", ""),
pair("refresh", ""),
pair("retry-after", ""),
pair("server", ""),
pair("set-cookie", ""),
pair("strict-transport-security", ""),
pair("transfer-encoding", ""),
pair("user-agent", ""),
pair("vary", ""),
pair("via", ""),
pair("www-authenticate", ""),
}
var huffmanCodes = []uint32{
0x1ff8,
0x7fffd8,
0xfffffe2,
0xfffffe3,
0xfffffe4,
0xfffffe5,
0xfffffe6,
0xfffffe7,
0xfffffe8,
0xffffea,
0x3ffffffc,
0xfffffe9,
0xfffffea,
0x3ffffffd,
0xfffffeb,
0xfffffec,
0xfffffed,
0xfffffee,
0xfffffef,
0xffffff0,
0xffffff1,
0xffffff2,
0x3ffffffe,
0xffffff3,
0xffffff4,
0xffffff5,
0xffffff6,
0xffffff7,
0xffffff8,
0xffffff9,
0xffffffa,
0xffffffb,
0x14,
0x3f8,
0x3f9,
0xffa,
0x1ff9,
0x15,
0xf8,
0x7fa,
0x3fa,
0x3fb,
0xf9,
0x7fb,
0xfa,
0x16,
0x17,
0x18,
0x0,
0x1,
0x2,
0x19,
0x1a,
0x1b,
0x1c,
0x1d,
0x1e,
0x1f,
0x5c,
0xfb,
0x7ffc,
0x20,
0xffb,
0x3fc,
0x1ffa,
0x21,
0x5d,
0x5e,
0x5f,
0x60,
0x61,
0x62,
0x63,
0x64,
0x65,
0x66,
0x67,
0x68,
0x69,
0x6a,
0x6b,
0x6c,
0x6d,
0x6e,
0x6f,
0x70,
0x71,
0x72,
0xfc,
0x73,
0xfd,
0x1ffb,
0x7fff0,
0x1ffc,
0x3ffc,
0x22,
0x7ffd,
0x3,
0x23,
0x4,
0x24,
0x5,
0x25,
0x26,
0x27,
0x6,
0x74,
0x75,
0x28,
0x29,
0x2a,
0x7,
0x2b,
0x76,
0x2c,
0x8,
0x9,
0x2d,
0x77,
0x78,
0x79,
0x7a,
0x7b,
0x7ffe,
0x7fc,
0x3ffd,
0x1ffd,
0xffffffc,
0xfffe6,
0x3fffd2,
0xfffe7,
0xfffe8,
0x3fffd3,
0x3fffd4,
0x3fffd5,
0x7fffd9,
0x3fffd6,
0x7fffda,
0x7fffdb,
0x7fffdc,
0x7fffdd,
0x7fffde,
0xffffeb,
0x7fffdf,
0xffffec,
0xffffed,
0x3fffd7,
0x7fffe0,
0xffffee,
0x7fffe1,
0x7fffe2,
0x7fffe3,
0x7fffe4,
0x1fffdc,
0x3fffd8,
0x7fffe5,
0x3fffd9,
0x7fffe6,
0x7fffe7,
0xffffef,
0x3fffda,
0x1fffdd,
0xfffe9,
0x3fffdb,
0x3fffdc,
0x7fffe8,
0x7fffe9,
0x1fffde,
0x7fffea,
0x3fffdd,
0x3fffde,
0xfffff0,
0x1fffdf,
0x3fffdf,
0x7fffeb,
0x7fffec,
0x1fffe0,
0x1fffe1,
0x3fffe0,
0x1fffe2,
0x7fffed,
0x3fffe1,
0x7fffee,
0x7fffef,
0xfffea,
0x3fffe2,
0x3fffe3,
0x3fffe4,
0x7ffff0,
0x3fffe5,
0x3fffe6,
0x7ffff1,
0x3ffffe0,
0x3ffffe1,
0xfffeb,
0x7fff1,
0x3fffe7,
0x7ffff2,
0x3fffe8,
0x1ffffec,
0x3ffffe2,
0x3ffffe3,
0x3ffffe4,
0x7ffffde,
0x7ffffdf,
0x3ffffe5,
0xfffff1,
0x1ffffed,
0x7fff2,
0x1fffe3,
0x3ffffe6,
0x7ffffe0,
0x7ffffe1,
0x3ffffe7,
0x7ffffe2,
0xfffff2,
0x1fffe4,
0x1fffe5,
0x3ffffe8,
0x3ffffe9,
0xffffffd,
0x7ffffe3,
0x7ffffe4,
0x7ffffe5,
0xfffec,
0xfffff3,
0xfffed,
0x1fffe6,
0x3fffe9,
0x1fffe7,
0x1fffe8,
0x7ffff3,
0x3fffea,
0x3fffeb,
0x1ffffee,
0x1ffffef,
0xfffff4,
0xfffff5,
0x3ffffea,
0x7ffff4,
0x3ffffeb,
0x7ffffe6,
0x3ffffec,
0x3ffffed,
0x7ffffe7,
0x7ffffe8,
0x7ffffe9,
0x7ffffea,
0x7ffffeb,
0xffffffe,
0x7ffffec,
0x7ffffed,
0x7ffffee,
0x7ffffef,
0x7fffff0,
0x3ffffee,
}
var huffmanCodeLen = []uint8{
13, 23, 28, 28, 28, 28, 28, 28, 28, 24, 30, 28, 28, 30, 28, 28,
28, 28, 28, 28, 28, 28, 30, 28, 28, 28, 28, 28, 28, 28, 28, 28,
6, 10, 10, 12, 13, 6, 8, 11, 10, 10, 8, 11, 8, 6, 6, 6,
5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 7, 8, 15, 6, 12, 10,
13, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
7, 7, 7, 7, 7, 7, 7, 7, 8, 7, 8, 13, 19, 13, 14, 6,
15, 5, 6, 5, 6, 5, 6, 6, 6, 5, 7, 7, 6, 6, 6, 5,
6, 7, 6, 5, 5, 6, 7, 7, 7, 7, 7, 15, 11, 14, 13, 28,
20, 22, 20, 20, 22, 22, 22, 23, 22, 23, 23, 23, 23, 23, 24, 23,
24, 24, 22, 23, 24, 23, 23, 23, 23, 21, 22, 23, 22, 23, 23, 24,
22, 21, 20, 22, 22, 23, 23, 21, 23, 22, 22, 24, 21, 22, 23, 23,
21, 21, 22, 21, 23, 22, 23, 23, 20, 22, 22, 22, 23, 22, 22, 23,
26, 26, 20, 19, 22, 23, 22, 25, 26, 26, 26, 27, 27, 26, 24, 25,
19, 21, 26, 27, 27, 26, 27, 24, 21, 21, 26, 26, 28, 27, 27, 27,
20, 24, 20, 21, 22, 21, 21, 23, 22, 22, 25, 25, 24, 24, 26, 23,
26, 27, 26, 26, 27, 27, 27, 27, 27, 28, 27, 27, 27, 27, 27, 26,
}

View file

@ -1,249 +0,0 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
// Package http2 implements the HTTP/2 protocol.
//
// This is a work in progress. This package is low-level and intended
// to be used directly by very few people. Most users will use it
// indirectly through integration with the net/http package. See
// ConfigureServer. That ConfigureServer call will likely be automatic
// or available via an empty import in the future.
//
// See http://http2.github.io/
package http2
import (
"bufio"
"fmt"
"io"
"net/http"
"strconv"
"sync"
)
var VerboseLogs = false
const (
// ClientPreface is the string that must be sent by new
// connections from clients.
ClientPreface = "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"
// SETTINGS_MAX_FRAME_SIZE default
// http://http2.github.io/http2-spec/#rfc.section.6.5.2
initialMaxFrameSize = 16384
// NextProtoTLS is the NPN/ALPN protocol negotiated during
// HTTP/2's TLS setup.
NextProtoTLS = "h2"
// http://http2.github.io/http2-spec/#SettingValues
initialHeaderTableSize = 4096
initialWindowSize = 65535 // 6.9.2 Initial Flow Control Window Size
defaultMaxReadFrameSize = 1 << 20
)
var (
clientPreface = []byte(ClientPreface)
)
type streamState int
const (
stateIdle streamState = iota
stateOpen
stateHalfClosedLocal
stateHalfClosedRemote
stateResvLocal
stateResvRemote
stateClosed
)
var stateName = [...]string{
stateIdle: "Idle",
stateOpen: "Open",
stateHalfClosedLocal: "HalfClosedLocal",
stateHalfClosedRemote: "HalfClosedRemote",
stateResvLocal: "ResvLocal",
stateResvRemote: "ResvRemote",
stateClosed: "Closed",
}
func (st streamState) String() string {
return stateName[st]
}
// Setting is a setting parameter: which setting it is, and its value.
type Setting struct {
// ID is which setting is being set.
// See http://http2.github.io/http2-spec/#SettingValues
ID SettingID
// Val is the value.
Val uint32
}
func (s Setting) String() string {
return fmt.Sprintf("[%v = %d]", s.ID, s.Val)
}
// Valid reports whether the setting is valid.
func (s Setting) Valid() error {
// Limits and error codes from 6.5.2 Defined SETTINGS Parameters
switch s.ID {
case SettingEnablePush:
if s.Val != 1 && s.Val != 0 {
return ConnectionError(ErrCodeProtocol)
}
case SettingInitialWindowSize:
if s.Val > 1<<31-1 {
return ConnectionError(ErrCodeFlowControl)
}
case SettingMaxFrameSize:
if s.Val < 16384 || s.Val > 1<<24-1 {
return ConnectionError(ErrCodeProtocol)
}
}
return nil
}
// A SettingID is an HTTP/2 setting as defined in
// http://http2.github.io/http2-spec/#iana-settings
type SettingID uint16
const (
SettingHeaderTableSize SettingID = 0x1
SettingEnablePush SettingID = 0x2
SettingMaxConcurrentStreams SettingID = 0x3
SettingInitialWindowSize SettingID = 0x4
SettingMaxFrameSize SettingID = 0x5
SettingMaxHeaderListSize SettingID = 0x6
)
var settingName = map[SettingID]string{
SettingHeaderTableSize: "HEADER_TABLE_SIZE",
SettingEnablePush: "ENABLE_PUSH",
SettingMaxConcurrentStreams: "MAX_CONCURRENT_STREAMS",
SettingInitialWindowSize: "INITIAL_WINDOW_SIZE",
SettingMaxFrameSize: "MAX_FRAME_SIZE",
SettingMaxHeaderListSize: "MAX_HEADER_LIST_SIZE",
}
func (s SettingID) String() string {
if v, ok := settingName[s]; ok {
return v
}
return fmt.Sprintf("UNKNOWN_SETTING_%d", uint16(s))
}
func validHeader(v string) bool {
if len(v) == 0 {
return false
}
for _, r := range v {
// "Just as in HTTP/1.x, header field names are
// strings of ASCII characters that are compared in a
// case-insensitive fashion. However, header field
// names MUST be converted to lowercase prior to their
// encoding in HTTP/2. "
if r >= 127 || ('A' <= r && r <= 'Z') {
return false
}
}
return true
}
var httpCodeStringCommon = map[int]string{} // n -> strconv.Itoa(n)
func init() {
for i := 100; i <= 999; i++ {
if v := http.StatusText(i); v != "" {
httpCodeStringCommon[i] = strconv.Itoa(i)
}
}
}
func httpCodeString(code int) string {
if s, ok := httpCodeStringCommon[code]; ok {
return s
}
return strconv.Itoa(code)
}
// from pkg io
type stringWriter interface {
WriteString(s string) (n int, err error)
}
// A gate lets two goroutines coordinate their activities.
type gate chan struct{}
func (g gate) Done() { g <- struct{}{} }
func (g gate) Wait() { <-g }
// A closeWaiter is like a sync.WaitGroup but only goes 1 to 0 (open to closed).
type closeWaiter chan struct{}
// Init makes a closeWaiter usable.
// It exists because so a closeWaiter value can be placed inside a
// larger struct and have the Mutex and Cond's memory in the same
// allocation.
func (cw *closeWaiter) Init() {
*cw = make(chan struct{})
}
// Close marks the closeWaiter as closed and unblocks any waiters.
func (cw closeWaiter) Close() {
close(cw)
}
// Wait waits for the closeWaiter to become closed.
func (cw closeWaiter) Wait() {
<-cw
}
// bufferedWriter is a buffered writer that writes to w.
// Its buffered writer is lazily allocated as needed, to minimize
// idle memory usage with many connections.
type bufferedWriter struct {
w io.Writer // immutable
bw *bufio.Writer // non-nil when data is buffered
}
func newBufferedWriter(w io.Writer) *bufferedWriter {
return &bufferedWriter{w: w}
}
var bufWriterPool = sync.Pool{
New: func() interface{} {
// TODO: pick something better? this is a bit under
// (3 x typical 1500 byte MTU) at least.
return bufio.NewWriterSize(nil, 4<<10)
},
}
func (w *bufferedWriter) Write(p []byte) (n int, err error) {
if w.bw == nil {
bw := bufWriterPool.Get().(*bufio.Writer)
bw.Reset(w.w)
w.bw = bw
}
return w.bw.Write(p)
}
func (w *bufferedWriter) Flush() error {
bw := w.bw
if bw == nil {
return nil
}
err := bw.Flush()
bw.Reset(nil)
bufWriterPool.Put(bw)
w.bw = nil
return err
}

View file

@ -1,43 +0,0 @@
// Copyright 2014 The Go Authors.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
package http2
import (
"sync"
)
type pipe struct {
b buffer
c sync.Cond
m sync.Mutex
}
// Read waits until data is available and copies bytes
// from the buffer into p.
func (r *pipe) Read(p []byte) (n int, err error) {
r.c.L.Lock()
defer r.c.L.Unlock()
for r.b.Len() == 0 && !r.b.closed {
r.c.Wait()
}
return r.b.Read(p)
}
// Write copies bytes from p into the buffer and wakes a reader.
// It is an error to write more data than the buffer can hold.
func (w *pipe) Write(p []byte) (n int, err error) {
w.c.L.Lock()
defer w.c.L.Unlock()
defer w.c.Signal()
return w.b.Write(p)
}
func (c *pipe) Close(err error) {
c.c.L.Lock()
defer c.c.L.Unlock()
defer c.c.Signal()
c.b.Close(err)
}

File diff suppressed because it is too large Load diff

View file

@ -1,553 +0,0 @@
// Copyright 2015 The Go Authors.
// See https://go.googlesource.com/go/+/master/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://go.googlesource.com/go/+/master/LICENSE
package http2
import (
"bufio"
"bytes"
"crypto/tls"
"errors"
"fmt"
"io"
"log"
"net"
"net/http"
"strconv"
"strings"
"sync"
"github.com/bradfitz/http2/hpack"
)
type Transport struct {
Fallback http.RoundTripper
// TODO: remove this and make more general with a TLS dial hook, like http
InsecureTLSDial bool
connMu sync.Mutex
conns map[string][]*clientConn // key is host:port
}
type clientConn struct {
t *Transport
tconn *tls.Conn
tlsState *tls.ConnectionState
connKey []string // key(s) this connection is cached in, in t.conns
readerDone chan struct{} // closed on error
readerErr error // set before readerDone is closed
hdec *hpack.Decoder
nextRes *http.Response
mu sync.Mutex
closed bool
goAway *GoAwayFrame // if non-nil, the GoAwayFrame we received
streams map[uint32]*clientStream
nextStreamID uint32
bw *bufio.Writer
werr error // first write error that has occurred
br *bufio.Reader
fr *Framer
// Settings from peer:
maxFrameSize uint32
maxConcurrentStreams uint32
initialWindowSize uint32
hbuf bytes.Buffer // HPACK encoder writes into this
henc *hpack.Encoder
}
type clientStream struct {
ID uint32
resc chan resAndError
pw *io.PipeWriter
pr *io.PipeReader
}
type stickyErrWriter struct {
w io.Writer
err *error
}
func (sew stickyErrWriter) Write(p []byte) (n int, err error) {
if *sew.err != nil {
return 0, *sew.err
}
n, err = sew.w.Write(p)
*sew.err = err
return
}
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
if req.URL.Scheme != "https" {
if t.Fallback == nil {
return nil, errors.New("http2: unsupported scheme and no Fallback")
}
return t.Fallback.RoundTrip(req)
}
host, port, err := net.SplitHostPort(req.URL.Host)
if err != nil {
host = req.URL.Host
port = "443"
}
for {
cc, err := t.getClientConn(host, port)
if err != nil {
return nil, err
}
res, err := cc.roundTrip(req)
if shouldRetryRequest(err) { // TODO: or clientconn is overloaded (too many outstanding requests)?
continue
}
if err != nil {
return nil, err
}
return res, nil
}
}
// CloseIdleConnections closes any connections which were previously
// connected from previous requests but are now sitting idle.
// It does not interrupt any connections currently in use.
func (t *Transport) CloseIdleConnections() {
t.connMu.Lock()
defer t.connMu.Unlock()
for _, vv := range t.conns {
for _, cc := range vv {
cc.closeIfIdle()
}
}
}
var errClientConnClosed = errors.New("http2: client conn is closed")
func shouldRetryRequest(err error) bool {
// TODO: or GOAWAY graceful shutdown stuff
return err == errClientConnClosed
}
func (t *Transport) removeClientConn(cc *clientConn) {
t.connMu.Lock()
defer t.connMu.Unlock()
for _, key := range cc.connKey {
vv, ok := t.conns[key]
if !ok {
continue
}
newList := filterOutClientConn(vv, cc)
if len(newList) > 0 {
t.conns[key] = newList
} else {
delete(t.conns, key)
}
}
}
func filterOutClientConn(in []*clientConn, exclude *clientConn) []*clientConn {
out := in[:0]
for _, v := range in {
if v != exclude {
out = append(out, v)
}
}
return out
}
func (t *Transport) getClientConn(host, port string) (*clientConn, error) {
t.connMu.Lock()
defer t.connMu.Unlock()
key := net.JoinHostPort(host, port)
for _, cc := range t.conns[key] {
if cc.canTakeNewRequest() {
return cc, nil
}
}
if t.conns == nil {
t.conns = make(map[string][]*clientConn)
}
cc, err := t.newClientConn(host, port, key)
if err != nil {
return nil, err
}
t.conns[key] = append(t.conns[key], cc)
return cc, nil
}
func (t *Transport) newClientConn(host, port, key string) (*clientConn, error) {
cfg := &tls.Config{
ServerName: host,
NextProtos: []string{NextProtoTLS},
InsecureSkipVerify: t.InsecureTLSDial,
}
tconn, err := tls.Dial("tcp", host+":"+port, cfg)
if err != nil {
return nil, err
}
if err := tconn.Handshake(); err != nil {
return nil, err
}
if !t.InsecureTLSDial {
if err := tconn.VerifyHostname(cfg.ServerName); err != nil {
return nil, err
}
}
state := tconn.ConnectionState()
if p := state.NegotiatedProtocol; p != NextProtoTLS {
// TODO(bradfitz): fall back to Fallback
return nil, fmt.Errorf("bad protocol: %v", p)
}
if !state.NegotiatedProtocolIsMutual {
return nil, errors.New("could not negotiate protocol mutually")
}
if _, err := tconn.Write(clientPreface); err != nil {
return nil, err
}
cc := &clientConn{
t: t,
tconn: tconn,
connKey: []string{key}, // TODO: cert's validated hostnames too
tlsState: &state,
readerDone: make(chan struct{}),
nextStreamID: 1,
maxFrameSize: 16 << 10, // spec default
initialWindowSize: 65535, // spec default
maxConcurrentStreams: 1000, // "infinite", per spec. 1000 seems good enough.
streams: make(map[uint32]*clientStream),
}
cc.bw = bufio.NewWriter(stickyErrWriter{tconn, &cc.werr})
cc.br = bufio.NewReader(tconn)
cc.fr = NewFramer(cc.bw, cc.br)
cc.henc = hpack.NewEncoder(&cc.hbuf)
cc.fr.WriteSettings()
// TODO: re-send more conn-level flow control tokens when server uses all these.
cc.fr.WriteWindowUpdate(0, 1<<30) // um, 0x7fffffff doesn't work to Google? it hangs?
cc.bw.Flush()
if cc.werr != nil {
return nil, cc.werr
}
// Read the obligatory SETTINGS frame
f, err := cc.fr.ReadFrame()
if err != nil {
return nil, err
}
sf, ok := f.(*SettingsFrame)
if !ok {
return nil, fmt.Errorf("expected settings frame, got: %T", f)
}
cc.fr.WriteSettingsAck()
cc.bw.Flush()
sf.ForeachSetting(func(s Setting) error {
switch s.ID {
case SettingMaxFrameSize:
cc.maxFrameSize = s.Val
case SettingMaxConcurrentStreams:
cc.maxConcurrentStreams = s.Val
case SettingInitialWindowSize:
cc.initialWindowSize = s.Val
default:
// TODO(bradfitz): handle more
log.Printf("Unhandled Setting: %v", s)
}
return nil
})
// TODO: figure out henc size
cc.hdec = hpack.NewDecoder(initialHeaderTableSize, cc.onNewHeaderField)
go cc.readLoop()
return cc, nil
}
func (cc *clientConn) setGoAway(f *GoAwayFrame) {
cc.mu.Lock()
defer cc.mu.Unlock()
cc.goAway = f
}
func (cc *clientConn) canTakeNewRequest() bool {
cc.mu.Lock()
defer cc.mu.Unlock()
return cc.goAway == nil &&
int64(len(cc.streams)+1) < int64(cc.maxConcurrentStreams) &&
cc.nextStreamID < 2147483647
}
func (cc *clientConn) closeIfIdle() {
cc.mu.Lock()
if len(cc.streams) > 0 {
cc.mu.Unlock()
return
}
cc.closed = true
// TODO: do clients send GOAWAY too? maybe? Just Close:
cc.mu.Unlock()
cc.tconn.Close()
}
func (cc *clientConn) roundTrip(req *http.Request) (*http.Response, error) {
cc.mu.Lock()
if cc.closed {
cc.mu.Unlock()
return nil, errClientConnClosed
}
cs := cc.newStream()
hasBody := false // TODO
// we send: HEADERS[+CONTINUATION] + (DATA?)
hdrs := cc.encodeHeaders(req)
first := true
for len(hdrs) > 0 {
chunk := hdrs
if len(chunk) > int(cc.maxFrameSize) {
chunk = chunk[:cc.maxFrameSize]
}
hdrs = hdrs[len(chunk):]
endHeaders := len(hdrs) == 0
if first {
cc.fr.WriteHeaders(HeadersFrameParam{
StreamID: cs.ID,
BlockFragment: chunk,
EndStream: !hasBody,
EndHeaders: endHeaders,
})
first = false
} else {
cc.fr.WriteContinuation(cs.ID, endHeaders, chunk)
}
}
cc.bw.Flush()
werr := cc.werr
cc.mu.Unlock()
if hasBody {
// TODO: write data. and it should probably be interleaved:
// go ... io.Copy(dataFrameWriter{cc, cs, ...}, req.Body) ... etc
}
if werr != nil {
return nil, werr
}
re := <-cs.resc
if re.err != nil {
return nil, re.err
}
res := re.res
res.Request = req
res.TLS = cc.tlsState
return res, nil
}
// requires cc.mu be held.
func (cc *clientConn) encodeHeaders(req *http.Request) []byte {
cc.hbuf.Reset()
// TODO(bradfitz): figure out :authority-vs-Host stuff between http2 and Go
host := req.Host
if host == "" {
host = req.URL.Host
}
path := req.URL.Path
if path == "" {
path = "/"
}
cc.writeHeader(":authority", host) // probably not right for all sites
cc.writeHeader(":method", req.Method)
cc.writeHeader(":path", path)
cc.writeHeader(":scheme", "https")
for k, vv := range req.Header {
lowKey := strings.ToLower(k)
if lowKey == "host" {
continue
}
for _, v := range vv {
cc.writeHeader(lowKey, v)
}
}
return cc.hbuf.Bytes()
}
func (cc *clientConn) writeHeader(name, value string) {
log.Printf("sending %q = %q", name, value)
cc.henc.WriteField(hpack.HeaderField{Name: name, Value: value})
}
type resAndError struct {
res *http.Response
err error
}
// requires cc.mu be held.
func (cc *clientConn) newStream() *clientStream {
cs := &clientStream{
ID: cc.nextStreamID,
resc: make(chan resAndError, 1),
}
cc.nextStreamID += 2
cc.streams[cs.ID] = cs
return cs
}
func (cc *clientConn) streamByID(id uint32, andRemove bool) *clientStream {
cc.mu.Lock()
defer cc.mu.Unlock()
cs := cc.streams[id]
if andRemove {
delete(cc.streams, id)
}
return cs
}
// runs in its own goroutine.
func (cc *clientConn) readLoop() {
defer cc.t.removeClientConn(cc)
defer close(cc.readerDone)
activeRes := map[uint32]*clientStream{} // keyed by streamID
// Close any response bodies if the server closes prematurely.
// TODO: also do this if we've written the headers but not
// gotten a response yet.
defer func() {
err := cc.readerErr
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
for _, cs := range activeRes {
cs.pw.CloseWithError(err)
}
}()
// continueStreamID is the stream ID we're waiting for
// continuation frames for.
var continueStreamID uint32
for {
f, err := cc.fr.ReadFrame()
if err != nil {
cc.readerErr = err
return
}
log.Printf("Transport received %v: %#v", f.Header(), f)
streamID := f.Header().StreamID
_, isContinue := f.(*ContinuationFrame)
if isContinue {
if streamID != continueStreamID {
log.Printf("Protocol violation: got CONTINUATION with id %d; want %d", streamID, continueStreamID)
cc.readerErr = ConnectionError(ErrCodeProtocol)
return
}
} else if continueStreamID != 0 {
// Continue frames need to be adjacent in the stream
// and we were in the middle of headers.
log.Printf("Protocol violation: got %T for stream %d, want CONTINUATION for %d", f, streamID, continueStreamID)
cc.readerErr = ConnectionError(ErrCodeProtocol)
return
}
if streamID%2 == 0 {
// Ignore streams pushed from the server for now.
// These always have an even stream id.
continue
}
streamEnded := false
if ff, ok := f.(streamEnder); ok {
streamEnded = ff.StreamEnded()
}
cs := cc.streamByID(streamID, streamEnded)
if cs == nil {
log.Printf("Received frame for untracked stream ID %d", streamID)
continue
}
switch f := f.(type) {
case *HeadersFrame:
cc.nextRes = &http.Response{
Proto: "HTTP/2.0",
ProtoMajor: 2,
Header: make(http.Header),
}
cs.pr, cs.pw = io.Pipe()
cc.hdec.Write(f.HeaderBlockFragment())
case *ContinuationFrame:
cc.hdec.Write(f.HeaderBlockFragment())
case *DataFrame:
log.Printf("DATA: %q", f.Data())
cs.pw.Write(f.Data())
case *GoAwayFrame:
cc.t.removeClientConn(cc)
if f.ErrCode != 0 {
// TODO: deal with GOAWAY more. particularly the error code
log.Printf("transport got GOAWAY with error code = %v", f.ErrCode)
}
cc.setGoAway(f)
default:
log.Printf("Transport: unhandled response frame type %T", f)
}
headersEnded := false
if he, ok := f.(headersEnder); ok {
headersEnded = he.HeadersEnded()
if headersEnded {
continueStreamID = 0
} else {
continueStreamID = streamID
}
}
if streamEnded {
cs.pw.Close()
delete(activeRes, streamID)
}
if headersEnded {
if cs == nil {
panic("couldn't find stream") // TODO be graceful
}
// TODO: set the Body to one which notes the
// Close and also sends the server a
// RST_STREAM
cc.nextRes.Body = cs.pr
res := cc.nextRes
activeRes[streamID] = cs
cs.resc <- resAndError{res: res}
}
}
}
func (cc *clientConn) onNewHeaderField(f hpack.HeaderField) {
// TODO: verifiy pseudo headers come before non-pseudo headers
// TODO: verifiy the status is set
log.Printf("Header field: %+v", f)
if f.Name == ":status" {
code, err := strconv.Atoi(f.Value)
if err != nil {
panic("TODO: be graceful")
}
cc.nextRes.Status = f.Value + " " + http.StatusText(code)
cc.nextRes.StatusCode = code
return
}
if strings.HasPrefix(f.Name, ":") {
// "Endpoints MUST NOT generate pseudo-header fields other than those defined in this document."
// TODO: treat as invalid?
return
}
cc.nextRes.Header.Add(http.CanonicalHeaderKey(f.Name), f.Value)
}

View file

@ -1,204 +0,0 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
package http2
import (
"bytes"
"fmt"
"net/http"
"time"
"github.com/bradfitz/http2/hpack"
)
// writeFramer is implemented by any type that is used to write frames.
type writeFramer interface {
writeFrame(writeContext) error
}
// writeContext is the interface needed by the various frame writer
// types below. All the writeFrame methods below are scheduled via the
// frame writing scheduler (see writeScheduler in writesched.go).
//
// This interface is implemented by *serverConn.
// TODO: use it from the client code too, once it exists.
type writeContext interface {
Framer() *Framer
Flush() error
CloseConn() error
// HeaderEncoder returns an HPACK encoder that writes to the
// returned buffer.
HeaderEncoder() (*hpack.Encoder, *bytes.Buffer)
}
// endsStream reports whether the given frame writer w will locally
// close the stream.
func endsStream(w writeFramer) bool {
switch v := w.(type) {
case *writeData:
return v.endStream
case *writeResHeaders:
return v.endStream
}
return false
}
type flushFrameWriter struct{}
func (flushFrameWriter) writeFrame(ctx writeContext) error {
return ctx.Flush()
}
type writeSettings []Setting
func (s writeSettings) writeFrame(ctx writeContext) error {
return ctx.Framer().WriteSettings([]Setting(s)...)
}
type writeGoAway struct {
maxStreamID uint32
code ErrCode
}
func (p *writeGoAway) writeFrame(ctx writeContext) error {
err := ctx.Framer().WriteGoAway(p.maxStreamID, p.code, nil)
if p.code != 0 {
ctx.Flush() // ignore error: we're hanging up on them anyway
time.Sleep(50 * time.Millisecond)
ctx.CloseConn()
}
return err
}
type writeData struct {
streamID uint32
p []byte
endStream bool
}
func (w *writeData) String() string {
return fmt.Sprintf("writeData(stream=%d, p=%d, endStream=%v)", w.streamID, len(w.p), w.endStream)
}
func (w *writeData) writeFrame(ctx writeContext) error {
return ctx.Framer().WriteData(w.streamID, w.endStream, w.p)
}
func (se StreamError) writeFrame(ctx writeContext) error {
return ctx.Framer().WriteRSTStream(se.StreamID, se.Code)
}
type writePingAck struct{ pf *PingFrame }
func (w writePingAck) writeFrame(ctx writeContext) error {
return ctx.Framer().WritePing(true, w.pf.Data)
}
type writeSettingsAck struct{}
func (writeSettingsAck) writeFrame(ctx writeContext) error {
return ctx.Framer().WriteSettingsAck()
}
// writeResHeaders is a request to write a HEADERS and 0+ CONTINUATION frames
// for HTTP response headers from a server handler.
type writeResHeaders struct {
streamID uint32
httpResCode int
h http.Header // may be nil
endStream bool
contentType string
contentLength string
}
func (w *writeResHeaders) writeFrame(ctx writeContext) error {
enc, buf := ctx.HeaderEncoder()
buf.Reset()
enc.WriteField(hpack.HeaderField{Name: ":status", Value: httpCodeString(w.httpResCode)})
for k, vv := range w.h {
k = lowerHeader(k)
for _, v := range vv {
// TODO: more of "8.1.2.2 Connection-Specific Header Fields"
if k == "transfer-encoding" && v != "trailers" {
continue
}
enc.WriteField(hpack.HeaderField{Name: k, Value: v})
}
}
if w.contentType != "" {
enc.WriteField(hpack.HeaderField{Name: "content-type", Value: w.contentType})
}
if w.contentLength != "" {
enc.WriteField(hpack.HeaderField{Name: "content-length", Value: w.contentLength})
}
headerBlock := buf.Bytes()
if len(headerBlock) == 0 {
panic("unexpected empty hpack")
}
// For now we're lazy and just pick the minimum MAX_FRAME_SIZE
// that all peers must support (16KB). Later we could care
// more and send larger frames if the peer advertised it, but
// there's little point. Most headers are small anyway (so we
// generally won't have CONTINUATION frames), and extra frames
// only waste 9 bytes anyway.
const maxFrameSize = 16384
first := true
for len(headerBlock) > 0 {
frag := headerBlock
if len(frag) > maxFrameSize {
frag = frag[:maxFrameSize]
}
headerBlock = headerBlock[len(frag):]
endHeaders := len(headerBlock) == 0
var err error
if first {
first = false
err = ctx.Framer().WriteHeaders(HeadersFrameParam{
StreamID: w.streamID,
BlockFragment: frag,
EndStream: w.endStream,
EndHeaders: endHeaders,
})
} else {
err = ctx.Framer().WriteContinuation(w.streamID, endHeaders, frag)
}
if err != nil {
return err
}
}
return nil
}
type write100ContinueHeadersFrame struct {
streamID uint32
}
func (w write100ContinueHeadersFrame) writeFrame(ctx writeContext) error {
enc, buf := ctx.HeaderEncoder()
buf.Reset()
enc.WriteField(hpack.HeaderField{Name: ":status", Value: "100"})
return ctx.Framer().WriteHeaders(HeadersFrameParam{
StreamID: w.streamID,
BlockFragment: buf.Bytes(),
EndStream: false,
EndHeaders: true,
})
}
type writeWindowUpdate struct {
streamID uint32 // or 0 for conn-level
n uint32
}
func (wu writeWindowUpdate) writeFrame(ctx writeContext) error {
return ctx.Framer().WriteWindowUpdate(wu.streamID, wu.n)
}

View file

@ -1,286 +0,0 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
package http2
import "fmt"
// frameWriteMsg is a request to write a frame.
type frameWriteMsg struct {
// write is the interface value that does the writing, once the
// writeScheduler (below) has decided to select this frame
// to write. The write functions are all defined in write.go.
write writeFramer
stream *stream // used for prioritization. nil for non-stream frames.
// done, if non-nil, must be a buffered channel with space for
// 1 message and is sent the return value from write (or an
// earlier error) when the frame has been written.
done chan error
}
// for debugging only:
func (wm frameWriteMsg) String() string {
var streamID uint32
if wm.stream != nil {
streamID = wm.stream.id
}
var des string
if s, ok := wm.write.(fmt.Stringer); ok {
des = s.String()
} else {
des = fmt.Sprintf("%T", wm.write)
}
return fmt.Sprintf("[frameWriteMsg stream=%d, ch=%v, type: %v]", streamID, wm.done != nil, des)
}
// writeScheduler tracks pending frames to write, priorities, and decides
// the next one to use. It is not thread-safe.
type writeScheduler struct {
// zero are frames not associated with a specific stream.
// They're sent before any stream-specific freams.
zero writeQueue
// maxFrameSize is the maximum size of a DATA frame
// we'll write. Must be non-zero and between 16K-16M.
maxFrameSize uint32
// sq contains the stream-specific queues, keyed by stream ID.
// when a stream is idle, it's deleted from the map.
sq map[uint32]*writeQueue
// canSend is a slice of memory that's reused between frame
// scheduling decisions to hold the list of writeQueues (from sq)
// which have enough flow control data to send. After canSend is
// built, the best is selected.
canSend []*writeQueue
// pool of empty queues for reuse.
queuePool []*writeQueue
}
func (ws *writeScheduler) putEmptyQueue(q *writeQueue) {
if len(q.s) != 0 {
panic("queue must be empty")
}
ws.queuePool = append(ws.queuePool, q)
}
func (ws *writeScheduler) getEmptyQueue() *writeQueue {
ln := len(ws.queuePool)
if ln == 0 {
return new(writeQueue)
}
q := ws.queuePool[ln-1]
ws.queuePool = ws.queuePool[:ln-1]
return q
}
func (ws *writeScheduler) empty() bool { return ws.zero.empty() && len(ws.sq) == 0 }
func (ws *writeScheduler) add(wm frameWriteMsg) {
st := wm.stream
if st == nil {
ws.zero.push(wm)
} else {
ws.streamQueue(st.id).push(wm)
}
}
func (ws *writeScheduler) streamQueue(streamID uint32) *writeQueue {
if q, ok := ws.sq[streamID]; ok {
return q
}
if ws.sq == nil {
ws.sq = make(map[uint32]*writeQueue)
}
q := ws.getEmptyQueue()
ws.sq[streamID] = q
return q
}
// take returns the most important frame to write and removes it from the scheduler.
// It is illegal to call this if the scheduler is empty or if there are no connection-level
// flow control bytes available.
func (ws *writeScheduler) take() (wm frameWriteMsg, ok bool) {
if ws.maxFrameSize == 0 {
panic("internal error: ws.maxFrameSize not initialized or invalid")
}
// If there any frames not associated with streams, prefer those first.
// These are usually SETTINGS, etc.
if !ws.zero.empty() {
return ws.zero.shift(), true
}
if len(ws.sq) == 0 {
return
}
// Next, prioritize frames on streams that aren't DATA frames (no cost).
for id, q := range ws.sq {
if q.firstIsNoCost() {
return ws.takeFrom(id, q)
}
}
// Now, all that remains are DATA frames with non-zero bytes to
// send. So pick the best one.
if len(ws.canSend) != 0 {
panic("should be empty")
}
for _, q := range ws.sq {
if n := ws.streamWritableBytes(q); n > 0 {
ws.canSend = append(ws.canSend, q)
}
}
if len(ws.canSend) == 0 {
return
}
defer ws.zeroCanSend()
// TODO: find the best queue
q := ws.canSend[0]
return ws.takeFrom(q.streamID(), q)
}
// zeroCanSend is defered from take.
func (ws *writeScheduler) zeroCanSend() {
for i := range ws.canSend {
ws.canSend[i] = nil
}
ws.canSend = ws.canSend[:0]
}
// streamWritableBytes returns the number of DATA bytes we could write
// from the given queue's stream, if this stream/queue were
// selected. It is an error to call this if q's head isn't a
// *writeData.
func (ws *writeScheduler) streamWritableBytes(q *writeQueue) int32 {
wm := q.head()
ret := wm.stream.flow.available() // max we can write
if ret == 0 {
return 0
}
if int32(ws.maxFrameSize) < ret {
ret = int32(ws.maxFrameSize)
}
if ret == 0 {
panic("internal error: ws.maxFrameSize not initialized or invalid")
}
wd := wm.write.(*writeData)
if len(wd.p) < int(ret) {
ret = int32(len(wd.p))
}
return ret
}
func (ws *writeScheduler) takeFrom(id uint32, q *writeQueue) (wm frameWriteMsg, ok bool) {
wm = q.head()
// If the first item in this queue costs flow control tokens
// and we don't have enough, write as much as we can.
if wd, ok := wm.write.(*writeData); ok && len(wd.p) > 0 {
allowed := wm.stream.flow.available() // max we can write
if allowed == 0 {
// No quota available. Caller can try the next stream.
return frameWriteMsg{}, false
}
if int32(ws.maxFrameSize) < allowed {
allowed = int32(ws.maxFrameSize)
}
// TODO: further restrict the allowed size, because even if
// the peer says it's okay to write 16MB data frames, we might
// want to write smaller ones to properly weight competing
// streams' priorities.
if len(wd.p) > int(allowed) {
wm.stream.flow.take(allowed)
chunk := wd.p[:allowed]
wd.p = wd.p[allowed:]
// Make up a new write message of a valid size, rather
// than shifting one off the queue.
return frameWriteMsg{
stream: wm.stream,
write: &writeData{
streamID: wd.streamID,
p: chunk,
// even if the original had endStream set, there
// arebytes remaining because len(wd.p) > allowed,
// so we know endStream is false:
endStream: false,
},
// our caller is blocking on the final DATA frame, not
// these intermediates, so no need to wait:
done: nil,
}, true
}
wm.stream.flow.take(int32(len(wd.p)))
}
q.shift()
if q.empty() {
ws.putEmptyQueue(q)
delete(ws.sq, id)
}
return wm, true
}
func (ws *writeScheduler) forgetStream(id uint32) {
q, ok := ws.sq[id]
if !ok {
return
}
delete(ws.sq, id)
// But keep it for others later.
for i := range q.s {
q.s[i] = frameWriteMsg{}
}
q.s = q.s[:0]
ws.putEmptyQueue(q)
}
type writeQueue struct {
s []frameWriteMsg
}
// streamID returns the stream ID for a non-empty stream-specific queue.
func (q *writeQueue) streamID() uint32 { return q.s[0].stream.id }
func (q *writeQueue) empty() bool { return len(q.s) == 0 }
func (q *writeQueue) push(wm frameWriteMsg) {
q.s = append(q.s, wm)
}
// head returns the next item that would be removed by shift.
func (q *writeQueue) head() frameWriteMsg {
if len(q.s) == 0 {
panic("invalid use of queue")
}
return q.s[0]
}
func (q *writeQueue) shift() frameWriteMsg {
if len(q.s) == 0 {
panic("invalid use of queue")
}
wm := q.s[0]
// TODO: less copy-happy queue.
copy(q.s, q.s[1:])
q.s[len(q.s)-1] = frameWriteMsg{}
q.s = q.s[:len(q.s)-1]
return wm
}
func (q *writeQueue) firstIsNoCost() bool {
if df, ok := q.s[0].write.(*writeData); ok {
return len(df.p) == 0
}
return true
}

View file

@ -39,5 +39,5 @@ test: install generate-test-pbs
generate-test-pbs:
make install
make -C testdata
protoc --go_out=Mtestdata/test.proto=github.com/golang/protobuf/proto/testdata:. proto3_proto/proto3.proto
protoc --go_out=Mtestdata/test.proto=github.com/golang/protobuf/proto/testdata,Mgoogle/protobuf/any.proto=github.com/golang/protobuf/ptypes/any:. proto3_proto/proto3.proto
make

View file

@ -30,7 +30,7 @@
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Protocol buffer deep copy and merge.
// TODO: MessageSet and RawMessage.
// TODO: RawMessage.
package proto
@ -120,6 +120,17 @@ func mergeAny(out, in reflect.Value, viaPtr bool, prop *Properties) {
return
}
out.Set(in)
case reflect.Interface:
// Probably a oneof field; copy non-nil values.
if in.IsNil() {
return
}
// Allocate destination if it is not set, or set to a different type.
// Otherwise we will merge as normal.
if out.IsNil() || out.Elem().Type() != in.Elem().Type() {
out.Set(reflect.New(in.Elem().Elem().Type())) // interface -> *T -> T -> new(T)
}
mergeAny(out.Elem(), in.Elem(), false, nil)
case reflect.Map:
if in.Len() == 0 {
return

View file

@ -46,6 +46,10 @@ import (
// errOverflow is returned when an integer is too large to be represented.
var errOverflow = errors.New("proto: integer overflow")
// ErrInternalBadWireType is returned by generated code when an incorrect
// wire type is encountered. It does not get returned to user code.
var ErrInternalBadWireType = errors.New("proto: internal error: bad wiretype for oneof")
// The fundamental decoders that interpret bytes on the wire.
// Those that take integer types all return uint64 and are
// therefore of type valueDecoder.
@ -314,6 +318,24 @@ func UnmarshalMerge(buf []byte, pb Message) error {
return NewBuffer(buf).Unmarshal(pb)
}
// DecodeMessage reads a count-delimited message from the Buffer.
func (p *Buffer) DecodeMessage(pb Message) error {
enc, err := p.DecodeRawBytes(false)
if err != nil {
return err
}
return NewBuffer(enc).Unmarshal(pb)
}
// DecodeGroup reads a tag-delimited group from the Buffer.
func (p *Buffer) DecodeGroup(pb Message) error {
typ, base, err := getbase(pb)
if err != nil {
return err
}
return p.unmarshalType(typ.Elem(), GetProperties(typ.Elem()), true, base)
}
// Unmarshal parses the protocol buffer representation in the
// Buffer and places the decoded result in pb. If the struct
// underlying pb does not match the data in the buffer, the results can be
@ -377,6 +399,20 @@ func (o *Buffer) unmarshalType(st reflect.Type, prop *StructProperties, is_group
continue
}
}
// Maybe it's a oneof?
if prop.oneofUnmarshaler != nil {
m := structPointer_Interface(base, st).(Message)
// First return value indicates whether tag is a oneof field.
ok, err = prop.oneofUnmarshaler(m, tag, wire, o)
if err == ErrInternalBadWireType {
// Map the error to something more descriptive.
// Do the formatting here to save generated code space.
err = fmt.Errorf("bad wiretype for oneof field in %T", m)
}
if ok {
continue
}
}
err = o.skipAndSave(st, tag, wire, base, prop.unrecField)
continue
}
@ -561,9 +597,13 @@ func (o *Buffer) dec_slice_packed_bool(p *Properties, base structPointer) error
return err
}
nb := int(nn) // number of bytes of encoded bools
fin := o.index + nb
if fin < o.index {
return errOverflow
}
y := *v
for i := 0; i < nb; i++ {
for o.index < fin {
u, err := p.valDec(o)
if err != nil {
return err
@ -728,10 +768,11 @@ func (o *Buffer) dec_new_map(p *Properties, base structPointer) error {
}
}
keyelem, valelem := keyptr.Elem(), valptr.Elem()
if !keyelem.IsValid() || !valelem.IsValid() {
// We did not decode the key or the value in the map entry.
// Either way, it's an invalid map entry.
return fmt.Errorf("proto: bad map data: missing key/val")
if !keyelem.IsValid() {
keyelem = reflect.Zero(p.mtype.Key())
}
if !valelem.IsValid() {
valelem = reflect.Zero(p.mtype.Elem())
}
v.SetMapIndex(keyelem, valelem)

View file

@ -64,6 +64,10 @@ var (
// a struct with a repeated field containing a nil element.
errRepeatedHasNil = errors.New("proto: repeated field has nil element")
// errOneofHasNil is the error returned if Marshal is called with
// a struct with a oneof field containing a nil element.
errOneofHasNil = errors.New("proto: oneof field has nil value")
// ErrNil is the error returned if Marshal is called with nil.
ErrNil = errors.New("proto: Marshal called with nil")
)
@ -105,6 +109,11 @@ func (p *Buffer) EncodeVarint(x uint64) error {
return nil
}
// SizeVarint returns the varint encoding size of an integer.
func SizeVarint(x uint64) int {
return sizeVarint(x)
}
func sizeVarint(x uint64) (n int) {
for {
n++
@ -228,6 +237,20 @@ func Marshal(pb Message) ([]byte, error) {
return p.buf, err
}
// EncodeMessage writes the protocol buffer to the Buffer,
// prefixed by a varint-encoded length.
func (p *Buffer) EncodeMessage(pb Message) error {
t, base, err := getbase(pb)
if structPointer_IsNil(base) {
return ErrNil
}
if err == nil {
var state errorState
err = p.enc_len_struct(GetProperties(t.Elem()), base, &state)
}
return err
}
// Marshal takes the protocol buffer
// and encodes it into the wire format, writing the result to the
// Buffer.
@ -318,7 +341,7 @@ func size_bool(p *Properties, base structPointer) int {
func size_proto3_bool(p *Properties, base structPointer) int {
v := *structPointer_BoolVal(base, p.field)
if !v {
if !v && !p.oneof {
return 0
}
return len(p.tagcode) + 1 // each bool takes exactly one byte
@ -361,7 +384,7 @@ func size_int32(p *Properties, base structPointer) (n int) {
func size_proto3_int32(p *Properties, base structPointer) (n int) {
v := structPointer_Word32Val(base, p.field)
x := int32(word32Val_Get(v)) // permit sign extension to use full 64-bit range
if x == 0 {
if x == 0 && !p.oneof {
return 0
}
n += len(p.tagcode)
@ -407,7 +430,7 @@ func size_uint32(p *Properties, base structPointer) (n int) {
func size_proto3_uint32(p *Properties, base structPointer) (n int) {
v := structPointer_Word32Val(base, p.field)
x := word32Val_Get(v)
if x == 0 {
if x == 0 && !p.oneof {
return 0
}
n += len(p.tagcode)
@ -452,7 +475,7 @@ func size_int64(p *Properties, base structPointer) (n int) {
func size_proto3_int64(p *Properties, base structPointer) (n int) {
v := structPointer_Word64Val(base, p.field)
x := word64Val_Get(v)
if x == 0 {
if x == 0 && !p.oneof {
return 0
}
n += len(p.tagcode)
@ -495,7 +518,7 @@ func size_string(p *Properties, base structPointer) (n int) {
func size_proto3_string(p *Properties, base structPointer) (n int) {
v := *structPointer_StringVal(base, p.field)
if v == "" {
if v == "" && !p.oneof {
return 0
}
n += len(p.tagcode)
@ -529,7 +552,7 @@ func (o *Buffer) enc_struct_message(p *Properties, base structPointer) error {
}
o.buf = append(o.buf, p.tagcode...)
o.EncodeRawBytes(data)
return nil
return state.err
}
o.buf = append(o.buf, p.tagcode...)
@ -667,7 +690,7 @@ func (o *Buffer) enc_proto3_slice_byte(p *Properties, base structPointer) error
func size_slice_byte(p *Properties, base structPointer) (n int) {
s := *structPointer_Bytes(base, p.field)
if s == nil {
if s == nil && !p.oneof {
return 0
}
n += len(p.tagcode)
@ -677,7 +700,7 @@ func size_slice_byte(p *Properties, base structPointer) (n int) {
func size_proto3_slice_byte(p *Properties, base structPointer) (n int) {
s := *structPointer_Bytes(base, p.field)
if len(s) == 0 {
if len(s) == 0 && !p.oneof {
return 0
}
n += len(p.tagcode)
@ -1101,9 +1124,8 @@ func (o *Buffer) enc_new_map(p *Properties, base structPointer) error {
return nil
}
keys := v.MapKeys()
sort.Sort(mapKeys(keys))
for _, key := range keys {
// Don't sort map keys. It is not required by the spec, and C++ doesn't do it.
for _, key := range v.MapKeys() {
val := v.MapIndex(key)
// The only illegal map entry values are nil message pointers.
@ -1201,6 +1223,16 @@ func (o *Buffer) enc_struct(prop *StructProperties, base structPointer) error {
}
}
// Do oneof fields.
if prop.oneofMarshaler != nil {
m := structPointer_Interface(base, prop.stype).(Message)
if err := prop.oneofMarshaler(m, o); err == ErrNil {
return errOneofHasNil
} else if err != nil {
return err
}
}
// Add unrecognized fields at the end.
if prop.unrecField.IsValid() {
v := *structPointer_Bytes(base, prop.unrecField)
@ -1226,6 +1258,12 @@ func size_struct(prop *StructProperties, base structPointer) (n int) {
n += len(v)
}
// Factor in any oneof fields.
if prop.oneofSizer != nil {
m := structPointer_Interface(base, prop.stype).(Message)
n += prop.oneofSizer(m)
}
return
}

View file

@ -30,7 +30,6 @@
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Protocol buffer comparison.
// TODO: MessageSet.
package proto
@ -51,7 +50,9 @@ Equality is defined in this way:
are equal, and extensions sets are equal.
- Two set scalar fields are equal iff their values are equal.
If the fields are of a floating-point type, remember that
NaN != x for all x, including NaN.
NaN != x for all x, including NaN. If the message is defined
in a proto3 .proto file, fields are not "set"; specifically,
zero length proto3 "bytes" fields are equal (nil == {}).
- Two repeated fields are equal iff their lengths are the same,
and their corresponding elements are equal (a "bytes" field,
although represented by []byte, is not a repeated field)
@ -89,6 +90,7 @@ func Equal(a, b Message) bool {
// v1 and v2 are known to have the same type.
func equalStruct(v1, v2 reflect.Value) bool {
sprop := GetProperties(v1.Type())
for i := 0; i < v1.NumField(); i++ {
f := v1.Type().Field(i)
if strings.HasPrefix(f.Name, "XXX_") {
@ -114,7 +116,7 @@ func equalStruct(v1, v2 reflect.Value) bool {
}
f1, f2 = f1.Elem(), f2.Elem()
}
if !equalAny(f1, f2) {
if !equalAny(f1, f2, sprop.Prop[i]) {
return false
}
}
@ -141,7 +143,8 @@ func equalStruct(v1, v2 reflect.Value) bool {
}
// v1 and v2 are known to have the same type.
func equalAny(v1, v2 reflect.Value) bool {
// prop may be nil.
func equalAny(v1, v2 reflect.Value, prop *Properties) bool {
if v1.Type() == protoMessageType {
m1, _ := v1.Interface().(Message)
m2, _ := v2.Interface().(Message)
@ -154,6 +157,17 @@ func equalAny(v1, v2 reflect.Value) bool {
return v1.Float() == v2.Float()
case reflect.Int32, reflect.Int64:
return v1.Int() == v2.Int()
case reflect.Interface:
// Probably a oneof field; compare the inner values.
n1, n2 := v1.IsNil(), v2.IsNil()
if n1 || n2 {
return n1 == n2
}
e1, e2 := v1.Elem(), v2.Elem()
if e1.Type() != e2.Type() {
return false
}
return equalAny(e1, e2, nil)
case reflect.Map:
if v1.Len() != v2.Len() {
return false
@ -164,16 +178,22 @@ func equalAny(v1, v2 reflect.Value) bool {
// This key was not found in the second map.
return false
}
if !equalAny(v1.MapIndex(key), val2) {
if !equalAny(v1.MapIndex(key), val2, nil) {
return false
}
}
return true
case reflect.Ptr:
return equalAny(v1.Elem(), v2.Elem())
return equalAny(v1.Elem(), v2.Elem(), prop)
case reflect.Slice:
if v1.Type().Elem().Kind() == reflect.Uint8 {
// short circuit: []byte
// Edge case: if this is in a proto3 message, a zero length
// bytes field is considered the zero value.
if prop != nil && prop.proto3 && v1.Len() == 0 && v2.Len() == 0 {
return true
}
if v1.IsNil() != v2.IsNil() {
return false
}
@ -184,7 +204,7 @@ func equalAny(v1, v2 reflect.Value) bool {
return false
}
for i := 0; i < v1.Len(); i++ {
if !equalAny(v1.Index(i), v2.Index(i)) {
if !equalAny(v1.Index(i), v2.Index(i), prop) {
return false
}
}
@ -219,7 +239,7 @@ func equalExtensions(base reflect.Type, em1, em2 map[int32]Extension) bool {
if m1 != nil && m2 != nil {
// Both are unencoded.
if !equalAny(reflect.ValueOf(m1), reflect.ValueOf(m2)) {
if !equalAny(reflect.ValueOf(m1), reflect.ValueOf(m2), nil) {
return false
}
continue
@ -247,7 +267,7 @@ func equalExtensions(base reflect.Type, em1, em2 map[int32]Extension) bool {
log.Printf("proto: badly encoded extension %d of %v: %v", extNum, base, err)
return false
}
if !equalAny(reflect.ValueOf(m1), reflect.ValueOf(m2)) {
if !equalAny(reflect.ValueOf(m1), reflect.ValueOf(m2), nil) {
return false
}
}

View file

@ -301,7 +301,6 @@ func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, error) {
o := NewBuffer(b)
t := reflect.TypeOf(extension.ExtensionType)
rep := extension.repeated()
props := extensionProperties(extension)
@ -323,7 +322,7 @@ func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, error) {
return nil, err
}
if !rep || o.index >= len(o.buf) {
if o.index >= len(o.buf) {
break
}
}

View file

@ -66,8 +66,16 @@ for a protocol buffer variable v:
that contain it (if any) followed by the CamelCased name of the
extension field itself. HasExtension, ClearExtension, GetExtension
and SetExtension are functions for manipulating extensions.
- Oneof field sets are given a single field in their message,
with distinguished wrapper types for each possible field value.
- Marshal and Unmarshal are functions to encode and decode the wire format.
When the .proto file specifies `syntax="proto3"`, there are some differences:
- Non-repeated fields of non-message type are values instead of pointers.
- Getters are only generated for message and oneof fields.
- Enum types do not get an Enum method.
The simplest way to describe this is to see an example.
Given file test.proto, containing
@ -82,6 +90,10 @@ Given file test.proto, containing
optional group OptionalGroup = 4 {
required string RequiredField = 5;
}
oneof union {
int32 number = 6;
string name = 7;
}
}
The resulting file, test.pb.go, is:
@ -120,15 +132,40 @@ The resulting file, test.pb.go, is:
}
type Test struct {
Label *string `protobuf:"bytes,1,req,name=label" json:"label,omitempty"`
Type *int32 `protobuf:"varint,2,opt,name=type,def=77" json:"type,omitempty"`
Reps []int64 `protobuf:"varint,3,rep,name=reps" json:"reps,omitempty"`
Optionalgroup *Test_OptionalGroup `protobuf:"group,4,opt,name=OptionalGroup" json:"optionalgroup,omitempty"`
XXX_unrecognized []byte `json:"-"`
Label *string `protobuf:"bytes,1,req,name=label" json:"label,omitempty"`
Type *int32 `protobuf:"varint,2,opt,name=type,def=77" json:"type,omitempty"`
Reps []int64 `protobuf:"varint,3,rep,name=reps" json:"reps,omitempty"`
Optionalgroup *Test_OptionalGroup `protobuf:"group,4,opt,name=OptionalGroup" json:"optionalgroup,omitempty"`
// Types that are valid to be assigned to Union:
// *Test_Number
// *Test_Name
Union isTest_Union `protobuf_oneof:"union"`
XXX_unrecognized []byte `json:"-"`
}
func (m *Test) Reset() { *m = Test{} }
func (m *Test) String() string { return proto.CompactTextString(m) }
func (*Test) ProtoMessage() {}
func (*Test) ProtoMessage() {}
type isTest_Union interface {
isTest_Union()
}
type Test_Number struct {
Number int32 `protobuf:"varint,6,opt,name=number"`
}
type Test_Name struct {
Name string `protobuf:"bytes,7,opt,name=name"`
}
func (*Test_Number) isTest_Union() {}
func (*Test_Name) isTest_Union() {}
func (m *Test) GetUnion() isTest_Union {
if m != nil {
return m.Union
}
return nil
}
const Default_Test_Type int32 = 77
func (m *Test) GetLabel() string {
@ -165,13 +202,27 @@ The resulting file, test.pb.go, is:
return ""
}
func (m *Test) GetNumber() int32 {
if x, ok := m.GetUnion().(*Test_Number); ok {
return x.Number
}
return 0
}
func (m *Test) GetName() string {
if x, ok := m.GetUnion().(*Test_Name); ok {
return x.Name
}
return ""
}
func init() {
proto.RegisterEnum("example.FOO", FOO_name, FOO_value)
}
To create and play with a Test object:
package main
package main
import (
"log"
@ -184,9 +235,11 @@ package main
test := &pb.Test{
Label: proto.String("hello"),
Type: proto.Int32(17),
Reps: []int64{1, 2, 3},
Optionalgroup: &pb.Test_OptionalGroup{
RequiredField: proto.String("good bye"),
},
Union: &pb.Test_Name{"fred"},
}
data, err := proto.Marshal(test)
if err != nil {
@ -201,6 +254,11 @@ package main
if test.GetLabel() != newTest.GetLabel() {
log.Fatalf("data mismatch %q != %q", test.GetLabel(), newTest.GetLabel())
}
// Use a type switch to determine which oneof was set.
switch u := test.Union.(type) {
case *pb.Test_Number: // u.Number contains the number.
case *pb.Test_Name: // u.Name contains the string.
}
// etc.
}
*/
@ -211,6 +269,7 @@ import (
"fmt"
"log"
"reflect"
"sort"
"strconv"
"sync"
)
@ -459,7 +518,6 @@ out:
break out
}
fmt.Printf("%3d: t=%3d fix64 %d\n", index, tag, u)
break
case WireVarint:
u, err = p.DecodeVarint()
@ -470,19 +528,11 @@ out:
fmt.Printf("%3d: t=%3d varint %d\n", index, tag, u)
case WireStartGroup:
if err != nil {
fmt.Printf("%3d: t=%3d start err %v\n", index, tag, err)
break out
}
fmt.Printf("%3d: t=%3d start\n", index, tag)
depth++
case WireEndGroup:
depth--
if err != nil {
fmt.Printf("%3d: t=%3d end err %v\n", index, tag, err)
break out
}
fmt.Printf("%3d: t=%3d end\n", index, tag)
}
}
@ -787,12 +837,39 @@ func fieldDefault(ft reflect.Type, prop *Properties) (sf *scalarField, nestedMes
// If this turns out to be inefficient we can always consider other options,
// such as doing a Schwartzian transform.
type mapKeys []reflect.Value
func mapKeys(vs []reflect.Value) sort.Interface {
s := mapKeySorter{
vs: vs,
// default Less function: textual comparison
less: func(a, b reflect.Value) bool {
return fmt.Sprint(a.Interface()) < fmt.Sprint(b.Interface())
},
}
func (s mapKeys) Len() int { return len(s) }
func (s mapKeys) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
func (s mapKeys) Less(i, j int) bool {
return fmt.Sprint(s[i].Interface()) < fmt.Sprint(s[j].Interface())
// Type specialization per https://developers.google.com/protocol-buffers/docs/proto#maps;
// numeric keys are sorted numerically.
if len(vs) == 0 {
return s
}
switch vs[0].Kind() {
case reflect.Int32, reflect.Int64:
s.less = func(a, b reflect.Value) bool { return a.Int() < b.Int() }
case reflect.Uint32, reflect.Uint64:
s.less = func(a, b reflect.Value) bool { return a.Uint() < b.Uint() }
}
return s
}
type mapKeySorter struct {
vs []reflect.Value
less func(a, b reflect.Value) bool
}
func (s mapKeySorter) Len() int { return len(s.vs) }
func (s mapKeySorter) Swap(i, j int) { s.vs[i], s.vs[j] = s.vs[j], s.vs[i] }
func (s mapKeySorter) Less(i, j int) bool {
return s.less(s.vs[i], s.vs[j])
}
// isProto3Zero reports whether v is a zero proto3 value.
@ -811,3 +888,7 @@ func isProto3Zero(v reflect.Value) bool {
}
return false
}
// ProtoPackageIsVersion1 is referenced from generated protocol buffer files
// to assert that that code is compatible with this version of the proto package.
const ProtoPackageIsVersion1 = true

View file

@ -44,11 +44,11 @@ import (
"sort"
)
// ErrNoMessageTypeId occurs when a protocol buffer does not have a message type ID.
// errNoMessageTypeID occurs when a protocol buffer does not have a message type ID.
// A message type ID is required for storing a protocol buffer in a message set.
var ErrNoMessageTypeId = errors.New("proto does not have a message type ID")
var errNoMessageTypeID = errors.New("proto does not have a message type ID")
// The first two types (_MessageSet_Item and MessageSet)
// The first two types (_MessageSet_Item and messageSet)
// model what the protocol compiler produces for the following protocol message:
// message MessageSet {
// repeated group Item = 1 {
@ -58,27 +58,20 @@ var ErrNoMessageTypeId = errors.New("proto does not have a message type ID")
// }
// That is the MessageSet wire format. We can't use a proto to generate these
// because that would introduce a circular dependency between it and this package.
//
// When a proto1 proto has a field that looks like:
// optional message<MessageSet> info = 3;
// the protocol compiler produces a field in the generated struct that looks like:
// Info *_proto_.MessageSet `protobuf:"bytes,3,opt,name=info"`
// The package is automatically inserted so there is no need for that proto file to
// import this package.
type _MessageSet_Item struct {
TypeId *int32 `protobuf:"varint,2,req,name=type_id"`
Message []byte `protobuf:"bytes,3,req,name=message"`
}
type MessageSet struct {
type messageSet struct {
Item []*_MessageSet_Item `protobuf:"group,1,rep"`
XXX_unrecognized []byte
// TODO: caching?
}
// Make sure MessageSet is a Message.
var _ Message = (*MessageSet)(nil)
// Make sure messageSet is a Message.
var _ Message = (*messageSet)(nil)
// messageTypeIder is an interface satisfied by a protocol buffer type
// that may be stored in a MessageSet.
@ -86,7 +79,7 @@ type messageTypeIder interface {
MessageTypeId() int32
}
func (ms *MessageSet) find(pb Message) *_MessageSet_Item {
func (ms *messageSet) find(pb Message) *_MessageSet_Item {
mti, ok := pb.(messageTypeIder)
if !ok {
return nil
@ -100,24 +93,24 @@ func (ms *MessageSet) find(pb Message) *_MessageSet_Item {
return nil
}
func (ms *MessageSet) Has(pb Message) bool {
func (ms *messageSet) Has(pb Message) bool {
if ms.find(pb) != nil {
return true
}
return false
}
func (ms *MessageSet) Unmarshal(pb Message) error {
func (ms *messageSet) Unmarshal(pb Message) error {
if item := ms.find(pb); item != nil {
return Unmarshal(item.Message, pb)
}
if _, ok := pb.(messageTypeIder); !ok {
return ErrNoMessageTypeId
return errNoMessageTypeID
}
return nil // TODO: return error instead?
}
func (ms *MessageSet) Marshal(pb Message) error {
func (ms *messageSet) Marshal(pb Message) error {
msg, err := Marshal(pb)
if err != nil {
return err
@ -130,7 +123,7 @@ func (ms *MessageSet) Marshal(pb Message) error {
mti, ok := pb.(messageTypeIder)
if !ok {
return ErrNoMessageTypeId
return errNoMessageTypeID
}
mtid := mti.MessageTypeId()
@ -141,9 +134,9 @@ func (ms *MessageSet) Marshal(pb Message) error {
return nil
}
func (ms *MessageSet) Reset() { *ms = MessageSet{} }
func (ms *MessageSet) String() string { return CompactTextString(ms) }
func (*MessageSet) ProtoMessage() {}
func (ms *messageSet) Reset() { *ms = messageSet{} }
func (ms *messageSet) String() string { return CompactTextString(ms) }
func (*messageSet) ProtoMessage() {}
// Support for the message_set_wire_format message option.
@ -169,7 +162,7 @@ func MarshalMessageSet(m map[int32]Extension) ([]byte, error) {
}
sort.Ints(ids)
ms := &MessageSet{Item: make([]*_MessageSet_Item, 0, len(m))}
ms := &messageSet{Item: make([]*_MessageSet_Item, 0, len(m))}
for _, id := range ids {
e := m[int32(id)]
// Remove the wire type and field number varint, as well as the length varint.
@ -186,7 +179,7 @@ func MarshalMessageSet(m map[int32]Extension) ([]byte, error) {
// UnmarshalMessageSet decodes the extension map encoded in buf in the message set wire format.
// It is called by generated Unmarshal methods on protocol buffer messages with the message_set_wire_format option.
func UnmarshalMessageSet(buf []byte, m map[int32]Extension) error {
ms := new(MessageSet)
ms := new(messageSet)
if err := Unmarshal(buf, ms); err != nil {
return err
}

View file

@ -37,6 +37,7 @@ package proto
import (
"fmt"
"log"
"os"
"reflect"
"sort"
@ -84,6 +85,15 @@ type decoder func(p *Buffer, prop *Properties, base structPointer) error
// A valueDecoder decodes a single integer in a particular encoding.
type valueDecoder func(o *Buffer) (x uint64, err error)
// A oneofMarshaler does the marshaling for all oneof fields in a message.
type oneofMarshaler func(Message, *Buffer) error
// A oneofUnmarshaler does the unmarshaling for a oneof field in a message.
type oneofUnmarshaler func(Message, int, int, *Buffer) (bool, error)
// A oneofSizer does the sizing for all oneof fields in a message.
type oneofSizer func(Message) int
// tagMap is an optimization over map[int]int for typical protocol buffer
// use-cases. Encoded protocol buffers are often in tag order with small tag
// numbers.
@ -132,6 +142,22 @@ type StructProperties struct {
order []int // list of struct field numbers in tag order
unrecField field // field id of the XXX_unrecognized []byte field
extendable bool // is this an extendable proto
oneofMarshaler oneofMarshaler
oneofUnmarshaler oneofUnmarshaler
oneofSizer oneofSizer
stype reflect.Type
// OneofTypes contains information about the oneof fields in this message.
// It is keyed by the original name of a field.
OneofTypes map[string]*OneofProperties
}
// OneofProperties represents information about a specific field in a oneof.
type OneofProperties struct {
Type reflect.Type // pointer to generated struct type for this oneof field
Field int // struct field number of the containing oneof in the message
Prop *Properties
}
// Implement the sorting interface so we can sort the fields in tag order, as recommended by the spec.
@ -147,6 +173,7 @@ func (sp *StructProperties) Swap(i, j int) { sp.order[i], sp.order[j] = sp.order
type Properties struct {
Name string // name of the field, for error messages
OrigName string // original name before protocol compiler (always set)
JSONName string // name to use for JSON; determined by protoc
Wire string
WireType int
Tag int
@ -156,6 +183,7 @@ type Properties struct {
Packed bool // relevant for repeated primitives only
Enum string // set for enum types only
proto3 bool // whether this is known to be a proto3 field; set for []byte only
oneof bool // whether this is a oneof field
Default string // default value
HasDefault bool // whether an explicit default was provided
@ -202,12 +230,16 @@ func (p *Properties) String() string {
if p.Packed {
s += ",packed"
}
if p.OrigName != p.Name {
s += ",name=" + p.OrigName
s += ",name=" + p.OrigName
if p.JSONName != p.OrigName {
s += ",json=" + p.JSONName
}
if p.proto3 {
s += ",proto3"
}
if p.oneof {
s += ",oneof"
}
if len(p.Enum) > 0 {
s += ",enum=" + p.Enum
}
@ -280,10 +312,14 @@ func (p *Properties) Parse(s string) {
p.Packed = true
case strings.HasPrefix(f, "name="):
p.OrigName = f[5:]
case strings.HasPrefix(f, "json="):
p.JSONName = f[5:]
case strings.HasPrefix(f, "enum="):
p.Enum = f[5:]
case f == "proto3":
p.proto3 = true
case f == "oneof":
p.oneof = true
case strings.HasPrefix(f, "def="):
p.HasDefault = true
p.Default = f[4:] // rest of string
@ -665,6 +701,7 @@ func getPropertiesLocked(t reflect.Type) *StructProperties {
if f.Name == "XXX_unrecognized" { // special case
prop.unrecField = toField(&f)
}
oneof := f.Tag.Get("protobuf_oneof") != "" // special case
prop.Prop[i] = p
prop.order[i] = i
if debug {
@ -674,7 +711,7 @@ func getPropertiesLocked(t reflect.Type) *StructProperties {
}
print("\n")
}
if p.enc == nil && !strings.HasPrefix(f.Name, "XXX_") {
if p.enc == nil && !strings.HasPrefix(f.Name, "XXX_") && !oneof {
fmt.Fprintln(os.Stderr, "proto: no encoder for", f.Name, f.Type.String(), "[GetProperties]")
}
}
@ -682,6 +719,41 @@ func getPropertiesLocked(t reflect.Type) *StructProperties {
// Re-order prop.order.
sort.Sort(prop)
type oneofMessage interface {
XXX_OneofFuncs() (func(Message, *Buffer) error, func(Message, int, int, *Buffer) (bool, error), func(Message) int, []interface{})
}
if om, ok := reflect.Zero(reflect.PtrTo(t)).Interface().(oneofMessage); ok {
var oots []interface{}
prop.oneofMarshaler, prop.oneofUnmarshaler, prop.oneofSizer, oots = om.XXX_OneofFuncs()
prop.stype = t
// Interpret oneof metadata.
prop.OneofTypes = make(map[string]*OneofProperties)
for _, oot := range oots {
oop := &OneofProperties{
Type: reflect.ValueOf(oot).Type(), // *T
Prop: new(Properties),
}
sft := oop.Type.Elem().Field(0)
oop.Prop.Name = sft.Name
oop.Prop.Parse(sft.Tag.Get("protobuf"))
// There will be exactly one interface field that
// this new value is assignable to.
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
if f.Type.Kind() != reflect.Interface {
continue
}
if !oop.Type.AssignableTo(f.Type) {
continue
}
oop.Field = i
break
}
prop.OneofTypes[oop.Prop.OrigName] = oop
}
}
// build required counts
// build tags
reqCount := 0
@ -740,3 +812,35 @@ func RegisterEnum(typeName string, unusedNameMap map[int32]string, valueMap map[
}
enumValueMaps[typeName] = valueMap
}
// EnumValueMap returns the mapping from names to integers of the
// enum type enumType, or a nil if not found.
func EnumValueMap(enumType string) map[string]int32 {
return enumValueMaps[enumType]
}
// A registry of all linked message types.
// The string is a fully-qualified proto name ("pkg.Message").
var (
protoTypes = make(map[string]reflect.Type)
revProtoTypes = make(map[reflect.Type]string)
)
// RegisterType is called from generated code and maps from the fully qualified
// proto name to the type (pointer to struct) of the protocol buffer.
func RegisterType(x Message, name string) {
if _, ok := protoTypes[name]; ok {
// TODO: Some day, make this a panic.
log.Printf("proto: duplicate proto type registered: %s", name)
return
}
t := reflect.TypeOf(x)
protoTypes[name] = t
revProtoTypes[t] = name
}
// MessageName returns the fully-qualified proto name for the given message type.
func MessageName(x Message) string { return revProtoTypes[reflect.TypeOf(x)] }
// MessageType returns the message type (pointer to struct) for a named message.
func MessageType(name string) reflect.Type { return protoTypes[name] }

View file

@ -37,6 +37,7 @@ import (
"bufio"
"bytes"
"encoding"
"errors"
"fmt"
"io"
"log"
@ -169,20 +170,98 @@ func writeName(w *textWriter, props *Properties) error {
return nil
}
var (
messageSetType = reflect.TypeOf((*MessageSet)(nil)).Elem()
)
// raw is the interface satisfied by RawMessage.
type raw interface {
Bytes() []byte
}
func writeStruct(w *textWriter, sv reflect.Value) error {
if sv.Type() == messageSetType {
return writeMessageSet(w, sv.Addr().Interface().(*MessageSet))
func requiresQuotes(u string) bool {
// When type URL contains any characters except [0-9A-Za-z./\-]*, it must be quoted.
for _, ch := range u {
switch {
case ch == '.' || ch == '/' || ch == '_':
continue
case '0' <= ch && ch <= '9':
continue
case 'A' <= ch && ch <= 'Z':
continue
case 'a' <= ch && ch <= 'z':
continue
default:
return true
}
}
return false
}
// isAny reports whether sv is a google.protobuf.Any message
func isAny(sv reflect.Value) bool {
type wkt interface {
XXX_WellKnownType() string
}
t, ok := sv.Addr().Interface().(wkt)
return ok && t.XXX_WellKnownType() == "Any"
}
// writeProto3Any writes an expanded google.protobuf.Any message.
//
// It returns (false, nil) if sv value can't be unmarshaled (e.g. because
// required messages are not linked in).
//
// It returns (true, error) when sv was written in expanded format or an error
// was encountered.
func (tm *TextMarshaler) writeProto3Any(w *textWriter, sv reflect.Value) (bool, error) {
turl := sv.FieldByName("TypeUrl")
val := sv.FieldByName("Value")
if !turl.IsValid() || !val.IsValid() {
return true, errors.New("proto: invalid google.protobuf.Any message")
}
b, ok := val.Interface().([]byte)
if !ok {
return true, errors.New("proto: invalid google.protobuf.Any message")
}
parts := strings.Split(turl.String(), "/")
mt := MessageType(parts[len(parts)-1])
if mt == nil {
return false, nil
}
m := reflect.New(mt.Elem())
if err := Unmarshal(b, m.Interface().(Message)); err != nil {
return false, nil
}
w.Write([]byte("["))
u := turl.String()
if requiresQuotes(u) {
writeString(w, u)
} else {
w.Write([]byte(u))
}
if w.compact {
w.Write([]byte("]:<"))
} else {
w.Write([]byte("]: <\n"))
w.ind++
}
if err := tm.writeStruct(w, m.Elem()); err != nil {
return true, err
}
if w.compact {
w.Write([]byte("> "))
} else {
w.ind--
w.Write([]byte(">\n"))
}
return true, nil
}
func (tm *TextMarshaler) writeStruct(w *textWriter, sv reflect.Value) error {
if tm.ExpandAny && isAny(sv) {
if canExpand, err := tm.writeProto3Any(w, sv); canExpand {
return err
}
}
st := sv.Type()
sprops := GetProperties(st)
for i := 0; i < sv.NumField(); i++ {
@ -234,7 +313,7 @@ func writeStruct(w *textWriter, sv reflect.Value) error {
}
continue
}
if err := writeAny(w, v, props); err != nil {
if err := tm.writeAny(w, v, props); err != nil {
return err
}
if err := w.WriteByte('\n'); err != nil {
@ -245,7 +324,7 @@ func writeStruct(w *textWriter, sv reflect.Value) error {
}
if fv.Kind() == reflect.Map {
// Map fields are rendered as a repeated struct with key/value fields.
keys := fv.MapKeys() // TODO: should we sort these for deterministic output?
keys := fv.MapKeys()
sort.Sort(mapKeys(keys))
for _, key := range keys {
val := fv.MapIndex(key)
@ -276,7 +355,7 @@ func writeStruct(w *textWriter, sv reflect.Value) error {
return err
}
}
if err := writeAny(w, key, props.mkeyprop); err != nil {
if err := tm.writeAny(w, key, props.mkeyprop); err != nil {
return err
}
if err := w.WriteByte('\n'); err != nil {
@ -293,7 +372,7 @@ func writeStruct(w *textWriter, sv reflect.Value) error {
return err
}
}
if err := writeAny(w, val, props.mvalprop); err != nil {
if err := tm.writeAny(w, val, props.mvalprop); err != nil {
return err
}
if err := w.WriteByte('\n'); err != nil {
@ -322,6 +401,33 @@ func writeStruct(w *textWriter, sv reflect.Value) error {
}
}
if fv.Kind() == reflect.Interface {
// Check if it is a oneof.
if st.Field(i).Tag.Get("protobuf_oneof") != "" {
// fv is nil, or holds a pointer to generated struct.
// That generated struct has exactly one field,
// which has a protobuf struct tag.
if fv.IsNil() {
continue
}
inner := fv.Elem().Elem() // interface -> *T -> T
tag := inner.Type().Field(0).Tag.Get("protobuf")
props = new(Properties) // Overwrite the outer props var, but not its pointee.
props.Parse(tag)
// Write the value in the oneof, not the oneof itself.
fv = inner.Field(0)
// Special case to cope with malformed messages gracefully:
// If the value in the oneof is a nil pointer, don't panic
// in writeAny.
if fv.Kind() == reflect.Ptr && fv.IsNil() {
// Use errors.New so writeAny won't render quotes.
msg := errors.New("/* nil */")
fv = reflect.ValueOf(&msg).Elem()
}
}
}
if err := writeName(w, props); err != nil {
return err
}
@ -338,7 +444,7 @@ func writeStruct(w *textWriter, sv reflect.Value) error {
}
// Enums have a String method, so writeAny will work fine.
if err := writeAny(w, fv, props); err != nil {
if err := tm.writeAny(w, fv, props); err != nil {
return err
}
@ -350,7 +456,7 @@ func writeStruct(w *textWriter, sv reflect.Value) error {
// Extensions (the XXX_extensions field).
pv := sv.Addr()
if pv.Type().Implements(extendableProtoType) {
if err := writeExtensions(w, pv); err != nil {
if err := tm.writeExtensions(w, pv); err != nil {
return err
}
}
@ -380,7 +486,7 @@ func writeRaw(w *textWriter, b []byte) error {
}
// writeAny writes an arbitrary field.
func writeAny(w *textWriter, v reflect.Value, props *Properties) error {
func (tm *TextMarshaler) writeAny(w *textWriter, v reflect.Value, props *Properties) error {
v = reflect.Indirect(v)
// Floats have special cases.
@ -429,15 +535,15 @@ func writeAny(w *textWriter, v reflect.Value, props *Properties) error {
}
}
w.indent()
if tm, ok := v.Interface().(encoding.TextMarshaler); ok {
text, err := tm.MarshalText()
if etm, ok := v.Interface().(encoding.TextMarshaler); ok {
text, err := etm.MarshalText()
if err != nil {
return err
}
if _, err = w.Write(text); err != nil {
return err
}
} else if err := writeStruct(w, v); err != nil {
} else if err := tm.writeStruct(w, v); err != nil {
return err
}
w.unindent()
@ -497,44 +603,6 @@ func writeString(w *textWriter, s string) error {
return w.WriteByte('"')
}
func writeMessageSet(w *textWriter, ms *MessageSet) error {
for _, item := range ms.Item {
id := *item.TypeId
if msd, ok := messageSetMap[id]; ok {
// Known message set type.
if _, err := fmt.Fprintf(w, "[%s]: <\n", msd.name); err != nil {
return err
}
w.indent()
pb := reflect.New(msd.t.Elem())
if err := Unmarshal(item.Message, pb.Interface().(Message)); err != nil {
if _, err := fmt.Fprintf(w, "/* bad message: %v */\n", err); err != nil {
return err
}
} else {
if err := writeStruct(w, pb.Elem()); err != nil {
return err
}
}
} else {
// Unknown type.
if _, err := fmt.Fprintf(w, "[%d]: <\n", id); err != nil {
return err
}
w.indent()
if err := writeUnknownStruct(w, item.Message); err != nil {
return err
}
}
w.unindent()
if _, err := w.Write(gtNewline); err != nil {
return err
}
}
return nil
}
func writeUnknownStruct(w *textWriter, data []byte) (err error) {
if !w.compact {
if _, err := fmt.Fprintf(w, "/* %d unknown bytes */\n", len(data)); err != nil {
@ -619,7 +687,7 @@ func (s int32Slice) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
// writeExtensions writes all the extensions in pv.
// pv is assumed to be a pointer to a protocol message struct that is extendable.
func writeExtensions(w *textWriter, pv reflect.Value) error {
func (tm *TextMarshaler) writeExtensions(w *textWriter, pv reflect.Value) error {
emap := extensionMaps[pv.Type().Elem()]
ep := pv.Interface().(extendableProto)
@ -654,13 +722,13 @@ func writeExtensions(w *textWriter, pv reflect.Value) error {
// Repeated extensions will appear as a slice.
if !desc.repeated() {
if err := writeExtension(w, desc.Name, pb); err != nil {
if err := tm.writeExtension(w, desc.Name, pb); err != nil {
return err
}
} else {
v := reflect.ValueOf(pb)
for i := 0; i < v.Len(); i++ {
if err := writeExtension(w, desc.Name, v.Index(i).Interface()); err != nil {
if err := tm.writeExtension(w, desc.Name, v.Index(i).Interface()); err != nil {
return err
}
}
@ -669,7 +737,7 @@ func writeExtensions(w *textWriter, pv reflect.Value) error {
return nil
}
func writeExtension(w *textWriter, name string, pb interface{}) error {
func (tm *TextMarshaler) writeExtension(w *textWriter, name string, pb interface{}) error {
if _, err := fmt.Fprintf(w, "[%s]:", name); err != nil {
return err
}
@ -678,7 +746,7 @@ func writeExtension(w *textWriter, name string, pb interface{}) error {
return err
}
}
if err := writeAny(w, reflect.ValueOf(pb), nil); err != nil {
if err := tm.writeAny(w, reflect.ValueOf(pb), nil); err != nil {
return err
}
if err := w.WriteByte('\n'); err != nil {
@ -703,7 +771,15 @@ func (w *textWriter) writeIndent() {
w.complete = false
}
func marshalText(w io.Writer, pb Message, compact bool) error {
// TextMarshaler is a configurable text format marshaler.
type TextMarshaler struct {
Compact bool // use compact text format (one line).
ExpandAny bool // expand google.protobuf.Any messages of known types
}
// Marshal writes a given protocol buffer in text format.
// The only errors returned are from w.
func (tm *TextMarshaler) Marshal(w io.Writer, pb Message) error {
val := reflect.ValueOf(pb)
if pb == nil || val.IsNil() {
w.Write([]byte("<nil>"))
@ -718,11 +794,11 @@ func marshalText(w io.Writer, pb Message, compact bool) error {
aw := &textWriter{
w: ww,
complete: true,
compact: compact,
compact: tm.Compact,
}
if tm, ok := pb.(encoding.TextMarshaler); ok {
text, err := tm.MarshalText()
if etm, ok := pb.(encoding.TextMarshaler); ok {
text, err := etm.MarshalText()
if err != nil {
return err
}
@ -736,7 +812,7 @@ func marshalText(w io.Writer, pb Message, compact bool) error {
}
// Dereference the received pointer so we don't have outer < and >.
v := reflect.Indirect(val)
if err := writeStruct(aw, v); err != nil {
if err := tm.writeStruct(aw, v); err != nil {
return err
}
if bw != nil {
@ -745,25 +821,29 @@ func marshalText(w io.Writer, pb Message, compact bool) error {
return nil
}
// Text is the same as Marshal, but returns the string directly.
func (tm *TextMarshaler) Text(pb Message) string {
var buf bytes.Buffer
tm.Marshal(&buf, pb)
return buf.String()
}
var (
defaultTextMarshaler = TextMarshaler{}
compactTextMarshaler = TextMarshaler{Compact: true}
)
// TODO: consider removing some of the Marshal functions below.
// MarshalText writes a given protocol buffer in text format.
// The only errors returned are from w.
func MarshalText(w io.Writer, pb Message) error {
return marshalText(w, pb, false)
}
func MarshalText(w io.Writer, pb Message) error { return defaultTextMarshaler.Marshal(w, pb) }
// MarshalTextString is the same as MarshalText, but returns the string directly.
func MarshalTextString(pb Message) string {
var buf bytes.Buffer
marshalText(&buf, pb, false)
return buf.String()
}
func MarshalTextString(pb Message) string { return defaultTextMarshaler.Text(pb) }
// CompactText writes a given protocol buffer in compact text format (one line).
func CompactText(w io.Writer, pb Message) error { return marshalText(w, pb, true) }
func CompactText(w io.Writer, pb Message) error { return compactTextMarshaler.Marshal(w, pb) }
// CompactTextString is the same as CompactText, but returns the string directly.
func CompactTextString(pb Message) string {
var buf bytes.Buffer
marshalText(&buf, pb, true)
return buf.String()
}
func CompactTextString(pb Message) string { return compactTextMarshaler.Text(pb) }

View file

@ -119,6 +119,14 @@ func isWhitespace(c byte) bool {
return false
}
func isQuote(c byte) bool {
switch c {
case '"', '\'':
return true
}
return false
}
func (p *textParser) skipWhitespace() {
i := 0
for i < len(p.s) && (isWhitespace(p.s[i]) || p.s[i] == '#') {
@ -155,7 +163,7 @@ func (p *textParser) advance() {
p.cur.offset, p.cur.line = p.offset, p.line
p.cur.unquoted = ""
switch p.s[0] {
case '<', '>', '{', '}', ':', '[', ']', ';', ',':
case '<', '>', '{', '}', ':', '[', ']', ';', ',', '/':
// Single symbol
p.cur.value, p.s = p.s[0:1], p.s[1:len(p.s)]
case '"', '\'':
@ -333,13 +341,13 @@ func (p *textParser) next() *token {
p.advance()
if p.done {
p.cur.value = ""
} else if len(p.cur.value) > 0 && p.cur.value[0] == '"' {
} else if len(p.cur.value) > 0 && isQuote(p.cur.value[0]) {
// Look for multiple quoted strings separated by whitespace,
// and concatenate them.
cat := p.cur
for {
p.skipWhitespace()
if p.done || p.s[0] != '"' {
if p.done || !isQuote(p.s[0]) {
break
}
p.advance()
@ -385,8 +393,7 @@ func (p *textParser) missingRequiredFieldError(sv reflect.Value) *RequiredNotSet
}
// Returns the index in the struct for the named field, as well as the parsed tag properties.
func structFieldByName(st reflect.Type, name string) (int, *Properties, bool) {
sprops := GetProperties(st)
func structFieldByName(sprops *StructProperties, name string) (int, *Properties, bool) {
i, ok := sprops.decoderOrigNames[name]
if ok {
return i, sprops.Prop[i], true
@ -438,12 +445,16 @@ func (p *textParser) checkForColon(props *Properties, typ reflect.Type) *ParseEr
func (p *textParser) readStruct(sv reflect.Value, terminator string) error {
st := sv.Type()
reqCount := GetProperties(st).reqCount
sprops := GetProperties(st)
reqCount := sprops.reqCount
var reqFieldErr error
fieldSet := make(map[string]bool)
// A struct is a sequence of "name: value", terminated by one of
// '>' or '}', or the end of the input. A name may also be
// "[extension]".
// "[extension]" or "[type/url]".
//
// The whole struct can also be an expanded Any message, like:
// [type/url] < ... struct contents ... >
for {
tok := p.next()
if tok.err != nil {
@ -453,33 +464,66 @@ func (p *textParser) readStruct(sv reflect.Value, terminator string) error {
break
}
if tok.value == "[" {
// Looks like an extension.
// Looks like an extension or an Any.
//
// TODO: Check whether we need to handle
// namespace rooted names (e.g. ".something.Foo").
tok = p.next()
if tok.err != nil {
return tok.err
extName, err := p.consumeExtName()
if err != nil {
return err
}
if s := strings.LastIndex(extName, "/"); s >= 0 {
// If it contains a slash, it's an Any type URL.
messageName := extName[s+1:]
mt := MessageType(messageName)
if mt == nil {
return p.errorf("unrecognized message %q in google.protobuf.Any", messageName)
}
tok = p.next()
if tok.err != nil {
return tok.err
}
// consume an optional colon
if tok.value == ":" {
tok = p.next()
if tok.err != nil {
return tok.err
}
}
var terminator string
switch tok.value {
case "<":
terminator = ">"
case "{":
terminator = "}"
default:
return p.errorf("expected '{' or '<', found %q", tok.value)
}
v := reflect.New(mt.Elem())
if pe := p.readStruct(v.Elem(), terminator); pe != nil {
return pe
}
b, err := Marshal(v.Interface().(Message))
if err != nil {
return p.errorf("failed to marshal message of type %q: %v", messageName, err)
}
sv.FieldByName("TypeUrl").SetString(extName)
sv.FieldByName("Value").SetBytes(b)
continue
}
var desc *ExtensionDesc
// This could be faster, but it's functional.
// TODO: Do something smarter than a linear scan.
for _, d := range RegisteredExtensions(reflect.New(st).Interface().(Message)) {
if d.Name == tok.value {
if d.Name == extName {
desc = d
break
}
}
if desc == nil {
return p.errorf("unrecognized extension %q", tok.value)
}
// Check the extension terminator.
tok = p.next()
if tok.err != nil {
return tok.err
}
if tok.value != "]" {
return p.errorf("unrecognized extension terminator %q", tok.value)
return p.errorf("unrecognized extension %q", extName)
}
props := &Properties{}
@ -520,95 +564,107 @@ func (p *textParser) readStruct(sv reflect.Value, terminator string) error {
sl = reflect.Append(sl, ext)
SetExtension(ep, desc, sl.Interface())
}
} else {
// This is a normal, non-extension field.
name := tok.value
fi, props, ok := structFieldByName(st, name)
if !ok {
return p.errorf("unknown field name %q in %v", name, st)
if err := p.consumeOptionalSeparator(); err != nil {
return err
}
continue
}
dst := sv.Field(fi)
// This is a normal, non-extension field.
name := tok.value
var dst reflect.Value
fi, props, ok := structFieldByName(sprops, name)
if ok {
dst = sv.Field(fi)
} else if oop, ok := sprops.OneofTypes[name]; ok {
// It is a oneof.
props = oop.Prop
nv := reflect.New(oop.Type.Elem())
dst = nv.Elem().Field(0)
sv.Field(oop.Field).Set(nv)
}
if !dst.IsValid() {
return p.errorf("unknown field name %q in %v", name, st)
}
if dst.Kind() == reflect.Map {
// Consume any colon.
if err := p.checkForColon(props, dst.Type()); err != nil {
return err
}
// Construct the map if it doesn't already exist.
if dst.IsNil() {
dst.Set(reflect.MakeMap(dst.Type()))
}
key := reflect.New(dst.Type().Key()).Elem()
val := reflect.New(dst.Type().Elem()).Elem()
// The map entry should be this sequence of tokens:
// < key : KEY value : VALUE >
// Technically the "key" and "value" could come in any order,
// but in practice they won't.
tok := p.next()
var terminator string
switch tok.value {
case "<":
terminator = ">"
case "{":
terminator = "}"
default:
return p.errorf("expected '{' or '<', found %q", tok.value)
}
if err := p.consumeToken("key"); err != nil {
return err
}
if err := p.consumeToken(":"); err != nil {
return err
}
if err := p.readAny(key, props.mkeyprop); err != nil {
return err
}
if err := p.consumeOptionalSeparator(); err != nil {
return err
}
if err := p.consumeToken("value"); err != nil {
return err
}
if err := p.checkForColon(props.mvalprop, dst.Type().Elem()); err != nil {
return err
}
if err := p.readAny(val, props.mvalprop); err != nil {
return err
}
if err := p.consumeOptionalSeparator(); err != nil {
return err
}
if err := p.consumeToken(terminator); err != nil {
return err
}
dst.SetMapIndex(key, val)
continue
}
// Check that it's not already set if it's not a repeated field.
if !props.Repeated && fieldSet[name] {
return p.errorf("non-repeated field %q was repeated", name)
}
if err := p.checkForColon(props, st.Field(fi).Type); err != nil {
if dst.Kind() == reflect.Map {
// Consume any colon.
if err := p.checkForColon(props, dst.Type()); err != nil {
return err
}
// Parse into the field.
fieldSet[name] = true
if err := p.readAny(dst, props); err != nil {
if _, ok := err.(*RequiredNotSetError); !ok {
return err
}
reqFieldErr = err
} else if props.Required {
reqCount--
// Construct the map if it doesn't already exist.
if dst.IsNil() {
dst.Set(reflect.MakeMap(dst.Type()))
}
key := reflect.New(dst.Type().Key()).Elem()
val := reflect.New(dst.Type().Elem()).Elem()
// The map entry should be this sequence of tokens:
// < key : KEY value : VALUE >
// Technically the "key" and "value" could come in any order,
// but in practice they won't.
tok := p.next()
var terminator string
switch tok.value {
case "<":
terminator = ">"
case "{":
terminator = "}"
default:
return p.errorf("expected '{' or '<', found %q", tok.value)
}
if err := p.consumeToken("key"); err != nil {
return err
}
if err := p.consumeToken(":"); err != nil {
return err
}
if err := p.readAny(key, props.mkeyprop); err != nil {
return err
}
if err := p.consumeOptionalSeparator(); err != nil {
return err
}
if err := p.consumeToken("value"); err != nil {
return err
}
if err := p.checkForColon(props.mvalprop, dst.Type().Elem()); err != nil {
return err
}
if err := p.readAny(val, props.mvalprop); err != nil {
return err
}
if err := p.consumeOptionalSeparator(); err != nil {
return err
}
if err := p.consumeToken(terminator); err != nil {
return err
}
dst.SetMapIndex(key, val)
continue
}
// Check that it's not already set if it's not a repeated field.
if !props.Repeated && fieldSet[name] {
return p.errorf("non-repeated field %q was repeated", name)
}
if err := p.checkForColon(props, dst.Type()); err != nil {
return err
}
// Parse into the field.
fieldSet[name] = true
if err := p.readAny(dst, props); err != nil {
if _, ok := err.(*RequiredNotSetError); !ok {
return err
}
reqFieldErr = err
} else if props.Required {
reqCount--
}
if err := p.consumeOptionalSeparator(); err != nil {
@ -623,6 +679,35 @@ func (p *textParser) readStruct(sv reflect.Value, terminator string) error {
return reqFieldErr
}
// consumeExtName consumes extension name or expanded Any type URL and the
// following ']'. It returns the name or URL consumed.
func (p *textParser) consumeExtName() (string, error) {
tok := p.next()
if tok.err != nil {
return "", tok.err
}
// If extension name or type url is quoted, it's a single token.
if len(tok.value) > 2 && isQuote(tok.value[0]) && tok.value[len(tok.value)-1] == tok.value[0] {
name, err := unquoteC(tok.value[1:len(tok.value)-1], rune(tok.value[0]))
if err != nil {
return "", err
}
return name, p.consumeToken("]")
}
// Consume everything up to "]"
var parts []string
for tok.value != "]" {
parts = append(parts, tok.value)
tok = p.next()
if tok.err != nil {
return "", p.errorf("unrecognized type_url or extension name: %s", tok.err)
}
}
return strings.Join(parts, ""), nil
}
// consumeOptionalSeparator consumes an optional semicolon or comma.
// It is used in readStruct to provide backward compatibility.
func (p *textParser) consumeOptionalSeparator() error {
@ -660,18 +745,32 @@ func (p *textParser) readAny(v reflect.Value, props *Properties) error {
fv.Set(reflect.ValueOf(bytes))
return nil
}
// Repeated field. May already exist.
flen := fv.Len()
if flen == fv.Cap() {
nav := reflect.MakeSlice(at, flen, 2*flen+1)
reflect.Copy(nav, fv)
fv.Set(nav)
// Repeated field.
if tok.value == "[" {
// Repeated field with list notation, like [1,2,3].
for {
fv.Set(reflect.Append(fv, reflect.New(at.Elem()).Elem()))
err := p.readAny(fv.Index(fv.Len()-1), props)
if err != nil {
return err
}
tok := p.next()
if tok.err != nil {
return tok.err
}
if tok.value == "]" {
break
}
if tok.value != "," {
return p.errorf("Expected ']' or ',' found %q", tok.value)
}
}
return nil
}
fv.SetLen(flen + 1)
// Read one.
// One value of the repeated field.
p.back()
return p.readAny(fv.Index(flen), props)
fv.Set(reflect.Append(fv, reflect.New(at.Elem()).Elem()))
return p.readAny(fv.Index(fv.Len()-1), props)
case reflect.Bool:
// Either "true", "false", 1 or 0.
switch tok.value {