From 6c02e000bd8530503720af85579ac7d6edff1398 Mon Sep 17 00:00:00 2001 From: wingdeans <66850754+wingdeans@users.noreply.github.com> Date: Wed, 27 Dec 2023 22:21:38 -0500 Subject: [PATCH] Support pings, binary messages --- tool/net/redbean.c | 54 +++++++++++++++++++++++++++++++++++++--------- 1 file changed, 44 insertions(+), 10 deletions(-) diff --git a/tool/net/redbean.c b/tool/net/redbean.c index e3ff05b69..5d88f643b 100644 --- a/tool/net/redbean.c +++ b/tool/net/redbean.c @@ -407,7 +407,7 @@ struct ClearedPerMessage { bool hascontenttype; bool gotcachecontrol; bool gotxcontenttypeoptions; - bool iswebsocket; + char websockettype; int frags; int statuscode; int isyielding; @@ -5206,7 +5206,7 @@ static int LuaUpgradeWS(lua_State *L) { p = AppendHeader(p, "Sec-WebSocket-Accept", accept); cpm.luaheaderp = p; - cpm.iswebsocket = true; + cpm.websockettype = 1; return 0; } @@ -5215,6 +5215,7 @@ static int LuaReadWS(lua_State *L) { size_t i, got, amt; unsigned char wshdr[10], wshdrlen, *extlen, *mask; uint64_t len; + struct iovec iov[2]; OnlyCallDuringRequest(L, "ReadWS"); got = 0; @@ -5223,9 +5224,13 @@ static int LuaReadWS(lua_State *L) { luaL_error(L, "Could not read WS header"); } while ((got += rc) < 2); - if (!(wshdr[1] | (1 << 7))) luaL_error(L, "Unmasked WS frame"); + if (wshdr[0] & 0x70) goto close; // reserved bit set + if (!(wshdr[1] | (1 << 7))) goto close; // unmasked + if ((wshdr[0] & 0x7) >= 0x3) goto close; // reserved opcode len = wshdr[1] & ~(1 << 7); + if (wshdr[0] & 0x8 && len >= 126) goto close; // long control frame + wshdrlen = 6; if (len == 126) { wshdrlen = 8; @@ -5256,13 +5261,38 @@ static int LuaReadWS(lua_State *L) { luaL_error(L, "Could not read WS data"); } - if ((wshdr[0] & 0xF) == 0x8) - luaL_error(L, "WS connection closed"); - for (i = 0, amt = amtread; i < got; ++i, ++amt) inbuf.p[amt] ^= mask[i & 0x3]; + if ((wshdr[0] & 0xF) == 0x9) { + wshdr[0] = (wshdr[0] & ~0xF) | 0xA; + wshdr[1] = wshdr[1] & ~0x80; + iov[0].iov_base = wshdr; + iov[0].iov_len = wshdrlen - 4; + iov[1].iov_base = inbuf.p + amtread; + iov[1].iov_len = got; + Send(iov, 2); + } + lua_pushlstring(L, inbuf.p + amtread, got); - return 1; + lua_pushnumber(L, wshdr[0]); + + return 2; + +close: + lua_pushnil(L); + lua_pushnumber(L, 0x08); + return 2; +} + +static int LuaSetWSFlags(lua_State *L) { + OnlyCallDuringRequest(L, "SetWSFlags"); + char flags = lround(lua_tonumber(L, 1)); + if (flags & 0x01) { + cpm.websockettype = 1; + } else if (flags & 0x02) { + cpm.websockettype = 2; + } + return 0; } // @@ -5483,6 +5513,7 @@ static const luaL_Reg kLuaFuncs[] = { {"SetHeader", LuaSetHeader}, // {"SetLogLevel", LuaSetLogLevel}, // {"SetStatus", LuaSetStatus}, // + {"SetWSFlags", LuaSetWSFlags}, // undocumented {"Sha1", LuaSha1}, // {"Sha224", LuaSha224}, // {"Sha256", LuaSha256}, // @@ -6605,14 +6636,16 @@ static bool StreamWS(char *p) { bzero(iov, sizeof(iov)); iov[0].iov_base = wshdr; - wshdr[0] = 0x1 | (1 << 7); extlen = &wshdr[2]; for (;;) { if (!YL || lua_status(YL) != LUA_YIELD) break; // done yielding cpm.contentlength = 0; status = lua_resume(YL, NULL, 0, &nresults); - if (status != LUA_OK && status != LUA_YIELD) { + if (status == LUA_OK) { + lua_pop(YL, nresults); + break; + } else if (status != LUA_YIELD) { LogLuaError("resume", lua_tostring(YL, -1)); lua_pop(YL, 1); break; @@ -6637,6 +6670,7 @@ static bool StreamWS(char *p) { *(uint64_t *)extlen = htobe64(rc); iov[0].iov_len = 10; } + wshdr[0] = cpm.websockettype | (1 << 7); if (Send(iov, 2) == -1) break; } @@ -6713,7 +6747,7 @@ static bool HandleMessageActual(void) { } if (!cpm.generator) { return TransmitResponse(p); - } else if (cpm.iswebsocket) { + } else if (cpm.websockettype) { return StreamWS(p); } else { return StreamResponse(p);