redbean: clean up websockets support

This commit is contained in:
Derek Meer 2025-03-19 02:29:08 -07:00
parent b5993e4e1d
commit c6002e00fb
6 changed files with 118 additions and 79 deletions

View file

@ -45,6 +45,8 @@ static const char kUtf8Dispatch[] = {
* - Incorrect sequencing of 0300 (FIRST) and 0200 (CONT) chars * - Incorrect sequencing of 0300 (FIRST) and 0200 (CONT) chars
* - Thompson-Pike varint sequence not encodable as UTF-16 * - Thompson-Pike varint sequence not encodable as UTF-16
* - Overlong UTF-8 encoding * - Overlong UTF-8 encoding
* - any UTF-16 surrogate code points
* - last character in a multi-byte UTF-8 sequence exceeds the valid limit
* *
* @param size if -1 implies strlen * @param size if -1 implies strlen
*/ */

View file

@ -104,4 +104,4 @@ CF-Visitor, kHttpCfVisitor
CF-Connecting-IP, kHttpCfConnectingIp CF-Connecting-IP, kHttpCfConnectingIp
CF-IPCountry, kHttpCfIpcountry CF-IPCountry, kHttpCfIpcountry
CDN-Loop, kHttpCdnLoop CDN-Loop, kHttpCdnLoop
Sec-WebSocket-Key, kHttpWebsocketKey Sec-WebSocket-Key, kHttpWebSocketKey

View file

@ -389,7 +389,7 @@ LookupHttpHeader (register const char *str, register size_t len)
{""}, {""}, {""}, {""}, {""}, {""}, {""}, {""}, {""}, {""}, {""}, {""}, {""}, {""}, {""}, {""}, {""}, {""},
{""}, {""}, {""}, {""},
#line 107 "gethttpheader.gperf" #line 107 "gethttpheader.gperf"
{"Sec-WebSocket-Key", kHttpWebsocketKey}, {"Sec-WebSocket-Key", kHttpWebSocketKey},
{""}, {""}, {""}, {""},
#line 22 "gethttpheader.gperf" #line 22 "gethttpheader.gperf"
{"X-Forwarded-For", kHttpXForwardedFor}, {"X-Forwarded-For", kHttpXForwardedFor},

View file

@ -206,7 +206,7 @@ const char *GetHttpHeaderName(int h) {
return "CDN-Loop"; return "CDN-Loop";
case kHttpSecChUaPlatform: case kHttpSecChUaPlatform:
return "Sec-CH-UA-Platform"; return "Sec-CH-UA-Platform";
case kHttpWebsocketKey: case kHttpWebSocketKey:
return "Sec-WebSocket-Key"; return "Sec-WebSocket-Key";
default: default:
return NULL; return NULL;

View file

@ -138,7 +138,7 @@
#define kHttpCfIpcountry 90 #define kHttpCfIpcountry 90
#define kHttpSecChUaPlatform 91 #define kHttpSecChUaPlatform 91
#define kHttpCdnLoop 92 #define kHttpCdnLoop 92
#define kHttpWebsocketKey 93 #define kHttpWebSocketKey 93
#define kHttpHeadersMax 94 #define kHttpHeadersMax 94
COSMOPOLITAN_C_START_ COSMOPOLITAN_C_START_

View file

@ -5165,131 +5165,159 @@ static bool LuaRunAsset(const char *path, bool mandatory) {
} }
static int LuaWSUpgrade(lua_State *L) { static int LuaWSUpgrade(lua_State *L) {
size_t i;
char *p, *q;
bool haskey;
mbedtls_sha1_context ctx; mbedtls_sha1_context ctx;
unsigned char hash[20]; unsigned char hash[20];
char *accept, *p, *q;
if (cpm.generator) {
return luaL_error(L, "Cannot upgrade to websocket after yielding normally");
}
if (cpm.generator) if (!HasHeader(kHttpWebSocketKey)) {
luaL_error(L, "Cannot upgrade to websocket after yielding normally"); return luaL_error(L, "No Sec-WebSocket-Key header");
}
if (!HasHeader(kHttpWebsocketKey)) // Prepare Sec-WebSocket-Accept response header (See RFC6455 1.3)
luaL_error(L, "No Sec-WebSocket-Key header");
mbedtls_sha1_init(&ctx); mbedtls_sha1_init(&ctx);
mbedtls_sha1_starts_ret(&ctx); mbedtls_sha1_starts_ret(&ctx);
mbedtls_sha1_update_ret(&ctx, (unsigned char*) mbedtls_sha1_update_ret(&ctx,
HeaderData(kHttpWebsocketKey), (unsigned char*)HeaderData(kHttpWebSocketKey),
HeaderLength(kHttpWebsocketKey)); HeaderLength(kHttpWebSocketKey));
mbedtls_sha1_update_ret(&ctx,
(unsigned char*)"258EAFA5-E914-47DA-95CA-C5AB0DC85B11",
36);
mbedtls_sha1_finish_ret(&ctx, hash);
accept = EncodeBase64((char *)hash, 20, NULL);
// prepare response
p = SetStatus(101, "Switching Protocols"); p = SetStatus(101, "Switching Protocols");
while (p - hdrbuf.p + (20 + 21 + (20 + 28 + 4)) + 512 > hdrbuf.n) { // make enough space for the handshake message:
// "Upgrade: websocket\r\n" (20 bytes)
// "Connection: Upgrade\r\n" (21 bytes)
// "Sec-WebSocket-Accept: <accept>\r\n" (54 bytes)
// <accept> will always be 28 bytes, as len(b64(hash)) = 4*ceil(20/3) = 28
while (p - hdrbuf.p + 95 + 512 > hdrbuf.n) {
hdrbuf.n += hdrbuf.n >> 1; hdrbuf.n += hdrbuf.n >> 1;
q = xrealloc(hdrbuf.p, hdrbuf.n); q = xrealloc(hdrbuf.p, hdrbuf.n);
cpm.luaheaderp = p = q + (p - hdrbuf.p); cpm.luaheaderp = p = q + (p - hdrbuf.p);
hdrbuf.p = q; hdrbuf.p = q;
} }
mbedtls_sha1_update_ret(
&ctx, (unsigned char *)"258EAFA5-E914-47DA-95CA-C5AB0DC85B11", 36);
mbedtls_sha1_finish_ret(&ctx, hash);
char *accept = EncodeBase64((char *)hash, 20, NULL);
p = stpcpy(p, "Upgrade: websocket\r\n"); p = stpcpy(p, "Upgrade: websocket\r\n");
p = stpcpy(p, "Connection: upgrade\r\n"); p = stpcpy(p, "Connection: Upgrade\r\n");
p = AppendHeader(p, "Sec-WebSocket-Accept", accept); p = AppendHeader(p, "Sec-WebSocket-Accept", accept);
cpm.luaheaderp = p; cpm.luaheaderp = p;
cpm.wstype = 1; cpm.wstype = 1;
return 0; return 0;
} }
// see RFC6455 5.2 for details on the websocket data frame structure
static int LuaWSRead(lua_State *L) { static int LuaWSRead(lua_State *L) {
ssize_t rc; ssize_t rc;
size_t i, got, amt, bufsize; size_t i, got, amt, bufsize;
unsigned char wshdr[10], wshdrlen, *extlen, *mask, op; unsigned char header[10], headerlen, opcode, *extlen, *mask;
char *bufstart; char *bufstart;
uint64_t len; uint64_t len;
struct iovec iov[2]; struct iovec iov[2];
OnlyCallDuringRequest(L, "ws.Read"); OnlyCallDuringRequest(L, "ws.Read");
got = 0; got = 0;
// read 2 bytes of the frame header
do { do {
if ((rc = reader(client, wshdr + got, 2 - got)) == -1) if ((rc = reader(client, header + got, 2 - got)) == -1) {
luaL_error(L, "Could not read WS header"); return luaL_error(L, "Could not read WS header");
}
} while ((got += rc) < 2); } while ((got += rc) < 2);
op = wshdr[0] & 0xF; // reserved bit set
if (header[0] & 0x70) goto close;
// reserved opcode
if ((header[0] & 0x7) > 0x3) goto close;
// payload data is unmasked
if (!(header[1] | (1 << 7))) goto close;
if (wshdr[0] & 0x70) goto close; // reserved bit set opcode = header[0] & 0xF;
if (!(wshdr[1] | (1 << 7))) goto close; // unmasked // not in continuation
if ((wshdr[0] & 0x7) >= 0x3) goto close; // reserved opcode if (!wsfragtype && !opcode) goto close;
if (!wsfragtype && !op) goto close; // not in continuation
len = wshdr[1] & ~(1 << 7); len = header[1] & ~(1 << 7);
if (wshdr[0] & 0x8) { // control frame if (header[0] & 0x8) {
if (!(wshdr[0] & 0x80) || len >= 126) goto close; // fragmented or too long // control frame is fragmented or too long
if (!(header[0] & 0x80) || len >= 126) goto close;
} else { } else {
if (op && wsfragtype) goto close; // during fragmented seq // control frame during fragmented sequence
if (opcode && wsfragtype) goto close;
} }
wshdrlen = 6; headerlen = 6;
if (len == 126) { if (len == 126) {
wshdrlen = 8; headerlen = 8;
} else if (len == 127) { } else if (len == 127) {
wshdrlen = 14; headerlen = 14;
} }
while (got < wshdrlen) { // read rest of header, if necessary
if ((rc = reader(client, wshdr + got, wshdrlen - got)) == -1) while (got < headerlen) {
luaL_error(L, "Could not read WS extended length"); if ((rc = reader(client, header + got, headerlen - got)) == -1) {
return luaL_error(L, "Could not read WS extended length");
}
got += rc; got += rc;
} }
extlen = &wshdr[2]; extlen = &header[2];
mask = &wshdr[wshdrlen - 4]; mask = &header[headerlen - 4];
// multibyte length quantities are expressed in network byte order
if (len == 126) { if (len == 126) {
len = be16toh(*(uint16_t *)extlen); len = be16toh(*(uint16_t *)extlen);
} else if (len == 127) { } else if (len == 127) {
len = be64toh(*(uint64_t *)extlen); len = be64toh(*(uint64_t *)extlen);
} }
if (len >= inbuf.n - wsfragread) if (len >= inbuf.n - wsfragread) {
luaL_error(L, "Required %d bytes to read WS frame, %d bytes available", len, return luaL_error(L,
"Required %d bytes to read WS frame, %d bytes available",
len,
inbuf.n - wsfragread); inbuf.n - wsfragread);
for (got = 0, amt = wsfragread; got < len; got += rc, amt += rc) {
if ((rc = reader(client, inbuf.p + amt, len - got)) == -1)
luaL_error(L, "Could not read WS data");
} }
for (i = 0, amt = wsfragread; i < got; ++i, ++amt) // read in frame data
inbuf.p[amt] ^= mask[i & 0x3]; for (got = 0, amt = wsfragread; got < len; got += rc, amt += rc) {
if ((rc = reader(client, inbuf.p + amt, len - got)) == -1) {
return luaL_error(L, "Could not read WS data");
}
}
if (op == 0x9) { // unmask data
wshdr[0] = (wshdr[0] & ~0xF) | 0xA; for (i = 0, amt = wsfragread; i < got; ++i, ++amt) {
wshdr[1] = wshdr[1] & ~0x80; inbuf.p[amt] ^= mask[i & 0x3];
iov[0].iov_base = wshdr; }
iov[0].iov_len = wshdrlen - 4;
// ping received, respond with pong
if (opcode == 0x9) {
header[0] = (header[0] & ~0xF) | 0xA;
header[1] = header[1] & ~0x80;
// pong data must be identical to ping
iov[0].iov_base = header;
iov[0].iov_len = headerlen - 4;
iov[1].iov_base = inbuf.p + wsfragread; iov[1].iov_base = inbuf.p + wsfragread;
iov[1].iov_len = got; iov[1].iov_len = got;
Send(iov, 2); Send(iov, 2);
} }
if (wshdr[0] & 0x80) { // final fragment
if (op) { if (header[0] & 0x80) {
// non-continuation frame
if (opcode) {
bufstart = inbuf.p + wsfragread; bufstart = inbuf.p + wsfragread;
bufsize = got; bufsize = got;
if (op == 0x1 && !isutf8(bufstart, bufsize)) goto close; // text frame with invalid text
if (opcode == 0x1 && !isutf8(bufstart, bufsize)) goto close;
lua_pushlstring(L, bufstart, bufsize); lua_pushlstring(L, bufstart, bufsize);
lua_pushinteger(L, op); lua_pushinteger(L, opcode);
} else { } else {
bufstart = inbuf.p + amtread; bufstart = inbuf.p + amtread;
bufsize = (wsfragread - amtread) + got; bufsize = (wsfragread - amtread) + got;
// text frame with invalid text
if (wsfragtype == 0x1 && !isutf8(bufstart, bufsize)) goto close; if (wsfragtype == 0x1 && !isutf8(bufstart, bufsize)) goto close;
lua_pushlstring(L, bufstart, bufsize); lua_pushlstring(L, bufstart, bufsize);
lua_pushinteger(L, wsfragtype); lua_pushinteger(L, wsfragtype);
@ -5301,7 +5329,7 @@ static int LuaWSRead(lua_State *L) {
lua_pushnil(L); lua_pushnil(L);
lua_pushinteger(L, 0); lua_pushinteger(L, 0);
if (!wsfragtype) wsfragtype = op; if (!wsfragtype) wsfragtype = opcode;
wsfragread += got; wsfragread += got;
} }
@ -5319,20 +5347,22 @@ static int LuaWSWrite(lua_State *L) {
const char *data; const char *data;
OnlyCallDuringRequest(L, "ws.Write"); OnlyCallDuringRequest(L, "ws.Write");
if (!cpm.wstype) if (!cpm.wstype) {
LuaWSUpgrade(L); LuaWSUpgrade(L);
}
type = luaL_optinteger(L, 2, -1); type = luaL_optinteger(L, 2, -1);
if (type == 1 || type == 2) { if (type == 1 || type == 2) {
cpm.wstype = type; cpm.wstype = type;
} else if (type != -1) { } else if (type != -1) {
luaL_error(L, "Invalid WS type"); return luaL_error(L, "Invalid WS type");
} }
if (!lua_isnil(L, 1)) { if (!lua_isnil(L, 1)) {
data = luaL_checklstring(L, 1, &size); data = luaL_checklstring(L, 1, &size);
appendd(&cpm.outbuf, data, size); appendd(&cpm.outbuf, data, size);
} }
return 0; return 0;
} }
@ -5620,7 +5650,7 @@ static const luaL_Reg kLuaLibs[] = {
{"path", LuaPath}, // {"path", LuaPath}, //
{"re", LuaRe}, // {"re", LuaRe}, //
{"unix", LuaUnix}, // {"unix", LuaUnix}, //
{"ws", LuaWS} // {"ws", LuaWS}, //
}; };
static void LuaSetArgv(lua_State *L) { static void LuaSetArgv(lua_State *L) {
@ -6677,7 +6707,7 @@ static bool StreamResponse(char *p) {
static bool StreamWS(char *p) { static bool StreamWS(char *p) {
ssize_t rc; ssize_t rc;
struct iovec iov[2]; struct iovec iov[2];
char *s, wshdr[10], *extlen; char header[10], *s, *extlen;
int nresults, status; int nresults, status;
p = AppendCrlf(p); p = AppendCrlf(p);
@ -6690,14 +6720,17 @@ static bool StreamWS(char *p) {
Send(iov, 1); Send(iov, 1);
bzero(iov, sizeof(iov)); bzero(iov, sizeof(iov));
iov[0].iov_base = wshdr; iov[0].iov_base = header;
extlen = &wshdr[2]; extlen = &header[2];
wsfragread = amtread; wsfragread = amtread;
wsfragtype = 0; wsfragtype = 0;
for (;;) { for (;;) {
if (!YL || lua_status(YL) != LUA_YIELD) break; // done yielding // done yielding
if (!YL || lua_status(YL) != LUA_YIELD) {
break;
}
cpm.contentlength = 0; cpm.contentlength = 0;
status = lua_resume(YL, NULL, 0, &nresults); status = lua_resume(YL, NULL, 0, &nresults);
if (status == LUA_OK) { if (status == LUA_OK) {
@ -6709,7 +6742,9 @@ static bool StreamWS(char *p) {
break; break;
} }
lua_pop(YL, nresults); lua_pop(YL, nresults);
if (!cpm.contentlength) UseOutput(); if (!cpm.contentlength) {
UseOutput();
}
DEBUGF("(lua) ws yielded with %ld bytes generated", cpm.contentlength); DEBUGF("(lua) ws yielded with %ld bytes generated", cpm.contentlength);
@ -6717,23 +6752,25 @@ static bool StreamWS(char *p) {
iov[1].iov_len = rc = cpm.contentlength; iov[1].iov_len = rc = cpm.contentlength;
if (rc < 126) { if (rc < 126) {
wshdr[1] = rc; header[1] = rc;
iov[0].iov_len = 2; iov[0].iov_len = 2;
} else if (rc <= 0xFFFF) { } else if (rc <= 0xFFFF) {
wshdr[1] = 126; header[1] = 126;
*(uint16_t *)extlen = htobe16(rc); *(uint16_t *)extlen = htobe16(rc);
iov[0].iov_len = 4; iov[0].iov_len = 4;
} else { } else {
wshdr[1] = 127; header[1] = 127;
*(uint64_t *)extlen = htobe64(rc); *(uint64_t *)extlen = htobe64(rc);
iov[0].iov_len = 10; iov[0].iov_len = 10;
} }
wshdr[0] = cpm.wstype | (1 << 7); header[0] = cpm.wstype | (1 << 7);
if (Send(iov, 2) == -1) break; if (Send(iov, 2) == -1) {
break;
}
} }
wshdr[0] = 0x8 | (1 << 7); header[0] = 0x8 | (1 << 7);
wshdr[1] = 0; header[1] = 0;
iov[0].iov_len = 2; iov[0].iov_len = 2;
Send(iov, 1); Send(iov, 1);
connectionclose = true; connectionclose = true;