diff --git a/tool/net/redbean.c b/tool/net/redbean.c index 93816d1aa..bc70f1892 100644 --- a/tool/net/redbean.c +++ b/tool/net/redbean.c @@ -42,6 +42,7 @@ #include "libc/intrin/atomic.h" #include "libc/intrin/bsr.h" #include "libc/intrin/likely.h" +#include "libc/intrin/newbie.h" #include "libc/intrin/nomultics.h" #include "libc/intrin/safemacros.h" #include "libc/log/appendresourcereport.internal.h" @@ -125,6 +126,7 @@ #include "third_party/mbedtls/net_sockets.h" #include "third_party/mbedtls/oid.h" #include "third_party/mbedtls/san.h" +#include "third_party/mbedtls/sha1.h" #include "third_party/mbedtls/ssl.h" #include "third_party/mbedtls/ssl_ticket.h" #include "third_party/mbedtls/x509.h" @@ -405,6 +407,7 @@ struct ClearedPerMessage { bool hascontenttype; bool gotcachecontrol; bool gotxcontenttypeoptions; + bool iswebsocket; int frags; int statuscode; int isyielding; @@ -5159,6 +5162,106 @@ static bool LuaRunAsset(const char *path, bool mandatory) { return !!a; } +static int LuaUpgradeWS(lua_State *L) { + size_t i; + char *p, *q; + bool haskey; + mbedtls_sha1_context ctx; + unsigned char hash[20]; + OnlyCallDuringRequest(L, "UpgradeWS"); + + haskey = true; + for (i = 0; i < cpm.msg.xheaders.n; ++i) { + if (SlicesEqualCase( + "Sec-WebSocket-Key", strlen("Sec-WebSocket-Key"), + inbuf.p + cpm.msg.xheaders.p[i].k.a, + cpm.msg.xheaders.p[i].k.b - cpm.msg.xheaders.p[i].k.a)) { + mbedtls_sha1_init(&ctx); + mbedtls_sha1_starts_ret(&ctx); + mbedtls_sha1_update_ret( + &ctx, (unsigned char *)inbuf.p + cpm.msg.xheaders.p[i].v.a, + cpm.msg.xheaders.p[i].v.b - cpm.msg.xheaders.p[i].v.a); + haskey = true; + break; + } + } + + if (!haskey) luaL_error(L, "No Sec-WebSocket-Key header"); + + p = SetStatus(101, "Switching Protocols"); + while (p - hdrbuf.p + (20 + 21 + (20 + 28 + 4)) + 512 > hdrbuf.n) { + hdrbuf.n += hdrbuf.n >> 1; + q = xrealloc(hdrbuf.p, hdrbuf.n); + cpm.luaheaderp = p = q + (p - hdrbuf.p); + hdrbuf.p = q; + } + + mbedtls_sha1_update_ret( + &ctx, (unsigned char *)"258EAFA5-E914-47DA-95CA-C5AB0DC85B11", 36); + mbedtls_sha1_finish_ret(&ctx, hash); + char *accept = EncodeBase64((char *)hash, 20, NULL); + + p = stpcpy(p, "Upgrade: websocket\r\n"); + p = stpcpy(p, "Connection: upgrade\r\n"); + p = AppendHeader(p, "Sec-WebSocket-Accept", accept); + + cpm.luaheaderp = p; + cpm.iswebsocket = true; + return 0; +} + +static int LuaReadWS(lua_State *L) { + ssize_t rc; + size_t i, got, amt; + unsigned char wshdr[10], wshdrlen, *extlen, *mask; + uint64_t len; + OnlyCallDuringRequest(L, "ReadWS"); + + got = 0; + do { + if ((rc = reader(client, wshdr + got, 2 - got)) == -1) + luaL_error(L, "Could not read WS header"); + } while ((got += rc) < 2); + + if (!(wshdr[1] | (1 << 7))) luaL_error(L, "Unmasked WS frame"); + + len = wshdr[1] & ~(1 << 7); + wshdrlen = 6; + if (len == 126) { + wshdrlen = 8; + } else if (len == 127) { + wshdrlen = 14; + } + + while (got < wshdrlen) { + if ((rc = reader(client, wshdr + got, wshdrlen - got)) == -1) + luaL_error(L, "Could not read WS extended length"); + got += rc; + } + + extlen = &wshdr[2]; + mask = &wshdr[wshdrlen - 4]; + if (len == 126) { + len = be16toh(*(uint16_t *)extlen); + } else if (len == 127) { + len = be64toh(*(uint64_t *)extlen); + } + + if (len >= inbuf.n - amtread) + luaL_error(L, "Required %d bytes to read WS frame, %d bytes available", len, + inbuf.n - amtread); + + for (got = 0, amt = amtread; got < len; got += rc, amt += rc) { + if ((rc = reader(client, inbuf.p + amt, len - got)) == -1) + luaL_error(L, "Could not read WS data"); + } + + for (i = 0, amt = amtread; i < got; ++i, ++amt) inbuf.p[amt] ^= mask[i & 0x3]; + + lua_pushlstring(L, inbuf.p + amtread, got); + return 1; +} + // // list of functions that can't be run from the repl static const char *const kDontAutoComplete[] = { @@ -5359,6 +5462,7 @@ static const luaL_Reg kLuaFuncs[] = { {"ProgramUid", LuaProgramUid}, // {"ProgramUniprocess", LuaProgramUniprocess}, // {"Rand64", LuaRand64}, // + {"ReadWS", LuaReadWS}, // undocumented {"Rdrand", LuaRdrand}, // {"Rdseed", LuaRdseed}, // {"Rdtsc", LuaRdtsc}, // @@ -5388,6 +5492,7 @@ static const luaL_Reg kLuaFuncs[] = { {"Underlong", LuaUnderlong}, // {"UuidV4", LuaUuidV4}, // {"UuidV7", LuaUuidV7}, // + {"UpgradeWS", LuaUpgradeWS}, // undocumented {"VisualizeControlCodes", LuaVisualizeControlCodes}, // {"Write", LuaWrite}, // {"bin", LuaBin}, // @@ -6479,6 +6584,57 @@ static bool StreamResponse(char *p) { return true; } +static bool StreamWS(char *p) { + ssize_t rc; + struct iovec iov[4]; + char *s, wshdr[10], *extlen; + + p = AppendCrlf(p); + CHECK_LE(p - hdrbuf.p, hdrbuf.n); + if (logmessages) { + LogMessage("sending", hdrbuf.p, p - hdrbuf.p); + } + iov[0].iov_base = hdrbuf.p; + iov[0].iov_len = p - hdrbuf.p; + Send(iov, 1); + + bzero(iov, sizeof(iov)); + iov[0].iov_base = wshdr; + + wshdr[0] = 0x1 | (0x1 << 7); + extlen = &wshdr[2]; + + cpm.isyielding = 2; // skip first YieldGenerator + + for (;;) { + if ((rc = cpm.generator(iov + 1)) <= 0) break; + + if (rc < 126) { + wshdr[1] = rc; + iov[0].iov_len = 2; + } else if (rc <= 0xFFFF) { + wshdr[1] = 126; + *(uint16_t *)extlen = htobe16(rc); + iov[0].iov_len = 4; + } else { + wshdr[1] = 127; + *(uint64_t *)extlen = htobe64(rc); + iov[0].iov_len = 10; + } + if (Send(iov, 4) == -1) break; + } + + if (rc != -1) { + wshdr[0] = 0x8; + wshdr[1] = 0; + iov[0].iov_len = 2; + Send(iov, 1); + } else { + connectionclose = true; + } + return true; +} + static bool HandleMessageActual(void) { int rc; long reqtime, contime; @@ -6543,6 +6699,8 @@ static bool HandleMessageActual(void) { } if (!cpm.generator) { return TransmitResponse(p); + } else if (cpm.iswebsocket) { + return StreamWS(p); } else { return StreamResponse(p); }