From b3ba9ba38c628f1c0f3ae9424b69667dc0de1719 Mon Sep 17 00:00:00 2001 From: Paul Kulchenko Date: Sat, 19 Mar 2022 10:03:52 -0700 Subject: [PATCH] Update redbean to allow yielding from Lua to support streaming --- tool/net/redbean.c | 69 +++++++++++++++++++++++++++++++++++++--------- 1 file changed, 56 insertions(+), 13 deletions(-) diff --git a/tool/net/redbean.c b/tool/net/redbean.c index df1e1d189..7e15641ea 100644 --- a/tool/net/redbean.c +++ b/tool/net/redbean.c @@ -385,6 +385,7 @@ static int gmtoff; static int client; static int changeuid; static int changegid; +static int isyielding; static int statuscode; static int sslpskindex; static int oldloglevel; @@ -395,7 +396,7 @@ static uint32_t clientaddrsize; static size_t zsize; static char *outbuf; -static lua_State *GL; +static lua_State *GL, *YL; static char *content; static uint8_t *zmap; static uint8_t *zbase; @@ -1063,10 +1064,12 @@ static nodiscard char *LuaFormatStack(lua_State *L) { // -1 is error or result (assuming nres == 1) // @param L is main Lua interpreter // @note this needs to be reentrant -static int LuaCallWithTrace(lua_State *L, int nargs, int nres) { +static int LuaCallWithTrace(lua_State *L, int nargs, int nres, lua_State *C) { int nresults, status; - lua_State *C = lua_newthread(L); // create a new coroutine - lua_insert(L, 1); // move coroutine to the bottom of the stack + bool canyield = !!C; // allow yield if coroutine is provided + if (!C) C = lua_newthread(L); // create a new coroutine if not passed + // move coroutine to the bottom of the stack (including one that is passed) + lua_insert(L, 1); // move the function (and arguments) to the top of the coro stack lua_xmove(L, C, 1 + nargs); @@ -1085,7 +1088,7 @@ static int LuaCallWithTrace(lua_State *L, int nargs, int nres) { // than the caller expects, as lua_resume // doesn't adjust the stack for needed results for (; nresults < nres; nresults++) lua_pushnil(L); - status = LUA_OK; // treat LUA_YIELD the same as LUA_OK + if (!canyield) status = LUA_OK; // treat LUA_YIELD the same as LUA_OK } return status; } @@ -1097,7 +1100,7 @@ static void LogLuaError(char *hook, char *err) { static bool LuaRunCode(const char *code) { lua_State *L = GL; int status = luaL_loadstring(L, code); - if (status != LUA_OK || LuaCallWithTrace(L, 0, 0) != LUA_OK) { + if (status != LUA_OK || LuaCallWithTrace(L, 0, 0, NULL) != LUA_OK) { LogLuaError("lua code", lua_tostring(L, -1)); lua_pop(L, 1); // pop error return false; @@ -1118,7 +1121,7 @@ static bool LuaOnClientConnection(void) { lua_pushinteger(L, port); lua_pushinteger(L, serverip); lua_pushinteger(L, serverport); - if (LuaCallWithTrace(L, 4, 1) == LUA_OK) { + if (LuaCallWithTrace(L, 4, 1, NULL) == LUA_OK) { dropit = lua_toboolean(L, -1); } else { LogLuaError("OnClientConnection", lua_tostring(L, -1)); @@ -1141,7 +1144,7 @@ static void LuaOnProcessCreate(int pid) { lua_pushinteger(L, port); lua_pushinteger(L, serverip); lua_pushinteger(L, serverport); - if (LuaCallWithTrace(L, 5, 0) != LUA_OK) { + if (LuaCallWithTrace(L, 5, 0, NULL) != LUA_OK) { LogLuaError("OnProcessCreate", lua_tostring(L, -1)); lua_pop(L, 1); // pop error } @@ -1152,7 +1155,7 @@ static void LuaOnProcessDestroy(int pid) { lua_State *L = GL; lua_getglobal(L, "OnProcessDestroy"); lua_pushinteger(L, pid); - if (LuaCallWithTrace(L, 1, 0) != LUA_OK) { + if (LuaCallWithTrace(L, 1, 0, NULL) != LUA_OK) { LogLuaError("OnProcessDestroy", lua_tostring(L, -1)); lua_pop(L, 1); // pop error } @@ -1173,7 +1176,7 @@ static inline bool IsHookDefined(const char *s) { static void CallSimpleHook(const char *s) { lua_State *L = GL; lua_getglobal(L, s); - if (LuaCallWithTrace(L, 0, 0) != LUA_OK) { + if (LuaCallWithTrace(L, 0, 0, NULL) != LUA_OK) { LogLuaError(s, lua_tostring(L, -1)); lua_pop(L, 1); // pop error } @@ -2401,6 +2404,45 @@ static char *ServeFailure(unsigned code, const char *reason) { return ServeErrorImpl(code, reason, NULL); } +static ssize_t YieldGenerator(struct iovec v[3]) { + int nresults, status; + if (isyielding > 1) { + do { + if (!YL || lua_status(YL) != LUA_YIELD) return 0; // done yielding + contentlength = 0; + status = lua_resume(YL, NULL, 0, &nresults); + if (status != LUA_OK && status != LUA_YIELD) { + LogLuaError("resume", lua_tostring(YL, -1)); + lua_pop(YL, 1); + return -1; + } + lua_pop(YL, nresults); + if (!contentlength) UseOutput(); + // continue yielding if nothing to return to keep generator running + } while (!contentlength); + } + DEBUGF("(lua) yielded with %ld bytes generated", contentlength); + isyielding++; + v[0].iov_base = content; + v[0].iov_len = contentlength; + return contentlength; +} + +static int LuaCallWithYield(lua_State *L) { + int status; + // since yield may happen in OnHttpRequest and in ServeLua, + // need to fully restart the yield generator; + // the second set of headers is not going to be sent + lua_State *co = lua_newthread(L); + if ((status = LuaCallWithTrace(L, 0, 0, co)) == LUA_YIELD) { + YL = co; + generator = YieldGenerator; + if (!isyielding) isyielding = 1; + status = LUA_OK; + } + return status; +} + static ssize_t DeflateGenerator(struct iovec v[3]) { int i, rc; size_t no; @@ -2978,7 +3020,7 @@ static char *LuaOnHttpRequest(void) { effectivepath.p = url.path.p; effectivepath.n = url.path.n; lua_getglobal(L, "OnHttpRequest"); - if (LuaCallWithTrace(L, 0, 0) == LUA_OK) { + if (LuaCallWithYield(L) == LUA_OK) { AssertLuaStackIsEmpty(L); return CommitOutput(GetLuaResponse()); } else { @@ -3003,7 +3045,7 @@ static char *ServeLua(struct Asset *a, const char *s, size_t n) { int status = luaL_loadbuffer(L, code, codelen, FreeLater(xasprintf("@%s", FreeLater(strndup(s, n))))); - if (status == LUA_OK && LuaCallWithTrace(L, 0, 0) == LUA_OK) { + if (status == LUA_OK && LuaCallWithYield(L) == LUA_OK) { return CommitOutput(GetLuaResponse()); } else { char *error; @@ -5684,7 +5726,7 @@ static bool LuaRunAsset(const char *path, bool mandatory) { DEBUGF("(lua) LuaRunAsset(%`'s)", path); status = luaL_loadbuffer(L, code, codelen, FreeLater(xasprintf("@%s", path))); - if (status != LUA_OK || LuaCallWithTrace(L, 0, 0) != LUA_OK) { + if (status != LUA_OK || LuaCallWithTrace(L, 0, 0, NULL) != LUA_OK) { LogLuaError("lua code", lua_tostring(L, -1)); lua_pop(L, 1); // pop error if (mandatory) exit(1); @@ -6756,6 +6798,7 @@ static void InitRequest(void) { loops.n = 0; generator = 0; luaheaderp = 0; + isyielding = 0; contentlength = 0; hascontenttype = false; referrerpolicy = 0;