Merge parse aes parse options functions

This commit is contained in:
Miguel Terron 2025-06-04 21:42:17 +12:00
parent d06d0879b8
commit e35f99c7db

View file

@ -1,18 +1,18 @@
#include "libc/log/log.h"
#include "net/https/https.h"
#include "third_party/lua/lauxlib.h"
#include "third_party/mbedtls/aes.h"
#include "third_party/mbedtls/base64.h"
#include "third_party/mbedtls/ctr_drbg.h"
#include "third_party/mbedtls/ecdsa.h"
#include "third_party/mbedtls/entropy.h"
#include "third_party/mbedtls/error.h"
#include "third_party/mbedtls/gcm.h"
#include "third_party/mbedtls/md.h"
#include "third_party/mbedtls/oid.h"
#include "third_party/mbedtls/pk.h"
#include "third_party/mbedtls/rsa.h"
#include "third_party/mbedtls/ecdsa.h"
#include "third_party/mbedtls/x509_csr.h"
#include "third_party/mbedtls/oid.h"
#include "third_party/mbedtls/md.h"
#include "third_party/mbedtls/base64.h"
#include "third_party/mbedtls/aes.h"
#include "third_party/mbedtls/ctr_drbg.h"
#include "third_party/mbedtls/entropy.h"
#include "third_party/mbedtls/gcm.h"
// Standard C library and redbean utilities
#include "libc/errno.h"
@ -29,8 +29,10 @@ static int LuaConvertPemToJwk(lua_State *L) {
int ret;
// Parse the PEM key
if ((ret = mbedtls_pk_parse_key(&key, (const unsigned char *)pem_key, strlen(pem_key) + 1, NULL, 0)) != 0 &&
(ret = mbedtls_pk_parse_public_key(&key, (const unsigned char *)pem_key, strlen(pem_key) + 1)) != 0) {
if ((ret = mbedtls_pk_parse_key(&key, (const unsigned char *)pem_key,
strlen(pem_key) + 1, NULL, 0)) != 0 &&
(ret = mbedtls_pk_parse_public_key(&key, (const unsigned char *)pem_key,
strlen(pem_key) + 1)) != 0) {
lua_pushnil(L);
lua_pushfstring(L, "Failed to parse PEM key: -0x%04x", -ret);
mbedtls_pk_free(&key);
@ -80,8 +82,10 @@ static int LuaConvertPemToJwk(lua_State *L) {
return 2;
}
mbedtls_base64_encode((unsigned char *)n_b64, n_b64_len, &n_b64_len, n, n_len);
mbedtls_base64_encode((unsigned char *)e_b64, e_b64_len, &e_b64_len, e, e_len);
mbedtls_base64_encode((unsigned char *)n_b64, n_b64_len, &n_b64_len, n,
n_len);
mbedtls_base64_encode((unsigned char *)e_b64, e_b64_len, &e_b64_len, e,
e_len);
n_b64[n_b64_len] = '\0';
e_b64[e_b64_len] = '\0';
@ -139,8 +143,10 @@ static int LuaConvertPemToJwk(lua_State *L) {
return 2;
}
mbedtls_base64_encode((unsigned char *)x_b64, x_b64_len, &x_b64_len, x, x_len);
mbedtls_base64_encode((unsigned char *)y_b64, y_b64_len, &y_b64_len, y, y_len);
mbedtls_base64_encode((unsigned char *)x_b64, x_b64_len, &x_b64_len, x,
x_len);
mbedtls_base64_encode((unsigned char *)y_b64, y_b64_len, &y_b64_len, y,
y_len);
x_b64[x_b64_len] = '\0';
y_b64[y_b64_len] = '\0';
@ -181,7 +187,6 @@ static int LuaGenerateCSR(lua_State *L) {
subject_name = luaL_checkstring(L, 2);
}
if (lua_isnoneornil(L, 3) && subject_name[0] == '\0') {
lua_pushnil(L);
lua_pushstring(L, "Subject name or SANs are required");
@ -195,7 +200,8 @@ static int LuaGenerateCSR(lua_State *L) {
mbedtls_pk_init(&key);
mbedtls_x509write_csr_init(&req);
if ((ret = mbedtls_pk_parse_key(&key, (const unsigned char *)key_pem, strlen(key_pem) + 1, NULL, 0)) != 0) {
if ((ret = mbedtls_pk_parse_key(&key, (const unsigned char *)key_pem,
strlen(key_pem) + 1, NULL, 0)) != 0) {
lua_pushnil(L);
lua_pushfstring(L, "Failed to parse key: %d", ret);
return 2;
@ -206,14 +212,18 @@ static int LuaGenerateCSR(lua_State *L) {
mbedtls_x509write_csr_set_md_alg(&req, MBEDTLS_MD_SHA256);
if (san_list) {
if ((ret = mbedtls_x509write_csr_set_extension(&req, MBEDTLS_OID_SUBJECT_ALT_NAME, MBEDTLS_OID_SIZE(MBEDTLS_OID_SUBJECT_ALT_NAME), (const unsigned char *)san_list, strlen(san_list))) != 0) {
if ((ret = mbedtls_x509write_csr_set_extension(
&req, MBEDTLS_OID_SUBJECT_ALT_NAME,
MBEDTLS_OID_SIZE(MBEDTLS_OID_SUBJECT_ALT_NAME),
(const unsigned char *)san_list, strlen(san_list))) != 0) {
lua_pushnil(L);
lua_pushfstring(L, "Failed to set SANs: %d", ret);
return 2;
}
}
if ((ret = mbedtls_x509write_csr_pem(&req, (unsigned char *)buf, sizeof(buf), NULL, NULL)) < 0) {
if ((ret = mbedtls_x509write_csr_pem(&req, (unsigned char *)buf, sizeof(buf),
NULL, NULL)) < 0) {
lua_pushnil(L);
lua_pushfstring(L, "Failed to write CSR: %d", ret);
return 2;
@ -281,15 +291,6 @@ static bool RSAGenerateKeyPair(char **private_key_pem, size_t *private_key_len,
mbedtls_pk_free(&key);
return true;
}
/**
* Lua wrapper for RSA key pair generation
*
* Lua function signature: RSAGenerateKeyPair([key_length])
* @param L Lua state
* @return 2 on success (private_key, public_key), 2 on failure (nil,
* error_message)
*/
static int LuaRSAGenerateKeyPair(lua_State *L) {
int bits = 2048;
// If no arguments, or first argument is nil, default to 2048
@ -305,7 +306,8 @@ static int LuaRSAGenerateKeyPair(lua_State *L) {
size_t private_len, public_len;
// Call the C function to generate the key pair
if (!RSAGenerateKeyPair(&private_key, &private_len, &public_key, &public_len, bits)) {
if (!RSAGenerateKeyPair(&private_key, &private_len, &public_key, &public_len,
bits)) {
lua_pushnil(L);
lua_pushstring(L, "Failed to generate RSA key pair");
return 2;
@ -383,7 +385,8 @@ static int LuaRSAEncrypt(lua_State *L) {
// Args: key, plaintext, options table
size_t keylen, ptlen;
const char *key = luaL_checklstring(L, 1, &keylen);
const unsigned char *plaintext = (const unsigned char *)luaL_checklstring(L, 2, &ptlen);
const unsigned char *plaintext =
(const unsigned char *)luaL_checklstring(L, 2, &ptlen);
// int options_idx = 3;
// const char *padding = parse_rsa_options(L, options_idx);
size_t out_len;
@ -402,8 +405,8 @@ static int LuaRSAEncrypt(lua_State *L) {
}
static char *RSADecrypt(const char *private_key_pem,
const unsigned char *encrypted_data, size_t encrypted_len,
size_t *out_len) {
const unsigned char *encrypted_data,
size_t encrypted_len, size_t *out_len) {
int rc;
// Parse private key
@ -450,7 +453,8 @@ static int LuaRSADecrypt(lua_State *L) {
// Args: key, ciphertext, options table
size_t keylen, ctlen;
const char *key = luaL_checklstring(L, 1, &keylen);
const unsigned char *ciphertext = (const unsigned char *)luaL_checklstring(L, 2, &ctlen);
const unsigned char *ciphertext =
(const unsigned char *)luaL_checklstring(L, 2, &ctlen);
// int options_idx = 3;
// const char *padding = parse_rsa_options(L, options_idx);
size_t out_len;
@ -470,7 +474,8 @@ static int LuaRSADecrypt(lua_State *L) {
// RSA Signing
static char *RSASign(const char *private_key_pem, const unsigned char *data,
size_t data_len, const char *hash_algo_str, size_t *sig_len) {
size_t data_len, const char *hash_algo_str,
size_t *sig_len) {
int rc;
unsigned char hash[64]; // Large enough for SHA-512
size_t hash_len = 32; // Default for SHA-256
@ -1089,7 +1094,8 @@ static int LuaECDSAVerify(lua_State *L) {
const char *pub_key_pem = luaL_checkstring(L, 1);
const char *message = luaL_checkstring(L, 2);
size_t sig_len;
const unsigned char *signature = (const unsigned char *)luaL_checklstring(L, 3, &sig_len);
const unsigned char *signature =
(const unsigned char *)luaL_checklstring(L, 3, &sig_len);
const char *hash_name = luaL_optstring(L, 4, "sha256");
hash_algorithm_t hash_alg = string_to_hash_alg(hash_name);
@ -1100,7 +1106,6 @@ static int LuaECDSAVerify(lua_State *L) {
return 1;
}
// AES
// AES key generation helper
@ -1110,7 +1115,8 @@ static int LuaAesGenerateKey(lua_State *L) {
keybits = luaL_checkinteger(L, 1);
}
int keylen = keybits / 8;
if ((keybits != 128 && keybits != 192 && keybits != 256) || (keylen != 16 && keylen != 24 && keylen != 32)) {
if ((keybits != 128 && keybits != 192 && keybits != 256) ||
(keylen != 16 && keylen != 24 && keylen != 32)) {
lua_pushnil(L);
lua_pushstring(L, "AES key length must be 128, 192, or 256 bits");
return 2;
@ -1121,7 +1127,8 @@ static int LuaAesGenerateKey(lua_State *L) {
mbedtls_entropy_init(&entropy);
mbedtls_ctr_drbg_init(&ctr_drbg);
const char *pers = "aes_keygen";
int ret = mbedtls_ctr_drbg_seed(&ctr_drbg, mbedtls_entropy_func, &entropy, (const unsigned char *)pers, strlen(pers));
int ret = mbedtls_ctr_drbg_seed(&ctr_drbg, mbedtls_entropy_func, &entropy,
(const unsigned char *)pers, strlen(pers));
if (ret != 0) {
lua_pushnil(L);
lua_pushstring(L, "Failed to initialize RNG for AES key");
@ -1142,29 +1149,6 @@ static int LuaAesGenerateKey(lua_State *L) {
}
// Helper to get string field from options table
typedef struct {
const char *mode;
const unsigned char *iv;
size_t ivlen;
} aes_options_t;
static void parse_aes_options(lua_State *L, int options_idx, aes_options_t *opts) {
opts->mode = "cbc";
opts->iv = NULL;
opts->ivlen = 0;
if (lua_istable(L, options_idx)) {
lua_getfield(L, options_idx, "mode");
if (!lua_isnil(L, -1)) opts->mode = lua_tostring(L, -1);
lua_pop(L, 1);
lua_getfield(L, options_idx, "iv");
if (lua_isstring(L, -1)) {
opts->iv = (const unsigned char *)lua_tolstring(L, -1, &opts->ivlen);
}
lua_pop(L, 1);
}
}
// Helper for AES decrypt options
typedef struct {
const char *mode;
const unsigned char *iv;
@ -1173,9 +1157,10 @@ typedef struct {
size_t taglen;
const unsigned char *aad;
size_t aadlen;
} aes_decrypt_options_t;
} aes_options_t;
static void parse_aes_decrypt_options(lua_State *L, int options_idx, aes_decrypt_options_t *opts) {
static void parse_aes_options(lua_State *L, int options_idx,
aes_options_t *opts) {
opts->mode = "cbc";
opts->iv = NULL;
opts->ivlen = 0;
@ -1185,7 +1170,8 @@ static void parse_aes_decrypt_options(lua_State *L, int options_idx, aes_decrypt
opts->aadlen = 0;
if (lua_istable(L, options_idx)) {
lua_getfield(L, options_idx, "mode");
if (!lua_isnil(L, -1)) opts->mode = lua_tostring(L, -1);
if (!lua_isnil(L, -1))
opts->mode = lua_tostring(L, -1);
lua_pop(L, 1);
lua_getfield(L, options_idx, "iv");
if (lua_isstring(L, -1)) {
@ -1209,8 +1195,10 @@ static void parse_aes_decrypt_options(lua_State *L, int options_idx, aes_decrypt
static int LuaAesEncrypt(lua_State *L) {
// Args: key, plaintext, options table
size_t keylen, ptlen;
const unsigned char *key = (const unsigned char *)luaL_checklstring(L, 1, &keylen);
const unsigned char *plaintext = (const unsigned char *)luaL_checklstring(L, 2, &ptlen);
const unsigned char *key =
(const unsigned char *)luaL_checklstring(L, 1, &keylen);
const unsigned char *plaintext =
(const unsigned char *)luaL_checklstring(L, 2, &ptlen);
int options_idx = 3;
aes_options_t opts;
parse_aes_options(L, options_idx, &opts);
@ -1290,7 +1278,8 @@ static int LuaAesEncrypt(lua_State *L) {
}
unsigned char iv_copy[16];
memcpy(iv_copy, iv, 16);
ret = mbedtls_aes_crypt_cbc(&aes, MBEDTLS_AES_ENCRYPT, ctlen, iv_copy, input, output);
ret = mbedtls_aes_crypt_cbc(&aes, MBEDTLS_AES_ENCRYPT, ctlen, iv_copy,
input, output);
mbedtls_aes_free(&aes);
free(input);
if (ret != 0) {
@ -1302,7 +1291,8 @@ static int LuaAesEncrypt(lua_State *L) {
lua_pushlstring(L, (const char *)output, ctlen);
lua_pushlstring(L, (const char *)iv, ivlen);
free(output);
if (iv_was_generated) free(gen_iv);
if (iv_was_generated)
free(gen_iv);
return 2;
} else if (is_ctr) {
// CTR mode: no padding
@ -1327,7 +1317,8 @@ static int LuaAesEncrypt(lua_State *L) {
size_t nc_off = 0;
memcpy(nonce_counter, iv, 16);
memset(stream_block, 0, 16);
ret = mbedtls_aes_crypt_ctr(&aes, ptlen, &nc_off, nonce_counter, stream_block, plaintext, output);
ret = mbedtls_aes_crypt_ctr(&aes, ptlen, &nc_off, nonce_counter,
stream_block, plaintext, output);
mbedtls_aes_free(&aes);
if (ret != 0) {
free(output);
@ -1338,7 +1329,8 @@ static int LuaAesEncrypt(lua_State *L) {
lua_pushlstring(L, (const char *)output, ptlen);
lua_pushlstring(L, (const char *)iv, ivlen);
free(output);
if (iv_was_generated) free(gen_iv);
if (iv_was_generated)
free(gen_iv);
return 2;
} else if (is_gcm) {
// GCM mode: authenticated encryption
@ -1360,8 +1352,8 @@ static int LuaAesEncrypt(lua_State *L) {
lua_pushstring(L, "Failed to set AES GCM key");
return 2;
}
// Use actual ivlen, not hardcoded 16
ret = mbedtls_gcm_crypt_and_tag(&gcm, MBEDTLS_GCM_ENCRYPT, ptlen, iv, ivlen, NULL, 0, plaintext, output, taglen, tag);
ret = mbedtls_gcm_crypt_and_tag(&gcm, MBEDTLS_GCM_ENCRYPT, ptlen, iv, ivlen,
NULL, 0, plaintext, output, taglen, tag);
mbedtls_gcm_free(&gcm);
if (ret != 0) {
free(output);
@ -1373,7 +1365,8 @@ static int LuaAesEncrypt(lua_State *L) {
lua_pushlstring(L, (const char *)iv, ivlen);
lua_pushlstring(L, (const char *)tag, taglen);
free(output);
if (iv_was_generated) free(gen_iv);
if (iv_was_generated)
free(gen_iv);
return 3;
}
lua_pushnil(L);
@ -1385,11 +1378,13 @@ static int LuaAesEncrypt(lua_State *L) {
static int LuaAesDecrypt(lua_State *L) {
// Args: key, ciphertext, options table
size_t keylen, ctlen;
const unsigned char *key = (const unsigned char *)luaL_checklstring(L, 1, &keylen);
const unsigned char *ciphertext = (const unsigned char *)luaL_checklstring(L, 2, &ctlen);
const unsigned char *key =
(const unsigned char *)luaL_checklstring(L, 1, &keylen);
const unsigned char *ciphertext =
(const unsigned char *)luaL_checklstring(L, 2, &ctlen);
int options_idx = 3;
aes_decrypt_options_t opts;
parse_aes_decrypt_options(L, options_idx, &opts);
aes_options_t opts;
parse_aes_options(L, options_idx, &opts);
const char *mode = opts.mode;
const unsigned char *iv = opts.iv;
size_t ivlen = opts.ivlen;
@ -1467,7 +1462,8 @@ static int LuaAesDecrypt(lua_State *L) {
}
unsigned char iv_copy[16];
memcpy(iv_copy, iv, 16);
ret = mbedtls_aes_crypt_cbc(&aes, MBEDTLS_AES_DECRYPT, ctlen, iv_copy, ciphertext, output);
ret = mbedtls_aes_crypt_cbc(&aes, MBEDTLS_AES_DECRYPT, ctlen, iv_copy,
ciphertext, output);
mbedtls_aes_free(&aes);
if (ret != 0) {
free(output);
@ -1524,7 +1520,8 @@ static int LuaAesDecrypt(lua_State *L) {
size_t nc_off = 0;
memcpy(nonce_counter, iv, 16);
memset(stream_block, 0, 16);
ret = mbedtls_aes_crypt_ctr(&aes, ctlen, &nc_off, nonce_counter, stream_block, ciphertext, output);
ret = mbedtls_aes_crypt_ctr(&aes, ctlen, &nc_off, nonce_counter,
stream_block, ciphertext, output);
mbedtls_aes_free(&aes);
if (ret != 0) {
free(output);
@ -1553,7 +1550,8 @@ static int LuaAesDecrypt(lua_State *L) {
lua_pushstring(L, "Failed to set AES GCM key");
return 2;
}
ret = mbedtls_gcm_auth_decrypt(&gcm, ctlen, iv, ivlen, aad, aadlen, tag, taglen, ciphertext, output);
ret = mbedtls_gcm_auth_decrypt(&gcm, ctlen, iv, ivlen, aad, aadlen, tag,
taglen, ciphertext, output);
mbedtls_gcm_free(&gcm);
if (ret != 0) {
free(output);
@ -1572,8 +1570,10 @@ static int LuaAesDecrypt(lua_State *L) {
// LuaCrypto compatible API
static int LuaCryptoSign(lua_State *L) {
const char *dtype = luaL_checkstring(L, 1); // Type of signature (e.g., "rsa", "ecdsa")
lua_remove(L, 1); // Remove the first argument (key type or cipher type) before dispatching
const char *dtype =
luaL_checkstring(L, 1); // Type of signature (e.g., "rsa", "ecdsa")
lua_remove(L, 1); // Remove the first argument (key type or cipher type)
// before dispatching
if (strcasecmp(dtype, "rsa") == 0) {
return LuaRSASign(L);
@ -1585,8 +1585,10 @@ static int LuaCryptoSign(lua_State *L) {
}
static int LuaCryptoVerify(lua_State *L) {
const char *dtype = luaL_checkstring(L, 1); // Type of signature (e.g., "rsa", "ecdsa")
lua_remove(L, 1); // Remove the first argument (key type or cipher type) before dispatching
const char *dtype =
luaL_checkstring(L, 1); // Type of signature (e.g., "rsa", "ecdsa")
lua_remove(L, 1); // Remove the first argument (key type or cipher type)
// before dispatching
if (strcasecmp(dtype, "rsa") == 0) {
return LuaRSAVerify(L);
@ -1600,10 +1602,10 @@ static int LuaCryptoVerify(lua_State *L) {
static int LuaCryptoEncrypt(lua_State *L) {
// Args: cipher_type, key, msg, options table
const char *cipher = luaL_checkstring(L, 1);
// Remove cipher_type from stack, so key is at 1, msg at 2, options at 3
lua_remove(L, 1);
lua_remove(L, 1); // Remove cipher_type from stack, so key is at 1, msg at 2,
// options at 3
if (strcasecmp(cipher, "rsa") == 0) {
// Update LuaRSAEncrypt to accept (key, msg, options)
return LuaRSAEncrypt(L);
} else if (strcasecmp(cipher, "aes") == 0) {
return LuaAesEncrypt(L);
@ -1615,7 +1617,10 @@ static int LuaCryptoEncrypt(lua_State *L) {
static int LuaCryptoDecrypt(lua_State *L) {
// Args: cipher_type, key, ciphertext, options table
const char *cipher = luaL_checkstring(L, 1);
lua_remove(L, 1); // Remove cipher_type, so key is at 1, ciphertext at 2, options at 3
lua_remove(
L,
1); // Remove cipher_type, so key is at 1, ciphertext at 2, options at 3
if (strcasecmp(cipher, "rsa") == 0) {
return LuaRSADecrypt(L);
} else if (strcasecmp(cipher, "aes") == 0) {
@ -1626,12 +1631,13 @@ static int LuaCryptoDecrypt(lua_State *L) {
}
static int LuaCryptoGenerateKeyPair(lua_State *L) {
// If the first argument is a number, treat as RSA key length
// If the first argument is a number, treat it as RSA key length
if (lua_gettop(L) >= 1 && lua_type(L, 1) == LUA_TNUMBER) {
// Call LuaRSAGenerateKeyPair with the number as the key length
return LuaRSAGenerateKeyPair(L);
}
// Otherwise, get the key type from the first argument, default to "rsa" if not provided
// Otherwise, get the key type from the first argument, default to "rsa" if
// not provided
const char *type = luaL_optstring(L, 1, "rsa");
lua_remove(L, 1);
if (strcasecmp(type, "rsa") == 0) {