vendor: Update vendoring for the exec client and server implementations

Signed-off-by: Jacek J. Łakis <jacek.lakis@intel.com>
Signed-off-by: Samuel Ortiz <sameo@linux.intel.com>
This commit is contained in:
Jacek J. Łakis 2017-02-08 14:57:52 +01:00 committed by Samuel Ortiz
parent d25b88583f
commit bf51655a7b
2124 changed files with 809703 additions and 5 deletions

83
vendor/k8s.io/apiserver/pkg/util/cache/cache.go generated vendored Normal file
View file

@ -0,0 +1,83 @@
/*
Copyright 2014 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cache
import (
"sync"
)
const (
shardsCount int = 32
)
type Cache []*cacheShard
func NewCache(maxSize int) Cache {
if maxSize < shardsCount {
maxSize = shardsCount
}
cache := make(Cache, shardsCount)
for i := 0; i < shardsCount; i++ {
cache[i] = &cacheShard{
items: make(map[uint64]interface{}),
maxSize: maxSize / shardsCount,
}
}
return cache
}
func (c Cache) getShard(index uint64) *cacheShard {
return c[index%uint64(shardsCount)]
}
// Returns true if object already existed, false otherwise.
func (c *Cache) Add(index uint64, obj interface{}) bool {
return c.getShard(index).add(index, obj)
}
func (c *Cache) Get(index uint64) (obj interface{}, found bool) {
return c.getShard(index).get(index)
}
type cacheShard struct {
items map[uint64]interface{}
sync.RWMutex
maxSize int
}
// Returns true if object already existed, false otherwise.
func (s *cacheShard) add(index uint64, obj interface{}) bool {
s.Lock()
defer s.Unlock()
_, isOverwrite := s.items[index]
if !isOverwrite && len(s.items) >= s.maxSize {
var randomKey uint64
for randomKey = range s.items {
break
}
delete(s.items, randomKey)
}
s.items[index] = obj
return isOverwrite
}
func (s *cacheShard) get(index uint64) (obj interface{}, found bool) {
s.RLock()
defer s.RUnlock()
obj, found = s.items[index]
return
}

90
vendor/k8s.io/apiserver/pkg/util/cache/cache_test.go generated vendored Normal file
View file

@ -0,0 +1,90 @@
/*
Copyright 2014 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cache
import (
"testing"
)
const (
maxTestCacheSize int = shardsCount * 2
)
func ExpectEntry(t *testing.T, cache Cache, index uint64, expectedValue interface{}) bool {
elem, found := cache.Get(index)
if !found {
t.Errorf("Expected to find entry with key %d", index)
return false
} else if elem != expectedValue {
t.Errorf("Expected to find %v, got %v", expectedValue, elem)
return false
}
return true
}
func TestBasic(t *testing.T) {
cache := NewCache(maxTestCacheSize)
cache.Add(1, "xxx")
ExpectEntry(t, cache, 1, "xxx")
}
func TestOverflow(t *testing.T) {
cache := NewCache(maxTestCacheSize)
for i := 0; i < maxTestCacheSize+1; i++ {
cache.Add(uint64(i), "xxx")
}
foundIndexes := make([]uint64, 0)
for i := 0; i < maxTestCacheSize+1; i++ {
_, found := cache.Get(uint64(i))
if found {
foundIndexes = append(foundIndexes, uint64(i))
}
}
if len(foundIndexes) != maxTestCacheSize {
t.Errorf("Expect to find %d elements, got %d %v", maxTestCacheSize, len(foundIndexes), foundIndexes)
}
}
func TestOverwrite(t *testing.T) {
cache := NewCache(maxTestCacheSize)
cache.Add(1, "xxx")
ExpectEntry(t, cache, 1, "xxx")
cache.Add(1, "yyy")
ExpectEntry(t, cache, 1, "yyy")
}
// TestEvict this test will fail sporatically depending on what add()
// selects for the randomKey to be evicted. Ensure that randomKey
// is never the key we most recently added. Since the chance of failure
// on each evict is 50%, if we do it 7 times, it should catch the problem
// if it exists >99% of the time.
func TestEvict(t *testing.T) {
cache := NewCache(shardsCount)
var found bool
for retry := 0; retry < 7; retry++ {
cache.Add(uint64(shardsCount), "xxx")
found = ExpectEntry(t, cache, uint64(shardsCount), "xxx")
if !found {
break
}
cache.Add(0, "xxx")
found = ExpectEntry(t, cache, 0, "xxx")
if !found {
break
}
}
}

View file

@ -0,0 +1,85 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cache
import (
"sync"
"time"
"github.com/golang/groupcache/lru"
)
// Clock defines an interface for obtaining the current time
type Clock interface {
Now() time.Time
}
// realClock implements the Clock interface by calling time.Now()
type realClock struct{}
func (realClock) Now() time.Time { return time.Now() }
type LRUExpireCache struct {
// clock is used to obtain the current time
clock Clock
cache *lru.Cache
lock sync.Mutex
}
// NewLRUExpireCache creates an expiring cache with the given size
func NewLRUExpireCache(maxSize int) *LRUExpireCache {
return &LRUExpireCache{clock: realClock{}, cache: lru.New(maxSize)}
}
// NewLRUExpireCache creates an expiring cache with the given size, using the specified clock to obtain the current time
func NewLRUExpireCacheWithClock(maxSize int, clock Clock) *LRUExpireCache {
return &LRUExpireCache{clock: clock, cache: lru.New(maxSize)}
}
type cacheEntry struct {
value interface{}
expireTime time.Time
}
func (c *LRUExpireCache) Add(key lru.Key, value interface{}, ttl time.Duration) {
c.lock.Lock()
defer c.lock.Unlock()
c.cache.Add(key, &cacheEntry{value, c.clock.Now().Add(ttl)})
// Remove entry from cache after ttl.
time.AfterFunc(ttl, func() { c.remove(key) })
}
func (c *LRUExpireCache) Get(key lru.Key) (interface{}, bool) {
c.lock.Lock()
defer c.lock.Unlock()
e, ok := c.cache.Get(key)
if !ok {
return nil, false
}
if c.clock.Now().After(e.(*cacheEntry).expireTime) {
go c.remove(key)
return nil, false
}
return e.(*cacheEntry).value, true
}
func (c *LRUExpireCache) remove(key lru.Key) {
c.lock.Lock()
defer c.lock.Unlock()
c.cache.Remove(key)
}

View file

@ -0,0 +1,68 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cache
import (
"testing"
"time"
"k8s.io/client-go/util/clock"
"github.com/golang/groupcache/lru"
)
func expectEntry(t *testing.T, c *LRUExpireCache, key lru.Key, value interface{}) {
result, ok := c.Get(key)
if !ok || result != value {
t.Errorf("Expected cache[%v]: %v, got %v", key, value, result)
}
}
func expectNotEntry(t *testing.T, c *LRUExpireCache, key lru.Key) {
if result, ok := c.Get(key); ok {
t.Errorf("Expected cache[%v] to be empty, got %v", key, result)
}
}
func TestSimpleGet(t *testing.T) {
c := NewLRUExpireCache(10)
c.Add("long-lived", "12345", 10*time.Hour)
expectEntry(t, c, "long-lived", "12345")
}
func TestExpiredGet(t *testing.T) {
fakeClock := clock.NewFakeClock(time.Now())
c := NewLRUExpireCacheWithClock(10, fakeClock)
c.Add("short-lived", "12345", 1*time.Millisecond)
// ensure the entry expired
fakeClock.Step(2 * time.Millisecond)
expectNotEntry(t, c, "short-lived")
}
func TestLRUOverflow(t *testing.T) {
c := NewLRUExpireCache(4)
c.Add("elem1", "1", 10*time.Hour)
c.Add("elem2", "2", 10*time.Hour)
c.Add("elem3", "3", 10*time.Hour)
c.Add("elem4", "4", 10*time.Hour)
c.Add("elem5", "5", 10*time.Hour)
expectNotEntry(t, c, "elem1")
expectEntry(t, c, "elem2", "2")
expectEntry(t, c, "elem3", "3")
expectEntry(t, c, "elem4", "4")
expectEntry(t, c, "elem5", "5")
}

View file

@ -0,0 +1,232 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package feature
import (
"fmt"
"sort"
"strconv"
"strings"
"github.com/golang/glog"
"github.com/spf13/pflag"
)
type Feature string
const (
flagName = "feature-gates"
// allAlphaGate is a global toggle for alpha features. Per-feature key
// values override the default set by allAlphaGate. Examples:
// AllAlpha=false,NewFeature=true will result in newFeature=true
// AllAlpha=true,NewFeature=false will result in newFeature=false
allAlphaGate Feature = "AllAlpha"
)
var (
// The generic features.
defaultFeatures = map[Feature]FeatureSpec{
allAlphaGate: {Default: false, PreRelease: Alpha},
}
// Special handling for a few gates.
specialFeatures = map[Feature]func(f *featureGate, val bool){
allAlphaGate: setUnsetAlphaGates,
}
// DefaultFeatureGate is a shared global FeatureGate.
DefaultFeatureGate = &featureGate{
known: defaultFeatures,
special: specialFeatures,
}
)
type FeatureSpec struct {
Default bool
PreRelease prerelease
}
type prerelease string
const (
// Values for PreRelease.
Alpha = prerelease("ALPHA")
Beta = prerelease("BETA")
GA = prerelease("")
)
// FeatureGate parses and stores flag gates for known features from
// a string like feature1=true,feature2=false,...
type FeatureGate interface {
AddFlag(fs *pflag.FlagSet)
Set(value string) error
Add(features map[Feature]FeatureSpec)
KnownFeatures() []string
// Every feature gate should add method here following this template:
//
// // owner: @username
// // alpha: v1.4
// MyFeature() bool
// owner: @timstclair
// beta: v1.4
AppArmor() bool
// owner: @girishkalele
// alpha: v1.4
ExternalTrafficLocalOnly() bool
// owner: @saad-ali
// alpha: v1.3
DynamicVolumeProvisioning() bool
// owner: @mtaufen
// alpha: v1.4
DynamicKubeletConfig() bool
// owner: timstclair
// alpha: v1.5
StreamingProxyRedirects() bool
// owner: @pweil-
// alpha: v1.5
ExperimentalHostUserNamespaceDefaulting() bool
}
// featureGate implements FeatureGate as well as pflag.Value for flag parsing.
type featureGate struct {
known map[Feature]FeatureSpec
special map[Feature]func(*featureGate, bool)
enabled map[Feature]bool
// is set to true when AddFlag is called. Note: initialization is not go-routine safe, lookup is
closed bool
}
func setUnsetAlphaGates(f *featureGate, val bool) {
for k, v := range f.known {
if v.PreRelease == Alpha {
if _, found := f.enabled[k]; !found {
f.enabled[k] = val
}
}
}
}
// Set, String, and Type implement pflag.Value
var _ pflag.Value = &featureGate{}
// Set Parses a string of the form // "key1=value1,key2=value2,..." into a
// map[string]bool of known keys or returns an error.
func (f *featureGate) Set(value string) error {
f.enabled = make(map[Feature]bool)
for _, s := range strings.Split(value, ",") {
if len(s) == 0 {
continue
}
arr := strings.SplitN(s, "=", 2)
k := Feature(strings.TrimSpace(arr[0]))
_, ok := f.known[Feature(k)]
if !ok {
return fmt.Errorf("unrecognized key: %s", k)
}
if len(arr) != 2 {
return fmt.Errorf("missing bool value for %s", k)
}
v := strings.TrimSpace(arr[1])
boolValue, err := strconv.ParseBool(v)
if err != nil {
return fmt.Errorf("invalid value of %s: %s, err: %v", k, v, err)
}
f.enabled[k] = boolValue
// Handle "special" features like "all alpha gates"
if fn, found := f.special[k]; found {
fn(f, boolValue)
}
}
glog.Infof("feature gates: %v", f.enabled)
return nil
}
func (f *featureGate) String() string {
pairs := []string{}
for k, v := range f.enabled {
pairs = append(pairs, fmt.Sprintf("%s=%t", k, v))
}
sort.Strings(pairs)
return strings.Join(pairs, ",")
}
func (f *featureGate) Type() string {
return "mapStringBool"
}
func (f *featureGate) Add(features map[Feature]FeatureSpec) error {
if f.closed {
return fmt.Errorf("cannot add a feature gate after adding it to the flag set")
}
for name, spec := range features {
if existingSpec, found := f.known[name]; found {
if existingSpec == spec {
continue
}
return fmt.Errorf("feature gate %q with different spec already exists: %v", name, existingSpec)
}
f.known[name] = spec
}
return nil
}
func (f *featureGate) Enabled(key Feature) bool {
defaultValue := f.known[key].Default
if f.enabled != nil {
if v, ok := f.enabled[key]; ok {
return v
}
}
return defaultValue
}
// AddFlag adds a flag for setting global feature gates to the specified FlagSet.
func (f *featureGate) AddFlag(fs *pflag.FlagSet) {
f.closed = true
known := f.KnownFeatures()
fs.Var(f, flagName, ""+
"A set of key=value pairs that describe feature gates for alpha/experimental features. "+
"Options are:\n"+strings.Join(known, "\n"))
}
// Returns a string describing the FeatureGate's known features.
func (f *featureGate) KnownFeatures() []string {
var known []string
for k, v := range f.known {
pre := ""
if v.PreRelease != GA {
pre = fmt.Sprintf("%s - ", v.PreRelease)
}
known = append(known, fmt.Sprintf("%s=true|false (%sdefault=%t)", k, pre, v.Default))
}
sort.Strings(known)
return known
}

View file

@ -0,0 +1,159 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package feature
import (
"fmt"
"strings"
"testing"
"github.com/spf13/pflag"
)
func TestFeatureGateFlag(t *testing.T) {
// gates for testing
const testAlphaGate Feature = "TestAlpha"
const testBetaGate Feature = "TestBeta"
tests := []struct {
arg string
expect map[Feature]bool
parseError string
}{
{
arg: "",
expect: map[Feature]bool{
allAlphaGate: false,
testAlphaGate: false,
testBetaGate: false,
},
},
{
arg: "fooBarBaz=maybeidk",
expect: map[Feature]bool{
allAlphaGate: false,
testAlphaGate: false,
testBetaGate: false,
},
parseError: "unrecognized key: fooBarBaz",
},
{
arg: "AllAlpha=false",
expect: map[Feature]bool{
allAlphaGate: false,
testAlphaGate: false,
testBetaGate: false,
},
},
{
arg: "AllAlpha=true",
expect: map[Feature]bool{
allAlphaGate: true,
testAlphaGate: true,
testBetaGate: false,
},
},
{
arg: "AllAlpha=banana",
expect: map[Feature]bool{
allAlphaGate: false,
testAlphaGate: false,
testBetaGate: false,
},
parseError: "invalid value of AllAlpha",
},
{
arg: "AllAlpha=false,TestAlpha=true",
expect: map[Feature]bool{
allAlphaGate: false,
testAlphaGate: true,
testBetaGate: false,
},
},
{
arg: "TestAlpha=true,AllAlpha=false",
expect: map[Feature]bool{
allAlphaGate: false,
testAlphaGate: true,
testBetaGate: false,
},
},
{
arg: "AllAlpha=true,TestAlpha=false",
expect: map[Feature]bool{
allAlphaGate: true,
testAlphaGate: false,
testBetaGate: false,
},
},
{
arg: "TestAlpha=false,AllAlpha=true",
expect: map[Feature]bool{
allAlphaGate: true,
testAlphaGate: false,
testBetaGate: false,
},
},
{
arg: "TestBeta=true,AllAlpha=false",
expect: map[Feature]bool{
allAlphaGate: false,
testAlphaGate: false,
testBetaGate: true,
},
},
}
for i, test := range tests {
fs := pflag.NewFlagSet("testfeaturegateflag", pflag.ContinueOnError)
f := DefaultFeatureGate
f.known[testAlphaGate] = FeatureSpec{Default: false, PreRelease: Alpha}
f.known[testBetaGate] = FeatureSpec{Default: false, PreRelease: Beta}
f.AddFlag(fs)
err := fs.Parse([]string{fmt.Sprintf("--%s=%s", flagName, test.arg)})
if test.parseError != "" {
if !strings.Contains(err.Error(), test.parseError) {
t.Errorf("%d: Parse() Expected %v, Got %v", i, test.parseError, err)
}
} else if err != nil {
t.Errorf("%d: Parse() Expected nil, Got %v", i, err)
}
for k, v := range test.expect {
if f.enabled[k] != v {
t.Errorf("%d: expected %s=%v, Got %v", i, k, v, f.enabled[k])
}
}
}
}
func TestFeatureGateFlagDefaults(t *testing.T) {
// gates for testing
const testAlphaGate Feature = "TestAlpha"
const testBetaGate Feature = "TestBeta"
// Don't parse the flag, assert defaults are used.
f := DefaultFeatureGate
f.known[testAlphaGate] = FeatureSpec{Default: false, PreRelease: Alpha}
f.known[testBetaGate] = FeatureSpec{Default: true, PreRelease: Beta}
if f.Enabled(testAlphaGate) != false {
t.Errorf("Expected false")
}
if f.Enabled(testBetaGate) != true {
t.Errorf("Expected true")
}
}

View file

@ -0,0 +1,53 @@
/*
Copyright 2014 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package flag
import (
"fmt"
"sort"
"strings"
)
type ConfigurationMap map[string]string
func (m *ConfigurationMap) String() string {
pairs := []string{}
for k, v := range *m {
pairs = append(pairs, fmt.Sprintf("%s=%s", k, v))
}
sort.Strings(pairs)
return strings.Join(pairs, ",")
}
func (m *ConfigurationMap) Set(value string) error {
for _, s := range strings.Split(value, ",") {
if len(s) == 0 {
continue
}
arr := strings.SplitN(s, "=", 2)
if len(arr) == 2 {
(*m)[strings.TrimSpace(arr[0])] = strings.TrimSpace(arr[1])
} else {
(*m)[strings.TrimSpace(arr[0])] = ""
}
}
return nil
}
func (*ConfigurationMap) Type() string {
return "mapStringString"
}

51
vendor/k8s.io/apiserver/pkg/util/flag/flags.go generated vendored Normal file
View file

@ -0,0 +1,51 @@
/*
Copyright 2014 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package flag
import (
goflag "flag"
"strings"
"github.com/golang/glog"
"github.com/spf13/pflag"
)
// WordSepNormalizeFunc changes all flags that contain "_" separators
func WordSepNormalizeFunc(f *pflag.FlagSet, name string) pflag.NormalizedName {
if strings.Contains(name, "_") {
return pflag.NormalizedName(strings.Replace(name, "_", "-", -1))
}
return pflag.NormalizedName(name)
}
// WarnWordSepNormalizeFunc changes and warns for flags that contain "_" separators
func WarnWordSepNormalizeFunc(f *pflag.FlagSet, name string) pflag.NormalizedName {
if strings.Contains(name, "_") {
nname := strings.Replace(name, "_", "-", -1)
glog.Warningf("%s is DEPRECATED and will be removed in a future version. Use %s instead.", name, nname)
return pflag.NormalizedName(nname)
}
return pflag.NormalizedName(name)
}
// InitFlags normalizes and parses the command line flags
func InitFlags() {
pflag.CommandLine.SetNormalizeFunc(WordSepNormalizeFunc)
pflag.CommandLine.AddGoFlagSet(goflag.CommandLine)
pflag.Parse()
}

View file

@ -0,0 +1,113 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package flag
import (
"errors"
"flag"
"strings"
)
// NamedCertKey is a flag value parsing "certfile,keyfile" and "certfile,keyfile:name,name,name".
type NamedCertKey struct {
Names []string
CertFile, KeyFile string
}
var _ flag.Value = &NamedCertKey{}
func (nkc *NamedCertKey) String() string {
s := nkc.CertFile + "," + nkc.KeyFile
if len(nkc.Names) > 0 {
s = s + ":" + strings.Join(nkc.Names, ",")
}
return s
}
func (nkc *NamedCertKey) Set(value string) error {
cs := strings.SplitN(value, ":", 2)
var keycert string
if len(cs) == 2 {
var names string
keycert, names = strings.TrimSpace(cs[0]), strings.TrimSpace(cs[1])
if names == "" {
return errors.New("empty names list is not allowed")
}
nkc.Names = nil
for _, name := range strings.Split(names, ",") {
nkc.Names = append(nkc.Names, strings.TrimSpace(name))
}
} else {
nkc.Names = nil
keycert = strings.TrimSpace(cs[0])
}
cs = strings.Split(keycert, ",")
if len(cs) != 2 {
return errors.New("expected comma separated certificate and key file paths")
}
nkc.CertFile = strings.TrimSpace(cs[0])
nkc.KeyFile = strings.TrimSpace(cs[1])
return nil
}
func (*NamedCertKey) Type() string {
return "namedCertKey"
}
// NamedCertKeyArray is a flag value parsing NamedCertKeys, each passed with its own
// flag instance (in contrast to comma separated slices).
type NamedCertKeyArray struct {
value *[]NamedCertKey
changed bool
}
var _ flag.Value = &NamedCertKey{}
// NewNamedKeyCertArray creates a new NamedCertKeyArray with the internal value
// pointing to p.
func NewNamedCertKeyArray(p *[]NamedCertKey) *NamedCertKeyArray {
return &NamedCertKeyArray{
value: p,
}
}
func (a *NamedCertKeyArray) Set(val string) error {
nkc := NamedCertKey{}
err := nkc.Set(val)
if err != nil {
return err
}
if !a.changed {
*a.value = []NamedCertKey{nkc}
a.changed = true
} else {
*a.value = append(*a.value, nkc)
}
return nil
}
func (a *NamedCertKeyArray) Type() string {
return "namedCertKey"
}
func (a *NamedCertKeyArray) String() string {
nkcs := make([]string, 0, len(*a.value))
for i := range *a.value {
nkcs = append(nkcs, (*a.value)[i].String())
}
return "[" + strings.Join(nkcs, ";") + "]"
}

View file

@ -0,0 +1,139 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package flag
import (
"fmt"
"reflect"
"strings"
"testing"
"github.com/spf13/pflag"
)
func TestNamedCertKeyArrayFlag(t *testing.T) {
tests := []struct {
args []string
def []NamedCertKey
expected []NamedCertKey
parseError string
}{
{
args: []string{},
expected: nil,
},
{
args: []string{"foo.crt,foo.key"},
expected: []NamedCertKey{{
KeyFile: "foo.key",
CertFile: "foo.crt",
}},
},
{
args: []string{" foo.crt , foo.key "},
expected: []NamedCertKey{{
KeyFile: "foo.key",
CertFile: "foo.crt",
}},
},
{
args: []string{"foo.crt,foo.key:abc"},
expected: []NamedCertKey{{
KeyFile: "foo.key",
CertFile: "foo.crt",
Names: []string{"abc"},
}},
},
{
args: []string{"foo.crt,foo.key: abc "},
expected: []NamedCertKey{{
KeyFile: "foo.key",
CertFile: "foo.crt",
Names: []string{"abc"},
}},
},
{
args: []string{"foo.crt,foo.key:"},
parseError: "empty names list is not allowed",
},
{
args: []string{""},
parseError: "expected comma separated certificate and key file paths",
},
{
args: []string{" "},
parseError: "expected comma separated certificate and key file paths",
},
{
args: []string{"a,b,c"},
parseError: "expected comma separated certificate and key file paths",
},
{
args: []string{"foo.crt,foo.key:abc,def,ghi"},
expected: []NamedCertKey{{
KeyFile: "foo.key",
CertFile: "foo.crt",
Names: []string{"abc", "def", "ghi"},
}},
},
{
args: []string{"foo.crt,foo.key:*.*.*"},
expected: []NamedCertKey{{
KeyFile: "foo.key",
CertFile: "foo.crt",
Names: []string{"*.*.*"},
}},
},
{
args: []string{"foo.crt,foo.key", "bar.crt,bar.key"},
expected: []NamedCertKey{{
KeyFile: "foo.key",
CertFile: "foo.crt",
}, {
KeyFile: "bar.key",
CertFile: "bar.crt",
}},
},
}
for i, test := range tests {
fs := pflag.NewFlagSet("testNamedCertKeyArray", pflag.ContinueOnError)
var nkcs []NamedCertKey
for _, d := range test.def {
nkcs = append(nkcs, d)
}
fs.Var(NewNamedCertKeyArray(&nkcs), "tls-sni-cert-key", "usage")
args := []string{}
for _, a := range test.args {
args = append(args, fmt.Sprintf("--tls-sni-cert-key=%s", a))
}
err := fs.Parse(args)
if test.parseError != "" {
if err == nil {
t.Errorf("%d: expected error %q, got nil", i, test.parseError)
} else if !strings.Contains(err.Error(), test.parseError) {
t.Errorf("%d: expected error %q, got %q", i, test.parseError, err)
}
} else if err != nil {
t.Errorf("%d: expected nil error, got %v", i, err)
}
if !reflect.DeepEqual(nkcs, test.expected) {
t.Errorf("%d: expected %+v, got %+v", i, test.expected, nkcs)
}
}
}

56
vendor/k8s.io/apiserver/pkg/util/flag/string_flag.go generated vendored Normal file
View file

@ -0,0 +1,56 @@
/*
Copyright 2014 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package flag
// StringFlag is a string flag compatible with flags and pflags that keeps track of whether it had a value supplied or not.
type StringFlag struct {
// If Set has been invoked this value is true
provided bool
// The exact value provided on the flag
value string
}
func NewStringFlag(defaultVal string) StringFlag {
return StringFlag{value: defaultVal}
}
func (f *StringFlag) Default(value string) {
f.value = value
}
func (f StringFlag) String() string {
return f.value
}
func (f StringFlag) Value() string {
return f.value
}
func (f *StringFlag) Set(value string) error {
f.value = value
f.provided = true
return nil
}
func (f StringFlag) Provided() bool {
return f.provided
}
func (f *StringFlag) Type() string {
return "string"
}

83
vendor/k8s.io/apiserver/pkg/util/flag/tristate.go generated vendored Normal file
View file

@ -0,0 +1,83 @@
/*
Copyright 2014 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package flag
import (
"fmt"
"strconv"
)
// Tristate is a flag compatible with flags and pflags that
// keeps track of whether it had a value supplied or not.
type Tristate int
const (
Unset Tristate = iota // 0
True
False
)
func (f *Tristate) Default(value bool) {
*f = triFromBool(value)
}
func (f Tristate) String() string {
b := boolFromTri(f)
return fmt.Sprintf("%t", b)
}
func (f Tristate) Value() bool {
b := boolFromTri(f)
return b
}
func (f *Tristate) Set(value string) error {
boolVal, err := strconv.ParseBool(value)
if err != nil {
return err
}
*f = triFromBool(boolVal)
return nil
}
func (f Tristate) Provided() bool {
if f != Unset {
return true
}
return false
}
func (f *Tristate) Type() string {
return "tristate"
}
func boolFromTri(t Tristate) bool {
if t == True {
return true
} else {
return false
}
}
func triFromBool(b bool) Tristate {
if b {
return True
} else {
return False
}
}

19
vendor/k8s.io/apiserver/pkg/util/flushwriter/doc.go generated vendored Normal file
View file

@ -0,0 +1,19 @@
/*
Copyright 2014 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package flushwriter implements a wrapper for a writer that flushes on every
// write if that writer implements the io.Flusher interface
package flushwriter // import "k8s.io/apiserver/pkg/util/flushwriter"

53
vendor/k8s.io/apiserver/pkg/util/flushwriter/writer.go generated vendored Normal file
View file

@ -0,0 +1,53 @@
/*
Copyright 2014 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package flushwriter
import (
"io"
"net/http"
)
// Wrap wraps an io.Writer into a writer that flushes after every write if
// the writer implements the Flusher interface.
func Wrap(w io.Writer) io.Writer {
fw := &flushWriter{
writer: w,
}
if flusher, ok := w.(http.Flusher); ok {
fw.flusher = flusher
}
return fw
}
// flushWriter provides wrapper for responseWriter with HTTP streaming capabilities
type flushWriter struct {
flusher http.Flusher
writer io.Writer
}
// Write is a FlushWriter implementation of the io.Writer that sends any buffered
// data to the client.
func (fw *flushWriter) Write(p []byte) (n int, err error) {
n, err = fw.writer.Write(p)
if err != nil {
return
}
if fw.flusher != nil {
fw.flusher.Flush()
}
return
}

View file

@ -0,0 +1,86 @@
/*
Copyright 2014 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package flushwriter
import (
"fmt"
"testing"
)
type writerWithFlush struct {
writeCount, flushCount int
err error
}
func (w *writerWithFlush) Flush() {
w.flushCount++
}
func (w *writerWithFlush) Write(p []byte) (n int, err error) {
w.writeCount++
return len(p), w.err
}
type writerWithNoFlush struct {
writeCount int
}
func (w *writerWithNoFlush) Write(p []byte) (n int, err error) {
w.writeCount++
return len(p), nil
}
func TestWriteWithFlush(t *testing.T) {
w := &writerWithFlush{}
fw := Wrap(w)
for i := 0; i < 10; i++ {
_, err := fw.Write([]byte("Test write"))
if err != nil {
t.Errorf("Unexpected error while writing with flush writer: %v", err)
}
}
if w.flushCount != 10 {
t.Errorf("Flush not called the expected number of times. Actual: %d", w.flushCount)
}
if w.writeCount != 10 {
t.Errorf("Write not called the expected number of times. Actual: %d", w.writeCount)
}
}
func TestWriteWithoutFlush(t *testing.T) {
w := &writerWithNoFlush{}
fw := Wrap(w)
for i := 0; i < 10; i++ {
_, err := fw.Write([]byte("Test write"))
if err != nil {
t.Errorf("Unexpected error while writing with flush writer: %v", err)
}
}
if w.writeCount != 10 {
t.Errorf("Write not called the expected number of times. Actual: %d", w.writeCount)
}
}
func TestWriteError(t *testing.T) {
e := fmt.Errorf("Error")
w := &writerWithFlush{err: e}
fw := Wrap(w)
_, err := fw.Write([]byte("Test write"))
if err != e {
t.Errorf("Did not get expected error. Got: %#v", err)
}
}

106
vendor/k8s.io/apiserver/pkg/util/proxy/dial.go generated vendored Normal file
View file

@ -0,0 +1,106 @@
/*
Copyright 2015 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package proxy
import (
"crypto/tls"
"fmt"
"net"
"net/http"
"net/url"
"github.com/golang/glog"
utilnet "k8s.io/apimachinery/pkg/util/net"
"k8s.io/apimachinery/third_party/forked/golang/netutil"
)
func DialURL(url *url.URL, transport http.RoundTripper) (net.Conn, error) {
dialAddr := netutil.CanonicalAddr(url)
dialer, _ := utilnet.Dialer(transport)
switch url.Scheme {
case "http":
if dialer != nil {
return dialer("tcp", dialAddr)
}
return net.Dial("tcp", dialAddr)
case "https":
// Get the tls config from the transport if we recognize it
var tlsConfig *tls.Config
var tlsConn *tls.Conn
var err error
tlsConfig, _ = utilnet.TLSClientConfig(transport)
if dialer != nil {
// We have a dialer; use it to open the connection, then
// create a tls client using the connection.
netConn, err := dialer("tcp", dialAddr)
if err != nil {
return nil, err
}
if tlsConfig == nil {
// tls.Client requires non-nil config
glog.Warningf("using custom dialer with no TLSClientConfig. Defaulting to InsecureSkipVerify")
// tls.Handshake() requires ServerName or InsecureSkipVerify
tlsConfig = &tls.Config{
InsecureSkipVerify: true,
}
} else if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify {
// tls.Handshake() requires ServerName or InsecureSkipVerify
// infer the ServerName from the hostname we're connecting to.
inferredHost := dialAddr
if host, _, err := net.SplitHostPort(dialAddr); err == nil {
inferredHost = host
}
// Make a copy to avoid polluting the provided config
tlsConfigCopy := utilnet.CloneTLSConfig(tlsConfig)
tlsConfigCopy.ServerName = inferredHost
tlsConfig = tlsConfigCopy
}
tlsConn = tls.Client(netConn, tlsConfig)
if err := tlsConn.Handshake(); err != nil {
netConn.Close()
return nil, err
}
} else {
// Dial
tlsConn, err = tls.Dial("tcp", dialAddr, tlsConfig)
if err != nil {
return nil, err
}
}
// Return if we were configured to skip validation
if tlsConfig != nil && tlsConfig.InsecureSkipVerify {
return tlsConn, nil
}
// Verify
host, _, _ := net.SplitHostPort(dialAddr)
if err := tlsConn.VerifyHostname(host); err != nil {
tlsConn.Close()
return nil, err
}
return tlsConn, nil
default:
return nil, fmt.Errorf("Unknown scheme: %s", url.Scheme)
}
}

171
vendor/k8s.io/apiserver/pkg/util/proxy/dial_test.go generated vendored Normal file
View file

@ -0,0 +1,171 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package proxy
import (
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"strings"
"testing"
utilnet "k8s.io/apimachinery/pkg/util/net"
)
func TestDialURL(t *testing.T) {
roots := x509.NewCertPool()
if !roots.AppendCertsFromPEM(localhostCert) {
t.Fatal("error setting up localhostCert pool")
}
cert, err := tls.X509KeyPair(localhostCert, localhostKey)
if err != nil {
t.Fatal(err)
}
testcases := map[string]struct {
TLSConfig *tls.Config
Dial utilnet.DialFunc
ExpectError string
}{
"insecure": {
TLSConfig: &tls.Config{InsecureSkipVerify: true},
},
"secure, no roots": {
TLSConfig: &tls.Config{InsecureSkipVerify: false},
ExpectError: "unknown authority",
},
"secure with roots": {
TLSConfig: &tls.Config{InsecureSkipVerify: false, RootCAs: roots},
},
"secure with mismatched server": {
TLSConfig: &tls.Config{InsecureSkipVerify: false, RootCAs: roots, ServerName: "bogus.com"},
ExpectError: "not bogus.com",
},
"secure with matched server": {
TLSConfig: &tls.Config{InsecureSkipVerify: false, RootCAs: roots, ServerName: "example.com"},
},
"insecure, custom dial": {
TLSConfig: &tls.Config{InsecureSkipVerify: true},
Dial: net.Dial,
},
"secure, no roots, custom dial": {
TLSConfig: &tls.Config{InsecureSkipVerify: false},
Dial: net.Dial,
ExpectError: "unknown authority",
},
"secure with roots, custom dial": {
TLSConfig: &tls.Config{InsecureSkipVerify: false, RootCAs: roots},
Dial: net.Dial,
},
"secure with mismatched server, custom dial": {
TLSConfig: &tls.Config{InsecureSkipVerify: false, RootCAs: roots, ServerName: "bogus.com"},
Dial: net.Dial,
ExpectError: "not bogus.com",
},
"secure with matched server, custom dial": {
TLSConfig: &tls.Config{InsecureSkipVerify: false, RootCAs: roots, ServerName: "example.com"},
Dial: net.Dial,
},
}
for k, tc := range testcases {
func() {
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {}))
defer ts.Close()
ts.TLS = &tls.Config{Certificates: []tls.Certificate{cert}}
ts.StartTLS()
tlsConfigCopy := utilnet.CloneTLSConfig(tc.TLSConfig)
transport := &http.Transport{
Dial: tc.Dial,
TLSClientConfig: tlsConfigCopy,
}
extractedDial, err := utilnet.Dialer(transport)
if err != nil {
t.Fatal(err)
}
if fmt.Sprintf("%p", extractedDial) != fmt.Sprintf("%p", tc.Dial) {
t.Fatalf("%s: Unexpected dial", k)
}
extractedTLSConfig, err := utilnet.TLSClientConfig(transport)
if err != nil {
t.Fatal(err)
}
if extractedTLSConfig == nil {
t.Fatalf("%s: Expected tlsConfig", k)
}
u, _ := url.Parse(ts.URL)
_, p, _ := net.SplitHostPort(u.Host)
u.Host = net.JoinHostPort("127.0.0.1", p)
conn, err := DialURL(u, transport)
// Make sure dialing doesn't mutate the transport's TLSConfig
if !reflect.DeepEqual(tc.TLSConfig, tlsConfigCopy) {
t.Errorf("%s: transport's copy of TLSConfig was mutated\n%#v\n\n%#v", k, tc.TLSConfig, tlsConfigCopy)
}
if err != nil {
if tc.ExpectError == "" {
t.Errorf("%s: expected no error, got %q", k, err.Error())
}
if !strings.Contains(err.Error(), tc.ExpectError) {
t.Errorf("%s: expected error containing %q, got %q", k, tc.ExpectError, err.Error())
}
return
}
conn.Close()
if tc.ExpectError != "" {
t.Errorf("%s: expected error %q, got none", k, tc.ExpectError)
}
}()
}
}
// localhostCert was generated from crypto/tls/generate_cert.go with the following command:
// go run generate_cert.go --rsa-bits 512 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h
var localhostCert = []byte(`-----BEGIN CERTIFICATE-----
MIIBdzCCASOgAwIBAgIBADALBgkqhkiG9w0BAQUwEjEQMA4GA1UEChMHQWNtZSBD
bzAeFw03MDAxMDEwMDAwMDBaFw00OTEyMzEyMzU5NTlaMBIxEDAOBgNVBAoTB0Fj
bWUgQ28wWjALBgkqhkiG9w0BAQEDSwAwSAJBAN55NcYKZeInyTuhcCwFMhDHCmwa
IUSdtXdcbItRB/yfXGBhiex00IaLXQnSU+QZPRZWYqeTEbFSgihqi1PUDy8CAwEA
AaNoMGYwDgYDVR0PAQH/BAQDAgCkMBMGA1UdJQQMMAoGCCsGAQUFBwMBMA8GA1Ud
EwEB/wQFMAMBAf8wLgYDVR0RBCcwJYILZXhhbXBsZS5jb22HBH8AAAGHEAAAAAAA
AAAAAAAAAAAAAAEwCwYJKoZIhvcNAQEFA0EAAoQn/ytgqpiLcZu9XKbCJsJcvkgk
Se6AbGXgSlq+ZCEVo0qIwSgeBqmsJxUu7NCSOwVJLYNEBO2DtIxoYVk+MA==
-----END CERTIFICATE-----`)
// localhostKey is the private key for localhostCert.
var localhostKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
MIIBPAIBAAJBAN55NcYKZeInyTuhcCwFMhDHCmwaIUSdtXdcbItRB/yfXGBhiex0
0IaLXQnSU+QZPRZWYqeTEbFSgihqi1PUDy8CAwEAAQJBAQdUx66rfh8sYsgfdcvV
NoafYpnEcB5s4m/vSVe6SU7dCK6eYec9f9wpT353ljhDUHq3EbmE4foNzJngh35d
AekCIQDhRQG5Li0Wj8TM4obOnnXUXf1jRv0UkzE9AHWLG5q3AwIhAPzSjpYUDjVW
MCUXgckTpKCuGwbJk7424Nb8bLzf3kllAiA5mUBgjfr/WtFSJdWcPQ4Zt9KTMNKD
EUO0ukpTwEIl6wIhAMbGqZK3zAAFdq8DD2jPx+UJXnh0rnOkZBzDtJ6/iN69AiEA
1Aq8MJgTaYsDQWyU/hDq5YkDJc9e9DSCvUIzqxQWMQE=
-----END RSA PRIVATE KEY-----`)

18
vendor/k8s.io/apiserver/pkg/util/proxy/doc.go generated vendored Normal file
View file

@ -0,0 +1,18 @@
/*
Copyright 2014 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package proxy provides transport and upgrade support for proxies
package proxy // import "k8s.io/apiserver/pkg/util/proxy"

241
vendor/k8s.io/apiserver/pkg/util/proxy/transport.go generated vendored Normal file
View file

@ -0,0 +1,241 @@
/*
Copyright 2014 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package proxy
import (
"bytes"
"compress/gzip"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"path"
"strings"
"github.com/golang/glog"
"golang.org/x/net/html"
"golang.org/x/net/html/atom"
"k8s.io/apimachinery/pkg/util/net"
"k8s.io/apimachinery/pkg/util/sets"
)
// atomsToAttrs states which attributes of which tags require URL substitution.
// Sources: http://www.w3.org/TR/REC-html40/index/attributes.html
// http://www.w3.org/html/wg/drafts/html/master/index.html#attributes-1
var atomsToAttrs = map[atom.Atom]sets.String{
atom.A: sets.NewString("href"),
atom.Applet: sets.NewString("codebase"),
atom.Area: sets.NewString("href"),
atom.Audio: sets.NewString("src"),
atom.Base: sets.NewString("href"),
atom.Blockquote: sets.NewString("cite"),
atom.Body: sets.NewString("background"),
atom.Button: sets.NewString("formaction"),
atom.Command: sets.NewString("icon"),
atom.Del: sets.NewString("cite"),
atom.Embed: sets.NewString("src"),
atom.Form: sets.NewString("action"),
atom.Frame: sets.NewString("longdesc", "src"),
atom.Head: sets.NewString("profile"),
atom.Html: sets.NewString("manifest"),
atom.Iframe: sets.NewString("longdesc", "src"),
atom.Img: sets.NewString("longdesc", "src", "usemap"),
atom.Input: sets.NewString("src", "usemap", "formaction"),
atom.Ins: sets.NewString("cite"),
atom.Link: sets.NewString("href"),
atom.Object: sets.NewString("classid", "codebase", "data", "usemap"),
atom.Q: sets.NewString("cite"),
atom.Script: sets.NewString("src"),
atom.Source: sets.NewString("src"),
atom.Video: sets.NewString("poster", "src"),
// TODO: css URLs hidden in style elements.
}
// Transport is a transport for text/html content that replaces URLs in html
// content with the prefix of the proxy server
type Transport struct {
Scheme string
Host string
PathPrepend string
http.RoundTripper
}
// RoundTrip implements the http.RoundTripper interface
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
// Add reverse proxy headers.
forwardedURI := path.Join(t.PathPrepend, req.URL.Path)
if strings.HasSuffix(req.URL.Path, "/") {
forwardedURI = forwardedURI + "/"
}
req.Header.Set("X-Forwarded-Uri", forwardedURI)
if len(t.Host) > 0 {
req.Header.Set("X-Forwarded-Host", t.Host)
}
if len(t.Scheme) > 0 {
req.Header.Set("X-Forwarded-Proto", t.Scheme)
}
rt := t.RoundTripper
if rt == nil {
rt = http.DefaultTransport
}
resp, err := rt.RoundTrip(req)
if err != nil {
message := fmt.Sprintf("Error: '%s'\nTrying to reach: '%v'", err.Error(), req.URL.String())
resp = &http.Response{
StatusCode: http.StatusServiceUnavailable,
Body: ioutil.NopCloser(strings.NewReader(message)),
}
return resp, nil
}
if redirect := resp.Header.Get("Location"); redirect != "" {
resp.Header.Set("Location", t.rewriteURL(redirect, req.URL))
return resp, nil
}
cType := resp.Header.Get("Content-Type")
cType = strings.TrimSpace(strings.SplitN(cType, ";", 2)[0])
if cType != "text/html" {
// Do nothing, simply pass through
return resp, nil
}
return t.rewriteResponse(req, resp)
}
var _ = net.RoundTripperWrapper(&Transport{})
func (rt *Transport) WrappedRoundTripper() http.RoundTripper {
return rt.RoundTripper
}
// rewriteURL rewrites a single URL to go through the proxy, if the URL refers
// to the same host as sourceURL, which is the page on which the target URL
// occurred. If any error occurs (e.g. parsing), it returns targetURL.
func (t *Transport) rewriteURL(targetURL string, sourceURL *url.URL) string {
url, err := url.Parse(targetURL)
if err != nil {
return targetURL
}
isDifferentHost := url.Host != "" && url.Host != sourceURL.Host
isRelative := !strings.HasPrefix(url.Path, "/")
if isDifferentHost || isRelative {
return targetURL
}
url.Scheme = t.Scheme
url.Host = t.Host
origPath := url.Path
// Do not rewrite URL if the sourceURL already contains the necessary prefix.
if strings.HasPrefix(url.Path, t.PathPrepend) {
return url.String()
}
url.Path = path.Join(t.PathPrepend, url.Path)
if strings.HasSuffix(origPath, "/") {
// Add back the trailing slash, which was stripped by path.Join().
url.Path += "/"
}
return url.String()
}
// rewriteHTML scans the HTML for tags with url-valued attributes, and updates
// those values with the urlRewriter function. The updated HTML is output to the
// writer.
func rewriteHTML(reader io.Reader, writer io.Writer, urlRewriter func(string) string) error {
// Note: This assumes the content is UTF-8.
tokenizer := html.NewTokenizer(reader)
var err error
for err == nil {
tokenType := tokenizer.Next()
switch tokenType {
case html.ErrorToken:
err = tokenizer.Err()
case html.StartTagToken, html.SelfClosingTagToken:
token := tokenizer.Token()
if urlAttrs, ok := atomsToAttrs[token.DataAtom]; ok {
for i, attr := range token.Attr {
if urlAttrs.Has(attr.Key) {
token.Attr[i].Val = urlRewriter(attr.Val)
}
}
}
_, err = writer.Write([]byte(token.String()))
default:
_, err = writer.Write(tokenizer.Raw())
}
}
if err != io.EOF {
return err
}
return nil
}
// rewriteResponse modifies an HTML response by updating absolute links referring
// to the original host to instead refer to the proxy transport.
func (t *Transport) rewriteResponse(req *http.Request, resp *http.Response) (*http.Response, error) {
origBody := resp.Body
defer origBody.Close()
newContent := &bytes.Buffer{}
var reader io.Reader = origBody
var writer io.Writer = newContent
encoding := resp.Header.Get("Content-Encoding")
switch encoding {
case "gzip":
var err error
reader, err = gzip.NewReader(reader)
if err != nil {
return nil, fmt.Errorf("errorf making gzip reader: %v", err)
}
gzw := gzip.NewWriter(writer)
defer gzw.Close()
writer = gzw
// TODO: support flate, other encodings.
case "":
// This is fine
default:
// Some encoding we don't understand-- don't try to parse this
glog.Errorf("Proxy encountered encoding %v for text/html; can't understand this so not fixing links.", encoding)
return resp, nil
}
urlRewriter := func(targetUrl string) string {
return t.rewriteURL(targetUrl, req.URL)
}
err := rewriteHTML(reader, writer, urlRewriter)
if err != nil {
glog.Errorf("Failed to rewrite URLs: %v", err)
return resp, err
}
resp.Body = ioutil.NopCloser(newContent)
// Update header node with new content-length
// TODO: Remove any hash/signature headers here?
resp.Header.Del("Content-Length")
resp.ContentLength = int64(newContent.Len())
return resp, err
}

View file

@ -0,0 +1,261 @@
/*
Copyright 2014 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package proxy
import (
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
)
func parseURLOrDie(inURL string) *url.URL {
parsed, err := url.Parse(inURL)
if err != nil {
panic(err)
}
return parsed
}
func TestProxyTransport(t *testing.T) {
testTransport := &Transport{
Scheme: "http",
Host: "foo.com",
PathPrepend: "/proxy/node/node1:10250",
}
testTransport2 := &Transport{
Scheme: "https",
Host: "foo.com",
PathPrepend: "/proxy/node/node1:8080",
}
emptyHostTransport := &Transport{
Scheme: "https",
PathPrepend: "/proxy/node/node1:10250",
}
emptySchemeTransport := &Transport{
Host: "foo.com",
PathPrepend: "/proxy/node/node1:10250",
}
type Item struct {
input string
sourceURL string
transport *Transport
output string
contentType string
forwardedURI string
redirect string
redirectWant string
}
table := map[string]Item{
"normal": {
input: `<pre><a href="kubelet.log">kubelet.log</a><a href="/google.log">google.log</a></pre>`,
sourceURL: "http://mynode.com/logs/log.log",
transport: testTransport,
output: `<pre><a href="kubelet.log">kubelet.log</a><a href="http://foo.com/proxy/node/node1:10250/google.log">google.log</a></pre>`,
contentType: "text/html",
forwardedURI: "/proxy/node/node1:10250/logs/log.log",
},
"full document": {
input: `<html><header></header><body><pre><a href="kubelet.log">kubelet.log</a><a href="/google.log">google.log</a></pre></body></html>`,
sourceURL: "http://mynode.com/logs/log.log",
transport: testTransport,
output: `<html><header></header><body><pre><a href="kubelet.log">kubelet.log</a><a href="http://foo.com/proxy/node/node1:10250/google.log">google.log</a></pre></body></html>`,
contentType: "text/html",
forwardedURI: "/proxy/node/node1:10250/logs/log.log",
},
"trailing slash": {
input: `<pre><a href="kubelet.log">kubelet.log</a><a href="/google.log/">google.log</a></pre>`,
sourceURL: "http://mynode.com/logs/log.log",
transport: testTransport,
output: `<pre><a href="kubelet.log">kubelet.log</a><a href="http://foo.com/proxy/node/node1:10250/google.log/">google.log</a></pre>`,
contentType: "text/html",
forwardedURI: "/proxy/node/node1:10250/logs/log.log",
},
"content-type charset": {
input: `<pre><a href="kubelet.log">kubelet.log</a><a href="/google.log">google.log</a></pre>`,
sourceURL: "http://mynode.com/logs/log.log",
transport: testTransport,
output: `<pre><a href="kubelet.log">kubelet.log</a><a href="http://foo.com/proxy/node/node1:10250/google.log">google.log</a></pre>`,
contentType: "text/html; charset=utf-8",
forwardedURI: "/proxy/node/node1:10250/logs/log.log",
},
"content-type passthrough": {
input: `<pre><a href="kubelet.log">kubelet.log</a><a href="/google.log">google.log</a></pre>`,
sourceURL: "http://mynode.com/logs/log.log",
transport: testTransport,
output: `<pre><a href="kubelet.log">kubelet.log</a><a href="/google.log">google.log</a></pre>`,
contentType: "text/plain",
forwardedURI: "/proxy/node/node1:10250/logs/log.log",
},
"subdir": {
input: `<a href="kubelet.log">kubelet.log</a><a href="/google.log">google.log</a>`,
sourceURL: "http://mynode.com/whatever/apt/somelog.log",
transport: testTransport2,
output: `<a href="kubelet.log">kubelet.log</a><a href="https://foo.com/proxy/node/node1:8080/google.log">google.log</a>`,
contentType: "text/html",
forwardedURI: "/proxy/node/node1:8080/whatever/apt/somelog.log",
},
"image": {
input: `<pre><img src="kubernetes.jpg"/><img src="/kubernetes_abs.jpg"/></pre>`,
sourceURL: "http://mynode.com/",
transport: testTransport,
output: `<pre><img src="kubernetes.jpg"/><img src="http://foo.com/proxy/node/node1:10250/kubernetes_abs.jpg"/></pre>`,
contentType: "text/html",
forwardedURI: "/proxy/node/node1:10250/",
},
"abs": {
input: `<script src="http://google.com/kubernetes.js"/>`,
sourceURL: "http://mynode.com/any/path/",
transport: testTransport,
output: `<script src="http://google.com/kubernetes.js"/>`,
contentType: "text/html",
forwardedURI: "/proxy/node/node1:10250/any/path/",
},
"abs but same host": {
input: `<script src="http://mynode.com/kubernetes.js"/>`,
sourceURL: "http://mynode.com/any/path/",
transport: testTransport,
output: `<script src="http://foo.com/proxy/node/node1:10250/kubernetes.js"/>`,
contentType: "text/html",
forwardedURI: "/proxy/node/node1:10250/any/path/",
},
"redirect rel": {
sourceURL: "http://mynode.com/redirect",
transport: testTransport,
redirect: "/redirected/target/",
redirectWant: "http://foo.com/proxy/node/node1:10250/redirected/target/",
forwardedURI: "/proxy/node/node1:10250/redirect",
},
"redirect abs same host": {
sourceURL: "http://mynode.com/redirect",
transport: testTransport,
redirect: "http://mynode.com/redirected/target/",
redirectWant: "http://foo.com/proxy/node/node1:10250/redirected/target/",
forwardedURI: "/proxy/node/node1:10250/redirect",
},
"redirect abs other host": {
sourceURL: "http://mynode.com/redirect",
transport: testTransport,
redirect: "http://example.com/redirected/target/",
redirectWant: "http://example.com/redirected/target/",
forwardedURI: "/proxy/node/node1:10250/redirect",
},
"source contains the redirect already": {
input: `<pre><a href="kubelet.log">kubelet.log</a><a href="http://foo.com/proxy/node/node1:10250/google.log">google.log</a></pre>`,
sourceURL: "http://foo.com/logs/log.log",
transport: testTransport,
output: `<pre><a href="kubelet.log">kubelet.log</a><a href="http://foo.com/proxy/node/node1:10250/google.log">google.log</a></pre>`,
contentType: "text/html",
forwardedURI: "/proxy/node/node1:10250/logs/log.log",
},
"no host": {
input: "<html></html>",
sourceURL: "http://mynode.com/logs/log.log",
transport: emptyHostTransport,
output: "<html></html>",
contentType: "text/html",
forwardedURI: "/proxy/node/node1:10250/logs/log.log",
},
"no scheme": {
input: "<html></html>",
sourceURL: "http://mynode.com/logs/log.log",
transport: emptySchemeTransport,
output: "<html></html>",
contentType: "text/html",
forwardedURI: "/proxy/node/node1:10250/logs/log.log",
},
}
testItem := func(name string, item *Item) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Check request headers.
if got, want := r.Header.Get("X-Forwarded-Uri"), item.forwardedURI; got != want {
t.Errorf("%v: X-Forwarded-Uri = %q, want %q", name, got, want)
}
if len(item.transport.Host) == 0 {
_, present := r.Header["X-Forwarded-Host"]
if present {
t.Errorf("%v: X-Forwarded-Host header should not be present", name)
}
} else {
if got, want := r.Header.Get("X-Forwarded-Host"), item.transport.Host; got != want {
t.Errorf("%v: X-Forwarded-Host = %q, want %q", name, got, want)
}
}
if len(item.transport.Scheme) == 0 {
_, present := r.Header["X-Forwarded-Proto"]
if present {
t.Errorf("%v: X-Forwarded-Proto header should not be present", name)
}
} else {
if got, want := r.Header.Get("X-Forwarded-Proto"), item.transport.Scheme; got != want {
t.Errorf("%v: X-Forwarded-Proto = %q, want %q", name, got, want)
}
}
// Send response.
if item.redirect != "" {
http.Redirect(w, r, item.redirect, http.StatusMovedPermanently)
return
}
w.Header().Set("Content-Type", item.contentType)
fmt.Fprint(w, item.input)
}))
defer server.Close()
// Replace source URL with our test server address.
sourceURL := parseURLOrDie(item.sourceURL)
serverURL := parseURLOrDie(server.URL)
item.input = strings.Replace(item.input, sourceURL.Host, serverURL.Host, -1)
item.redirect = strings.Replace(item.redirect, sourceURL.Host, serverURL.Host, -1)
sourceURL.Host = serverURL.Host
req, err := http.NewRequest("GET", sourceURL.String(), nil)
if err != nil {
t.Errorf("%v: Unexpected error: %v", name, err)
return
}
resp, err := item.transport.RoundTrip(req)
if err != nil {
t.Errorf("%v: Unexpected error: %v", name, err)
return
}
if item.redirect != "" {
// Check that redirect URLs get rewritten properly.
if got, want := resp.Header.Get("Location"), item.redirectWant; got != want {
t.Errorf("%v: Location header = %q, want %q", name, got, want)
}
return
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Errorf("%v: Unexpected error: %v", name, err)
return
}
if e, a := item.output, string(body); e != a {
t.Errorf("%v: expected %v, but got %v", name, e, a)
}
}
for name, item := range table {
testItem(name, &item)
}
}

72
vendor/k8s.io/apiserver/pkg/util/trace/trace.go generated vendored Normal file
View file

@ -0,0 +1,72 @@
/*
Copyright 2015 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package trace
import (
"bytes"
"fmt"
"time"
"github.com/golang/glog"
)
type traceStep struct {
stepTime time.Time
msg string
}
type Trace struct {
name string
startTime time.Time
steps []traceStep
}
func New(name string) *Trace {
return &Trace{name, time.Now(), nil}
}
func (t *Trace) Step(msg string) {
if t.steps == nil {
// traces almost always have less than 6 steps, do this to avoid more than a single allocation
t.steps = make([]traceStep, 0, 6)
}
t.steps = append(t.steps, traceStep{time.Now(), msg})
}
func (t *Trace) Log() {
endTime := time.Now()
var buffer bytes.Buffer
buffer.WriteString(fmt.Sprintf("Trace %q (started %v):\n", t.name, t.startTime))
lastStepTime := t.startTime
for _, step := range t.steps {
buffer.WriteString(fmt.Sprintf("[%v] [%v] %v\n", step.stepTime.Sub(t.startTime), step.stepTime.Sub(lastStepTime), step.msg))
lastStepTime = step.stepTime
}
buffer.WriteString(fmt.Sprintf("[%v] [%v] END\n", endTime.Sub(t.startTime), endTime.Sub(lastStepTime)))
glog.Info(buffer.String())
}
func (t *Trace) LogIfLong(threshold time.Duration) {
if time.Since(t.startTime) >= threshold {
t.Log()
}
}
func (t *Trace) TotalTime() time.Duration {
return time.Since(t.startTime)
}

79
vendor/k8s.io/apiserver/pkg/util/trie/trie.go generated vendored Normal file
View file

@ -0,0 +1,79 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package trie
// A simple trie implementation with Add an HasPrefix methods only.
type Trie struct {
children map[byte]*Trie
wordTail bool
word string
}
// New creates a Trie and add all strings in the provided list to it.
func New(list []string) Trie {
ret := Trie{
children: make(map[byte]*Trie),
wordTail: false,
}
for _, v := range list {
ret.Add(v)
}
return ret
}
// Add adds a word to this trie
func (t *Trie) Add(v string) {
root := t
for _, b := range []byte(v) {
child, exists := root.children[b]
if !exists {
child = &Trie{
children: make(map[byte]*Trie),
wordTail: false,
}
root.children[b] = child
}
root = child
}
root.wordTail = true
root.word = v
}
// HasPrefix returns true of v has any of the prefixes stored in this trie.
func (t *Trie) HasPrefix(v string) bool {
_, has := t.GetPrefix(v)
return has
}
// GetPrefix is like HasPrefix but return the prefix in case of match or empty string otherwise.
func (t *Trie) GetPrefix(v string) (string, bool) {
root := t
if root.wordTail {
return root.word, true
}
for _, b := range []byte(v) {
child, exists := root.children[b]
if !exists {
return "", false
}
if child.wordTail {
return child.word, true
}
root = child
}
return "", false
}

105
vendor/k8s.io/apiserver/pkg/util/webhook/webhook.go generated vendored Executable file
View file

@ -0,0 +1,105 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package webhook implements a generic HTTP webhook plugin.
package webhook
import (
"fmt"
"time"
apierrors "k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/apimachinery/registered"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/runtime/serializer"
runtimeserializer "k8s.io/apimachinery/pkg/runtime/serializer"
"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/client-go/rest"
"k8s.io/client-go/tools/clientcmd"
)
type GenericWebhook struct {
RestClient *rest.RESTClient
initialBackoff time.Duration
}
// NewGenericWebhook creates a new GenericWebhook from the provided kubeconfig file.
func NewGenericWebhook(registry *registered.APIRegistrationManager, codecFactory serializer.CodecFactory, kubeConfigFile string, groupVersions []schema.GroupVersion, initialBackoff time.Duration) (*GenericWebhook, error) {
for _, groupVersion := range groupVersions {
if !registry.IsEnabledVersion(groupVersion) {
return nil, fmt.Errorf("webhook plugin requires enabling extension resource: %s", groupVersion)
}
}
loadingRules := clientcmd.NewDefaultClientConfigLoadingRules()
loadingRules.ExplicitPath = kubeConfigFile
loader := clientcmd.NewNonInteractiveDeferredLoadingClientConfig(loadingRules, &clientcmd.ConfigOverrides{})
clientConfig, err := loader.ClientConfig()
if err != nil {
return nil, err
}
codec := codecFactory.LegacyCodec(groupVersions...)
clientConfig.ContentConfig.NegotiatedSerializer = runtimeserializer.NegotiatedSerializerWrapper(runtime.SerializerInfo{Serializer: codec})
restClient, err := rest.UnversionedRESTClientFor(clientConfig)
if err != nil {
return nil, err
}
// TODO(ericchiang): Can we ensure remote service is reachable?
return &GenericWebhook{restClient, initialBackoff}, nil
}
// WithExponentialBackoff will retry webhookFn() up to 5 times with exponentially increasing backoff when
// it returns an error for which apierrors.SuggestsClientDelay() or apierrors.IsInternalError() returns true.
func (g *GenericWebhook) WithExponentialBackoff(webhookFn func() rest.Result) rest.Result {
var result rest.Result
WithExponentialBackoff(g.initialBackoff, func() error {
result = webhookFn()
return result.Error()
})
return result
}
// WithExponentialBackoff will retry webhookFn() up to 5 times with exponentially increasing backoff when
// it returns an error for which apierrors.SuggestsClientDelay() or apierrors.IsInternalError() returns true.
func WithExponentialBackoff(initialBackoff time.Duration, webhookFn func() error) error {
backoff := wait.Backoff{
Duration: initialBackoff,
Factor: 1.5,
Jitter: 0.2,
Steps: 5,
}
var err error
wait.ExponentialBackoff(backoff, func() (bool, error) {
err = webhookFn()
if _, shouldRetry := apierrors.SuggestsClientDelay(err); shouldRetry {
return false, nil
}
if apierrors.IsInternalError(err) {
return false, nil
}
if err != nil {
return false, err
}
return true, nil
})
return err
}

349
vendor/k8s.io/apiserver/pkg/util/wsstream/conn.go generated vendored Normal file
View file

@ -0,0 +1,349 @@
/*
Copyright 2015 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package wsstream
import (
"encoding/base64"
"fmt"
"io"
"net/http"
"regexp"
"strings"
"time"
"github.com/golang/glog"
"golang.org/x/net/websocket"
"k8s.io/apimachinery/pkg/util/runtime"
)
// The Websocket subprotocol "channel.k8s.io" prepends each binary message with a byte indicating
// the channel number (zero indexed) the message was sent on. Messages in both directions should
// prefix their messages with this channel byte. When used for remote execution, the channel numbers
// are by convention defined to match the POSIX file-descriptors assigned to STDIN, STDOUT, and STDERR
// (0, 1, and 2). No other conversion is performed on the raw subprotocol - writes are sent as they
// are received by the server.
//
// Example client session:
//
// CONNECT http://server.com with subprotocol "channel.k8s.io"
// WRITE []byte{0, 102, 111, 111, 10} # send "foo\n" on channel 0 (STDIN)
// READ []byte{1, 10} # receive "\n" on channel 1 (STDOUT)
// CLOSE
//
const ChannelWebSocketProtocol = "channel.k8s.io"
// The Websocket subprotocol "base64.channel.k8s.io" base64 encodes each message with a character
// indicating the channel number (zero indexed) the message was sent on. Messages in both directions
// should prefix their messages with this channel char. When used for remote execution, the channel
// numbers are by convention defined to match the POSIX file-descriptors assigned to STDIN, STDOUT,
// and STDERR ('0', '1', and '2'). The data received on the server is base64 decoded (and must be
// be valid) and data written by the server to the client is base64 encoded.
//
// Example client session:
//
// CONNECT http://server.com with subprotocol "base64.channel.k8s.io"
// WRITE []byte{48, 90, 109, 57, 118, 67, 103, 111, 61} # send "foo\n" (base64: "Zm9vCgo=") on channel '0' (STDIN)
// READ []byte{49, 67, 103, 61, 61} # receive "\n" (base64: "Cg==") on channel '1' (STDOUT)
// CLOSE
//
const Base64ChannelWebSocketProtocol = "base64.channel.k8s.io"
type codecType int
const (
rawCodec codecType = iota
base64Codec
)
type ChannelType int
const (
IgnoreChannel ChannelType = iota
ReadChannel
WriteChannel
ReadWriteChannel
)
var (
// connectionUpgradeRegex matches any Connection header value that includes upgrade
connectionUpgradeRegex = regexp.MustCompile("(^|.*,\\s*)upgrade($|\\s*,)")
)
// IsWebSocketRequest returns true if the incoming request contains connection upgrade headers
// for WebSockets.
func IsWebSocketRequest(req *http.Request) bool {
return connectionUpgradeRegex.MatchString(strings.ToLower(req.Header.Get("Connection"))) && strings.ToLower(req.Header.Get("Upgrade")) == "websocket"
}
// IgnoreReceives reads from a WebSocket until it is closed, then returns. If timeout is set, the
// read and write deadlines are pushed every time a new message is received.
func IgnoreReceives(ws *websocket.Conn, timeout time.Duration) {
defer runtime.HandleCrash()
var data []byte
for {
resetTimeout(ws, timeout)
if err := websocket.Message.Receive(ws, &data); err != nil {
return
}
}
}
// handshake ensures the provided user protocol matches one of the allowed protocols. It returns
// no error if no protocol is specified.
func handshake(config *websocket.Config, req *http.Request, allowed []string) error {
protocols := config.Protocol
if len(protocols) == 0 {
protocols = []string{""}
}
for _, protocol := range protocols {
for _, allow := range allowed {
if allow == protocol {
config.Protocol = []string{protocol}
return nil
}
}
}
return fmt.Errorf("requested protocol(s) are not supported: %v; supports %v", config.Protocol, allowed)
}
// ChannelProtocolConfig describes a websocket subprotocol with channels.
type ChannelProtocolConfig struct {
Binary bool
Channels []ChannelType
}
// NewDefaultChannelProtocols returns a channel protocol map with the
// subprotocols "", "channel.k8s.io", "base64.channel.k8s.io" and the given
// channels.
func NewDefaultChannelProtocols(channels []ChannelType) map[string]ChannelProtocolConfig {
return map[string]ChannelProtocolConfig{
"": {Binary: true, Channels: channels},
ChannelWebSocketProtocol: {Binary: true, Channels: channels},
Base64ChannelWebSocketProtocol: {Binary: false, Channels: channels},
}
}
// Conn supports sending multiple binary channels over a websocket connection.
type Conn struct {
protocols map[string]ChannelProtocolConfig
selectedProtocol string
channels []*websocketChannel
codec codecType
ready chan struct{}
ws *websocket.Conn
timeout time.Duration
}
// NewConn creates a WebSocket connection that supports a set of channels. Channels begin each
// web socket message with a single byte indicating the channel number (0-N). 255 is reserved for
// future use. The channel types for each channel are passed as an array, supporting the different
// duplex modes. Read and Write refer to whether the channel can be used as a Reader or Writer.
//
// The protocols parameter maps subprotocol names to ChannelProtocols. The empty string subprotocol
// name is used if websocket.Config.Protocol is empty.
func NewConn(protocols map[string]ChannelProtocolConfig) *Conn {
return &Conn{
ready: make(chan struct{}),
protocols: protocols,
}
}
// SetIdleTimeout sets the interval for both reads and writes before timeout. If not specified,
// there is no timeout on the connection.
func (conn *Conn) SetIdleTimeout(duration time.Duration) {
conn.timeout = duration
}
// Open the connection and create channels for reading and writing. It returns
// the selected subprotocol, a slice of channels and an error.
func (conn *Conn) Open(w http.ResponseWriter, req *http.Request) (string, []io.ReadWriteCloser, error) {
go func() {
defer runtime.HandleCrash()
defer conn.Close()
websocket.Server{Handshake: conn.handshake, Handler: conn.handle}.ServeHTTP(w, req)
}()
<-conn.ready
rwc := make([]io.ReadWriteCloser, len(conn.channels))
for i := range conn.channels {
rwc[i] = conn.channels[i]
}
return conn.selectedProtocol, rwc, nil
}
func (conn *Conn) initialize(ws *websocket.Conn) {
negotiated := ws.Config().Protocol
conn.selectedProtocol = negotiated[0]
p := conn.protocols[conn.selectedProtocol]
if p.Binary {
conn.codec = rawCodec
} else {
conn.codec = base64Codec
}
conn.ws = ws
conn.channels = make([]*websocketChannel, len(p.Channels))
for i, t := range p.Channels {
switch t {
case ReadChannel:
conn.channels[i] = newWebsocketChannel(conn, byte(i), true, false)
case WriteChannel:
conn.channels[i] = newWebsocketChannel(conn, byte(i), false, true)
case ReadWriteChannel:
conn.channels[i] = newWebsocketChannel(conn, byte(i), true, true)
case IgnoreChannel:
conn.channels[i] = newWebsocketChannel(conn, byte(i), false, false)
}
}
close(conn.ready)
}
func (conn *Conn) handshake(config *websocket.Config, req *http.Request) error {
supportedProtocols := make([]string, 0, len(conn.protocols))
for p := range conn.protocols {
supportedProtocols = append(supportedProtocols, p)
}
return handshake(config, req, supportedProtocols)
}
func (conn *Conn) resetTimeout() {
if conn.timeout > 0 {
conn.ws.SetDeadline(time.Now().Add(conn.timeout))
}
}
// Close is only valid after Open has been called
func (conn *Conn) Close() error {
<-conn.ready
for _, s := range conn.channels {
s.Close()
}
conn.ws.Close()
return nil
}
// handle implements a websocket handler.
func (conn *Conn) handle(ws *websocket.Conn) {
defer conn.Close()
conn.initialize(ws)
for {
conn.resetTimeout()
var data []byte
if err := websocket.Message.Receive(ws, &data); err != nil {
if err != io.EOF {
glog.Errorf("Error on socket receive: %v", err)
}
break
}
if len(data) == 0 {
continue
}
channel := data[0]
if conn.codec == base64Codec {
channel = channel - '0'
}
data = data[1:]
if int(channel) >= len(conn.channels) {
glog.V(6).Infof("Frame is targeted for a reader %d that is not valid, possible protocol error", channel)
continue
}
if _, err := conn.channels[channel].DataFromSocket(data); err != nil {
glog.Errorf("Unable to write frame to %d: %v\n%s", channel, err, string(data))
continue
}
}
}
// write multiplexes the specified channel onto the websocket
func (conn *Conn) write(num byte, data []byte) (int, error) {
conn.resetTimeout()
switch conn.codec {
case rawCodec:
frame := make([]byte, len(data)+1)
frame[0] = num
copy(frame[1:], data)
if err := websocket.Message.Send(conn.ws, frame); err != nil {
return 0, err
}
case base64Codec:
frame := string('0'+num) + base64.StdEncoding.EncodeToString(data)
if err := websocket.Message.Send(conn.ws, frame); err != nil {
return 0, err
}
}
return len(data), nil
}
// websocketChannel represents a channel in a connection
type websocketChannel struct {
conn *Conn
num byte
r io.Reader
w io.WriteCloser
read, write bool
}
// newWebsocketChannel creates a pipe for writing to a websocket. Do not write to this pipe
// prior to the connection being opened. It may be no, half, or full duplex depending on
// read and write.
func newWebsocketChannel(conn *Conn, num byte, read, write bool) *websocketChannel {
r, w := io.Pipe()
return &websocketChannel{conn, num, r, w, read, write}
}
func (p *websocketChannel) Write(data []byte) (int, error) {
if !p.write {
return len(data), nil
}
return p.conn.write(p.num, data)
}
// DataFromSocket is invoked by the connection receiver to move data from the connection
// into a specific channel.
func (p *websocketChannel) DataFromSocket(data []byte) (int, error) {
if !p.read {
return len(data), nil
}
switch p.conn.codec {
case rawCodec:
return p.w.Write(data)
case base64Codec:
dst := make([]byte, len(data))
n, err := base64.StdEncoding.Decode(dst, data)
if err != nil {
return 0, err
}
return p.w.Write(dst[:n])
}
return 0, nil
}
func (p *websocketChannel) Read(data []byte) (int, error) {
if !p.read {
return 0, io.EOF
}
return p.r.Read(data)
}
func (p *websocketChannel) Close() error {
return p.w.Close()
}

272
vendor/k8s.io/apiserver/pkg/util/wsstream/conn_test.go generated vendored Normal file
View file

@ -0,0 +1,272 @@
/*
Copyright 2015 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package wsstream
import (
"encoding/base64"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"reflect"
"sync"
"testing"
"golang.org/x/net/websocket"
)
func newServer(handler http.Handler) (*httptest.Server, string) {
server := httptest.NewServer(handler)
serverAddr := server.Listener.Addr().String()
return server, serverAddr
}
func TestRawConn(t *testing.T) {
channels := []ChannelType{ReadWriteChannel, ReadWriteChannel, IgnoreChannel, ReadChannel, WriteChannel}
conn := NewConn(NewDefaultChannelProtocols(channels))
s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
conn.Open(w, req)
}))
defer s.Close()
client, err := websocket.Dial("ws://"+addr, "", "http://localhost/")
if err != nil {
t.Fatal(err)
}
defer client.Close()
<-conn.ready
wg := sync.WaitGroup{}
// verify we can read a client write
wg.Add(1)
go func() {
defer wg.Done()
data, err := ioutil.ReadAll(conn.channels[0])
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(data, []byte("client")) {
t.Errorf("unexpected server read: %v", data)
}
}()
if n, err := client.Write(append([]byte{0}, []byte("client")...)); err != nil || n != 7 {
t.Fatalf("%d: %v", n, err)
}
// verify we can read a server write
wg.Add(1)
go func() {
defer wg.Done()
if n, err := conn.channels[1].Write([]byte("server")); err != nil && n != 6 {
t.Fatalf("%d: %v", n, err)
}
}()
data := make([]byte, 1024)
if n, err := io.ReadAtLeast(client, data, 6); n != 7 || err != nil {
t.Fatalf("%d: %v", n, err)
}
if !reflect.DeepEqual(data[:7], append([]byte{1}, []byte("server")...)) {
t.Errorf("unexpected client read: %v", data[:7])
}
// verify that an ignore channel is empty in both directions.
if n, err := conn.channels[2].Write([]byte("test")); n != 4 || err != nil {
t.Errorf("writes should be ignored")
}
data = make([]byte, 1024)
if n, err := conn.channels[2].Read(data); n != 0 || err != io.EOF {
t.Errorf("reads should be ignored")
}
// verify that a write to a Read channel doesn't block
if n, err := conn.channels[3].Write([]byte("test")); n != 4 || err != nil {
t.Errorf("writes should be ignored")
}
// verify that a read from a Write channel doesn't block
data = make([]byte, 1024)
if n, err := conn.channels[4].Read(data); n != 0 || err != io.EOF {
t.Errorf("reads should be ignored")
}
// verify that a client write to a Write channel doesn't block (is dropped)
if n, err := client.Write(append([]byte{4}, []byte("ignored")...)); err != nil || n != 8 {
t.Fatalf("%d: %v", n, err)
}
client.Close()
wg.Wait()
}
func TestBase64Conn(t *testing.T) {
conn := NewConn(NewDefaultChannelProtocols([]ChannelType{ReadWriteChannel, ReadWriteChannel}))
s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
conn.Open(w, req)
}))
defer s.Close()
config, err := websocket.NewConfig("ws://"+addr, "http://localhost/")
if err != nil {
t.Fatal(err)
}
config.Protocol = []string{"base64.channel.k8s.io"}
client, err := websocket.DialConfig(config)
if err != nil {
t.Fatal(err)
}
defer client.Close()
<-conn.ready
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
data, err := ioutil.ReadAll(conn.channels[0])
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(data, []byte("client")) {
t.Errorf("unexpected server read: %s", string(data))
}
}()
clientData := base64.StdEncoding.EncodeToString([]byte("client"))
if n, err := client.Write(append([]byte{'0'}, clientData...)); err != nil || n != len(clientData)+1 {
t.Fatalf("%d: %v", n, err)
}
wg.Add(1)
go func() {
defer wg.Done()
if n, err := conn.channels[1].Write([]byte("server")); err != nil && n != 6 {
t.Fatalf("%d: %v", n, err)
}
}()
data := make([]byte, 1024)
if n, err := io.ReadAtLeast(client, data, 9); n != 9 || err != nil {
t.Fatalf("%d: %v", n, err)
}
expect := []byte(base64.StdEncoding.EncodeToString([]byte("server")))
if !reflect.DeepEqual(data[:9], append([]byte{'1'}, expect...)) {
t.Errorf("unexpected client read: %v", data[:9])
}
client.Close()
wg.Wait()
}
type versionTest struct {
supported map[string]bool // protocol -> binary
requested []string
error bool
expected string
}
func versionTests() []versionTest {
const (
binary = true
base64 = false
)
return []versionTest{
{
supported: nil,
requested: []string{"raw"},
error: true,
},
{
supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
requested: nil,
expected: "",
},
{
supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
requested: []string{"v1.raw"},
error: true,
},
{
supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
requested: []string{"v1.raw", "v1.base64"},
error: true,
}, {
supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
requested: []string{"v1.raw", "raw"},
expected: "raw",
},
{
supported: map[string]bool{"": binary, "v1.raw": binary, "v1.base64": base64, "v2.raw": binary, "v2.base64": base64},
requested: []string{"v1.raw"},
expected: "v1.raw",
},
{
supported: map[string]bool{"": binary, "v1.raw": binary, "v1.base64": base64, "v2.raw": binary, "v2.base64": base64},
requested: []string{"v2.base64"},
expected: "v2.base64",
},
}
}
func TestVersionedConn(t *testing.T) {
for i, test := range versionTests() {
func() {
supportedProtocols := map[string]ChannelProtocolConfig{}
for p, binary := range test.supported {
supportedProtocols[p] = ChannelProtocolConfig{
Binary: binary,
Channels: []ChannelType{ReadWriteChannel},
}
}
conn := NewConn(supportedProtocols)
// note that it's not enough to wait for conn.ready to avoid a race here. Hence,
// we use a channel.
selectedProtocol := make(chan string, 0)
s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
p, _, _ := conn.Open(w, req)
selectedProtocol <- p
}))
defer s.Close()
config, err := websocket.NewConfig("ws://"+addr, "http://localhost/")
if err != nil {
t.Fatal(err)
}
config.Protocol = test.requested
client, err := websocket.DialConfig(config)
if err != nil {
if !test.error {
t.Fatalf("test %d: didn't expect error: %v", i, err)
} else {
return
}
}
defer client.Close()
if test.error && err == nil {
t.Fatalf("test %d: expected an error", i)
}
<-conn.ready
if got, expected := <-selectedProtocol, test.expected; got != expected {
t.Fatalf("test %d: unexpected protocol version: got=%s expected=%s", i, got, expected)
}
}()
}
}

21
vendor/k8s.io/apiserver/pkg/util/wsstream/doc.go generated vendored Normal file
View file

@ -0,0 +1,21 @@
/*
Copyright 2015 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package wsstream contains utilities for streaming content over WebSockets.
// The Conn type allows callers to multiplex multiple read/write channels over
// a single websocket. The Reader type allows an io.Reader to be copied over
// a websocket channel as binary content.
package wsstream // import "k8s.io/apiserver/pkg/util/wsstream"

177
vendor/k8s.io/apiserver/pkg/util/wsstream/stream.go generated vendored Normal file
View file

@ -0,0 +1,177 @@
/*
Copyright 2015 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package wsstream
import (
"encoding/base64"
"io"
"net/http"
"sync"
"time"
"golang.org/x/net/websocket"
"k8s.io/apimachinery/pkg/util/runtime"
)
// The WebSocket subprotocol "binary.k8s.io" will only send messages to the
// client and ignore messages sent to the server. The received messages are
// the exact bytes written to the stream. Zero byte messages are possible.
const binaryWebSocketProtocol = "binary.k8s.io"
// The WebSocket subprotocol "base64.binary.k8s.io" will only send messages to the
// client and ignore messages sent to the server. The received messages are
// a base64 version of the bytes written to the stream. Zero byte messages are
// possible.
const base64BinaryWebSocketProtocol = "base64.binary.k8s.io"
// ReaderProtocolConfig describes a websocket subprotocol with one stream.
type ReaderProtocolConfig struct {
Binary bool
}
// NewDefaultReaderProtocols returns a stream protocol map with the
// subprotocols "", "channel.k8s.io", "base64.channel.k8s.io".
func NewDefaultReaderProtocols() map[string]ReaderProtocolConfig {
return map[string]ReaderProtocolConfig{
"": {Binary: true},
binaryWebSocketProtocol: {Binary: true},
base64BinaryWebSocketProtocol: {Binary: false},
}
}
// Reader supports returning an arbitrary byte stream over a websocket channel.
type Reader struct {
err chan error
r io.Reader
ping bool
timeout time.Duration
protocols map[string]ReaderProtocolConfig
selectedProtocol string
handleCrash func() // overridable for testing
}
// NewReader creates a WebSocket pipe that will copy the contents of r to a provided
// WebSocket connection. If ping is true, a zero length message will be sent to the client
// before the stream begins reading.
//
// The protocols parameter maps subprotocol names to StreamProtocols. The empty string
// subprotocol name is used if websocket.Config.Protocol is empty.
func NewReader(r io.Reader, ping bool, protocols map[string]ReaderProtocolConfig) *Reader {
return &Reader{
r: r,
err: make(chan error),
ping: ping,
protocols: protocols,
handleCrash: func() { runtime.HandleCrash() },
}
}
// SetIdleTimeout sets the interval for both reads and writes before timeout. If not specified,
// there is no timeout on the reader.
func (r *Reader) SetIdleTimeout(duration time.Duration) {
r.timeout = duration
}
func (r *Reader) handshake(config *websocket.Config, req *http.Request) error {
supportedProtocols := make([]string, 0, len(r.protocols))
for p := range r.protocols {
supportedProtocols = append(supportedProtocols, p)
}
return handshake(config, req, supportedProtocols)
}
// Copy the reader to the response. The created WebSocket is closed after this
// method completes.
func (r *Reader) Copy(w http.ResponseWriter, req *http.Request) error {
go func() {
defer r.handleCrash()
websocket.Server{Handshake: r.handshake, Handler: r.handle}.ServeHTTP(w, req)
}()
return <-r.err
}
// handle implements a WebSocket handler.
func (r *Reader) handle(ws *websocket.Conn) {
// Close the connection when the client requests it, or when we finish streaming, whichever happens first
closeConnOnce := &sync.Once{}
closeConn := func() {
closeConnOnce.Do(func() {
ws.Close()
})
}
negotiated := ws.Config().Protocol
r.selectedProtocol = negotiated[0]
defer close(r.err)
defer closeConn()
go func() {
defer runtime.HandleCrash()
// This blocks until the connection is closed.
// Client should not send anything.
IgnoreReceives(ws, r.timeout)
// Once the client closes, we should also close
closeConn()
}()
r.err <- messageCopy(ws, r.r, !r.protocols[r.selectedProtocol].Binary, r.ping, r.timeout)
}
func resetTimeout(ws *websocket.Conn, timeout time.Duration) {
if timeout > 0 {
ws.SetDeadline(time.Now().Add(timeout))
}
}
func messageCopy(ws *websocket.Conn, r io.Reader, base64Encode, ping bool, timeout time.Duration) error {
buf := make([]byte, 2048)
if ping {
resetTimeout(ws, timeout)
if base64Encode {
if err := websocket.Message.Send(ws, ""); err != nil {
return err
}
} else {
if err := websocket.Message.Send(ws, []byte{}); err != nil {
return err
}
}
}
for {
resetTimeout(ws, timeout)
n, err := r.Read(buf)
if err != nil {
if err == io.EOF {
return nil
}
return err
}
if n > 0 {
if base64Encode {
if err := websocket.Message.Send(ws, base64.StdEncoding.EncodeToString(buf[:n])); err != nil {
return err
}
} else {
if err := websocket.Message.Send(ws, buf[:n]); err != nil {
return err
}
}
}
}
}

View file

@ -0,0 +1,294 @@
/*
Copyright 2015 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package wsstream
import (
"bytes"
"encoding/base64"
"fmt"
"io"
"io/ioutil"
"net/http"
"reflect"
"strings"
"testing"
"time"
"golang.org/x/net/websocket"
)
func TestStream(t *testing.T) {
input := "some random text"
r := NewReader(bytes.NewBuffer([]byte(input)), true, NewDefaultReaderProtocols())
r.SetIdleTimeout(time.Second)
data, err := readWebSocket(r, t, nil)
if !reflect.DeepEqual(data, []byte(input)) {
t.Errorf("unexpected server read: %v", data)
}
if err != nil {
t.Fatal(err)
}
}
func TestStreamPing(t *testing.T) {
input := "some random text"
r := NewReader(bytes.NewBuffer([]byte(input)), true, NewDefaultReaderProtocols())
r.SetIdleTimeout(time.Second)
err := expectWebSocketFrames(r, t, nil, [][]byte{
{},
[]byte(input),
})
if err != nil {
t.Fatal(err)
}
}
func TestStreamBase64(t *testing.T) {
input := "some random text"
encoded := base64.StdEncoding.EncodeToString([]byte(input))
r := NewReader(bytes.NewBuffer([]byte(input)), true, NewDefaultReaderProtocols())
data, err := readWebSocket(r, t, nil, "base64.binary.k8s.io")
if !reflect.DeepEqual(data, []byte(encoded)) {
t.Errorf("unexpected server read: %v\n%v", data, []byte(encoded))
}
if err != nil {
t.Fatal(err)
}
}
func TestStreamVersionedBase64(t *testing.T) {
input := "some random text"
encoded := base64.StdEncoding.EncodeToString([]byte(input))
r := NewReader(bytes.NewBuffer([]byte(input)), true, map[string]ReaderProtocolConfig{
"": {Binary: true},
"binary.k8s.io": {Binary: true},
"base64.binary.k8s.io": {Binary: false},
"v1.binary.k8s.io": {Binary: true},
"v1.base64.binary.k8s.io": {Binary: false},
"v2.binary.k8s.io": {Binary: true},
"v2.base64.binary.k8s.io": {Binary: false},
})
data, err := readWebSocket(r, t, nil, "v2.base64.binary.k8s.io")
if !reflect.DeepEqual(data, []byte(encoded)) {
t.Errorf("unexpected server read: %v\n%v", data, []byte(encoded))
}
if err != nil {
t.Fatal(err)
}
}
func TestStreamVersionedCopy(t *testing.T) {
for i, test := range versionTests() {
func() {
supportedProtocols := map[string]ReaderProtocolConfig{}
for p, binary := range test.supported {
supportedProtocols[p] = ReaderProtocolConfig{
Binary: binary,
}
}
input := "some random text"
r := NewReader(bytes.NewBuffer([]byte(input)), true, supportedProtocols)
s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
err := r.Copy(w, req)
if err != nil {
w.WriteHeader(503)
}
}))
defer s.Close()
config, err := websocket.NewConfig("ws://"+addr, "http://localhost/")
if err != nil {
t.Error(err)
return
}
config.Protocol = test.requested
client, err := websocket.DialConfig(config)
if err != nil {
if !test.error {
t.Errorf("test %d: didn't expect error: %v", i, err)
}
return
}
defer client.Close()
if test.error && err == nil {
t.Errorf("test %d: expected an error", i)
return
}
<-r.err
if got, expected := r.selectedProtocol, test.expected; got != expected {
t.Errorf("test %d: unexpected protocol version: got=%s expected=%s", i, got, expected)
}
}()
}
}
func TestStreamError(t *testing.T) {
input := "some random text"
errs := &errorReader{
reads: [][]byte{
[]byte("some random"),
[]byte(" text"),
},
err: fmt.Errorf("bad read"),
}
r := NewReader(errs, false, NewDefaultReaderProtocols())
data, err := readWebSocket(r, t, nil)
if !reflect.DeepEqual(data, []byte(input)) {
t.Errorf("unexpected server read: %v", data)
}
if err == nil || err.Error() != "bad read" {
t.Fatal(err)
}
}
func TestStreamSurvivesPanic(t *testing.T) {
input := "some random text"
errs := &errorReader{
reads: [][]byte{
[]byte("some random"),
[]byte(" text"),
},
panicMessage: "bad read",
}
r := NewReader(errs, false, NewDefaultReaderProtocols())
// do not call runtime.HandleCrash() in handler. Otherwise, the tests are interrupted.
r.handleCrash = func() { recover() }
data, err := readWebSocket(r, t, nil)
if !reflect.DeepEqual(data, []byte(input)) {
t.Errorf("unexpected server read: %v", data)
}
if err != nil {
t.Fatal(err)
}
}
func TestStreamClosedDuringRead(t *testing.T) {
for i := 0; i < 25; i++ {
ch := make(chan struct{})
input := "some random text"
errs := &errorReader{
reads: [][]byte{
[]byte("some random"),
[]byte(" text"),
},
err: fmt.Errorf("stuff"),
pause: ch,
}
r := NewReader(errs, false, NewDefaultReaderProtocols())
data, err := readWebSocket(r, t, func(c *websocket.Conn) {
c.Close()
close(ch)
})
// verify that the data returned by the server on an early close always has a specific error
if err == nil || !strings.Contains(err.Error(), "use of closed network connection") {
t.Fatal(err)
}
// verify that the data returned is a strict subset of the input
if !bytes.HasPrefix([]byte(input), data) && len(data) != 0 {
t.Fatalf("unexpected server read: %q", string(data))
}
}
}
type errorReader struct {
reads [][]byte
err error
panicMessage string
pause chan struct{}
}
func (r *errorReader) Read(p []byte) (int, error) {
if len(r.reads) == 0 {
if r.pause != nil {
<-r.pause
}
if len(r.panicMessage) != 0 {
panic(r.panicMessage)
}
return 0, r.err
}
next := r.reads[0]
r.reads = r.reads[1:]
copy(p, next)
return len(next), nil
}
func readWebSocket(r *Reader, t *testing.T, fn func(*websocket.Conn), protocols ...string) ([]byte, error) {
errCh := make(chan error, 1)
s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
errCh <- r.Copy(w, req)
}))
defer s.Close()
config, _ := websocket.NewConfig("ws://"+addr, "http://"+addr)
config.Protocol = protocols
client, err := websocket.DialConfig(config)
if err != nil {
return nil, err
}
defer client.Close()
if fn != nil {
fn(client)
}
data, err := ioutil.ReadAll(client)
if err != nil {
return data, err
}
return data, <-errCh
}
func expectWebSocketFrames(r *Reader, t *testing.T, fn func(*websocket.Conn), frames [][]byte, protocols ...string) error {
errCh := make(chan error, 1)
s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
errCh <- r.Copy(w, req)
}))
defer s.Close()
config, _ := websocket.NewConfig("ws://"+addr, "http://"+addr)
config.Protocol = protocols
ws, err := websocket.DialConfig(config)
if err != nil {
return err
}
defer ws.Close()
if fn != nil {
fn(ws)
}
for i := range frames {
var data []byte
if err := websocket.Message.Receive(ws, &data); err != nil {
return err
}
if !reflect.DeepEqual(frames[i], data) {
return fmt.Errorf("frame %d did not match expected: %v", data, err)
}
}
var data []byte
if err := websocket.Message.Receive(ws, &data); err != io.EOF {
return fmt.Errorf("expected no more frames: %v (%v)", err, data)
}
return <-errCh
}