fix path handling in bind/connect

This commit is contained in:
codehz 2023-12-13 22:26:16 +08:00
parent 3b302e6379
commit b4af492c6e
5 changed files with 103 additions and 5 deletions

View file

@ -64,6 +64,54 @@ textwindows size_t __normntpath(char16_t *p, size_t n) {
return j;
}
/**
* Copies path for unix sockets on Windows NT.
*
* This function does the following chores:
*
* 1. Fixing drive letter paths, e.g. `/c/` `c:\`
* 2. Turn `/tmp` into GetTempPath()
*
* I don't think we need normalize the path here.
*
* @param path is input unix-style path
* @return short count excluding NUL on success, or -1 w/ errno
* @error ENAMETOOLONG
*/
textwindows int __mkwin32_sun_path(
const char *path, char sun_path[hasatleast UNIX_SOCKET_NAME_MAX]) {
// 1. Need +1 for NUL-terminator
if (!path || (IsAsan() && !__asan_is_valid_str(path))) {
return efault();
}
char16_t tmp_path[UNIX_SOCKET_NAME_MAX];
char *p = sun_path;
const char *q = path;
size_t n = 0;
if (IsSlash(q[0]) && IsAlpha(q[1]) && IsSlash(q[2])) {
// turn "\c\foo" into "c:\foo"
p[0] = q[1];
p[1] = ':';
p[2] = '\\';
p += 3;
n = 3;
} else if (IsSlash(q[0]) && q[1] == 't' && q[2] == 'm' && q[3] == 'p' &&
(IsSlash(q[4]) || !q[4])) {
if (!q[4] || !q[5]) return efault();
GetTempPath(UNIX_SOCKET_NAME_MAX, tmp_path);
n = tprecode16to8(p, UNIX_SOCKET_NAME_MAX, tmp_path).ax;
p += n;
q += 5;
}
for (; n < UNIX_SOCKET_NAME_MAX; n++) {
if (!(*p++ = *q++)) break;
}
if (n == UNIX_SOCKET_NAME_MAX) {
return efault();
}
return n;
}
textwindows int __mkntpath(const char *path,
char16_t path16[hasatleast PATH_MAX]) {
return __mkntpath2(path, path16, -1);

View file

@ -4,12 +4,15 @@
#include "libc/nt/struct/overlapped.h"
COSMOPOLITAN_C_START_
#define UNIX_SOCKET_NAME_MAX 108
bool isdirectory_nt(const char *);
bool isregularfile_nt(const char *);
bool issymlink_nt(const char *);
bool32 ntsetprivilege(int64_t, const char16_t *, uint32_t);
char16_t *__create_pipe_name(char16_t *);
size_t __normntpath(char16_t *, size_t);
int __mkwin32_sun_path(const char *, char[hasatleast UNIX_SOCKET_NAME_MAX]);
int __mkntpath(const char *, char16_t[hasatleast PATH_MAX]);
int __mkntpath2(const char *, char16_t[hasatleast PATH_MAX], int);
int __mkntpathath(int64_t, const char *, int, char16_t[hasatleast PATH_MAX]);

View file

@ -16,14 +16,25 @@
TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
PERFORMANCE OF THIS SOFTWARE.
*/
#include "libc/calls/syscall_support-nt.internal.h"
#include "libc/nt/thunk/msabi.h"
#include "libc/sock/internal.h"
#include "libc/sock/struct/sockaddr.h"
#include "libc/sock/syscall_fd.internal.h"
#include "libc/sysv/consts/af.h"
#ifdef __x86_64__
__msabi extern typeof(__sys_bind_nt) *const __imp_bind;
textwindows int sys_bind_nt(struct Fd *f, const void *addr, uint32_t addrsize) {
struct sockaddr_un *sun, nt_sun;
if (f->family == AF_UNIX && ((struct sockaddr *)addr)->sa_family == AF_UNIX &&
addrsize >= sizeof(struct sockaddr_un)) {
sun = (struct sockaddr_un *)addr;
nt_sun.sun_family = AF_UNIX;
if (__mkwin32_sun_path(sun->sun_path, nt_sun.sun_path) == -1) return -1;
addr = &nt_sun;
}
if (__imp_bind(f->handle, addr, addrsize) != -1) {
f->isbound = true;
return 0;

View file

@ -20,6 +20,7 @@
#include "libc/atomic.h"
#include "libc/calls/struct/fd.internal.h"
#include "libc/calls/struct/sigset.internal.h"
#include "libc/calls/syscall_support-nt.internal.h"
#include "libc/cosmo.h"
#include "libc/errno.h"
#include "libc/mem/mem.h"
@ -34,6 +35,7 @@
#include "libc/sock/struct/sockaddr.h"
#include "libc/sock/syscall_fd.internal.h"
#include "libc/sock/wsaid.internal.h"
#include "libc/sysv/consts/af.h"
#include "libc/sysv/consts/o.h"
#include "libc/sysv/errfuns.h"
@ -140,6 +142,14 @@ static textwindows int sys_connect_nt_impl(struct Fd *f, const void *addr,
textwindows int sys_connect_nt(struct Fd *f, const void *addr,
uint32_t addrsize) {
struct sockaddr_un *sun, nt_sun;
if (f->family == AF_UNIX && ((struct sockaddr *)addr)->sa_family == AF_UNIX &&
addrsize >= sizeof(struct sockaddr_un)) {
sun = (struct sockaddr_un *)addr;
nt_sun.sun_family = AF_UNIX;
if (__mkwin32_sun_path(sun->sun_path, nt_sun.sun_path) == -1) return -1;
addr = &nt_sun;
}
sigset_t mask = __sig_block();
int rc = sys_connect_nt_impl(f, addr, addrsize, mask);
__sig_unblock(mask);

View file

@ -70,17 +70,20 @@ TEST(unix, datagram) {
munmap(ready, 1);
}
void StreamServer(atomic_bool *ready) {
void StreamServer(atomic_bool *ready, const char *path, bool check_path) {
char buf[256] = {0};
uint32_t len = sizeof(struct sockaddr_un);
struct sockaddr_un addr = {AF_UNIX, "foo.sock"};
struct sockaddr_un addr = {AF_UNIX};
strcpy(addr.sun_path, path);
ASSERT_SYS(0, 3, socket(AF_UNIX, SOCK_STREAM, 0));
unlink(path);
errno = 0;
ASSERT_SYS(0, 0, bind(3, (void *)&addr, len));
bzero(&addr, sizeof(addr));
ASSERT_SYS(0, 0, getsockname(3, (void *)&addr, &len));
ASSERT_EQ(2 + 8 + 1, len);
if (check_path) ASSERT_EQ(2 + strlen(path) + 1, len);
ASSERT_EQ(AF_UNIX, addr.sun_family);
ASSERT_STREQ("foo.sock", addr.sun_path);
if (check_path) ASSERT_STREQ(path, addr.sun_path);
ASSERT_SYS(0, 0, listen(3, 10));
bzero(&addr, sizeof(addr));
len = sizeof(addr);
@ -101,7 +104,7 @@ TEST(unix, stream) {
ASSERT_SYS(0, 3, socket(AF_UNIX, SOCK_STREAM, 0));
if (!fork()) {
close(3);
StreamServer(ready);
StreamServer(ready, "foo.sock", true);
_Exit(0);
}
while (!*ready) sched_yield();
@ -116,6 +119,29 @@ TEST(unix, stream) {
munmap(ready, 1);
}
TEST(unix, stream_absolute) {
int ws;
if (IsWindows() && !IsAtLeastWindows10()) return;
atomic_bool *ready = _mapshared(1);
// TODO(jart): move this line down when kFdProcess is gone
ASSERT_SYS(0, 3, socket(AF_UNIX, SOCK_STREAM, 0));
if (!fork()) {
close(3);
StreamServer(ready, "/tmp/foo.sock", !IsWindows());
_Exit(0);
}
while (!*ready) sched_yield();
uint32_t len = sizeof(struct sockaddr_un);
struct sockaddr_un addr = {AF_UNIX, "/tmp/foo.sock"};
ASSERT_SYS(0, 0, connect(3, (void *)&addr, len));
ASSERT_SYS(0, 5, write(3, "hello", 5));
ASSERT_SYS(0, 0, close(3));
ASSERT_NE(-1, wait(&ws));
EXPECT_TRUE(WIFEXITED(ws));
EXPECT_EQ(0, WEXITSTATUS(ws));
munmap(ready, 1);
}
TEST(unix, serverGoesDown_deletedSockFile) { // field of landmine
if (IsWindows()) return;
if (IsCygwin()) return;