/* Copyright 2015 The Kubernetes Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package wsstream import ( "bytes" "encoding/base64" "fmt" "io" "io/ioutil" "net/http" "reflect" "strings" "testing" "time" "golang.org/x/net/websocket" ) func TestStream(t *testing.T) { input := "some random text" r := NewReader(bytes.NewBuffer([]byte(input)), true, NewDefaultReaderProtocols()) r.SetIdleTimeout(time.Second) data, err := readWebSocket(r, t, nil) if !reflect.DeepEqual(data, []byte(input)) { t.Errorf("unexpected server read: %v", data) } if err != nil { t.Fatal(err) } } func TestStreamPing(t *testing.T) { input := "some random text" r := NewReader(bytes.NewBuffer([]byte(input)), true, NewDefaultReaderProtocols()) r.SetIdleTimeout(time.Second) err := expectWebSocketFrames(r, t, nil, [][]byte{ {}, []byte(input), }) if err != nil { t.Fatal(err) } } func TestStreamBase64(t *testing.T) { input := "some random text" encoded := base64.StdEncoding.EncodeToString([]byte(input)) r := NewReader(bytes.NewBuffer([]byte(input)), true, NewDefaultReaderProtocols()) data, err := readWebSocket(r, t, nil, "base64.binary.k8s.io") if !reflect.DeepEqual(data, []byte(encoded)) { t.Errorf("unexpected server read: %v\n%v", data, []byte(encoded)) } if err != nil { t.Fatal(err) } } func TestStreamVersionedBase64(t *testing.T) { input := "some random text" encoded := base64.StdEncoding.EncodeToString([]byte(input)) r := NewReader(bytes.NewBuffer([]byte(input)), true, map[string]ReaderProtocolConfig{ "": {Binary: true}, "binary.k8s.io": {Binary: true}, "base64.binary.k8s.io": {Binary: false}, "v1.binary.k8s.io": {Binary: true}, "v1.base64.binary.k8s.io": {Binary: false}, "v2.binary.k8s.io": {Binary: true}, "v2.base64.binary.k8s.io": {Binary: false}, }) data, err := readWebSocket(r, t, nil, "v2.base64.binary.k8s.io") if !reflect.DeepEqual(data, []byte(encoded)) { t.Errorf("unexpected server read: %v\n%v", data, []byte(encoded)) } if err != nil { t.Fatal(err) } } func TestStreamVersionedCopy(t *testing.T) { for i, test := range versionTests() { func() { supportedProtocols := map[string]ReaderProtocolConfig{} for p, binary := range test.supported { supportedProtocols[p] = ReaderProtocolConfig{ Binary: binary, } } input := "some random text" r := NewReader(bytes.NewBuffer([]byte(input)), true, supportedProtocols) s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { err := r.Copy(w, req) if err != nil { w.WriteHeader(503) } })) defer s.Close() config, err := websocket.NewConfig("ws://"+addr, "http://localhost/") if err != nil { t.Error(err) return } config.Protocol = test.requested client, err := websocket.DialConfig(config) if err != nil { if !test.error { t.Errorf("test %d: didn't expect error: %v", i, err) } return } defer client.Close() if test.error && err == nil { t.Errorf("test %d: expected an error", i) return } <-r.err if got, expected := r.selectedProtocol, test.expected; got != expected { t.Errorf("test %d: unexpected protocol version: got=%s expected=%s", i, got, expected) } }() } } func TestStreamError(t *testing.T) { input := "some random text" errs := &errorReader{ reads: [][]byte{ []byte("some random"), []byte(" text"), }, err: fmt.Errorf("bad read"), } r := NewReader(errs, false, NewDefaultReaderProtocols()) data, err := readWebSocket(r, t, nil) if !reflect.DeepEqual(data, []byte(input)) { t.Errorf("unexpected server read: %v", data) } if err == nil || err.Error() != "bad read" { t.Fatal(err) } } func TestStreamSurvivesPanic(t *testing.T) { input := "some random text" errs := &errorReader{ reads: [][]byte{ []byte("some random"), []byte(" text"), }, panicMessage: "bad read", } r := NewReader(errs, false, NewDefaultReaderProtocols()) // do not call runtime.HandleCrash() in handler. Otherwise, the tests are interrupted. r.handleCrash = func() { recover() } data, err := readWebSocket(r, t, nil) if !reflect.DeepEqual(data, []byte(input)) { t.Errorf("unexpected server read: %v", data) } if err != nil { t.Fatal(err) } } func TestStreamClosedDuringRead(t *testing.T) { for i := 0; i < 25; i++ { ch := make(chan struct{}) input := "some random text" errs := &errorReader{ reads: [][]byte{ []byte("some random"), []byte(" text"), }, err: fmt.Errorf("stuff"), pause: ch, } r := NewReader(errs, false, NewDefaultReaderProtocols()) data, err := readWebSocket(r, t, func(c *websocket.Conn) { c.Close() close(ch) }) // verify that the data returned by the server on an early close always has a specific error if err == nil || !strings.Contains(err.Error(), "use of closed network connection") { t.Fatal(err) } // verify that the data returned is a strict subset of the input if !bytes.HasPrefix([]byte(input), data) && len(data) != 0 { t.Fatalf("unexpected server read: %q", string(data)) } } } type errorReader struct { reads [][]byte err error panicMessage string pause chan struct{} } func (r *errorReader) Read(p []byte) (int, error) { if len(r.reads) == 0 { if r.pause != nil { <-r.pause } if len(r.panicMessage) != 0 { panic(r.panicMessage) } return 0, r.err } next := r.reads[0] r.reads = r.reads[1:] copy(p, next) return len(next), nil } func readWebSocket(r *Reader, t *testing.T, fn func(*websocket.Conn), protocols ...string) ([]byte, error) { errCh := make(chan error, 1) s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { errCh <- r.Copy(w, req) })) defer s.Close() config, _ := websocket.NewConfig("ws://"+addr, "http://"+addr) config.Protocol = protocols client, err := websocket.DialConfig(config) if err != nil { return nil, err } defer client.Close() if fn != nil { fn(client) } data, err := ioutil.ReadAll(client) if err != nil { return data, err } return data, <-errCh } func expectWebSocketFrames(r *Reader, t *testing.T, fn func(*websocket.Conn), frames [][]byte, protocols ...string) error { errCh := make(chan error, 1) s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { errCh <- r.Copy(w, req) })) defer s.Close() config, _ := websocket.NewConfig("ws://"+addr, "http://"+addr) config.Protocol = protocols ws, err := websocket.DialConfig(config) if err != nil { return err } defer ws.Close() if fn != nil { fn(ws) } for i := range frames { var data []byte if err := websocket.Message.Receive(ws, &data); err != nil { return err } if !reflect.DeepEqual(frames[i], data) { return fmt.Errorf("frame %d did not match expected: %v", data, err) } } var data []byte if err := websocket.Message.Receive(ws, &data); err != io.EOF { return fmt.Errorf("expected no more frames: %v (%v)", err, data) } return <-errCh }