295 lines
7.4 KiB
Go
295 lines
7.4 KiB
Go
|
/*
|
||
|
Copyright 2015 The Kubernetes Authors.
|
||
|
|
||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
you may not use this file except in compliance with the License.
|
||
|
You may obtain a copy of the License at
|
||
|
|
||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||
|
|
||
|
Unless required by applicable law or agreed to in writing, software
|
||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
See the License for the specific language governing permissions and
|
||
|
limitations under the License.
|
||
|
*/
|
||
|
|
||
|
package wsstream
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"encoding/base64"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"io/ioutil"
|
||
|
"net/http"
|
||
|
"reflect"
|
||
|
"strings"
|
||
|
"testing"
|
||
|
"time"
|
||
|
|
||
|
"golang.org/x/net/websocket"
|
||
|
)
|
||
|
|
||
|
func TestStream(t *testing.T) {
|
||
|
input := "some random text"
|
||
|
r := NewReader(bytes.NewBuffer([]byte(input)), true, NewDefaultReaderProtocols())
|
||
|
r.SetIdleTimeout(time.Second)
|
||
|
data, err := readWebSocket(r, t, nil)
|
||
|
if !reflect.DeepEqual(data, []byte(input)) {
|
||
|
t.Errorf("unexpected server read: %v", data)
|
||
|
}
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestStreamPing(t *testing.T) {
|
||
|
input := "some random text"
|
||
|
r := NewReader(bytes.NewBuffer([]byte(input)), true, NewDefaultReaderProtocols())
|
||
|
r.SetIdleTimeout(time.Second)
|
||
|
err := expectWebSocketFrames(r, t, nil, [][]byte{
|
||
|
{},
|
||
|
[]byte(input),
|
||
|
})
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestStreamBase64(t *testing.T) {
|
||
|
input := "some random text"
|
||
|
encoded := base64.StdEncoding.EncodeToString([]byte(input))
|
||
|
r := NewReader(bytes.NewBuffer([]byte(input)), true, NewDefaultReaderProtocols())
|
||
|
data, err := readWebSocket(r, t, nil, "base64.binary.k8s.io")
|
||
|
if !reflect.DeepEqual(data, []byte(encoded)) {
|
||
|
t.Errorf("unexpected server read: %v\n%v", data, []byte(encoded))
|
||
|
}
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestStreamVersionedBase64(t *testing.T) {
|
||
|
input := "some random text"
|
||
|
encoded := base64.StdEncoding.EncodeToString([]byte(input))
|
||
|
r := NewReader(bytes.NewBuffer([]byte(input)), true, map[string]ReaderProtocolConfig{
|
||
|
"": {Binary: true},
|
||
|
"binary.k8s.io": {Binary: true},
|
||
|
"base64.binary.k8s.io": {Binary: false},
|
||
|
"v1.binary.k8s.io": {Binary: true},
|
||
|
"v1.base64.binary.k8s.io": {Binary: false},
|
||
|
"v2.binary.k8s.io": {Binary: true},
|
||
|
"v2.base64.binary.k8s.io": {Binary: false},
|
||
|
})
|
||
|
data, err := readWebSocket(r, t, nil, "v2.base64.binary.k8s.io")
|
||
|
if !reflect.DeepEqual(data, []byte(encoded)) {
|
||
|
t.Errorf("unexpected server read: %v\n%v", data, []byte(encoded))
|
||
|
}
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestStreamVersionedCopy(t *testing.T) {
|
||
|
for i, test := range versionTests() {
|
||
|
func() {
|
||
|
supportedProtocols := map[string]ReaderProtocolConfig{}
|
||
|
for p, binary := range test.supported {
|
||
|
supportedProtocols[p] = ReaderProtocolConfig{
|
||
|
Binary: binary,
|
||
|
}
|
||
|
}
|
||
|
input := "some random text"
|
||
|
r := NewReader(bytes.NewBuffer([]byte(input)), true, supportedProtocols)
|
||
|
s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||
|
err := r.Copy(w, req)
|
||
|
if err != nil {
|
||
|
w.WriteHeader(503)
|
||
|
}
|
||
|
}))
|
||
|
defer s.Close()
|
||
|
|
||
|
config, err := websocket.NewConfig("ws://"+addr, "http://localhost/")
|
||
|
if err != nil {
|
||
|
t.Error(err)
|
||
|
return
|
||
|
}
|
||
|
config.Protocol = test.requested
|
||
|
client, err := websocket.DialConfig(config)
|
||
|
if err != nil {
|
||
|
if !test.error {
|
||
|
t.Errorf("test %d: didn't expect error: %v", i, err)
|
||
|
}
|
||
|
return
|
||
|
}
|
||
|
defer client.Close()
|
||
|
if test.error && err == nil {
|
||
|
t.Errorf("test %d: expected an error", i)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
<-r.err
|
||
|
if got, expected := r.selectedProtocol, test.expected; got != expected {
|
||
|
t.Errorf("test %d: unexpected protocol version: got=%s expected=%s", i, got, expected)
|
||
|
}
|
||
|
}()
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestStreamError(t *testing.T) {
|
||
|
input := "some random text"
|
||
|
errs := &errorReader{
|
||
|
reads: [][]byte{
|
||
|
[]byte("some random"),
|
||
|
[]byte(" text"),
|
||
|
},
|
||
|
err: fmt.Errorf("bad read"),
|
||
|
}
|
||
|
r := NewReader(errs, false, NewDefaultReaderProtocols())
|
||
|
|
||
|
data, err := readWebSocket(r, t, nil)
|
||
|
if !reflect.DeepEqual(data, []byte(input)) {
|
||
|
t.Errorf("unexpected server read: %v", data)
|
||
|
}
|
||
|
if err == nil || err.Error() != "bad read" {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestStreamSurvivesPanic(t *testing.T) {
|
||
|
input := "some random text"
|
||
|
errs := &errorReader{
|
||
|
reads: [][]byte{
|
||
|
[]byte("some random"),
|
||
|
[]byte(" text"),
|
||
|
},
|
||
|
panicMessage: "bad read",
|
||
|
}
|
||
|
r := NewReader(errs, false, NewDefaultReaderProtocols())
|
||
|
|
||
|
// do not call runtime.HandleCrash() in handler. Otherwise, the tests are interrupted.
|
||
|
r.handleCrash = func() { recover() }
|
||
|
|
||
|
data, err := readWebSocket(r, t, nil)
|
||
|
if !reflect.DeepEqual(data, []byte(input)) {
|
||
|
t.Errorf("unexpected server read: %v", data)
|
||
|
}
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestStreamClosedDuringRead(t *testing.T) {
|
||
|
for i := 0; i < 25; i++ {
|
||
|
ch := make(chan struct{})
|
||
|
input := "some random text"
|
||
|
errs := &errorReader{
|
||
|
reads: [][]byte{
|
||
|
[]byte("some random"),
|
||
|
[]byte(" text"),
|
||
|
},
|
||
|
err: fmt.Errorf("stuff"),
|
||
|
pause: ch,
|
||
|
}
|
||
|
r := NewReader(errs, false, NewDefaultReaderProtocols())
|
||
|
|
||
|
data, err := readWebSocket(r, t, func(c *websocket.Conn) {
|
||
|
c.Close()
|
||
|
close(ch)
|
||
|
})
|
||
|
// verify that the data returned by the server on an early close always has a specific error
|
||
|
if err == nil || !strings.Contains(err.Error(), "use of closed network connection") {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
// verify that the data returned is a strict subset of the input
|
||
|
if !bytes.HasPrefix([]byte(input), data) && len(data) != 0 {
|
||
|
t.Fatalf("unexpected server read: %q", string(data))
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
type errorReader struct {
|
||
|
reads [][]byte
|
||
|
err error
|
||
|
panicMessage string
|
||
|
pause chan struct{}
|
||
|
}
|
||
|
|
||
|
func (r *errorReader) Read(p []byte) (int, error) {
|
||
|
if len(r.reads) == 0 {
|
||
|
if r.pause != nil {
|
||
|
<-r.pause
|
||
|
}
|
||
|
if len(r.panicMessage) != 0 {
|
||
|
panic(r.panicMessage)
|
||
|
}
|
||
|
return 0, r.err
|
||
|
}
|
||
|
next := r.reads[0]
|
||
|
r.reads = r.reads[1:]
|
||
|
copy(p, next)
|
||
|
return len(next), nil
|
||
|
}
|
||
|
|
||
|
func readWebSocket(r *Reader, t *testing.T, fn func(*websocket.Conn), protocols ...string) ([]byte, error) {
|
||
|
errCh := make(chan error, 1)
|
||
|
s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||
|
errCh <- r.Copy(w, req)
|
||
|
}))
|
||
|
defer s.Close()
|
||
|
|
||
|
config, _ := websocket.NewConfig("ws://"+addr, "http://"+addr)
|
||
|
config.Protocol = protocols
|
||
|
client, err := websocket.DialConfig(config)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
defer client.Close()
|
||
|
|
||
|
if fn != nil {
|
||
|
fn(client)
|
||
|
}
|
||
|
|
||
|
data, err := ioutil.ReadAll(client)
|
||
|
if err != nil {
|
||
|
return data, err
|
||
|
}
|
||
|
return data, <-errCh
|
||
|
}
|
||
|
|
||
|
func expectWebSocketFrames(r *Reader, t *testing.T, fn func(*websocket.Conn), frames [][]byte, protocols ...string) error {
|
||
|
errCh := make(chan error, 1)
|
||
|
s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||
|
errCh <- r.Copy(w, req)
|
||
|
}))
|
||
|
defer s.Close()
|
||
|
|
||
|
config, _ := websocket.NewConfig("ws://"+addr, "http://"+addr)
|
||
|
config.Protocol = protocols
|
||
|
ws, err := websocket.DialConfig(config)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
defer ws.Close()
|
||
|
|
||
|
if fn != nil {
|
||
|
fn(ws)
|
||
|
}
|
||
|
|
||
|
for i := range frames {
|
||
|
var data []byte
|
||
|
if err := websocket.Message.Receive(ws, &data); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
if !reflect.DeepEqual(frames[i], data) {
|
||
|
return fmt.Errorf("frame %d did not match expected: %v", data, err)
|
||
|
}
|
||
|
}
|
||
|
var data []byte
|
||
|
if err := websocket.Message.Receive(ws, &data); err != io.EOF {
|
||
|
return fmt.Errorf("expected no more frames: %v (%v)", err, data)
|
||
|
}
|
||
|
return <-errCh
|
||
|
}
|