Support fragmentation, utf-8 checks

This commit is contained in:
wingdeans 2023-12-28 00:49:52 -05:00
parent 988bbd2475
commit 8a8fc9a65f

View file

@ -412,7 +412,7 @@ struct ClearedPerMessage {
bool hascontenttype; bool hascontenttype;
bool gotcachecontrol; bool gotcachecontrol;
bool gotxcontenttypeoptions; bool gotxcontenttypeoptions;
char websockettype; char wstype;
int frags; int frags;
int statuscode; int statuscode;
int isyielding; int isyielding;
@ -493,6 +493,8 @@ static uint8_t *zmap;
static uint8_t *zcdir; static uint8_t *zcdir;
static size_t hdrsize; static size_t hdrsize;
static size_t amtread; static size_t amtread;
static size_t wsfragread;
static char wsfragtype;
static reader_f reader; static reader_f reader;
static writer_f writer; static writer_f writer;
static char *extrahdrs; static char *extrahdrs;
@ -5100,14 +5102,15 @@ static int LuaUpgradeWS(lua_State *L) {
p = AppendHeader(p, "Sec-WebSocket-Accept", accept); p = AppendHeader(p, "Sec-WebSocket-Accept", accept);
cpm.luaheaderp = p; cpm.luaheaderp = p;
cpm.websockettype = 1; cpm.wstype = 1;
return 0; return 0;
} }
static int LuaReadWS(lua_State *L) { static int LuaReadWS(lua_State *L) {
ssize_t rc; ssize_t rc;
size_t i, got, amt; size_t i, got, amt, bufsize;
unsigned char wshdr[10], wshdrlen, *extlen, *mask; unsigned char wshdr[10], wshdrlen, *extlen, *mask, op;
char *bufstart;
uint64_t len; uint64_t len;
struct iovec iov[2]; struct iovec iov[2];
OnlyCallDuringRequest(L, "ReadWS"); OnlyCallDuringRequest(L, "ReadWS");
@ -5118,12 +5121,19 @@ static int LuaReadWS(lua_State *L) {
luaL_error(L, "Could not read WS header"); luaL_error(L, "Could not read WS header");
} while ((got += rc) < 2); } while ((got += rc) < 2);
op = wshdr[0] & 0xF;
if (wshdr[0] & 0x70) goto close; // reserved bit set if (wshdr[0] & 0x70) goto close; // reserved bit set
if (!(wshdr[1] | (1 << 7))) goto close; // unmasked if (!(wshdr[1] | (1 << 7))) goto close; // unmasked
if ((wshdr[0] & 0x7) >= 0x3) goto close; // reserved opcode if ((wshdr[0] & 0x7) >= 0x3) goto close; // reserved opcode
if (!wsfragtype && !op) goto close; // not in continuation
len = wshdr[1] & ~(1 << 7); len = wshdr[1] & ~(1 << 7);
if (wshdr[0] & 0x8 && len >= 126) goto close; // long control frame if (wshdr[0] & 0x8) { // control frame
if (!(wshdr[0] & 0x80) || len >= 126) goto close; // fragmented or too long
} else {
if (op && wsfragtype) goto close; // during fragmented seq
}
wshdrlen = 6; wshdrlen = 6;
if (len == 126) { if (len == 126) {
@ -5146,29 +5156,54 @@ static int LuaReadWS(lua_State *L) {
len = be64toh(*(uint64_t *)extlen); len = be64toh(*(uint64_t *)extlen);
} }
if (len >= inbuf.n - amtread) if (len >= inbuf.n - wsfragread)
luaL_error(L, "Required %d bytes to read WS frame, %d bytes available", len, luaL_error(L, "Required %d bytes to read WS frame, %d bytes available", len,
inbuf.n - amtread); inbuf.n - wsfragread);
for (got = 0, amt = amtread; got < len; got += rc, amt += rc) { for (got = 0, amt = wsfragread; got < len; got += rc, amt += rc) {
if ((rc = reader(client, inbuf.p + amt, len - got)) == -1) if ((rc = reader(client, inbuf.p + amt, len - got)) == -1)
luaL_error(L, "Could not read WS data"); luaL_error(L, "Could not read WS data");
} }
for (i = 0, amt = amtread; i < got; ++i, ++amt) inbuf.p[amt] ^= mask[i & 0x3]; for (i = 0, amt = wsfragread; i < got; ++i, ++amt)
inbuf.p[amt] ^= mask[i & 0x3];
if ((wshdr[0] & 0xF) == 0x9) { if (op == 0x9) {
wshdr[0] = (wshdr[0] & ~0xF) | 0xA; wshdr[0] = (wshdr[0] & ~0xF) | 0xA;
wshdr[1] = wshdr[1] & ~0x80; wshdr[1] = wshdr[1] & ~0x80;
iov[0].iov_base = wshdr; iov[0].iov_base = wshdr;
iov[0].iov_len = wshdrlen - 4; iov[0].iov_len = wshdrlen - 4;
iov[1].iov_base = inbuf.p + amtread; iov[1].iov_base = inbuf.p + wsfragread;
iov[1].iov_len = got; iov[1].iov_len = got;
Send(iov, 2); Send(iov, 2);
} }
lua_pushlstring(L, inbuf.p + amtread, got); if (wshdr[0] & 0x80) {
lua_pushnumber(L, wshdr[0]); if (op) {
bufstart = inbuf.p + wsfragread;
bufsize = got;
if (op == 0x1 && !isutf8(bufstart, bufsize)) goto close;
lua_pushlstring(L, bufstart, bufsize);
lua_pushnumber(L, wshdr[0]);
} else {
bufstart = inbuf.p + amtread;
bufsize = (wsfragread - amtread) + got;
if (wsfragtype == 0x1 && !isutf8(bufstart, bufsize)) goto close;
lua_pushlstring(L, bufstart, bufsize);
lua_pushnumber(L, wsfragtype);
wsfragread = amtread;
wsfragtype = 0;
}
} else {
lua_pushnil(L);
lua_pushnumber(L, 0);
if (!wsfragtype) wsfragtype = op;
wsfragread += got;
}
return 2; return 2;
@ -5182,9 +5217,9 @@ static int LuaSetWSFlags(lua_State *L) {
OnlyCallDuringRequest(L, "SetWSFlags"); OnlyCallDuringRequest(L, "SetWSFlags");
char flags = lround(lua_tonumber(L, 1)); char flags = lround(lua_tonumber(L, 1));
if (flags & 0x01) { if (flags & 0x01) {
cpm.websockettype = 1; cpm.wstype = 1;
} else if (flags & 0x02) { } else if (flags & 0x02) {
cpm.websockettype = 2; cpm.wstype = 2;
} }
return 0; return 0;
} }
@ -6492,6 +6527,8 @@ static bool StreamWS(char *p) {
iov[0].iov_base = wshdr; iov[0].iov_base = wshdr;
extlen = &wshdr[2]; extlen = &wshdr[2];
wsfragread = amtread;
wsfragtype = 0;
for (;;) { for (;;) {
if (!YL || lua_status(YL) != LUA_YIELD) break; // done yielding if (!YL || lua_status(YL) != LUA_YIELD) break; // done yielding
@ -6525,7 +6562,7 @@ static bool StreamWS(char *p) {
*(uint64_t *)extlen = htobe64(rc); *(uint64_t *)extlen = htobe64(rc);
iov[0].iov_len = 10; iov[0].iov_len = 10;
} }
wshdr[0] = cpm.websockettype | (1 << 7); wshdr[0] = cpm.wstype | (1 << 7);
if (Send(iov, 2) == -1) break; if (Send(iov, 2) == -1) break;
} }
@ -6596,7 +6633,7 @@ static bool HandleMessageActual(void) {
} }
if (!cpm.generator) { if (!cpm.generator) {
return TransmitResponse(p); return TransmitResponse(p);
} else if (cpm.websockettype) { } else if (cpm.wstype) {
return StreamWS(p); return StreamWS(p);
} else { } else {
return StreamResponse(p); return StreamResponse(p);