Refactor out some duplicated code

This commit is contained in:
Justine Tunney 2021-08-14 06:17:56 -07:00
parent e963d9c8e3
commit 579b597ded
58 changed files with 1110 additions and 3214 deletions

View file

@ -36,8 +36,8 @@ TOOL_BUILD_DIRECTDEPS = \
LIBC_MEM \
LIBC_NEXGEN32E \
LIBC_NT_KERNEL32 \
LIBC_NT_WS2_32 \
LIBC_NT_USER32 \
LIBC_NT_WS2_32 \
LIBC_RAND \
LIBC_RUNTIME \
LIBC_SOCK \
@ -50,12 +50,13 @@ TOOL_BUILD_DIRECTDEPS = \
LIBC_TINYMATH \
LIBC_UNICODE \
LIBC_X \
NET_HTTPS \
THIRD_PARTY_COMPILER_RT \
THIRD_PARTY_GDTOA \
THIRD_PARTY_GETOPT \
THIRD_PARTY_MBEDTLS \
THIRD_PARTY_STB \
THIRD_PARTY_XED \
THIRD_PARTY_MBEDTLS \
THIRD_PARTY_ZLIB \
TOOL_BUILD_LIB

View file

@ -33,6 +33,7 @@
#include "libc/nexgen32e/kcpuids.h"
#include "libc/runtime/runtime.h"
#include "libc/stdio/append.internal.h"
#include "libc/stdio/stdio.h"
#include "libc/str/str.h"
#include "libc/sysv/consts/rlimit.h"
#include "libc/sysv/consts/sig.h"
@ -353,7 +354,7 @@ int main(int argc, char *argv[]) {
/*
* parse prefix arguments
*/
while ((opt = getopt(argc, argv, "?hntC:M:F:A:T:V:")) != -1) {
while ((opt = getopt(argc, argv, "hntC:M:F:A:T:V:")) != -1) {
switch (opt) {
case 'n':
exit(0);
@ -378,17 +379,16 @@ int main(int argc, char *argv[]) {
case 'F':
fszquota = sizetol(optarg, 1000);
break;
case '?':
case 'h':
write(1, MANUAL, sizeof(MANUAL) - 1);
fputs(MANUAL, stdout);
exit(0);
default:
write(2, MANUAL, sizeof(MANUAL) - 1);
fputs(MANUAL, stderr);
exit(1);
}
}
if (optind == argc) {
write(2, MANUAL, sizeof(MANUAL) - 1);
fputs("error: missing arguments\n", stderr);
exit(1);
}

View file

@ -45,6 +45,7 @@ TOOL_BUILD_LIB_A_DIRECTDEPS = \
LIBC_TINYMATH \
LIBC_UNICODE \
LIBC_X \
NET_HTTPS \
THIRD_PARTY_COMPILER_RT \
THIRD_PARTY_MBEDTLS \
THIRD_PARTY_XED

View file

@ -25,9 +25,11 @@
#include "libc/sock/sock.h"
#include "libc/sysv/consts/sig.h"
#include "libc/x/x.h"
#include "net/https/https.h"
#include "third_party/mbedtls/ctr_drbg.h"
#include "third_party/mbedtls/ecp.h"
#include "third_party/mbedtls/error.h"
#include "third_party/mbedtls/platform.h"
#include "third_party/mbedtls/ssl.h"
#include "tool/build/lib/eztls.h"
#include "tool/build/lib/psk.h"
@ -37,34 +39,6 @@ mbedtls_ssl_config ezconf;
mbedtls_ssl_context ezssl;
mbedtls_ctr_drbg_context ezrng;
static char *EzTlsError(int r) {
static char b[128];
mbedtls_strerror(r, b, sizeof(b));
return b;
}
void EzTlsDie(const char *s, int r) {
if (IsTiny()) {
fprintf(stderr, "error: %s (-0x%04x %s)\n", s, -r, EzTlsError(r));
} else {
fprintf(stderr, "error: %s (grep -0x%04x)\n", s, -r);
}
exit(1);
}
static int EzGetEntropy(void *c, unsigned char *p, size_t n) {
CHECK_EQ(n, getrandom(p, n, 0));
return 0;
}
static void EzInitializeRng(mbedtls_ctr_drbg_context *r) {
volatile unsigned char b[64];
mbedtls_ctr_drbg_init(r);
CHECK(getrandom(b, 64, 0) == 64);
CHECK(!mbedtls_ctr_drbg_seed(r, EzGetEntropy, 0, b, 64));
mbedtls_platform_zeroize(b, 64);
}
static ssize_t EzWritevAll(int fd, struct iovec *iov, int iovlen) {
int i;
ssize_t rc;
@ -165,34 +139,38 @@ static int EzTlsRecv(void *ctx, unsigned char *buf, size_t len, uint32_t tmo) {
return EzTlsRecvImpl(ctx, buf, len, tmo);
}
void EzFd(int fd) {
mbedtls_ssl_session_reset(&ezssl);
mbedtls_platform_zeroize(&ezbio, sizeof(ezbio));
ezbio.fd = fd;
}
void EzHandshake(void) {
int rc;
while ((rc = mbedtls_ssl_handshake(&ezssl))) {
if (rc != MBEDTLS_ERR_SSL_WANT_READ) {
EzTlsDie("handshake failed", rc);
TlsDie("handshake failed", rc);
}
}
while ((rc = EzTlsFlush(&ezbio, 0, 0))) {
if (rc != MBEDTLS_ERR_SSL_WANT_READ) {
EzTlsDie("handshake flush failed", rc);
TlsDie("handshake flush failed", rc);
}
}
}
/*
* openssl s_client -connect 127.0.0.1:31337 \
* -psk $(hex <~/.runit.psk) \
* -psk_identity runit
*/
void SetupPresharedKeySsl(int endpoint) {
void EzInitialize(void) {
xsigaction(SIGPIPE, SIG_IGN, 0, 0, 0);
EzInitializeRng(&ezrng);
ezconf.disable_compression = 1; /* TODO(jart): Why does it behave weirdly? */
mbedtls_ssl_config_defaults(&ezconf, endpoint, MBEDTLS_SSL_TRANSPORT_STREAM,
MBEDTLS_SSL_PRESET_SUITEC);
InitializeRng(&ezrng);
}
void EzSetup(char psk[32]) {
int rc;
mbedtls_ssl_conf_rng(&ezconf, mbedtls_ctr_drbg_random, &ezrng);
DCHECK_EQ(0, mbedtls_ssl_conf_psk(&ezconf, GetRunitPsk(), 32, "runit", 5));
DCHECK_EQ(0, mbedtls_ssl_setup(&ezssl, &ezconf));
if ((rc = mbedtls_ssl_conf_psk(&ezconf, psk, 32, "runit", 5)) ||
(rc = mbedtls_ssl_setup(&ezssl, &ezconf))) {
TlsDie("EzSetup", rc);
}
mbedtls_ssl_set_bio(&ezssl, &ezbio, EzTlsSend, 0, EzTlsRecv);
}

View file

@ -2,6 +2,7 @@
#define COSMOPOLITAN_TOOL_BUILD_LIB_EZTLS_H_
#include "third_party/mbedtls/ctr_drbg.h"
#include "third_party/mbedtls/ssl.h"
#include "third_party/mbedtls/x509_crt.h"
#if !(__ASSEMBLER__ + __LINKER__ + 0)
COSMOPOLITAN_C_START_
@ -17,11 +18,24 @@ extern mbedtls_ssl_config ezconf;
extern mbedtls_ssl_context ezssl;
extern mbedtls_ctr_drbg_context ezrng;
void EzFd(int);
void EzHandshake(void);
void SetupPresharedKeySsl(int);
void EzTlsDie(const char *, int) wontreturn;
void EzSetup(char[32]);
void EzInitialize(void);
int EzTlsFlush(struct EzTlsBio *, const unsigned char *, size_t);
/*
* openssl s_client -connect 127.0.0.1:31337 \
* -psk $(hex <~/.runit.psk) \
* -psk_identity runit
*/
forceinline void SetupPresharedKeySsl(int endpoint, char psk[32]) {
EzInitialize();
mbedtls_ssl_config_defaults(&ezconf, endpoint, MBEDTLS_SSL_TRANSPORT_STREAM,
MBEDTLS_SSL_PRESET_SUITEC);
EzSetup(psk);
}
COSMOPOLITAN_C_END_
#endif /* !(__ASSEMBLER__ + __LINKER__ + 0) */
#endif /* COSMOPOLITAN_TOOL_BUILD_LIB_EZTLS_H_ */

View file

@ -42,8 +42,10 @@
#include "libc/sysv/consts/sock.h"
#include "libc/time/time.h"
#include "libc/x/x.h"
#include "net/https/https.h"
#include "third_party/mbedtls/ssl.h"
#include "tool/build/lib/eztls.h"
#include "tool/build/lib/psk.h"
#include "tool/build/runit.h"
/**
@ -336,7 +338,7 @@ bool Recv(unsigned char *p, size_t n) {
usleep((backoff = (backoff + 1000) * 2));
return false;
} else if (rc < 0) {
EzTlsDie("read response failed", rc);
TlsDie("read response failed", rc);
}
}
return true;
@ -387,9 +389,8 @@ int RunOnHost(char *spec) {
1);
if (!strchr(g_hostname, '.')) strcat(g_hostname, ".test.");
do {
mbedtls_ssl_session_reset(&ezssl);
Connect();
ezbio.fd = g_sock;
EzFd(g_sock);
EzHandshake();
SendRequest();
} while ((rc = ReadResponse()) == -1);
@ -454,7 +455,7 @@ int RunRemoteTestsInParallel(char *hosts[], int count) {
int main(int argc, char *argv[]) {
showcrashreports();
SetupPresharedKeySsl(MBEDTLS_SSL_IS_CLIENT);
SetupPresharedKeySsl(MBEDTLS_SSL_IS_CLIENT, GetRunitPsk());
/* __log_level = kLogDebug; */
if (argc > 1 &&
(strcmp(argv[1], "-h") == 0 || strcmp(argv[1], "--help") == 0)) {

View file

@ -45,9 +45,11 @@
#include "libc/sysv/consts/w.h"
#include "libc/time/time.h"
#include "libc/x/x.h"
#include "net/https/https.h"
#include "third_party/getopt/getopt.h"
#include "third_party/mbedtls/ssl.h"
#include "tool/build/lib/eztls.h"
#include "tool/build/lib/psk.h"
#include "tool/build/runit.h"
/**
@ -264,13 +266,13 @@ void HandleClient(void) {
close(g_clifd);
return;
}
ezbio.fd = g_clifd;
EzFd(g_clifd);
EzHandshake();
addrstr = gc(DescribeAddress(&addr));
DEBUGF("%s %s %s", gc(DescribeAddress(&g_servaddr)), "accepted", addrstr);
while ((got = mbedtls_ssl_read(&ezssl, (p = g_buf), sizeof(g_buf))) < 0) {
if (got != MBEDTLS_ERR_SSL_WANT_READ) {
EzTlsDie("ssl read failed", got);
TlsDie("ssl read failed", got);
}
}
CHECK_GE(got, kMinMsgSize);
@ -302,7 +304,7 @@ void HandleClient(void) {
while (remaining) {
while ((got = mbedtls_ssl_read(&ezssl, g_buf, sizeof(g_buf))) < 0) {
if (got != MBEDTLS_ERR_SSL_WANT_READ) {
EzTlsDie("ssl read failed", got);
TlsDie("ssl read failed", got);
}
}
CHECK_LE(got, remaining);
@ -443,7 +445,7 @@ void Daemonize(void) {
int main(int argc, char *argv[]) {
showcrashreports();
SetupPresharedKeySsl(MBEDTLS_SSL_IS_SERVER);
SetupPresharedKeySsl(MBEDTLS_SSL_IS_SERVER, GetRunitPsk());
/* __log_level = kLogDebug; */
GetOpts(argc, argv);
CHECK_NE(-1, (g_devnullfd = open("/dev/null", O_RDWR)));

View file

@ -124,34 +124,10 @@ void GetOpts(int *argc, char ***argv) {
CHECK_NOTNULL(outpath_);
}
bool IsUtf8(const void *data, size_t size) {
const unsigned char *p, *pe;
for (p = data, pe = p + size; p + 2 <= pe; ++p) {
if (p[0] >= 0300) {
if (p[1] >= 0200 && p[1] < 0300) {
return true;
} else {
return false;
}
}
}
return false;
}
bool IsText(const void *data, size_t size) {
const unsigned char *p, *pe;
for (p = data, pe = p + size; p < pe; ++p) {
if (*p <= 3) {
return false;
}
}
return true;
}
bool ShouldCompress(const char *name, size_t namesize,
const unsigned char *data, size_t datasize) {
return !nocompress_ && datasize >= 64 && !IsNoCompressExt(name, namesize) &&
(datasize < 1000 || MeasureEntropy((void *)data, 1000) < 6);
(datasize < 1000 || MeasureEntropy((void *)data, 1000) < 7);
}
void GetDosLocalTime(int64_t utcunixts, uint16_t *out_time,