diff --git a/test/tool/net/redbean_test.c b/test/tool/net/redbean_test.c index e13c3cfdd..06996f0d0 100644 --- a/test/tool/net/redbean_test.c +++ b/test/tool/net/redbean_test.c @@ -291,3 +291,45 @@ Z\n", EXPECT_NE(-1, wait(0)); EXPECT_NE(-1, sigprocmask(SIG_SETMASK, &savemask, 0)); } + +TEST(redbean, testWebSockets) { + if (IsWindows()) + return; + char portbuf[16]; + int pid, pipefds[2]; + sigset_t chldmask, savemask; + sigaddset(&chldmask, SIGCHLD); + EXPECT_NE(-1, sigprocmask(SIG_BLOCK, &chldmask, &savemask)); + ASSERT_NE(-1, pipe(pipefds)); + ASSERT_NE(-1, (pid = fork())); + if (!pid) { + setpgrp(); + close(0); + open("/dev/null", O_RDWR); + close(pipefds[0]); + dup2(pipefds[1], 1); + sigprocmask(SIG_SETMASK, &savemask, NULL); + execv("bin/redbean-tester", + (char *const[]){"bin/redbean-tester", "-vvszXp0", "-l127.0.0.1", // "-L/tmp/redbean-tester.log", + __strace > 0 ? "--strace" : 0, 0}); + _exit(127); + } + EXPECT_NE(-1, close(pipefds[1])); + EXPECT_NE(-1, read(pipefds[0], portbuf, sizeof(portbuf))); + port = atoi(portbuf); + EXPECT_TRUE(Matches("HTTP/1.1 101 Switching Protocols\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n" + "Date: .*\r\n" + "Server: redbean/.*\r\n" + "\r\n", + gc(SendHttpRequest("GET /ws HTTP/1.1\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\r\n")))); + EXPECT_EQ(0, close(pipefds[0])); + EXPECT_NE(-1, kill(pid, SIGTERM)); + EXPECT_NE(-1, wait(0)); + EXPECT_NE(-1, sigprocmask(SIG_SETMASK, &savemask, 0)); +} diff --git a/tool/net/tester/.init.lua b/tool/net/tester/.init.lua index 6fa68a912..ccdd769fa 100644 --- a/tool/net/tester/.init.lua +++ b/tool/net/tester/.init.lua @@ -10,10 +10,14 @@ function OnHttpRequest() ws.Write(nil) -- upgrade without sending a response coroutine.yield() - local fds = {[GetClientFd()] = unix.POLLIN} + local fd = GetClientFd() + local fds = {[fd] = unix.POLLIN | unix.POLLHUP | unix.POLLRDHUP} -- simple echo server while true do - unix.poll(fds) + res = unix.poll(fds) + if (res[fd] & unix.POLLHUP == unix.POLLHUP) or (res[fd] & unix.POLLRDHUP == unix.POLLRDHUP) then + return + end local s, t = ws.Read() if t == ws.CLOSE then return