Add OnServerListen hook to configure listening sockets

This commit is contained in:
Paul Kulchenko 2022-07-02 23:18:01 -07:00
parent 5df3e4e7a8
commit cdd0406bc9
2 changed files with 35 additions and 6 deletions

View file

@ -501,7 +501,7 @@ HOOKS
OnClientConnection(ip:int,port:int,serverip:int,serverport:int) → bool OnClientConnection(ip:int,port:int,serverip:int,serverport:int) → bool
If this function is defined it'll be called from the main process If this function is defined it'll be called from the main process
each time redbean accepts a new client connection. If it returns each time redbean accepts a new client connection. If it returns
true then redbean will close the connection without calling fork. `true`, redbean will close the connection without calling fork.
OnProcessCreate(pid:int,ip:int,port:int,serverip:int,serverport:int) OnProcessCreate(pid:int,ip:int,port:int,serverip:int,serverport:int)
If this function is defined it'll be called from the main process If this function is defined it'll be called from the main process
@ -517,6 +517,12 @@ HOOKS
each time redbean reaps a child connection process using wait4(). each time redbean reaps a child connection process using wait4().
This won't be called in uniprocess mode. This won't be called in uniprocess mode.
OnServerListen(socketdescriptor:int,serverip:int,serverport:int) → bool
If this function is defined it'll be called from the main process
before redbean starts listening on a port. This hook can be used
to modify socket configuration to set `SO_REUSEPORT`, for example.
If it returns `true`, redbean will not listen to that ip/port.
OnServerStart() OnServerStart()
If this function is defined it'll be called from the main process If this function is defined it'll be called from the main process
right before the main event loop starts. right before the main event loop starts.

View file

@ -1081,8 +1081,8 @@ static bool LuaEvalFile(const char *path) {
} }
static bool LuaOnClientConnection(void) { static bool LuaOnClientConnection(void) {
bool dropit = false;
#ifndef STATIC #ifndef STATIC
bool dropit;
uint32_t ip, serverip; uint32_t ip, serverip;
uint16_t port, serverport; uint16_t port, serverport;
lua_State *L = GL; lua_State *L = GL;
@ -1097,14 +1097,11 @@ static bool LuaOnClientConnection(void) {
dropit = lua_toboolean(L, -1); dropit = lua_toboolean(L, -1);
} else { } else {
LogLuaError("OnClientConnection", lua_tostring(L, -1)); LogLuaError("OnClientConnection", lua_tostring(L, -1));
dropit = false;
} }
lua_pop(L, 1); // pop result or error lua_pop(L, 1); // pop result or error
AssertLuaStackIsAt(L, 0); AssertLuaStackIsAt(L, 0);
return dropit;
#else
return false;
#endif #endif
return dropit;
} }
static void LuaOnProcessCreate(int pid) { static void LuaOnProcessCreate(int pid) {
@ -1128,6 +1125,25 @@ static void LuaOnProcessCreate(int pid) {
#endif #endif
} }
static bool LuaOnServerListen(int fd, uint32_t ip, uint16_t port) {
bool nouse = false;
#ifndef STATIC
lua_State *L = GL;
lua_getglobal(L, "OnServerListen");
lua_pushinteger(L, fd);
lua_pushinteger(L, ip);
lua_pushinteger(L, port);
if (LuaCallWithTrace(L, 3, 1, NULL) == LUA_OK) {
nouse = lua_toboolean(L, -1);
} else {
LogLuaError("OnServerListen", lua_tostring(L, -1));
}
lua_pop(L, 1); // pop result or error
AssertLuaStackIsAt(L, 0);
#endif
return nouse;
}
static void LuaOnProcessDestroy(int pid) { static void LuaOnProcessDestroy(int pid) {
#ifndef STATIC #ifndef STATIC
lua_State *L = GL; lua_State *L = GL;
@ -6914,6 +6930,7 @@ static void Listen(void) {
char ipbuf[16]; char ipbuf[16];
size_t i, j, n; size_t i, j, n;
uint32_t ip, port, addrsize, *ifs, *ifp; uint32_t ip, port, addrsize, *ifs, *ifp;
bool hasonserverlisten = IsHookDefined("OnServerListen");
if (!ports.n) { if (!ports.n) {
ProgramPort(8080); ProgramPort(8080);
} }
@ -6940,6 +6957,12 @@ static void Listen(void) {
IPPROTO_TCP, true, &timeout)) == -1) { IPPROTO_TCP, true, &timeout)) == -1) {
DIEF("(srvr) socket: %m"); DIEF("(srvr) socket: %m");
} }
if (hasonserverlisten &&
LuaOnServerListen(servers.p[n].fd, ips.p[i], ports.p[j])) {
n--; // skip this server instance
continue;
}
if (bind(servers.p[n].fd, &servers.p[n].addr, if (bind(servers.p[n].fd, &servers.p[n].addr,
sizeof(servers.p[n].addr)) == -1) { sizeof(servers.p[n].addr)) == -1) {
DIEF("(srvr) bind error: %m: %hhu.%hhu.%hhu.%hhu:%hu", ips.p[i] >> 24, DIEF("(srvr) bind error: %m: %hhu.%hhu.%hhu.%hhu:%hu", ips.p[i] >> 24,