350 lines
10 KiB
Go
350 lines
10 KiB
Go
|
/*
|
||
|
Copyright 2015 The Kubernetes Authors.
|
||
|
|
||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
you may not use this file except in compliance with the License.
|
||
|
You may obtain a copy of the License at
|
||
|
|
||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||
|
|
||
|
Unless required by applicable law or agreed to in writing, software
|
||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
See the License for the specific language governing permissions and
|
||
|
limitations under the License.
|
||
|
*/
|
||
|
|
||
|
package wsstream
|
||
|
|
||
|
import (
|
||
|
"encoding/base64"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"net/http"
|
||
|
"regexp"
|
||
|
"strings"
|
||
|
"time"
|
||
|
|
||
|
"github.com/golang/glog"
|
||
|
"golang.org/x/net/websocket"
|
||
|
|
||
|
"k8s.io/apimachinery/pkg/util/runtime"
|
||
|
)
|
||
|
|
||
|
// The Websocket subprotocol "channel.k8s.io" prepends each binary message with a byte indicating
|
||
|
// the channel number (zero indexed) the message was sent on. Messages in both directions should
|
||
|
// prefix their messages with this channel byte. When used for remote execution, the channel numbers
|
||
|
// are by convention defined to match the POSIX file-descriptors assigned to STDIN, STDOUT, and STDERR
|
||
|
// (0, 1, and 2). No other conversion is performed on the raw subprotocol - writes are sent as they
|
||
|
// are received by the server.
|
||
|
//
|
||
|
// Example client session:
|
||
|
//
|
||
|
// CONNECT http://server.com with subprotocol "channel.k8s.io"
|
||
|
// WRITE []byte{0, 102, 111, 111, 10} # send "foo\n" on channel 0 (STDIN)
|
||
|
// READ []byte{1, 10} # receive "\n" on channel 1 (STDOUT)
|
||
|
// CLOSE
|
||
|
//
|
||
|
const ChannelWebSocketProtocol = "channel.k8s.io"
|
||
|
|
||
|
// The Websocket subprotocol "base64.channel.k8s.io" base64 encodes each message with a character
|
||
|
// indicating the channel number (zero indexed) the message was sent on. Messages in both directions
|
||
|
// should prefix their messages with this channel char. When used for remote execution, the channel
|
||
|
// numbers are by convention defined to match the POSIX file-descriptors assigned to STDIN, STDOUT,
|
||
|
// and STDERR ('0', '1', and '2'). The data received on the server is base64 decoded (and must be
|
||
|
// be valid) and data written by the server to the client is base64 encoded.
|
||
|
//
|
||
|
// Example client session:
|
||
|
//
|
||
|
// CONNECT http://server.com with subprotocol "base64.channel.k8s.io"
|
||
|
// WRITE []byte{48, 90, 109, 57, 118, 67, 103, 111, 61} # send "foo\n" (base64: "Zm9vCgo=") on channel '0' (STDIN)
|
||
|
// READ []byte{49, 67, 103, 61, 61} # receive "\n" (base64: "Cg==") on channel '1' (STDOUT)
|
||
|
// CLOSE
|
||
|
//
|
||
|
const Base64ChannelWebSocketProtocol = "base64.channel.k8s.io"
|
||
|
|
||
|
type codecType int
|
||
|
|
||
|
const (
|
||
|
rawCodec codecType = iota
|
||
|
base64Codec
|
||
|
)
|
||
|
|
||
|
type ChannelType int
|
||
|
|
||
|
const (
|
||
|
IgnoreChannel ChannelType = iota
|
||
|
ReadChannel
|
||
|
WriteChannel
|
||
|
ReadWriteChannel
|
||
|
)
|
||
|
|
||
|
var (
|
||
|
// connectionUpgradeRegex matches any Connection header value that includes upgrade
|
||
|
connectionUpgradeRegex = regexp.MustCompile("(^|.*,\\s*)upgrade($|\\s*,)")
|
||
|
)
|
||
|
|
||
|
// IsWebSocketRequest returns true if the incoming request contains connection upgrade headers
|
||
|
// for WebSockets.
|
||
|
func IsWebSocketRequest(req *http.Request) bool {
|
||
|
return connectionUpgradeRegex.MatchString(strings.ToLower(req.Header.Get("Connection"))) && strings.ToLower(req.Header.Get("Upgrade")) == "websocket"
|
||
|
}
|
||
|
|
||
|
// IgnoreReceives reads from a WebSocket until it is closed, then returns. If timeout is set, the
|
||
|
// read and write deadlines are pushed every time a new message is received.
|
||
|
func IgnoreReceives(ws *websocket.Conn, timeout time.Duration) {
|
||
|
defer runtime.HandleCrash()
|
||
|
var data []byte
|
||
|
for {
|
||
|
resetTimeout(ws, timeout)
|
||
|
if err := websocket.Message.Receive(ws, &data); err != nil {
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// handshake ensures the provided user protocol matches one of the allowed protocols. It returns
|
||
|
// no error if no protocol is specified.
|
||
|
func handshake(config *websocket.Config, req *http.Request, allowed []string) error {
|
||
|
protocols := config.Protocol
|
||
|
if len(protocols) == 0 {
|
||
|
protocols = []string{""}
|
||
|
}
|
||
|
|
||
|
for _, protocol := range protocols {
|
||
|
for _, allow := range allowed {
|
||
|
if allow == protocol {
|
||
|
config.Protocol = []string{protocol}
|
||
|
return nil
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return fmt.Errorf("requested protocol(s) are not supported: %v; supports %v", config.Protocol, allowed)
|
||
|
}
|
||
|
|
||
|
// ChannelProtocolConfig describes a websocket subprotocol with channels.
|
||
|
type ChannelProtocolConfig struct {
|
||
|
Binary bool
|
||
|
Channels []ChannelType
|
||
|
}
|
||
|
|
||
|
// NewDefaultChannelProtocols returns a channel protocol map with the
|
||
|
// subprotocols "", "channel.k8s.io", "base64.channel.k8s.io" and the given
|
||
|
// channels.
|
||
|
func NewDefaultChannelProtocols(channels []ChannelType) map[string]ChannelProtocolConfig {
|
||
|
return map[string]ChannelProtocolConfig{
|
||
|
"": {Binary: true, Channels: channels},
|
||
|
ChannelWebSocketProtocol: {Binary: true, Channels: channels},
|
||
|
Base64ChannelWebSocketProtocol: {Binary: false, Channels: channels},
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Conn supports sending multiple binary channels over a websocket connection.
|
||
|
type Conn struct {
|
||
|
protocols map[string]ChannelProtocolConfig
|
||
|
selectedProtocol string
|
||
|
channels []*websocketChannel
|
||
|
codec codecType
|
||
|
ready chan struct{}
|
||
|
ws *websocket.Conn
|
||
|
timeout time.Duration
|
||
|
}
|
||
|
|
||
|
// NewConn creates a WebSocket connection that supports a set of channels. Channels begin each
|
||
|
// web socket message with a single byte indicating the channel number (0-N). 255 is reserved for
|
||
|
// future use. The channel types for each channel are passed as an array, supporting the different
|
||
|
// duplex modes. Read and Write refer to whether the channel can be used as a Reader or Writer.
|
||
|
//
|
||
|
// The protocols parameter maps subprotocol names to ChannelProtocols. The empty string subprotocol
|
||
|
// name is used if websocket.Config.Protocol is empty.
|
||
|
func NewConn(protocols map[string]ChannelProtocolConfig) *Conn {
|
||
|
return &Conn{
|
||
|
ready: make(chan struct{}),
|
||
|
protocols: protocols,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// SetIdleTimeout sets the interval for both reads and writes before timeout. If not specified,
|
||
|
// there is no timeout on the connection.
|
||
|
func (conn *Conn) SetIdleTimeout(duration time.Duration) {
|
||
|
conn.timeout = duration
|
||
|
}
|
||
|
|
||
|
// Open the connection and create channels for reading and writing. It returns
|
||
|
// the selected subprotocol, a slice of channels and an error.
|
||
|
func (conn *Conn) Open(w http.ResponseWriter, req *http.Request) (string, []io.ReadWriteCloser, error) {
|
||
|
go func() {
|
||
|
defer runtime.HandleCrash()
|
||
|
defer conn.Close()
|
||
|
websocket.Server{Handshake: conn.handshake, Handler: conn.handle}.ServeHTTP(w, req)
|
||
|
}()
|
||
|
<-conn.ready
|
||
|
rwc := make([]io.ReadWriteCloser, len(conn.channels))
|
||
|
for i := range conn.channels {
|
||
|
rwc[i] = conn.channels[i]
|
||
|
}
|
||
|
return conn.selectedProtocol, rwc, nil
|
||
|
}
|
||
|
|
||
|
func (conn *Conn) initialize(ws *websocket.Conn) {
|
||
|
negotiated := ws.Config().Protocol
|
||
|
conn.selectedProtocol = negotiated[0]
|
||
|
p := conn.protocols[conn.selectedProtocol]
|
||
|
if p.Binary {
|
||
|
conn.codec = rawCodec
|
||
|
} else {
|
||
|
conn.codec = base64Codec
|
||
|
}
|
||
|
conn.ws = ws
|
||
|
conn.channels = make([]*websocketChannel, len(p.Channels))
|
||
|
for i, t := range p.Channels {
|
||
|
switch t {
|
||
|
case ReadChannel:
|
||
|
conn.channels[i] = newWebsocketChannel(conn, byte(i), true, false)
|
||
|
case WriteChannel:
|
||
|
conn.channels[i] = newWebsocketChannel(conn, byte(i), false, true)
|
||
|
case ReadWriteChannel:
|
||
|
conn.channels[i] = newWebsocketChannel(conn, byte(i), true, true)
|
||
|
case IgnoreChannel:
|
||
|
conn.channels[i] = newWebsocketChannel(conn, byte(i), false, false)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
close(conn.ready)
|
||
|
}
|
||
|
|
||
|
func (conn *Conn) handshake(config *websocket.Config, req *http.Request) error {
|
||
|
supportedProtocols := make([]string, 0, len(conn.protocols))
|
||
|
for p := range conn.protocols {
|
||
|
supportedProtocols = append(supportedProtocols, p)
|
||
|
}
|
||
|
return handshake(config, req, supportedProtocols)
|
||
|
}
|
||
|
|
||
|
func (conn *Conn) resetTimeout() {
|
||
|
if conn.timeout > 0 {
|
||
|
conn.ws.SetDeadline(time.Now().Add(conn.timeout))
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Close is only valid after Open has been called
|
||
|
func (conn *Conn) Close() error {
|
||
|
<-conn.ready
|
||
|
for _, s := range conn.channels {
|
||
|
s.Close()
|
||
|
}
|
||
|
conn.ws.Close()
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// handle implements a websocket handler.
|
||
|
func (conn *Conn) handle(ws *websocket.Conn) {
|
||
|
defer conn.Close()
|
||
|
conn.initialize(ws)
|
||
|
|
||
|
for {
|
||
|
conn.resetTimeout()
|
||
|
var data []byte
|
||
|
if err := websocket.Message.Receive(ws, &data); err != nil {
|
||
|
if err != io.EOF {
|
||
|
glog.Errorf("Error on socket receive: %v", err)
|
||
|
}
|
||
|
break
|
||
|
}
|
||
|
if len(data) == 0 {
|
||
|
continue
|
||
|
}
|
||
|
channel := data[0]
|
||
|
if conn.codec == base64Codec {
|
||
|
channel = channel - '0'
|
||
|
}
|
||
|
data = data[1:]
|
||
|
if int(channel) >= len(conn.channels) {
|
||
|
glog.V(6).Infof("Frame is targeted for a reader %d that is not valid, possible protocol error", channel)
|
||
|
continue
|
||
|
}
|
||
|
if _, err := conn.channels[channel].DataFromSocket(data); err != nil {
|
||
|
glog.Errorf("Unable to write frame to %d: %v\n%s", channel, err, string(data))
|
||
|
continue
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// write multiplexes the specified channel onto the websocket
|
||
|
func (conn *Conn) write(num byte, data []byte) (int, error) {
|
||
|
conn.resetTimeout()
|
||
|
switch conn.codec {
|
||
|
case rawCodec:
|
||
|
frame := make([]byte, len(data)+1)
|
||
|
frame[0] = num
|
||
|
copy(frame[1:], data)
|
||
|
if err := websocket.Message.Send(conn.ws, frame); err != nil {
|
||
|
return 0, err
|
||
|
}
|
||
|
case base64Codec:
|
||
|
frame := string('0'+num) + base64.StdEncoding.EncodeToString(data)
|
||
|
if err := websocket.Message.Send(conn.ws, frame); err != nil {
|
||
|
return 0, err
|
||
|
}
|
||
|
}
|
||
|
return len(data), nil
|
||
|
}
|
||
|
|
||
|
// websocketChannel represents a channel in a connection
|
||
|
type websocketChannel struct {
|
||
|
conn *Conn
|
||
|
num byte
|
||
|
r io.Reader
|
||
|
w io.WriteCloser
|
||
|
|
||
|
read, write bool
|
||
|
}
|
||
|
|
||
|
// newWebsocketChannel creates a pipe for writing to a websocket. Do not write to this pipe
|
||
|
// prior to the connection being opened. It may be no, half, or full duplex depending on
|
||
|
// read and write.
|
||
|
func newWebsocketChannel(conn *Conn, num byte, read, write bool) *websocketChannel {
|
||
|
r, w := io.Pipe()
|
||
|
return &websocketChannel{conn, num, r, w, read, write}
|
||
|
}
|
||
|
|
||
|
func (p *websocketChannel) Write(data []byte) (int, error) {
|
||
|
if !p.write {
|
||
|
return len(data), nil
|
||
|
}
|
||
|
return p.conn.write(p.num, data)
|
||
|
}
|
||
|
|
||
|
// DataFromSocket is invoked by the connection receiver to move data from the connection
|
||
|
// into a specific channel.
|
||
|
func (p *websocketChannel) DataFromSocket(data []byte) (int, error) {
|
||
|
if !p.read {
|
||
|
return len(data), nil
|
||
|
}
|
||
|
|
||
|
switch p.conn.codec {
|
||
|
case rawCodec:
|
||
|
return p.w.Write(data)
|
||
|
case base64Codec:
|
||
|
dst := make([]byte, len(data))
|
||
|
n, err := base64.StdEncoding.Decode(dst, data)
|
||
|
if err != nil {
|
||
|
return 0, err
|
||
|
}
|
||
|
return p.w.Write(dst[:n])
|
||
|
}
|
||
|
return 0, nil
|
||
|
}
|
||
|
|
||
|
func (p *websocketChannel) Read(data []byte) (int, error) {
|
||
|
if !p.read {
|
||
|
return 0, io.EOF
|
||
|
}
|
||
|
return p.r.Read(data)
|
||
|
}
|
||
|
|
||
|
func (p *websocketChannel) Close() error {
|
||
|
return p.w.Close()
|
||
|
}
|