[tls socket] use tls.TlsClient(fd)

This commit is contained in:
alan crouau 2024-09-09 17:15:26 +02:00
parent 66f57f0f34
commit 470ee52159
5 changed files with 130 additions and 98 deletions

View file

@ -100,7 +100,8 @@ TOOL_NET_REDBEAN_LUA_MODULES = \
o/$(MODE)/tool/net/lmaxmind.o \
o/$(MODE)/tool/net/lsqlite3.o \
o/$(MODE)/tool/net/largon2.o \
o/$(MODE)/tool/net/launch.o
o/$(MODE)/tool/net/launch.o \
o/$(MODE)/tool/net/ltls.o
o/$(MODE)/tool/net/redbean.dbg: \
$(TOOL_NET_DEPS) \

View file

@ -4992,11 +4992,12 @@ unix = {
local tls = {}
--- Creates a new TLS socket.
---@param fd integer File descriptor of the socket
---@param verify? boolean Whether to verify the server's certificate (default: true)
---@param timeout? integer Read timeout in milliseconds (default: 0, no timeout)
---@return TlsContext|nil context
---@return string? error
function tls.socket(verify, timeout) end
function tls.TlsClient(fd, verify, timeout) end
--- Connects to a server using TLS.
---@param context TlsContext

View file

@ -1,21 +1,51 @@
static const char *const tls_meta = ":mbedtls";
/*-*- mode:c;indent-tabs-mode:nil;c-basic-offset:2;tab-width:8;coding:utf-8 -*-│
vi: set et ft=c ts=2 sts=2 sw=2 fenc=utf-8 :vi
Copyright 2022 Justine Alexandra Roberts Tunney
Permission to use, copy, modify, and/or distribute this software for
any purpose with or without fee is hereby granted, provided that the
above copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL
WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE
AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL
DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR
PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER
TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
PERFORMANCE OF THIS SOFTWARE.
*/
#include "ltls.h"
#include "libc/calls/struct/iovec.h"
#include "third_party/mbedtls/ctr_drbg.h"
#include "third_party/mbedtls/debug.h"
#include "third_party/mbedtls/iana.h"
#include "third_party/mbedtls/net_sockets.h"
#include "third_party/mbedtls/oid.h"
#include "third_party/mbedtls/san.h"
#include "third_party/mbedtls/ssl.h"
#include "third_party/mbedtls/ssl_ticket.h"
#include "third_party/mbedtls/x509.h"
#include "third_party/mbedtls/x509_crt.h"
#include "third_party/mbedtls/entropy.h"
#include "net/https/https.h"
typedef enum {
TLS_STATE_INIT,
TLS_STATE_CONNECTED,
TLS_STATE_CLOSED
} TlsConnectionState;
#ifndef MIN
#define MIN(a,b) ( (a) < (b) ? (a) :(b) )
#endif
static const char *const tls_meta = ":mbedtls";
typedef struct {
mbedtls_entropy_context entropy;
mbedtls_ctr_drbg_context ctr_drbg;
mbedtls_ssl_context ssl;
mbedtls_ssl_config conf;
mbedtls_net_context server_fd;
int ref; // Reference to self in the Lua registry
TlsConnectionState connection_state;
char *read_buffer;
size_t read_buffer_size;
int fd; // File descriptor
} TlsContext;
static TlsContext **checktls(lua_State *L) {
@ -25,19 +55,53 @@ static TlsContext **checktls(lua_State *L) {
return tls;
}
static int TlsSend(void *c, const unsigned char *p, size_t n) {
int rc;
if ((rc = write(*(int *)c, p, n)) == -1) {
perror("TlsSend");
fprintf(stderr, "TlsSend error: rc=%d, c=%d\n", rc, *(int *)c);
exit(1);
}
return rc;
}
static int TlsRecv(void *c, unsigned char *p, size_t n, uint32_t o) {
int r;
struct iovec v[2];
static unsigned a, b;
static unsigned char t[4096];
if (a < b) {
r = MIN(n, b - a);
memcpy(p, t + a, r);
if ((a += r) == b) {
a = b = 0;
}
return r;
}
v[0].iov_base = p;
v[0].iov_len = n;
v[1].iov_base = t;
v[1].iov_len = sizeof(t);
if ((r = readv(*(int *)c, v, 2)) == -1) {
perror("TlsRecv");
exit(1);
}
if (r > n) {
b = r - n;
}
return MIN(n, r);
}
static int tls_gc(lua_State *L) {
TlsContext **tlsp = checktls(L);
TlsContext *tls = *tlsp;
if (tls) {
if (tls->connection_state != TLS_STATE_CLOSED) {
mbedtls_net_free(&tls->server_fd);
mbedtls_ssl_free(&tls->ssl);
mbedtls_ssl_config_free(&tls->conf);
mbedtls_ctr_drbg_free(&tls->ctr_drbg);
mbedtls_entropy_free(&tls->entropy);
}
mbedtls_ssl_free(&tls->ssl);
mbedtls_ssl_config_free(&tls->conf);
mbedtls_ctr_drbg_free(&tls->ctr_drbg);
mbedtls_entropy_free(&tls->entropy);
luaL_unref(L, LUA_REGISTRYINDEX, tls->ref);
free(tls->read_buffer);
free(tls);
@ -46,10 +110,16 @@ static int tls_gc(lua_State *L) {
return 0;
}
static int tls_socket(lua_State *L) {
if (!sslinitialized) {
TlsInit();
}
static void my_debug(void *ctx, int level, const char *file, int line,
const char *str) {
((void)level);
fprintf((FILE *)ctx, "%s:%04d: %s", file, line, str);
fflush((FILE *)ctx);
}
static int tls_client(lua_State *L) {
int fd = luaL_checkinteger(L, 1);
printf("fd: %d\n", fd);
TlsContext **tlsp = (TlsContext **)lua_newuserdata(L, sizeof(TlsContext *));
*tlsp = NULL;
@ -65,16 +135,16 @@ static int tls_socket(lua_State *L) {
}
*tlsp = tls;
tls->connection_state = TLS_STATE_INIT;
tls->read_buffer = NULL;
tls->read_buffer_size = 0;
tls->fd = fd;
mbedtls_net_init(&tls->server_fd);
mbedtls_ssl_init(&tls->ssl);
mbedtls_ssl_config_init(&tls->conf);
mbedtls_ctr_drbg_init(&tls->ctr_drbg);
mbedtls_entropy_init(&tls->entropy);
int sslVerify = lua_isnone(L, 1) ? 1 : lua_toboolean(L, 1);
int sslVerify = lua_isnone(L, 2) ? 1 : lua_toboolean(L, 2);
if (sslVerify) {
mbedtls_ssl_conf_ca_chain(&tls->conf, GetSslRoots(), 0);
mbedtls_ssl_conf_authmode(&tls->conf, MBEDTLS_SSL_VERIFY_REQUIRED);
@ -82,10 +152,10 @@ static int tls_socket(lua_State *L) {
mbedtls_ssl_conf_authmode(&tls->conf, MBEDTLS_SSL_VERIFY_NONE);
}
int timeout = lua_isnone(L, 2) ? 0 : luaL_checkinteger(L, 2);
int timeout = lua_isnone(L, 3) ? 0 : luaL_checkinteger(L, 3);
mbedtls_ssl_conf_read_timeout(&tls->conf, timeout);
const char *pers = "tls_socket";
const char *pers = "tls_client";
int ret;
if ((ret = mbedtls_ctr_drbg_seed(&tls->ctr_drbg, mbedtls_entropy_func,
&tls->entropy, (const unsigned char *)pers,
@ -97,6 +167,19 @@ static int tls_socket(lua_State *L) {
return 2;
}
if ((ret = mbedtls_ssl_config_defaults(&tls->conf, MBEDTLS_SSL_IS_CLIENT,
MBEDTLS_SSL_TRANSPORT_STREAM,
MBEDTLS_SSL_PRESET_DEFAULT)) != 0) {
free(tls);
*tlsp = NULL;
lua_pushnil(L);
lua_pushfstring(L, "mbedtls_ssl_config_defaults failed: %d", ret);
return 2;
}
mbedtls_ssl_conf_rng(&tls->conf, mbedtls_ctr_drbg_random, &tls->ctr_drbg);
mbedtls_ssl_conf_dbg(&tls->conf, my_debug, stdout);
if ((ret = mbedtls_ssl_setup(&tls->ssl, &tls->conf)) != 0) {
free(tls);
*tlsp = NULL;
@ -105,55 +188,7 @@ static int tls_socket(lua_State *L) {
return 2;
}
tls->ref = luaL_ref(L, LUA_REGISTRYINDEX);
lua_rawgeti(L, LUA_REGISTRYINDEX, tls->ref);
return 1;
}
static void my_debug(void *ctx, int level, const char *file, int line,
const char *str) {
((void)level);
fprintf((FILE *)ctx, "%s:%04d: %s", file, line, str);
fflush((FILE *)ctx);
}
static int tls_connect(lua_State *L) {
TlsContext **tlsp = checktls(L);
TlsContext *tls = *tlsp;
const char *server_name = luaL_checkstring(L, 2);
const char *server_port = luaL_checkstring(L, 3);
int ret;
if ((ret = mbedtls_net_connect(&tls->server_fd, server_name, server_port,
MBEDTLS_NET_PROTO_TCP)) != 0) {
lua_pushnil(L);
lua_pushfstring(L, "connect failed: %d", ret);
return 2;
}
if ((ret = mbedtls_ssl_config_defaults(&tls->conf, MBEDTLS_SSL_IS_CLIENT,
MBEDTLS_SSL_TRANSPORT_STREAM,
MBEDTLS_SSL_PRESET_DEFAULT)) != 0) {
lua_pushnil(L);
lua_pushfstring(L, "mbedtls_ssl_config_defaults failed: %d", ret);
return 2;
}
// mbedtls_ssl_conf_authmode(&tls->conf, MBEDTLS_SSL_VERIFY_NONE); // only for
// test, conf mbedtls_x509_crt_init instead
mbedtls_ssl_conf_rng(&tls->conf, mbedtls_ctr_drbg_random, &tls->ctr_drbg);
mbedtls_ssl_conf_dbg(&tls->conf, my_debug, stdout);
if ((ret = mbedtls_ssl_set_hostname(&tls->ssl, server_name)) != 0) {
lua_pushnil(L);
lua_pushfstring(L, "mbedtls_ssl_set_hostname failed: %d", ret);
return 2;
}
mbedtls_ssl_set_bio(&tls->ssl, &tls->server_fd, mbedtls_net_send, NULL,
mbedtls_net_recv_timeout);
mbedtls_ssl_set_bio(&tls->ssl, &tls->fd, TlsSend, 0, TlsRecv);
if ((ret = mbedtls_ssl_handshake(&tls->ssl)) != 0) {
lua_pushnil(L);
@ -161,9 +196,9 @@ static int tls_connect(lua_State *L) {
return 2;
}
tls->connection_state = TLS_STATE_CONNECTED;
tls->ref = luaL_ref(L, LUA_REGISTRYINDEX);
lua_rawgeti(L, LUA_REGISTRYINDEX, tls->ref);
lua_pushboolean(L, 1);
return 1;
}
@ -223,7 +258,6 @@ static int tls_close(lua_State *L) {
TlsContext **tlsp = checktls(L);
TlsContext *tls = *tlsp;
mbedtls_net_free(&tls->server_fd);
mbedtls_ssl_free(&tls->ssl);
mbedtls_ssl_config_free(&tls->conf);
mbedtls_ctr_drbg_free(&tls->ctr_drbg);
@ -231,7 +265,6 @@ static int tls_close(lua_State *L) {
free(tls->read_buffer);
tls->read_buffer = NULL;
tls->read_buffer_size = 0;
tls->connection_state = TLS_STATE_CLOSED;
return 0;
}
@ -239,37 +272,25 @@ static int tls_close(lua_State *L) {
static int tls_tostring(lua_State *L) {
TlsContext **tlsp = checktls(L);
TlsContext *tls = *tlsp;
const char *state_str;
switch (tls->connection_state) {
case TLS_STATE_INIT:
state_str = "initialized";
break;
case TLS_STATE_CONNECTED:
state_str = "connected";
break;
case TLS_STATE_CLOSED:
state_str = "closed";
break;
default:
state_str = "unknown";
}
lua_pushfstring(L, "TLS connection (%p): %s", tls, state_str);
lua_pushfstring(L, "tls.TlsClient(fd=%d)", tls->fd);
return 1;
}
static const struct luaL_Reg tls_methods[] = {
{"connect", tls_connect},
{"write", tls_write},
{"read", tls_read},
{"close", tls_close},
{"__gc", tls_gc},
{"__tostring", tls_tostring},
{"__repr", tls_tostring},
{NULL, NULL}
};
static const struct luaL_Reg tlslib[] = {{"socket", tls_socket}, {NULL, NULL}};
static const struct luaL_Reg tlslib[] = {
{"TlsClient", tls_client},
{NULL, NULL}
};
static void create_meta(lua_State *L, const char *name,
const luaL_Reg *methods) {

9
tool/net/ltls.h Normal file
View file

@ -0,0 +1,9 @@
#ifndef COSMOPOLITAN_TOOL_NET_LTLS_H_
#define COSMOPOLITAN_TOOL_NET_LTLS_H_
#include "third_party/lua/lauxlib.h"
COSMOPOLITAN_C_START_
int luaopen_tls(lua_State *);
COSMOPOLITAN_C_END_
#endif /* COSMOPOLITAN_TOOL_NET_LTLS_H_ */

View file

@ -137,6 +137,7 @@
#include "tool/net/lfuncs.h"
#include "tool/net/ljson.h"
#include "tool/net/lpath.h"
#include "tool/net/ltls.h"
#include "tool/net/luacheck.h"
#include "tool/net/sandbox.h"
@ -3978,7 +3979,6 @@ static int LuaNilTlsError(lua_State *L, const char *s, int r) {
}
#include "tool/net/fetch.inc"
#include "tool/net/ltls.inc"
static int LuaGetDate(lua_State *L) {
lua_pushinteger(L, shared->nowish.tv_sec);