Support pings, binary messages

This commit is contained in:
wingdeans 2023-12-27 22:21:38 -05:00
parent a3c060b2ff
commit 988bbd2475

View file

@ -412,7 +412,7 @@ struct ClearedPerMessage {
bool hascontenttype;
bool gotcachecontrol;
bool gotxcontenttypeoptions;
bool iswebsocket;
char websockettype;
int frags;
int statuscode;
int isyielding;
@ -5100,7 +5100,7 @@ static int LuaUpgradeWS(lua_State *L) {
p = AppendHeader(p, "Sec-WebSocket-Accept", accept);
cpm.luaheaderp = p;
cpm.iswebsocket = true;
cpm.websockettype = 1;
return 0;
}
@ -5109,6 +5109,7 @@ static int LuaReadWS(lua_State *L) {
size_t i, got, amt;
unsigned char wshdr[10], wshdrlen, *extlen, *mask;
uint64_t len;
struct iovec iov[2];
OnlyCallDuringRequest(L, "ReadWS");
got = 0;
@ -5117,9 +5118,13 @@ static int LuaReadWS(lua_State *L) {
luaL_error(L, "Could not read WS header");
} while ((got += rc) < 2);
if (!(wshdr[1] | (1 << 7))) luaL_error(L, "Unmasked WS frame");
if (wshdr[0] & 0x70) goto close; // reserved bit set
if (!(wshdr[1] | (1 << 7))) goto close; // unmasked
if ((wshdr[0] & 0x7) >= 0x3) goto close; // reserved opcode
len = wshdr[1] & ~(1 << 7);
if (wshdr[0] & 0x8 && len >= 126) goto close; // long control frame
wshdrlen = 6;
if (len == 126) {
wshdrlen = 8;
@ -5150,13 +5155,38 @@ static int LuaReadWS(lua_State *L) {
luaL_error(L, "Could not read WS data");
}
if ((wshdr[0] & 0xF) == 0x8)
luaL_error(L, "WS connection closed");
for (i = 0, amt = amtread; i < got; ++i, ++amt) inbuf.p[amt] ^= mask[i & 0x3];
if ((wshdr[0] & 0xF) == 0x9) {
wshdr[0] = (wshdr[0] & ~0xF) | 0xA;
wshdr[1] = wshdr[1] & ~0x80;
iov[0].iov_base = wshdr;
iov[0].iov_len = wshdrlen - 4;
iov[1].iov_base = inbuf.p + amtread;
iov[1].iov_len = got;
Send(iov, 2);
}
lua_pushlstring(L, inbuf.p + amtread, got);
return 1;
lua_pushnumber(L, wshdr[0]);
return 2;
close:
lua_pushnil(L);
lua_pushnumber(L, 0x08);
return 2;
}
static int LuaSetWSFlags(lua_State *L) {
OnlyCallDuringRequest(L, "SetWSFlags");
char flags = lround(lua_tonumber(L, 1));
if (flags & 0x01) {
cpm.websockettype = 1;
} else if (flags & 0x02) {
cpm.websockettype = 2;
}
return 0;
}
// <SORTED>
@ -5376,6 +5406,7 @@ static const luaL_Reg kLuaFuncs[] = {
{"SetHeader", LuaSetHeader}, //
{"SetLogLevel", LuaSetLogLevel}, //
{"SetStatus", LuaSetStatus}, //
{"SetWSFlags", LuaSetWSFlags}, // undocumented
{"Sha1", LuaSha1}, //
{"Sha224", LuaSha224}, //
{"Sha256", LuaSha256}, //
@ -6460,14 +6491,16 @@ static bool StreamWS(char *p) {
bzero(iov, sizeof(iov));
iov[0].iov_base = wshdr;
wshdr[0] = 0x1 | (1 << 7);
extlen = &wshdr[2];
for (;;) {
if (!YL || lua_status(YL) != LUA_YIELD) break; // done yielding
cpm.contentlength = 0;
status = lua_resume(YL, NULL, 0, &nresults);
if (status != LUA_OK && status != LUA_YIELD) {
if (status == LUA_OK) {
lua_pop(YL, nresults);
break;
} else if (status != LUA_YIELD) {
LogLuaError("resume", lua_tostring(YL, -1));
lua_pop(YL, 1);
break;
@ -6492,6 +6525,7 @@ static bool StreamWS(char *p) {
*(uint64_t *)extlen = htobe64(rc);
iov[0].iov_len = 10;
}
wshdr[0] = cpm.websockettype | (1 << 7);
if (Send(iov, 2) == -1) break;
}
@ -6562,7 +6596,7 @@ static bool HandleMessageActual(void) {
}
if (!cpm.generator) {
return TransmitResponse(p);
} else if (cpm.iswebsocket) {
} else if (cpm.websockettype) {
return StreamWS(p);
} else {
return StreamResponse(p);