273 lines
7.1 KiB
Go
273 lines
7.1 KiB
Go
|
/*
|
||
|
Copyright 2015 The Kubernetes Authors.
|
||
|
|
||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
you may not use this file except in compliance with the License.
|
||
|
You may obtain a copy of the License at
|
||
|
|
||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||
|
|
||
|
Unless required by applicable law or agreed to in writing, software
|
||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
See the License for the specific language governing permissions and
|
||
|
limitations under the License.
|
||
|
*/
|
||
|
|
||
|
package wsstream
|
||
|
|
||
|
import (
|
||
|
"encoding/base64"
|
||
|
"io"
|
||
|
"io/ioutil"
|
||
|
"net/http"
|
||
|
"net/http/httptest"
|
||
|
"reflect"
|
||
|
"sync"
|
||
|
"testing"
|
||
|
|
||
|
"golang.org/x/net/websocket"
|
||
|
)
|
||
|
|
||
|
func newServer(handler http.Handler) (*httptest.Server, string) {
|
||
|
server := httptest.NewServer(handler)
|
||
|
serverAddr := server.Listener.Addr().String()
|
||
|
return server, serverAddr
|
||
|
}
|
||
|
|
||
|
func TestRawConn(t *testing.T) {
|
||
|
channels := []ChannelType{ReadWriteChannel, ReadWriteChannel, IgnoreChannel, ReadChannel, WriteChannel}
|
||
|
conn := NewConn(NewDefaultChannelProtocols(channels))
|
||
|
|
||
|
s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||
|
conn.Open(w, req)
|
||
|
}))
|
||
|
defer s.Close()
|
||
|
|
||
|
client, err := websocket.Dial("ws://"+addr, "", "http://localhost/")
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
defer client.Close()
|
||
|
|
||
|
<-conn.ready
|
||
|
wg := sync.WaitGroup{}
|
||
|
|
||
|
// verify we can read a client write
|
||
|
wg.Add(1)
|
||
|
go func() {
|
||
|
defer wg.Done()
|
||
|
data, err := ioutil.ReadAll(conn.channels[0])
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
if !reflect.DeepEqual(data, []byte("client")) {
|
||
|
t.Errorf("unexpected server read: %v", data)
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
if n, err := client.Write(append([]byte{0}, []byte("client")...)); err != nil || n != 7 {
|
||
|
t.Fatalf("%d: %v", n, err)
|
||
|
}
|
||
|
|
||
|
// verify we can read a server write
|
||
|
wg.Add(1)
|
||
|
go func() {
|
||
|
defer wg.Done()
|
||
|
if n, err := conn.channels[1].Write([]byte("server")); err != nil && n != 6 {
|
||
|
t.Fatalf("%d: %v", n, err)
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
data := make([]byte, 1024)
|
||
|
if n, err := io.ReadAtLeast(client, data, 6); n != 7 || err != nil {
|
||
|
t.Fatalf("%d: %v", n, err)
|
||
|
}
|
||
|
if !reflect.DeepEqual(data[:7], append([]byte{1}, []byte("server")...)) {
|
||
|
t.Errorf("unexpected client read: %v", data[:7])
|
||
|
}
|
||
|
|
||
|
// verify that an ignore channel is empty in both directions.
|
||
|
if n, err := conn.channels[2].Write([]byte("test")); n != 4 || err != nil {
|
||
|
t.Errorf("writes should be ignored")
|
||
|
}
|
||
|
data = make([]byte, 1024)
|
||
|
if n, err := conn.channels[2].Read(data); n != 0 || err != io.EOF {
|
||
|
t.Errorf("reads should be ignored")
|
||
|
}
|
||
|
|
||
|
// verify that a write to a Read channel doesn't block
|
||
|
if n, err := conn.channels[3].Write([]byte("test")); n != 4 || err != nil {
|
||
|
t.Errorf("writes should be ignored")
|
||
|
}
|
||
|
|
||
|
// verify that a read from a Write channel doesn't block
|
||
|
data = make([]byte, 1024)
|
||
|
if n, err := conn.channels[4].Read(data); n != 0 || err != io.EOF {
|
||
|
t.Errorf("reads should be ignored")
|
||
|
}
|
||
|
|
||
|
// verify that a client write to a Write channel doesn't block (is dropped)
|
||
|
if n, err := client.Write(append([]byte{4}, []byte("ignored")...)); err != nil || n != 8 {
|
||
|
t.Fatalf("%d: %v", n, err)
|
||
|
}
|
||
|
|
||
|
client.Close()
|
||
|
wg.Wait()
|
||
|
}
|
||
|
|
||
|
func TestBase64Conn(t *testing.T) {
|
||
|
conn := NewConn(NewDefaultChannelProtocols([]ChannelType{ReadWriteChannel, ReadWriteChannel}))
|
||
|
s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||
|
conn.Open(w, req)
|
||
|
}))
|
||
|
defer s.Close()
|
||
|
|
||
|
config, err := websocket.NewConfig("ws://"+addr, "http://localhost/")
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
config.Protocol = []string{"base64.channel.k8s.io"}
|
||
|
client, err := websocket.DialConfig(config)
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
defer client.Close()
|
||
|
|
||
|
<-conn.ready
|
||
|
wg := sync.WaitGroup{}
|
||
|
wg.Add(1)
|
||
|
go func() {
|
||
|
defer wg.Done()
|
||
|
data, err := ioutil.ReadAll(conn.channels[0])
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
if !reflect.DeepEqual(data, []byte("client")) {
|
||
|
t.Errorf("unexpected server read: %s", string(data))
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
clientData := base64.StdEncoding.EncodeToString([]byte("client"))
|
||
|
if n, err := client.Write(append([]byte{'0'}, clientData...)); err != nil || n != len(clientData)+1 {
|
||
|
t.Fatalf("%d: %v", n, err)
|
||
|
}
|
||
|
|
||
|
wg.Add(1)
|
||
|
go func() {
|
||
|
defer wg.Done()
|
||
|
if n, err := conn.channels[1].Write([]byte("server")); err != nil && n != 6 {
|
||
|
t.Fatalf("%d: %v", n, err)
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
data := make([]byte, 1024)
|
||
|
if n, err := io.ReadAtLeast(client, data, 9); n != 9 || err != nil {
|
||
|
t.Fatalf("%d: %v", n, err)
|
||
|
}
|
||
|
expect := []byte(base64.StdEncoding.EncodeToString([]byte("server")))
|
||
|
|
||
|
if !reflect.DeepEqual(data[:9], append([]byte{'1'}, expect...)) {
|
||
|
t.Errorf("unexpected client read: %v", data[:9])
|
||
|
}
|
||
|
|
||
|
client.Close()
|
||
|
wg.Wait()
|
||
|
}
|
||
|
|
||
|
type versionTest struct {
|
||
|
supported map[string]bool // protocol -> binary
|
||
|
requested []string
|
||
|
error bool
|
||
|
expected string
|
||
|
}
|
||
|
|
||
|
func versionTests() []versionTest {
|
||
|
const (
|
||
|
binary = true
|
||
|
base64 = false
|
||
|
)
|
||
|
return []versionTest{
|
||
|
{
|
||
|
supported: nil,
|
||
|
requested: []string{"raw"},
|
||
|
error: true,
|
||
|
},
|
||
|
{
|
||
|
supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
|
||
|
requested: nil,
|
||
|
expected: "",
|
||
|
},
|
||
|
{
|
||
|
supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
|
||
|
requested: []string{"v1.raw"},
|
||
|
error: true,
|
||
|
},
|
||
|
{
|
||
|
supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
|
||
|
requested: []string{"v1.raw", "v1.base64"},
|
||
|
error: true,
|
||
|
}, {
|
||
|
supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
|
||
|
requested: []string{"v1.raw", "raw"},
|
||
|
expected: "raw",
|
||
|
},
|
||
|
{
|
||
|
supported: map[string]bool{"": binary, "v1.raw": binary, "v1.base64": base64, "v2.raw": binary, "v2.base64": base64},
|
||
|
requested: []string{"v1.raw"},
|
||
|
expected: "v1.raw",
|
||
|
},
|
||
|
{
|
||
|
supported: map[string]bool{"": binary, "v1.raw": binary, "v1.base64": base64, "v2.raw": binary, "v2.base64": base64},
|
||
|
requested: []string{"v2.base64"},
|
||
|
expected: "v2.base64",
|
||
|
},
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestVersionedConn(t *testing.T) {
|
||
|
for i, test := range versionTests() {
|
||
|
func() {
|
||
|
supportedProtocols := map[string]ChannelProtocolConfig{}
|
||
|
for p, binary := range test.supported {
|
||
|
supportedProtocols[p] = ChannelProtocolConfig{
|
||
|
Binary: binary,
|
||
|
Channels: []ChannelType{ReadWriteChannel},
|
||
|
}
|
||
|
}
|
||
|
conn := NewConn(supportedProtocols)
|
||
|
// note that it's not enough to wait for conn.ready to avoid a race here. Hence,
|
||
|
// we use a channel.
|
||
|
selectedProtocol := make(chan string, 0)
|
||
|
s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||
|
p, _, _ := conn.Open(w, req)
|
||
|
selectedProtocol <- p
|
||
|
}))
|
||
|
defer s.Close()
|
||
|
|
||
|
config, err := websocket.NewConfig("ws://"+addr, "http://localhost/")
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
config.Protocol = test.requested
|
||
|
client, err := websocket.DialConfig(config)
|
||
|
if err != nil {
|
||
|
if !test.error {
|
||
|
t.Fatalf("test %d: didn't expect error: %v", i, err)
|
||
|
} else {
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
defer client.Close()
|
||
|
if test.error && err == nil {
|
||
|
t.Fatalf("test %d: expected an error", i)
|
||
|
}
|
||
|
|
||
|
<-conn.ready
|
||
|
if got, expected := <-selectedProtocol, test.expected; got != expected {
|
||
|
t.Fatalf("test %d: unexpected protocol version: got=%s expected=%s", i, got, expected)
|
||
|
}
|
||
|
}()
|
||
|
}
|
||
|
}
|