From 55dcce4f7d6b16cbaf8fc47709c82a4d24f2758c Mon Sep 17 00:00:00 2001 From: Miguel Terron Date: Sun, 1 Jun 2025 21:23:34 +1200 Subject: [PATCH 01/18] Add LuaCrypto compatible functions (plus some auxiliary functions) as per #1136 --- third_party/mbedtls/config.h | 2 +- tool/net/BUILD.mk | 1 + tool/net/lcrypto.c | 1147 ++++++++++++++++++++++++++++++++++ tool/net/lcrypto.h | 10 + tool/net/lfuncs.h | 1 + tool/net/redbean.c | 1 + 6 files changed, 1161 insertions(+), 1 deletion(-) create mode 100644 tool/net/lcrypto.c create mode 100644 tool/net/lcrypto.h diff --git a/third_party/mbedtls/config.h b/third_party/mbedtls/config.h index 24f2c227b..88087503e 100644 --- a/third_party/mbedtls/config.h +++ b/third_party/mbedtls/config.h @@ -71,10 +71,10 @@ /* eliptic curves */ #define MBEDTLS_ECP_DP_SECP256R1_ENABLED #define MBEDTLS_ECP_DP_SECP384R1_ENABLED +#define MBEDTLS_ECP_DP_SECP521R1_ENABLED #define MBEDTLS_ECP_DP_CURVE25519_ENABLED #ifndef TINY #define MBEDTLS_ECP_DP_CURVE448_ENABLED -/*#define MBEDTLS_ECP_DP_SECP521R1_ENABLED*/ /*#define MBEDTLS_ECP_DP_BP384R1_ENABLED*/ /*#define MBEDTLS_ECP_DP_SECP192R1_ENABLED*/ /*#define MBEDTLS_ECP_DP_SECP224R1_ENABLED*/ diff --git a/tool/net/BUILD.mk b/tool/net/BUILD.mk index 06a80d5f8..6528b8854 100644 --- a/tool/net/BUILD.mk +++ b/tool/net/BUILD.mk @@ -100,6 +100,7 @@ TOOL_NET_REDBEAN_LUA_MODULES = \ o/$(MODE)/tool/net/lmaxmind.o \ o/$(MODE)/tool/net/lsqlite3.o \ o/$(MODE)/tool/net/largon2.o \ + o/$(MODE)/tool/net/lcrypto.o \ o/$(MODE)/tool/net/launch.o o/$(MODE)/tool/net/redbean.dbg: \ diff --git a/tool/net/lcrypto.c b/tool/net/lcrypto.c new file mode 100644 index 000000000..c33e73a03 --- /dev/null +++ b/tool/net/lcrypto.c @@ -0,0 +1,1147 @@ +#include "libc/log/log.h" +#include "net/https/https.h" +#include "third_party/lua/lauxlib.h" +#include "third_party/mbedtls/error.h" +#include "third_party/mbedtls/pk.h" +#include "third_party/mbedtls/rsa.h" +#include "third_party/mbedtls/ecdsa.h" +#include "third_party/mbedtls/x509_csr.h" +#include "third_party/mbedtls/oid.h" +#include "third_party/mbedtls/md.h" +#include "third_party/mbedtls/base64.h" + +// Standard C library and redbean utilities +#include "libc/errno.h" +#include "libc/mem/mem.h" +#include "libc/str/str.h" +#include "tool/net/luacheck.h" + +// Updated PemToJwk to parse PEM keys and convert them into JWK format +static int PemToJwk(lua_State *L) { + const char *pem_key = luaL_checkstring(L, 1); + + mbedtls_pk_context key; + mbedtls_pk_init(&key); + int ret; + + // Parse the PEM key + if ((ret = mbedtls_pk_parse_key(&key, (const unsigned char *)pem_key, strlen(pem_key) + 1, NULL, 0)) != 0 && + (ret = mbedtls_pk_parse_public_key(&key, (const unsigned char *)pem_key, strlen(pem_key) + 1)) != 0) { + lua_pushnil(L); + lua_pushfstring(L, "Failed to parse PEM key: -0x%04x", -ret); + mbedtls_pk_free(&key); + return 2; + } + + lua_newtable(L); // Create a new Lua table + + if (mbedtls_pk_get_type(&key) == MBEDTLS_PK_RSA) { + // Handle RSA keys + const mbedtls_rsa_context *rsa = mbedtls_pk_rsa(key); + size_t n_len = mbedtls_mpi_size(&rsa->N); + size_t e_len = mbedtls_mpi_size(&rsa->E); + + unsigned char *n = malloc(n_len); + unsigned char *e = malloc(e_len); + + if (!n || !e) { + lua_pushnil(L); + lua_pushstring(L, "Memory allocation failed"); + free(n); + free(e); + mbedtls_pk_free(&key); + return 2; + } + + mbedtls_mpi_write_binary(&rsa->N, n, n_len); + mbedtls_mpi_write_binary(&rsa->E, e, e_len); + + char *n_b64 = NULL, *e_b64 = NULL; + size_t n_b64_len, e_b64_len; + + mbedtls_base64_encode(NULL, 0, &n_b64_len, n, n_len); + mbedtls_base64_encode(NULL, 0, &e_b64_len, e, e_len); + + n_b64 = malloc(n_b64_len + 1); + e_b64 = malloc(e_b64_len + 1); + + if (!n_b64 || !e_b64) { + lua_pushnil(L); + lua_pushstring(L, "Memory allocation failed"); + free(n); + free(e); + free(n_b64); + free(e_b64); + mbedtls_pk_free(&key); + return 2; + } + + mbedtls_base64_encode((unsigned char *)n_b64, n_b64_len, &n_b64_len, n, n_len); + mbedtls_base64_encode((unsigned char *)e_b64, e_b64_len, &e_b64_len, e, e_len); + + n_b64[n_b64_len] = '\0'; + e_b64[e_b64_len] = '\0'; + + lua_pushstring(L, "RSA"); + lua_setfield(L, -2, "kty"); + lua_pushstring(L, n_b64); + lua_setfield(L, -2, "n"); + lua_pushstring(L, e_b64); + lua_setfield(L, -2, "e"); + + free(n); + free(e); + free(n_b64); + free(e_b64); + } else if (mbedtls_pk_get_type(&key) == MBEDTLS_PK_ECKEY) { + // Handle ECDSA keys + const mbedtls_ecp_keypair *ec = mbedtls_pk_ec(key); + const mbedtls_ecp_point *Q = &ec->Q; + size_t x_len = (ec->grp.pbits + 7) / 8; + size_t y_len = (ec->grp.pbits + 7) / 8; + + unsigned char *x = malloc(x_len); + unsigned char *y = malloc(y_len); + + if (!x || !y) { + lua_pushnil(L); + lua_pushstring(L, "Memory allocation failed"); + free(x); + free(y); + mbedtls_pk_free(&key); + return 2; + } + + mbedtls_mpi_write_binary(&Q->X, x, x_len); + mbedtls_mpi_write_binary(&Q->Y, y, y_len); + + char *x_b64 = NULL, *y_b64 = NULL; + size_t x_b64_len, y_b64_len; + + mbedtls_base64_encode(NULL, 0, &x_b64_len, x, x_len); + mbedtls_base64_encode(NULL, 0, &y_b64_len, y, y_len); + + x_b64 = malloc(x_b64_len + 1); + y_b64 = malloc(y_b64_len + 1); + + if (!x_b64 || !y_b64) { + lua_pushnil(L); + lua_pushstring(L, "Memory allocation failed"); + free(x); + free(y); + free(x_b64); + free(y_b64); + mbedtls_pk_free(&key); + return 2; + } + + mbedtls_base64_encode((unsigned char *)x_b64, x_b64_len, &x_b64_len, x, x_len); + mbedtls_base64_encode((unsigned char *)y_b64, y_b64_len, &y_b64_len, y, y_len); + + x_b64[x_b64_len] = '\0'; + y_b64[y_b64_len] = '\0'; + + lua_pushstring(L, "EC"); + lua_setfield(L, -2, "kty"); + lua_pushstring(L, mbedtls_ecp_curve_info_from_grp_id(ec->grp.id)->name); + lua_setfield(L, -2, "crv"); + lua_pushstring(L, x_b64); + lua_setfield(L, -2, "x"); + lua_pushstring(L, y_b64); + lua_setfield(L, -2, "y"); + + free(x); + free(y); + free(x_b64); + free(y_b64); + } else { + lua_pushnil(L); + lua_pushstring(L, "Unsupported key type"); + mbedtls_pk_free(&key); + return 2; + } + + mbedtls_pk_free(&key); + return 1; +} + +// CSR Creation Function +static int CreateCSR(lua_State *L) { + const char *key_pem = luaL_checkstring(L, 1); + const char *subject_name = luaL_checkstring(L, 2); + const char *san_list = luaL_optstring(L, 3, NULL); + + mbedtls_pk_context key; + mbedtls_x509write_csr req; + char buf[4096]; + int ret; + + mbedtls_pk_init(&key); + mbedtls_x509write_csr_init(&req); + + if ((ret = mbedtls_pk_parse_key(&key, (const unsigned char *)key_pem, strlen(key_pem) + 1, NULL, 0)) != 0) { + lua_pushnil(L); + lua_pushfstring(L, "Failed to parse key: %d", ret); + return 2; + } + + mbedtls_x509write_csr_set_subject_name(&req, subject_name); + mbedtls_x509write_csr_set_key(&req, &key); + mbedtls_x509write_csr_set_md_alg(&req, MBEDTLS_MD_SHA256); + + if (san_list) { + if ((ret = mbedtls_x509write_csr_set_extension(&req, MBEDTLS_OID_SUBJECT_ALT_NAME, MBEDTLS_OID_SIZE(MBEDTLS_OID_SUBJECT_ALT_NAME), (const unsigned char *)san_list, strlen(san_list))) != 0) { + lua_pushnil(L); + lua_pushfstring(L, "Failed to set SANs: %d", ret); + return 2; + } + } + + if ((ret = mbedtls_x509write_csr_pem(&req, (unsigned char *)buf, sizeof(buf), NULL, NULL)) < 0) { + lua_pushnil(L); + lua_pushfstring(L, "Failed to write CSR: %d", ret); + return 2; + } + + lua_pushstring(L, buf); + + mbedtls_pk_free(&key); + mbedtls_x509write_csr_free(&req); + + return 1; +} + + +static bool RSAGenerateKeyPair(char **private_key_pem, size_t *private_key_len, + char **public_key_pem, size_t *public_key_len, + unsigned int key_length) { + int rc; + mbedtls_pk_context key; + mbedtls_pk_init(&key); + + // Initialize as RSA key + if ((rc = mbedtls_pk_setup(&key, + mbedtls_pk_info_from_type(MBEDTLS_PK_RSA))) != 0) { + WARNF("Failed to setup key (grep -0x%04x)", -rc); + mbedtls_pk_free(&key); + return false; + } + + // Generate RSA key + if ((rc = mbedtls_rsa_gen_key(mbedtls_pk_rsa(key), GenerateHardRandom, 0, + key_length, 65537)) != 0) { + WARNF("Failed to generate key (grep -0x%04x)", -rc); + mbedtls_pk_free(&key); + return false; + } + + // Write private key to PEM + *private_key_len = 16000; // Buffer size for private key + *private_key_pem = calloc(1, *private_key_len); + if ((rc = mbedtls_pk_write_key_pem(&key, (unsigned char *)*private_key_pem, + *private_key_len)) != 0) { + WARNF("Failed to write private key (grep -0x%04x)", -rc); + free(*private_key_pem); + mbedtls_pk_free(&key); + return false; + } + *private_key_len = strlen(*private_key_pem); + + // Write public key to PEM + *public_key_len = 8000; // Buffer size for public key + *public_key_pem = calloc(1, *public_key_len); + if ((rc = mbedtls_pk_write_pubkey_pem(&key, (unsigned char *)*public_key_pem, + *public_key_len)) != 0) { + WARNF("Failed to write public key (grep -0x%04x)", -rc); + free(*private_key_pem); + free(*public_key_pem); + mbedtls_pk_free(&key); + return false; + } + *public_key_len = strlen(*public_key_pem); + + mbedtls_pk_free(&key); + return true; +} +/** + * Lua wrapper for RSA key pair generation + * + * Lua function signature: RSAGenerateKeyPair([key_length]) + * @param L Lua state + * @return 2 on success (private_key, public_key), 2 on failure (nil, + * error_message) + */ +static int LuaRSAGenerateKeyPair(lua_State *L) { + char *private_key, *public_key; + size_t private_len, public_len; + int key_length = 2048; // Default RSA key length + + // Get key length from Lua (optional parameter) + if (lua_gettop(L) >= 1 && !lua_isnil(L, 1)) { + key_length = luaL_checkinteger(L, 1); + // Validate key length (common RSA key lengths are 1024, 2048, 3072, 4096) + if (key_length != 1024 && key_length != 2048 && key_length != 3072 && + key_length != 4096) { + lua_pushnil(L); + lua_pushstring(L, + "Invalid RSA key length. Use 1024, 2048, 3072, or 4096."); + return 2; + } + } + + // Call the C function to generate the key pair + if (!RSAGenerateKeyPair(&private_key, &private_len, &public_key, &public_len, + key_length)) { + lua_pushnil(L); + lua_pushstring(L, "Failed to generate RSA key pair"); + return 2; + } + + // Push results to Lua + lua_pushstring(L, private_key); + lua_pushstring(L, public_key); + + // Clean up + free(private_key); + free(public_key); + + return 2; +} + +// RSA +static char *RSAEncrypt(const char *public_key_pem, const unsigned char *data, + size_t data_len, size_t *out_len) { + int rc; + + // Parse public key + mbedtls_pk_context key; + mbedtls_pk_init(&key); + if ((rc = mbedtls_pk_parse_public_key(&key, + (const unsigned char *)public_key_pem, + strlen(public_key_pem) + 1)) != 0) { + WARNF("Failed to parse public key (grep -0x%04x)", -rc); + mbedtls_pk_free(&key); + return NULL; + } + + // Check if key is RSA + if (mbedtls_pk_get_type(&key) != MBEDTLS_PK_RSA) { + WARNF("Key is not an RSA key"); + mbedtls_pk_free(&key); + return NULL; + } + + // Allocate output buffer + size_t key_size = mbedtls_pk_get_len(&key); + unsigned char *output = calloc(1, key_size); + if (!output) { + mbedtls_pk_free(&key); + return NULL; + } + + // Encrypt data + if ((rc = mbedtls_rsa_pkcs1_encrypt(mbedtls_pk_rsa(key), GenerateHardRandom, + 0, MBEDTLS_RSA_PUBLIC, data_len, data, + output)) != 0) { + WARNF("Encryption failed (grep -0x%04x)", -rc); + free(output); + mbedtls_pk_free(&key); + return NULL; + } + + *out_len = key_size; + mbedtls_pk_free(&key); + return (char *)output; +} +static int LuaRSAEncrypt(lua_State *L) { + const char *public_key = luaL_checkstring(L, 1); + size_t data_len; + const unsigned char *data = + (const unsigned char *)luaL_checklstring(L, 2, &data_len); + size_t out_len; + + char *encrypted = RSAEncrypt(public_key, data, data_len, &out_len); + if (!encrypted) { + lua_pushnil(L); + lua_pushstring(L, "Encryption failed"); + return 2; + } + + lua_pushlstring(L, encrypted, out_len); + free(encrypted); + + return 1; +} + +static char *RSADecrypt(const char *private_key_pem, + const unsigned char *encrypted_data, size_t encrypted_len, + size_t *out_len) { + int rc; + + // Parse private key + mbedtls_pk_context key; + mbedtls_pk_init(&key); + if ((rc = mbedtls_pk_parse_key(&key, (const unsigned char *)private_key_pem, + strlen(private_key_pem) + 1, NULL, 0)) != 0) { + WARNF("Failed to parse private key (grep -0x%04x)", -rc); + mbedtls_pk_free(&key); + return NULL; + } + + // Check if key is RSA + if (mbedtls_pk_get_type(&key) != MBEDTLS_PK_RSA) { + WARNF("Key is not an RSA key"); + mbedtls_pk_free(&key); + return NULL; + } + + // Allocate output buffer + size_t key_size = mbedtls_pk_get_len(&key); + unsigned char *output = calloc(1, key_size); + if (!output) { + mbedtls_pk_free(&key); + return NULL; + } + + // Decrypt data + size_t output_len = 0; + if ((rc = mbedtls_rsa_pkcs1_decrypt(mbedtls_pk_rsa(key), GenerateHardRandom, + 0, MBEDTLS_RSA_PRIVATE, &output_len, + encrypted_data, output, key_size)) != 0) { + WARNF("Decryption failed (grep -0x%04x)", -rc); + free(output); + mbedtls_pk_free(&key); + return NULL; + } + + *out_len = output_len; + mbedtls_pk_free(&key); + return (char *)output; +} +static int LuaRSADecrypt(lua_State *L) { + const char *private_key = luaL_checkstring(L, 1); + size_t encrypted_len; + const unsigned char *encrypted_data = + (const unsigned char *)luaL_checklstring(L, 2, &encrypted_len); + size_t out_len; + + char *decrypted = + RSADecrypt(private_key, encrypted_data, encrypted_len, &out_len); + if (!decrypted) { + lua_pushnil(L); + lua_pushstring(L, "Decryption failed"); + return 2; + } + + lua_pushlstring(L, decrypted, out_len); + free(decrypted); + + return 1; +} + +// RSA Signing +static char *RSASign(const char *private_key_pem, const unsigned char *data, + size_t data_len, const char *hash_algo_str, size_t *sig_len) { + int rc; + unsigned char hash[64]; // Large enough for SHA-512 + size_t hash_len = 32; // Default for SHA-256 + unsigned char *signature; + mbedtls_md_type_t hash_algo = MBEDTLS_MD_SHA256; // Default + + // Determine hash algorithm + if (hash_algo_str) { + if (strcasecmp(hash_algo_str, "sha256") == 0) { + hash_algo = MBEDTLS_MD_SHA256; + hash_len = 32; + } else if (strcasecmp(hash_algo_str, "sha384") == 0) { + hash_algo = MBEDTLS_MD_SHA384; + hash_len = 48; + } else if (strcasecmp(hash_algo_str, "sha512") == 0) { + hash_algo = MBEDTLS_MD_SHA512; + hash_len = 64; + } else { + return NULL; // Unsupported hash algorithm + } + } + + // Parse private key + mbedtls_pk_context key; + mbedtls_pk_init(&key); + if ((rc = mbedtls_pk_parse_key(&key, (const unsigned char *)private_key_pem, + strlen(private_key_pem) + 1, NULL, 0)) != 0) { + WARNF("Failed to parse private key (grep -0x%04x)", -rc); + mbedtls_pk_free(&key); + return NULL; + } + + // Check if key is RSA + if (mbedtls_pk_get_type(&key) != MBEDTLS_PK_RSA) { + WARNF("Key is not an RSA key"); + mbedtls_pk_free(&key); + return NULL; + } + + // Hash the message + if ((rc = mbedtls_md(mbedtls_md_info_from_type(hash_algo), data, data_len, + hash)) != 0) { + mbedtls_pk_free(&key); + return NULL; + } + + // Allocate buffer for signature + signature = malloc(MBEDTLS_PK_SIGNATURE_MAX_SIZE); + if (!signature) { + mbedtls_pk_free(&key); + return NULL; + } + + // Sign the hash + if ((rc = mbedtls_pk_sign(&key, hash_algo, hash, hash_len, signature, sig_len, + GenerateHardRandom, 0)) != 0) { + free(signature); + mbedtls_pk_free(&key); + return NULL; + } + + // Clean up + mbedtls_pk_free(&key); + + return (char *)signature; +} +static int LuaRSASign(lua_State *L) { + size_t msg_len, key_len; + const char *msg, *key_pem, *hash_algo_str = NULL; + unsigned char *signature; + size_t sig_len = 0; + + // Get parameters from Lua + key_pem = luaL_checklstring(L, 1, &key_len); + msg = luaL_checklstring(L, 2, &msg_len); + + // Optional hash algorithm parameter + if (!lua_isnoneornil(L, 3)) { + hash_algo_str = luaL_checkstring(L, 3); + } + + // Call the C implementation + signature = (unsigned char *)RSASign(key_pem, (const unsigned char *)msg, + msg_len, hash_algo_str, &sig_len); + + if (!signature) { + return luaL_error(L, "failed to sign message"); + } + + // Return the signature as a Lua string + lua_pushlstring(L, (char *)signature, sig_len); + + // Clean up + free(signature); + + return 1; +} + +static int RSAVerify(const char *public_key_pem, const unsigned char *data, + size_t data_len, const unsigned char *signature, + size_t sig_len, const char *hash_algo_str) { + int rc; + unsigned char hash[64]; // Large enough for SHA-512 + size_t hash_len = 32; // Default for SHA-256 + mbedtls_md_type_t hash_algo = MBEDTLS_MD_SHA256; // Default + + // Determine hash algorithm + if (hash_algo_str) { + if (strcasecmp(hash_algo_str, "sha256") == 0) { + hash_algo = MBEDTLS_MD_SHA256; + hash_len = 32; + } else if (strcasecmp(hash_algo_str, "sha384") == 0) { + hash_algo = MBEDTLS_MD_SHA384; + hash_len = 48; + } else if (strcasecmp(hash_algo_str, "sha512") == 0) { + hash_algo = MBEDTLS_MD_SHA512; + hash_len = 64; + } else { + return -1; // Unsupported hash algorithm + } + } + + // Parse public key + mbedtls_pk_context key; + mbedtls_pk_init(&key); + if ((rc = mbedtls_pk_parse_public_key(&key, + (const unsigned char *)public_key_pem, + strlen(public_key_pem) + 1)) != 0) { + WARNF("Failed to parse public key (grep -0x%04x)", -rc); + mbedtls_pk_free(&key); + return -1; + } + + // Check if key is RSA + if (mbedtls_pk_get_type(&key) != MBEDTLS_PK_RSA) { + WARNF("Key is not an RSA key"); + mbedtls_pk_free(&key); + return -1; + } + + // Hash the message + if ((rc = mbedtls_md(mbedtls_md_info_from_type(hash_algo), data, data_len, + hash)) != 0) { + mbedtls_pk_free(&key); + return -1; + } + + // Verify the signature + rc = mbedtls_pk_verify(&key, hash_algo, hash, hash_len, signature, sig_len); + + // Clean up + mbedtls_pk_free(&key); + + return rc; // 0 means success (valid signature) +} +static int LuaRSAVerify(lua_State *L) { + size_t msg_len, key_len, sig_len; + const char *msg, *key_pem, *signature, *hash_algo_str = NULL; + int result; + + // Get parameters from Lua + key_pem = luaL_checklstring(L, 1, &key_len); + msg = luaL_checklstring(L, 2, &msg_len); + signature = luaL_checklstring(L, 3, &sig_len); + + // Optional hash algorithm parameter + if (!lua_isnoneornil(L, 4)) { + hash_algo_str = luaL_checkstring(L, 4); + } + + // Call the C implementation + result = RSAVerify(key_pem, (const unsigned char *)msg, msg_len, + (const unsigned char *)signature, sig_len, hash_algo_str); + + // Return boolean result (0 means valid signature) + lua_pushboolean(L, result == 0); + + return 1; +} + + +// Supported curves mapping +typedef struct { + const char *name; + mbedtls_ecp_group_id id; +} curve_map_t; + +static const curve_map_t supported_curves[] = { + {"secp256r1", MBEDTLS_ECP_DP_SECP256R1}, + {"secp384r1", MBEDTLS_ECP_DP_SECP384R1}, + {"secp521r1", MBEDTLS_ECP_DP_SECP521R1}, + {"secp192r1", MBEDTLS_ECP_DP_SECP192R1}, + {"secp224r1", MBEDTLS_ECP_DP_SECP224R1}, + {"curve25519", MBEDTLS_ECP_DP_CURVE25519}, + {NULL, 0}}; + +typedef enum { SHA256, SHA384, SHA512 } hash_algorithm_t; + +static mbedtls_md_type_t hash_to_md_type(hash_algorithm_t hash_alg) { + switch (hash_alg) { + case SHA256: + return MBEDTLS_MD_SHA256; + case SHA384: + return MBEDTLS_MD_SHA384; + case SHA512: + return MBEDTLS_MD_SHA512; + default: + return MBEDTLS_MD_SHA256; // Default to SHA-256 + } +} + +static size_t get_hash_size(hash_algorithm_t hash_alg) { + switch (hash_alg) { + case SHA256: + return 32; + case SHA384: + return 48; + case SHA512: + return 64; + default: + return 32; // Default to SHA-256 + } +} + +static hash_algorithm_t string_to_hash_alg(const char *hash_name) { + if (!hash_name || !*hash_name) { + return SHA256; // Default to SHA-256 if no name provided + } + + if (strcasecmp(hash_name, "sha256") == 0 || + strcasecmp(hash_name, "sha-256") == 0) { + return SHA256; + } else if (strcasecmp(hash_name, "sha384") == 0 || + strcasecmp(hash_name, "sha-384") == 0) { + return SHA384; + } else if (strcasecmp(hash_name, "sha512") == 0 || + strcasecmp(hash_name, "sha-512") == 0) { + return SHA512; + } else { + WARNF("(ecdsa) Unknown hash algorithm '%s', using SHA-256", hash_name); + return SHA256; + } +} + +static int LuaListHashAlgorithms(lua_State *L) { + lua_newtable(L); + + lua_pushstring(L, "SHA256"); + lua_rawseti(L, -2, 1); + + lua_pushstring(L, "SHA384"); + lua_rawseti(L, -2, 2); + + lua_pushstring(L, "SHA512"); + lua_rawseti(L, -2, 3); + + // Add hyphenated versions + lua_pushstring(L, "SHA-256"); + lua_rawseti(L, -2, 4); + + lua_pushstring(L, "SHA-384"); + lua_rawseti(L, -2, 5); + + lua_pushstring(L, "SHA-512"); + lua_rawseti(L, -2, 6); + + return 1; +} +// List available curves +static int LuaListCurves(lua_State *L) { + const curve_map_t *curve = supported_curves; + int i = 1; + + lua_newtable(L); + + while (curve->name != NULL) { + lua_pushstring(L, curve->name); + lua_rawseti(L, -2, i++); + curve++; + } + + return 1; +} + +static int compute_hash(hash_algorithm_t hash_alg, const unsigned char *input, + size_t input_len, unsigned char *output, + size_t output_size) { + mbedtls_md_context_t md_ctx; + const mbedtls_md_info_t *md_info; + int ret; + + mbedtls_md_type_t md_type = hash_to_md_type(hash_alg); + md_info = mbedtls_md_info_from_type(md_type); + if (md_info == NULL) { + WARNF("(ecdsa) Unsupported hash algorithm"); + return -1; + } + + if (output_size < mbedtls_md_get_size(md_info)) { + WARNF("(ecdsa) Output buffer too small for hash"); + return -1; + } + + mbedtls_md_init(&md_ctx); + + ret = mbedtls_md_setup(&md_ctx, md_info, 0); // 0 = non-HMAC + if (ret != 0) { + WARNF("(ecdsa) Failed to set up hash context: -0x%04x", -ret); + goto cleanup; + } + + ret = mbedtls_md_starts(&md_ctx); + if (ret != 0) { + WARNF("(ecdsa) Failed to start hash operation: -0x%04x", -ret); + goto cleanup; + } + + ret = mbedtls_md_update(&md_ctx, input, input_len); + if (ret != 0) { + WARNF("(ecdsa) Failed to update hash: -0x%04x", -ret); + goto cleanup; + } + + ret = mbedtls_md_finish(&md_ctx, output); + if (ret != 0) { + WARNF("(ecdsa) Failed to finish hash: -0x%04x", -ret); + goto cleanup; + } + +cleanup: + mbedtls_md_free(&md_ctx); + return ret; +} + +// Find curve ID by name +static mbedtls_ecp_group_id find_curve_by_name(const char *name) { + const curve_map_t *curve = supported_curves; + + while (curve->name != NULL) { + if (strcasecmp(curve->name, name) == 0) { + return curve->id; + } + curve++; + } + + return MBEDTLS_ECP_DP_NONE; +} + +// Generate an ECDSA key pair and return in PEM format +static int ECDSAGenerateKeyPair(const char *curve_name, char **priv_key_pem, + char **pub_key_pem) { + mbedtls_pk_context key; + unsigned char output_buf[16000]; + int ret; + mbedtls_ecp_group_id curve_id; + + // Initialize output parameters to NULL in case of early return + if (priv_key_pem) + *priv_key_pem = NULL; + if (pub_key_pem) + *pub_key_pem = NULL; + + // Use secp256r1 as default if curve_name is NULL or empty + if (curve_name == NULL || curve_name[0] == '\0') { + curve_id = MBEDTLS_ECP_DP_SECP256R1; + VERBOSEF("(ecdsa) No curve specified, using default: secp256r1"); + } else { + // Find the curve by name + curve_id = find_curve_by_name(curve_name); + if (curve_id == MBEDTLS_ECP_DP_NONE) { + WARNF("(ecdsa) Unknown curve: %s, using default: secp256r1", curve_name); + curve_id = MBEDTLS_ECP_DP_SECP256R1; + } + } + + mbedtls_pk_init(&key); + + // Generate the key with the specified curve + ret = mbedtls_pk_setup(&key, mbedtls_pk_info_from_type(MBEDTLS_PK_ECKEY)); + if (ret != 0) { + WARNF("(ecdsa) Failed to setup key: -0x%04x", -ret); + goto cleanup; + } + + ret = + mbedtls_ecp_gen_key(curve_id, mbedtls_pk_ec(key), GenerateHardRandom, 0); + if (ret != 0) { + WARNF("(ecdsa) Failed to generate key: -0x%04x", -ret); + goto cleanup; + } + + // Generate private key PEM + if (priv_key_pem != NULL) { + memset(output_buf, 0, sizeof(output_buf)); + ret = mbedtls_pk_write_key_pem(&key, output_buf, sizeof(output_buf)); + if (ret != 0) { + WARNF("(ecdsa) Failed to write private key: -0x%04x", -ret); + goto cleanup; + } + *priv_key_pem = strdup((char *)output_buf); + if (*priv_key_pem == NULL) { + WARNF("(ecdsa) Failed to allocate memory for private key PEM"); + ret = -1; + goto cleanup; + } + } + + // Generate public key PEM + if (pub_key_pem != NULL) { + memset(output_buf, 0, sizeof(output_buf)); + ret = mbedtls_pk_write_pubkey_pem(&key, output_buf, sizeof(output_buf)); + if (ret != 0) { + WARNF("(ecdsa) Failed to write public key: -0x%04x", -ret); + goto cleanup; + } + *pub_key_pem = strdup((char *)output_buf); + if (*pub_key_pem == NULL) { + WARNF("(ecdsa) Failed to allocate memory for public key PEM"); + ret = -1; + goto cleanup; + } + } + +cleanup: + mbedtls_pk_free(&key); + if (ret != 0) { + // Clean up on error + if (priv_key_pem && *priv_key_pem) { + free(*priv_key_pem); + *priv_key_pem = NULL; + } + if (pub_key_pem && *pub_key_pem) { + free(*pub_key_pem); + *pub_key_pem = NULL; + } + } + return ret; +} +// Lua binding for generating ECDSA keys +static int LuaECDSAGenerateKeyPair(lua_State *L) { + const char *curve_name = NULL; + char *priv_key_pem = NULL; + char *pub_key_pem = NULL; + + // Check if curve name is provided + if (lua_gettop(L) >= 1 && !lua_isnil(L, 1)) { + curve_name = luaL_checkstring(L, 1); + } + // If not provided, generate_key_pem will use the default + + int ret = ECDSAGenerateKeyPair(curve_name, &priv_key_pem, &pub_key_pem); + + if (ret == 0) { + lua_pushstring(L, priv_key_pem); + lua_pushstring(L, pub_key_pem); + free(priv_key_pem); + free(pub_key_pem); + return 2; + } else { + lua_pushnil(L); + lua_pushnil(L); + return 2; + } +} + +// Sign a message using an ECDSA private key in PEM format +static int ECDSASign(const char *priv_key_pem, const char *message, + hash_algorithm_t hash_alg, unsigned char **signature, + size_t *sig_len) { + mbedtls_pk_context key; + unsigned char hash[64]; // Max hash size (SHA-512) + size_t hash_size; + int ret; + + *signature = NULL; + *sig_len = 0; + + if (!priv_key_pem) { + WARNF("(ecdsa) Private key is NULL"); + return -1; + } + + // Get the length of the PEM string (excluding null terminator) + size_t key_len = strlen(priv_key_pem); + if (key_len == 0) { + WARNF("(ecdsa) Private key is empty"); + return -1; + } + + // Get hash size for the selected algorithm + hash_size = get_hash_size(hash_alg); + + mbedtls_pk_init(&key); + + // Parse the private key from PEM directly without creating a copy + ret = mbedtls_pk_parse_key(&key, (const unsigned char *)priv_key_pem, + key_len + 1, NULL, 0); + + if (ret != 0) { + WARNF("(ecdsa) Failed to parse private key: -0x%04x", -ret); + goto cleanup; + } + + // Compute hash of the message using the specified algorithm + ret = compute_hash(hash_alg, (const unsigned char *)message, strlen(message), + hash, sizeof(hash)); + if (ret != 0) { + WARNF("(ecdsa) Failed to compute message hash"); + goto cleanup; + } + + // Allocate memory for signature (max size for ECDSA) + *signature = malloc(MBEDTLS_ECDSA_MAX_LEN); + if (*signature == NULL) { + WARNF("(ecdsa) Failed to allocate memory for signature"); + ret = -1; + goto cleanup; + } + + // Sign the hash using GenerateHardRandom + ret = mbedtls_pk_sign(&key, hash_to_md_type(hash_alg), hash, hash_size, + *signature, sig_len, GenerateHardRandom, 0); + + if (ret != 0) { + WARNF("(ecdsa) Failed to sign message: -0x%04x", -ret); + free(*signature); + *signature = NULL; + *sig_len = 0; + goto cleanup; + } + +cleanup: + mbedtls_pk_free(&key); + return ret; +} // Lua binding for signing a message +static int LuaECDSASign(lua_State *L) { + const char *hash_name = luaL_optstring(L, 1, "sha256"); // Default to SHA-256 + const char *message = luaL_checkstring(L, 2); + const char *priv_key_pem = luaL_checkstring(L, 3); + + hash_algorithm_t hash_alg = string_to_hash_alg(hash_name); + + unsigned char *signature = NULL; + size_t sig_len = 0; + + int ret = ECDSASign(priv_key_pem, message, hash_alg, &signature, &sig_len); + + if (ret == 0) { + lua_pushlstring(L, (const char *)signature, sig_len); + free(signature); + } else { + lua_pushnil(L); + } + + return 1; +} + +// Verify a signature using an ECDSA public key in PEM format +static int ECDSAVerify(const char *pub_key_pem, const char *message, + const unsigned char *signature, size_t sig_len, + hash_algorithm_t hash_alg) { + mbedtls_pk_context key; + unsigned char hash[64]; // Max hash size (SHA-512) + size_t hash_size; + int ret; + + if (!pub_key_pem) { + WARNF("(ecdsa) Public key is NULL"); + return -1; + } + + // Get the length of the PEM string (excluding null terminator) + size_t key_len = strlen(pub_key_pem); + if (key_len == 0) { + WARNF("(ecdsa) Public key is empty"); + return -1; + } + + // Get hash size for the selected algorithm + hash_size = get_hash_size(hash_alg); + + mbedtls_pk_init(&key); + + // Parse the public key from PEM + ret = mbedtls_pk_parse_public_key(&key, (const unsigned char *)pub_key_pem, + key_len + 1); + if (ret != 0) { + WARNF("(ecdsa) Failed to parse public key: -0x%04x", -ret); + goto cleanup; + } + + // Compute hash of the message using the specified algorithm + ret = compute_hash(hash_alg, (const unsigned char *)message, strlen(message), + hash, sizeof(hash)); + if (ret != 0) { + WARNF("(ecdsa) Failed to compute message hash"); + goto cleanup; + } + + // Verify the signature + ret = mbedtls_pk_verify(&key, hash_to_md_type(hash_alg), hash, hash_size, + signature, sig_len); + if (ret != 0) { + WARNF("(ecdsa) Signature verification failed: -0x%04x", -ret); + goto cleanup; + } + +cleanup: + mbedtls_pk_free(&key); + return ret; +} +static int LuaECDSAVerify(lua_State *L) { + const char *hash_name = luaL_optstring(L, 1, "sha256"); // Default to SHA-256 + const char *message = luaL_checkstring(L, 2); + const char *pub_key_pem = luaL_checkstring(L, 3); + size_t sig_len; + const unsigned char *signature = + (const unsigned char *)luaL_checklstring(L, 4, &sig_len); + + hash_algorithm_t hash_alg = string_to_hash_alg(hash_name); + + int ret = ECDSAVerify(pub_key_pem, message, signature, sig_len, hash_alg); + + lua_pushboolean(L, ret == 0); + return 1; +} + +static int LuaCryptoSign(lua_State *L) { + const char *dtype = luaL_checkstring(L, 1); // Type of signature (e.g., "rsa", "ecdsa") + lua_remove(L, 1); // Remove the first argument (key type or cipher type) before dispatching + + if (strcasecmp(dtype, "rsa") == 0) { + return LuaRSASign(L); + } else if (strcasecmp(dtype, "ecdsa") == 0) { + return LuaECDSASign(L); + } else { + return luaL_error(L, "Unsupported signature type: %s", dtype); + } +} + +static int LuaCryptoVerify(lua_State *L) { + const char *dtype = luaL_checkstring(L, 1); // Type of signature (e.g., "rsa", "ecdsa") + lua_remove(L, 1); // Remove the first argument (key type or cipher type) before dispatching + + if (strcasecmp(dtype, "rsa") == 0) { + return LuaRSAVerify(L); + } else if (strcasecmp(dtype, "ecdsa") == 0) { + return LuaECDSAVerify(L); + } else { + return luaL_error(L, "Unsupported signature type: %s", dtype); + } +} + +static int LuaCryptoEncrypt(lua_State *L) { + const char *cipher = luaL_checkstring(L, 1); // Cipher type (e.g., "rsa") + lua_remove(L, 1); // Remove the first argument (key type or cipher type) before dispatching + + if (strcasecmp(cipher, "rsa") == 0) { + return LuaRSAEncrypt(L); + } else { + return luaL_error(L, "Unsupported cipher type: %s", cipher); + } +} + +static int LuaCryptoDecrypt(lua_State *L) { + const char *cipher = luaL_checkstring(L, 1); // Cipher type (e.g., "rsa") + lua_remove(L, 1); // Remove the first argument (key type or cipher type) before dispatching + + if (strcasecmp(cipher, "rsa") == 0) { + return LuaRSADecrypt(L); + } else { + return luaL_error(L, "Unsupported cipher type: %s", cipher); + } +} + +static int LuaCryptoGenerateKeyPair(lua_State *L) { + const char *key_type = luaL_checkstring(L, 1); // Key type (e.g., "rsa", "ecdsa") + lua_remove(L, 1); // Remove the first argument (key type or cipher type) before dispatching + + if (strcasecmp(key_type, "rsa") == 0) { + return LuaRSAGenerateKeyPair(L); + } else if (strcasecmp(key_type, "ecdsa") == 0) { + return LuaECDSAGenerateKeyPair(L); + } else { + return luaL_error(L, "Unsupported key type: %s", key_type); + } +} + +static const luaL_Reg kLuaCrypto[] = { + {"sign", LuaCryptoSign}, // + {"verify", LuaCryptoVerify}, // + {"encrypt", LuaCryptoEncrypt}, // + {"decrypt", LuaCryptoDecrypt}, // + {"generatekeypair", LuaCryptoGenerateKeyPair}, // + {"PemToJwk", PemToJwk}, // + {"csrGenerate", CreateCSR}, // + {0}, // +}; + +int LuaCrypto(lua_State *L) { + luaL_newlib(L, kLuaCrypto); + return 1; +} diff --git a/tool/net/lcrypto.h b/tool/net/lcrypto.h new file mode 100644 index 000000000..e1e11ed65 --- /dev/null +++ b/tool/net/lcrypto.h @@ -0,0 +1,10 @@ +#ifndef COSMOPOLITAN_TOOL_NET_LCRYPTO_H_ +#define COSMOPOLITAN_TOOL_NET_LCRYPTO_H_ +#include "third_party/lua/lauxlib.h" +COSMOPOLITAN_C_START_ + +int LuaCrypto(lua_State *L); +int luaopen_lcrypto(lua_State *L); + +COSMOPOLITAN_C_END_ +#endif /* COSMOPOLITAN_TOOL_NET_LCRYPTO_H_ */ diff --git a/tool/net/lfuncs.h b/tool/net/lfuncs.h index 7bc3fc748..4fcbd0fa5 100644 --- a/tool/net/lfuncs.h +++ b/tool/net/lfuncs.h @@ -8,6 +8,7 @@ int LuaMaxmind(lua_State *); int LuaRe(lua_State *); int luaopen_argon2(lua_State *); int luaopen_lsqlite3(lua_State *); +int LuaCrypto(lua_State *); int LuaBarf(lua_State *); int LuaBenchmark(lua_State *); diff --git a/tool/net/redbean.c b/tool/net/redbean.c index 93816d1aa..c7b9de601 100644 --- a/tool/net/redbean.c +++ b/tool/net/redbean.c @@ -5426,6 +5426,7 @@ static const luaL_Reg kLuaLibs[] = { {"path", LuaPath}, // {"re", LuaRe}, // {"unix", LuaUnix}, // + {"crypto", LuaCrypto}, // }; static void LuaSetArgv(lua_State *L) { From 47c01b548a1b876725af1b5d15dedcc769b2b195 Mon Sep 17 00:00:00 2001 From: Miguel Terron Date: Mon, 2 Jun 2025 14:28:19 +1200 Subject: [PATCH 02/18] Add LuaCrypto compatible functions (plus some auxiliary functions) as per #1136 --- test/tool/net/lcrypto_test.lua | 83 ++++++++++++++++++++++ tool/net/lcrypto.c | 126 +++++++++++++++------------------ 2 files changed, 142 insertions(+), 67 deletions(-) create mode 100644 test/tool/net/lcrypto_test.lua diff --git a/test/tool/net/lcrypto_test.lua b/test/tool/net/lcrypto_test.lua new file mode 100644 index 000000000..c89483086 --- /dev/null +++ b/test/tool/net/lcrypto_test.lua @@ -0,0 +1,83 @@ +-- Helper function to print test results +local function assert_equal(actual, expected, message) + if actual ~= expected then + error(message .. ": expected " .. tostring(expected) .. ", got " .. tostring(actual)) + else + print("PASS: " .. message) + end +end + +-- Test RSA key pair generation +local function test_rsa_keypair_generation() + local priv_key, pub_key = crypto.generatekeypair("rsa", 2048) + assert_equal(type(priv_key), "string", "RSA private key generation") + assert_equal(type(pub_key), "string", "RSA public key generation") +end + +-- Test ECDSA key pair generation +local function test_ecdsa_keypair_generation() + local priv_key, pub_key = crypto.generatekeypair("ecdsa", "secp256r1") + assert_equal(type(priv_key), "string", "ECDSA private key generation") + assert_equal(type(pub_key), "string", "ECDSA public key generation") +end + +-- Test RSA encryption and decryption +local function test_rsa_encryption_decryption() + local priv_key, pub_key = crypto.generatekeypair("rsa", 2048) + local message = "Hello, RSA!" + local encrypted = crypto.encrypt("rsa", pub_key, message) + assert_equal(type(encrypted), "string", "RSA encryption") + local decrypted = crypto.decrypt("rsa", priv_key, encrypted) + assert_equal(decrypted, message, "RSA decryption") +end + +-- Test RSA signing and verification +local function test_rsa_signing_verification() + local priv_key, pub_key = crypto.generatekeypair("rsa", 2048) + local message = "Sign this message" + local signature = crypto.sign("rsa", priv_key, message, "sha256") + assert_equal(type(signature), "string", "RSA signing") + local is_valid = crypto.verify("rsa", pub_key, message, signature, "sha256") + assert_equal(is_valid, true, "RSA signature verification") +end + +-- Test ECDSA signing and verification +local function test_ecdsa_signing_verification() + local priv_key, pub_key = crypto.generatekeypair("ecdsa", "secp256r1") + local message = "Sign this message with ECDSA" + local signature = crypto.sign("ecdsa", priv_key, message, "sha256") + assert_equal(type(signature), "string", "ECDSA signing") + local is_valid = crypto.verify("ecdsa", pub_key, message, signature, "sha256") + assert_equal(is_valid, true, "ECDSA signature verification") +end + +-- Test CSR generation +local function test_csr_generation() + local priv_key, pub_key = crypto.generatekeypair("rsa", 2048) + local subject_name = "CN=example.com,O=Example Org,C=US" + local csr = crypto.csrGenerate(priv_key, subject_name) + assert_equal(type(csr), "string", "CSR generation") +end + +-- Test PemToJwk conversion +local function test_pem_to_jwk() + local priv_key, pub_key = crypto.generatekeypair("rsa", 2048) + local jwk = crypto.PemToJwk(pub_key) + assert_equal(type(jwk), "table", "PEM to JWK conversion") + assert_equal(jwk.kty, "RSA", "JWK key type") +end + +-- Run all tests +local function run_tests() + print("Running tests for lcrypto...") + test_rsa_keypair_generation() + test_ecdsa_keypair_generation() + test_rsa_encryption_decryption() + test_rsa_signing_verification() + test_ecdsa_signing_verification() + test_csr_generation() + test_pem_to_jwk() + print("All tests passed!") +end + +run_tests() diff --git a/tool/net/lcrypto.c b/tool/net/lcrypto.c index c33e73a03..dd909e175 100644 --- a/tool/net/lcrypto.c +++ b/tool/net/lcrypto.c @@ -911,76 +911,68 @@ static int LuaECDSAGenerateKeyPair(lua_State *L) { static int ECDSASign(const char *priv_key_pem, const char *message, hash_algorithm_t hash_alg, unsigned char **signature, size_t *sig_len) { - mbedtls_pk_context key; - unsigned char hash[64]; // Max hash size (SHA-512) - size_t hash_size; - int ret; + mbedtls_pk_context key; + unsigned char hash[64]; // Max hash size (SHA-512) + size_t hash_size; + int ret; - *signature = NULL; - *sig_len = 0; - - if (!priv_key_pem) { - WARNF("(ecdsa) Private key is NULL"); - return -1; - } - - // Get the length of the PEM string (excluding null terminator) - size_t key_len = strlen(priv_key_pem); - if (key_len == 0) { - WARNF("(ecdsa) Private key is empty"); - return -1; - } - - // Get hash size for the selected algorithm - hash_size = get_hash_size(hash_alg); - - mbedtls_pk_init(&key); - - // Parse the private key from PEM directly without creating a copy - ret = mbedtls_pk_parse_key(&key, (const unsigned char *)priv_key_pem, - key_len + 1, NULL, 0); - - if (ret != 0) { - WARNF("(ecdsa) Failed to parse private key: -0x%04x", -ret); - goto cleanup; - } - - // Compute hash of the message using the specified algorithm - ret = compute_hash(hash_alg, (const unsigned char *)message, strlen(message), - hash, sizeof(hash)); - if (ret != 0) { - WARNF("(ecdsa) Failed to compute message hash"); - goto cleanup; - } - - // Allocate memory for signature (max size for ECDSA) - *signature = malloc(MBEDTLS_ECDSA_MAX_LEN); - if (*signature == NULL) { - WARNF("(ecdsa) Failed to allocate memory for signature"); - ret = -1; - goto cleanup; - } - - // Sign the hash using GenerateHardRandom - ret = mbedtls_pk_sign(&key, hash_to_md_type(hash_alg), hash, hash_size, - *signature, sig_len, GenerateHardRandom, 0); - - if (ret != 0) { - WARNF("(ecdsa) Failed to sign message: -0x%04x", -ret); - free(*signature); *signature = NULL; *sig_len = 0; - goto cleanup; - } -cleanup: - mbedtls_pk_free(&key); - return ret; -} // Lua binding for signing a message + if (!priv_key_pem || strlen(priv_key_pem) == 0) { + WARNF("(ecdsa) Private key is NULL or empty"); + return -1; + } + + mbedtls_pk_init(&key); + + // Parse the private key from PEM (PKCS#8 format) + ret = mbedtls_pk_parse_key(&key, (const unsigned char *)priv_key_pem, + strlen(priv_key_pem) + 1, NULL, 0); + if (ret != 0) { + WARNF("(ecdsa) Failed to parse private key: -0x%04x", -ret); + mbedtls_pk_free(&key); + return -1; + } + + // Compute hash of the message + hash_size = get_hash_size(hash_alg); + ret = compute_hash(hash_alg, (const unsigned char *)message, strlen(message), + hash, sizeof(hash)); + if (ret != 0) { + WARNF("(ecdsa) Failed to compute message hash"); + mbedtls_pk_free(&key); + return -1; + } + + // Allocate memory for the signature + *signature = malloc(MBEDTLS_PK_SIGNATURE_MAX_SIZE); + if (*signature == NULL) { + WARNF("(ecdsa) Failed to allocate memory for signature"); + mbedtls_pk_free(&key); + return -1; + } + + // Sign the hash + ret = mbedtls_pk_sign(&key, hash_to_md_type(hash_alg), hash, hash_size, + *signature, sig_len, GenerateHardRandom, NULL); + if (ret != 0) { + WARNF("(ecdsa) Failed to sign message: -0x%04x", -ret); + free(*signature); + *signature = NULL; + *sig_len = 0; + mbedtls_pk_free(&key); + return -1; + } + + mbedtls_pk_free(&key); + return 0; +} +// Lua binding for signing a message static int LuaECDSASign(lua_State *L) { - const char *hash_name = luaL_optstring(L, 1, "sha256"); // Default to SHA-256 + const char *hash_name = luaL_optstring(L, 3, "sha256"); // Default to SHA-256 const char *message = luaL_checkstring(L, 2); - const char *priv_key_pem = luaL_checkstring(L, 3); + const char *priv_key_pem = luaL_checkstring(L, 1); hash_algorithm_t hash_alg = string_to_hash_alg(hash_name); @@ -1054,12 +1046,12 @@ cleanup: return ret; } static int LuaECDSAVerify(lua_State *L) { - const char *hash_name = luaL_optstring(L, 1, "sha256"); // Default to SHA-256 + const char *pub_key_pem = luaL_checkstring(L, 1); const char *message = luaL_checkstring(L, 2); - const char *pub_key_pem = luaL_checkstring(L, 3); size_t sig_len; const unsigned char *signature = - (const unsigned char *)luaL_checklstring(L, 4, &sig_len); + (const unsigned char *)luaL_checklstring(L, 3, &sig_len); + const char *hash_name = luaL_optstring(L, 4, "sha256"); // Default to SHA-256 hash_algorithm_t hash_alg = string_to_hash_alg(hash_name); From e1403ff9a96fe35be333032b91e6197033158834 Mon Sep 17 00:00:00 2001 From: Miguel Terron Date: Mon, 2 Jun 2025 15:07:00 +1200 Subject: [PATCH 03/18] Align function naming --- tool/net/lcrypto.c | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tool/net/lcrypto.c b/tool/net/lcrypto.c index dd909e175..dc72c8bed 100644 --- a/tool/net/lcrypto.c +++ b/tool/net/lcrypto.c @@ -17,7 +17,7 @@ #include "tool/net/luacheck.h" // Updated PemToJwk to parse PEM keys and convert them into JWK format -static int PemToJwk(lua_State *L) { +static int convertPemToJwk(lua_State *L) { const char *pem_key = luaL_checkstring(L, 1); mbedtls_pk_context key; @@ -166,7 +166,7 @@ static int PemToJwk(lua_State *L) { } // CSR Creation Function -static int CreateCSR(lua_State *L) { +static int generateCsr(lua_State *L) { const char *key_pem = luaL_checkstring(L, 1); const char *subject_name = luaL_checkstring(L, 2); const char *san_list = luaL_optstring(L, 3, NULL); @@ -1128,8 +1128,8 @@ static const luaL_Reg kLuaCrypto[] = { {"encrypt", LuaCryptoEncrypt}, // {"decrypt", LuaCryptoDecrypt}, // {"generatekeypair", LuaCryptoGenerateKeyPair}, // - {"PemToJwk", PemToJwk}, // - {"csrGenerate", CreateCSR}, // + {"convertPemToJwk", convertPemToJwk}, // + {"generateCsr", generateCsr}, // {0}, // }; From c1be35a820a59a258af5c521bad8ae707d2de217 Mon Sep 17 00:00:00 2001 From: Miguel Terron Date: Mon, 2 Jun 2025 15:30:37 +1200 Subject: [PATCH 04/18] Make key type optional in crypto.generateKeyPair. Defaults to rsa --- tool/net/lcrypto.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tool/net/lcrypto.c b/tool/net/lcrypto.c index dc72c8bed..1761bdd20 100644 --- a/tool/net/lcrypto.c +++ b/tool/net/lcrypto.c @@ -1110,7 +1110,7 @@ static int LuaCryptoDecrypt(lua_State *L) { } static int LuaCryptoGenerateKeyPair(lua_State *L) { - const char *key_type = luaL_checkstring(L, 1); // Key type (e.g., "rsa", "ecdsa") + const char *key_type = luaL_optstring(L, 1, "rsa"); // Key type (e.g., "rsa", "ecdsa") lua_remove(L, 1); // Remove the first argument (key type or cipher type) before dispatching if (strcasecmp(key_type, "rsa") == 0) { From b4b7c9e5b71044a88917b7da4ef90d18f9afe678 Mon Sep 17 00:00:00 2001 From: Miguel Terron Date: Mon, 2 Jun 2025 15:35:51 +1200 Subject: [PATCH 05/18] Fix tests --- test/tool/net/lcrypto_test.lua | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/tool/net/lcrypto_test.lua b/test/tool/net/lcrypto_test.lua index c89483086..34418533a 100644 --- a/test/tool/net/lcrypto_test.lua +++ b/test/tool/net/lcrypto_test.lua @@ -55,14 +55,14 @@ end local function test_csr_generation() local priv_key, pub_key = crypto.generatekeypair("rsa", 2048) local subject_name = "CN=example.com,O=Example Org,C=US" - local csr = crypto.csrGenerate(priv_key, subject_name) + local csr = crypto.generateCsr(priv_key, subject_name) assert_equal(type(csr), "string", "CSR generation") end -- Test PemToJwk conversion local function test_pem_to_jwk() local priv_key, pub_key = crypto.generatekeypair("rsa", 2048) - local jwk = crypto.PemToJwk(pub_key) + local jwk = crypto.convertPemToJwk(pub_key) assert_equal(type(jwk), "table", "PEM to JWK conversion") assert_equal(jwk.kty, "RSA", "JWK key type") end From 558214598fa381d486b2bb421ddd432a663e4fa2 Mon Sep 17 00:00:00 2001 From: Miguel Terron Date: Mon, 2 Jun 2025 15:40:18 +1200 Subject: [PATCH 06/18] Remove old reference --- tool/net/lcrypto.h | 1 - 1 file changed, 1 deletion(-) diff --git a/tool/net/lcrypto.h b/tool/net/lcrypto.h index e1e11ed65..0e1cac872 100644 --- a/tool/net/lcrypto.h +++ b/tool/net/lcrypto.h @@ -4,7 +4,6 @@ COSMOPOLITAN_C_START_ int LuaCrypto(lua_State *L); -int luaopen_lcrypto(lua_State *L); COSMOPOLITAN_C_END_ #endif /* COSMOPOLITAN_TOOL_NET_LCRYPTO_H_ */ From 9e121882d033463c62e0e3de36289dbc858b926b Mon Sep 17 00:00:00 2001 From: Miguel Terron Date: Mon, 2 Jun 2025 16:19:01 +1200 Subject: [PATCH 07/18] PROPERLY make arguments optional in crypto.generatekeypair thanks to @pkulchenko --- tool/net/lcrypto.c | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tool/net/lcrypto.c b/tool/net/lcrypto.c index 1761bdd20..f5d2bced2 100644 --- a/tool/net/lcrypto.c +++ b/tool/net/lcrypto.c @@ -1110,8 +1110,12 @@ static int LuaCryptoDecrypt(lua_State *L) { } static int LuaCryptoGenerateKeyPair(lua_State *L) { - const char *key_type = luaL_optstring(L, 1, "rsa"); // Key type (e.g., "rsa", "ecdsa") - lua_remove(L, 1); // Remove the first argument (key type or cipher type) before dispatching + const char *key_type = "rsa"; // Key type (e.g., "rsa", "ecdsa") + + if (! lua_isinteger(L, 1) && ! lua_isnoneornil(L, 1)) { + key_type = luaL_checkstring(L, 1); // Get key type from first argumen + lua_remove(L, 1); // Remove the first argument (key type or cipher type) before dispatching + } if (strcasecmp(key_type, "rsa") == 0) { return LuaRSAGenerateKeyPair(L); From 19541f95bd745faad5156d4a7e909b59a662ddd4 Mon Sep 17 00:00:00 2001 From: Miguel Terron Date: Tue, 3 Jun 2025 20:38:10 +1200 Subject: [PATCH 08/18] Add AES based encryption and decryption in multiple modes (CBC, CTR & GCM) Improve test coverage --- test/tool/net/lcrypto_test.lua | 208 ++++++++-- third_party/mbedtls/config.h | 2 +- tool/net/lcrypto.c | 682 +++++++++++++++++++++++++++------ 3 files changed, 755 insertions(+), 137 deletions(-) diff --git a/test/tool/net/lcrypto_test.lua b/test/tool/net/lcrypto_test.lua index 34418533a..52a7a5521 100644 --- a/test/tool/net/lcrypto_test.lua +++ b/test/tool/net/lcrypto_test.lua @@ -1,14 +1,23 @@ -- Helper function to print test results -local function assert_equal(actual, expected, message) +local function assert_equal(actual, expected, plaintext) if actual ~= expected then - error(message .. ": expected " .. tostring(expected) .. ", got " .. tostring(actual)) + error(plaintext .. ": expected " .. tostring(expected) .. ", got " .. tostring(actual)) else - print("PASS: " .. message) + print("PASS: " .. plaintext) + end +end + +local function assert_not_equal(actual, not_expected, plaintext) + if actual == not_expected then + error(plaintext .. ": did not expect " .. tostring(not_expected)) + else + print("PASS: " .. plaintext) end end -- Test RSA key pair generation local function test_rsa_keypair_generation() + print('\27[1;7mTest RSA key pair generation \27[0m') local priv_key, pub_key = crypto.generatekeypair("rsa", 2048) assert_equal(type(priv_key), "string", "RSA private key generation") assert_equal(type(pub_key), "string", "RSA public key generation") @@ -16,6 +25,7 @@ end -- Test ECDSA key pair generation local function test_ecdsa_keypair_generation() + print('\n\27[1;7mTest ECDSA key pair generation \27[0m') local priv_key, pub_key = crypto.generatekeypair("ecdsa", "secp256r1") assert_equal(type(priv_key), "string", "ECDSA private key generation") assert_equal(type(pub_key), "string", "ECDSA public key generation") @@ -23,61 +33,207 @@ end -- Test RSA encryption and decryption local function test_rsa_encryption_decryption() + print('\n\27[1;7mTest RSA encryption and decryption \27[0m') local priv_key, pub_key = crypto.generatekeypair("rsa", 2048) - local message = "Hello, RSA!" - local encrypted = crypto.encrypt("rsa", pub_key, message) + local plaintext = "Hello, RSA!" + local encrypted = crypto.encrypt("rsa", pub_key, plaintext) assert_equal(type(encrypted), "string", "RSA encryption") local decrypted = crypto.decrypt("rsa", priv_key, encrypted) - assert_equal(decrypted, message, "RSA decryption") + assert_equal(decrypted, plaintext, "RSA decryption") end -- Test RSA signing and verification local function test_rsa_signing_verification() + print('\n\27[1;7mTest RSA signing and verification \27[0m') local priv_key, pub_key = crypto.generatekeypair("rsa", 2048) - local message = "Sign this message" - local signature = crypto.sign("rsa", priv_key, message, "sha256") + local plaintext = "Sign this plaintext" + local signature = crypto.sign("rsa", priv_key, plaintext, "sha256") assert_equal(type(signature), "string", "RSA signing") - local is_valid = crypto.verify("rsa", pub_key, message, signature, "sha256") + local is_valid = crypto.verify("rsa", pub_key, plaintext, signature, "sha256") assert_equal(is_valid, true, "RSA signature verification") end -- Test ECDSA signing and verification local function test_ecdsa_signing_verification() + print('\n\27[1;7mTest ECDSA signing and verification \27[0m') local priv_key, pub_key = crypto.generatekeypair("ecdsa", "secp256r1") - local message = "Sign this message with ECDSA" - local signature = crypto.sign("ecdsa", priv_key, message, "sha256") + local plaintext = "Sign this plaintext with ECDSA" + local signature = crypto.sign("ecdsa", priv_key, plaintext, "sha256") assert_equal(type(signature), "string", "ECDSA signing") - local is_valid = crypto.verify("ecdsa", pub_key, message, signature, "sha256") + local is_valid = crypto.verify("ecdsa", pub_key, plaintext, signature, "sha256") assert_equal(is_valid, true, "ECDSA signature verification") end --- Test CSR generation -local function test_csr_generation() - local priv_key, pub_key = crypto.generatekeypair("rsa", 2048) - local subject_name = "CN=example.com,O=Example Org,C=US" - local csr = crypto.generateCsr(priv_key, subject_name) - assert_equal(type(csr), "string", "CSR generation") +-- Test AES key generation +local function test_aes_key_generation() + print('\n\27[1;7mTest AES key generation \27[0m') + local key = crypto.generatekeypair('aes', 256) -- 256-bit key + assert_equal(type(key), "string", "AES key generation") + assert_equal(#key, 32, "AES key length (256 bits)") +end + +-- Test AES encryption and decryption (CBC mode) +local function test_aes_encryption_decryption() + print('\n\27[1;7mTest AES encryption and decryption (CBC mode) \27[0m') + local key = crypto.generatekeypair('aes',256) -- 256-bit key + local plaintext = "Hello, AES CBC!" + + -- Encrypt without providing IV (should auto-generate IV) + print('\27[1mAES encryption (auto IV)\27[0m') + local encrypted, iv = crypto.encrypt("aes", key, plaintext, nil) + assert_equal(type(encrypted), "string", "AES encryption (CBC, auto IV)") + assert_equal(type(iv), "string", "AES IV (auto-generated)") + + -- Decrypt + print('\n\27[1mAES decryption (auto IV)\27[0m') + local decrypted = crypto.decrypt("aes", key, encrypted, iv) + assert_equal(decrypted, plaintext, "AES decryption (CBC, auto IV)") + + -- Encrypt with explicit IV + print('\n\27[1mAES encryption (explicit IV)\27[0m') + local iv2 = GetRandomBytes(16) + local encrypted2, iv_used = crypto.encrypt("aes", key, plaintext, iv2) + assert_equal(type(encrypted2), "string", "AES encryption (CBC, explicit IV)") + assert_equal(iv_used, iv2, "AES IV (explicit)") + + print('\n\27[1mAES decryption (explicit IV)\27[0m') + local decrypted2 = crypto.decrypt("aes", key, encrypted2, iv2) + assert_equal(decrypted2, plaintext, "AES decryption (CBC, explicit IV)") +end + +-- Test AES encryption and decryption (CTR mode) +local function test_aes_encryption_decryption_ctr() + print('\n\27[1;7mTest AES encryption and decryption (CTR mode) \27[0m') + local key = crypto.generatekeypair('aes',256) + local plaintext = "Hello, AES CTR!" + + -- Encrypt without providing IV (should auto-generate IV) + print('\27[1mAES encryption (auto IV)\27[0m') + local encrypted, iv = crypto.encrypt("aes", key, plaintext, nil, "ctr") + assert_equal(type(encrypted), "string", "AES encryption (CTR, auto IV)") + assert_equal(type(iv), "string", "AES IV (auto-generated, CTR)") + + -- Decrypt + print('\n\27[1mAES decryption (auto IV)\27[0m') + local decrypted = crypto.decrypt("aes", key, encrypted, iv, "ctr") + assert_equal(decrypted, plaintext, "AES decryption (CTR, auto IV)") + + -- Encrypt with explicit IV + print('\n\27[1mAES encryption (explicit IV)\27[0m') + local iv2 = GetRandomBytes(16) + local encrypted2, iv_used = crypto.encrypt("aes", key, plaintext, iv2, "ctr") + assert_equal(type(encrypted2), "string", "AES encryption (CTR, explicit IV)") + assert_equal(iv_used, iv2, "AES IV (explicit, CTR)") + + print('\n\27[1mAES decryption (explicit IV)\27[0m') + local decrypted2 = crypto.decrypt("aes", key, encrypted2, iv2, "ctr") + assert_equal(decrypted2, plaintext, "AES decryption (CTR, explicit IV)") +end + +-- Test AES encryption and decryption (GCM mode) +local function test_aes_encryption_decryption_gcm() + print('\n\27[1;7mTest AES encryption and decryption (GCM mode) \27[0m') + local key = crypto.generatekeypair('aes',256) + local plaintext = "Hello, AES GCM!" + + -- Encrypt without providing IV (should auto-generate IV) + print('\27[1mAES encryption (auto IV)\27[0m') + local encrypted, iv, tag = crypto.encrypt("aes", key, plaintext, nil, "gcm") + assert_equal(type(encrypted), "string", "AES encryption (GCM, auto IV)") + assert_equal(type(iv), "string", "AES IV (auto-generated, GCM)") + assert_equal(type(tag), "string", "AES GCM tag (auto IV)") + + -- Decrypt + print('\n\27[1mAES decryption (auto IV)\27[0m') + local decrypted = crypto.decrypt("aes", key, encrypted, iv, "gcm", nil, tag) + assert_equal(decrypted, plaintext, "AES decryption (GCM, auto IV)") + + -- Encrypt with explicit IV + print('\n\27[1mAES encryption (explicit IV)\27[0m') + local iv2 = GetRandomBytes(13) -- GCM IV/nonce can be 12-16 bytes, 12 is standard + local encrypted2, iv_used, tag2 = crypto.encrypt("aes", key, plaintext, iv2, "gcm") + assert_equal(type(encrypted2), "string", "AES encryption (GCM, explicit IV)") + assert_equal(iv_used, iv2, "AES IV (explicit, GCM)") + assert_equal(type(tag2), "string", "AES GCM tag (explicit IV)") + + print('\n\27[1mAES decryption (explicit IV)\27[0m') + local decrypted2 = crypto.decrypt("aes", key, encrypted2, iv2, "gcm", nil, tag2) + assert_equal(decrypted2, plaintext, "AES decryption (GCM, explicit IV)") end -- Test PemToJwk conversion local function test_pem_to_jwk() - local priv_key, pub_key = crypto.generatekeypair("rsa", 2048) - local jwk = crypto.convertPemToJwk(pub_key) - assert_equal(type(jwk), "table", "PEM to JWK conversion") - assert_equal(jwk.kty, "RSA", "JWK key type") + print('\n\27[1;7mTest PEM to JWK conversion \27[0m') + local priv_key, pub_key = crypto.generatekeypair() + print('\27[1mRSA Private key to JWK conversion\27[0m') + local priv_jwk = crypto.convertPemToJwk(priv_key) + assert_equal(type(priv_jwk), "table", "PEM to JWK conversion") + assert_equal(priv_jwk.kty, "RSA", "JWK key type") + + print('\n\27[1mRSA Public key to JWK conversion\27[0m') + local pub_jwk = crypto.convertPemToJwk(pub_key) + assert_equal(type(pub_jwk), "table", "PEM to JWK conversion") + assert_equal(pub_jwk.kty, "RSA", "JWK key type") + + -- Test ECDSA keys + local priv_key, pub_key = crypto.generatekeypair('ecdsa') + print('\n\27[1mECDSA Private key to JWK conversion\27[0m') + local priv_jwk = crypto.convertPemToJwk(priv_key) + assert_equal(type(priv_jwk), "table", "PEM to JWK conversion") + assert_equal(priv_jwk.kty, "EC", "JWK key type") + + print('\n\27[1mECDSA Public key to JWK conversion\27[0m') + local pub_jwk = crypto.convertPemToJwk(pub_key) + assert_equal(type(pub_jwk), "table", "PEM to JWK conversion") + assert_equal(pub_jwk.kty, "EC", "JWK key type") +end + +-- Test CSR generation +local function test_csr_generation() + print('\n\27[1;7mTest CSR generation \27[0m') + local priv_key, _ = crypto.generatekeypair() + local subject_name = "CN=example.com,O=Example Org,C=US" + local san = "DNS:example.com, DNS:www.example.com, IP:192.168.1.1" + + local csr = crypto.GenerateCsr(priv_key, subject_name) + assert_equal(type(csr), "string", "CSR generation with subject name") + + csr = crypto.GenerateCsr(priv_key, subject_name, san) + assert_equal(type(csr), "string", "CSR generation with subject name and san") + + csr = crypto.GenerateCsr(priv_key, nil, san) + assert_equal(type(csr), "string", "CSR generation with nil subject name and san") + + csr = crypto.GenerateCsr(priv_key, '', san) + assert_equal(type(csr), "string", "CSR generation with empty subject name and san") + + -- These should fail + csr = crypto.GenerateCsr(priv_key, '') + assert_not_equal(type(csr), "string", "CSR generation with empty subject name and no san is rejected") + + csr = crypto.GenerateCsr(priv_key) + assert_not_equal(type(csr), "string", "CSR generation with nil subject name and no san is rejected") end -- Run all tests local function run_tests() print("Running tests for lcrypto...") test_rsa_keypair_generation() - test_ecdsa_keypair_generation() - test_rsa_encryption_decryption() test_rsa_signing_verification() + test_rsa_encryption_decryption() + test_ecdsa_keypair_generation() test_ecdsa_signing_verification() - test_csr_generation() + test_aes_key_generation() + test_aes_encryption_decryption() + test_aes_encryption_decryption_ctr() + test_aes_encryption_decryption_gcm() test_pem_to_jwk() + test_csr_generation() + print('') print("All tests passed!") + EXIT=0 + return EXIT end -run_tests() +EXIT=70 +os.exit(run_tests()) diff --git a/third_party/mbedtls/config.h b/third_party/mbedtls/config.h index 88087503e..d181060b1 100644 --- a/third_party/mbedtls/config.h +++ b/third_party/mbedtls/config.h @@ -40,9 +40,9 @@ #define MBEDTLS_GCM_C #ifndef TINY #define MBEDTLS_CIPHER_MODE_CBC +#define MBEDTLS_CIPHER_MODE_CTR /*#define MBEDTLS_CCM_C*/ /*#define MBEDTLS_CIPHER_MODE_CFB*/ -/*#define MBEDTLS_CIPHER_MODE_CTR*/ /*#define MBEDTLS_CIPHER_MODE_OFB*/ /*#define MBEDTLS_CIPHER_MODE_XTS*/ #endif diff --git a/tool/net/lcrypto.c b/tool/net/lcrypto.c index f5d2bced2..d5ca2890f 100644 --- a/tool/net/lcrypto.c +++ b/tool/net/lcrypto.c @@ -9,6 +9,10 @@ #include "third_party/mbedtls/oid.h" #include "third_party/mbedtls/md.h" #include "third_party/mbedtls/base64.h" +#include "third_party/mbedtls/aes.h" +#include "third_party/mbedtls/ctr_drbg.h" +#include "third_party/mbedtls/entropy.h" +#include "third_party/mbedtls/gcm.h" // Standard C library and redbean utilities #include "libc/errno.h" @@ -16,8 +20,8 @@ #include "libc/str/str.h" #include "tool/net/luacheck.h" -// Updated PemToJwk to parse PEM keys and convert them into JWK format -static int convertPemToJwk(lua_State *L) { +// Parse PEM keys and convert them into JWK format +static int LuaConvertPemToJwk(lua_State *L) { const char *pem_key = luaL_checkstring(L, 1); mbedtls_pk_context key; @@ -166,11 +170,23 @@ static int convertPemToJwk(lua_State *L) { } // CSR Creation Function -static int generateCsr(lua_State *L) { +static int LuaGenerateCSR(lua_State *L) { const char *key_pem = luaL_checkstring(L, 1); - const char *subject_name = luaL_checkstring(L, 2); + const char *subject_name; const char *san_list = luaL_optstring(L, 3, NULL); + if (lua_isnoneornil(L, 2)) { + subject_name = ""; + } else { + subject_name = luaL_checkstring(L, 2); + } + + + if (lua_isnoneornil(L, 3) && subject_name[0] == '\0') { + lua_pushnil(L); + lua_pushstring(L, "Subject name or SANs are required"); + return 2; + } mbedtls_pk_context key; mbedtls_x509write_csr req; char buf[4096]; @@ -211,7 +227,9 @@ static int generateCsr(lua_State *L) { return 1; } +// RSA +// Generate RSA Key Pair static bool RSAGenerateKeyPair(char **private_key_pem, size_t *private_key_len, char **public_key_pem, size_t *public_key_len, unsigned int key_length) { @@ -263,6 +281,7 @@ static bool RSAGenerateKeyPair(char **private_key_pem, size_t *private_key_len, mbedtls_pk_free(&key); return true; } + /** * Lua wrapper for RSA key pair generation * @@ -272,43 +291,38 @@ static bool RSAGenerateKeyPair(char **private_key_pem, size_t *private_key_len, * error_message) */ static int LuaRSAGenerateKeyPair(lua_State *L) { - char *private_key, *public_key; - size_t private_len, public_len; - int key_length = 2048; // Default RSA key length - - // Get key length from Lua (optional parameter) - if (lua_gettop(L) >= 1 && !lua_isnil(L, 1)) { - key_length = luaL_checkinteger(L, 1); - // Validate key length (common RSA key lengths are 1024, 2048, 3072, 4096) - if (key_length != 1024 && key_length != 2048 && key_length != 3072 && - key_length != 4096) { - lua_pushnil(L); - lua_pushstring(L, - "Invalid RSA key length. Use 1024, 2048, 3072, or 4096."); - return 2; + int bits = 2048; + // If no arguments, or first argument is nil, default to 2048 + if (lua_gettop(L) == 0 || lua_isnoneornil(L, 1)) { + bits = 2048; + } else if (lua_gettop(L) == 1 && lua_type(L, 1) == LUA_TNUMBER) { + bits = (int)lua_tointeger(L, 1); + } else { + bits = (int)luaL_optinteger(L, 2, 2048); } - } - // Call the C function to generate the key pair - if (!RSAGenerateKeyPair(&private_key, &private_len, &public_key, &public_len, - key_length)) { - lua_pushnil(L); - lua_pushstring(L, "Failed to generate RSA key pair"); + char *private_key, *public_key; + size_t private_len, public_len; + + // Call the C function to generate the key pair + if (!RSAGenerateKeyPair(&private_key, &private_len, &public_key, &public_len, bits)) { + lua_pushnil(L); + lua_pushstring(L, "Failed to generate RSA key pair"); + return 2; + } + + // Push results to Lua + lua_pushstring(L, private_key); + lua_pushstring(L, public_key); + + // Clean up + free(private_key); + free(public_key); + return 2; - } - - // Push results to Lua - lua_pushstring(L, private_key); - lua_pushstring(L, public_key); - - // Clean up - free(private_key); - free(public_key); - - return 2; } -// RSA + static char *RSAEncrypt(const char *public_key_pem, const unsigned char *data, size_t data_len, size_t *out_len) { int rc; @@ -622,7 +636,7 @@ static int LuaRSAVerify(lua_State *L) { return 1; } - +// Elliptic Curve Cryptography Functions // Supported curves mapping typedef struct { const char *name; @@ -710,6 +724,7 @@ static int LuaListHashAlgorithms(lua_State *L) { return 1; } + // List available curves static int LuaListCurves(lua_State *L) { const curve_map_t *curve = supported_curves; @@ -911,68 +926,77 @@ static int LuaECDSAGenerateKeyPair(lua_State *L) { static int ECDSASign(const char *priv_key_pem, const char *message, hash_algorithm_t hash_alg, unsigned char **signature, size_t *sig_len) { - mbedtls_pk_context key; - unsigned char hash[64]; // Max hash size (SHA-512) - size_t hash_size; - int ret; + mbedtls_pk_context key; + unsigned char hash[64]; // Max hash size (SHA-512) + size_t hash_size; + int ret; + *signature = NULL; + *sig_len = 0; + + if (!priv_key_pem) { + WARNF("(ecdsa) Private key is NULL"); + return -1; + } + + // Get the length of the PEM string (excluding null terminator) + size_t key_len = strlen(priv_key_pem); + if (key_len == 0) { + WARNF("(ecdsa) Private key is empty"); + return -1; + } + + // Get hash size for the selected algorithm + hash_size = get_hash_size(hash_alg); + + mbedtls_pk_init(&key); + + // Parse the private key from PEM directly without creating a copy + ret = mbedtls_pk_parse_key(&key, (const unsigned char *)priv_key_pem, + key_len + 1, NULL, 0); + + if (ret != 0) { + WARNF("(ecdsa) Failed to parse private key: -0x%04x", -ret); + goto cleanup; + } + + // Compute hash of the message using the specified algorithm + ret = compute_hash(hash_alg, (const unsigned char *)message, strlen(message), + hash, sizeof(hash)); + if (ret != 0) { + WARNF("(ecdsa) Failed to compute message hash"); + goto cleanup; + } + + // Allocate memory for signature (max size for ECDSA) + *signature = malloc(MBEDTLS_ECDSA_MAX_LEN); + if (*signature == NULL) { + WARNF("(ecdsa) Failed to allocate memory for signature"); + ret = -1; + goto cleanup; + } + + // Sign the hash using GenerateHardRandom + ret = mbedtls_pk_sign(&key, hash_to_md_type(hash_alg), hash, hash_size, + *signature, sig_len, GenerateHardRandom, 0); + + if (ret != 0) { + WARNF("(ecdsa) Failed to sign message: -0x%04x", -ret); + free(*signature); *signature = NULL; *sig_len = 0; + goto cleanup; + } - if (!priv_key_pem || strlen(priv_key_pem) == 0) { - WARNF("(ecdsa) Private key is NULL or empty"); - return -1; - } - - mbedtls_pk_init(&key); - - // Parse the private key from PEM (PKCS#8 format) - ret = mbedtls_pk_parse_key(&key, (const unsigned char *)priv_key_pem, - strlen(priv_key_pem) + 1, NULL, 0); - if (ret != 0) { - WARNF("(ecdsa) Failed to parse private key: -0x%04x", -ret); - mbedtls_pk_free(&key); - return -1; - } - - // Compute hash of the message - hash_size = get_hash_size(hash_alg); - ret = compute_hash(hash_alg, (const unsigned char *)message, strlen(message), - hash, sizeof(hash)); - if (ret != 0) { - WARNF("(ecdsa) Failed to compute message hash"); - mbedtls_pk_free(&key); - return -1; - } - - // Allocate memory for the signature - *signature = malloc(MBEDTLS_PK_SIGNATURE_MAX_SIZE); - if (*signature == NULL) { - WARNF("(ecdsa) Failed to allocate memory for signature"); - mbedtls_pk_free(&key); - return -1; - } - - // Sign the hash - ret = mbedtls_pk_sign(&key, hash_to_md_type(hash_alg), hash, hash_size, - *signature, sig_len, GenerateHardRandom, NULL); - if (ret != 0) { - WARNF("(ecdsa) Failed to sign message: -0x%04x", -ret); - free(*signature); - *signature = NULL; - *sig_len = 0; - mbedtls_pk_free(&key); - return -1; - } - - mbedtls_pk_free(&key); - return 0; -} -// Lua binding for signing a message +cleanup: + mbedtls_pk_free(&key); + return ret; +} // Lua binding for signing a message static int LuaECDSASign(lua_State *L) { - const char *hash_name = luaL_optstring(L, 3, "sha256"); // Default to SHA-256 - const char *message = luaL_checkstring(L, 2); + // Correct order: priv_key, message, hash_name (default sha256) const char *priv_key_pem = luaL_checkstring(L, 1); + const char *message = luaL_checkstring(L, 2); + const char *hash_name = luaL_optstring(L, 3, "sha256"); hash_algorithm_t hash_alg = string_to_hash_alg(hash_name); @@ -1046,12 +1070,12 @@ cleanup: return ret; } static int LuaECDSAVerify(lua_State *L) { + // Correct order: pub_key, message, signature, hash_name (default sha256) const char *pub_key_pem = luaL_checkstring(L, 1); const char *message = luaL_checkstring(L, 2); size_t sig_len; - const unsigned char *signature = - (const unsigned char *)luaL_checklstring(L, 3, &sig_len); - const char *hash_name = luaL_optstring(L, 4, "sha256"); // Default to SHA-256 + const unsigned char *signature = (const unsigned char *)luaL_checklstring(L, 3, &sig_len); + const char *hash_name = luaL_optstring(L, 4, "sha256"); hash_algorithm_t hash_alg = string_to_hash_alg(hash_name); @@ -1061,6 +1085,437 @@ static int LuaECDSAVerify(lua_State *L) { return 1; } + +// AES +// AES key generation helper +static int LuaAesGenerateKey(lua_State *L) { + int keybits = 128; + if (lua_gettop(L) >= 1 && !lua_isnil(L, 1)) { + keybits = luaL_checkinteger(L, 1); + } + int keylen = keybits / 8; + if ((keybits != 128 && keybits != 192 && keybits != 256) || (keylen != 16 && keylen != 24 && keylen != 32)) { + lua_pushnil(L); + lua_pushstring(L, "AES key length must be 128, 192, or 256 bits"); + return 2; + } + unsigned char key[32]; + mbedtls_entropy_context entropy; + mbedtls_ctr_drbg_context ctr_drbg; + mbedtls_entropy_init(&entropy); + mbedtls_ctr_drbg_init(&ctr_drbg); + const char *pers = "aes_keygen"; + int ret = mbedtls_ctr_drbg_seed(&ctr_drbg, mbedtls_entropy_func, &entropy, (const unsigned char *)pers, strlen(pers)); + if (ret != 0) { + lua_pushnil(L); + lua_pushstring(L, "Failed to initialize RNG for AES key"); + mbedtls_ctr_drbg_free(&ctr_drbg); + mbedtls_entropy_free(&entropy); + return 2; + } + ret = mbedtls_ctr_drbg_random(&ctr_drbg, key, keylen); + mbedtls_ctr_drbg_free(&ctr_drbg); + mbedtls_entropy_free(&entropy); + if (ret != 0) { + lua_pushnil(L); + lua_pushstring(L, "Failed to generate random AES key"); + return 2; + } + lua_pushlstring(L, (const char *)key, keylen); + return 1; +} + +// AES encryption supporting CBC, GCM, and CTR modes +static int LuaAesEncrypt(lua_State *L) { + // Accept IV as the 3rd argument (after key, plaintext) + size_t keylen, ivlen = 0, ptlen; + const unsigned char *key = (const unsigned char *)luaL_checklstring(L, 1, &keylen); + const unsigned char *plaintext = (const unsigned char *)luaL_checklstring(L, 2, &ptlen); + const unsigned char *iv = NULL; + unsigned char *gen_iv = NULL; + int iv_was_generated = 0; + + const char *mode = luaL_optstring(L, 4, "cbc"); // Default to CBC if not provided + int ret = 0; + unsigned char *output = NULL; + int is_gcm = 0, is_ctr = 0, is_cbc = 0; + + if (strcasecmp(mode, "cbc") == 0) { + is_cbc = 1; + } else if (strcasecmp(mode, "gcm") == 0) { + is_gcm = 1; + } else if (strcasecmp(mode, "ctr") == 0) { + is_ctr = 1; + } else { + lua_pushnil(L); + lua_pushstring(L, "Unsupported AES mode. Use 'cbc', 'gcm', or 'ctr'."); + return 2; + } + + // If IV is not provided (arg3 is nil or missing), auto-generate + if (lua_isnoneornil(L, 3)) { + // For GCM, standard is 12 bytes, but allow 12-16 + if (is_gcm) { + ivlen = 12; + } else { + ivlen = 16; + } + gen_iv = malloc(ivlen); + if (!gen_iv) { + lua_pushnil(L); + lua_pushstring(L, "Failed to allocate IV"); + return 2; + } + mbedtls_entropy_context entropy; + mbedtls_ctr_drbg_context ctr_drbg; + mbedtls_entropy_init(&entropy); + mbedtls_ctr_drbg_init(&ctr_drbg); + mbedtls_ctr_drbg_seed(&ctr_drbg, mbedtls_entropy_func, &entropy, NULL, 0); + mbedtls_ctr_drbg_random(&ctr_drbg, gen_iv, ivlen); + mbedtls_ctr_drbg_free(&ctr_drbg); + mbedtls_entropy_free(&entropy); + iv = gen_iv; + iv_was_generated = 1; + } else { + // IV provided + iv = (const unsigned char *)luaL_checklstring(L, 3, &ivlen); + // Do not force ivlen to 16 here! Accept actual length for GCM (12-16) + if (is_cbc || is_ctr) { + if (ivlen != 16) { + lua_pushnil(L); + lua_pushstring(L, "AES IV must be 16 bytes for CBC/CTR"); + return 2; + } + } else if (is_gcm) { + if (ivlen < 12 || ivlen > 16) { + lua_pushnil(L); + lua_pushstring(L, "AES GCM IV/nonce must be 12-16 bytes"); + return 2; + } + } + iv_was_generated = 0; + } + + if (is_cbc) { + // PKCS7 padding + size_t block_size = 16; + size_t padlen = block_size - (ptlen % block_size); + size_t ctlen = ptlen + padlen; + unsigned char *input = malloc(ctlen); + if (!input) { + lua_pushnil(L); + lua_pushstring(L, "Memory allocation failed"); + return 2; + } + memcpy(input, plaintext, ptlen); + memset(input + ptlen, (unsigned char)padlen, padlen); + output = malloc(ctlen); + if (!output) { + free(input); + lua_pushnil(L); + lua_pushstring(L, "Memory allocation failed"); + return 2; + } + mbedtls_aes_context aes; + mbedtls_aes_init(&aes); + ret = mbedtls_aes_setkey_enc(&aes, key, keylen * 8); + if (ret != 0) { + free(input); + free(output); + mbedtls_aes_free(&aes); + lua_pushnil(L); + lua_pushstring(L, "Failed to set AES encryption key"); + return 2; + } + unsigned char iv_copy[16]; + memcpy(iv_copy, iv, 16); + ret = mbedtls_aes_crypt_cbc(&aes, MBEDTLS_AES_ENCRYPT, ctlen, iv_copy, input, output); + mbedtls_aes_free(&aes); + free(input); + if (ret != 0) { + free(output); + lua_pushnil(L); + lua_pushstring(L, "AES CBC encryption failed"); + return 2; + } + lua_pushlstring(L, (const char *)output, ctlen); + lua_pushlstring(L, (const char *)iv, ivlen); + free(output); + if (iv_was_generated) free(gen_iv); + return 2; + } else if (is_ctr) { + // CTR mode: no padding + output = malloc(ptlen); + if (!output) { + lua_pushnil(L); + lua_pushstring(L, "Memory allocation failed"); + return 2; + } + mbedtls_aes_context aes; + mbedtls_aes_init(&aes); + ret = mbedtls_aes_setkey_enc(&aes, key, keylen * 8); + if (ret != 0) { + free(output); + mbedtls_aes_free(&aes); + lua_pushnil(L); + lua_pushstring(L, "Failed to set AES encryption key"); + return 2; + } + unsigned char nonce_counter[16]; + unsigned char stream_block[16]; + size_t nc_off = 0; + memcpy(nonce_counter, iv, 16); + memset(stream_block, 0, 16); + ret = mbedtls_aes_crypt_ctr(&aes, ptlen, &nc_off, nonce_counter, stream_block, plaintext, output); + mbedtls_aes_free(&aes); + if (ret != 0) { + free(output); + lua_pushnil(L); + lua_pushstring(L, "AES CTR encryption failed"); + return 2; + } + lua_pushlstring(L, (const char *)output, ptlen); + lua_pushlstring(L, (const char *)iv, ivlen); + free(output); + if (iv_was_generated) free(gen_iv); + return 2; + } else if (is_gcm) { + // GCM mode: authenticated encryption + size_t taglen = 16; + unsigned char tag[16]; + output = malloc(ptlen); + if (!output) { + lua_pushnil(L); + lua_pushstring(L, "Memory allocation failed"); + return 2; + } + mbedtls_gcm_context gcm; + mbedtls_gcm_init(&gcm); + ret = mbedtls_gcm_setkey(&gcm, MBEDTLS_CIPHER_ID_AES, key, keylen * 8); + if (ret != 0) { + free(output); + mbedtls_gcm_free(&gcm); + lua_pushnil(L); + lua_pushstring(L, "Failed to set AES GCM key"); + return 2; + } + // Use actual ivlen, not hardcoded 16 + ret = mbedtls_gcm_crypt_and_tag(&gcm, MBEDTLS_GCM_ENCRYPT, ptlen, iv, ivlen, NULL, 0, plaintext, output, taglen, tag); + mbedtls_gcm_free(&gcm); + if (ret != 0) { + free(output); + lua_pushnil(L); + lua_pushstring(L, "AES GCM encryption failed"); + return 2; + } + lua_pushlstring(L, (const char *)output, ptlen); + lua_pushlstring(L, (const char *)iv, ivlen); + lua_pushlstring(L, (const char *)tag, taglen); + free(output); + if (iv_was_generated) free(gen_iv); + return 3; + } + lua_pushnil(L); + lua_pushstring(L, "Internal error in AES encrypt"); + return 2; +} + +// AES decryption supporting CBC, GCM, and CTR modes +static int LuaAesDecrypt(lua_State *L) { + size_t keylen, ctlen, ivlen; + const unsigned char *key = (const unsigned char *)luaL_checklstring(L, 1, &keylen); + const unsigned char *ciphertext = (const unsigned char *)luaL_checklstring(L, 2, &ctlen); + const unsigned char *iv = (const unsigned char *)luaL_checklstring(L, 3, &ivlen); + const char *mode = luaL_optstring(L, 4, "cbc"); // Default to CBC if not provided + const unsigned char *aad = NULL; + const unsigned char *tag = NULL; + size_t aadlen = 0, taglen = 0; + int is_gcm = 0, is_ctr = 0, is_cbc = 0; + + if (strcasecmp(mode, "cbc") == 0) { + is_cbc = 1; + } else if (strcasecmp(mode, "gcm") == 0) { + is_gcm = 1; + } else if (strcasecmp(mode, "ctr") == 0) { + is_ctr = 1; + } else { + lua_pushnil(L); + lua_pushstring(L, "Unsupported AES mode. Use 'cbc', 'gcm', or 'ctr'."); + return 2; + } + + // Validate key length (16, 24, 32 bytes) + if (keylen != 16 && keylen != 24 && keylen != 32) { + lua_pushnil(L); + lua_pushstring(L, "AES key must be 16, 24, or 32 bytes"); + return 2; + } + // Validate IV/nonce length + if (is_cbc || is_ctr) { + if (ivlen != 16) { + lua_pushnil(L); + lua_pushstring(L, "AES IV/nonce must be 16 bytes for CBC/CTR"); + return 2; + } + } else if (is_gcm) { + if (ivlen < 12 || ivlen > 16) { + lua_pushnil(L); + lua_pushstring(L, "AES GCM nonce must be 12-16 bytes"); + return 2; + } + } + + // GCM: require tag and optional AAD + if (is_gcm) { + if (!lua_isnoneornil(L, 5)) { + aad = (const unsigned char *)luaL_checklstring(L, 5, &aadlen); + } + if (!lua_isnoneornil(L, 6)) { + tag = (const unsigned char *)luaL_checklstring(L, 6, &taglen); + if (taglen < 12 || taglen > 16) { + lua_pushnil(L); + lua_pushstring(L, "AES GCM tag must be 12-16 bytes"); + return 2; + } + } else { + lua_pushnil(L); + lua_pushstring(L, "AES GCM tag required as 6th argument"); + return 2; + } + } + + int ret = 0; + unsigned char *output = NULL; + + if (is_cbc) { + // Ciphertext must be a multiple of block size + if (ctlen == 0 || (ctlen % 16) != 0) { + lua_pushnil(L); + lua_pushstring(L, "Ciphertext length must be a multiple of 16"); + return 2; + } + output = malloc(ctlen); + if (!output) { + lua_pushnil(L); + lua_pushstring(L, "Memory allocation failed"); + return 2; + } + mbedtls_aes_context aes; + mbedtls_aes_init(&aes); + ret = mbedtls_aes_setkey_dec(&aes, key, keylen * 8); + if (ret != 0) { + free(output); + mbedtls_aes_free(&aes); + lua_pushnil(L); + lua_pushstring(L, "Failed to set AES decryption key"); + return 2; + } + unsigned char iv_copy[16]; + memcpy(iv_copy, iv, 16); + ret = mbedtls_aes_crypt_cbc(&aes, MBEDTLS_AES_DECRYPT, ctlen, iv_copy, ciphertext, output); + mbedtls_aes_free(&aes); + if (ret != 0) { + free(output); + lua_pushnil(L); + lua_pushstring(L, "AES CBC decryption failed"); + return 2; + } + // PKCS7 unpadding + if (ctlen == 0) { + free(output); + lua_pushnil(L); + lua_pushstring(L, "Decrypted data is empty"); + return 2; + } + unsigned char pad = output[ctlen - 1]; + if (pad == 0 || pad > 16) { + free(output); + lua_pushnil(L); + lua_pushstring(L, "Invalid PKCS7 padding"); + return 2; + } + for (size_t i = 0; i < pad; ++i) { + if (output[ctlen - 1 - i] != pad) { + free(output); + lua_pushnil(L); + lua_pushstring(L, "Invalid PKCS7 padding"); + return 2; + } + } + size_t ptlen = ctlen - pad; + lua_pushlstring(L, (const char *)output, ptlen); + free(output); + return 1; + } else if (is_ctr) { + // CTR mode: no padding + output = malloc(ctlen); + if (!output) { + lua_pushnil(L); + lua_pushstring(L, "Memory allocation failed"); + return 2; + } + mbedtls_aes_context aes; + mbedtls_aes_init(&aes); + ret = mbedtls_aes_setkey_enc(&aes, key, keylen * 8); + if (ret != 0) { + free(output); + mbedtls_aes_free(&aes); + lua_pushnil(L); + lua_pushstring(L, "Failed to set AES encryption key"); + return 2; + } + unsigned char nonce_counter[16]; + unsigned char stream_block[16]; + size_t nc_off = 0; + memcpy(nonce_counter, iv, 16); + memset(stream_block, 0, 16); + ret = mbedtls_aes_crypt_ctr(&aes, ctlen, &nc_off, nonce_counter, stream_block, ciphertext, output); + mbedtls_aes_free(&aes); + if (ret != 0) { + free(output); + lua_pushnil(L); + lua_pushstring(L, "AES CTR decryption failed"); + return 2; + } + lua_pushlstring(L, (const char *)output, ctlen); + free(output); + return 1; + } else if (is_gcm) { + // GCM mode: authenticated decryption + output = malloc(ctlen); + if (!output) { + lua_pushnil(L); + lua_pushstring(L, "Memory allocation failed"); + return 2; + } + mbedtls_gcm_context gcm; + mbedtls_gcm_init(&gcm); + ret = mbedtls_gcm_setkey(&gcm, MBEDTLS_CIPHER_ID_AES, key, keylen * 8); + if (ret != 0) { + free(output); + mbedtls_gcm_free(&gcm); + lua_pushnil(L); + lua_pushstring(L, "Failed to set AES GCM key"); + return 2; + } + ret = mbedtls_gcm_auth_decrypt(&gcm, ctlen, iv, ivlen, aad, aadlen, tag, taglen, ciphertext, output); + mbedtls_gcm_free(&gcm); + if (ret != 0) { + free(output); + lua_pushnil(L); + lua_pushstring(L, "AES GCM decryption failed or authentication failed"); + return 2; + } + lua_pushlstring(L, (const char *)output, ctlen); + free(output); + return 1; + } + lua_pushnil(L); + lua_pushstring(L, "Internal error in AES decrypt"); + return 2; +} + +// LuaCrypto compatible API static int LuaCryptoSign(lua_State *L) { const char *dtype = luaL_checkstring(L, 1); // Type of signature (e.g., "rsa", "ecdsa") lua_remove(L, 1); // Remove the first argument (key type or cipher type) before dispatching @@ -1088,41 +1543,48 @@ static int LuaCryptoVerify(lua_State *L) { } static int LuaCryptoEncrypt(lua_State *L) { - const char *cipher = luaL_checkstring(L, 1); // Cipher type (e.g., "rsa") + const char *cipher = luaL_checkstring(L, 1); // Cipher type (e.g., "rsa", "aes") lua_remove(L, 1); // Remove the first argument (key type or cipher type) before dispatching if (strcasecmp(cipher, "rsa") == 0) { return LuaRSAEncrypt(L); + } else if (strcasecmp(cipher, "aes") == 0) { + return LuaAesEncrypt(L); } else { return luaL_error(L, "Unsupported cipher type: %s", cipher); } } static int LuaCryptoDecrypt(lua_State *L) { - const char *cipher = luaL_checkstring(L, 1); // Cipher type (e.g., "rsa") + const char *cipher = luaL_checkstring(L, 1); // Cipher type (e.g., "rsa", "aes") lua_remove(L, 1); // Remove the first argument (key type or cipher type) before dispatching if (strcasecmp(cipher, "rsa") == 0) { return LuaRSADecrypt(L); + } else if (strcasecmp(cipher, "aes") == 0) { + return LuaAesDecrypt(L); } else { return luaL_error(L, "Unsupported cipher type: %s", cipher); } } static int LuaCryptoGenerateKeyPair(lua_State *L) { - const char *key_type = "rsa"; // Key type (e.g., "rsa", "ecdsa") - - if (! lua_isinteger(L, 1) && ! lua_isnoneornil(L, 1)) { - key_type = luaL_checkstring(L, 1); // Get key type from first argumen - lua_remove(L, 1); // Remove the first argument (key type or cipher type) before dispatching - } - - if (strcasecmp(key_type, "rsa") == 0) { + // If the first argument is a number, treat as RSA key length + if (lua_gettop(L) >= 1 && lua_type(L, 1) == LUA_TNUMBER) { + // Call LuaRSAGenerateKeyPair with the number as the key length return LuaRSAGenerateKeyPair(L); - } else if (strcasecmp(key_type, "ecdsa") == 0) { + } + // Otherwise, get the key type from the first argument, default to "rsa" if not provided + const char *type = luaL_optstring(L, 1, "rsa"); + lua_remove(L, 1); + if (strcasecmp(type, "rsa") == 0) { + return LuaRSAGenerateKeyPair(L); + } else if (strcasecmp(type, "ecdsa") == 0) { return LuaECDSAGenerateKeyPair(L); + } else if (strcasecmp(type, "aes") == 0) { + return LuaAesGenerateKey(L); } else { - return luaL_error(L, "Unsupported key type: %s", key_type); + return luaL_error(L, "Unsupported key type: %s", type); } } @@ -1132,8 +1594,8 @@ static const luaL_Reg kLuaCrypto[] = { {"encrypt", LuaCryptoEncrypt}, // {"decrypt", LuaCryptoDecrypt}, // {"generatekeypair", LuaCryptoGenerateKeyPair}, // - {"convertPemToJwk", convertPemToJwk}, // - {"generateCsr", generateCsr}, // + {"convertPemToJwk", LuaConvertPemToJwk}, // + {"GenerateCsr", LuaGenerateCSR}, // {0}, // }; From 2d2a8a2d7dd79cdef9fd2447b68c24175cebbf11 Mon Sep 17 00:00:00 2001 From: Miguel Terron Date: Wed, 4 Jun 2025 07:44:04 +1200 Subject: [PATCH 09/18] Unconditionally compile AES Cleanup language on the test file --- test/tool/net/lcrypto_test.lua | 24 ++++++++++++------------ third_party/mbedtls/config.h | 2 +- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/test/tool/net/lcrypto_test.lua b/test/tool/net/lcrypto_test.lua index 52a7a5521..5b6c6a640 100644 --- a/test/tool/net/lcrypto_test.lua +++ b/test/tool/net/lcrypto_test.lua @@ -1,17 +1,17 @@ -- Helper function to print test results -local function assert_equal(actual, expected, plaintext) +local function assert_equal(actual, expected, message) if actual ~= expected then - error(plaintext .. ": expected " .. tostring(expected) .. ", got " .. tostring(actual)) + error(message .. ": expected " .. tostring(expected) .. ", got " .. tostring(actual)) else - print("PASS: " .. plaintext) + print("PASS: " .. message) end end -local function assert_not_equal(actual, not_expected, plaintext) +local function assert_not_equal(actual, not_expected, message) if actual == not_expected then - error(plaintext .. ": did not expect " .. tostring(not_expected)) + error(message .. ": did not expect " .. tostring(not_expected)) else - print("PASS: " .. plaintext) + print("PASS: " .. message) end end @@ -46,10 +46,10 @@ end local function test_rsa_signing_verification() print('\n\27[1;7mTest RSA signing and verification \27[0m') local priv_key, pub_key = crypto.generatekeypair("rsa", 2048) - local plaintext = "Sign this plaintext" - local signature = crypto.sign("rsa", priv_key, plaintext, "sha256") + local message = "Sign this message" + local signature = crypto.sign("rsa", priv_key, message, "sha256") assert_equal(type(signature), "string", "RSA signing") - local is_valid = crypto.verify("rsa", pub_key, plaintext, signature, "sha256") + local is_valid = crypto.verify("rsa", pub_key, message, signature, "sha256") assert_equal(is_valid, true, "RSA signature verification") end @@ -57,10 +57,10 @@ end local function test_ecdsa_signing_verification() print('\n\27[1;7mTest ECDSA signing and verification \27[0m') local priv_key, pub_key = crypto.generatekeypair("ecdsa", "secp256r1") - local plaintext = "Sign this plaintext with ECDSA" - local signature = crypto.sign("ecdsa", priv_key, plaintext, "sha256") + local message = "Sign this message with ECDSA" + local signature = crypto.sign("ecdsa", priv_key, message, "sha256") assert_equal(type(signature), "string", "ECDSA signing") - local is_valid = crypto.verify("ecdsa", pub_key, plaintext, signature, "sha256") + local is_valid = crypto.verify("ecdsa", pub_key, message, signature, "sha256") assert_equal(is_valid, true, "ECDSA signature verification") end diff --git a/third_party/mbedtls/config.h b/third_party/mbedtls/config.h index d181060b1..c4e457749 100644 --- a/third_party/mbedtls/config.h +++ b/third_party/mbedtls/config.h @@ -38,9 +38,9 @@ /* block modes */ #define MBEDTLS_GCM_C -#ifndef TINY #define MBEDTLS_CIPHER_MODE_CBC #define MBEDTLS_CIPHER_MODE_CTR +#ifndef TINY /*#define MBEDTLS_CCM_C*/ /*#define MBEDTLS_CIPHER_MODE_CFB*/ /*#define MBEDTLS_CIPHER_MODE_OFB*/ From 5c47674d27ff6824f8a7564b662656871ec70744 Mon Sep 17 00:00:00 2001 From: Miguel Terron Date: Wed, 4 Jun 2025 16:25:12 +1200 Subject: [PATCH 10/18] Improve tests Add definitions Align function name --- test/tool/net/lcrypto_test.lua | 112 ++++++++++++++++++--------------- tool/net/definitions.lua | 67 ++++++++++++++++++++ tool/net/lcrypto.c | 2 +- 3 files changed, 128 insertions(+), 53 deletions(-) diff --git a/test/tool/net/lcrypto_test.lua b/test/tool/net/lcrypto_test.lua index 5b6c6a640..fe8335cfd 100644 --- a/test/tool/net/lcrypto_test.lua +++ b/test/tool/net/lcrypto_test.lua @@ -1,7 +1,7 @@ -- Helper function to print test results local function assert_equal(actual, expected, message) if actual ~= expected then - error(message .. ": expected " .. tostring(expected) .. ", got " .. tostring(actual)) + error("FAIL: " .. message .. ": expected " .. tostring(expected) .. ", got " .. tostring(actual)) else print("PASS: " .. message) end @@ -19,146 +19,153 @@ end local function test_rsa_keypair_generation() print('\27[1;7mTest RSA key pair generation \27[0m') local priv_key, pub_key = crypto.generatekeypair("rsa", 2048) - assert_equal(type(priv_key), "string", "RSA private key generation") - assert_equal(type(pub_key), "string", "RSA public key generation") + assert_equal(type(priv_key), "string", "Private key type") + assert_equal(type(pub_key), "string", "Public key type") end -- Test ECDSA key pair generation local function test_ecdsa_keypair_generation() print('\n\27[1;7mTest ECDSA key pair generation \27[0m') local priv_key, pub_key = crypto.generatekeypair("ecdsa", "secp256r1") - assert_equal(type(priv_key), "string", "ECDSA private key generation") - assert_equal(type(pub_key), "string", "ECDSA public key generation") + assert_equal(type(priv_key), "string", "Private key type") + assert_equal(type(pub_key), "string", "Public key type") end -- Test RSA encryption and decryption local function test_rsa_encryption_decryption() print('\n\27[1;7mTest RSA encryption and decryption \27[0m') local priv_key, pub_key = crypto.generatekeypair("rsa", 2048) + assert(type(priv_key) == "string", "Private key type") + assert(type(pub_key) == "string", "Public key type") local plaintext = "Hello, RSA!" local encrypted = crypto.encrypt("rsa", pub_key, plaintext) - assert_equal(type(encrypted), "string", "RSA encryption") + assert_equal(type(encrypted), "string", "Ciphertext type") local decrypted = crypto.decrypt("rsa", priv_key, encrypted) - assert_equal(decrypted, plaintext, "RSA decryption") + assert_equal(decrypted, plaintext, "Decrypted ciphertext matches plaintext") end -- Test RSA signing and verification local function test_rsa_signing_verification() print('\n\27[1;7mTest RSA signing and verification \27[0m') local priv_key, pub_key = crypto.generatekeypair("rsa", 2048) + assert(type(priv_key) == "string", "Private key type") + assert(type(pub_key) == "string", "Public key type") local message = "Sign this message" local signature = crypto.sign("rsa", priv_key, message, "sha256") - assert_equal(type(signature), "string", "RSA signing") + assert_equal(type(signature), "string", "Signature type") local is_valid = crypto.verify("rsa", pub_key, message, signature, "sha256") - assert_equal(is_valid, true, "RSA signature verification") + assert_equal(is_valid, true, "Signature verification") end -- Test ECDSA signing and verification local function test_ecdsa_signing_verification() print('\n\27[1;7mTest ECDSA signing and verification \27[0m') local priv_key, pub_key = crypto.generatekeypair("ecdsa", "secp256r1") + assert(type(priv_key) == "string", "Private key type") + assert(type(pub_key) == "string", "Public key type") local message = "Sign this message with ECDSA" local signature = crypto.sign("ecdsa", priv_key, message, "sha256") - assert_equal(type(signature), "string", "ECDSA signing") + assert_equal(type(signature), "string", "Signature type") local is_valid = crypto.verify("ecdsa", pub_key, message, signature, "sha256") - assert_equal(is_valid, true, "ECDSA signature verification") + assert_equal(is_valid, true, "Signature verification") end -- Test AES key generation local function test_aes_key_generation() print('\n\27[1;7mTest AES key generation \27[0m') local key = crypto.generatekeypair('aes', 256) -- 256-bit key - assert_equal(type(key), "string", "AES key generation") - assert_equal(#key, 32, "AES key length (256 bits)") + assert_equal(type(key), "string", "Key type") + assert_equal(#key, 32, "Key length (256 bits)") end -- Test AES encryption and decryption (CBC mode) local function test_aes_encryption_decryption() print('\n\27[1;7mTest AES encryption and decryption (CBC mode) \27[0m') - local key = crypto.generatekeypair('aes',256) -- 256-bit key + local key = crypto.generatekeypair('aes', 256) -- 256-bit key local plaintext = "Hello, AES CBC!" -- Encrypt without providing IV (should auto-generate IV) print('\27[1mAES encryption (auto IV)\27[0m') local encrypted, iv = crypto.encrypt("aes", key, plaintext, nil) - assert_equal(type(encrypted), "string", "AES encryption (CBC, auto IV)") - assert_equal(type(iv), "string", "AES IV (auto-generated)") + assert_equal(type(encrypted), "string", "Ciphertext type") + assert_equal(type(iv), "string", "IV type") -- Decrypt print('\n\27[1mAES decryption (auto IV)\27[0m') local decrypted = crypto.decrypt("aes", key, encrypted, iv) - assert_equal(decrypted, plaintext, "AES decryption (CBC, auto IV)") + assert_equal(decrypted, plaintext, "Decrypted ciphertext matches plaintext") -- Encrypt with explicit IV print('\n\27[1mAES encryption (explicit IV)\27[0m') local iv2 = GetRandomBytes(16) local encrypted2, iv_used = crypto.encrypt("aes", key, plaintext, iv2) - assert_equal(type(encrypted2), "string", "AES encryption (CBC, explicit IV)") - assert_equal(iv_used, iv2, "AES IV (explicit)") + assert_equal(type(encrypted2), "string", "Ciphertext type") + assert_equal(iv_used, iv2, "IV match") print('\n\27[1mAES decryption (explicit IV)\27[0m') local decrypted2 = crypto.decrypt("aes", key, encrypted2, iv2) - assert_equal(decrypted2, plaintext, "AES decryption (CBC, explicit IV)") + assert_equal(decrypted2, plaintext, "Decrypted ciphertext matches plaintext") end -- Test AES encryption and decryption (CTR mode) local function test_aes_encryption_decryption_ctr() print('\n\27[1;7mTest AES encryption and decryption (CTR mode) \27[0m') - local key = crypto.generatekeypair('aes',256) + local key = crypto.generatekeypair('aes', 256) local plaintext = "Hello, AES CTR!" -- Encrypt without providing IV (should auto-generate IV) print('\27[1mAES encryption (auto IV)\27[0m') local encrypted, iv = crypto.encrypt("aes", key, plaintext, nil, "ctr") - assert_equal(type(encrypted), "string", "AES encryption (CTR, auto IV)") - assert_equal(type(iv), "string", "AES IV (auto-generated, CTR)") + assert_equal(type(encrypted), "string", "Ciphertext type") + assert_equal(type(iv), "string", "IV type") -- Decrypt print('\n\27[1mAES decryption (auto IV)\27[0m') local decrypted = crypto.decrypt("aes", key, encrypted, iv, "ctr") - assert_equal(decrypted, plaintext, "AES decryption (CTR, auto IV)") + assert_equal(decrypted, plaintext, "Decrypted ciphertext matches plaintext") -- Encrypt with explicit IV print('\n\27[1mAES encryption (explicit IV)\27[0m') local iv2 = GetRandomBytes(16) local encrypted2, iv_used = crypto.encrypt("aes", key, plaintext, iv2, "ctr") - assert_equal(type(encrypted2), "string", "AES encryption (CTR, explicit IV)") - assert_equal(iv_used, iv2, "AES IV (explicit, CTR)") + assert_equal(type(encrypted2), "string", "Ciphertext type") + assert_equal(iv_used, iv2, "IV match") print('\n\27[1mAES decryption (explicit IV)\27[0m') local decrypted2 = crypto.decrypt("aes", key, encrypted2, iv2, "ctr") - assert_equal(decrypted2, plaintext, "AES decryption (CTR, explicit IV)") + assert_equal(decrypted2, plaintext, "Decrypted ciphertext matches plaintext") end -- Test AES encryption and decryption (GCM mode) local function test_aes_encryption_decryption_gcm() print('\n\27[1;7mTest AES encryption and decryption (GCM mode) \27[0m') - local key = crypto.generatekeypair('aes',256) + local key = crypto.generatekeypair('aes', 256) local plaintext = "Hello, AES GCM!" -- Encrypt without providing IV (should auto-generate IV) print('\27[1mAES encryption (auto IV)\27[0m') local encrypted, iv, tag = crypto.encrypt("aes", key, plaintext, nil, "gcm") - assert_equal(type(encrypted), "string", "AES encryption (GCM, auto IV)") - assert_equal(type(iv), "string", "AES IV (auto-generated, GCM)") - assert_equal(type(tag), "string", "AES GCM tag (auto IV)") + assert_equal(#plaintext, #encrypted, "Ciphertext length matches plaintext") + assert_equal(type(encrypted), "string", "Ciphertext type") + assert_equal(type(iv), "string", "IV type") + assert_equal(type(tag), "string", "Tag type") -- Decrypt print('\n\27[1mAES decryption (auto IV)\27[0m') local decrypted = crypto.decrypt("aes", key, encrypted, iv, "gcm", nil, tag) - assert_equal(decrypted, plaintext, "AES decryption (GCM, auto IV)") + assert_equal(decrypted, plaintext, "Decrypted ciphertext matches plaintext") -- Encrypt with explicit IV print('\n\27[1mAES encryption (explicit IV)\27[0m') local iv2 = GetRandomBytes(13) -- GCM IV/nonce can be 12-16 bytes, 12 is standard local encrypted2, iv_used, tag2 = crypto.encrypt("aes", key, plaintext, iv2, "gcm") - assert_equal(type(encrypted2), "string", "AES encryption (GCM, explicit IV)") - assert_equal(iv_used, iv2, "AES IV (explicit, GCM)") - assert_equal(type(tag2), "string", "AES GCM tag (explicit IV)") + assert_equal(type(encrypted2), "string", "Ciphertext type") + assert_equal(iv_used, iv2, "IV match") + assert_equal(type(tag2), "string", "Tag type") print('\n\27[1mAES decryption (explicit IV)\27[0m') local decrypted2 = crypto.decrypt("aes", key, encrypted2, iv2, "gcm", nil, tag2) - assert_equal(decrypted2, plaintext, "AES decryption (GCM, explicit IV)") + assert_equal(decrypted2, plaintext, "Decrypted ciphertext matches plaintext") end -- Test PemToJwk conversion @@ -167,25 +174,25 @@ local function test_pem_to_jwk() local priv_key, pub_key = crypto.generatekeypair() print('\27[1mRSA Private key to JWK conversion\27[0m') local priv_jwk = crypto.convertPemToJwk(priv_key) - assert_equal(type(priv_jwk), "table", "PEM to JWK conversion") - assert_equal(priv_jwk.kty, "RSA", "JWK key type") + assert_equal(type(priv_jwk), "table", "JWK type") + assert_equal(priv_jwk.kty, "RSA", "kty is correct") print('\n\27[1mRSA Public key to JWK conversion\27[0m') local pub_jwk = crypto.convertPemToJwk(pub_key) - assert_equal(type(pub_jwk), "table", "PEM to JWK conversion") - assert_equal(pub_jwk.kty, "RSA", "JWK key type") + assert_equal(type(pub_jwk), "table", "JWK type") + assert_equal(pub_jwk.kty, "RSA", "kty is correct") -- Test ECDSA keys local priv_key, pub_key = crypto.generatekeypair('ecdsa') print('\n\27[1mECDSA Private key to JWK conversion\27[0m') local priv_jwk = crypto.convertPemToJwk(priv_key) - assert_equal(type(priv_jwk), "table", "PEM to JWK conversion") - assert_equal(priv_jwk.kty, "EC", "JWK key type") + assert_equal(type(priv_jwk), "table", "JWK type") + assert_equal(priv_jwk.kty, "EC", "kty is correct") print('\n\27[1mECDSA Public key to JWK conversion\27[0m') local pub_jwk = crypto.convertPemToJwk(pub_key) - assert_equal(type(pub_jwk), "table", "PEM to JWK conversion") - assert_equal(pub_jwk.kty, "EC", "JWK key type") + assert_equal(type(pub_jwk), "table", "JWK type") + assert_equal(pub_jwk.kty, "EC", "kty is correct") end -- Test CSR generation @@ -194,24 +201,25 @@ local function test_csr_generation() local priv_key, _ = crypto.generatekeypair() local subject_name = "CN=example.com,O=Example Org,C=US" local san = "DNS:example.com, DNS:www.example.com, IP:192.168.1.1" + assert(type(priv_key) == "string", "Private key type") - local csr = crypto.GenerateCsr(priv_key, subject_name) + local csr = crypto.generateCsr(priv_key, subject_name) assert_equal(type(csr), "string", "CSR generation with subject name") - csr = crypto.GenerateCsr(priv_key, subject_name, san) + csr = crypto.generateCsr(priv_key, subject_name, san) assert_equal(type(csr), "string", "CSR generation with subject name and san") - csr = crypto.GenerateCsr(priv_key, nil, san) + csr = crypto.generateCsr(priv_key, nil, san) assert_equal(type(csr), "string", "CSR generation with nil subject name and san") - csr = crypto.GenerateCsr(priv_key, '', san) + csr = crypto.generateCsr(priv_key, '', san) assert_equal(type(csr), "string", "CSR generation with empty subject name and san") -- These should fail - csr = crypto.GenerateCsr(priv_key, '') + csr = crypto.generateCsr(priv_key, '') assert_not_equal(type(csr), "string", "CSR generation with empty subject name and no san is rejected") - csr = crypto.GenerateCsr(priv_key) + csr = crypto.generateCsr(priv_key) assert_not_equal(type(csr), "string", "CSR generation with nil subject name and no san is rejected") end @@ -231,9 +239,9 @@ local function run_tests() test_csr_generation() print('') print("All tests passed!") - EXIT=0 + EXIT = 0 return EXIT end -EXIT=70 +EXIT = 70 os.exit(run_tests()) diff --git a/tool/net/definitions.lua b/tool/net/definitions.lua index 3732416b0..661e0ea27 100644 --- a/tool/net/definitions.lua +++ b/tool/net/definitions.lua @@ -8048,6 +8048,73 @@ kUrlPlus = nil ---@type integer to transcode ISO-8859-1 input into UTF-8. See `ParseUrl`. kUrlLatin1 = nil + +--- This module provides cryptographic operations. + +--- The crypto module for cryptographic operations +crypto = {} + +--- Converts a PEM-encoded key to JWK format +---@param pem string PEM-encoded key +---@return table?, string? JWK table or nil on error +---@return string? error message +function crypto.convertPemToJwk(pem) end + +--- Generates a Certificate Signing Request (CSR) +---@param key_pem string PEM-encoded private key +---@param subject_name string? X.509 subject name +---@param san_list string? Subject Alternative Names +---@return string?, string? CSR in PEM format or nil on error and error message +function crypto.generateCsr(key_pem, subject_name, san_list) end + +--- Signs data using a private key +---@param key_type string "rsa" or "ecdsa" +---@param private_key string PEM-encoded private key +---@param message string Data to sign +---@param hash_algo string? Hash algorithm (default: SHA-256) +---@return string?, string? Signature or nil on error and error message +function crypto.sign(key_type, private_key, message, hash_algo) end + +--- Verifies a signature +---@param key_type string "rsa" or "ecdsa" +---@param public_key string PEM-encoded public key +---@param message string Original message +---@param signature string Signature to verify +---@param hash_algo string? Hash algorithm (default: SHA-256) +---@return boolean?, string? True if valid or nil on error and error message +function crypto.verify(key_type, public_key, message, signature, hash_algo) end + +--- Encrypts data +---@param cipher_type string "rsa" or "aes" +---@param key string Public key or symmetric key +---@param plaintext string Data to encrypt +---@param mode string? AES mode: "cbc", "gcm", "ctr" (default: "cbc") +---@param iv string? Initialization Vector for AES +---@param aad string? Additional data for AES-GCM +---@return string? Encrypted data or nil on error +---@return string? IV or error message +---@return string? Authentication tag for GCM mode +function crypto.encrypt(cipher_type, key, plaintext, mode, iv, aad) end + +--- Decrypts data +---@param cipher_type string "rsa" or "aes" +---@param key string Private key or symmetric key +---@param ciphertext string Data to decrypt +---@param iv string? Initialization Vector for AES +---@param mode string? AES mode: "cbc", "gcm", "ctr" (default: "cbc") +---@param tag string? Authentication tag for AES-GCM +---@param aad string? Additional data for AES-GCM +---@return string?, string? Decrypted data or nil on error and error message +function crypto.decrypt(cipher_type, key, ciphertext, iv, mode, tag, aad) end + +--- Generates cryptographic keys +---@param key_type string? "rsa", "ecdsa", or "aes" +---@param key_size_or_curve number|string? Key size or curve name +---@return string? Private key or nil on error +---@return string? Public key (nil for AES) or error message +function crypto.generatekeypair(key_type, key_size_or_curve) end + + --[[ ──────────────────────────────────────────────────────────────────────────────── LEGAL diff --git a/tool/net/lcrypto.c b/tool/net/lcrypto.c index d5ca2890f..c473df4cc 100644 --- a/tool/net/lcrypto.c +++ b/tool/net/lcrypto.c @@ -1595,7 +1595,7 @@ static const luaL_Reg kLuaCrypto[] = { {"decrypt", LuaCryptoDecrypt}, // {"generatekeypair", LuaCryptoGenerateKeyPair}, // {"convertPemToJwk", LuaConvertPemToJwk}, // - {"GenerateCsr", LuaGenerateCSR}, // + {"generateCsr", LuaGenerateCSR}, // {0}, // }; From cef06a5b22e6fe7a83e6518ea6adb687703753c4 Mon Sep 17 00:00:00 2001 From: Miguel Terron Date: Wed, 4 Jun 2025 17:04:24 +1200 Subject: [PATCH 11/18] Quiet tests --- test/tool/net/lcrypto_test.lua | 36 +--------------------------------- 1 file changed, 1 insertion(+), 35 deletions(-) diff --git a/test/tool/net/lcrypto_test.lua b/test/tool/net/lcrypto_test.lua index fe8335cfd..173f83168 100644 --- a/test/tool/net/lcrypto_test.lua +++ b/test/tool/net/lcrypto_test.lua @@ -2,22 +2,17 @@ local function assert_equal(actual, expected, message) if actual ~= expected then error("FAIL: " .. message .. ": expected " .. tostring(expected) .. ", got " .. tostring(actual)) - else - print("PASS: " .. message) end end local function assert_not_equal(actual, not_expected, message) if actual == not_expected then error(message .. ": did not expect " .. tostring(not_expected)) - else - print("PASS: " .. message) end end -- Test RSA key pair generation local function test_rsa_keypair_generation() - print('\27[1;7mTest RSA key pair generation \27[0m') local priv_key, pub_key = crypto.generatekeypair("rsa", 2048) assert_equal(type(priv_key), "string", "Private key type") assert_equal(type(pub_key), "string", "Public key type") @@ -25,7 +20,6 @@ end -- Test ECDSA key pair generation local function test_ecdsa_keypair_generation() - print('\n\27[1;7mTest ECDSA key pair generation \27[0m') local priv_key, pub_key = crypto.generatekeypair("ecdsa", "secp256r1") assert_equal(type(priv_key), "string", "Private key type") assert_equal(type(pub_key), "string", "Public key type") @@ -33,7 +27,6 @@ end -- Test RSA encryption and decryption local function test_rsa_encryption_decryption() - print('\n\27[1;7mTest RSA encryption and decryption \27[0m') local priv_key, pub_key = crypto.generatekeypair("rsa", 2048) assert(type(priv_key) == "string", "Private key type") assert(type(pub_key) == "string", "Public key type") @@ -46,7 +39,6 @@ end -- Test RSA signing and verification local function test_rsa_signing_verification() - print('\n\27[1;7mTest RSA signing and verification \27[0m') local priv_key, pub_key = crypto.generatekeypair("rsa", 2048) assert(type(priv_key) == "string", "Private key type") assert(type(pub_key) == "string", "Public key type") @@ -59,7 +51,6 @@ end -- Test ECDSA signing and verification local function test_ecdsa_signing_verification() - print('\n\27[1;7mTest ECDSA signing and verification \27[0m') local priv_key, pub_key = crypto.generatekeypair("ecdsa", "secp256r1") assert(type(priv_key) == "string", "Private key type") assert(type(pub_key) == "string", "Public key type") @@ -72,7 +63,6 @@ end -- Test AES key generation local function test_aes_key_generation() - print('\n\27[1;7mTest AES key generation \27[0m') local key = crypto.generatekeypair('aes', 256) -- 256-bit key assert_equal(type(key), "string", "Key type") assert_equal(#key, 32, "Key length (256 bits)") @@ -80,70 +70,58 @@ end -- Test AES encryption and decryption (CBC mode) local function test_aes_encryption_decryption() - print('\n\27[1;7mTest AES encryption and decryption (CBC mode) \27[0m') local key = crypto.generatekeypair('aes', 256) -- 256-bit key local plaintext = "Hello, AES CBC!" -- Encrypt without providing IV (should auto-generate IV) - print('\27[1mAES encryption (auto IV)\27[0m') local encrypted, iv = crypto.encrypt("aes", key, plaintext, nil) assert_equal(type(encrypted), "string", "Ciphertext type") assert_equal(type(iv), "string", "IV type") -- Decrypt - print('\n\27[1mAES decryption (auto IV)\27[0m') local decrypted = crypto.decrypt("aes", key, encrypted, iv) assert_equal(decrypted, plaintext, "Decrypted ciphertext matches plaintext") -- Encrypt with explicit IV - print('\n\27[1mAES encryption (explicit IV)\27[0m') local iv2 = GetRandomBytes(16) local encrypted2, iv_used = crypto.encrypt("aes", key, plaintext, iv2) assert_equal(type(encrypted2), "string", "Ciphertext type") assert_equal(iv_used, iv2, "IV match") - print('\n\27[1mAES decryption (explicit IV)\27[0m') local decrypted2 = crypto.decrypt("aes", key, encrypted2, iv2) assert_equal(decrypted2, plaintext, "Decrypted ciphertext matches plaintext") end -- Test AES encryption and decryption (CTR mode) local function test_aes_encryption_decryption_ctr() - print('\n\27[1;7mTest AES encryption and decryption (CTR mode) \27[0m') local key = crypto.generatekeypair('aes', 256) local plaintext = "Hello, AES CTR!" -- Encrypt without providing IV (should auto-generate IV) - print('\27[1mAES encryption (auto IV)\27[0m') local encrypted, iv = crypto.encrypt("aes", key, plaintext, nil, "ctr") assert_equal(type(encrypted), "string", "Ciphertext type") assert_equal(type(iv), "string", "IV type") -- Decrypt - print('\n\27[1mAES decryption (auto IV)\27[0m') local decrypted = crypto.decrypt("aes", key, encrypted, iv, "ctr") assert_equal(decrypted, plaintext, "Decrypted ciphertext matches plaintext") -- Encrypt with explicit IV - print('\n\27[1mAES encryption (explicit IV)\27[0m') local iv2 = GetRandomBytes(16) local encrypted2, iv_used = crypto.encrypt("aes", key, plaintext, iv2, "ctr") assert_equal(type(encrypted2), "string", "Ciphertext type") assert_equal(iv_used, iv2, "IV match") - print('\n\27[1mAES decryption (explicit IV)\27[0m') local decrypted2 = crypto.decrypt("aes", key, encrypted2, iv2, "ctr") assert_equal(decrypted2, plaintext, "Decrypted ciphertext matches plaintext") end -- Test AES encryption and decryption (GCM mode) local function test_aes_encryption_decryption_gcm() - print('\n\27[1;7mTest AES encryption and decryption (GCM mode) \27[0m') local key = crypto.generatekeypair('aes', 256) local plaintext = "Hello, AES GCM!" -- Encrypt without providing IV (should auto-generate IV) - print('\27[1mAES encryption (auto IV)\27[0m') local encrypted, iv, tag = crypto.encrypt("aes", key, plaintext, nil, "gcm") assert_equal(#plaintext, #encrypted, "Ciphertext length matches plaintext") assert_equal(type(encrypted), "string", "Ciphertext type") @@ -151,45 +129,37 @@ local function test_aes_encryption_decryption_gcm() assert_equal(type(tag), "string", "Tag type") -- Decrypt - print('\n\27[1mAES decryption (auto IV)\27[0m') local decrypted = crypto.decrypt("aes", key, encrypted, iv, "gcm", nil, tag) assert_equal(decrypted, plaintext, "Decrypted ciphertext matches plaintext") -- Encrypt with explicit IV - print('\n\27[1mAES encryption (explicit IV)\27[0m') local iv2 = GetRandomBytes(13) -- GCM IV/nonce can be 12-16 bytes, 12 is standard local encrypted2, iv_used, tag2 = crypto.encrypt("aes", key, plaintext, iv2, "gcm") assert_equal(type(encrypted2), "string", "Ciphertext type") assert_equal(iv_used, iv2, "IV match") assert_equal(type(tag2), "string", "Tag type") - print('\n\27[1mAES decryption (explicit IV)\27[0m') local decrypted2 = crypto.decrypt("aes", key, encrypted2, iv2, "gcm", nil, tag2) assert_equal(decrypted2, plaintext, "Decrypted ciphertext matches plaintext") end -- Test PemToJwk conversion local function test_pem_to_jwk() - print('\n\27[1;7mTest PEM to JWK conversion \27[0m') local priv_key, pub_key = crypto.generatekeypair() - print('\27[1mRSA Private key to JWK conversion\27[0m') local priv_jwk = crypto.convertPemToJwk(priv_key) assert_equal(type(priv_jwk), "table", "JWK type") assert_equal(priv_jwk.kty, "RSA", "kty is correct") - print('\n\27[1mRSA Public key to JWK conversion\27[0m') local pub_jwk = crypto.convertPemToJwk(pub_key) assert_equal(type(pub_jwk), "table", "JWK type") assert_equal(pub_jwk.kty, "RSA", "kty is correct") -- Test ECDSA keys local priv_key, pub_key = crypto.generatekeypair('ecdsa') - print('\n\27[1mECDSA Private key to JWK conversion\27[0m') local priv_jwk = crypto.convertPemToJwk(priv_key) assert_equal(type(priv_jwk), "table", "JWK type") assert_equal(priv_jwk.kty, "EC", "kty is correct") - print('\n\27[1mECDSA Public key to JWK conversion\27[0m') local pub_jwk = crypto.convertPemToJwk(pub_key) assert_equal(type(pub_jwk), "table", "JWK type") assert_equal(pub_jwk.kty, "EC", "kty is correct") @@ -197,11 +167,10 @@ end -- Test CSR generation local function test_csr_generation() - print('\n\27[1;7mTest CSR generation \27[0m') local priv_key, _ = crypto.generatekeypair() local subject_name = "CN=example.com,O=Example Org,C=US" local san = "DNS:example.com, DNS:www.example.com, IP:192.168.1.1" - assert(type(priv_key) == "string", "Private key type") + assert_equal(type(priv_key), "string", "Private key type") local csr = crypto.generateCsr(priv_key, subject_name) assert_equal(type(csr), "string", "CSR generation with subject name") @@ -225,7 +194,6 @@ end -- Run all tests local function run_tests() - print("Running tests for lcrypto...") test_rsa_keypair_generation() test_rsa_signing_verification() test_rsa_encryption_decryption() @@ -237,8 +205,6 @@ local function run_tests() test_aes_encryption_decryption_gcm() test_pem_to_jwk() test_csr_generation() - print('') - print("All tests passed!") EXIT = 0 return EXIT end From d06d0879b86131b3489b6898c6ee73f9e487cad9 Mon Sep 17 00:00:00 2001 From: Miguel Terron Date: Wed, 4 Jun 2025 21:13:09 +1200 Subject: [PATCH 12/18] Change of API for Encrypt and Decrypt. The options are now passed in a table instead of positional parameters. This is not LuaCrypto compatible but it is a nicer interface. --- test/tool/net/lcrypto_test.lua | 27 ++-- tool/net/lcrypto.c | 225 +++++++++++++++++++++------------ 2 files changed, 155 insertions(+), 97 deletions(-) diff --git a/test/tool/net/lcrypto_test.lua b/test/tool/net/lcrypto_test.lua index 173f83168..077cd03c3 100644 --- a/test/tool/net/lcrypto_test.lua +++ b/test/tool/net/lcrypto_test.lua @@ -69,7 +69,7 @@ local function test_aes_key_generation() end -- Test AES encryption and decryption (CBC mode) -local function test_aes_encryption_decryption() +local function test_aes_encryption_decryption_cbc() local key = crypto.generatekeypair('aes', 256) -- 256-bit key local plaintext = "Hello, AES CBC!" @@ -79,16 +79,16 @@ local function test_aes_encryption_decryption() assert_equal(type(iv), "string", "IV type") -- Decrypt - local decrypted = crypto.decrypt("aes", key, encrypted, iv) + local decrypted = crypto.decrypt("aes", key, encrypted, {mode="cbc",iv=iv}) assert_equal(decrypted, plaintext, "Decrypted ciphertext matches plaintext") -- Encrypt with explicit IV local iv2 = GetRandomBytes(16) - local encrypted2, iv_used = crypto.encrypt("aes", key, plaintext, iv2) + local encrypted2, iv_used = crypto.encrypt("aes", key, plaintext, {mode="cbc",iv=iv2}) assert_equal(type(encrypted2), "string", "Ciphertext type") assert_equal(iv_used, iv2, "IV match") - local decrypted2 = crypto.decrypt("aes", key, encrypted2, iv2) + local decrypted2 = crypto.decrypt("aes", key, encrypted2, {mode="cbc",iv=iv2}) assert_equal(decrypted2, plaintext, "Decrypted ciphertext matches plaintext") end @@ -98,48 +98,49 @@ local function test_aes_encryption_decryption_ctr() local plaintext = "Hello, AES CTR!" -- Encrypt without providing IV (should auto-generate IV) - local encrypted, iv = crypto.encrypt("aes", key, plaintext, nil, "ctr") + local encrypted, iv = crypto.encrypt("aes", key, plaintext, {mode="ctr"}) assert_equal(type(encrypted), "string", "Ciphertext type") assert_equal(type(iv), "string", "IV type") -- Decrypt - local decrypted = crypto.decrypt("aes", key, encrypted, iv, "ctr") + local decrypted = crypto.decrypt("aes", key, encrypted, {mode="ctr", iv=iv}) assert_equal(decrypted, plaintext, "Decrypted ciphertext matches plaintext") -- Encrypt with explicit IV local iv2 = GetRandomBytes(16) - local encrypted2, iv_used = crypto.encrypt("aes", key, plaintext, iv2, "ctr") + local encrypted2, iv_used = crypto.encrypt("aes", key, plaintext, {mode="ctr", iv=iv2}) assert_equal(type(encrypted2), "string", "Ciphertext type") assert_equal(iv_used, iv2, "IV match") - local decrypted2 = crypto.decrypt("aes", key, encrypted2, iv2, "ctr") + local decrypted2 = crypto.decrypt("aes", key, encrypted2, {mode="ctr", iv=iv2}) assert_equal(decrypted2, plaintext, "Decrypted ciphertext matches plaintext") end -- Test AES encryption and decryption (GCM mode) local function test_aes_encryption_decryption_gcm() local key = crypto.generatekeypair('aes', 256) + assert_equal(type(key), "string", "key type") local plaintext = "Hello, AES GCM!" -- Encrypt without providing IV (should auto-generate IV) - local encrypted, iv, tag = crypto.encrypt("aes", key, plaintext, nil, "gcm") + local encrypted, iv, tag = crypto.encrypt("aes", key, plaintext, {mode="gcm"}) assert_equal(#plaintext, #encrypted, "Ciphertext length matches plaintext") assert_equal(type(encrypted), "string", "Ciphertext type") assert_equal(type(iv), "string", "IV type") assert_equal(type(tag), "string", "Tag type") -- Decrypt - local decrypted = crypto.decrypt("aes", key, encrypted, iv, "gcm", nil, tag) + local decrypted = crypto.decrypt("aes", key, encrypted, {mode="gcm",iv=iv,tag=tag}) assert_equal(decrypted, plaintext, "Decrypted ciphertext matches plaintext") -- Encrypt with explicit IV local iv2 = GetRandomBytes(13) -- GCM IV/nonce can be 12-16 bytes, 12 is standard - local encrypted2, iv_used, tag2 = crypto.encrypt("aes", key, plaintext, iv2, "gcm") + local encrypted2, iv_used, tag2 = crypto.encrypt("aes", key, plaintext, {mode="gcm",iv=iv2}) assert_equal(type(encrypted2), "string", "Ciphertext type") assert_equal(iv_used, iv2, "IV match") assert_equal(type(tag2), "string", "Tag type") - local decrypted2 = crypto.decrypt("aes", key, encrypted2, iv2, "gcm", nil, tag2) + local decrypted2 = crypto.decrypt("aes", key, encrypted2, {mode="gcm",iv=iv2,tag=tag2}) assert_equal(decrypted2, plaintext, "Decrypted ciphertext matches plaintext") end @@ -200,7 +201,7 @@ local function run_tests() test_ecdsa_keypair_generation() test_ecdsa_signing_verification() test_aes_key_generation() - test_aes_encryption_decryption() + test_aes_encryption_decryption_cbc() test_aes_encryption_decryption_ctr() test_aes_encryption_decryption_gcm() test_pem_to_jwk() diff --git a/tool/net/lcrypto.c b/tool/net/lcrypto.c index c473df4cc..07d8c120a 100644 --- a/tool/net/lcrypto.c +++ b/tool/net/lcrypto.c @@ -322,6 +322,18 @@ static int LuaRSAGenerateKeyPair(lua_State *L) { return 2; } +// Helper to get string field from options table for RSA +// static const char *parse_rsa_options(lua_State *L, int options_idx) { +// const char *padding = "pkcs1"; // default +// if (lua_istable(L, options_idx)) { +// lua_getfield(L, options_idx, "padding"); +// if (lua_isstring(L, -1)) { +// padding = lua_tostring(L, -1); +// } +// lua_pop(L, 1); +// } +// return padding; +// } static char *RSAEncrypt(const char *public_key_pem, const unsigned char *data, size_t data_len, size_t *out_len) { @@ -368,23 +380,25 @@ static char *RSAEncrypt(const char *public_key_pem, const unsigned char *data, return (char *)output; } static int LuaRSAEncrypt(lua_State *L) { - const char *public_key = luaL_checkstring(L, 1); - size_t data_len; - const unsigned char *data = - (const unsigned char *)luaL_checklstring(L, 2, &data_len); - size_t out_len; + // Args: key, plaintext, options table + size_t keylen, ptlen; + const char *key = luaL_checklstring(L, 1, &keylen); + const unsigned char *plaintext = (const unsigned char *)luaL_checklstring(L, 2, &ptlen); + // int options_idx = 3; + // const char *padding = parse_rsa_options(L, options_idx); + size_t out_len; - char *encrypted = RSAEncrypt(public_key, data, data_len, &out_len); - if (!encrypted) { - lua_pushnil(L); - lua_pushstring(L, "Encryption failed"); - return 2; - } + char *encrypted = RSAEncrypt(key, plaintext, ptlen, &out_len); + if (!encrypted) { + lua_pushnil(L); + lua_pushstring(L, "Encryption failed"); + return 2; + } - lua_pushlstring(L, encrypted, out_len); - free(encrypted); + lua_pushlstring(L, encrypted, out_len); + free(encrypted); - return 1; + return 1; } static char *RSADecrypt(const char *private_key_pem, @@ -433,24 +447,25 @@ static char *RSADecrypt(const char *private_key_pem, return (char *)output; } static int LuaRSADecrypt(lua_State *L) { - const char *private_key = luaL_checkstring(L, 1); - size_t encrypted_len; - const unsigned char *encrypted_data = - (const unsigned char *)luaL_checklstring(L, 2, &encrypted_len); - size_t out_len; + // Args: key, ciphertext, options table + size_t keylen, ctlen; + const char *key = luaL_checklstring(L, 1, &keylen); + const unsigned char *ciphertext = (const unsigned char *)luaL_checklstring(L, 2, &ctlen); + // int options_idx = 3; + // const char *padding = parse_rsa_options(L, options_idx); + size_t out_len; - char *decrypted = - RSADecrypt(private_key, encrypted_data, encrypted_len, &out_len); - if (!decrypted) { - lua_pushnil(L); - lua_pushstring(L, "Decryption failed"); - return 2; - } + char *decrypted = RSADecrypt(key, ciphertext, ctlen, &out_len); + if (!decrypted) { + lua_pushnil(L); + lua_pushstring(L, "Decryption failed"); + return 2; + } - lua_pushlstring(L, decrypted, out_len); - free(decrypted); + lua_pushlstring(L, decrypted, out_len); + free(decrypted); - return 1; + return 1; } // RSA Signing @@ -1087,6 +1102,7 @@ static int LuaECDSAVerify(lua_State *L) { // AES + // AES key generation helper static int LuaAesGenerateKey(lua_State *L) { int keybits = 128; @@ -1125,21 +1141,87 @@ static int LuaAesGenerateKey(lua_State *L) { return 1; } +// Helper to get string field from options table +typedef struct { + const char *mode; + const unsigned char *iv; + size_t ivlen; +} aes_options_t; + +static void parse_aes_options(lua_State *L, int options_idx, aes_options_t *opts) { + opts->mode = "cbc"; + opts->iv = NULL; + opts->ivlen = 0; + if (lua_istable(L, options_idx)) { + lua_getfield(L, options_idx, "mode"); + if (!lua_isnil(L, -1)) opts->mode = lua_tostring(L, -1); + lua_pop(L, 1); + lua_getfield(L, options_idx, "iv"); + if (lua_isstring(L, -1)) { + opts->iv = (const unsigned char *)lua_tolstring(L, -1, &opts->ivlen); + } + lua_pop(L, 1); + } +} + +// Helper for AES decrypt options +typedef struct { + const char *mode; + const unsigned char *iv; + size_t ivlen; + const unsigned char *tag; + size_t taglen; + const unsigned char *aad; + size_t aadlen; +} aes_decrypt_options_t; + +static void parse_aes_decrypt_options(lua_State *L, int options_idx, aes_decrypt_options_t *opts) { + opts->mode = "cbc"; + opts->iv = NULL; + opts->ivlen = 0; + opts->tag = NULL; + opts->taglen = 0; + opts->aad = NULL; + opts->aadlen = 0; + if (lua_istable(L, options_idx)) { + lua_getfield(L, options_idx, "mode"); + if (!lua_isnil(L, -1)) opts->mode = lua_tostring(L, -1); + lua_pop(L, 1); + lua_getfield(L, options_idx, "iv"); + if (lua_isstring(L, -1)) { + opts->iv = (const unsigned char *)lua_tolstring(L, -1, &opts->ivlen); + } + lua_pop(L, 1); + lua_getfield(L, options_idx, "tag"); + if (lua_isstring(L, -1)) { + opts->tag = (const unsigned char *)lua_tolstring(L, -1, &opts->taglen); + } + lua_pop(L, 1); + lua_getfield(L, options_idx, "aad"); + if (lua_isstring(L, -1)) { + opts->aad = (const unsigned char *)lua_tolstring(L, -1, &opts->aadlen); + } + lua_pop(L, 1); + } +} + // AES encryption supporting CBC, GCM, and CTR modes static int LuaAesEncrypt(lua_State *L) { - // Accept IV as the 3rd argument (after key, plaintext) - size_t keylen, ivlen = 0, ptlen; + // Args: key, plaintext, options table + size_t keylen, ptlen; const unsigned char *key = (const unsigned char *)luaL_checklstring(L, 1, &keylen); const unsigned char *plaintext = (const unsigned char *)luaL_checklstring(L, 2, &ptlen); - const unsigned char *iv = NULL; + int options_idx = 3; + aes_options_t opts; + parse_aes_options(L, options_idx, &opts); + const char *mode = opts.mode; + const unsigned char *iv = opts.iv; + size_t ivlen = opts.ivlen; unsigned char *gen_iv = NULL; int iv_was_generated = 0; - - const char *mode = luaL_optstring(L, 4, "cbc"); // Default to CBC if not provided int ret = 0; unsigned char *output = NULL; int is_gcm = 0, is_ctr = 0, is_cbc = 0; - if (strcasecmp(mode, "cbc") == 0) { is_cbc = 1; } else if (strcasecmp(mode, "gcm") == 0) { @@ -1151,10 +1233,8 @@ static int LuaAesEncrypt(lua_State *L) { lua_pushstring(L, "Unsupported AES mode. Use 'cbc', 'gcm', or 'ctr'."); return 2; } - - // If IV is not provided (arg3 is nil or missing), auto-generate - if (lua_isnoneornil(L, 3)) { - // For GCM, standard is 12 bytes, but allow 12-16 + // If IV is not provided, auto-generate + if (!iv) { if (is_gcm) { ivlen = 12; } else { @@ -1176,26 +1256,7 @@ static int LuaAesEncrypt(lua_State *L) { mbedtls_entropy_free(&entropy); iv = gen_iv; iv_was_generated = 1; - } else { - // IV provided - iv = (const unsigned char *)luaL_checklstring(L, 3, &ivlen); - // Do not force ivlen to 16 here! Accept actual length for GCM (12-16) - if (is_cbc || is_ctr) { - if (ivlen != 16) { - lua_pushnil(L); - lua_pushstring(L, "AES IV must be 16 bytes for CBC/CTR"); - return 2; - } - } else if (is_gcm) { - if (ivlen < 12 || ivlen > 16) { - lua_pushnil(L); - lua_pushstring(L, "AES GCM IV/nonce must be 12-16 bytes"); - return 2; - } - } - iv_was_generated = 0; } - if (is_cbc) { // PKCS7 padding size_t block_size = 16; @@ -1322,16 +1383,21 @@ static int LuaAesEncrypt(lua_State *L) { // AES decryption supporting CBC, GCM, and CTR modes static int LuaAesDecrypt(lua_State *L) { - size_t keylen, ctlen, ivlen; + // Args: key, ciphertext, options table + size_t keylen, ctlen; const unsigned char *key = (const unsigned char *)luaL_checklstring(L, 1, &keylen); const unsigned char *ciphertext = (const unsigned char *)luaL_checklstring(L, 2, &ctlen); - const unsigned char *iv = (const unsigned char *)luaL_checklstring(L, 3, &ivlen); - const char *mode = luaL_optstring(L, 4, "cbc"); // Default to CBC if not provided - const unsigned char *aad = NULL; - const unsigned char *tag = NULL; - size_t aadlen = 0, taglen = 0; + int options_idx = 3; + aes_decrypt_options_t opts; + parse_aes_decrypt_options(L, options_idx, &opts); + const char *mode = opts.mode; + const unsigned char *iv = opts.iv; + size_t ivlen = opts.ivlen; + const unsigned char *tag = opts.tag; + size_t taglen = opts.taglen; + const unsigned char *aad = opts.aad; + size_t aadlen = opts.aadlen; int is_gcm = 0, is_ctr = 0, is_cbc = 0; - if (strcasecmp(mode, "cbc") == 0) { is_cbc = 1; } else if (strcasecmp(mode, "gcm") == 0) { @@ -1343,7 +1409,6 @@ static int LuaAesDecrypt(lua_State *L) { lua_pushstring(L, "Unsupported AES mode. Use 'cbc', 'gcm', or 'ctr'."); return 2; } - // Validate key length (16, 24, 32 bytes) if (keylen != 16 && keylen != 24 && keylen != 32) { lua_pushnil(L); @@ -1367,19 +1432,9 @@ static int LuaAesDecrypt(lua_State *L) { // GCM: require tag and optional AAD if (is_gcm) { - if (!lua_isnoneornil(L, 5)) { - aad = (const unsigned char *)luaL_checklstring(L, 5, &aadlen); - } - if (!lua_isnoneornil(L, 6)) { - tag = (const unsigned char *)luaL_checklstring(L, 6, &taglen); - if (taglen < 12 || taglen > 16) { - lua_pushnil(L); - lua_pushstring(L, "AES GCM tag must be 12-16 bytes"); - return 2; - } - } else { + if (!tag || taglen < 12 || taglen > 16) { lua_pushnil(L); - lua_pushstring(L, "AES GCM tag required as 6th argument"); + lua_pushstring(L, "AES GCM tag must be 12-16 bytes"); return 2; } } @@ -1543,10 +1598,12 @@ static int LuaCryptoVerify(lua_State *L) { } static int LuaCryptoEncrypt(lua_State *L) { - const char *cipher = luaL_checkstring(L, 1); // Cipher type (e.g., "rsa", "aes") - lua_remove(L, 1); // Remove the first argument (key type or cipher type) before dispatching - + // Args: cipher_type, key, msg, options table + const char *cipher = luaL_checkstring(L, 1); + // Remove cipher_type from stack, so key is at 1, msg at 2, options at 3 + lua_remove(L, 1); if (strcasecmp(cipher, "rsa") == 0) { + // Update LuaRSAEncrypt to accept (key, msg, options) return LuaRSAEncrypt(L); } else if (strcasecmp(cipher, "aes") == 0) { return LuaAesEncrypt(L); @@ -1556,9 +1613,9 @@ static int LuaCryptoEncrypt(lua_State *L) { } static int LuaCryptoDecrypt(lua_State *L) { - const char *cipher = luaL_checkstring(L, 1); // Cipher type (e.g., "rsa", "aes") - lua_remove(L, 1); // Remove the first argument (key type or cipher type) before dispatching - + // Args: cipher_type, key, ciphertext, options table + const char *cipher = luaL_checkstring(L, 1); + lua_remove(L, 1); // Remove cipher_type, so key is at 1, ciphertext at 2, options at 3 if (strcasecmp(cipher, "rsa") == 0) { return LuaRSADecrypt(L); } else if (strcasecmp(cipher, "aes") == 0) { From e35f99c7db1aaefbf489d055f1ebd001440b3bb2 Mon Sep 17 00:00:00 2001 From: Miguel Terron Date: Wed, 4 Jun 2025 21:42:17 +1200 Subject: [PATCH 13/18] Merge parse aes parse options functions --- tool/net/lcrypto.c | 1582 ++++++++++++++++++++++---------------------- 1 file changed, 794 insertions(+), 788 deletions(-) diff --git a/tool/net/lcrypto.c b/tool/net/lcrypto.c index 07d8c120a..2b4ba2fb6 100644 --- a/tool/net/lcrypto.c +++ b/tool/net/lcrypto.c @@ -1,18 +1,18 @@ #include "libc/log/log.h" #include "net/https/https.h" #include "third_party/lua/lauxlib.h" +#include "third_party/mbedtls/aes.h" +#include "third_party/mbedtls/base64.h" +#include "third_party/mbedtls/ctr_drbg.h" +#include "third_party/mbedtls/ecdsa.h" +#include "third_party/mbedtls/entropy.h" #include "third_party/mbedtls/error.h" +#include "third_party/mbedtls/gcm.h" +#include "third_party/mbedtls/md.h" +#include "third_party/mbedtls/oid.h" #include "third_party/mbedtls/pk.h" #include "third_party/mbedtls/rsa.h" -#include "third_party/mbedtls/ecdsa.h" #include "third_party/mbedtls/x509_csr.h" -#include "third_party/mbedtls/oid.h" -#include "third_party/mbedtls/md.h" -#include "third_party/mbedtls/base64.h" -#include "third_party/mbedtls/aes.h" -#include "third_party/mbedtls/ctr_drbg.h" -#include "third_party/mbedtls/entropy.h" -#include "third_party/mbedtls/gcm.h" // Standard C library and redbean utilities #include "libc/errno.h" @@ -22,217 +22,227 @@ // Parse PEM keys and convert them into JWK format static int LuaConvertPemToJwk(lua_State *L) { - const char *pem_key = luaL_checkstring(L, 1); + const char *pem_key = luaL_checkstring(L, 1); - mbedtls_pk_context key; - mbedtls_pk_init(&key); - int ret; - - // Parse the PEM key - if ((ret = mbedtls_pk_parse_key(&key, (const unsigned char *)pem_key, strlen(pem_key) + 1, NULL, 0)) != 0 && - (ret = mbedtls_pk_parse_public_key(&key, (const unsigned char *)pem_key, strlen(pem_key) + 1)) != 0) { - lua_pushnil(L); - lua_pushfstring(L, "Failed to parse PEM key: -0x%04x", -ret); - mbedtls_pk_free(&key); - return 2; - } - - lua_newtable(L); // Create a new Lua table - - if (mbedtls_pk_get_type(&key) == MBEDTLS_PK_RSA) { - // Handle RSA keys - const mbedtls_rsa_context *rsa = mbedtls_pk_rsa(key); - size_t n_len = mbedtls_mpi_size(&rsa->N); - size_t e_len = mbedtls_mpi_size(&rsa->E); - - unsigned char *n = malloc(n_len); - unsigned char *e = malloc(e_len); - - if (!n || !e) { - lua_pushnil(L); - lua_pushstring(L, "Memory allocation failed"); - free(n); - free(e); - mbedtls_pk_free(&key); - return 2; - } - - mbedtls_mpi_write_binary(&rsa->N, n, n_len); - mbedtls_mpi_write_binary(&rsa->E, e, e_len); - - char *n_b64 = NULL, *e_b64 = NULL; - size_t n_b64_len, e_b64_len; - - mbedtls_base64_encode(NULL, 0, &n_b64_len, n, n_len); - mbedtls_base64_encode(NULL, 0, &e_b64_len, e, e_len); - - n_b64 = malloc(n_b64_len + 1); - e_b64 = malloc(e_b64_len + 1); - - if (!n_b64 || !e_b64) { - lua_pushnil(L); - lua_pushstring(L, "Memory allocation failed"); - free(n); - free(e); - free(n_b64); - free(e_b64); - mbedtls_pk_free(&key); - return 2; - } - - mbedtls_base64_encode((unsigned char *)n_b64, n_b64_len, &n_b64_len, n, n_len); - mbedtls_base64_encode((unsigned char *)e_b64, e_b64_len, &e_b64_len, e, e_len); - - n_b64[n_b64_len] = '\0'; - e_b64[e_b64_len] = '\0'; - - lua_pushstring(L, "RSA"); - lua_setfield(L, -2, "kty"); - lua_pushstring(L, n_b64); - lua_setfield(L, -2, "n"); - lua_pushstring(L, e_b64); - lua_setfield(L, -2, "e"); - - free(n); - free(e); - free(n_b64); - free(e_b64); - } else if (mbedtls_pk_get_type(&key) == MBEDTLS_PK_ECKEY) { - // Handle ECDSA keys - const mbedtls_ecp_keypair *ec = mbedtls_pk_ec(key); - const mbedtls_ecp_point *Q = &ec->Q; - size_t x_len = (ec->grp.pbits + 7) / 8; - size_t y_len = (ec->grp.pbits + 7) / 8; - - unsigned char *x = malloc(x_len); - unsigned char *y = malloc(y_len); - - if (!x || !y) { - lua_pushnil(L); - lua_pushstring(L, "Memory allocation failed"); - free(x); - free(y); - mbedtls_pk_free(&key); - return 2; - } - - mbedtls_mpi_write_binary(&Q->X, x, x_len); - mbedtls_mpi_write_binary(&Q->Y, y, y_len); - - char *x_b64 = NULL, *y_b64 = NULL; - size_t x_b64_len, y_b64_len; - - mbedtls_base64_encode(NULL, 0, &x_b64_len, x, x_len); - mbedtls_base64_encode(NULL, 0, &y_b64_len, y, y_len); - - x_b64 = malloc(x_b64_len + 1); - y_b64 = malloc(y_b64_len + 1); - - if (!x_b64 || !y_b64) { - lua_pushnil(L); - lua_pushstring(L, "Memory allocation failed"); - free(x); - free(y); - free(x_b64); - free(y_b64); - mbedtls_pk_free(&key); - return 2; - } - - mbedtls_base64_encode((unsigned char *)x_b64, x_b64_len, &x_b64_len, x, x_len); - mbedtls_base64_encode((unsigned char *)y_b64, y_b64_len, &y_b64_len, y, y_len); - - x_b64[x_b64_len] = '\0'; - y_b64[y_b64_len] = '\0'; - - lua_pushstring(L, "EC"); - lua_setfield(L, -2, "kty"); - lua_pushstring(L, mbedtls_ecp_curve_info_from_grp_id(ec->grp.id)->name); - lua_setfield(L, -2, "crv"); - lua_pushstring(L, x_b64); - lua_setfield(L, -2, "x"); - lua_pushstring(L, y_b64); - lua_setfield(L, -2, "y"); - - free(x); - free(y); - free(x_b64); - free(y_b64); - } else { - lua_pushnil(L); - lua_pushstring(L, "Unsupported key type"); - mbedtls_pk_free(&key); - return 2; - } + mbedtls_pk_context key; + mbedtls_pk_init(&key); + int ret; + // Parse the PEM key + if ((ret = mbedtls_pk_parse_key(&key, (const unsigned char *)pem_key, + strlen(pem_key) + 1, NULL, 0)) != 0 && + (ret = mbedtls_pk_parse_public_key(&key, (const unsigned char *)pem_key, + strlen(pem_key) + 1)) != 0) { + lua_pushnil(L); + lua_pushfstring(L, "Failed to parse PEM key: -0x%04x", -ret); mbedtls_pk_free(&key); - return 1; + return 2; + } + + lua_newtable(L); // Create a new Lua table + + if (mbedtls_pk_get_type(&key) == MBEDTLS_PK_RSA) { + // Handle RSA keys + const mbedtls_rsa_context *rsa = mbedtls_pk_rsa(key); + size_t n_len = mbedtls_mpi_size(&rsa->N); + size_t e_len = mbedtls_mpi_size(&rsa->E); + + unsigned char *n = malloc(n_len); + unsigned char *e = malloc(e_len); + + if (!n || !e) { + lua_pushnil(L); + lua_pushstring(L, "Memory allocation failed"); + free(n); + free(e); + mbedtls_pk_free(&key); + return 2; + } + + mbedtls_mpi_write_binary(&rsa->N, n, n_len); + mbedtls_mpi_write_binary(&rsa->E, e, e_len); + + char *n_b64 = NULL, *e_b64 = NULL; + size_t n_b64_len, e_b64_len; + + mbedtls_base64_encode(NULL, 0, &n_b64_len, n, n_len); + mbedtls_base64_encode(NULL, 0, &e_b64_len, e, e_len); + + n_b64 = malloc(n_b64_len + 1); + e_b64 = malloc(e_b64_len + 1); + + if (!n_b64 || !e_b64) { + lua_pushnil(L); + lua_pushstring(L, "Memory allocation failed"); + free(n); + free(e); + free(n_b64); + free(e_b64); + mbedtls_pk_free(&key); + return 2; + } + + mbedtls_base64_encode((unsigned char *)n_b64, n_b64_len, &n_b64_len, n, + n_len); + mbedtls_base64_encode((unsigned char *)e_b64, e_b64_len, &e_b64_len, e, + e_len); + + n_b64[n_b64_len] = '\0'; + e_b64[e_b64_len] = '\0'; + + lua_pushstring(L, "RSA"); + lua_setfield(L, -2, "kty"); + lua_pushstring(L, n_b64); + lua_setfield(L, -2, "n"); + lua_pushstring(L, e_b64); + lua_setfield(L, -2, "e"); + + free(n); + free(e); + free(n_b64); + free(e_b64); + } else if (mbedtls_pk_get_type(&key) == MBEDTLS_PK_ECKEY) { + // Handle ECDSA keys + const mbedtls_ecp_keypair *ec = mbedtls_pk_ec(key); + const mbedtls_ecp_point *Q = &ec->Q; + size_t x_len = (ec->grp.pbits + 7) / 8; + size_t y_len = (ec->grp.pbits + 7) / 8; + + unsigned char *x = malloc(x_len); + unsigned char *y = malloc(y_len); + + if (!x || !y) { + lua_pushnil(L); + lua_pushstring(L, "Memory allocation failed"); + free(x); + free(y); + mbedtls_pk_free(&key); + return 2; + } + + mbedtls_mpi_write_binary(&Q->X, x, x_len); + mbedtls_mpi_write_binary(&Q->Y, y, y_len); + + char *x_b64 = NULL, *y_b64 = NULL; + size_t x_b64_len, y_b64_len; + + mbedtls_base64_encode(NULL, 0, &x_b64_len, x, x_len); + mbedtls_base64_encode(NULL, 0, &y_b64_len, y, y_len); + + x_b64 = malloc(x_b64_len + 1); + y_b64 = malloc(y_b64_len + 1); + + if (!x_b64 || !y_b64) { + lua_pushnil(L); + lua_pushstring(L, "Memory allocation failed"); + free(x); + free(y); + free(x_b64); + free(y_b64); + mbedtls_pk_free(&key); + return 2; + } + + mbedtls_base64_encode((unsigned char *)x_b64, x_b64_len, &x_b64_len, x, + x_len); + mbedtls_base64_encode((unsigned char *)y_b64, y_b64_len, &y_b64_len, y, + y_len); + + x_b64[x_b64_len] = '\0'; + y_b64[y_b64_len] = '\0'; + + lua_pushstring(L, "EC"); + lua_setfield(L, -2, "kty"); + lua_pushstring(L, mbedtls_ecp_curve_info_from_grp_id(ec->grp.id)->name); + lua_setfield(L, -2, "crv"); + lua_pushstring(L, x_b64); + lua_setfield(L, -2, "x"); + lua_pushstring(L, y_b64); + lua_setfield(L, -2, "y"); + + free(x); + free(y); + free(x_b64); + free(y_b64); + } else { + lua_pushnil(L); + lua_pushstring(L, "Unsupported key type"); + mbedtls_pk_free(&key); + return 2; + } + + mbedtls_pk_free(&key); + return 1; } // CSR Creation Function static int LuaGenerateCSR(lua_State *L) { - const char *key_pem = luaL_checkstring(L, 1); - const char *subject_name; - const char *san_list = luaL_optstring(L, 3, NULL); + const char *key_pem = luaL_checkstring(L, 1); + const char *subject_name; + const char *san_list = luaL_optstring(L, 3, NULL); - if (lua_isnoneornil(L, 2)) { - subject_name = ""; - } else { - subject_name = luaL_checkstring(L, 2); + if (lua_isnoneornil(L, 2)) { + subject_name = ""; + } else { + subject_name = luaL_checkstring(L, 2); + } + + if (lua_isnoneornil(L, 3) && subject_name[0] == '\0') { + lua_pushnil(L); + lua_pushstring(L, "Subject name or SANs are required"); + return 2; + } + mbedtls_pk_context key; + mbedtls_x509write_csr req; + char buf[4096]; + int ret; + + mbedtls_pk_init(&key); + mbedtls_x509write_csr_init(&req); + + if ((ret = mbedtls_pk_parse_key(&key, (const unsigned char *)key_pem, + strlen(key_pem) + 1, NULL, 0)) != 0) { + lua_pushnil(L); + lua_pushfstring(L, "Failed to parse key: %d", ret); + return 2; + } + + mbedtls_x509write_csr_set_subject_name(&req, subject_name); + mbedtls_x509write_csr_set_key(&req, &key); + mbedtls_x509write_csr_set_md_alg(&req, MBEDTLS_MD_SHA256); + + if (san_list) { + if ((ret = mbedtls_x509write_csr_set_extension( + &req, MBEDTLS_OID_SUBJECT_ALT_NAME, + MBEDTLS_OID_SIZE(MBEDTLS_OID_SUBJECT_ALT_NAME), + (const unsigned char *)san_list, strlen(san_list))) != 0) { + lua_pushnil(L); + lua_pushfstring(L, "Failed to set SANs: %d", ret); + return 2; } + } + if ((ret = mbedtls_x509write_csr_pem(&req, (unsigned char *)buf, sizeof(buf), + NULL, NULL)) < 0) { + lua_pushnil(L); + lua_pushfstring(L, "Failed to write CSR: %d", ret); + return 2; + } - if (lua_isnoneornil(L, 3) && subject_name[0] == '\0') { - lua_pushnil(L); - lua_pushstring(L, "Subject name or SANs are required"); - return 2; - } - mbedtls_pk_context key; - mbedtls_x509write_csr req; - char buf[4096]; - int ret; + lua_pushstring(L, buf); - mbedtls_pk_init(&key); - mbedtls_x509write_csr_init(&req); + mbedtls_pk_free(&key); + mbedtls_x509write_csr_free(&req); - if ((ret = mbedtls_pk_parse_key(&key, (const unsigned char *)key_pem, strlen(key_pem) + 1, NULL, 0)) != 0) { - lua_pushnil(L); - lua_pushfstring(L, "Failed to parse key: %d", ret); - return 2; - } - - mbedtls_x509write_csr_set_subject_name(&req, subject_name); - mbedtls_x509write_csr_set_key(&req, &key); - mbedtls_x509write_csr_set_md_alg(&req, MBEDTLS_MD_SHA256); - - if (san_list) { - if ((ret = mbedtls_x509write_csr_set_extension(&req, MBEDTLS_OID_SUBJECT_ALT_NAME, MBEDTLS_OID_SIZE(MBEDTLS_OID_SUBJECT_ALT_NAME), (const unsigned char *)san_list, strlen(san_list))) != 0) { - lua_pushnil(L); - lua_pushfstring(L, "Failed to set SANs: %d", ret); - return 2; - } - } - - if ((ret = mbedtls_x509write_csr_pem(&req, (unsigned char *)buf, sizeof(buf), NULL, NULL)) < 0) { - lua_pushnil(L); - lua_pushfstring(L, "Failed to write CSR: %d", ret); - return 2; - } - - lua_pushstring(L, buf); - - mbedtls_pk_free(&key); - mbedtls_x509write_csr_free(&req); - - return 1; + return 1; } // RSA // Generate RSA Key Pair static bool RSAGenerateKeyPair(char **private_key_pem, size_t *private_key_len, - char **public_key_pem, size_t *public_key_len, - unsigned int key_length) { + char **public_key_pem, size_t *public_key_len, + unsigned int key_length) { int rc; mbedtls_pk_context key; mbedtls_pk_init(&key); @@ -281,45 +291,37 @@ static bool RSAGenerateKeyPair(char **private_key_pem, size_t *private_key_len, mbedtls_pk_free(&key); return true; } - -/** - * Lua wrapper for RSA key pair generation - * - * Lua function signature: RSAGenerateKeyPair([key_length]) - * @param L Lua state - * @return 2 on success (private_key, public_key), 2 on failure (nil, - * error_message) - */ static int LuaRSAGenerateKeyPair(lua_State *L) { - int bits = 2048; - // If no arguments, or first argument is nil, default to 2048 - if (lua_gettop(L) == 0 || lua_isnoneornil(L, 1)) { - bits = 2048; - } else if (lua_gettop(L) == 1 && lua_type(L, 1) == LUA_TNUMBER) { - bits = (int)lua_tointeger(L, 1); - } else { - bits = (int)luaL_optinteger(L, 2, 2048); - } + int bits = 2048; + // If no arguments, or first argument is nil, default to 2048 + if (lua_gettop(L) == 0 || lua_isnoneornil(L, 1)) { + bits = 2048; + } else if (lua_gettop(L) == 1 && lua_type(L, 1) == LUA_TNUMBER) { + bits = (int)lua_tointeger(L, 1); + } else { + bits = (int)luaL_optinteger(L, 2, 2048); + } - char *private_key, *public_key; - size_t private_len, public_len; - - // Call the C function to generate the key pair - if (!RSAGenerateKeyPair(&private_key, &private_len, &public_key, &public_len, bits)) { - lua_pushnil(L); - lua_pushstring(L, "Failed to generate RSA key pair"); - return 2; - } - - // Push results to Lua - lua_pushstring(L, private_key); - lua_pushstring(L, public_key); - - // Clean up - free(private_key); - free(public_key); + char *private_key, *public_key; + size_t private_len, public_len; + // Call the C function to generate the key pair + if (!RSAGenerateKeyPair(&private_key, &private_len, &public_key, &public_len, + bits)) { + lua_pushnil(L); + lua_pushstring(L, "Failed to generate RSA key pair"); return 2; + } + + // Push results to Lua + lua_pushstring(L, private_key); + lua_pushstring(L, public_key); + + // Clean up + free(private_key); + free(public_key); + + return 2; } // Helper to get string field from options table for RSA @@ -336,7 +338,7 @@ static int LuaRSAGenerateKeyPair(lua_State *L) { // } static char *RSAEncrypt(const char *public_key_pem, const unsigned char *data, - size_t data_len, size_t *out_len) { + size_t data_len, size_t *out_len) { int rc; // Parse public key @@ -380,30 +382,31 @@ static char *RSAEncrypt(const char *public_key_pem, const unsigned char *data, return (char *)output; } static int LuaRSAEncrypt(lua_State *L) { - // Args: key, plaintext, options table - size_t keylen, ptlen; - const char *key = luaL_checklstring(L, 1, &keylen); - const unsigned char *plaintext = (const unsigned char *)luaL_checklstring(L, 2, &ptlen); - // int options_idx = 3; - // const char *padding = parse_rsa_options(L, options_idx); - size_t out_len; + // Args: key, plaintext, options table + size_t keylen, ptlen; + const char *key = luaL_checklstring(L, 1, &keylen); + const unsigned char *plaintext = + (const unsigned char *)luaL_checklstring(L, 2, &ptlen); + // int options_idx = 3; + // const char *padding = parse_rsa_options(L, options_idx); + size_t out_len; - char *encrypted = RSAEncrypt(key, plaintext, ptlen, &out_len); - if (!encrypted) { - lua_pushnil(L); - lua_pushstring(L, "Encryption failed"); - return 2; - } + char *encrypted = RSAEncrypt(key, plaintext, ptlen, &out_len); + if (!encrypted) { + lua_pushnil(L); + lua_pushstring(L, "Encryption failed"); + return 2; + } - lua_pushlstring(L, encrypted, out_len); - free(encrypted); + lua_pushlstring(L, encrypted, out_len); + free(encrypted); - return 1; + return 1; } static char *RSADecrypt(const char *private_key_pem, - const unsigned char *encrypted_data, size_t encrypted_len, - size_t *out_len) { + const unsigned char *encrypted_data, + size_t encrypted_len, size_t *out_len) { int rc; // Parse private key @@ -447,30 +450,32 @@ static char *RSADecrypt(const char *private_key_pem, return (char *)output; } static int LuaRSADecrypt(lua_State *L) { - // Args: key, ciphertext, options table - size_t keylen, ctlen; - const char *key = luaL_checklstring(L, 1, &keylen); - const unsigned char *ciphertext = (const unsigned char *)luaL_checklstring(L, 2, &ctlen); - // int options_idx = 3; - // const char *padding = parse_rsa_options(L, options_idx); - size_t out_len; + // Args: key, ciphertext, options table + size_t keylen, ctlen; + const char *key = luaL_checklstring(L, 1, &keylen); + const unsigned char *ciphertext = + (const unsigned char *)luaL_checklstring(L, 2, &ctlen); + // int options_idx = 3; + // const char *padding = parse_rsa_options(L, options_idx); + size_t out_len; - char *decrypted = RSADecrypt(key, ciphertext, ctlen, &out_len); - if (!decrypted) { - lua_pushnil(L); - lua_pushstring(L, "Decryption failed"); - return 2; - } + char *decrypted = RSADecrypt(key, ciphertext, ctlen, &out_len); + if (!decrypted) { + lua_pushnil(L); + lua_pushstring(L, "Decryption failed"); + return 2; + } - lua_pushlstring(L, decrypted, out_len); - free(decrypted); + lua_pushlstring(L, decrypted, out_len); + free(decrypted); - return 1; + return 1; } // RSA Signing static char *RSASign(const char *private_key_pem, const unsigned char *data, - size_t data_len, const char *hash_algo_str, size_t *sig_len) { + size_t data_len, const char *hash_algo_str, + size_t *sig_len) { int rc; unsigned char hash[64]; // Large enough for SHA-512 size_t hash_len = 32; // Default for SHA-256 @@ -554,7 +559,7 @@ static int LuaRSASign(lua_State *L) { // Call the C implementation signature = (unsigned char *)RSASign(key_pem, (const unsigned char *)msg, - msg_len, hash_algo_str, &sig_len); + msg_len, hash_algo_str, &sig_len); if (!signature) { return luaL_error(L, "failed to sign message"); @@ -570,8 +575,8 @@ static int LuaRSASign(lua_State *L) { } static int RSAVerify(const char *public_key_pem, const unsigned char *data, - size_t data_len, const unsigned char *signature, - size_t sig_len, const char *hash_algo_str) { + size_t data_len, const unsigned char *signature, + size_t sig_len, const char *hash_algo_str) { int rc; unsigned char hash[64]; // Large enough for SHA-512 size_t hash_len = 32; // Default for SHA-256 @@ -643,7 +648,7 @@ static int LuaRSAVerify(lua_State *L) { // Call the C implementation result = RSAVerify(key_pem, (const unsigned char *)msg, msg_len, - (const unsigned char *)signature, sig_len, hash_algo_str); + (const unsigned char *)signature, sig_len, hash_algo_str); // Return boolean result (0 means valid signature) lua_pushboolean(L, result == 0); @@ -822,7 +827,7 @@ static mbedtls_ecp_group_id find_curve_by_name(const char *name) { // Generate an ECDSA key pair and return in PEM format static int ECDSAGenerateKeyPair(const char *curve_name, char **priv_key_pem, - char **pub_key_pem) { + char **pub_key_pem) { mbedtls_pk_context key; unsigned char output_buf[16000]; int ret; @@ -939,8 +944,8 @@ static int LuaECDSAGenerateKeyPair(lua_State *L) { // Sign a message using an ECDSA private key in PEM format static int ECDSASign(const char *priv_key_pem, const char *message, - hash_algorithm_t hash_alg, unsigned char **signature, - size_t *sig_len) { + hash_algorithm_t hash_alg, unsigned char **signature, + size_t *sig_len) { mbedtls_pk_context key; unsigned char hash[64]; // Max hash size (SHA-512) size_t hash_size; @@ -1032,8 +1037,8 @@ static int LuaECDSASign(lua_State *L) { // Verify a signature using an ECDSA public key in PEM format static int ECDSAVerify(const char *pub_key_pem, const char *message, - const unsigned char *signature, size_t sig_len, - hash_algorithm_t hash_alg) { + const unsigned char *signature, size_t sig_len, + hash_algorithm_t hash_alg) { mbedtls_pk_context key; unsigned char hash[64]; // Max hash size (SHA-512) size_t hash_size; @@ -1089,7 +1094,8 @@ static int LuaECDSAVerify(lua_State *L) { const char *pub_key_pem = luaL_checkstring(L, 1); const char *message = luaL_checkstring(L, 2); size_t sig_len; - const unsigned char *signature = (const unsigned char *)luaL_checklstring(L, 3, &sig_len); + const unsigned char *signature = + (const unsigned char *)luaL_checklstring(L, 3, &sig_len); const char *hash_name = luaL_optstring(L, 4, "sha256"); hash_algorithm_t hash_alg = string_to_hash_alg(hash_name); @@ -1100,560 +1106,560 @@ static int LuaECDSAVerify(lua_State *L) { return 1; } - // AES // AES key generation helper static int LuaAesGenerateKey(lua_State *L) { - int keybits = 128; - if (lua_gettop(L) >= 1 && !lua_isnil(L, 1)) { - keybits = luaL_checkinteger(L, 1); - } - int keylen = keybits / 8; - if ((keybits != 128 && keybits != 192 && keybits != 256) || (keylen != 16 && keylen != 24 && keylen != 32)) { - lua_pushnil(L); - lua_pushstring(L, "AES key length must be 128, 192, or 256 bits"); - return 2; - } - unsigned char key[32]; - mbedtls_entropy_context entropy; - mbedtls_ctr_drbg_context ctr_drbg; - mbedtls_entropy_init(&entropy); - mbedtls_ctr_drbg_init(&ctr_drbg); - const char *pers = "aes_keygen"; - int ret = mbedtls_ctr_drbg_seed(&ctr_drbg, mbedtls_entropy_func, &entropy, (const unsigned char *)pers, strlen(pers)); - if (ret != 0) { - lua_pushnil(L); - lua_pushstring(L, "Failed to initialize RNG for AES key"); - mbedtls_ctr_drbg_free(&ctr_drbg); - mbedtls_entropy_free(&entropy); - return 2; - } - ret = mbedtls_ctr_drbg_random(&ctr_drbg, key, keylen); + int keybits = 128; + if (lua_gettop(L) >= 1 && !lua_isnil(L, 1)) { + keybits = luaL_checkinteger(L, 1); + } + int keylen = keybits / 8; + if ((keybits != 128 && keybits != 192 && keybits != 256) || + (keylen != 16 && keylen != 24 && keylen != 32)) { + lua_pushnil(L); + lua_pushstring(L, "AES key length must be 128, 192, or 256 bits"); + return 2; + } + unsigned char key[32]; + mbedtls_entropy_context entropy; + mbedtls_ctr_drbg_context ctr_drbg; + mbedtls_entropy_init(&entropy); + mbedtls_ctr_drbg_init(&ctr_drbg); + const char *pers = "aes_keygen"; + int ret = mbedtls_ctr_drbg_seed(&ctr_drbg, mbedtls_entropy_func, &entropy, + (const unsigned char *)pers, strlen(pers)); + if (ret != 0) { + lua_pushnil(L); + lua_pushstring(L, "Failed to initialize RNG for AES key"); mbedtls_ctr_drbg_free(&ctr_drbg); mbedtls_entropy_free(&entropy); - if (ret != 0) { - lua_pushnil(L); - lua_pushstring(L, "Failed to generate random AES key"); - return 2; - } - lua_pushlstring(L, (const char *)key, keylen); - return 1; + return 2; + } + ret = mbedtls_ctr_drbg_random(&ctr_drbg, key, keylen); + mbedtls_ctr_drbg_free(&ctr_drbg); + mbedtls_entropy_free(&entropy); + if (ret != 0) { + lua_pushnil(L); + lua_pushstring(L, "Failed to generate random AES key"); + return 2; + } + lua_pushlstring(L, (const char *)key, keylen); + return 1; } // Helper to get string field from options table typedef struct { - const char *mode; - const unsigned char *iv; - size_t ivlen; + const char *mode; + const unsigned char *iv; + size_t ivlen; + const unsigned char *tag; + size_t taglen; + const unsigned char *aad; + size_t aadlen; } aes_options_t; -static void parse_aes_options(lua_State *L, int options_idx, aes_options_t *opts) { - opts->mode = "cbc"; - opts->iv = NULL; - opts->ivlen = 0; - if (lua_istable(L, options_idx)) { - lua_getfield(L, options_idx, "mode"); - if (!lua_isnil(L, -1)) opts->mode = lua_tostring(L, -1); - lua_pop(L, 1); - lua_getfield(L, options_idx, "iv"); - if (lua_isstring(L, -1)) { - opts->iv = (const unsigned char *)lua_tolstring(L, -1, &opts->ivlen); - } - lua_pop(L, 1); +static void parse_aes_options(lua_State *L, int options_idx, + aes_options_t *opts) { + opts->mode = "cbc"; + opts->iv = NULL; + opts->ivlen = 0; + opts->tag = NULL; + opts->taglen = 0; + opts->aad = NULL; + opts->aadlen = 0; + if (lua_istable(L, options_idx)) { + lua_getfield(L, options_idx, "mode"); + if (!lua_isnil(L, -1)) + opts->mode = lua_tostring(L, -1); + lua_pop(L, 1); + lua_getfield(L, options_idx, "iv"); + if (lua_isstring(L, -1)) { + opts->iv = (const unsigned char *)lua_tolstring(L, -1, &opts->ivlen); } -} - -// Helper for AES decrypt options -typedef struct { - const char *mode; - const unsigned char *iv; - size_t ivlen; - const unsigned char *tag; - size_t taglen; - const unsigned char *aad; - size_t aadlen; -} aes_decrypt_options_t; - -static void parse_aes_decrypt_options(lua_State *L, int options_idx, aes_decrypt_options_t *opts) { - opts->mode = "cbc"; - opts->iv = NULL; - opts->ivlen = 0; - opts->tag = NULL; - opts->taglen = 0; - opts->aad = NULL; - opts->aadlen = 0; - if (lua_istable(L, options_idx)) { - lua_getfield(L, options_idx, "mode"); - if (!lua_isnil(L, -1)) opts->mode = lua_tostring(L, -1); - lua_pop(L, 1); - lua_getfield(L, options_idx, "iv"); - if (lua_isstring(L, -1)) { - opts->iv = (const unsigned char *)lua_tolstring(L, -1, &opts->ivlen); - } - lua_pop(L, 1); - lua_getfield(L, options_idx, "tag"); - if (lua_isstring(L, -1)) { - opts->tag = (const unsigned char *)lua_tolstring(L, -1, &opts->taglen); - } - lua_pop(L, 1); - lua_getfield(L, options_idx, "aad"); - if (lua_isstring(L, -1)) { - opts->aad = (const unsigned char *)lua_tolstring(L, -1, &opts->aadlen); - } - lua_pop(L, 1); + lua_pop(L, 1); + lua_getfield(L, options_idx, "tag"); + if (lua_isstring(L, -1)) { + opts->tag = (const unsigned char *)lua_tolstring(L, -1, &opts->taglen); } + lua_pop(L, 1); + lua_getfield(L, options_idx, "aad"); + if (lua_isstring(L, -1)) { + opts->aad = (const unsigned char *)lua_tolstring(L, -1, &opts->aadlen); + } + lua_pop(L, 1); + } } // AES encryption supporting CBC, GCM, and CTR modes static int LuaAesEncrypt(lua_State *L) { - // Args: key, plaintext, options table - size_t keylen, ptlen; - const unsigned char *key = (const unsigned char *)luaL_checklstring(L, 1, &keylen); - const unsigned char *plaintext = (const unsigned char *)luaL_checklstring(L, 2, &ptlen); - int options_idx = 3; - aes_options_t opts; - parse_aes_options(L, options_idx, &opts); - const char *mode = opts.mode; - const unsigned char *iv = opts.iv; - size_t ivlen = opts.ivlen; - unsigned char *gen_iv = NULL; - int iv_was_generated = 0; - int ret = 0; - unsigned char *output = NULL; - int is_gcm = 0, is_ctr = 0, is_cbc = 0; - if (strcasecmp(mode, "cbc") == 0) { - is_cbc = 1; - } else if (strcasecmp(mode, "gcm") == 0) { - is_gcm = 1; - } else if (strcasecmp(mode, "ctr") == 0) { - is_ctr = 1; - } else { - lua_pushnil(L); - lua_pushstring(L, "Unsupported AES mode. Use 'cbc', 'gcm', or 'ctr'."); - return 2; - } - // If IV is not provided, auto-generate - if (!iv) { - if (is_gcm) { - ivlen = 12; - } else { - ivlen = 16; - } - gen_iv = malloc(ivlen); - if (!gen_iv) { - lua_pushnil(L); - lua_pushstring(L, "Failed to allocate IV"); - return 2; - } - mbedtls_entropy_context entropy; - mbedtls_ctr_drbg_context ctr_drbg; - mbedtls_entropy_init(&entropy); - mbedtls_ctr_drbg_init(&ctr_drbg); - mbedtls_ctr_drbg_seed(&ctr_drbg, mbedtls_entropy_func, &entropy, NULL, 0); - mbedtls_ctr_drbg_random(&ctr_drbg, gen_iv, ivlen); - mbedtls_ctr_drbg_free(&ctr_drbg); - mbedtls_entropy_free(&entropy); - iv = gen_iv; - iv_was_generated = 1; - } - if (is_cbc) { - // PKCS7 padding - size_t block_size = 16; - size_t padlen = block_size - (ptlen % block_size); - size_t ctlen = ptlen + padlen; - unsigned char *input = malloc(ctlen); - if (!input) { - lua_pushnil(L); - lua_pushstring(L, "Memory allocation failed"); - return 2; - } - memcpy(input, plaintext, ptlen); - memset(input + ptlen, (unsigned char)padlen, padlen); - output = malloc(ctlen); - if (!output) { - free(input); - lua_pushnil(L); - lua_pushstring(L, "Memory allocation failed"); - return 2; - } - mbedtls_aes_context aes; - mbedtls_aes_init(&aes); - ret = mbedtls_aes_setkey_enc(&aes, key, keylen * 8); - if (ret != 0) { - free(input); - free(output); - mbedtls_aes_free(&aes); - lua_pushnil(L); - lua_pushstring(L, "Failed to set AES encryption key"); - return 2; - } - unsigned char iv_copy[16]; - memcpy(iv_copy, iv, 16); - ret = mbedtls_aes_crypt_cbc(&aes, MBEDTLS_AES_ENCRYPT, ctlen, iv_copy, input, output); - mbedtls_aes_free(&aes); - free(input); - if (ret != 0) { - free(output); - lua_pushnil(L); - lua_pushstring(L, "AES CBC encryption failed"); - return 2; - } - lua_pushlstring(L, (const char *)output, ctlen); - lua_pushlstring(L, (const char *)iv, ivlen); - free(output); - if (iv_was_generated) free(gen_iv); - return 2; - } else if (is_ctr) { - // CTR mode: no padding - output = malloc(ptlen); - if (!output) { - lua_pushnil(L); - lua_pushstring(L, "Memory allocation failed"); - return 2; - } - mbedtls_aes_context aes; - mbedtls_aes_init(&aes); - ret = mbedtls_aes_setkey_enc(&aes, key, keylen * 8); - if (ret != 0) { - free(output); - mbedtls_aes_free(&aes); - lua_pushnil(L); - lua_pushstring(L, "Failed to set AES encryption key"); - return 2; - } - unsigned char nonce_counter[16]; - unsigned char stream_block[16]; - size_t nc_off = 0; - memcpy(nonce_counter, iv, 16); - memset(stream_block, 0, 16); - ret = mbedtls_aes_crypt_ctr(&aes, ptlen, &nc_off, nonce_counter, stream_block, plaintext, output); - mbedtls_aes_free(&aes); - if (ret != 0) { - free(output); - lua_pushnil(L); - lua_pushstring(L, "AES CTR encryption failed"); - return 2; - } - lua_pushlstring(L, (const char *)output, ptlen); - lua_pushlstring(L, (const char *)iv, ivlen); - free(output); - if (iv_was_generated) free(gen_iv); - return 2; - } else if (is_gcm) { - // GCM mode: authenticated encryption - size_t taglen = 16; - unsigned char tag[16]; - output = malloc(ptlen); - if (!output) { - lua_pushnil(L); - lua_pushstring(L, "Memory allocation failed"); - return 2; - } - mbedtls_gcm_context gcm; - mbedtls_gcm_init(&gcm); - ret = mbedtls_gcm_setkey(&gcm, MBEDTLS_CIPHER_ID_AES, key, keylen * 8); - if (ret != 0) { - free(output); - mbedtls_gcm_free(&gcm); - lua_pushnil(L); - lua_pushstring(L, "Failed to set AES GCM key"); - return 2; - } - // Use actual ivlen, not hardcoded 16 - ret = mbedtls_gcm_crypt_and_tag(&gcm, MBEDTLS_GCM_ENCRYPT, ptlen, iv, ivlen, NULL, 0, plaintext, output, taglen, tag); - mbedtls_gcm_free(&gcm); - if (ret != 0) { - free(output); - lua_pushnil(L); - lua_pushstring(L, "AES GCM encryption failed"); - return 2; - } - lua_pushlstring(L, (const char *)output, ptlen); - lua_pushlstring(L, (const char *)iv, ivlen); - lua_pushlstring(L, (const char *)tag, taglen); - free(output); - if (iv_was_generated) free(gen_iv); - return 3; - } + // Args: key, plaintext, options table + size_t keylen, ptlen; + const unsigned char *key = + (const unsigned char *)luaL_checklstring(L, 1, &keylen); + const unsigned char *plaintext = + (const unsigned char *)luaL_checklstring(L, 2, &ptlen); + int options_idx = 3; + aes_options_t opts; + parse_aes_options(L, options_idx, &opts); + const char *mode = opts.mode; + const unsigned char *iv = opts.iv; + size_t ivlen = opts.ivlen; + unsigned char *gen_iv = NULL; + int iv_was_generated = 0; + int ret = 0; + unsigned char *output = NULL; + int is_gcm = 0, is_ctr = 0, is_cbc = 0; + if (strcasecmp(mode, "cbc") == 0) { + is_cbc = 1; + } else if (strcasecmp(mode, "gcm") == 0) { + is_gcm = 1; + } else if (strcasecmp(mode, "ctr") == 0) { + is_ctr = 1; + } else { lua_pushnil(L); - lua_pushstring(L, "Internal error in AES encrypt"); + lua_pushstring(L, "Unsupported AES mode. Use 'cbc', 'gcm', or 'ctr'."); return 2; + } + // If IV is not provided, auto-generate + if (!iv) { + if (is_gcm) { + ivlen = 12; + } else { + ivlen = 16; + } + gen_iv = malloc(ivlen); + if (!gen_iv) { + lua_pushnil(L); + lua_pushstring(L, "Failed to allocate IV"); + return 2; + } + mbedtls_entropy_context entropy; + mbedtls_ctr_drbg_context ctr_drbg; + mbedtls_entropy_init(&entropy); + mbedtls_ctr_drbg_init(&ctr_drbg); + mbedtls_ctr_drbg_seed(&ctr_drbg, mbedtls_entropy_func, &entropy, NULL, 0); + mbedtls_ctr_drbg_random(&ctr_drbg, gen_iv, ivlen); + mbedtls_ctr_drbg_free(&ctr_drbg); + mbedtls_entropy_free(&entropy); + iv = gen_iv; + iv_was_generated = 1; + } + if (is_cbc) { + // PKCS7 padding + size_t block_size = 16; + size_t padlen = block_size - (ptlen % block_size); + size_t ctlen = ptlen + padlen; + unsigned char *input = malloc(ctlen); + if (!input) { + lua_pushnil(L); + lua_pushstring(L, "Memory allocation failed"); + return 2; + } + memcpy(input, plaintext, ptlen); + memset(input + ptlen, (unsigned char)padlen, padlen); + output = malloc(ctlen); + if (!output) { + free(input); + lua_pushnil(L); + lua_pushstring(L, "Memory allocation failed"); + return 2; + } + mbedtls_aes_context aes; + mbedtls_aes_init(&aes); + ret = mbedtls_aes_setkey_enc(&aes, key, keylen * 8); + if (ret != 0) { + free(input); + free(output); + mbedtls_aes_free(&aes); + lua_pushnil(L); + lua_pushstring(L, "Failed to set AES encryption key"); + return 2; + } + unsigned char iv_copy[16]; + memcpy(iv_copy, iv, 16); + ret = mbedtls_aes_crypt_cbc(&aes, MBEDTLS_AES_ENCRYPT, ctlen, iv_copy, + input, output); + mbedtls_aes_free(&aes); + free(input); + if (ret != 0) { + free(output); + lua_pushnil(L); + lua_pushstring(L, "AES CBC encryption failed"); + return 2; + } + lua_pushlstring(L, (const char *)output, ctlen); + lua_pushlstring(L, (const char *)iv, ivlen); + free(output); + if (iv_was_generated) + free(gen_iv); + return 2; + } else if (is_ctr) { + // CTR mode: no padding + output = malloc(ptlen); + if (!output) { + lua_pushnil(L); + lua_pushstring(L, "Memory allocation failed"); + return 2; + } + mbedtls_aes_context aes; + mbedtls_aes_init(&aes); + ret = mbedtls_aes_setkey_enc(&aes, key, keylen * 8); + if (ret != 0) { + free(output); + mbedtls_aes_free(&aes); + lua_pushnil(L); + lua_pushstring(L, "Failed to set AES encryption key"); + return 2; + } + unsigned char nonce_counter[16]; + unsigned char stream_block[16]; + size_t nc_off = 0; + memcpy(nonce_counter, iv, 16); + memset(stream_block, 0, 16); + ret = mbedtls_aes_crypt_ctr(&aes, ptlen, &nc_off, nonce_counter, + stream_block, plaintext, output); + mbedtls_aes_free(&aes); + if (ret != 0) { + free(output); + lua_pushnil(L); + lua_pushstring(L, "AES CTR encryption failed"); + return 2; + } + lua_pushlstring(L, (const char *)output, ptlen); + lua_pushlstring(L, (const char *)iv, ivlen); + free(output); + if (iv_was_generated) + free(gen_iv); + return 2; + } else if (is_gcm) { + // GCM mode: authenticated encryption + size_t taglen = 16; + unsigned char tag[16]; + output = malloc(ptlen); + if (!output) { + lua_pushnil(L); + lua_pushstring(L, "Memory allocation failed"); + return 2; + } + mbedtls_gcm_context gcm; + mbedtls_gcm_init(&gcm); + ret = mbedtls_gcm_setkey(&gcm, MBEDTLS_CIPHER_ID_AES, key, keylen * 8); + if (ret != 0) { + free(output); + mbedtls_gcm_free(&gcm); + lua_pushnil(L); + lua_pushstring(L, "Failed to set AES GCM key"); + return 2; + } + ret = mbedtls_gcm_crypt_and_tag(&gcm, MBEDTLS_GCM_ENCRYPT, ptlen, iv, ivlen, + NULL, 0, plaintext, output, taglen, tag); + mbedtls_gcm_free(&gcm); + if (ret != 0) { + free(output); + lua_pushnil(L); + lua_pushstring(L, "AES GCM encryption failed"); + return 2; + } + lua_pushlstring(L, (const char *)output, ptlen); + lua_pushlstring(L, (const char *)iv, ivlen); + lua_pushlstring(L, (const char *)tag, taglen); + free(output); + if (iv_was_generated) + free(gen_iv); + return 3; + } + lua_pushnil(L); + lua_pushstring(L, "Internal error in AES encrypt"); + return 2; } // AES decryption supporting CBC, GCM, and CTR modes static int LuaAesDecrypt(lua_State *L) { - // Args: key, ciphertext, options table - size_t keylen, ctlen; - const unsigned char *key = (const unsigned char *)luaL_checklstring(L, 1, &keylen); - const unsigned char *ciphertext = (const unsigned char *)luaL_checklstring(L, 2, &ctlen); - int options_idx = 3; - aes_decrypt_options_t opts; - parse_aes_decrypt_options(L, options_idx, &opts); - const char *mode = opts.mode; - const unsigned char *iv = opts.iv; - size_t ivlen = opts.ivlen; - const unsigned char *tag = opts.tag; - size_t taglen = opts.taglen; - const unsigned char *aad = opts.aad; - size_t aadlen = opts.aadlen; - int is_gcm = 0, is_ctr = 0, is_cbc = 0; - if (strcasecmp(mode, "cbc") == 0) { - is_cbc = 1; - } else if (strcasecmp(mode, "gcm") == 0) { - is_gcm = 1; - } else if (strcasecmp(mode, "ctr") == 0) { - is_ctr = 1; - } else { - lua_pushnil(L); - lua_pushstring(L, "Unsupported AES mode. Use 'cbc', 'gcm', or 'ctr'."); - return 2; - } - // Validate key length (16, 24, 32 bytes) - if (keylen != 16 && keylen != 24 && keylen != 32) { - lua_pushnil(L); - lua_pushstring(L, "AES key must be 16, 24, or 32 bytes"); - return 2; - } - // Validate IV/nonce length - if (is_cbc || is_ctr) { - if (ivlen != 16) { - lua_pushnil(L); - lua_pushstring(L, "AES IV/nonce must be 16 bytes for CBC/CTR"); - return 2; - } - } else if (is_gcm) { - if (ivlen < 12 || ivlen > 16) { - lua_pushnil(L); - lua_pushstring(L, "AES GCM nonce must be 12-16 bytes"); - return 2; - } - } - - // GCM: require tag and optional AAD - if (is_gcm) { - if (!tag || taglen < 12 || taglen > 16) { - lua_pushnil(L); - lua_pushstring(L, "AES GCM tag must be 12-16 bytes"); - return 2; - } - } - - int ret = 0; - unsigned char *output = NULL; - - if (is_cbc) { - // Ciphertext must be a multiple of block size - if (ctlen == 0 || (ctlen % 16) != 0) { - lua_pushnil(L); - lua_pushstring(L, "Ciphertext length must be a multiple of 16"); - return 2; - } - output = malloc(ctlen); - if (!output) { - lua_pushnil(L); - lua_pushstring(L, "Memory allocation failed"); - return 2; - } - mbedtls_aes_context aes; - mbedtls_aes_init(&aes); - ret = mbedtls_aes_setkey_dec(&aes, key, keylen * 8); - if (ret != 0) { - free(output); - mbedtls_aes_free(&aes); - lua_pushnil(L); - lua_pushstring(L, "Failed to set AES decryption key"); - return 2; - } - unsigned char iv_copy[16]; - memcpy(iv_copy, iv, 16); - ret = mbedtls_aes_crypt_cbc(&aes, MBEDTLS_AES_DECRYPT, ctlen, iv_copy, ciphertext, output); - mbedtls_aes_free(&aes); - if (ret != 0) { - free(output); - lua_pushnil(L); - lua_pushstring(L, "AES CBC decryption failed"); - return 2; - } - // PKCS7 unpadding - if (ctlen == 0) { - free(output); - lua_pushnil(L); - lua_pushstring(L, "Decrypted data is empty"); - return 2; - } - unsigned char pad = output[ctlen - 1]; - if (pad == 0 || pad > 16) { - free(output); - lua_pushnil(L); - lua_pushstring(L, "Invalid PKCS7 padding"); - return 2; - } - for (size_t i = 0; i < pad; ++i) { - if (output[ctlen - 1 - i] != pad) { - free(output); - lua_pushnil(L); - lua_pushstring(L, "Invalid PKCS7 padding"); - return 2; - } - } - size_t ptlen = ctlen - pad; - lua_pushlstring(L, (const char *)output, ptlen); - free(output); - return 1; - } else if (is_ctr) { - // CTR mode: no padding - output = malloc(ctlen); - if (!output) { - lua_pushnil(L); - lua_pushstring(L, "Memory allocation failed"); - return 2; - } - mbedtls_aes_context aes; - mbedtls_aes_init(&aes); - ret = mbedtls_aes_setkey_enc(&aes, key, keylen * 8); - if (ret != 0) { - free(output); - mbedtls_aes_free(&aes); - lua_pushnil(L); - lua_pushstring(L, "Failed to set AES encryption key"); - return 2; - } - unsigned char nonce_counter[16]; - unsigned char stream_block[16]; - size_t nc_off = 0; - memcpy(nonce_counter, iv, 16); - memset(stream_block, 0, 16); - ret = mbedtls_aes_crypt_ctr(&aes, ctlen, &nc_off, nonce_counter, stream_block, ciphertext, output); - mbedtls_aes_free(&aes); - if (ret != 0) { - free(output); - lua_pushnil(L); - lua_pushstring(L, "AES CTR decryption failed"); - return 2; - } - lua_pushlstring(L, (const char *)output, ctlen); - free(output); - return 1; - } else if (is_gcm) { - // GCM mode: authenticated decryption - output = malloc(ctlen); - if (!output) { - lua_pushnil(L); - lua_pushstring(L, "Memory allocation failed"); - return 2; - } - mbedtls_gcm_context gcm; - mbedtls_gcm_init(&gcm); - ret = mbedtls_gcm_setkey(&gcm, MBEDTLS_CIPHER_ID_AES, key, keylen * 8); - if (ret != 0) { - free(output); - mbedtls_gcm_free(&gcm); - lua_pushnil(L); - lua_pushstring(L, "Failed to set AES GCM key"); - return 2; - } - ret = mbedtls_gcm_auth_decrypt(&gcm, ctlen, iv, ivlen, aad, aadlen, tag, taglen, ciphertext, output); - mbedtls_gcm_free(&gcm); - if (ret != 0) { - free(output); - lua_pushnil(L); - lua_pushstring(L, "AES GCM decryption failed or authentication failed"); - return 2; - } - lua_pushlstring(L, (const char *)output, ctlen); - free(output); - return 1; - } + // Args: key, ciphertext, options table + size_t keylen, ctlen; + const unsigned char *key = + (const unsigned char *)luaL_checklstring(L, 1, &keylen); + const unsigned char *ciphertext = + (const unsigned char *)luaL_checklstring(L, 2, &ctlen); + int options_idx = 3; + aes_options_t opts; + parse_aes_options(L, options_idx, &opts); + const char *mode = opts.mode; + const unsigned char *iv = opts.iv; + size_t ivlen = opts.ivlen; + const unsigned char *tag = opts.tag; + size_t taglen = opts.taglen; + const unsigned char *aad = opts.aad; + size_t aadlen = opts.aadlen; + int is_gcm = 0, is_ctr = 0, is_cbc = 0; + if (strcasecmp(mode, "cbc") == 0) { + is_cbc = 1; + } else if (strcasecmp(mode, "gcm") == 0) { + is_gcm = 1; + } else if (strcasecmp(mode, "ctr") == 0) { + is_ctr = 1; + } else { lua_pushnil(L); - lua_pushstring(L, "Internal error in AES decrypt"); + lua_pushstring(L, "Unsupported AES mode. Use 'cbc', 'gcm', or 'ctr'."); return 2; + } + // Validate key length (16, 24, 32 bytes) + if (keylen != 16 && keylen != 24 && keylen != 32) { + lua_pushnil(L); + lua_pushstring(L, "AES key must be 16, 24, or 32 bytes"); + return 2; + } + // Validate IV/nonce length + if (is_cbc || is_ctr) { + if (ivlen != 16) { + lua_pushnil(L); + lua_pushstring(L, "AES IV/nonce must be 16 bytes for CBC/CTR"); + return 2; + } + } else if (is_gcm) { + if (ivlen < 12 || ivlen > 16) { + lua_pushnil(L); + lua_pushstring(L, "AES GCM nonce must be 12-16 bytes"); + return 2; + } + } + + // GCM: require tag and optional AAD + if (is_gcm) { + if (!tag || taglen < 12 || taglen > 16) { + lua_pushnil(L); + lua_pushstring(L, "AES GCM tag must be 12-16 bytes"); + return 2; + } + } + + int ret = 0; + unsigned char *output = NULL; + + if (is_cbc) { + // Ciphertext must be a multiple of block size + if (ctlen == 0 || (ctlen % 16) != 0) { + lua_pushnil(L); + lua_pushstring(L, "Ciphertext length must be a multiple of 16"); + return 2; + } + output = malloc(ctlen); + if (!output) { + lua_pushnil(L); + lua_pushstring(L, "Memory allocation failed"); + return 2; + } + mbedtls_aes_context aes; + mbedtls_aes_init(&aes); + ret = mbedtls_aes_setkey_dec(&aes, key, keylen * 8); + if (ret != 0) { + free(output); + mbedtls_aes_free(&aes); + lua_pushnil(L); + lua_pushstring(L, "Failed to set AES decryption key"); + return 2; + } + unsigned char iv_copy[16]; + memcpy(iv_copy, iv, 16); + ret = mbedtls_aes_crypt_cbc(&aes, MBEDTLS_AES_DECRYPT, ctlen, iv_copy, + ciphertext, output); + mbedtls_aes_free(&aes); + if (ret != 0) { + free(output); + lua_pushnil(L); + lua_pushstring(L, "AES CBC decryption failed"); + return 2; + } + // PKCS7 unpadding + if (ctlen == 0) { + free(output); + lua_pushnil(L); + lua_pushstring(L, "Decrypted data is empty"); + return 2; + } + unsigned char pad = output[ctlen - 1]; + if (pad == 0 || pad > 16) { + free(output); + lua_pushnil(L); + lua_pushstring(L, "Invalid PKCS7 padding"); + return 2; + } + for (size_t i = 0; i < pad; ++i) { + if (output[ctlen - 1 - i] != pad) { + free(output); + lua_pushnil(L); + lua_pushstring(L, "Invalid PKCS7 padding"); + return 2; + } + } + size_t ptlen = ctlen - pad; + lua_pushlstring(L, (const char *)output, ptlen); + free(output); + return 1; + } else if (is_ctr) { + // CTR mode: no padding + output = malloc(ctlen); + if (!output) { + lua_pushnil(L); + lua_pushstring(L, "Memory allocation failed"); + return 2; + } + mbedtls_aes_context aes; + mbedtls_aes_init(&aes); + ret = mbedtls_aes_setkey_enc(&aes, key, keylen * 8); + if (ret != 0) { + free(output); + mbedtls_aes_free(&aes); + lua_pushnil(L); + lua_pushstring(L, "Failed to set AES encryption key"); + return 2; + } + unsigned char nonce_counter[16]; + unsigned char stream_block[16]; + size_t nc_off = 0; + memcpy(nonce_counter, iv, 16); + memset(stream_block, 0, 16); + ret = mbedtls_aes_crypt_ctr(&aes, ctlen, &nc_off, nonce_counter, + stream_block, ciphertext, output); + mbedtls_aes_free(&aes); + if (ret != 0) { + free(output); + lua_pushnil(L); + lua_pushstring(L, "AES CTR decryption failed"); + return 2; + } + lua_pushlstring(L, (const char *)output, ctlen); + free(output); + return 1; + } else if (is_gcm) { + // GCM mode: authenticated decryption + output = malloc(ctlen); + if (!output) { + lua_pushnil(L); + lua_pushstring(L, "Memory allocation failed"); + return 2; + } + mbedtls_gcm_context gcm; + mbedtls_gcm_init(&gcm); + ret = mbedtls_gcm_setkey(&gcm, MBEDTLS_CIPHER_ID_AES, key, keylen * 8); + if (ret != 0) { + free(output); + mbedtls_gcm_free(&gcm); + lua_pushnil(L); + lua_pushstring(L, "Failed to set AES GCM key"); + return 2; + } + ret = mbedtls_gcm_auth_decrypt(&gcm, ctlen, iv, ivlen, aad, aadlen, tag, + taglen, ciphertext, output); + mbedtls_gcm_free(&gcm); + if (ret != 0) { + free(output); + lua_pushnil(L); + lua_pushstring(L, "AES GCM decryption failed or authentication failed"); + return 2; + } + lua_pushlstring(L, (const char *)output, ctlen); + free(output); + return 1; + } + lua_pushnil(L); + lua_pushstring(L, "Internal error in AES decrypt"); + return 2; } // LuaCrypto compatible API static int LuaCryptoSign(lua_State *L) { - const char *dtype = luaL_checkstring(L, 1); // Type of signature (e.g., "rsa", "ecdsa") - lua_remove(L, 1); // Remove the first argument (key type or cipher type) before dispatching + const char *dtype = + luaL_checkstring(L, 1); // Type of signature (e.g., "rsa", "ecdsa") + lua_remove(L, 1); // Remove the first argument (key type or cipher type) + // before dispatching - if (strcasecmp(dtype, "rsa") == 0) { - return LuaRSASign(L); - } else if (strcasecmp(dtype, "ecdsa") == 0) { - return LuaECDSASign(L); - } else { - return luaL_error(L, "Unsupported signature type: %s", dtype); - } + if (strcasecmp(dtype, "rsa") == 0) { + return LuaRSASign(L); + } else if (strcasecmp(dtype, "ecdsa") == 0) { + return LuaECDSASign(L); + } else { + return luaL_error(L, "Unsupported signature type: %s", dtype); + } } static int LuaCryptoVerify(lua_State *L) { - const char *dtype = luaL_checkstring(L, 1); // Type of signature (e.g., "rsa", "ecdsa") - lua_remove(L, 1); // Remove the first argument (key type or cipher type) before dispatching + const char *dtype = + luaL_checkstring(L, 1); // Type of signature (e.g., "rsa", "ecdsa") + lua_remove(L, 1); // Remove the first argument (key type or cipher type) + // before dispatching - if (strcasecmp(dtype, "rsa") == 0) { - return LuaRSAVerify(L); - } else if (strcasecmp(dtype, "ecdsa") == 0) { - return LuaECDSAVerify(L); - } else { - return luaL_error(L, "Unsupported signature type: %s", dtype); - } + if (strcasecmp(dtype, "rsa") == 0) { + return LuaRSAVerify(L); + } else if (strcasecmp(dtype, "ecdsa") == 0) { + return LuaECDSAVerify(L); + } else { + return luaL_error(L, "Unsupported signature type: %s", dtype); + } } static int LuaCryptoEncrypt(lua_State *L) { - // Args: cipher_type, key, msg, options table - const char *cipher = luaL_checkstring(L, 1); - // Remove cipher_type from stack, so key is at 1, msg at 2, options at 3 - lua_remove(L, 1); - if (strcasecmp(cipher, "rsa") == 0) { - // Update LuaRSAEncrypt to accept (key, msg, options) - return LuaRSAEncrypt(L); - } else if (strcasecmp(cipher, "aes") == 0) { - return LuaAesEncrypt(L); - } else { - return luaL_error(L, "Unsupported cipher type: %s", cipher); - } + // Args: cipher_type, key, msg, options table + const char *cipher = luaL_checkstring(L, 1); + lua_remove(L, 1); // Remove cipher_type from stack, so key is at 1, msg at 2, + // options at 3 + + if (strcasecmp(cipher, "rsa") == 0) { + return LuaRSAEncrypt(L); + } else if (strcasecmp(cipher, "aes") == 0) { + return LuaAesEncrypt(L); + } else { + return luaL_error(L, "Unsupported cipher type: %s", cipher); + } } static int LuaCryptoDecrypt(lua_State *L) { - // Args: cipher_type, key, ciphertext, options table - const char *cipher = luaL_checkstring(L, 1); - lua_remove(L, 1); // Remove cipher_type, so key is at 1, ciphertext at 2, options at 3 - if (strcasecmp(cipher, "rsa") == 0) { - return LuaRSADecrypt(L); - } else if (strcasecmp(cipher, "aes") == 0) { - return LuaAesDecrypt(L); - } else { - return luaL_error(L, "Unsupported cipher type: %s", cipher); - } + // Args: cipher_type, key, ciphertext, options table + const char *cipher = luaL_checkstring(L, 1); + lua_remove( + L, + 1); // Remove cipher_type, so key is at 1, ciphertext at 2, options at 3 + + if (strcasecmp(cipher, "rsa") == 0) { + return LuaRSADecrypt(L); + } else if (strcasecmp(cipher, "aes") == 0) { + return LuaAesDecrypt(L); + } else { + return luaL_error(L, "Unsupported cipher type: %s", cipher); + } } static int LuaCryptoGenerateKeyPair(lua_State *L) { - // If the first argument is a number, treat as RSA key length - if (lua_gettop(L) >= 1 && lua_type(L, 1) == LUA_TNUMBER) { - // Call LuaRSAGenerateKeyPair with the number as the key length - return LuaRSAGenerateKeyPair(L); - } - // Otherwise, get the key type from the first argument, default to "rsa" if not provided - const char *type = luaL_optstring(L, 1, "rsa"); - lua_remove(L, 1); - if (strcasecmp(type, "rsa") == 0) { - return LuaRSAGenerateKeyPair(L); - } else if (strcasecmp(type, "ecdsa") == 0) { - return LuaECDSAGenerateKeyPair(L); - } else if (strcasecmp(type, "aes") == 0) { - return LuaAesGenerateKey(L); - } else { - return luaL_error(L, "Unsupported key type: %s", type); - } + // If the first argument is a number, treat it as RSA key length + if (lua_gettop(L) >= 1 && lua_type(L, 1) == LUA_TNUMBER) { + // Call LuaRSAGenerateKeyPair with the number as the key length + return LuaRSAGenerateKeyPair(L); + } + // Otherwise, get the key type from the first argument, default to "rsa" if + // not provided + const char *type = luaL_optstring(L, 1, "rsa"); + lua_remove(L, 1); + if (strcasecmp(type, "rsa") == 0) { + return LuaRSAGenerateKeyPair(L); + } else if (strcasecmp(type, "ecdsa") == 0) { + return LuaECDSAGenerateKeyPair(L); + } else if (strcasecmp(type, "aes") == 0) { + return LuaAesGenerateKey(L); + } else { + return luaL_error(L, "Unsupported key type: %s", type); + } } static const luaL_Reg kLuaCrypto[] = { - {"sign", LuaCryptoSign}, // - {"verify", LuaCryptoVerify}, // - {"encrypt", LuaCryptoEncrypt}, // - {"decrypt", LuaCryptoDecrypt}, // - {"generatekeypair", LuaCryptoGenerateKeyPair}, // - {"convertPemToJwk", LuaConvertPemToJwk}, // - {"generateCsr", LuaGenerateCSR}, // - {0}, // + {"sign", LuaCryptoSign}, // + {"verify", LuaCryptoVerify}, // + {"encrypt", LuaCryptoEncrypt}, // + {"decrypt", LuaCryptoDecrypt}, // + {"generatekeypair", LuaCryptoGenerateKeyPair}, // + {"convertPemToJwk", LuaConvertPemToJwk}, // + {"generateCsr", LuaGenerateCSR}, // + {0}, // }; int LuaCrypto(lua_State *L) { From a603cc90aecde7f3b75b4507565d1a191c1231be Mon Sep 17 00:00:00 2001 From: Miguel Terron Date: Thu, 5 Jun 2025 17:35:40 +1200 Subject: [PATCH 14/18] Use built-in entropy generator and remove dependency on GetHardRandom Add correct jwltopem and pemtojwk function Expand tests --- test/tool/net/lcrypto_test.lua | 35 ++ tool/net/lcrypto.c | 814 ++++++++++++++++++++++----------- 2 files changed, 577 insertions(+), 272 deletions(-) diff --git a/test/tool/net/lcrypto_test.lua b/test/tool/net/lcrypto_test.lua index 077cd03c3..0516a6ec2 100644 --- a/test/tool/net/lcrypto_test.lua +++ b/test/tool/net/lcrypto_test.lua @@ -166,6 +166,40 @@ local function test_pem_to_jwk() assert_equal(pub_jwk.kty, "EC", "kty is correct") end +-- Test JwkToPem conversion +local function test_jwk_to_pem() + local priv_key, pub_key = crypto.generatekeypair() + local priv_jwk = crypto.convertPemToJwk(priv_key) + local pub_jwk = crypto.convertPemToJwk(pub_key) + + local priv_pem = crypto.convertJwkToPem(priv_jwk) + local pub_pem = crypto.convertJwkToPem(pub_jwk) + assert_equal(type(priv_pem), "string", "Private PEM type") + + -- Roundtrip + assert_equal(priv_key,priv_pem, "Private PEM matches original RSA key") + assert_equal(pub_key,pub_pem, "Public PEM matches original RSA key") + + local pub_pem = crypto.convertJwkToPem(pub_jwk) + assert_equal(type(pub_pem), "string", "Public PEM type") + + -- Test ECDSA keys + local priv_key, pub_key = crypto.generatekeypair('ecdsa') + local priv_jwk = crypto.convertPemToJwk(priv_key) + local pub_jwk = crypto.convertPemToJwk(pub_key) + + local priv_pem = crypto.convertJwkToPem(priv_jwk) + local pub_pem = crypto.convertJwkToPem(pub_jwk) + assert_equal(type(priv_pem), "string", "Private PEM type for ECDSA") + + -- Roundtrip + assert_equal(priv_key,priv_pem, "Private PEM matches original ECDSA key") + assert_equal(pub_key,pub_pem, "Public PEM matches original ECDSA key") + + local pub_pem = crypto.convertJwkToPem(pub_jwk) + assert_equal(type(pub_pem), "string", "Public PEM type for ECDSA") +end + -- Test CSR generation local function test_csr_generation() local priv_key, _ = crypto.generatekeypair() @@ -205,6 +239,7 @@ local function run_tests() test_aes_encryption_decryption_ctr() test_aes_encryption_decryption_gcm() test_pem_to_jwk() + test_jwk_to_pem() test_csr_generation() EXIT = 0 return EXIT diff --git a/tool/net/lcrypto.c b/tool/net/lcrypto.c index 2b4ba2fb6..e50304fe4 100644 --- a/tool/net/lcrypto.c +++ b/tool/net/lcrypto.c @@ -1,6 +1,24 @@ -#include "libc/log/log.h" -#include "net/https/https.h" -#include "third_party/lua/lauxlib.h" +/*-*- mode:c;indent-tabs-mode:nil;c-basic-offset:2;tab-width:8;coding:utf-8 -*-│ +│ vi: set et ft=c ts=2 sts=2 sw=2 fenc=utf-8 :vi │ +╞══════════════════════════════════════════════════════════════════════════════╡ +│ Copyright 2025 Miguel Angel Terron │ +│ │ +│ Permission to use, copy, modify, and/or distribute this software for │ +│ any purpose with or without fee is hereby granted, provided that the │ +│ above copyright notice and this permission notice appear in all copies. │ +│ │ +│ THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL │ +│ WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED │ +│ WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE │ +│ AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL │ +│ DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR │ +│ PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER │ +│ TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR │ +│ PERFORMANCE OF THIS SOFTWARE. │ +╚─────────────────────────────────────────────────────────────────────────────*/ + +#include "tool/net/luacheck.h" +// mbedTLS #include "third_party/mbedtls/aes.h" #include "third_party/mbedtls/base64.h" #include "third_party/mbedtls/ctr_drbg.h" @@ -14,227 +32,36 @@ #include "third_party/mbedtls/rsa.h" #include "third_party/mbedtls/x509_csr.h" -// Standard C library and redbean utilities -#include "libc/errno.h" -#include "libc/mem/mem.h" -#include "libc/str/str.h" -#include "tool/net/luacheck.h" - -// Parse PEM keys and convert them into JWK format -static int LuaConvertPemToJwk(lua_State *L) { - const char *pem_key = luaL_checkstring(L, 1); - - mbedtls_pk_context key; - mbedtls_pk_init(&key); +// Strong RNG using mbedtls_entropy_context and mbedtls_ctr_drbg_context +int GenerateRandom(void *ctx, unsigned char *output, size_t len) { + static mbedtls_entropy_context entropy; + static mbedtls_ctr_drbg_context ctr_drbg; + static int initialized = 0; int ret; + const char *pers = "redbean_entropy"; - // Parse the PEM key - if ((ret = mbedtls_pk_parse_key(&key, (const unsigned char *)pem_key, - strlen(pem_key) + 1, NULL, 0)) != 0 && - (ret = mbedtls_pk_parse_public_key(&key, (const unsigned char *)pem_key, - strlen(pem_key) + 1)) != 0) { - lua_pushnil(L); - lua_pushfstring(L, "Failed to parse PEM key: -0x%04x", -ret); - mbedtls_pk_free(&key); - return 2; - } + if (!initialized) { + mbedtls_entropy_init(&entropy); + mbedtls_ctr_drbg_init(&ctr_drbg); - lua_newtable(L); // Create a new Lua table - - if (mbedtls_pk_get_type(&key) == MBEDTLS_PK_RSA) { - // Handle RSA keys - const mbedtls_rsa_context *rsa = mbedtls_pk_rsa(key); - size_t n_len = mbedtls_mpi_size(&rsa->N); - size_t e_len = mbedtls_mpi_size(&rsa->E); - - unsigned char *n = malloc(n_len); - unsigned char *e = malloc(e_len); - - if (!n || !e) { - lua_pushnil(L); - lua_pushstring(L, "Memory allocation failed"); - free(n); - free(e); - mbedtls_pk_free(&key); - return 2; + ret = mbedtls_ctr_drbg_seed(&ctr_drbg, mbedtls_entropy_func, &entropy, + (const unsigned char *)pers, strlen(pers)); + if (ret != 0) { + // Clean up on failure + mbedtls_ctr_drbg_free(&ctr_drbg); + mbedtls_entropy_free(&entropy); + return -1; } - - mbedtls_mpi_write_binary(&rsa->N, n, n_len); - mbedtls_mpi_write_binary(&rsa->E, e, e_len); - - char *n_b64 = NULL, *e_b64 = NULL; - size_t n_b64_len, e_b64_len; - - mbedtls_base64_encode(NULL, 0, &n_b64_len, n, n_len); - mbedtls_base64_encode(NULL, 0, &e_b64_len, e, e_len); - - n_b64 = malloc(n_b64_len + 1); - e_b64 = malloc(e_b64_len + 1); - - if (!n_b64 || !e_b64) { - lua_pushnil(L); - lua_pushstring(L, "Memory allocation failed"); - free(n); - free(e); - free(n_b64); - free(e_b64); - mbedtls_pk_free(&key); - return 2; - } - - mbedtls_base64_encode((unsigned char *)n_b64, n_b64_len, &n_b64_len, n, - n_len); - mbedtls_base64_encode((unsigned char *)e_b64, e_b64_len, &e_b64_len, e, - e_len); - - n_b64[n_b64_len] = '\0'; - e_b64[e_b64_len] = '\0'; - - lua_pushstring(L, "RSA"); - lua_setfield(L, -2, "kty"); - lua_pushstring(L, n_b64); - lua_setfield(L, -2, "n"); - lua_pushstring(L, e_b64); - lua_setfield(L, -2, "e"); - - free(n); - free(e); - free(n_b64); - free(e_b64); - } else if (mbedtls_pk_get_type(&key) == MBEDTLS_PK_ECKEY) { - // Handle ECDSA keys - const mbedtls_ecp_keypair *ec = mbedtls_pk_ec(key); - const mbedtls_ecp_point *Q = &ec->Q; - size_t x_len = (ec->grp.pbits + 7) / 8; - size_t y_len = (ec->grp.pbits + 7) / 8; - - unsigned char *x = malloc(x_len); - unsigned char *y = malloc(y_len); - - if (!x || !y) { - lua_pushnil(L); - lua_pushstring(L, "Memory allocation failed"); - free(x); - free(y); - mbedtls_pk_free(&key); - return 2; - } - - mbedtls_mpi_write_binary(&Q->X, x, x_len); - mbedtls_mpi_write_binary(&Q->Y, y, y_len); - - char *x_b64 = NULL, *y_b64 = NULL; - size_t x_b64_len, y_b64_len; - - mbedtls_base64_encode(NULL, 0, &x_b64_len, x, x_len); - mbedtls_base64_encode(NULL, 0, &y_b64_len, y, y_len); - - x_b64 = malloc(x_b64_len + 1); - y_b64 = malloc(y_b64_len + 1); - - if (!x_b64 || !y_b64) { - lua_pushnil(L); - lua_pushstring(L, "Memory allocation failed"); - free(x); - free(y); - free(x_b64); - free(y_b64); - mbedtls_pk_free(&key); - return 2; - } - - mbedtls_base64_encode((unsigned char *)x_b64, x_b64_len, &x_b64_len, x, - x_len); - mbedtls_base64_encode((unsigned char *)y_b64, y_b64_len, &y_b64_len, y, - y_len); - - x_b64[x_b64_len] = '\0'; - y_b64[y_b64_len] = '\0'; - - lua_pushstring(L, "EC"); - lua_setfield(L, -2, "kty"); - lua_pushstring(L, mbedtls_ecp_curve_info_from_grp_id(ec->grp.id)->name); - lua_setfield(L, -2, "crv"); - lua_pushstring(L, x_b64); - lua_setfield(L, -2, "x"); - lua_pushstring(L, y_b64); - lua_setfield(L, -2, "y"); - - free(x); - free(y); - free(x_b64); - free(y_b64); - } else { - lua_pushnil(L); - lua_pushstring(L, "Unsupported key type"); - mbedtls_pk_free(&key); - return 2; + initialized = 1; } - - mbedtls_pk_free(&key); - return 1; -} - -// CSR Creation Function -static int LuaGenerateCSR(lua_State *L) { - const char *key_pem = luaL_checkstring(L, 1); - const char *subject_name; - const char *san_list = luaL_optstring(L, 3, NULL); - - if (lua_isnoneornil(L, 2)) { - subject_name = ""; - } else { - subject_name = luaL_checkstring(L, 2); + // mbedtls_ctr_drbg_random returns 0 on success + ret = mbedtls_ctr_drbg_random(&ctr_drbg, output, len); + if (ret != 0) { + // If DRBG fails, reinitialize on next call + initialized = 0; + return -1; } - - if (lua_isnoneornil(L, 3) && subject_name[0] == '\0') { - lua_pushnil(L); - lua_pushstring(L, "Subject name or SANs are required"); - return 2; - } - mbedtls_pk_context key; - mbedtls_x509write_csr req; - char buf[4096]; - int ret; - - mbedtls_pk_init(&key); - mbedtls_x509write_csr_init(&req); - - if ((ret = mbedtls_pk_parse_key(&key, (const unsigned char *)key_pem, - strlen(key_pem) + 1, NULL, 0)) != 0) { - lua_pushnil(L); - lua_pushfstring(L, "Failed to parse key: %d", ret); - return 2; - } - - mbedtls_x509write_csr_set_subject_name(&req, subject_name); - mbedtls_x509write_csr_set_key(&req, &key); - mbedtls_x509write_csr_set_md_alg(&req, MBEDTLS_MD_SHA256); - - if (san_list) { - if ((ret = mbedtls_x509write_csr_set_extension( - &req, MBEDTLS_OID_SUBJECT_ALT_NAME, - MBEDTLS_OID_SIZE(MBEDTLS_OID_SUBJECT_ALT_NAME), - (const unsigned char *)san_list, strlen(san_list))) != 0) { - lua_pushnil(L); - lua_pushfstring(L, "Failed to set SANs: %d", ret); - return 2; - } - } - - if ((ret = mbedtls_x509write_csr_pem(&req, (unsigned char *)buf, sizeof(buf), - NULL, NULL)) < 0) { - lua_pushnil(L); - lua_pushfstring(L, "Failed to write CSR: %d", ret); - return 2; - } - - lua_pushstring(L, buf); - - mbedtls_pk_free(&key); - mbedtls_x509write_csr_free(&req); - - return 1; + return 0; } // RSA @@ -256,7 +83,7 @@ static bool RSAGenerateKeyPair(char **private_key_pem, size_t *private_key_len, } // Generate RSA key - if ((rc = mbedtls_rsa_gen_key(mbedtls_pk_rsa(key), GenerateHardRandom, 0, + if ((rc = mbedtls_rsa_gen_key(mbedtls_pk_rsa(key), GenerateRandom, 0, key_length, 65537)) != 0) { WARNF("Failed to generate key (grep -0x%04x)", -rc); mbedtls_pk_free(&key); @@ -368,8 +195,8 @@ static char *RSAEncrypt(const char *public_key_pem, const unsigned char *data, } // Encrypt data - if ((rc = mbedtls_rsa_pkcs1_encrypt(mbedtls_pk_rsa(key), GenerateHardRandom, - 0, MBEDTLS_RSA_PUBLIC, data_len, data, + if ((rc = mbedtls_rsa_pkcs1_encrypt(mbedtls_pk_rsa(key), GenerateRandom, 0, + MBEDTLS_RSA_PUBLIC, data_len, data, output)) != 0) { WARNF("Encryption failed (grep -0x%04x)", -rc); free(output); @@ -436,8 +263,8 @@ static char *RSADecrypt(const char *private_key_pem, // Decrypt data size_t output_len = 0; - if ((rc = mbedtls_rsa_pkcs1_decrypt(mbedtls_pk_rsa(key), GenerateHardRandom, - 0, MBEDTLS_RSA_PRIVATE, &output_len, + if ((rc = mbedtls_rsa_pkcs1_decrypt(mbedtls_pk_rsa(key), GenerateRandom, 0, + MBEDTLS_RSA_PRIVATE, &output_len, encrypted_data, output, key_size)) != 0) { WARNF("Decryption failed (grep -0x%04x)", -rc); free(output); @@ -531,7 +358,7 @@ static char *RSASign(const char *private_key_pem, const unsigned char *data, // Sign the hash if ((rc = mbedtls_pk_sign(&key, hash_algo, hash, hash_len, signature, sig_len, - GenerateHardRandom, 0)) != 0) { + GenerateRandom, 0)) != 0) { free(signature); mbedtls_pk_free(&key); return NULL; @@ -742,6 +569,9 @@ static int LuaListHashAlgorithms(lua_State *L) { lua_pushstring(L, "SHA-512"); lua_rawseti(L, -2, 6); + lua_pushstring(L, "MD5"); + lua_rawseti(L, -2, 7); + return 1; } @@ -861,8 +691,7 @@ static int ECDSAGenerateKeyPair(const char *curve_name, char **priv_key_pem, goto cleanup; } - ret = - mbedtls_ecp_gen_key(curve_id, mbedtls_pk_ec(key), GenerateHardRandom, 0); + ret = mbedtls_ecp_gen_key(curve_id, mbedtls_pk_ec(key), GenerateRandom, 0); if (ret != 0) { WARNF("(ecdsa) Failed to generate key: -0x%04x", -ret); goto cleanup; @@ -915,7 +744,6 @@ cleanup: } return ret; } -// Lua binding for generating ECDSA keys static int LuaECDSAGenerateKeyPair(lua_State *L) { const char *curve_name = NULL; char *priv_key_pem = NULL; @@ -925,7 +753,6 @@ static int LuaECDSAGenerateKeyPair(lua_State *L) { if (lua_gettop(L) >= 1 && !lua_isnil(L, 1)) { curve_name = luaL_checkstring(L, 1); } - // If not provided, generate_key_pem will use the default int ret = ECDSAGenerateKeyPair(curve_name, &priv_key_pem, &pub_key_pem); @@ -996,9 +823,9 @@ static int ECDSASign(const char *priv_key_pem, const char *message, goto cleanup; } - // Sign the hash using GenerateHardRandom + // Sign the hash using GenerateRandom ret = mbedtls_pk_sign(&key, hash_to_md_type(hash_alg), hash, hash_size, - *signature, sig_len, GenerateHardRandom, 0); + *signature, sig_len, GenerateRandom, 0); if (ret != 0) { WARNF("(ecdsa) Failed to sign message: -0x%04x", -ret); @@ -1122,26 +949,10 @@ static int LuaAesGenerateKey(lua_State *L) { return 2; } unsigned char key[32]; - mbedtls_entropy_context entropy; - mbedtls_ctr_drbg_context ctr_drbg; - mbedtls_entropy_init(&entropy); - mbedtls_ctr_drbg_init(&ctr_drbg); - const char *pers = "aes_keygen"; - int ret = mbedtls_ctr_drbg_seed(&ctr_drbg, mbedtls_entropy_func, &entropy, - (const unsigned char *)pers, strlen(pers)); - if (ret != 0) { + // Generate random key + if (GenerateRandom(NULL, key, keylen) != 0) { lua_pushnil(L); - lua_pushstring(L, "Failed to initialize RNG for AES key"); - mbedtls_ctr_drbg_free(&ctr_drbg); - mbedtls_entropy_free(&entropy); - return 2; - } - ret = mbedtls_ctr_drbg_random(&ctr_drbg, key, keylen); - mbedtls_ctr_drbg_free(&ctr_drbg); - mbedtls_entropy_free(&entropy); - if (ret != 0) { - lua_pushnil(L); - lua_pushstring(L, "Failed to generate random AES key"); + lua_pushstring(L, "Failed to generate random key"); return 2; } lua_pushlstring(L, (const char *)key, keylen); @@ -1234,14 +1045,13 @@ static int LuaAesEncrypt(lua_State *L) { lua_pushstring(L, "Failed to allocate IV"); return 2; } - mbedtls_entropy_context entropy; - mbedtls_ctr_drbg_context ctr_drbg; - mbedtls_entropy_init(&entropy); - mbedtls_ctr_drbg_init(&ctr_drbg); - mbedtls_ctr_drbg_seed(&ctr_drbg, mbedtls_entropy_func, &entropy, NULL, 0); - mbedtls_ctr_drbg_random(&ctr_drbg, gen_iv, ivlen); - mbedtls_ctr_drbg_free(&ctr_drbg); - mbedtls_entropy_free(&entropy); + // Generate random IV + if (GenerateRandom(NULL, gen_iv, ivlen) != 0) { + free(gen_iv); + lua_pushnil(L); + lua_pushstring(L, "Failed to generate random IV"); + return 2; + } iv = gen_iv; iv_was_generated = 1; } @@ -1568,12 +1378,470 @@ static int LuaAesDecrypt(lua_State *L) { return 2; } + +// JWK functions + +// Convert JWK (Lua table) to PEM key format +static int LuaConvertJwkToPem(lua_State *L) { + luaL_checktype(L, 1, LUA_TTABLE); + const char *kty; + lua_getfield(L, 1, "kty"); + kty = lua_tostring(L, -1); + if (!kty) { + lua_pushnil(L); + lua_pushstring(L, "Missing 'kty' in JWK"); + return 2; + } + + int ret = -1; + char *pem = NULL; + mbedtls_pk_context pk; + mbedtls_pk_init(&pk); + + if (strcasecmp(kty, "RSA") == 0) { + // RSA JWK: n, e (base64url), optionally d, p, q, dp, dq, qi + lua_getfield(L, 1, "n"); + lua_getfield(L, 1, "e"); + const char *n_b64 = lua_tostring(L, -2); + const char *e_b64 = lua_tostring(L, -1); + // Optional private fields + lua_getfield(L, 1, "d"); + lua_getfield(L, 1, "p"); + lua_getfield(L, 1, "q"); + lua_getfield(L, 1, "dp"); + lua_getfield(L, 1, "dq"); + lua_getfield(L, 1, "qi"); + const char *d_b64 = lua_tostring(L, -6); + const char *p_b64 = lua_tostring(L, -5); + const char *q_b64 = lua_tostring(L, -4); + const char *dp_b64 = lua_tostring(L, -3); + const char *dq_b64 = lua_tostring(L, -2); + const char *qi_b64 = lua_tostring(L, -1); + int has_private = d_b64 && *d_b64; + // Decode base64url to binary + size_t n_len, e_len; + unsigned char n_bin[1024], e_bin[16]; + char *n_b64_std = strdup(n_b64), *e_b64_std = strdup(e_b64); + for (char *p = n_b64_std; *p; ++p) if (*p == '-') *p = '+'; else if (*p == '_') *p = '/'; + for (char *p = e_b64_std; *p; ++p) if (*p == '-') *p = '+'; else if (*p == '_') *p = '/'; + int n_mod = strlen(n_b64_std) % 4; + int e_mod = strlen(e_b64_std) % 4; + if (n_mod) for (int i = 0; i < 4 - n_mod; ++i) strcat(n_b64_std, "="); + if (e_mod) for (int i = 0; i < 4 - e_mod; ++i) strcat(e_b64_std, "="); + if (mbedtls_base64_decode(n_bin, sizeof(n_bin), &n_len, (const unsigned char *)n_b64_std, strlen(n_b64_std)) != 0 || + mbedtls_base64_decode(e_bin, sizeof(e_bin), &e_len, (const unsigned char *)e_b64_std, strlen(e_b64_std)) != 0) { + free(n_b64_std); free(e_b64_std); + lua_pushnil(L); + lua_pushstring(L, "Base64 decode failed"); + return 2; + } + free(n_b64_std); free(e_b64_std); + // Build RSA context in pk + if ((ret = mbedtls_pk_setup(&pk, mbedtls_pk_info_from_type(MBEDTLS_PK_RSA))) != 0) { + lua_pushnil(L); lua_pushstring(L, "mbedtls_pk_setup failed"); return 2; + } + mbedtls_rsa_context *rsa = mbedtls_pk_rsa(pk); + mbedtls_rsa_init(rsa, MBEDTLS_RSA_PKCS_V15, 0); + mbedtls_mpi_read_binary(&rsa->N, n_bin, n_len); + mbedtls_mpi_read_binary(&rsa->E, e_bin, e_len); + rsa->len = n_len; + if (has_private) { + // Decode and set private fields + size_t d_len, p_len, q_len, dp_len, dq_len, qi_len; + unsigned char d_bin[1024], p_bin[512], q_bin[512], dp_bin[512], dq_bin[512], qi_bin[512]; + // Decode all private fields (skip if NULL) + #define DECODE_B64URL(var, b64, bin, binlen) \ + if (b64 && *b64) { \ + char *b64_std = strdup(b64); \ + for (char *p = b64_std; *p; ++p) if (*p == '-') *p = '+'; else if (*p == '_') *p = '/'; \ + int mod = strlen(b64_std) % 4; \ + if (mod) for (int i = 0; i < 4 - mod; ++i) strcat(b64_std, "="); \ + mbedtls_base64_decode(bin, sizeof(bin), &binlen, (const unsigned char *)b64_std, strlen(b64_std)); \ + free(b64_std); \ + } + DECODE_B64URL(d, d_b64, d_bin, d_len); + DECODE_B64URL(p, p_b64, p_bin, p_len); + DECODE_B64URL(q, q_b64, q_bin, q_len); + DECODE_B64URL(dp, dp_b64, dp_bin, dp_len); + DECODE_B64URL(dq, dq_b64, dq_bin, dq_len); + DECODE_B64URL(qi, qi_b64, qi_bin, qi_len); + mbedtls_mpi_read_binary(&rsa->D, d_bin, d_len); + mbedtls_mpi_read_binary(&rsa->P, p_bin, p_len); + mbedtls_mpi_read_binary(&rsa->Q, q_bin, q_len); + mbedtls_mpi_read_binary(&rsa->DP, dp_bin, dp_len); + mbedtls_mpi_read_binary(&rsa->DQ, dq_bin, dq_len); + mbedtls_mpi_read_binary(&rsa->QP, qi_bin, qi_len); + } + // Write PEM + unsigned char buf[4096]; + if (has_private) { + ret = mbedtls_pk_write_key_pem(&pk, buf, sizeof(buf)); + } else { + ret = mbedtls_pk_write_pubkey_pem(&pk, buf, sizeof(buf)); + } + if (ret != 0) { + mbedtls_pk_free(&pk); + lua_pushnil(L); lua_pushstring(L, "PEM write failed"); return 2; + } + pem = strdup((char *)buf); + mbedtls_pk_free(&pk); + lua_pushstring(L, pem); + free(pem); + return 1; + } else if (strcasecmp(kty, "EC") == 0) { + // EC JWK: crv, x, y (base64url), optionally d + lua_getfield(L, 1, "crv"); + lua_getfield(L, 1, "x"); + lua_getfield(L, 1, "y"); + lua_getfield(L, 1, "d"); + const char *crv = lua_tostring(L, -4); + const char *x_b64 = lua_tostring(L, -3); + const char *y_b64 = lua_tostring(L, -2); + const char *d_b64 = lua_tostring(L, -1); + int has_private = d_b64 && *d_b64; + mbedtls_ecp_group_id gid = find_curve_by_name(crv); + if (gid == MBEDTLS_ECP_DP_NONE) { + lua_pushnil(L); lua_pushstring(L, "Unknown curve"); return 2; + } + size_t x_len, y_len; + unsigned char x_bin[72], y_bin[72]; + char *x_b64_std = strdup(x_b64), *y_b64_std = strdup(y_b64); + for (char *p = x_b64_std; *p; ++p) if (*p == '-') *p = '+'; else if (*p == '_') *p = '/'; + for (char *p = y_b64_std; *p; ++p) if (*p == '-') *p = '+'; else if (*p == '_') *p = '/'; + int x_mod = strlen(x_b64_std) % 4; + int y_mod = strlen(y_b64_std) % 4; + if (x_mod) for (int i = 0; i < 4 - x_mod; ++i) strcat(x_b64_std, "="); + if (y_mod) for (int i = 0; i < 4 - y_mod; ++i) strcat(y_b64_std, "="); + if (mbedtls_base64_decode(x_bin, sizeof(x_bin), &x_len, (const unsigned char *)x_b64_std, strlen(x_b64_std)) != 0 || + mbedtls_base64_decode(y_bin, sizeof(y_bin), &y_len, (const unsigned char *)y_b64_std, strlen(y_b64_std)) != 0) { + free(x_b64_std); free(y_b64_std); + lua_pushnil(L); lua_pushstring(L, "Base64 decode failed"); return 2; + } + free(x_b64_std); free(y_b64_std); + if ((ret = mbedtls_pk_setup(&pk, mbedtls_pk_info_from_type(MBEDTLS_PK_ECKEY))) != 0) { + lua_pushnil(L); lua_pushstring(L, "mbedtls_pk_setup failed"); return 2; + } + mbedtls_ecp_keypair *ec = mbedtls_pk_ec(pk); + mbedtls_ecp_keypair_init(ec); + mbedtls_ecp_group_load(&ec->grp, gid); + mbedtls_mpi_read_binary(&ec->Q.X, x_bin, x_len); + mbedtls_mpi_read_binary(&ec->Q.Y, y_bin, y_len); + mbedtls_mpi_lset(&ec->Q.Z, 1); + if (has_private) { + size_t d_len; + unsigned char d_bin[72]; + DECODE_B64URL(d, d_b64, d_bin, d_len); + mbedtls_mpi_read_binary(&ec->d, d_bin, d_len); + } + unsigned char buf[4096]; + if (has_private) { + ret = mbedtls_pk_write_key_pem(&pk, buf, sizeof(buf)); + } else { + ret = mbedtls_pk_write_pubkey_pem(&pk, buf, sizeof(buf)); + } + if (ret != 0) { + mbedtls_pk_free(&pk); + lua_pushnil(L); lua_pushstring(L, "PEM write failed"); return 2; + } + pem = strdup((char *)buf); + mbedtls_pk_free(&pk); + lua_pushstring(L, pem); + free(pem); + return 1; + } else { + lua_pushnil(L); + lua_pushstring(L, "Unsupported kty"); + return 2; + } +} + +// Convert PEM key to JWK (Lua table) format +static int LuaConvertPemToJwk(lua_State *L) { + const char *pem_key = luaL_checkstring(L, 1); + + mbedtls_pk_context key; + mbedtls_pk_init(&key); + int ret; + + // Parse the PEM key + if ((ret = mbedtls_pk_parse_key(&key, (const unsigned char *)pem_key, + strlen(pem_key) + 1, NULL, 0)) != 0 && + (ret = mbedtls_pk_parse_public_key(&key, (const unsigned char *)pem_key, + strlen(pem_key) + 1)) != 0) { + lua_pushnil(L); + lua_pushfstring(L, "Failed to parse PEM key: -0x%04x", -ret); + mbedtls_pk_free(&key); + return 2; + } + + lua_newtable(L); + + if (mbedtls_pk_get_type(&key) == MBEDTLS_PK_RSA) { + lua_pushstring(L, "RSA"); + lua_setfield(L, -2, "kty"); + const mbedtls_rsa_context *rsa = mbedtls_pk_rsa(key); + size_t n_len = mbedtls_mpi_size(&rsa->N); + size_t e_len = mbedtls_mpi_size(&rsa->E); + unsigned char *n = malloc(n_len); + unsigned char *e = malloc(e_len); + if (!n || !e) { + lua_pushnil(L); + lua_pushstring(L, "Memory allocation failed"); + free(n); + free(e); + mbedtls_pk_free(&key); + return 2; + } + mbedtls_mpi_write_binary(&rsa->N, n, n_len); + mbedtls_mpi_write_binary(&rsa->E, e, e_len); + char *n_b64 = NULL, *e_b64 = NULL; + size_t n_b64_len, e_b64_len; + mbedtls_base64_encode(NULL, 0, &n_b64_len, n, n_len); + mbedtls_base64_encode(NULL, 0, &e_b64_len, e, e_len); + n_b64 = malloc(n_b64_len + 1); + e_b64 = malloc(e_b64_len + 1); + mbedtls_base64_encode((unsigned char *)n_b64, n_b64_len, &n_b64_len, n, + n_len); + mbedtls_base64_encode((unsigned char *)e_b64, e_b64_len, &e_b64_len, e, + e_len); + n_b64[n_b64_len] = '\0'; + e_b64[e_b64_len] = '\0'; + lua_pushstring(L, n_b64); + lua_setfield(L, -2, "n"); + lua_pushstring(L, e_b64); + lua_setfield(L, -2, "e"); + // If private key, add private fields + if (mbedtls_rsa_check_privkey(rsa) == 0 && rsa->D.p) { + size_t d_len = mbedtls_mpi_size(&rsa->D); + size_t p_len = mbedtls_mpi_size(&rsa->P); + size_t q_len = mbedtls_mpi_size(&rsa->Q); + size_t dp_len = mbedtls_mpi_size(&rsa->DP); + size_t dq_len = mbedtls_mpi_size(&rsa->DQ); + size_t qi_len = mbedtls_mpi_size(&rsa->QP); + unsigned char *d = malloc(d_len), *p = malloc(p_len), *q = malloc(q_len), + *dp = malloc(dp_len), *dq = malloc(dq_len), + *qi = malloc(qi_len); + if (!d || !p || !q || !dp || !dq || !qi) { + free(d); + free(p); + free(q); + free(dp); + free(dq); + free(qi); + lua_pushnil(L); + lua_pushstring(L, "Memory allocation failed"); + mbedtls_pk_free(&key); + return 2; + } + mbedtls_mpi_write_binary(&rsa->D, d, d_len); + mbedtls_mpi_write_binary(&rsa->P, p, p_len); + mbedtls_mpi_write_binary(&rsa->Q, q, q_len); + mbedtls_mpi_write_binary(&rsa->DP, dp, dp_len); + mbedtls_mpi_write_binary(&rsa->DQ, dq, dq_len); + mbedtls_mpi_write_binary(&rsa->QP, qi, qi_len); + char *d_b64 = NULL, *p_b64 = NULL, *q_b64 = NULL, *dp_b64 = NULL, + *dq_b64 = NULL, *qi_b64 = NULL; + size_t d_b64_len, p_b64_len, q_b64_len, dp_b64_len, dq_b64_len, + qi_b64_len; + mbedtls_base64_encode(NULL, 0, &d_b64_len, d, d_len); + mbedtls_base64_encode(NULL, 0, &p_b64_len, p, p_len); + mbedtls_base64_encode(NULL, 0, &q_b64_len, q, q_len); + mbedtls_base64_encode(NULL, 0, &dp_b64_len, dp, dp_len); + mbedtls_base64_encode(NULL, 0, &dq_b64_len, dq, dq_len); + mbedtls_base64_encode(NULL, 0, &qi_b64_len, qi, qi_len); + d_b64 = malloc(d_b64_len + 1); + p_b64 = malloc(p_b64_len + 1); + q_b64 = malloc(q_b64_len + 1); + dp_b64 = malloc(dp_b64_len + 1); + dq_b64 = malloc(dq_b64_len + 1); + qi_b64 = malloc(qi_b64_len + 1); + mbedtls_base64_encode((unsigned char *)d_b64, d_b64_len, &d_b64_len, d, + d_len); + mbedtls_base64_encode((unsigned char *)p_b64, p_b64_len, &p_b64_len, p, + p_len); + mbedtls_base64_encode((unsigned char *)q_b64, q_b64_len, &q_b64_len, q, + q_len); + mbedtls_base64_encode((unsigned char *)dp_b64, dp_b64_len, &dp_b64_len, + dp, dp_len); + mbedtls_base64_encode((unsigned char *)dq_b64, dq_b64_len, &dq_b64_len, + dq, dq_len); + mbedtls_base64_encode((unsigned char *)qi_b64, qi_b64_len, &qi_b64_len, + qi, qi_len); + d_b64[d_b64_len] = '\0'; + p_b64[p_b64_len] = '\0'; + q_b64[q_b64_len] = '\0'; + dp_b64[dp_b64_len] = '\0'; + dq_b64[dq_b64_len] = '\0'; + qi_b64[qi_b64_len] = '\0'; + lua_pushstring(L, d_b64); + lua_setfield(L, -2, "d"); + lua_pushstring(L, p_b64); + lua_setfield(L, -2, "p"); + lua_pushstring(L, q_b64); + lua_setfield(L, -2, "q"); + lua_pushstring(L, dp_b64); + lua_setfield(L, -2, "dp"); + lua_pushstring(L, dq_b64); + lua_setfield(L, -2, "dq"); + lua_pushstring(L, qi_b64); + lua_setfield(L, -2, "qi"); + free(d); + free(p); + free(q); + free(dp); + free(dq); + free(qi); + free(d_b64); + free(p_b64); + free(q_b64); + free(dp_b64); + free(dq_b64); + free(qi_b64); + } + free(n); + free(e); + free(n_b64); + free(e_b64); + } else if (mbedtls_pk_get_type(&key) == MBEDTLS_PK_ECKEY) { + // Handle ECDSA keys + const mbedtls_ecp_keypair *ec = mbedtls_pk_ec(key); + const mbedtls_ecp_point *Q = &ec->Q; + size_t x_len = (ec->grp.pbits + 7) / 8; + size_t y_len = (ec->grp.pbits + 7) / 8; + unsigned char *x = malloc(x_len); + unsigned char *y = malloc(y_len); + if (!x || !y) { + lua_pushnil(L); + lua_pushstring(L, "Memory allocation failed"); + free(x); + free(y); + mbedtls_pk_free(&key); + return 2; + } + mbedtls_mpi_write_binary(&Q->X, x, x_len); + mbedtls_mpi_write_binary(&Q->Y, y, y_len); + char *x_b64 = NULL, *y_b64 = NULL; + size_t x_b64_len, y_b64_len; + mbedtls_base64_encode(NULL, 0, &x_b64_len, x, x_len); + mbedtls_base64_encode(NULL, 0, &y_b64_len, y, y_len); + x_b64 = malloc(x_b64_len + 1); + y_b64 = malloc(y_b64_len + 1); + mbedtls_base64_encode((unsigned char *)x_b64, x_b64_len, &x_b64_len, x, x_len); + mbedtls_base64_encode((unsigned char *)y_b64, y_b64_len, &y_b64_len, y, y_len); + x_b64[x_b64_len] = '\0'; + y_b64[y_b64_len] = '\0'; + // Set kty and crv for EC keys + lua_pushstring(L, "EC"); + lua_setfield(L, -2, "kty"); + const mbedtls_ecp_curve_info *curve_info = mbedtls_ecp_curve_info_from_grp_id(ec->grp.id); + if (curve_info && curve_info->name) { + lua_pushstring(L, curve_info->name); + lua_setfield(L, -2, "crv"); + } else { + lua_pushstring(L, "unknown"); + lua_setfield(L, -2, "crv"); + } + lua_pushstring(L, x_b64); + lua_setfield(L, -2, "x"); + lua_pushstring(L, y_b64); + lua_setfield(L, -2, "y"); + // If private key, add 'd' + if (mbedtls_ecp_check_privkey(&ec->grp, &ec->d) == 0 && ec->d.p) { + size_t d_len = mbedtls_mpi_size(&ec->d); + unsigned char *d = malloc(d_len); + if (!d) { free(x); free(y); free(x_b64); free(y_b64); lua_pushnil(L); lua_pushstring(L, "Memory allocation failed"); mbedtls_pk_free(&key); return 2; } + mbedtls_mpi_write_binary(&ec->d, d, d_len); + char *d_b64 = NULL; + size_t d_b64_len; + mbedtls_base64_encode(NULL, 0, &d_b64_len, d, d_len); + d_b64 = malloc(d_b64_len + 1); + mbedtls_base64_encode((unsigned char *)d_b64, d_b64_len, &d_b64_len, d, d_len); + d_b64[d_b64_len] = '\0'; + lua_pushstring(L, d_b64); lua_setfield(L, -2, "d"); + free(d); free(d_b64); + } + free(x); free(y); free(x_b64); free(y_b64); + } else { + lua_pushnil(L); + lua_pushstring(L, "Unsupported key type"); + mbedtls_pk_free(&key); + return 2; + } + + mbedtls_pk_free(&key); + return 1; +} + + +// CSR creation Function +static int LuaGenerateCSR(lua_State *L) { + const char *key_pem = luaL_checkstring(L, 1); + const char *subject_name; + const char *san_list = luaL_optstring(L, 3, NULL); + + if (lua_isnoneornil(L, 2)) { + subject_name = ""; + } else { + subject_name = luaL_checkstring(L, 2); + } + + if (lua_isnoneornil(L, 3) && subject_name[0] == '\0') { + lua_pushnil(L); + lua_pushstring(L, "Subject name or SANs are required"); + return 2; + } + mbedtls_pk_context key; + mbedtls_x509write_csr req; + char buf[4096]; + int ret; + + mbedtls_pk_init(&key); + mbedtls_x509write_csr_init(&req); + + if ((ret = mbedtls_pk_parse_key(&key, (const unsigned char *)key_pem, + strlen(key_pem) + 1, NULL, 0)) != 0) { + lua_pushnil(L); + lua_pushfstring(L, "Failed to parse key: %d", ret); + return 2; + } + + mbedtls_x509write_csr_set_subject_name(&req, subject_name); + mbedtls_x509write_csr_set_key(&req, &key); + mbedtls_x509write_csr_set_md_alg(&req, MBEDTLS_MD_SHA256); + + if (san_list) { + if ((ret = mbedtls_x509write_csr_set_extension( + &req, MBEDTLS_OID_SUBJECT_ALT_NAME, + MBEDTLS_OID_SIZE(MBEDTLS_OID_SUBJECT_ALT_NAME), + (const unsigned char *)san_list, strlen(san_list))) != 0) { + lua_pushnil(L); + lua_pushfstring(L, "Failed to set SANs: %d", ret); + return 2; + } + } + + if ((ret = mbedtls_x509write_csr_pem(&req, (unsigned char *)buf, sizeof(buf), + NULL, NULL)) < 0) { + lua_pushnil(L); + lua_pushfstring(L, "Failed to write CSR: %d", ret); + return 2; + } + + lua_pushstring(L, buf); + + mbedtls_pk_free(&key); + mbedtls_x509write_csr_free(&req); + + return 1; +} + + // LuaCrypto compatible API static int LuaCryptoSign(lua_State *L) { - const char *dtype = - luaL_checkstring(L, 1); // Type of signature (e.g., "rsa", "ecdsa") - lua_remove(L, 1); // Remove the first argument (key type or cipher type) - // before dispatching + // Type of signature (e.g., "rsa", "ecdsa") + const char *dtype = luaL_checkstring(L, 1); + // Remove the first argument (key type or cipher type) before dispatching + lua_remove(L, 1); if (strcasecmp(dtype, "rsa") == 0) { return LuaRSASign(L); @@ -1585,10 +1853,10 @@ static int LuaCryptoSign(lua_State *L) { } static int LuaCryptoVerify(lua_State *L) { - const char *dtype = - luaL_checkstring(L, 1); // Type of signature (e.g., "rsa", "ecdsa") - lua_remove(L, 1); // Remove the first argument (key type or cipher type) - // before dispatching + // Type of signature (e.g., "rsa", "ecdsa") + const char *dtype = luaL_checkstring(L, 1); + // Remove the first argument (key type or cipher type) before dispatching + lua_remove(L, 1); if (strcasecmp(dtype, "rsa") == 0) { return LuaRSAVerify(L); @@ -1602,8 +1870,8 @@ static int LuaCryptoVerify(lua_State *L) { static int LuaCryptoEncrypt(lua_State *L) { // Args: cipher_type, key, msg, options table const char *cipher = luaL_checkstring(L, 1); - lua_remove(L, 1); // Remove cipher_type from stack, so key is at 1, msg at 2, - // options at 3 + // Remove cipher_type from stack, so key is at 1, msg at 2, options at 3 + lua_remove(L, 1); if (strcasecmp(cipher, "rsa") == 0) { return LuaRSAEncrypt(L); @@ -1617,9 +1885,8 @@ static int LuaCryptoEncrypt(lua_State *L) { static int LuaCryptoDecrypt(lua_State *L) { // Args: cipher_type, key, ciphertext, options table const char *cipher = luaL_checkstring(L, 1); - lua_remove( - L, - 1); // Remove cipher_type, so key is at 1, ciphertext at 2, options at 3 + // Remove cipher_type, so key is at 1, ciphertext at 2, options at 3 + lua_remove(L, 1); if (strcasecmp(cipher, "rsa") == 0) { return LuaRSADecrypt(L); @@ -1636,10 +1903,10 @@ static int LuaCryptoGenerateKeyPair(lua_State *L) { // Call LuaRSAGenerateKeyPair with the number as the key length return LuaRSAGenerateKeyPair(L); } - // Otherwise, get the key type from the first argument, default to "rsa" if - // not provided + // Otherwise, get the key type from the first argument, default to "rsa" const char *type = luaL_optstring(L, 1, "rsa"); lua_remove(L, 1); + if (strcasecmp(type, "rsa") == 0) { return LuaRSAGenerateKeyPair(L); } else if (strcasecmp(type, "ecdsa") == 0) { @@ -1651,12 +1918,15 @@ static int LuaCryptoGenerateKeyPair(lua_State *L) { } } + + static const luaL_Reg kLuaCrypto[] = { {"sign", LuaCryptoSign}, // {"verify", LuaCryptoVerify}, // {"encrypt", LuaCryptoEncrypt}, // {"decrypt", LuaCryptoDecrypt}, // {"generatekeypair", LuaCryptoGenerateKeyPair}, // + {"convertJwkToPem", LuaConvertJwkToPem}, // {"convertPemToJwk", LuaConvertPemToJwk}, // {"generateCsr", LuaGenerateCSR}, // {0}, // From 12ff789a699a2bb5ffc9bc3c71f8e71f1774d6b3 Mon Sep 17 00:00:00 2001 From: Miguel Terron Date: Tue, 24 Jun 2025 17:47:07 +1200 Subject: [PATCH 15/18] Add RSA PSS support Improve error messages Improve parameter validation Correct base64url encoding for JWK Add support for optional claims to convertPemToJwk Expand test coverage Add basic definitions --- test/tool/net/lcrypto_test.lua | 721 ++++++++++++++++++++---- tool/net/definitions.lua | 125 +++-- tool/net/lcrypto.c | 971 +++++++++++++++++++++++---------- 3 files changed, 1381 insertions(+), 436 deletions(-) diff --git a/test/tool/net/lcrypto_test.lua b/test/tool/net/lcrypto_test.lua index 0516a6ec2..27d4fb3fb 100644 --- a/test/tool/net/lcrypto_test.lua +++ b/test/tool/net/lcrypto_test.lua @@ -1,246 +1,753 @@ --- Helper function to print test results -local function assert_equal(actual, expected, message) - if actual ~= expected then - error("FAIL: " .. message .. ": expected " .. tostring(expected) .. ", got " .. tostring(actual)) - end -end - -local function assert_not_equal(actual, not_expected, message) - if actual == not_expected then - error(message .. ": did not expect " .. tostring(not_expected)) - end -end - +---@diagnostic disable: lowercase-global -- Test RSA key pair generation local function test_rsa_keypair_generation() - local priv_key, pub_key = crypto.generatekeypair("rsa", 2048) - assert_equal(type(priv_key), "string", "Private key type") - assert_equal(type(pub_key), "string", "Public key type") + local priv_key, pub_key = crypto.generateKeyPair("rsa", 2048) + assert(type(priv_key) == "string", "Private key type") + assert(type(pub_key) == "string", "Public key type") end -- Test ECDSA key pair generation local function test_ecdsa_keypair_generation() - local priv_key, pub_key = crypto.generatekeypair("ecdsa", "secp256r1") - assert_equal(type(priv_key), "string", "Private key type") - assert_equal(type(pub_key), "string", "Public key type") + local priv_key, pub_key = crypto.generateKeyPair("ecdsa", "secp256r1") + assert(type(priv_key) == "string", "Private key type") + assert(type(pub_key) == "string", "Public key type") end -- Test RSA encryption and decryption local function test_rsa_encryption_decryption() - local priv_key, pub_key = crypto.generatekeypair("rsa", 2048) + local priv_key, pub_key = crypto.generateKeyPair("rsa", 2048) assert(type(priv_key) == "string", "Private key type") assert(type(pub_key) == "string", "Public key type") + local plaintext = "Hello, RSA!" - local encrypted = crypto.encrypt("rsa", pub_key, plaintext) - assert_equal(type(encrypted), "string", "Ciphertext type") - local decrypted = crypto.decrypt("rsa", priv_key, encrypted) - assert_equal(decrypted, plaintext, "Decrypted ciphertext matches plaintext") + local ciphertext = crypto.encrypt("rsa", pub_key, plaintext) + assert(type(ciphertext) == "string", "Ciphertext type") + + local decrypted_plaintext = crypto.decrypt("rsa", priv_key, ciphertext) + assert(decrypted_plaintext == plaintext, "Decrypted ciphertext matches plaintext") end -- Test RSA signing and verification local function test_rsa_signing_verification() - local priv_key, pub_key = crypto.generatekeypair("rsa", 2048) + local priv_key, pub_key = crypto.generateKeyPair("rsa", 2048) assert(type(priv_key) == "string", "Private key type") assert(type(pub_key) == "string", "Public key type") + local message = "Sign this message" local signature = crypto.sign("rsa", priv_key, message, "sha256") - assert_equal(type(signature), "string", "Signature type") + assert(type(signature) == "string", "Signature type") + local is_valid = crypto.verify("rsa", pub_key, message, signature, "sha256") - assert_equal(is_valid, true, "Signature verification") + assert(is_valid == true, "Signature verification") +end + +-- Test RSA-PSS signing and verification +local function test_rsapss_signing_verification() + local priv_key, pub_key = crypto.generateKeyPair("rsa", 2048) + assert(type(priv_key) == "string", "Private key type") + assert(type(pub_key) == "string", "Public key type") + + Log(kLogVerbose," - Testing RSA-PSS signing") + local message = "Sign this message with RSA-PSS" + local signature = crypto.sign("rsapss", priv_key, message, "sha256") + assert(type(signature) == "string", "Signature type") + + Log(kLogVerbose," - Testing RSA-PSS verification") + local is_valid = crypto.verify("rsapss", pub_key, message, signature, "sha256") + assert(is_valid == true, "RSA-PSS Signature verification") + + -- Test with different hash algorithm + Log(kLogVerbose," - Testing RSA-PSS with different hash algorithms") + signature = crypto.sign("rsapss", priv_key, message, "sha384") + assert(type(signature) == "string", "SHA-384 Signature type") + is_valid = crypto.verify("rsapss", pub_key, message, signature, "sha384") + assert(is_valid == true, "RSA-PSS SHA-384 Signature verification") + + Log(kLogVerbose," - Testing RSA-PSS with SHA-512") + signature = crypto.sign("rsapss", priv_key, message, "sha512") + assert(type(signature) == "string", "SHA-512 Signature type") + is_valid = crypto.verify("rsapss", pub_key, message, signature, "sha512") + assert(is_valid == true, "RSA-PSS SHA-512 Signature verification") end -- Test ECDSA signing and verification local function test_ecdsa_signing_verification() - local priv_key, pub_key = crypto.generatekeypair("ecdsa", "secp256r1") + local priv_key, pub_key = crypto.generateKeyPair("ecdsa", "secp256r1") assert(type(priv_key) == "string", "Private key type") assert(type(pub_key) == "string", "Public key type") + local message = "Sign this message with ECDSA" local signature = crypto.sign("ecdsa", priv_key, message, "sha256") - assert_equal(type(signature), "string", "Signature type") + assert(type(signature) == "string", "Signature type") + local is_valid = crypto.verify("ecdsa", pub_key, message, signature, "sha256") - assert_equal(is_valid, true, "Signature verification") + assert(is_valid == true, "Signature verification") end -- Test AES key generation local function test_aes_key_generation() - local key = crypto.generatekeypair('aes', 256) -- 256-bit key - assert_equal(type(key), "string", "Key type") - assert_equal(#key, 32, "Key length (256 bits)") + local key = crypto.generateKeyPair('aes', 256) -- 256-bit key + assert(type(key) == "string", "Key type") + assert(#key == 32, "Key length (256 bits)") end -- Test AES encryption and decryption (CBC mode) local function test_aes_encryption_decryption_cbc() - local key = crypto.generatekeypair('aes', 256) -- 256-bit key + local key = crypto.generateKeyPair('aes', 256) -- 256-bit key local plaintext = "Hello, AES CBC!" -- Encrypt without providing IV (should auto-generate IV) - local encrypted, iv = crypto.encrypt("aes", key, plaintext, nil) - assert_equal(type(encrypted), "string", "Ciphertext type") - assert_equal(type(iv), "string", "IV type") + local ciphertext, iv = crypto.encrypt("aes", key, plaintext, nil) + assert(type(ciphertext) == "string", "Ciphertext type") + assert(type(iv) == "string", "IV type") -- Decrypt - local decrypted = crypto.decrypt("aes", key, encrypted, {mode="cbc",iv=iv}) - assert_equal(decrypted, plaintext, "Decrypted ciphertext matches plaintext") + local decrypted_plaintext = crypto.decrypt("aes", key, ciphertext, { mode = "cbc", iv = iv }) + assert(decrypted_plaintext == plaintext, "Decrypted ciphertext matches plaintext") -- Encrypt with explicit IV local iv2 = GetRandomBytes(16) - local encrypted2, iv_used = crypto.encrypt("aes", key, plaintext, {mode="cbc",iv=iv2}) - assert_equal(type(encrypted2), "string", "Ciphertext type") - assert_equal(iv_used, iv2, "IV match") + local ciphertext2, iv_used = crypto.encrypt("aes", key, plaintext, { mode = "cbc", iv = iv2 }) + assert(type(ciphertext2) == "string", "Ciphertext type") + assert(iv_used == iv2, "IV match") - local decrypted2 = crypto.decrypt("aes", key, encrypted2, {mode="cbc",iv=iv2}) - assert_equal(decrypted2, plaintext, "Decrypted ciphertext matches plaintext") + local decrypted_plaintext2 = crypto.decrypt("aes", key, ciphertext2, { mode = "cbc", iv = iv2 }) + assert(decrypted_plaintext2 == plaintext, "Decrypted ciphertext matches plaintext") end -- Test AES encryption and decryption (CTR mode) local function test_aes_encryption_decryption_ctr() - local key = crypto.generatekeypair('aes', 256) + local key = crypto.generateKeyPair('aes', 256) local plaintext = "Hello, AES CTR!" -- Encrypt without providing IV (should auto-generate IV) - local encrypted, iv = crypto.encrypt("aes", key, plaintext, {mode="ctr"}) - assert_equal(type(encrypted), "string", "Ciphertext type") - assert_equal(type(iv), "string", "IV type") + local ciphertext, iv = crypto.encrypt("aes", key, plaintext, { mode = "ctr" }) + assert(type(ciphertext) == "string", "Ciphertext type") + assert(type(iv) == "string", "IV type") -- Decrypt - local decrypted = crypto.decrypt("aes", key, encrypted, {mode="ctr", iv=iv}) - assert_equal(decrypted, plaintext, "Decrypted ciphertext matches plaintext") + local decrypted_plaintext = crypto.decrypt("aes", key, ciphertext, { mode = "ctr", iv = iv }) + assert(decrypted_plaintext == plaintext, "Decrypted ciphertext matches plaintext") -- Encrypt with explicit IV local iv2 = GetRandomBytes(16) - local encrypted2, iv_used = crypto.encrypt("aes", key, plaintext, {mode="ctr", iv=iv2}) - assert_equal(type(encrypted2), "string", "Ciphertext type") - assert_equal(iv_used, iv2, "IV match") + local ciphertext2, iv_used = crypto.encrypt("aes", key, plaintext, { mode = "ctr", iv = iv2 }) + assert(type(ciphertext2) == "string", "Ciphertext type") + assert(iv_used == iv2, "IV match") - local decrypted2 = crypto.decrypt("aes", key, encrypted2, {mode="ctr", iv=iv2}) - assert_equal(decrypted2, plaintext, "Decrypted ciphertext matches plaintext") + local decrypted_plaintext2 = crypto.decrypt("aes", key, ciphertext2, { mode = "ctr", iv = iv2 }) + assert(decrypted_plaintext2 == plaintext, "Decrypted ciphertext matches plaintext") end -- Test AES encryption and decryption (GCM mode) local function test_aes_encryption_decryption_gcm() - local key = crypto.generatekeypair('aes', 256) - assert_equal(type(key), "string", "key type") + local key = crypto.generateKeyPair('aes', 256) + assert(type(key) == "string", "key type") local plaintext = "Hello, AES GCM!" -- Encrypt without providing IV (should auto-generate IV) - local encrypted, iv, tag = crypto.encrypt("aes", key, plaintext, {mode="gcm"}) - assert_equal(#plaintext, #encrypted, "Ciphertext length matches plaintext") - assert_equal(type(encrypted), "string", "Ciphertext type") - assert_equal(type(iv), "string", "IV type") - assert_equal(type(tag), "string", "Tag type") + local ciphertext, iv, tag = crypto.encrypt("aes", key, plaintext, { mode = "gcm" }) + assert(#plaintext == #ciphertext, "Ciphertext length matches plaintext") + assert(type(ciphertext) == "string", "Ciphertext type") + assert(type(iv) == "string", "IV type") + assert(type(tag) == "string", "Tag type") -- Decrypt - local decrypted = crypto.decrypt("aes", key, encrypted, {mode="gcm",iv=iv,tag=tag}) - assert_equal(decrypted, plaintext, "Decrypted ciphertext matches plaintext") + local decrypted_plaintext = crypto.decrypt("aes", key, ciphertext, { mode = "gcm", iv = iv, tag = tag }) + assert(decrypted_plaintext == plaintext, "Decrypted ciphertext matches plaintext") -- Encrypt with explicit IV local iv2 = GetRandomBytes(13) -- GCM IV/nonce can be 12-16 bytes, 12 is standard - local encrypted2, iv_used, tag2 = crypto.encrypt("aes", key, plaintext, {mode="gcm",iv=iv2}) - assert_equal(type(encrypted2), "string", "Ciphertext type") - assert_equal(iv_used, iv2, "IV match") - assert_equal(type(tag2), "string", "Tag type") + local ciphertext2, iv_used, tag2 = crypto.encrypt("aes", key, plaintext, { mode = "gcm", iv = iv2 }) + assert(type(ciphertext2) == "string", "Ciphertext type") + assert(iv_used == iv2, "IV match") + assert(type(tag2) == "string", "Tag type") - local decrypted2 = crypto.decrypt("aes", key, encrypted2, {mode="gcm",iv=iv2,tag=tag2}) - assert_equal(decrypted2, plaintext, "Decrypted ciphertext matches plaintext") + local decrypted_plaintext2 = crypto.decrypt("aes", key, ciphertext2, { mode = "gcm", iv = iv2, tag = tag2 }) + assert(decrypted_plaintext2 == plaintext, "Decrypted ciphertext matches plaintext") end -- Test PemToJwk conversion local function test_pem_to_jwk() - local priv_key, pub_key = crypto.generatekeypair() + local priv_key, pub_key = crypto.generateKeyPair() local priv_jwk = crypto.convertPemToJwk(priv_key) - assert_equal(type(priv_jwk), "table", "JWK type") - assert_equal(priv_jwk.kty, "RSA", "kty is correct") + assert(type(priv_jwk) == "table", "JWK type") + assert(priv_jwk.kty == "RSA", "kty is correct") local pub_jwk = crypto.convertPemToJwk(pub_key) - assert_equal(type(pub_jwk), "table", "JWK type") - assert_equal(pub_jwk.kty, "RSA", "kty is correct") + assert(type(pub_jwk) == "table", "JWK type") + assert(pub_jwk.kty == "RSA", "kty is correct") -- Test ECDSA keys - local priv_key, pub_key = crypto.generatekeypair('ecdsa') - local priv_jwk = crypto.convertPemToJwk(priv_key) - assert_equal(type(priv_jwk), "table", "JWK type") - assert_equal(priv_jwk.kty, "EC", "kty is correct") + priv_key, pub_key = crypto.generateKeyPair('ecdsa') + priv_jwk = crypto.convertPemToJwk(priv_key) + assert(type(priv_jwk) == "table", "JWK type") + assert(priv_jwk.kty == "EC", "kty is correct") - local pub_jwk = crypto.convertPemToJwk(pub_key) - assert_equal(type(pub_jwk), "table", "JWK type") - assert_equal(pub_jwk.kty, "EC", "kty is correct") + pub_jwk = crypto.convertPemToJwk(pub_key) + assert(type(pub_jwk) == "table", "JWK type") + assert(pub_jwk.kty == "EC", "kty is correct") end -- Test JwkToPem conversion local function test_jwk_to_pem() - local priv_key, pub_key = crypto.generatekeypair() + local priv_key, pub_key = crypto.generateKeyPair() local priv_jwk = crypto.convertPemToJwk(priv_key) local pub_jwk = crypto.convertPemToJwk(pub_key) local priv_pem = crypto.convertJwkToPem(priv_jwk) local pub_pem = crypto.convertJwkToPem(pub_jwk) - assert_equal(type(priv_pem), "string", "Private PEM type") + assert(type(priv_pem) == "string", "Private PEM type") -- Roundtrip - assert_equal(priv_key,priv_pem, "Private PEM matches original RSA key") - assert_equal(pub_key,pub_pem, "Public PEM matches original RSA key") + assert(priv_key == priv_pem, "Private PEM matches original RSA key") + assert(pub_key == pub_pem, "Public PEM matches original RSA key") - local pub_pem = crypto.convertJwkToPem(pub_jwk) - assert_equal(type(pub_pem), "string", "Public PEM type") + pub_pem = crypto.convertJwkToPem(pub_jwk) + assert(type(pub_pem) == "string", "Public PEM type") -- Test ECDSA keys - local priv_key, pub_key = crypto.generatekeypair('ecdsa') - local priv_jwk = crypto.convertPemToJwk(priv_key) - local pub_jwk = crypto.convertPemToJwk(pub_key) + priv_key, pub_key = crypto.generateKeyPair('ecdsa') + priv_jwk = crypto.convertPemToJwk(priv_key) + pub_jwk = crypto.convertPemToJwk(pub_key) + + priv_pem = crypto.convertJwkToPem(priv_jwk) + pub_pem = crypto.convertJwkToPem(pub_jwk) + assert(type(priv_pem) == "string", "Private PEM type for ECDSA") - local priv_pem = crypto.convertJwkToPem(priv_jwk) - local pub_pem = crypto.convertJwkToPem(pub_jwk) - assert_equal(type(priv_pem), "string", "Private PEM type for ECDSA") - -- Roundtrip - assert_equal(priv_key,priv_pem, "Private PEM matches original ECDSA key") - assert_equal(pub_key,pub_pem, "Public PEM matches original ECDSA key") + assert(priv_key == priv_pem, "Private PEM matches original ECDSA key") + assert(pub_key == pub_pem, "Public PEM matches original ECDSA key") - local pub_pem = crypto.convertJwkToPem(pub_jwk) - assert_equal(type(pub_pem), "string", "Public PEM type for ECDSA") + pub_pem = crypto.convertJwkToPem(pub_jwk) + assert(type(pub_pem) == "string", "Public PEM type for ECDSA") end -- Test CSR generation local function test_csr_generation() - local priv_key, _ = crypto.generatekeypair() + local priv_key, _ = crypto.generateKeyPair() local subject_name = "CN=example.com,O=Example Org,C=US" local san = "DNS:example.com, DNS:www.example.com, IP:192.168.1.1" - assert_equal(type(priv_key), "string", "Private key type") + assert(type(priv_key) == "string", "Private key type") local csr = crypto.generateCsr(priv_key, subject_name) - assert_equal(type(csr), "string", "CSR generation with subject name") + assert(type(csr) == "string", "CSR generation with subject name") csr = crypto.generateCsr(priv_key, subject_name, san) - assert_equal(type(csr), "string", "CSR generation with subject name and san") + assert(type(csr) == "string", "CSR generation with subject name and san") csr = crypto.generateCsr(priv_key, nil, san) - assert_equal(type(csr), "string", "CSR generation with nil subject name and san") + assert(type(csr) == "string", "CSR generation with nil subject name and san") csr = crypto.generateCsr(priv_key, '', san) - assert_equal(type(csr), "string", "CSR generation with empty subject name and san") + assert(type(csr) == "string", "CSR generation with empty subject name and san") -- These should fail csr = crypto.generateCsr(priv_key, '') - assert_not_equal(type(csr), "string", "CSR generation with empty subject name and no san is rejected") + assert(type(csr) ~= "string", "CSR generation with empty subject name and no san is rejected") csr = crypto.generateCsr(priv_key) - assert_not_equal(type(csr), "string", "CSR generation with nil subject name and no san is rejected") + assert(type(csr) ~= "string", "CSR generation with nil subject name and no san is rejected") +end + +-- Test various hash algorithms +local function test_hash_algorithms() + local priv_key, pub_key = crypto.generateKeyPair("rsa", 2048) + local message = "Test message for hash algorithms" + + -- Test different hash algorithms for RSA signatures + local hash_algorithms = { "sha256", "sha384", "sha512" } + for _, hash in ipairs(hash_algorithms) do + local signature = crypto.sign("rsa", priv_key, message, hash) + assert(type(signature) == "string", "RSA signature with " .. hash) + local is_valid = crypto.verify("rsa", pub_key, message, signature, hash) + assert(is_valid == true, "RSA verification with " .. hash) + + -- Test with RSA-PSS + local signature_pss = crypto.sign("rsapss", priv_key, message, hash) + assert(type(signature_pss) == "string", "RSA-PSS signature with " .. hash) + local is_valid_pss = crypto.verify("rsapss", pub_key, message, signature_pss, hash) + assert(is_valid_pss == true, "RSA-PSS verification with " .. hash) + end + + -- Test ECDSA with different hash algorithms + local ec_priv_key, ec_pub_key = crypto.generateKeyPair("ecdsa", "secp256r1") + for _, hash in ipairs(hash_algorithms) do + local signature = crypto.sign("ecdsa", ec_priv_key, message, hash) + assert(type(signature) == "string", "ECDSA signature with " .. hash) + + local is_valid = crypto.verify("ecdsa", ec_pub_key, message, signature, hash) + assert(is_valid == true, "ECDSA verification with " .. hash) + end +end + +-- Test negative cases for hash algorithms +local function test_negative_hash_algorithms() + local priv_key, pub_key = crypto.generateKeyPair() + local message = "Test message for hash algorithms" + + -- Test with invalid hash algorithm + local ok = pcall(function() return crypto.sign("rsa", priv_key, message, "invalid-hash") end) + assert(ok == false, "Sign with invalid hash should fail") + + -- Test with nil hash algorithm (should default to SHA-256) + local signature = crypto.sign("rsa", priv_key, message) + assert(type(signature) == "string", "Sign with nil hash") + + local is_valid = crypto.verify("rsa", pub_key, message, signature) + assert(is_valid == true, "Verify with nil hash") +end + +-- Negative tests for crypto functions +local function test_negative_keypair_generation() + -- Invalid algorithm + local ok = pcall(function() return crypto.generateKeyPair('invalidalg', 2048) end) + assert(ok == false, "generatekeypair with invalid algorithm should fail") + -- Invalid RSA key size + local pk, _ = crypto.generateKeyPair('rsa', 123) + assert(pk == nil, "generatekeypair with invalid RSA size should fail") + -- Invalid ECDSA curve + pk, _ = crypto.generateKeyPair('ecdsa', 'invalidcurve') + assert(pk == nil, "generatekeypair with invalid ECDSA curve should fail") +end + +local function test_negative_encrypt_decrypt() + local priv_key, pub_key = crypto.generateKeyPair('rsa', 2048) + -- Encrypt with invalid algorithm + local ok, ciphertext = pcall(function() return crypto.encrypt('invalidalg', pub_key, 'data') end) + assert(ok == false, "RSA encrypt with invalid algorithm should fail") + + -- Decrypt with invalid algorithm + ok, _ = pcall(function() return crypto.decrypt('invalidalg', priv_key, 'data') end) + assert(ok == false, "RSA decrypt with invalid algorithm should fail") + + -- Encrypt with invalid key + ciphertext = crypto.encrypt('rsa', 'notakey', 'data') + assert(ciphertext == nil, "RSA encrypt with invalid key should fail") + + -- Decrypt with invalid key + local retval = crypto.decrypt('rsa', 'notakey', 'data') + assert(retval == nil, "RSA decrypt with invalid key should fail") + + -- AES: invalid IV length + local key = crypto.generateKeyPair('aes', 256) + ciphertext = crypto.encrypt('aes', key, 'data', { mode = "cbc", iv = "tooShortIV" }) + assert(ciphertext == nil, "AES encrypt with short IV should fail") + + retval = crypto.decrypt('aes', key, 'data', { mode = "cbc", iv = "tooShortIV" }) + assert(retval == nil, "AES decrypt with short IV should fail") +end + +local function test_negative_sign_verify() + local priv_key, pub_key = crypto.generateKeyPair('rsa', 2048) + -- Sign with invalid algorithm + local ok = pcall(function() return crypto.sign('invalidalg', priv_key, 'msg', 'sha256') end) + assert(ok == false, "RSA sign with invalid algorithm should fail") + + -- Verify with invalid algorithm + ok = pcall(function() return crypto.verify('invalidalg', pub_key, 'msg', 'sig', 'sha256') end) + assert(ok == false, "RSA verify with invalid algorithm should fail") + + -- Sign with invalid key + ok = pcall(function() return crypto.sign('rsa', 'notakey', 'msg', 'sha256') end) + assert(ok == false, "RSA sign with invalid key should fail") + + -- Verify with invalid key + local verified = crypto.verify('rsa', 'notakey', 'msg', 'sig', 'sha256') + assert(verified == false, "verify with invalid key should fail") + + -- Verify with wrong signature (should return false, not error) + local badsig = 'thisisnotavalidsignature' + local result = crypto.verify('rsa', pub_key, 'msg', badsig, 'sha256') + assert(result == false, "RSA verify with wrong signature should return false") +end + +local function test_negative_pem_jwk_conversion() + -- Invalid PEM + local ok = pcall(function() return crypto.convertPemToJwk('notapem') end) + assert(ok == false, "convertPemToJwk with invalid PEM should fail") + + -- Invalid JWK (wrong type, but still a table) + local pem = crypto.convertJwkToPem({ kty = 'INVALID' }) + assert(pem == nil, "convertJwkToPem with invalid JWK should fail") + + -- Missing kty in JWK + pem = crypto.convertJwkToPem({}) + assert(pem == nil, "convertJwkToPem with missing kty should fail") +end + +local function test_negative_csr_generation() + -- Invalid key + local csr = crypto.generateCsr('notakey', 'CN=bad') + assert(csr == nil, "generateCsr with invalid key should fail") +end + +-- Add additional tests for edge cases in crypto functions + +-- Test RSA key size variations +local function test_rsa_key_sizes() + -- Test 2048-bit keys + local priv_key_2048, pub_key_2048 = crypto.generateKeyPair("rsa", 2048) + assert(type(priv_key_2048) == "string", "2048-bit private key type") + assert(type(pub_key_2048) == "string", "2048-bit public key type") + + -- Test 4096-bit keys + local priv_key_4096, pub_key_4096 = crypto.generateKeyPair("rsa", 4096) + assert(type(priv_key_4096) == "string", "4096-bit private key type") + assert(type(pub_key_4096) == "string", "4096-bit public key type") + + -- Test signing and verification with different key sizes + local message = "Test message for RSA key sizes" + + local signature_2048 = crypto.sign("rsa", priv_key_2048, message, "sha256") + assert(type(signature_2048) == "string", "2048-bit key signature") + + local is_valid_2048 = crypto.verify("rsa", pub_key_2048, message, signature_2048, "sha256") + assert(is_valid_2048 == true, "2048-bit key verification") + + local signature_4096 = crypto.sign("rsa", priv_key_4096, message, "sha256") + assert(type(signature_4096) == "string", "4096-bit key signature") + + local is_valid_4096 = crypto.verify("rsa", pub_key_4096, message, signature_4096, "sha256") + assert(is_valid_4096 == true, "4096-bit key verification") +end + +-- Test ECDSA curves +local function test_ecdsa_curves() + local curves = { "secp256r1", "secp384r1", "secp521r1" } + local message = "Test message for ECDSA curves" + + for _, curve in ipairs(curves) do + local priv_key, pub_key = crypto.generateKeyPair("ecdsa", curve) + assert(type(priv_key) == "string", curve .. " private key type") + assert(type(pub_key) == "string", curve .. " public key type") + + local signature = crypto.sign("ecdsa", priv_key, message, "sha256") + assert(type(signature) == "string", curve .. " signature") + local is_valid = crypto.verify("ecdsa", pub_key, message, signature, "sha256") + assert(is_valid == true, curve .. " verification") + end +end + +-- Test AES key sizes +local function test_aes_key_sizes() + local key_sizes = { 128, 192, 256 } + local plaintext = "Test message for AES key sizes" + + for _, size in ipairs(key_sizes) do + local key = crypto.generateKeyPair("aes", size) + assert(type(key) == "string", size .. "-bit AES key type") + assert(#key == size / 8, size .. "-bit AES key length") + + -- Test CBC mode + local ciphertext, iv = crypto.encrypt("aes", key, plaintext, { mode = "cbc" }) + assert(type(ciphertext) == "string", size .. "-bit AES CBC encryption") + local decrypted_plaintext_cbc = crypto.decrypt("aes", key, ciphertext, { mode = "cbc", iv = iv }) + assert(decrypted_plaintext_cbc == plaintext, size .. "-bit AES CBC decryption") + + -- Test CTR mode + local ciphertext_ctr, iv_ctr = crypto.encrypt("aes", key, plaintext, { mode = "ctr" }) + assert(type(ciphertext_ctr) == "string", size .. "-bit AES CTR encryption") + local decrypted_plaintext_ctr = crypto.decrypt("aes", key, ciphertext_ctr, { mode = "ctr", iv = iv_ctr }) + assert(decrypted_plaintext_ctr == plaintext, size .. "-bit AES CTR decryption") + + -- Test GCM mode + local ciphertext_gcm, iv_gcm, tag = crypto.encrypt("aes", key, plaintext, { mode = "gcm" }) + assert(type(ciphertext_gcm) == "string", size .. "-bit AES GCM encryption") + local decrypted_plaintext_gcm = crypto.decrypt("aes", key, ciphertext_gcm, { mode = "gcm", iv = iv_gcm, tag = tag }) + assert(decrypted_plaintext_gcm == plaintext, size .. "-bit AES GCM decryption") + end +end + +-- Test AES decryption with corrupted ciphertext and tag +local function test_aes_corruption_handling() + local key = crypto.generateKeyPair('aes', 256) + local plaintext = "Sensitive data for corruption test" + -- CBC mode + local ciphertext, iv = crypto.encrypt("aes", key, plaintext, { mode = "cbc" }) + + -- Corrupt ciphertext + Log(kLogVerbose," - CBC decryption with corrupted ciphertext should fail") + local corrupted = ciphertext:sub(1, #ciphertext - 1) .. string.char((ciphertext:byte(-1) ~ 0xFF) % 256) + local plaintext_cbc = crypto.decrypt("aes", key, corrupted, { mode = "cbc", iv = iv }) + assert(plaintext_cbc == nil, "CBC decryption with corrupted ciphertext should fail") + + -- CTR mode (should not error, but output will be wrong) + Log(kLogVerbose," - CTR decryption with corrupted ciphertext should not match original") + local ciphertext_ctr, iv_ctr = crypto.encrypt("aes", key, plaintext, { mode = "ctr" }) + local corrupted_ctr = ciphertext_ctr:sub(1, #ciphertext_ctr - 1) .. + string.char((ciphertext_ctr:byte(-1) ~ 0xFF) % 256) + local plaintext_ctr = crypto.decrypt("aes", key, corrupted_ctr, { mode = "ctr", iv = iv_ctr }) + assert(plaintext_ctr ~= plaintext, "CTR decryption with corrupted ciphertext should not match original") + + -- GCM mode (should fail authentication) + Log(kLogVerbose,"GCM decryption with corrupted ciphertext should fail") + local ciphertext_gcm, iv_gcm, tag = crypto.encrypt("aes", key, plaintext, { mode = "gcm" }) + local corrupted_gcm = ciphertext_gcm:sub(1, #ciphertext_gcm - 1) .. + string.char((ciphertext_gcm:byte(-1) ~ 0xFF) % 256) + local plaintext_gcm = crypto.decrypt("aes", key, corrupted_gcm, { mode = "gcm", iv = iv_gcm, tag = tag }) + assert(plaintext_gcm == nil, "GCM decryption with corrupted ciphertext should fail") + + -- GCM mode with corrupted tag + Log(kLogVerbose,"GCM decryption with corrupted tag should fail") + local badtag = tag:sub(1, #tag - 1) .. string.char((tag:byte(-1) ~ 0xFF) % 256) + local plaintext_gcm2 = crypto.decrypt("aes", key, ciphertext_gcm, { mode = "gcm", iv = iv_gcm, tag = badtag }) + assert(plaintext_gcm2 ~= plaintext, "GCM decryption with corrupted tag should fail") +end + +-- Test AES encryption/decryption with empty plaintext +local function test_aes_empty_plaintext() + local key = crypto.generateKeyPair('aes', 256) + local empty = "" + for _, mode in ipairs({ "cbc", "ctr", "gcm" }) do + local ciphertext, iv, tag = crypto.encrypt("aes", key, empty, { mode = mode }) + assert(type(ciphertext) == "string", "AES " .. mode .. " encrypt empty string") + + local opts = { mode = mode, iv = iv, tag = tag } + if mode ~= "gcm" then opts.tag = nil end + + local plaintext = crypto.decrypt("aes", key, ciphertext, opts) + assert(plaintext == empty, "AES " .. mode .. " decrypt empty string") + end +end + +-- Test sign/verify with empty message +local function test_sign_verify_empty_message() + local priv_key, pub_key = crypto.generateKeyPair("rsa", 2048) + local signature = crypto.sign("rsa", priv_key, "", "sha256") + assert(type(signature) == "string", "RSA sign empty message") + + local is_valid = crypto.verify("rsa", pub_key, "", signature, "sha256") + assert(is_valid == true, "RSA verify empty message") + + local ec_priv, ec_pub = crypto.generateKeyPair("ecdsa", "secp256r1") + local ec_sig = crypto.sign("ecdsa", ec_priv, "", "sha256") + assert(type(ec_sig) == "string", "ECDSA sign empty message") + + local ec_valid = crypto.verify("ecdsa", ec_pub, "", ec_sig, "sha256") + assert(ec_valid == true, "ECDSA verify empty message") +end + +-- Test JWK to PEM with minimal valid JWKs and missing fields +local function test_jwk_to_pem_minimal() + -- Minimal valid RSA public JWK + local _, pub_key = crypto.generateKeyPair("rsa", 2048) + local pub_jwk = crypto.convertPemToJwk(pub_key) + local minimal_jwk = { kty = pub_jwk.kty, n = pub_jwk.n, e = pub_jwk.e } + Log(kLogVerbose," - Testing minimal JWK to PEM conversion") + local pem = crypto.convertJwkToPem(minimal_jwk) + assert(type(pem) == "string", "Minimal RSA JWK to PEM") + + -- Missing 'n' field + Log(kLogVerbose," - Testing missing 'n' field in JWK to PEM conversion") + local bad_jwk = { kty = "RSA", e = pub_jwk.e } + local pem2 = crypto.convertJwkToPem(bad_jwk) + assert(pem2 == nil, "JWK to PEM with missing n should fail") + + -- Minimal EC public JWK + Log(kLogVerbose," - Testing minimal EC JWK to PEM conversion") + local _, ec_pub = crypto.generateKeyPair("ecdsa", "secp256r1") + local ec_jwk = crypto.convertPemToJwk(ec_pub) + local minimal_ec_jwk = { kty = ec_jwk.kty, crv = ec_jwk.crv, x = ec_jwk.x, y = ec_jwk.y } + local ec_pem = crypto.convertJwkToPem(minimal_ec_jwk) + assert(type(ec_pem) == "string", "Minimal EC JWK to PEM") + + -- Missing 'x' field + Log(kLogVerbose," - Testing missing 'x' field in EC JWK to PEM conversion") + local bad_ec_jwk = { kty = "EC", crv = ec_jwk.crv, y = ec_jwk.y } + local ec_pem2 = crypto.convertJwkToPem(bad_ec_jwk) + assert(ec_pem2 == nil, "EC JWK to PEM with missing x should fail") +end + +-- Test PEM to JWK with corrupted PEM +local function test_pem_to_jwk_corrupted() + local badpem = "-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA7\n-----END PUBLIC KEY-----" + local ok = pcall(function() return crypto.convertPemToJwk(badpem) end) + assert(ok == false, "PEM to JWK with corrupted PEM should fail") +end + +-- Test CSR generation with missing/invalid subject/SAN +local function test_csr_generation_edge_cases() + local priv_key, _ = crypto.generateKeyPair() + -- Missing subject and SAN + local csr = crypto.generateCsr(priv_key) + assert(csr == nil, "CSR with missing subject and SAN should fail") + -- Invalid SAN type (not validated yet) + -- local csr2, err2 = crypto.generateCsr(priv_key, "CN=foo", 12345) + -- assert(csr2 == nil, "CSR with invalid SAN type should fail") +end + +-- Test unsupported AES mode +local function test_unsupported_aes_mode() + Log(kLogVerbose," - AES decrypt with unsupported mode should fail") + local key = crypto.generateKeyPair('aes', 256) + local ciphertext = crypto.encrypt('aes', key, 'data', { mode = 'ofb' }) + assert(ciphertext == nil, "AES encrypt with unsupported mode should fail") + + local plaintext = crypto.decrypt('aes', key, 'data', { mode = 'ofb', iv = string.rep('A', 16) }) + assert(plaintext == nil, "AES decrypt with unsupported mode should fail") +end + +-- Test encrypting and signing very large messages +local function test_large_message_handling() + local priv_key, pub_key = crypto.generateKeyPair("rsa", 2048) + local large_message = string.rep("A", 1024 * 1024) -- 1MB + + -- RSA encryption (should fail or be limited by key size) + Log(kLogVerbose," - RSA encrypt large message should fail or be limited") + local ciphertext = crypto.encrypt("rsa", pub_key, large_message) + assert(ciphertext == nil, "RSA encrypt large message should fail or be limited") + + -- AES encryption (should succeed) + Log(kLogVerbose," - AES encrypt large message") + local key = crypto.generateKeyPair('aes', 256) + local aes_ciphertext, iv = crypto.encrypt("aes", key, large_message, { mode = "cbc" }) + assert(type(aes_ciphertext) == "string", "AES encrypt large message") + + local decrypted_large_message = crypto.decrypt("aes", key, aes_ciphertext, { mode = "cbc", iv = iv }) + assert(decrypted_large_message == large_message, "AES decrypt large message") + + -- RSA sign large message + Log(kLogVerbose," - RSA verify large message") + local signature = crypto.sign("rsa", priv_key, large_message, "sha256") + assert(type(signature) == "string", "RSA sign large message") + + local is_valid = crypto.verify("rsa", pub_key, large_message, signature, "sha256") + assert(is_valid == true, "RSA verify large message") +end + +-- Test passing non-string values as keys/messages/options +local function test_invalid_types() + local priv_key, _ = crypto.generateKeyPair("rsa", 2048) + local key = crypto.generateKeyPair('aes', 256) + + -- Non-string message + Log(kLogVerbose," - RSA sign with integer message should fail") + local signature = crypto.sign("rsa", priv_key, 12345, "sha256") + assert(signature == nil, "RSA sign with non-string message should fail") + + Log(kLogVerbose," - AES encrypt with boolean message should fail") + ciphertext = crypto.encrypt("aes", key, true, { mode = "cbc" }) + assert(ciphertext == nil, "AES encrypt with boolean message should fail") + + -- Non-string key + Log(kLogVerbose," - RSA sign with table as key should fail") + signature = crypto.sign("rsa", {}, "msg", "sha256") + assert(signature == nil, "RSA sign with table as key should fail") + + -- Non-table options + Log(kLogVerbose," - AES encrypt with number as options should fail") + ciphertext = crypto.encrypt("aes", key, "msg", 123) + assert(ciphertext == nil, "AES encrypt with number as options should fail") +end + +-- Test encrypting with one mode and decrypting with another +local function test_mixed_mode_operations() + local key = crypto.generateKeyPair('aes', 256) + local plaintext = "Mixed mode test" + local ciphertext, iv = crypto.encrypt("aes", key, plaintext, { mode = "cbc" }) + local decrypted_plaintext = crypto.decrypt("aes", key, ciphertext, { mode = "ctr", iv = iv }) + assert(decrypted_plaintext ~= plaintext, "Decrypt CBC ciphertext with CTR mode should not match") + + local ciphertext2, iv2 = crypto.encrypt("aes", key, plaintext, { mode = "ctr" }) + local decrypted_plaintext2 = crypto.decrypt("aes", key, ciphertext2, { mode = "cbc", iv = iv2 }) + assert(decrypted_plaintext2 ~= plaintext, "Decrypt CTR ciphertext with CBC mode should not match") +end + +-- Test signing/verifying/converting with nil or empty parameters +local function test_nil_empty_parameters() + Log(kLogVerbose," - RSA sign with nil message should fail") + local priv_key, pub_key = crypto.generateKeyPair("rsa", 2048) + local signature = crypto.sign("rsa", priv_key, nil, "sha256") + assert(signature == nil, "Sign with nil message should fail") + + Log(kLogVerbose," - RSA verify with nil message should fail") + local is_valid = crypto.verify("rsa", pub_key, nil, "sig", "sha256") + assert(is_valid == nil, "Verify with nil message should fail") + + Log(kLogVerbose," - JWK to PEM with nil should fail") + local ok = pcall(function() return crypto.convertJwkToPem(nil) end) + assert(ok == false, "convertJwkToPem with nil should fail") + + Log(kLogVerbose," - JWK to PEM with empty string should fail") + ok = pcall(function() return crypto.convertJwkToPem("") end) + assert(ok == false, "convertJwkToPem with empty string should fail") end -- Run all tests local function run_tests() + Log(kLogVerbose,"Testing RSA keypair generation...") test_rsa_keypair_generation() + Log(kLogVerbose,"Testing RSA signing and verification...") test_rsa_signing_verification() + Log(kLogVerbose,"Testing RSA-PSS signing and verification...") + test_rsapss_signing_verification() + Log(kLogVerbose,"Testing RSA encryption and decryption...") test_rsa_encryption_decryption() + Log(kLogVerbose,"Testing RSA key size variations...") + test_rsa_key_sizes() + + Log(kLogVerbose,"Testing ECDSA keypair generation...") test_ecdsa_keypair_generation() + Log(kLogVerbose,"Testing ECDSA signing and verification...") test_ecdsa_signing_verification() + Log(kLogVerbose,"Testing ECDSA curves...") + test_ecdsa_curves() + + Log(kLogVerbose,"Testing AES key generation...") test_aes_key_generation() + Log(kLogVerbose,"Testing AES encryption and decryption (CBC mode)...") test_aes_encryption_decryption_cbc() + Log(kLogVerbose,"Testing AES encryption and decryption (CTR mode)...") test_aes_encryption_decryption_ctr() + Log(kLogVerbose,"Testing AES encryption and decryption (GCM mode)...") test_aes_encryption_decryption_gcm() + Log(kLogVerbose,"Testing unsupported AES mode...") + test_unsupported_aes_mode() + Log(kLogVerbose,"Testing AES key sizes...") + test_aes_key_sizes() + Log(kLogVerbose,"Testing AES decryption with corrupted ciphertext and tag...") + test_aes_corruption_handling() + Log(kLogVerbose,"Testing AES encryption/decryption with empty plaintext...") + test_aes_empty_plaintext() + Log(kLogVerbose,"Testing large message encryption and signing...") + test_large_message_handling() + + Log(kLogVerbose,"Testing various hash algorithms...") + test_hash_algorithms() + Log(kLogVerbose,"Testing negative cases for hash algorithms...") + test_negative_hash_algorithms() + Log(kLogVerbose,"Testing sign/verify with empty message...") + test_sign_verify_empty_message() + Log(kLogVerbose,"Testing invalid input types...") + test_invalid_types() + Log(kLogVerbose,"Testing mixed mode encryption/decryption...") + test_mixed_mode_operations() + Log(kLogVerbose,"Testing nil/empty parameters...") + test_nil_empty_parameters() + Log(kLogVerbose,"Testing negative keypair generation...") + test_negative_keypair_generation() + Log(kLogVerbose,"Testing negative encrypt/decrypt...") + test_negative_encrypt_decrypt() + Log(kLogVerbose,"Testing negative sign/verify...") + test_negative_sign_verify() + + Log(kLogVerbose,"Testing PEM to JWK conversion...") test_pem_to_jwk() + Log(kLogVerbose,"Testing PEM to JWK with corrupted PEM...") + test_pem_to_jwk_corrupted() + Log(kLogVerbose,"Testing JWK to PEM conversion...") test_jwk_to_pem() + Log(kLogVerbose,"Testing negative PEM/JWK conversion...") + test_negative_pem_jwk_conversion() + Log(kLogVerbose,"Testing JWK to PEM with minimal valid JWKs and missing fields...") + test_jwk_to_pem_minimal() + Log(kLogVerbose,"Testing CSR generation...") test_csr_generation() + Log(kLogVerbose,"Testing CSR generation with missing/invalid subject/SAN...") + test_csr_generation_edge_cases() + Log(kLogVerbose,"Testing negative CSR generation...") + test_negative_csr_generation() EXIT = 0 return EXIT end diff --git a/tool/net/definitions.lua b/tool/net/definitions.lua index 661e0ea27..c21ad82cc 100644 --- a/tool/net/definitions.lua +++ b/tool/net/definitions.lua @@ -8051,68 +8051,85 @@ kUrlLatin1 = nil --- This module provides cryptographic operations. ---- The crypto module for cryptographic operations crypto = {} ---- Converts a PEM-encoded key to JWK format ----@param pem string PEM-encoded key ----@return table?, string? JWK table or nil on error ----@return string? error message -function crypto.convertPemToJwk(pem) end +--- Signs a message using the specified key type. +--- Supported types: "rsa", "rsa-pss", "ecdsa" +---@param type "rsa"|"rsa-pss"|"rsapss"|"ecdsa" +---@param key string PEM-encoded private key +---@param message string +---@param hash? string Hash algorithm ("sha256", "sha384", "sha512"). Default: "sha256" +---@return string signature +---@overload fun(type: string, key: string, message: string, hash?: string): nil, error: string +function crypto.sign(type, key, message, hash) end ---- Generates a Certificate Signing Request (CSR) ----@param key_pem string PEM-encoded private key ----@param subject_name string? X.509 subject name ----@param san_list string? Subject Alternative Names ----@return string?, string? CSR in PEM format or nil on error and error message -function crypto.generateCsr(key_pem, subject_name, san_list) end +--- Verifies a signature using the specified key type. +--- Supported types: "rsa", "rsa-pss", "ecdsa" +---@param type "rsa"|"rsa-pss"|"rsapss"|"ecdsa" +---@param key string PEM-encoded public key +---@param message string +---@param signature string +---@param hash? string Hash algorithm ("sha256", "sha384", "sha512"). Default: "sha256" +---@return boolean valid +function crypto.verify(type, key, message, signature, hash) end ---- Signs data using a private key ----@param key_type string "rsa" or "ecdsa" ----@param private_key string PEM-encoded private key ----@param message string Data to sign ----@param hash_algo string? Hash algorithm (default: SHA-256) ----@return string?, string? Signature or nil on error and error message -function crypto.sign(key_type, private_key, message, hash_algo) end +--- Encrypts data using the specified cipher. +--- Supported ciphers: "rsa", "aes" +---@param cipher "rsa"|"aes" +---@param key string PEM-encoded public key (RSA) or raw key (AES) +---@param plaintext string +---@param options? table Options for AES: { mode="cbc"|"gcm"|"ctr", iv=string, aad=string } +---@return string ciphertext, string? iv, string? tag +---@overload fun(cipher: string, key: string, plaintext: string, options?: table): nil, error: string +function crypto.encrypt(cipher, key, plaintext, options) end ---- Verifies a signature ----@param key_type string "rsa" or "ecdsa" ----@param public_key string PEM-encoded public key ----@param message string Original message ----@param signature string Signature to verify ----@param hash_algo string? Hash algorithm (default: SHA-256) ----@return boolean?, string? True if valid or nil on error and error message -function crypto.verify(key_type, public_key, message, signature, hash_algo) end +--- Decrypts data using the specified cipher. +--- Supported ciphers: "rsa", "aes" +---@param cipher "rsa"|"aes" +---@param key string PEM-encoded private key (RSA) or raw key (AES) +---@param ciphertext string +---@param options? table Options for AES: { mode="cbc"|"gcm"|"ctr", iv=string, tag=string, aad=string } +---@return string plaintext +---@overload fun(cipher: string, key: string, ciphertext: string, options?: table): nil, error: string +function crypto.decrypt(cipher, key, ciphertext, options) end ---- Encrypts data ----@param cipher_type string "rsa" or "aes" ----@param key string Public key or symmetric key ----@param plaintext string Data to encrypt ----@param mode string? AES mode: "cbc", "gcm", "ctr" (default: "cbc") ----@param iv string? Initialization Vector for AES ----@param aad string? Additional data for AES-GCM ----@return string? Encrypted data or nil on error ----@return string? IV or error message ----@return string? Authentication tag for GCM mode -function crypto.encrypt(cipher_type, key, plaintext, mode, iv, aad) end +--- Generates a key pair. +--- For RSA: bits = 2048 or 4096. +--- For ECDSA: curve = "secp256r1", "secp384r1", "secp521r1", "curve25519" +--- For AES: bits = 128, 192, or 256. +---@param type "rsa"|"ecdsa"|"aes" +---@param param? integer|string For RSA: bits; for ECDSA: curve name; for AES: bits +---@return string private_key, string public_key|nil +---@overload fun(type: string, param?: integer|string): nil, error: string +function crypto.generateKeyPair(type, param) end ---- Decrypts data ----@param cipher_type string "rsa" or "aes" ----@param key string Private key or symmetric key ----@param ciphertext string Data to decrypt ----@param iv string? Initialization Vector for AES ----@param mode string? AES mode: "cbc", "gcm", "ctr" (default: "cbc") ----@param tag string? Authentication tag for AES-GCM ----@param aad string? Additional data for AES-GCM ----@return string?, string? Decrypted data or nil on error and error message -function crypto.decrypt(cipher_type, key, ciphertext, iv, mode, tag, aad) end +--- Converts a JWK (JSON Web Key, as a Lua table) to PEM format. +---@param jwk table +---@return string pem +---@overload fun(jwk: table): nil, error: string +function crypto.convertJwkToPem(jwk) end ---- Generates cryptographic keys ----@param key_type string? "rsa", "ecdsa", or "aes" ----@param key_size_or_curve number|string? Key size or curve name ----@return string? Private key or nil on error ----@return string? Public key (nil for AES) or error message -function crypto.generatekeypair(key_type, key_size_or_curve) end +--- Converts a PEM key to JWK (JSON Web Key, as a Lua table). +---@param pem string +---@param claims? table Additional claims to merge (RFC7517 fields) +---@return table jwk +---@overload fun(pem: string, claims?: table): nil, error: string +function crypto.convertPemToJwk(pem, claims) end + +--- Generates a Certificate Signing Request (CSR) from a PEM key. +---@param key string PEM-encoded private key +---@param subject string Subject name (e.g., "CN=example.com") +---@param sans? string Subject Alternative Names (SANs) as a comma-separated string +---@return string csr_pem +---@overload fun(key: string, subject: string, sans?: string): nil, error: string +function crypto.generateCsr(key, subject, sans) end + +-- AES options table for encrypt/decrypt: +---@class CryptoAesOptions +---@field mode? "cbc"|"gcm"|"ctr" +---@field iv? string +---@field tag? string +---@field aad? string --[[ diff --git a/tool/net/lcrypto.c b/tool/net/lcrypto.c index e50304fe4..7ce0c3cd4 100644 --- a/tool/net/lcrypto.c +++ b/tool/net/lcrypto.c @@ -32,6 +32,163 @@ #include "third_party/mbedtls/rsa.h" #include "third_party/mbedtls/x509_csr.h" +// Supported curves mapping +typedef struct { + const char *name; + mbedtls_ecp_group_id id; +} curve_map_t; + +static const curve_map_t supported_curves[] = { + {"secp256r1", MBEDTLS_ECP_DP_SECP256R1}, + {"P-256", MBEDTLS_ECP_DP_SECP256R1}, + {"secp384r1", MBEDTLS_ECP_DP_SECP384R1}, + {"P-384", MBEDTLS_ECP_DP_SECP384R1}, + {"secp521r1", MBEDTLS_ECP_DP_SECP521R1}, + {"P-521", MBEDTLS_ECP_DP_SECP521R1}, + {"curve25519", MBEDTLS_ECP_DP_CURVE25519}, + {NULL, 0}}; + +// List available curves +static int LuaListCurves(lua_State *L) { + const curve_map_t *curve = supported_curves; + int i = 1; + + lua_newtable(L); + + while (curve->name != NULL) { + lua_pushstring(L, curve->name); + lua_rawseti(L, -2, i++); + curve++; + } + + return 1; +} + +// Remove hash_algorithm_t and hash_to_md_type indirection +static mbedtls_md_type_t string_to_md_type(const char *hash_name) { + if (!hash_name || !*hash_name) { + return MBEDTLS_MD_SHA256; // Default to SHA-256 if no name provided + } + if (strcasecmp(hash_name, "sha256") == 0 || strcasecmp(hash_name, "sha-256") == 0) { + return MBEDTLS_MD_SHA256; + } else if (strcasecmp(hash_name, "sha384") == 0 || strcasecmp(hash_name, "sha-384") == 0) { + return MBEDTLS_MD_SHA384; + } else if (strcasecmp(hash_name, "sha512") == 0 || strcasecmp(hash_name, "sha-512") == 0) { + return MBEDTLS_MD_SHA512; + } else { + WARNF("(crypto) Unknown hash algorithm '%s'", hash_name); + return -1; + } +} + +static size_t get_hash_size_from_md_type(mbedtls_md_type_t md_type) { + switch (md_type) { + case MBEDTLS_MD_SHA256: + return 32; + case MBEDTLS_MD_SHA384: + return 48; + case MBEDTLS_MD_SHA512: + return 64; + default: + return 32; + } +} + +// List available hash algorithms +static int LuaListHashAlgorithms(lua_State *L) { + lua_newtable(L); + + lua_pushstring(L, "SHA256"); + lua_rawseti(L, -2, 1); + + lua_pushstring(L, "SHA384"); + lua_rawseti(L, -2, 2); + + lua_pushstring(L, "SHA512"); + lua_rawseti(L, -2, 3); + + // Add hyphenated versions + lua_pushstring(L, "SHA-256"); + lua_rawseti(L, -2, 4); + + lua_pushstring(L, "SHA-384"); + lua_rawseti(L, -2, 5); + + lua_pushstring(L, "SHA-512"); + lua_rawseti(L, -2, 6); + + lua_pushstring(L, "SHA1"); + lua_rawseti(L, -2, 7); + + lua_pushstring(L, "MD5"); + lua_rawseti(L, -2, 8); + + return 1; +} + +static int compute_hash(mbedtls_md_type_t md_type, const unsigned char *input, + size_t input_len, unsigned char *output, + size_t output_size) { + mbedtls_md_context_t md_ctx; + const mbedtls_md_info_t *md_info; + int ret; + + md_info = mbedtls_md_info_from_type(md_type); + if (md_info == NULL) { + WARNF("(crypto) Unsupported hash algorithm"); + return -1; + } + + if (output_size < mbedtls_md_get_size(md_info)) { + WARNF("(crypto) Output buffer too small for hash"); + return -1; + } + + mbedtls_md_init(&md_ctx); + + ret = mbedtls_md_setup(&md_ctx, md_info, 0); // 0 = non-HMAC + if (ret != 0) { + WARNF("(crypto) Failed to set up hash context: -0x%04x", -ret); + goto cleanup; + } + + ret = mbedtls_md_starts(&md_ctx); + if (ret != 0) { + WARNF("(crypto) Failed to start hash operation: -0x%04x", -ret); + goto cleanup; + } + + ret = mbedtls_md_update(&md_ctx, input, input_len); + if (ret != 0) { + WARNF("(crypto) Failed to update hash: -0x%04x", -ret); + goto cleanup; + } + + ret = mbedtls_md_finish(&md_ctx, output); + if (ret != 0) { + WARNF("(crypto) Failed to finish hash: -0x%04x", -ret); + goto cleanup; + } + +cleanup: + mbedtls_md_free(&md_ctx); + return ret; +} + +// Find curve ID by name +static mbedtls_ecp_group_id find_curve_by_name(const char *name) { + const curve_map_t *curve = supported_curves; + + while (curve->name != NULL) { + if (strcasecmp(curve->name, name) == 0) { + return curve->id; + } + curve++; + } + + return MBEDTLS_ECP_DP_NONE; +} + // Strong RNG using mbedtls_entropy_context and mbedtls_ctr_drbg_context int GenerateRandom(void *ctx, unsigned char *output, size_t len) { static mbedtls_entropy_context entropy; @@ -77,7 +234,7 @@ static bool RSAGenerateKeyPair(char **private_key_pem, size_t *private_key_len, // Initialize as RSA key if ((rc = mbedtls_pk_setup(&key, mbedtls_pk_info_from_type(MBEDTLS_PK_RSA))) != 0) { - WARNF("Failed to setup key (grep -0x%04x)", -rc); + WARNF("(crypto) Failed to setup key (grep -0x%04x)", -rc); mbedtls_pk_free(&key); return false; } @@ -85,7 +242,7 @@ static bool RSAGenerateKeyPair(char **private_key_pem, size_t *private_key_len, // Generate RSA key if ((rc = mbedtls_rsa_gen_key(mbedtls_pk_rsa(key), GenerateRandom, 0, key_length, 65537)) != 0) { - WARNF("Failed to generate key (grep -0x%04x)", -rc); + WARNF("(crypto) Failed to generate key (grep -0x%04x)", -rc); mbedtls_pk_free(&key); return false; } @@ -95,7 +252,7 @@ static bool RSAGenerateKeyPair(char **private_key_pem, size_t *private_key_len, *private_key_pem = calloc(1, *private_key_len); if ((rc = mbedtls_pk_write_key_pem(&key, (unsigned char *)*private_key_pem, *private_key_len)) != 0) { - WARNF("Failed to write private key (grep -0x%04x)", -rc); + WARNF("(crypto) Failed to write private key (grep -0x%04x)", -rc); free(*private_key_pem); mbedtls_pk_free(&key); return false; @@ -107,7 +264,7 @@ static bool RSAGenerateKeyPair(char **private_key_pem, size_t *private_key_len, *public_key_pem = calloc(1, *public_key_len); if ((rc = mbedtls_pk_write_pubkey_pem(&key, (unsigned char *)*public_key_pem, *public_key_len)) != 0) { - WARNF("Failed to write public key (grep -0x%04x)", -rc); + WARNF("(crypto) Failed to write public key (grep -0x%04x)", -rc); free(*private_key_pem); free(*public_key_pem); mbedtls_pk_free(&key); @@ -128,6 +285,12 @@ static int LuaRSAGenerateKeyPair(lua_State *L) { } else { bits = (int)luaL_optinteger(L, 2, 2048); } + // Check if key length is valid (only 2048 or 4096 bits are allowed) + if (bits != 2048 && bits != 4096) { + lua_pushnil(L); + lua_pushfstring(L, "Invalid RSA key length: %d. Only 2048 or 4096 bits key lengths are supported", bits); + return 2; + } char *private_key, *public_key; size_t private_len, public_len; @@ -174,14 +337,14 @@ static char *RSAEncrypt(const char *public_key_pem, const unsigned char *data, if ((rc = mbedtls_pk_parse_public_key(&key, (const unsigned char *)public_key_pem, strlen(public_key_pem) + 1)) != 0) { - WARNF("Failed to parse public key (grep -0x%04x)", -rc); + WARNF("(crypto) Failed to parse public key (grep -0x%04x)", -rc); mbedtls_pk_free(&key); return NULL; } // Check if key is RSA if (mbedtls_pk_get_type(&key) != MBEDTLS_PK_RSA) { - WARNF("Key is not an RSA key"); + WARNF("(crypto) Key is not an RSA key"); mbedtls_pk_free(&key); return NULL; } @@ -198,7 +361,7 @@ static char *RSAEncrypt(const char *public_key_pem, const unsigned char *data, if ((rc = mbedtls_rsa_pkcs1_encrypt(mbedtls_pk_rsa(key), GenerateRandom, 0, MBEDTLS_RSA_PUBLIC, data_len, data, output)) != 0) { - WARNF("Encryption failed (grep -0x%04x)", -rc); + WARNF("(crypto) Encryption failed (grep -0x%04x)", -rc); free(output); mbedtls_pk_free(&key); return NULL; @@ -211,7 +374,19 @@ static char *RSAEncrypt(const char *public_key_pem, const unsigned char *data, static int LuaRSAEncrypt(lua_State *L) { // Args: key, plaintext, options table size_t keylen, ptlen; + // Ensure key is a string + if (!lua_isstring(L, 1)) { + lua_pushnil(L); + lua_pushstring(L, "Key must be a string"); + return 2; + } const char *key = luaL_checklstring(L, 1, &keylen); + // Ensure plaintext is a string + if (!lua_isstring(L, 2)) { + lua_pushnil(L); + lua_pushstring(L, "Plaintext must be a string"); + return 2; + } const unsigned char *plaintext = (const unsigned char *)luaL_checklstring(L, 2, &ptlen); // int options_idx = 3; @@ -230,7 +405,6 @@ static int LuaRSAEncrypt(lua_State *L) { return 1; } - static char *RSADecrypt(const char *private_key_pem, const unsigned char *encrypted_data, size_t encrypted_len, size_t *out_len) { @@ -239,16 +413,17 @@ static char *RSADecrypt(const char *private_key_pem, // Parse private key mbedtls_pk_context key; mbedtls_pk_init(&key); - if ((rc = mbedtls_pk_parse_key(&key, (const unsigned char *)private_key_pem, - strlen(private_key_pem) + 1, NULL, 0)) != 0) { - WARNF("Failed to parse private key (grep -0x%04x)", -rc); + rc = mbedtls_pk_parse_key(&key, (const unsigned char *)private_key_pem, + strlen(private_key_pem) + 1, NULL, 0); + if (rc != 0) { + WARNF("(crypto) Failed to parse private key (grep -0x%04x)", -rc); mbedtls_pk_free(&key); return NULL; } // Check if key is RSA if (mbedtls_pk_get_type(&key) != MBEDTLS_PK_RSA) { - WARNF("Key is not an RSA key"); + WARNF("(crypto) Key is not an RSA key"); mbedtls_pk_free(&key); return NULL; } @@ -263,10 +438,11 @@ static char *RSADecrypt(const char *private_key_pem, // Decrypt data size_t output_len = 0; - if ((rc = mbedtls_rsa_pkcs1_decrypt(mbedtls_pk_rsa(key), GenerateRandom, 0, - MBEDTLS_RSA_PRIVATE, &output_len, - encrypted_data, output, key_size)) != 0) { - WARNF("Decryption failed (grep -0x%04x)", -rc); + rc = mbedtls_rsa_pkcs1_decrypt(mbedtls_pk_rsa(key), GenerateRandom, 0, + MBEDTLS_RSA_PRIVATE, &output_len, + encrypted_data, output, key_size); + if (rc != 0) { + WARNF("(crypto) Decryption failed (grep -0x%04x)", -rc); free(output); mbedtls_pk_free(&key); return NULL; @@ -279,7 +455,19 @@ static char *RSADecrypt(const char *private_key_pem, static int LuaRSADecrypt(lua_State *L) { // Args: key, ciphertext, options table size_t keylen, ctlen; + // Ensure key is a string + if (!lua_isstring(L, 1)) { + lua_pushnil(L); + lua_pushstring(L, "Key must be a string"); + return 2; + } const char *key = luaL_checklstring(L, 1, &keylen); + // Ensure ciphertext is a string + if (!lua_isstring(L, 2)) { + lua_pushnil(L); + lua_pushstring(L, "Ciphertext must be a string"); + return 2; + } const unsigned char *ciphertext = (const unsigned char *)luaL_checklstring(L, 2, &ctlen); // int options_idx = 3; @@ -311,18 +499,8 @@ static char *RSASign(const char *private_key_pem, const unsigned char *data, // Determine hash algorithm if (hash_algo_str) { - if (strcasecmp(hash_algo_str, "sha256") == 0) { - hash_algo = MBEDTLS_MD_SHA256; - hash_len = 32; - } else if (strcasecmp(hash_algo_str, "sha384") == 0) { - hash_algo = MBEDTLS_MD_SHA384; - hash_len = 48; - } else if (strcasecmp(hash_algo_str, "sha512") == 0) { - hash_algo = MBEDTLS_MD_SHA512; - hash_len = 64; - } else { - return NULL; // Unsupported hash algorithm - } + hash_algo = string_to_md_type(hash_algo_str); + hash_len = get_hash_size_from_md_type(hash_algo); } // Parse private key @@ -330,14 +508,14 @@ static char *RSASign(const char *private_key_pem, const unsigned char *data, mbedtls_pk_init(&key); if ((rc = mbedtls_pk_parse_key(&key, (const unsigned char *)private_key_pem, strlen(private_key_pem) + 1, NULL, 0)) != 0) { - WARNF("Failed to parse private key (grep -0x%04x)", -rc); + WARNF("(crypto) Failed to parse private key (grep -0x%04x)", -rc); mbedtls_pk_free(&key); return NULL; } // Check if key is RSA if (mbedtls_pk_get_type(&key) != MBEDTLS_PK_RSA) { - WARNF("Key is not an RSA key"); + WARNF("(crypto) Key is not an RSA key"); mbedtls_pk_free(&key); return NULL; } @@ -376,6 +554,24 @@ static int LuaRSASign(lua_State *L) { size_t sig_len = 0; // Get parameters from Lua + if (lua_type(L, 1) != LUA_TSTRING) { + lua_pushnil(L); + lua_pushstring(L, "Key must be a string"); + return 2; + } + if (lua_type(L, 1) == LUA_TTABLE) { + DEBUGF("(crypto) Detected table type for parameter 1"); + lua_pushnil(L); + lua_pushstring(L, "Key must be a string, got a table instead"); + return 2; + } + // Ensure msg is a string + if (lua_type(L, 2) != LUA_TSTRING) { + lua_pushnil(L); + lua_pushstring(L, "Plaintext must be a string"); + return 2; + } + key_pem = luaL_checklstring(L, 1, &key_len); msg = luaL_checklstring(L, 2, &msg_len); @@ -401,6 +597,119 @@ static int LuaRSASign(lua_State *L) { return 1; } +// RSA PSS Signing +static char *RSAPSSSign(const char *private_key_pem, const unsigned char *data, + size_t data_len, const char *hash_algo_str, + size_t *sig_len, int salt_len) { + int rc; + unsigned char hash[64]; // Large enough for SHA-512 + size_t hash_len = 32; // Default for SHA-256 + unsigned char *signature; + mbedtls_md_type_t hash_algo = MBEDTLS_MD_SHA256; // Default + + // Determine hash algorithm + if (hash_algo_str) { + hash_algo = string_to_md_type(hash_algo_str); + hash_len = get_hash_size_from_md_type(hash_algo); + } + + // Parse private key + mbedtls_pk_context key; + mbedtls_pk_init(&key); + rc = mbedtls_pk_parse_key(&key, (const unsigned char *)private_key_pem, + strlen(private_key_pem) + 1, NULL, 0); + if (rc != 0) { + WARNF("(crypto) Failed to parse private key (grep -0x%04x)", -rc); + mbedtls_pk_free(&key); + return NULL; + } + + // Check if key is RSA + if (mbedtls_pk_get_type(&key) != MBEDTLS_PK_RSA) { + WARNF("(crypto) Key is not an RSA key"); + mbedtls_pk_free(&key); + return NULL; + } + + // Hash the message + rc = mbedtls_md(mbedtls_md_info_from_type(hash_algo), data, data_len, hash); + if (rc != 0) { + mbedtls_pk_free(&key); + return NULL; + } + + // Allocate buffer for signature + signature = malloc(MBEDTLS_PK_SIGNATURE_MAX_SIZE); + if (!signature) { + mbedtls_pk_free(&key); + return NULL; + } + + // Get RSA context + mbedtls_rsa_context *rsa = mbedtls_pk_rsa(key); + mbedtls_rsa_set_padding(rsa, MBEDTLS_RSA_PKCS_V21, hash_algo); + + // Sign the hash using PSS + rc = mbedtls_rsa_rsassa_pss_sign( + rsa, GenerateRandom, 0, MBEDTLS_RSA_PRIVATE, hash_algo, (unsigned int)hash_len, + hash, signature); + if (rc != 0) { + free(signature); + mbedtls_pk_free(&key); + return NULL; + } + + *sig_len = mbedtls_pk_get_len(&key); + + // Clean up + mbedtls_pk_free(&key); + + return (char *)signature; +} + +static int LuaRSAPSSSign(lua_State *L) { + size_t key_len, msg_len; + const char *key_pem, *msg; + unsigned char *signature; + size_t sig_len = 0; + + // Get parameters from Lua + if (lua_type(L, 1) != LUA_TSTRING) { + lua_pushnil(L); + lua_pushstring(L, "Key must be a string"); + return 2; + } + // Ensure msg is a string + if (lua_type(L, 2) != LUA_TSTRING) { + lua_pushnil(L); + lua_pushstring(L, "Plaintext must be a string"); + return 2; + } + + // Get parameters from Lua + key_pem = luaL_checklstring(L, 1, &key_len); + msg = luaL_checklstring(L, 2, &msg_len); + + // Optional hash algorithm parameter + const char *hash_name = luaL_optstring(L, 3, "sha256"); + + // Optional salt length parameter + int salt_len = luaL_optinteger(L, 4, -1); + + // Call the C implementation + signature = (unsigned char *)RSAPSSSign(key_pem, (const unsigned char *)msg, + msg_len, hash_name, &sig_len, salt_len); + + if (!signature) { + return luaL_error(L, "failed to sign message (PSS)"); + } + + // Return the signature as a Lua string + lua_pushlstring(L, (char *)signature, sig_len); + free(signature); + return 1; +} + static int RSAVerify(const char *public_key_pem, const unsigned char *data, size_t data_len, const unsigned char *signature, size_t sig_len, const char *hash_algo_str) { @@ -411,18 +720,8 @@ static int RSAVerify(const char *public_key_pem, const unsigned char *data, // Determine hash algorithm if (hash_algo_str) { - if (strcasecmp(hash_algo_str, "sha256") == 0) { - hash_algo = MBEDTLS_MD_SHA256; - hash_len = 32; - } else if (strcasecmp(hash_algo_str, "sha384") == 0) { - hash_algo = MBEDTLS_MD_SHA384; - hash_len = 48; - } else if (strcasecmp(hash_algo_str, "sha512") == 0) { - hash_algo = MBEDTLS_MD_SHA512; - hash_len = 64; - } else { - return -1; // Unsupported hash algorithm - } + hash_algo = string_to_md_type(hash_algo_str); + hash_len = get_hash_size_from_md_type(hash_algo); } // Parse public key @@ -431,14 +730,14 @@ static int RSAVerify(const char *public_key_pem, const unsigned char *data, if ((rc = mbedtls_pk_parse_public_key(&key, (const unsigned char *)public_key_pem, strlen(public_key_pem) + 1)) != 0) { - WARNF("Failed to parse public key (grep -0x%04x)", -rc); + WARNF("(crypto) Failed to parse public key (grep -0x%04x)", -rc); mbedtls_pk_free(&key); return -1; } // Check if key is RSA if (mbedtls_pk_get_type(&key) != MBEDTLS_PK_RSA) { - WARNF("Key is not an RSA key"); + WARNF("(crypto) Key is not an RSA key"); mbedtls_pk_free(&key); return -1; } @@ -464,6 +763,17 @@ static int LuaRSAVerify(lua_State *L) { int result; // Get parameters from Lua + if (!lua_isstring(L, 1)) { + lua_pushnil(L); + lua_pushstring(L, "Key must be a string"); + return 2; + } + // Ensure msg is a string + if (!lua_isstring(L, 2)) { + lua_pushnil(L); + lua_pushstring(L, "Plaintext must be a string"); + return 2; + } key_pem = luaL_checklstring(L, 1, &key_len); msg = luaL_checklstring(L, 2, &msg_len); signature = luaL_checklstring(L, 3, &sig_len); @@ -483,177 +793,120 @@ static int LuaRSAVerify(lua_State *L) { return 1; } +// RSA PSS Verification +static int RSAPSSVerify(const char *public_key_pem, const unsigned char *data, + size_t data_len, const unsigned char *signature, + size_t sig_len, const char *hash_algo_str, + int expected_salt_len) { + int rc; + unsigned char hash[64]; // Large enough for SHA-512 + size_t hash_len = 32; // Default for SHA-256 + mbedtls_md_type_t hash_algo = MBEDTLS_MD_SHA256; // Default + + // Determine hash algorithm + if (hash_algo_str) { + hash_algo = string_to_md_type(hash_algo_str); + DEBUGF("(DEBUG) Using hash algorithm: %s", hash_algo_str); + hash_len = get_hash_size_from_md_type(hash_algo); + DEBUGF("(DEBUG) Hash length: %zu", hash_len); + } + + // Parse public key + mbedtls_pk_context key; + mbedtls_pk_init(&key); + if ((rc = mbedtls_pk_parse_public_key(&key, + (const unsigned char *)public_key_pem, + strlen(public_key_pem) + 1)) != 0) { + WARNF("(crypto) Failed to parse public key (grep -0x%04x)", -rc); + mbedtls_pk_free(&key); + return -1; + } + + // Check if key is RSA + if (mbedtls_pk_get_type(&key) != MBEDTLS_PK_RSA) { + WARNF("(crypto) Key is not an RSA key"); + mbedtls_pk_free(&key); + return -1; + } + + // Hash the message + if ((rc = mbedtls_md(mbedtls_md_info_from_type(hash_algo), data, data_len, + hash)) != 0) { + mbedtls_pk_free(&key); + return -1; + } + + // Get RSA context + mbedtls_rsa_context *rsa = mbedtls_pk_rsa(key); + + // Verify the signature using PSS + rc = mbedtls_rsa_rsassa_pss_verify( + rsa, NULL, NULL, MBEDTLS_RSA_PUBLIC, hash_algo, (unsigned int)hash_len, + hash, signature); + + // Clean up + mbedtls_pk_free(&key); + + return rc; // 0 means success (valid signature) +} + +static int LuaRSAPSSVerify(lua_State *L) { + // Args: key, msg, signature, hash_algo (optional), salt_len (optional + // crypto.verify('rsapss', key, msg, signature, hash_algo, salt_len) + size_t msg_len, key_len, sig_len; + const char *msg, *key_pem, *signature; + const char *hash_algo_str = NULL; // Default to SHA-256 + int expected_salt_len = -1; + int result; + + // Get parameters from Lua + if (lua_type(L, 1) != LUA_TSTRING) { + lua_pushnil(L); + lua_pushstring(L, "Key must be a string"); + return 2; + } + // Ensure msg is a string + if (lua_type(L, 2) != LUA_TSTRING) { + lua_pushnil(L); + lua_pushstring(L, "Plaintext must be a string"); + return 2; + } + // Ensure signature is a string + if (lua_type(L, 3) != LUA_TSTRING) { + lua_pushnil(L); + lua_pushstring(L, "Signature must be a string"); + return 2; + } + // Get parameters from Lua + key_pem = luaL_checklstring(L, 1, &key_len); + msg = luaL_checklstring(L, 2, &msg_len); + signature = luaL_checklstring(L, 3, &sig_len); + // Optional hash algorithm parameter + if (lua_isstring(L, 4)) { + hash_algo_str = luaL_checkstring(L, 4); + // Optional salt length parameter + expected_salt_len = luaL_optinteger(L, 5, 32); + } else if (lua_isnumber(L, 4)) { + // If it's a number, treat it as the expected salt length + expected_salt_len = (int)lua_tointeger(L, 4); + } + DEBUGF("(DEBUG) Key PEM: %s", key_pem); + DEBUGF("(DEBUG) Message length: %zu", msg_len); + DEBUGF("(DEBUG) Signature length: %zu", sig_len); + DEBUGF("(DEBUG) Hash algorithm: %s", hash_algo_str ? hash_algo_str : "SHA-256"); + DEBUGF("(DEBUG) Expected salt length: %d", expected_salt_len); + DEBUGF("(DEBUG) Signature: %.*s", (int)sig_len, signature); + // Call the C implementation + result = RSAPSSVerify(key_pem, (const unsigned char *)msg, msg_len, + (const unsigned char *)signature, sig_len, hash_algo_str, expected_salt_len); + + // Return boolean result (0 means valid signature) + lua_pushboolean(L, result == 0); + + return 1; +} + // Elliptic Curve Cryptography Functions -// Supported curves mapping -typedef struct { - const char *name; - mbedtls_ecp_group_id id; -} curve_map_t; - -static const curve_map_t supported_curves[] = { - {"secp256r1", MBEDTLS_ECP_DP_SECP256R1}, - {"secp384r1", MBEDTLS_ECP_DP_SECP384R1}, - {"secp521r1", MBEDTLS_ECP_DP_SECP521R1}, - {"secp192r1", MBEDTLS_ECP_DP_SECP192R1}, - {"secp224r1", MBEDTLS_ECP_DP_SECP224R1}, - {"curve25519", MBEDTLS_ECP_DP_CURVE25519}, - {NULL, 0}}; - -typedef enum { SHA256, SHA384, SHA512 } hash_algorithm_t; - -static mbedtls_md_type_t hash_to_md_type(hash_algorithm_t hash_alg) { - switch (hash_alg) { - case SHA256: - return MBEDTLS_MD_SHA256; - case SHA384: - return MBEDTLS_MD_SHA384; - case SHA512: - return MBEDTLS_MD_SHA512; - default: - return MBEDTLS_MD_SHA256; // Default to SHA-256 - } -} - -static size_t get_hash_size(hash_algorithm_t hash_alg) { - switch (hash_alg) { - case SHA256: - return 32; - case SHA384: - return 48; - case SHA512: - return 64; - default: - return 32; // Default to SHA-256 - } -} - -static hash_algorithm_t string_to_hash_alg(const char *hash_name) { - if (!hash_name || !*hash_name) { - return SHA256; // Default to SHA-256 if no name provided - } - - if (strcasecmp(hash_name, "sha256") == 0 || - strcasecmp(hash_name, "sha-256") == 0) { - return SHA256; - } else if (strcasecmp(hash_name, "sha384") == 0 || - strcasecmp(hash_name, "sha-384") == 0) { - return SHA384; - } else if (strcasecmp(hash_name, "sha512") == 0 || - strcasecmp(hash_name, "sha-512") == 0) { - return SHA512; - } else { - WARNF("(ecdsa) Unknown hash algorithm '%s', using SHA-256", hash_name); - return SHA256; - } -} - -static int LuaListHashAlgorithms(lua_State *L) { - lua_newtable(L); - - lua_pushstring(L, "SHA256"); - lua_rawseti(L, -2, 1); - - lua_pushstring(L, "SHA384"); - lua_rawseti(L, -2, 2); - - lua_pushstring(L, "SHA512"); - lua_rawseti(L, -2, 3); - - // Add hyphenated versions - lua_pushstring(L, "SHA-256"); - lua_rawseti(L, -2, 4); - - lua_pushstring(L, "SHA-384"); - lua_rawseti(L, -2, 5); - - lua_pushstring(L, "SHA-512"); - lua_rawseti(L, -2, 6); - - lua_pushstring(L, "MD5"); - lua_rawseti(L, -2, 7); - - return 1; -} - -// List available curves -static int LuaListCurves(lua_State *L) { - const curve_map_t *curve = supported_curves; - int i = 1; - - lua_newtable(L); - - while (curve->name != NULL) { - lua_pushstring(L, curve->name); - lua_rawseti(L, -2, i++); - curve++; - } - - return 1; -} - -static int compute_hash(hash_algorithm_t hash_alg, const unsigned char *input, - size_t input_len, unsigned char *output, - size_t output_size) { - mbedtls_md_context_t md_ctx; - const mbedtls_md_info_t *md_info; - int ret; - - mbedtls_md_type_t md_type = hash_to_md_type(hash_alg); - md_info = mbedtls_md_info_from_type(md_type); - if (md_info == NULL) { - WARNF("(ecdsa) Unsupported hash algorithm"); - return -1; - } - - if (output_size < mbedtls_md_get_size(md_info)) { - WARNF("(ecdsa) Output buffer too small for hash"); - return -1; - } - - mbedtls_md_init(&md_ctx); - - ret = mbedtls_md_setup(&md_ctx, md_info, 0); // 0 = non-HMAC - if (ret != 0) { - WARNF("(ecdsa) Failed to set up hash context: -0x%04x", -ret); - goto cleanup; - } - - ret = mbedtls_md_starts(&md_ctx); - if (ret != 0) { - WARNF("(ecdsa) Failed to start hash operation: -0x%04x", -ret); - goto cleanup; - } - - ret = mbedtls_md_update(&md_ctx, input, input_len); - if (ret != 0) { - WARNF("(ecdsa) Failed to update hash: -0x%04x", -ret); - goto cleanup; - } - - ret = mbedtls_md_finish(&md_ctx, output); - if (ret != 0) { - WARNF("(ecdsa) Failed to finish hash: -0x%04x", -ret); - goto cleanup; - } - -cleanup: - mbedtls_md_free(&md_ctx); - return ret; -} - -// Find curve ID by name -static mbedtls_ecp_group_id find_curve_by_name(const char *name) { - const curve_map_t *curve = supported_curves; - - while (curve->name != NULL) { - if (strcasecmp(curve->name, name) == 0) { - return curve->id; - } - curve++; - } - - return MBEDTLS_ECP_DP_NONE; -} // Generate an ECDSA key pair and return in PEM format static int ECDSAGenerateKeyPair(const char *curve_name, char **priv_key_pem, @@ -672,13 +925,13 @@ static int ECDSAGenerateKeyPair(const char *curve_name, char **priv_key_pem, // Use secp256r1 as default if curve_name is NULL or empty if (curve_name == NULL || curve_name[0] == '\0') { curve_id = MBEDTLS_ECP_DP_SECP256R1; - VERBOSEF("(ecdsa) No curve specified, using default: secp256r1"); + VERBOSEF("(crypto) No curve specified, using default: secp256r1"); } else { // Find the curve by name curve_id = find_curve_by_name(curve_name); if (curve_id == MBEDTLS_ECP_DP_NONE) { - WARNF("(ecdsa) Unknown curve: %s, using default: secp256r1", curve_name); - curve_id = MBEDTLS_ECP_DP_SECP256R1; + WARNF("(crypto) Unknown curve: '%s'", curve_name); + return -1; } } @@ -687,13 +940,13 @@ static int ECDSAGenerateKeyPair(const char *curve_name, char **priv_key_pem, // Generate the key with the specified curve ret = mbedtls_pk_setup(&key, mbedtls_pk_info_from_type(MBEDTLS_PK_ECKEY)); if (ret != 0) { - WARNF("(ecdsa) Failed to setup key: -0x%04x", -ret); + WARNF("(crypto) Failed to setup key: -0x%04x", -ret); goto cleanup; } ret = mbedtls_ecp_gen_key(curve_id, mbedtls_pk_ec(key), GenerateRandom, 0); if (ret != 0) { - WARNF("(ecdsa) Failed to generate key: -0x%04x", -ret); + WARNF("(crypto) Failed to generate key: -0x%04x", -ret); goto cleanup; } @@ -702,12 +955,12 @@ static int ECDSAGenerateKeyPair(const char *curve_name, char **priv_key_pem, memset(output_buf, 0, sizeof(output_buf)); ret = mbedtls_pk_write_key_pem(&key, output_buf, sizeof(output_buf)); if (ret != 0) { - WARNF("(ecdsa) Failed to write private key: -0x%04x", -ret); + WARNF("(crypto) Failed to write private key: -0x%04x", -ret); goto cleanup; } *priv_key_pem = strdup((char *)output_buf); if (*priv_key_pem == NULL) { - WARNF("(ecdsa) Failed to allocate memory for private key PEM"); + WARNF("(crypto) Failed to allocate memory for private key PEM"); ret = -1; goto cleanup; } @@ -718,12 +971,12 @@ static int ECDSAGenerateKeyPair(const char *curve_name, char **priv_key_pem, memset(output_buf, 0, sizeof(output_buf)); ret = mbedtls_pk_write_pubkey_pem(&key, output_buf, sizeof(output_buf)); if (ret != 0) { - WARNF("(ecdsa) Failed to write public key: -0x%04x", -ret); + WARNF("(crypto) Failed to write public key: -0x%04x", -ret); goto cleanup; } *pub_key_pem = strdup((char *)output_buf); if (*pub_key_pem == NULL) { - WARNF("(ecdsa) Failed to allocate memory for public key PEM"); + WARNF("(crypto) Failed to allocate memory for public key PEM"); ret = -1; goto cleanup; } @@ -771,7 +1024,7 @@ static int LuaECDSAGenerateKeyPair(lua_State *L) { // Sign a message using an ECDSA private key in PEM format static int ECDSASign(const char *priv_key_pem, const char *message, - hash_algorithm_t hash_alg, unsigned char **signature, + mbedtls_md_type_t hash_alg, unsigned char **signature, size_t *sig_len) { mbedtls_pk_context key; unsigned char hash[64]; // Max hash size (SHA-512) @@ -782,19 +1035,19 @@ static int ECDSASign(const char *priv_key_pem, const char *message, *sig_len = 0; if (!priv_key_pem) { - WARNF("(ecdsa) Private key is NULL"); + WARNF("(crypto) Private key is NULL"); return -1; } // Get the length of the PEM string (excluding null terminator) size_t key_len = strlen(priv_key_pem); if (key_len == 0) { - WARNF("(ecdsa) Private key is empty"); + WARNF("(crypto) Private key is empty"); return -1; } // Get hash size for the selected algorithm - hash_size = get_hash_size(hash_alg); + hash_size = get_hash_size_from_md_type(hash_alg); mbedtls_pk_init(&key); @@ -803,7 +1056,7 @@ static int ECDSASign(const char *priv_key_pem, const char *message, key_len + 1, NULL, 0); if (ret != 0) { - WARNF("(ecdsa) Failed to parse private key: -0x%04x", -ret); + WARNF("(crypto) Failed to parse private key: -0x%04x", -ret); goto cleanup; } @@ -811,24 +1064,24 @@ static int ECDSASign(const char *priv_key_pem, const char *message, ret = compute_hash(hash_alg, (const unsigned char *)message, strlen(message), hash, sizeof(hash)); if (ret != 0) { - WARNF("(ecdsa) Failed to compute message hash"); + WARNF("(crypto) Failed to compute message hash"); goto cleanup; } // Allocate memory for signature (max size for ECDSA) *signature = malloc(MBEDTLS_ECDSA_MAX_LEN); if (*signature == NULL) { - WARNF("(ecdsa) Failed to allocate memory for signature"); + WARNF("(crypto) Failed to allocate memory for signature"); ret = -1; goto cleanup; } // Sign the hash using GenerateRandom - ret = mbedtls_pk_sign(&key, hash_to_md_type(hash_alg), hash, hash_size, + ret = mbedtls_pk_sign(&key, hash_alg, hash, hash_size, *signature, sig_len, GenerateRandom, 0); if (ret != 0) { - WARNF("(ecdsa) Failed to sign message: -0x%04x", -ret); + WARNF("(crypto) Failed to sign message: -0x%04x", -ret); free(*signature); *signature = NULL; *sig_len = 0; @@ -845,7 +1098,7 @@ static int LuaECDSASign(lua_State *L) { const char *message = luaL_checkstring(L, 2); const char *hash_name = luaL_optstring(L, 3, "sha256"); - hash_algorithm_t hash_alg = string_to_hash_alg(hash_name); + mbedtls_md_type_t hash_alg = string_to_md_type(hash_name); unsigned char *signature = NULL; size_t sig_len = 0; @@ -865,26 +1118,26 @@ static int LuaECDSASign(lua_State *L) { // Verify a signature using an ECDSA public key in PEM format static int ECDSAVerify(const char *pub_key_pem, const char *message, const unsigned char *signature, size_t sig_len, - hash_algorithm_t hash_alg) { + mbedtls_md_type_t hash_alg) { mbedtls_pk_context key; unsigned char hash[64]; // Max hash size (SHA-512) size_t hash_size; int ret; if (!pub_key_pem) { - WARNF("(ecdsa) Public key is NULL"); + WARNF("(crypto) Public key is NULL"); return -1; } // Get the length of the PEM string (excluding null terminator) size_t key_len = strlen(pub_key_pem); if (key_len == 0) { - WARNF("(ecdsa) Public key is empty"); + WARNF("(crypto) Public key is empty"); return -1; } // Get hash size for the selected algorithm - hash_size = get_hash_size(hash_alg); + hash_size = get_hash_size_from_md_type(hash_alg); mbedtls_pk_init(&key); @@ -892,7 +1145,7 @@ static int ECDSAVerify(const char *pub_key_pem, const char *message, ret = mbedtls_pk_parse_public_key(&key, (const unsigned char *)pub_key_pem, key_len + 1); if (ret != 0) { - WARNF("(ecdsa) Failed to parse public key: -0x%04x", -ret); + WARNF("(crypto) Failed to parse public key: -0x%04x", -ret); goto cleanup; } @@ -900,15 +1153,15 @@ static int ECDSAVerify(const char *pub_key_pem, const char *message, ret = compute_hash(hash_alg, (const unsigned char *)message, strlen(message), hash, sizeof(hash)); if (ret != 0) { - WARNF("(ecdsa) Failed to compute message hash"); + WARNF("(crypto) Failed to compute message hash"); goto cleanup; } // Verify the signature - ret = mbedtls_pk_verify(&key, hash_to_md_type(hash_alg), hash, hash_size, + ret = mbedtls_pk_verify(&key, hash_alg, hash, hash_size, signature, sig_len); if (ret != 0) { - WARNF("(ecdsa) Signature verification failed: -0x%04x", -ret); + WARNF("(crypto) Signature verification failed: -0x%04x", -ret); goto cleanup; } @@ -925,7 +1178,7 @@ static int LuaECDSAVerify(lua_State *L) { (const unsigned char *)luaL_checklstring(L, 3, &sig_len); const char *hash_name = luaL_optstring(L, 4, "sha256"); - hash_algorithm_t hash_alg = string_to_hash_alg(hash_name); + mbedtls_md_type_t hash_alg = string_to_md_type(hash_name); int ret = ECDSAVerify(pub_key_pem, message, signature, sig_len, hash_alg); @@ -970,42 +1223,92 @@ typedef struct { size_t aadlen; } aes_options_t; -static void parse_aes_options(lua_State *L, int options_idx, - aes_options_t *opts) { - opts->mode = "cbc"; +static void parse_aes_options(lua_State *L, int options_idx, aes_options_t *opts) { + opts->mode = NULL; opts->iv = NULL; opts->ivlen = 0; opts->tag = NULL; opts->taglen = 0; opts->aad = NULL; opts->aadlen = 0; + + int mode_field_found = 0; + if (lua_istable(L, options_idx)) { + // Get mode lua_getfield(L, options_idx, "mode"); - if (!lua_isnil(L, -1)) - opts->mode = lua_tostring(L, -1); + if (!lua_isnil(L, -1)) { + mode_field_found = 1; + const char *mode = lua_tostring(L, -1); + if (mode && (strcasecmp(mode, "cbc") == 0 || + strcasecmp(mode, "gcm") == 0 || + strcasecmp(mode, "ctr") == 0)) { + opts->mode = mode; + } else { + opts->mode = NULL; // Invalid mode + } + } lua_pop(L, 1); + + // Get IV lua_getfield(L, options_idx, "iv"); if (lua_isstring(L, -1)) { - opts->iv = (const unsigned char *)lua_tolstring(L, -1, &opts->ivlen); + size_t ivlen; + opts->iv = (const unsigned char *)lua_tolstring(L, -1, &ivlen); + opts->ivlen = ivlen; } lua_pop(L, 1); + + // Get tag (for GCM) lua_getfield(L, options_idx, "tag"); if (lua_isstring(L, -1)) { - opts->tag = (const unsigned char *)lua_tolstring(L, -1, &opts->taglen); + size_t taglen; + opts->tag = (const unsigned char *)lua_tolstring(L, -1, &taglen); + opts->taglen = taglen; } lua_pop(L, 1); + + // Get aad (for GCM) lua_getfield(L, options_idx, "aad"); if (lua_isstring(L, -1)) { - opts->aad = (const unsigned char *)lua_tolstring(L, -1, &opts->aadlen); + size_t aadlen; + opts->aad = (const unsigned char *)lua_tolstring(L, -1, &aadlen); + opts->aadlen = aadlen; } lua_pop(L, 1); } -} + // Only default to cbc if mode field was not found at all + if (!mode_field_found) { + opts->mode = "cbc"; + } +} // AES encryption supporting CBC, GCM, and CTR modes static int LuaAesEncrypt(lua_State *L) { // Args: key, plaintext, options table size_t keylen, ptlen; + + + // Get parameters from Lua + // Ensure key is a string + if (lua_type(L, 1) != LUA_TSTRING) { + lua_pushnil(L); + lua_pushstring(L, "Key must be a string"); + return 2; + } + // Ensure plaintext is a string + if (lua_type(L, 2) != LUA_TSTRING) { + lua_pushnil(L); + lua_pushstring(L, "Plaintext must be a string"); + return 2; + } + // Ensure options is a table or nil + if (!lua_istable(L, 3) && !lua_isnil(L, 3)) { + lua_pushnil(L); + lua_pushstring(L, "Options must be a table or nil"); + return 2; + } + const unsigned char *key = (const unsigned char *)luaL_checklstring(L, 1, &keylen); const unsigned char *plaintext = @@ -1014,6 +1317,11 @@ static int LuaAesEncrypt(lua_State *L) { aes_options_t opts; parse_aes_options(L, options_idx, &opts); const char *mode = opts.mode; + if (!mode) { + lua_pushnil(L); + lua_pushstring(L, "Invalid AES mode specified. Use 'cbc', 'gcm', or 'ctr'."); + return 2; + } const unsigned char *iv = opts.iv; size_t ivlen = opts.ivlen; unsigned char *gen_iv = NULL; @@ -1032,6 +1340,7 @@ static int LuaAesEncrypt(lua_State *L) { lua_pushstring(L, "Unsupported AES mode. Use 'cbc', 'gcm', or 'ctr'."); return 2; } + // If IV is not provided, auto-generate if (!iv) { if (is_gcm) { @@ -1055,6 +1364,26 @@ static int LuaAesEncrypt(lua_State *L) { iv = gen_iv; iv_was_generated = 1; } + + // Validate IV/nonce length + if (is_cbc || is_ctr) { + // Validate IV/nonce length + if (is_cbc || is_ctr) { + if (opts.iv && opts.ivlen != 16) { + if (iv_was_generated) free(gen_iv); + lua_pushnil(L); + lua_pushstring(L, "AES IV/nonce must be 16 bytes for CBC/CTR"); + return 2; + } + } else if (is_gcm) { + if (opts.iv && (opts.ivlen < 12 || opts.ivlen > 16)) { + if (iv_was_generated) free(gen_iv); + lua_pushnil(L); + lua_pushstring(L, "AES GCM nonce must be 12-16 bytes"); + return 2; + } + } + } if (is_cbc) { // PKCS7 padding size_t block_size = 16; @@ -1188,14 +1517,32 @@ static int LuaAesEncrypt(lua_State *L) { static int LuaAesDecrypt(lua_State *L) { // Args: key, ciphertext, options table size_t keylen, ctlen; + // Ensure key is a string + if (lua_type(L, 1) != LUA_TSTRING) { + lua_pushnil(L); + lua_pushstring(L, "Key must be a string"); + return 2; + } + + // Ensure ciphertext is a string + if (lua_type(L, 2) != LUA_TSTRING) { + lua_pushnil(L); + lua_pushstring(L, "Ciphertext must be a string"); + return 2; + } const unsigned char *key = - (const unsigned char *)luaL_checklstring(L, 1, &keylen); + (const unsigned char *)luaL_checklstring(L, 1, &keylen); const unsigned char *ciphertext = (const unsigned char *)luaL_checklstring(L, 2, &ctlen); int options_idx = 3; aes_options_t opts; parse_aes_options(L, options_idx, &opts); const char *mode = opts.mode; + if (!mode) { + lua_pushnil(L); + lua_pushstring(L, "Invalid AES mode specified. Use 'cbc', 'gcm', or 'ctr'."); + return 2; + } const unsigned char *iv = opts.iv; size_t ivlen = opts.ivlen; const unsigned char *tag = opts.tag; @@ -1385,6 +1732,13 @@ static int LuaAesDecrypt(lua_State *L) { static int LuaConvertJwkToPem(lua_State *L) { luaL_checktype(L, 1, LUA_TTABLE); const char *kty; + + if (lua_isnoneornil(L, 1) || lua_type(L, 1) != LUA_TTABLE) { + lua_pushnil(L); + lua_pushstring(L, "Expected a JWK table, got nil"); + return 2; + } + lua_getfield(L, 1, "kty"); kty = lua_tostring(L, -1); if (!kty) { @@ -1404,6 +1758,16 @@ static int LuaConvertJwkToPem(lua_State *L) { lua_getfield(L, 1, "e"); const char *n_b64 = lua_tostring(L, -2); const char *e_b64 = lua_tostring(L, -1); + if (!n_b64 || !*n_b64) { + lua_pushnil(L); + lua_pushstring(L, "Missing or empty 'n' in JWK"); + return 2; + } + if (!e_b64 || !*e_b64) { + lua_pushnil(L); + lua_pushstring(L, "Missing or empty 'e' in JWK"); + return 2; + } // Optional private fields lua_getfield(L, 1, "d"); lua_getfield(L, 1, "p"); @@ -1498,6 +1862,15 @@ static int LuaConvertJwkToPem(lua_State *L) { const char *x_b64 = lua_tostring(L, -3); const char *y_b64 = lua_tostring(L, -2); const char *d_b64 = lua_tostring(L, -1); + if (!crv || !*crv) { + lua_pushnil(L); lua_pushstring(L, "Missing or empty 'crv' in JWK"); return 2; + } + if (!x_b64 || !*x_b64) { + lua_pushnil(L); lua_pushstring(L, "Missing or empty 'x' in JWK"); return 2; + } + if (!y_b64 || !*y_b64) { + lua_pushnil(L); lua_pushstring(L, "Missing or empty 'y' in JWK"); return 2; + } int has_private = d_b64 && *d_b64; mbedtls_ecp_group_id gid = find_curve_by_name(crv); if (gid == MBEDTLS_ECP_DP_NONE) { @@ -1507,7 +1880,7 @@ static int LuaConvertJwkToPem(lua_State *L) { unsigned char x_bin[72], y_bin[72]; char *x_b64_std = strdup(x_b64), *y_b64_std = strdup(y_b64); for (char *p = x_b64_std; *p; ++p) if (*p == '-') *p = '+'; else if (*p == '_') *p = '/'; - for (char *p = y_b64_std; *p; ++p) if (*p == '-') *p = '+'; else if (*p == '_') *p = '/'; + for ( char *p = y_b64_std; *p; ++p) if (*p == '-') *p = '+'; else if (*p == '_') *p = '/'; int x_mod = strlen(x_b64_std) % 4; int y_mod = strlen(y_b64_std) % 4; if (x_mod) for (int i = 0; i < 4 - x_mod; ++i) strcat(x_b64_std, "="); @@ -1542,7 +1915,7 @@ static int LuaConvertJwkToPem(lua_State *L) { if (ret != 0) { mbedtls_pk_free(&pk); lua_pushnil(L); lua_pushstring(L, "PEM write failed"); return 2; - } + } pem = strdup((char *)buf); mbedtls_pk_free(&pk); lua_pushstring(L, pem); @@ -1556,8 +1929,25 @@ static int LuaConvertJwkToPem(lua_State *L) { } // Convert PEM key to JWK (Lua table) format +static void base64_to_base64url(char *str) { + if (!str) return; + for (char *p = str; *p; p++) { + if (*p == '+') *p = '-'; + else if (*p == '/') *p = '_'; + } + // Remove padding + size_t len = strlen(str); + while (len > 0 && str[len-1] == '=') { + str[--len] = '\0'; + } +} + static int LuaConvertPemToJwk(lua_State *L) { const char *pem_key = luaL_checkstring(L, 1); + int has_claims = 0; + if (!lua_isnoneornil(L, 2) && lua_istable(L, 2)) { + has_claims = 1; + } mbedtls_pk_context key; mbedtls_pk_init(&key); @@ -1600,12 +1990,12 @@ static int LuaConvertPemToJwk(lua_State *L) { mbedtls_base64_encode(NULL, 0, &e_b64_len, e, e_len); n_b64 = malloc(n_b64_len + 1); e_b64 = malloc(e_b64_len + 1); - mbedtls_base64_encode((unsigned char *)n_b64, n_b64_len, &n_b64_len, n, - n_len); - mbedtls_base64_encode((unsigned char *)e_b64, e_b64_len, &e_b64_len, e, - e_len); + mbedtls_base64_encode((unsigned char *)n_b64, n_b64_len, &n_b64_len, n, n_len); n_b64[n_b64_len] = '\0'; + base64_to_base64url(n_b64); + mbedtls_base64_encode((unsigned char *)e_b64, e_b64_len, &e_b64_len, e, e_len); e_b64[e_b64_len] = '\0'; + base64_to_base64url(e_b64); lua_pushstring(L, n_b64); lua_setfield(L, -2, "n"); lua_pushstring(L, e_b64); @@ -1655,24 +2045,25 @@ static int LuaConvertPemToJwk(lua_State *L) { dp_b64 = malloc(dp_b64_len + 1); dq_b64 = malloc(dq_b64_len + 1); qi_b64 = malloc(qi_b64_len + 1); - mbedtls_base64_encode((unsigned char *)d_b64, d_b64_len, &d_b64_len, d, - d_len); - mbedtls_base64_encode((unsigned char *)p_b64, p_b64_len, &p_b64_len, p, - p_len); - mbedtls_base64_encode((unsigned char *)q_b64, q_b64_len, &q_b64_len, q, - q_len); - mbedtls_base64_encode((unsigned char *)dp_b64, dp_b64_len, &dp_b64_len, - dp, dp_len); - mbedtls_base64_encode((unsigned char *)dq_b64, dq_b64_len, &dq_b64_len, - dq, dq_len); - mbedtls_base64_encode((unsigned char *)qi_b64, qi_b64_len, &qi_b64_len, - qi, qi_len); + mbedtls_base64_encode((unsigned char *)d_b64, d_b64_len, &d_b64_len, d, d_len); + mbedtls_base64_encode((unsigned char *)p_b64, p_b64_len, &p_b64_len, p, p_len); + mbedtls_base64_encode((unsigned char *)q_b64, q_b64_len, &q_b64_len, q, q_len); + mbedtls_base64_encode((unsigned char *)dp_b64, dp_b64_len, &dp_b64_len, dp, dp_len); + mbedtls_base64_encode((unsigned char *)dq_b64, dq_b64_len, &dq_b64_len, dq, dq_len); + mbedtls_base64_encode((unsigned char *)qi_b64, qi_b64_len, &qi_b64_len, qi, qi_len); d_b64[d_b64_len] = '\0'; p_b64[p_b64_len] = '\0'; q_b64[q_b64_len] = '\0'; dp_b64[dp_b64_len] = '\0'; dq_b64[dq_b64_len] = '\0'; qi_b64[qi_b64_len] = '\0'; + // Convert all private components to base64url + base64_to_base64url(d_b64); + base64_to_base64url(p_b64); + base64_to_base64url(q_b64); + base64_to_base64url(dp_b64); + base64_to_base64url(dq_b64); + base64_to_base64url(qi_b64); lua_pushstring(L, d_b64); lua_setfield(L, -2, "d"); lua_pushstring(L, p_b64); @@ -1727,9 +2118,11 @@ static int LuaConvertPemToJwk(lua_State *L) { x_b64 = malloc(x_b64_len + 1); y_b64 = malloc(y_b64_len + 1); mbedtls_base64_encode((unsigned char *)x_b64, x_b64_len, &x_b64_len, x, x_len); - mbedtls_base64_encode((unsigned char *)y_b64, y_b64_len, &y_b64_len, y, y_len); x_b64[x_b64_len] = '\0'; + base64_to_base64url(x_b64); + mbedtls_base64_encode((unsigned char *)y_b64, y_b64_len, &y_b64_len, y, y_len); y_b64[y_b64_len] = '\0'; + base64_to_base64url(y_b64); // Set kty and crv for EC keys lua_pushstring(L, "EC"); lua_setfield(L, -2, "kty"); @@ -1757,6 +2150,7 @@ static int LuaConvertPemToJwk(lua_State *L) { d_b64 = malloc(d_b64_len + 1); mbedtls_base64_encode((unsigned char *)d_b64, d_b64_len, &d_b64_len, d, d_len); d_b64[d_b64_len] = '\0'; + base64_to_base64url(d_b64); lua_pushstring(L, d_b64); lua_setfield(L, -2, "d"); free(d); free(d_b64); } @@ -1769,10 +2163,33 @@ static int LuaConvertPemToJwk(lua_State *L) { } mbedtls_pk_free(&key); + + // Merge additional claims if provided and compatible with RFC7517 + if (has_claims) { + static const char *allowed[] = {"kty","use","sig","key_ops","alg","kid","x5u","x5c","x5t","x5t#S256",NULL}; + lua_pushnil(L); // first key + while (lua_next(L, 2) != 0) { + const char *k = lua_tostring(L, -2); + int allowed_key = 0; + for (int i = 0; allowed[i]; ++i) { + if (strcmp(k, allowed[i]) == 0) { + allowed_key = 1; + break; + } + } + if (allowed_key) { + lua_pushvalue(L, -2); + lua_insert(L, -2); + lua_settable(L, -4); + } else { + lua_pop(L, 1); + } + } + } + return 1; } - // CSR creation Function static int LuaGenerateCSR(lua_State *L) { const char *key_pem = luaL_checkstring(L, 1); @@ -1838,13 +2255,15 @@ static int LuaGenerateCSR(lua_State *L) { // LuaCrypto compatible API static int LuaCryptoSign(lua_State *L) { - // Type of signature (e.g., "rsa", "ecdsa") + // Type of signature (e.g., "rsa", "ecdsa", "rsa-pss") const char *dtype = luaL_checkstring(L, 1); // Remove the first argument (key type or cipher type) before dispatching lua_remove(L, 1); if (strcasecmp(dtype, "rsa") == 0) { return LuaRSASign(L); + } else if (strcasecmp(dtype, "rsa-pss") == 0 || strcasecmp(dtype, "rsapss") == 0) { + return LuaRSAPSSSign(L); } else if (strcasecmp(dtype, "ecdsa") == 0) { return LuaECDSASign(L); } else { @@ -1853,13 +2272,15 @@ static int LuaCryptoSign(lua_State *L) { } static int LuaCryptoVerify(lua_State *L) { - // Type of signature (e.g., "rsa", "ecdsa") + // Type of signature (e.g., "rsa", "ecdsa", "rsa-pss") const char *dtype = luaL_checkstring(L, 1); // Remove the first argument (key type or cipher type) before dispatching lua_remove(L, 1); if (strcasecmp(dtype, "rsa") == 0) { return LuaRSAVerify(L); + } else if (strcasecmp(dtype, "rsa-pss") == 0 || strcasecmp(dtype, "rsapss") == 0) { + return LuaRSAPSSVerify(L); } else if (strcasecmp(dtype, "ecdsa") == 0) { return LuaECDSAVerify(L); } else { @@ -1925,7 +2346,7 @@ static const luaL_Reg kLuaCrypto[] = { {"verify", LuaCryptoVerify}, // {"encrypt", LuaCryptoEncrypt}, // {"decrypt", LuaCryptoDecrypt}, // - {"generatekeypair", LuaCryptoGenerateKeyPair}, // + {"generateKeyPair", LuaCryptoGenerateKeyPair}, // {"convertJwkToPem", LuaConvertJwkToPem}, // {"convertPemToJwk", LuaConvertPemToJwk}, // {"generateCsr", LuaGenerateCSR}, // From b8bdccc7fc2fc33c59a9557f92ac743f7d652fd6 Mon Sep 17 00:00:00 2001 From: Miguel Terron Date: Tue, 24 Jun 2025 20:30:35 +1200 Subject: [PATCH 16/18] Address Paul's comments :) --- tool/net/lcrypto.c | 366 +++++++++++++++++++++++++++++---------------- 1 file changed, 236 insertions(+), 130 deletions(-) diff --git a/tool/net/lcrypto.c b/tool/net/lcrypto.c index 7ce0c3cd4..4cda00e1d 100644 --- a/tool/net/lcrypto.c +++ b/tool/net/lcrypto.c @@ -39,12 +39,15 @@ typedef struct { } curve_map_t; static const curve_map_t supported_curves[] = { - {"secp256r1", MBEDTLS_ECP_DP_SECP256R1}, - {"P-256", MBEDTLS_ECP_DP_SECP256R1}, - {"secp384r1", MBEDTLS_ECP_DP_SECP384R1}, - {"P-384", MBEDTLS_ECP_DP_SECP384R1}, - {"secp521r1", MBEDTLS_ECP_DP_SECP521R1}, - {"P-521", MBEDTLS_ECP_DP_SECP521R1}, + {"secp256r1", MBEDTLS_ECP_DP_SECP256R1}, + {"P256", MBEDTLS_ECP_DP_SECP256R1}, + {"P-256", MBEDTLS_ECP_DP_SECP256R1}, + {"secp384r1", MBEDTLS_ECP_DP_SECP384R1}, + {"P384", MBEDTLS_ECP_DP_SECP384R1}, + {"P-384", MBEDTLS_ECP_DP_SECP384R1}, + {"secp521r1", MBEDTLS_ECP_DP_SECP521R1}, + {"P521", MBEDTLS_ECP_DP_SECP521R1}, + {"P-521", MBEDTLS_ECP_DP_SECP521R1}, {"curve25519", MBEDTLS_ECP_DP_CURVE25519}, {NULL, 0}}; @@ -69,11 +72,14 @@ static mbedtls_md_type_t string_to_md_type(const char *hash_name) { if (!hash_name || !*hash_name) { return MBEDTLS_MD_SHA256; // Default to SHA-256 if no name provided } - if (strcasecmp(hash_name, "sha256") == 0 || strcasecmp(hash_name, "sha-256") == 0) { + if (strcasecmp(hash_name, "sha256") == 0 || + strcasecmp(hash_name, "sha-256") == 0) { return MBEDTLS_MD_SHA256; - } else if (strcasecmp(hash_name, "sha384") == 0 || strcasecmp(hash_name, "sha-384") == 0) { + } else if (strcasecmp(hash_name, "sha384") == 0 || + strcasecmp(hash_name, "sha-384") == 0) { return MBEDTLS_MD_SHA384; - } else if (strcasecmp(hash_name, "sha512") == 0 || strcasecmp(hash_name, "sha-512") == 0) { + } else if (strcasecmp(hash_name, "sha512") == 0 || + strcasecmp(hash_name, "sha-512") == 0) { return MBEDTLS_MD_SHA512; } else { WARNF("(crypto) Unknown hash algorithm '%s'", hash_name); @@ -81,6 +87,7 @@ static mbedtls_md_type_t string_to_md_type(const char *hash_name) { } } +// Get the size of the hash output based on the mbedtls_md_type_t static size_t get_hash_size_from_md_type(mbedtls_md_type_t md_type) { switch (md_type) { case MBEDTLS_MD_SHA256: @@ -288,7 +295,10 @@ static int LuaRSAGenerateKeyPair(lua_State *L) { // Check if key length is valid (only 2048 or 4096 bits are allowed) if (bits != 2048 && bits != 4096) { lua_pushnil(L); - lua_pushfstring(L, "Invalid RSA key length: %d. Only 2048 or 4096 bits key lengths are supported", bits); + lua_pushfstring(L, + "Invalid RSA key length: %d. Only 2048 or 4096 bits key " + "lengths are supported", + bits); return 2; } @@ -414,7 +424,7 @@ static char *RSADecrypt(const char *private_key_pem, mbedtls_pk_context key; mbedtls_pk_init(&key); rc = mbedtls_pk_parse_key(&key, (const unsigned char *)private_key_pem, - strlen(private_key_pem) + 1, NULL, 0); + strlen(private_key_pem) + 1, NULL, 0); if (rc != 0) { WARNF("(crypto) Failed to parse private key (grep -0x%04x)", -rc); mbedtls_pk_free(&key); @@ -439,8 +449,8 @@ static char *RSADecrypt(const char *private_key_pem, // Decrypt data size_t output_len = 0; rc = mbedtls_rsa_pkcs1_decrypt(mbedtls_pk_rsa(key), GenerateRandom, 0, - MBEDTLS_RSA_PRIVATE, &output_len, - encrypted_data, output, key_size); + MBEDTLS_RSA_PRIVATE, &output_len, + encrypted_data, output, key_size); if (rc != 0) { WARNF("(crypto) Decryption failed (grep -0x%04x)", -rc); free(output); @@ -560,10 +570,10 @@ static int LuaRSASign(lua_State *L) { return 2; } if (lua_type(L, 1) == LUA_TTABLE) { - DEBUGF("(crypto) Detected table type for parameter 1"); - lua_pushnil(L); - lua_pushstring(L, "Key must be a string, got a table instead"); - return 2; + DEBUGF("(crypto) Detected table type for parameter 1"); + lua_pushnil(L); + lua_pushstring(L, "Key must be a string, got a table instead"); + return 2; } // Ensure msg is a string if (lua_type(L, 2) != LUA_TSTRING) { @@ -617,7 +627,7 @@ static char *RSAPSSSign(const char *private_key_pem, const unsigned char *data, mbedtls_pk_context key; mbedtls_pk_init(&key); rc = mbedtls_pk_parse_key(&key, (const unsigned char *)private_key_pem, - strlen(private_key_pem) + 1, NULL, 0); + strlen(private_key_pem) + 1, NULL, 0); if (rc != 0) { WARNF("(crypto) Failed to parse private key (grep -0x%04x)", -rc); mbedtls_pk_free(&key); @@ -650,9 +660,9 @@ static char *RSAPSSSign(const char *private_key_pem, const unsigned char *data, mbedtls_rsa_set_padding(rsa, MBEDTLS_RSA_PKCS_V21, hash_algo); // Sign the hash using PSS - rc = mbedtls_rsa_rsassa_pss_sign( - rsa, GenerateRandom, 0, MBEDTLS_RSA_PRIVATE, hash_algo, (unsigned int)hash_len, - hash, signature); + rc = mbedtls_rsa_rsassa_pss_sign(rsa, GenerateRandom, 0, MBEDTLS_RSA_PRIVATE, + hash_algo, (unsigned int)hash_len, hash, + signature); if (rc != 0) { free(signature); mbedtls_pk_free(&key); @@ -673,7 +683,7 @@ static int LuaRSAPSSSign(lua_State *L) { unsigned char *signature; size_t sig_len = 0; - // Get parameters from Lua + // Get parameters from Lua if (lua_type(L, 1) != LUA_TSTRING) { lua_pushnil(L); lua_pushstring(L, "Key must be a string"); @@ -697,8 +707,9 @@ static int LuaRSAPSSSign(lua_State *L) { int salt_len = luaL_optinteger(L, 4, -1); // Call the C implementation - signature = (unsigned char *)RSAPSSSign(key_pem, (const unsigned char *)msg, - msg_len, hash_name, &sig_len, salt_len); + signature = + (unsigned char *)RSAPSSSign(key_pem, (const unsigned char *)msg, msg_len, + hash_name, &sig_len, salt_len); if (!signature) { return luaL_error(L, "failed to sign message (PSS)"); @@ -795,9 +806,9 @@ static int LuaRSAVerify(lua_State *L) { // RSA PSS Verification static int RSAPSSVerify(const char *public_key_pem, const unsigned char *data, - size_t data_len, const unsigned char *signature, - size_t sig_len, const char *hash_algo_str, - int expected_salt_len) { + size_t data_len, const unsigned char *signature, + size_t sig_len, const char *hash_algo_str, + int expected_salt_len) { int rc; unsigned char hash[64]; // Large enough for SHA-512 size_t hash_len = 32; // Default for SHA-256 @@ -806,9 +817,7 @@ static int RSAPSSVerify(const char *public_key_pem, const unsigned char *data, // Determine hash algorithm if (hash_algo_str) { hash_algo = string_to_md_type(hash_algo_str); - DEBUGF("(DEBUG) Using hash algorithm: %s", hash_algo_str); hash_len = get_hash_size_from_md_type(hash_algo); - DEBUGF("(DEBUG) Hash length: %zu", hash_len); } // Parse public key @@ -840,9 +849,9 @@ static int RSAPSSVerify(const char *public_key_pem, const unsigned char *data, mbedtls_rsa_context *rsa = mbedtls_pk_rsa(key); // Verify the signature using PSS - rc = mbedtls_rsa_rsassa_pss_verify( - rsa, NULL, NULL, MBEDTLS_RSA_PUBLIC, hash_algo, (unsigned int)hash_len, - hash, signature); + rc = mbedtls_rsa_rsassa_pss_verify(rsa, NULL, NULL, MBEDTLS_RSA_PUBLIC, + hash_algo, (unsigned int)hash_len, hash, + signature); // Clean up mbedtls_pk_free(&key); @@ -859,7 +868,7 @@ static int LuaRSAPSSVerify(lua_State *L) { int expected_salt_len = -1; int result; - // Get parameters from Lua + // Get parameters from Lua if (lua_type(L, 1) != LUA_TSTRING) { lua_pushnil(L); lua_pushstring(L, "Key must be a string"); @@ -893,12 +902,14 @@ static int LuaRSAPSSVerify(lua_State *L) { DEBUGF("(DEBUG) Key PEM: %s", key_pem); DEBUGF("(DEBUG) Message length: %zu", msg_len); DEBUGF("(DEBUG) Signature length: %zu", sig_len); - DEBUGF("(DEBUG) Hash algorithm: %s", hash_algo_str ? hash_algo_str : "SHA-256"); + DEBUGF("(DEBUG) Hash algorithm: %s", + hash_algo_str ? hash_algo_str : "SHA-256"); DEBUGF("(DEBUG) Expected salt length: %d", expected_salt_len); DEBUGF("(DEBUG) Signature: %.*s", (int)sig_len, signature); // Call the C implementation result = RSAPSSVerify(key_pem, (const unsigned char *)msg, msg_len, - (const unsigned char *)signature, sig_len, hash_algo_str, expected_salt_len); + (const unsigned char *)signature, sig_len, + hash_algo_str, expected_salt_len); // Return boolean result (0 means valid signature) lua_pushboolean(L, result == 0); @@ -1077,8 +1088,8 @@ static int ECDSASign(const char *priv_key_pem, const char *message, } // Sign the hash using GenerateRandom - ret = mbedtls_pk_sign(&key, hash_alg, hash, hash_size, - *signature, sig_len, GenerateRandom, 0); + ret = mbedtls_pk_sign(&key, hash_alg, hash, hash_size, *signature, sig_len, + GenerateRandom, 0); if (ret != 0) { WARNF("(crypto) Failed to sign message: -0x%04x", -ret); @@ -1158,8 +1169,7 @@ static int ECDSAVerify(const char *pub_key_pem, const char *message, } // Verify the signature - ret = mbedtls_pk_verify(&key, hash_alg, hash, hash_size, - signature, sig_len); + ret = mbedtls_pk_verify(&key, hash_alg, hash, hash_size, signature, sig_len); if (ret != 0) { WARNF("(crypto) Signature verification failed: -0x%04x", -ret); goto cleanup; @@ -1223,7 +1233,8 @@ typedef struct { size_t aadlen; } aes_options_t; -static void parse_aes_options(lua_State *L, int options_idx, aes_options_t *opts) { +static void parse_aes_options(lua_State *L, int options_idx, + aes_options_t *opts) { opts->mode = NULL; opts->iv = NULL; opts->ivlen = 0; @@ -1240,12 +1251,12 @@ static void parse_aes_options(lua_State *L, int options_idx, aes_options_t *opts if (!lua_isnil(L, -1)) { mode_field_found = 1; const char *mode = lua_tostring(L, -1); - if (mode && (strcasecmp(mode, "cbc") == 0 || - strcasecmp(mode, "gcm") == 0 || - strcasecmp(mode, "ctr") == 0)) { + if (mode && + (strcasecmp(mode, "cbc") == 0 || strcasecmp(mode, "gcm") == 0 || + strcasecmp(mode, "ctr") == 0)) { opts->mode = mode; } else { - opts->mode = NULL; // Invalid mode + opts->mode = NULL; // Invalid mode } } lua_pop(L, 1); @@ -1288,7 +1299,6 @@ static int LuaAesEncrypt(lua_State *L) { // Args: key, plaintext, options table size_t keylen, ptlen; - // Get parameters from Lua // Ensure key is a string if (lua_type(L, 1) != LUA_TSTRING) { @@ -1319,7 +1329,8 @@ static int LuaAesEncrypt(lua_State *L) { const char *mode = opts.mode; if (!mode) { lua_pushnil(L); - lua_pushstring(L, "Invalid AES mode specified. Use 'cbc', 'gcm', or 'ctr'."); + lua_pushstring(L, + "Invalid AES mode specified. Use 'cbc', 'gcm', or 'ctr'."); return 2; } const unsigned char *iv = opts.iv; @@ -1354,6 +1365,7 @@ static int LuaAesEncrypt(lua_State *L) { lua_pushstring(L, "Failed to allocate IV"); return 2; } + // Generate random IV if (GenerateRandom(NULL, gen_iv, ivlen) != 0) { free(gen_iv); @@ -1367,23 +1379,23 @@ static int LuaAesEncrypt(lua_State *L) { // Validate IV/nonce length if (is_cbc || is_ctr) { - // Validate IV/nonce length - if (is_cbc || is_ctr) { - if (opts.iv && opts.ivlen != 16) { - if (iv_was_generated) free(gen_iv); - lua_pushnil(L); - lua_pushstring(L, "AES IV/nonce must be 16 bytes for CBC/CTR"); - return 2; - } - } else if (is_gcm) { - if (opts.iv && (opts.ivlen < 12 || opts.ivlen > 16)) { - if (iv_was_generated) free(gen_iv); - lua_pushnil(L); - lua_pushstring(L, "AES GCM nonce must be 12-16 bytes"); - return 2; - } + if (opts.iv && opts.ivlen != 16) { + if (iv_was_generated) + free(gen_iv); + lua_pushnil(L); + lua_pushstring(L, "AES IV/nonce must be 16 bytes for CBC/CTR"); + return 2; } + } else if (is_gcm) { + if (opts.iv && (opts.ivlen < 12 || opts.ivlen > 16)) { + if (iv_was_generated) + free(gen_iv); + lua_pushnil(L); + lua_pushstring(L, "AES GCM nonce must be 12-16 bytes"); + return 2; } + } + if (is_cbc) { // PKCS7 padding size_t block_size = 16; @@ -1531,7 +1543,7 @@ static int LuaAesDecrypt(lua_State *L) { return 2; } const unsigned char *key = - (const unsigned char *)luaL_checklstring(L, 1, &keylen); + (const unsigned char *)luaL_checklstring(L, 1, &keylen); const unsigned char *ciphertext = (const unsigned char *)luaL_checklstring(L, 2, &ctlen); int options_idx = 3; @@ -1540,7 +1552,8 @@ static int LuaAesDecrypt(lua_State *L) { const char *mode = opts.mode; if (!mode) { lua_pushnil(L); - lua_pushstring(L, "Invalid AES mode specified. Use 'cbc', 'gcm', or 'ctr'."); + lua_pushstring(L, + "Invalid AES mode specified. Use 'cbc', 'gcm', or 'ctr'."); return 2; } const unsigned char *iv = opts.iv; @@ -1725,7 +1738,6 @@ static int LuaAesDecrypt(lua_State *L) { return 2; } - // JWK functions // Convert JWK (Lua table) to PEM key format @@ -1786,23 +1798,44 @@ static int LuaConvertJwkToPem(lua_State *L) { size_t n_len, e_len; unsigned char n_bin[1024], e_bin[16]; char *n_b64_std = strdup(n_b64), *e_b64_std = strdup(e_b64); - for (char *p = n_b64_std; *p; ++p) if (*p == '-') *p = '+'; else if (*p == '_') *p = '/'; - for (char *p = e_b64_std; *p; ++p) if (*p == '-') *p = '+'; else if (*p == '_') *p = '/'; + for (char *p = n_b64_std; *p; ++p) + if (*p == '-') + *p = '+'; + else if (*p == '_') + *p = '/'; + for (char *p = e_b64_std; *p; ++p) + if (*p == '-') + *p = '+'; + else if (*p == '_') + *p = '/'; int n_mod = strlen(n_b64_std) % 4; int e_mod = strlen(e_b64_std) % 4; - if (n_mod) for (int i = 0; i < 4 - n_mod; ++i) strcat(n_b64_std, "="); - if (e_mod) for (int i = 0; i < 4 - e_mod; ++i) strcat(e_b64_std, "="); - if (mbedtls_base64_decode(n_bin, sizeof(n_bin), &n_len, (const unsigned char *)n_b64_std, strlen(n_b64_std)) != 0 || - mbedtls_base64_decode(e_bin, sizeof(e_bin), &e_len, (const unsigned char *)e_b64_std, strlen(e_b64_std)) != 0) { - free(n_b64_std); free(e_b64_std); + if (n_mod) + for (int i = 0; i < 4 - n_mod; ++i) + strcat(n_b64_std, "="); + if (e_mod) + for (int i = 0; i < 4 - e_mod; ++i) + strcat(e_b64_std, "="); + if (mbedtls_base64_decode(n_bin, sizeof(n_bin), &n_len, + (const unsigned char *)n_b64_std, + strlen(n_b64_std)) != 0 || + mbedtls_base64_decode(e_bin, sizeof(e_bin), &e_len, + (const unsigned char *)e_b64_std, + strlen(e_b64_std)) != 0) { + free(n_b64_std); + free(e_b64_std); lua_pushnil(L); lua_pushstring(L, "Base64 decode failed"); return 2; } - free(n_b64_std); free(e_b64_std); + free(n_b64_std); + free(e_b64_std); // Build RSA context in pk - if ((ret = mbedtls_pk_setup(&pk, mbedtls_pk_info_from_type(MBEDTLS_PK_RSA))) != 0) { - lua_pushnil(L); lua_pushstring(L, "mbedtls_pk_setup failed"); return 2; + if ((ret = mbedtls_pk_setup( + &pk, mbedtls_pk_info_from_type(MBEDTLS_PK_RSA))) != 0) { + lua_pushnil(L); + lua_pushstring(L, "mbedtls_pk_setup failed"); + return 2; } mbedtls_rsa_context *rsa = mbedtls_pk_rsa(pk); mbedtls_rsa_init(rsa, MBEDTLS_RSA_PKCS_V15, 0); @@ -1812,17 +1845,25 @@ static int LuaConvertJwkToPem(lua_State *L) { if (has_private) { // Decode and set private fields size_t d_len, p_len, q_len, dp_len, dq_len, qi_len; - unsigned char d_bin[1024], p_bin[512], q_bin[512], dp_bin[512], dq_bin[512], qi_bin[512]; - // Decode all private fields (skip if NULL) - #define DECODE_B64URL(var, b64, bin, binlen) \ - if (b64 && *b64) { \ - char *b64_std = strdup(b64); \ - for (char *p = b64_std; *p; ++p) if (*p == '-') *p = '+'; else if (*p == '_') *p = '/'; \ - int mod = strlen(b64_std) % 4; \ - if (mod) for (int i = 0; i < 4 - mod; ++i) strcat(b64_std, "="); \ - mbedtls_base64_decode(bin, sizeof(bin), &binlen, (const unsigned char *)b64_std, strlen(b64_std)); \ - free(b64_std); \ - } + unsigned char d_bin[1024], p_bin[512], q_bin[512], dp_bin[512], + dq_bin[512], qi_bin[512]; +// Decode all private fields (skip if NULL) +#define DECODE_B64URL(var, b64, bin, binlen) \ + if (b64 && *b64) { \ + char *b64_std = strdup(b64); \ + for (char *p = b64_std; *p; ++p) \ + if (*p == '-') \ + *p = '+'; \ + else if (*p == '_') \ + *p = '/'; \ + int mod = strlen(b64_std) % 4; \ + if (mod) \ + for (int i = 0; i < 4 - mod; ++i) \ + strcat(b64_std, "="); \ + mbedtls_base64_decode(bin, sizeof(bin), &binlen, \ + (const unsigned char *)b64_std, strlen(b64_std)); \ + free(b64_std); \ + } DECODE_B64URL(d, d_b64, d_bin, d_len); DECODE_B64URL(p, p_b64, p_bin, p_len); DECODE_B64URL(q, q_b64, q_bin, q_len); @@ -1845,7 +1886,9 @@ static int LuaConvertJwkToPem(lua_State *L) { } if (ret != 0) { mbedtls_pk_free(&pk); - lua_pushnil(L); lua_pushstring(L, "PEM write failed"); return 2; + lua_pushnil(L); + lua_pushstring(L, "PEM write failed"); + return 2; } pem = strdup((char *)buf); mbedtls_pk_free(&pk); @@ -1863,36 +1906,67 @@ static int LuaConvertJwkToPem(lua_State *L) { const char *y_b64 = lua_tostring(L, -2); const char *d_b64 = lua_tostring(L, -1); if (!crv || !*crv) { - lua_pushnil(L); lua_pushstring(L, "Missing or empty 'crv' in JWK"); return 2; + lua_pushnil(L); + lua_pushstring(L, "Missing or empty 'crv' in JWK"); + return 2; } if (!x_b64 || !*x_b64) { - lua_pushnil(L); lua_pushstring(L, "Missing or empty 'x' in JWK"); return 2; + lua_pushnil(L); + lua_pushstring(L, "Missing or empty 'x' in JWK"); + return 2; } if (!y_b64 || !*y_b64) { - lua_pushnil(L); lua_pushstring(L, "Missing or empty 'y' in JWK"); return 2; + lua_pushnil(L); + lua_pushstring(L, "Missing or empty 'y' in JWK"); + return 2; } int has_private = d_b64 && *d_b64; mbedtls_ecp_group_id gid = find_curve_by_name(crv); if (gid == MBEDTLS_ECP_DP_NONE) { - lua_pushnil(L); lua_pushstring(L, "Unknown curve"); return 2; + lua_pushnil(L); + lua_pushstring(L, "Unknown curve"); + return 2; } size_t x_len, y_len; unsigned char x_bin[72], y_bin[72]; char *x_b64_std = strdup(x_b64), *y_b64_std = strdup(y_b64); - for (char *p = x_b64_std; *p; ++p) if (*p == '-') *p = '+'; else if (*p == '_') *p = '/'; - for ( char *p = y_b64_std; *p; ++p) if (*p == '-') *p = '+'; else if (*p == '_') *p = '/'; + for (char *p = x_b64_std; *p; ++p) + if (*p == '-') + *p = '+'; + else if (*p == '_') + *p = '/'; + for (char *p = y_b64_std; *p; ++p) + if (*p == '-') + *p = '+'; + else if (*p == '_') + *p = '/'; int x_mod = strlen(x_b64_std) % 4; int y_mod = strlen(y_b64_std) % 4; - if (x_mod) for (int i = 0; i < 4 - x_mod; ++i) strcat(x_b64_std, "="); - if (y_mod) for (int i = 0; i < 4 - y_mod; ++i) strcat(y_b64_std, "="); - if (mbedtls_base64_decode(x_bin, sizeof(x_bin), &x_len, (const unsigned char *)x_b64_std, strlen(x_b64_std)) != 0 || - mbedtls_base64_decode(y_bin, sizeof(y_bin), &y_len, (const unsigned char *)y_b64_std, strlen(y_b64_std)) != 0) { - free(x_b64_std); free(y_b64_std); - lua_pushnil(L); lua_pushstring(L, "Base64 decode failed"); return 2; + if (x_mod) + for (int i = 0; i < 4 - x_mod; ++i) + strcat(x_b64_std, "="); + if (y_mod) + for (int i = 0; i < 4 - y_mod; ++i) + strcat(y_b64_std, "="); + if (mbedtls_base64_decode(x_bin, sizeof(x_bin), &x_len, + (const unsigned char *)x_b64_std, + strlen(x_b64_std)) != 0 || + mbedtls_base64_decode(y_bin, sizeof(y_bin), &y_len, + (const unsigned char *)y_b64_std, + strlen(y_b64_std)) != 0) { + free(x_b64_std); + free(y_b64_std); + lua_pushnil(L); + lua_pushstring(L, "Base64 decode failed"); + return 2; } - free(x_b64_std); free(y_b64_std); - if ((ret = mbedtls_pk_setup(&pk, mbedtls_pk_info_from_type(MBEDTLS_PK_ECKEY))) != 0) { - lua_pushnil(L); lua_pushstring(L, "mbedtls_pk_setup failed"); return 2; + free(x_b64_std); + free(y_b64_std); + if ((ret = mbedtls_pk_setup( + &pk, mbedtls_pk_info_from_type(MBEDTLS_PK_ECKEY))) != 0) { + lua_pushnil(L); + lua_pushstring(L, "mbedtls_pk_setup failed"); + return 2; } mbedtls_ecp_keypair *ec = mbedtls_pk_ec(pk); mbedtls_ecp_keypair_init(ec); @@ -1914,8 +1988,10 @@ static int LuaConvertJwkToPem(lua_State *L) { } if (ret != 0) { mbedtls_pk_free(&pk); - lua_pushnil(L); lua_pushstring(L, "PEM write failed"); return 2; - } + lua_pushnil(L); + lua_pushstring(L, "PEM write failed"); + return 2; + } pem = strdup((char *)buf); mbedtls_pk_free(&pk); lua_pushstring(L, pem); @@ -1930,14 +2006,17 @@ static int LuaConvertJwkToPem(lua_State *L) { // Convert PEM key to JWK (Lua table) format static void base64_to_base64url(char *str) { - if (!str) return; + if (!str) + return; for (char *p = str; *p; p++) { - if (*p == '+') *p = '-'; - else if (*p == '/') *p = '_'; + if (*p == '+') + *p = '-'; + else if (*p == '/') + *p = '_'; } // Remove padding size_t len = strlen(str); - while (len > 0 && str[len-1] == '=') { + while (len > 0 && str[len - 1] == '=') { str[--len] = '\0'; } } @@ -1990,10 +2069,12 @@ static int LuaConvertPemToJwk(lua_State *L) { mbedtls_base64_encode(NULL, 0, &e_b64_len, e, e_len); n_b64 = malloc(n_b64_len + 1); e_b64 = malloc(e_b64_len + 1); - mbedtls_base64_encode((unsigned char *)n_b64, n_b64_len, &n_b64_len, n, n_len); + mbedtls_base64_encode((unsigned char *)n_b64, n_b64_len, &n_b64_len, n, + n_len); n_b64[n_b64_len] = '\0'; base64_to_base64url(n_b64); - mbedtls_base64_encode((unsigned char *)e_b64, e_b64_len, &e_b64_len, e, e_len); + mbedtls_base64_encode((unsigned char *)e_b64, e_b64_len, &e_b64_len, e, + e_len); e_b64[e_b64_len] = '\0'; base64_to_base64url(e_b64); lua_pushstring(L, n_b64); @@ -2045,12 +2126,18 @@ static int LuaConvertPemToJwk(lua_State *L) { dp_b64 = malloc(dp_b64_len + 1); dq_b64 = malloc(dq_b64_len + 1); qi_b64 = malloc(qi_b64_len + 1); - mbedtls_base64_encode((unsigned char *)d_b64, d_b64_len, &d_b64_len, d, d_len); - mbedtls_base64_encode((unsigned char *)p_b64, p_b64_len, &p_b64_len, p, p_len); - mbedtls_base64_encode((unsigned char *)q_b64, q_b64_len, &q_b64_len, q, q_len); - mbedtls_base64_encode((unsigned char *)dp_b64, dp_b64_len, &dp_b64_len, dp, dp_len); - mbedtls_base64_encode((unsigned char *)dq_b64, dq_b64_len, &dq_b64_len, dq, dq_len); - mbedtls_base64_encode((unsigned char *)qi_b64, qi_b64_len, &qi_b64_len, qi, qi_len); + mbedtls_base64_encode((unsigned char *)d_b64, d_b64_len, &d_b64_len, d, + d_len); + mbedtls_base64_encode((unsigned char *)p_b64, p_b64_len, &p_b64_len, p, + p_len); + mbedtls_base64_encode((unsigned char *)q_b64, q_b64_len, &q_b64_len, q, + q_len); + mbedtls_base64_encode((unsigned char *)dp_b64, dp_b64_len, &dp_b64_len, + dp, dp_len); + mbedtls_base64_encode((unsigned char *)dq_b64, dq_b64_len, &dq_b64_len, + dq, dq_len); + mbedtls_base64_encode((unsigned char *)qi_b64, qi_b64_len, &qi_b64_len, + qi, qi_len); d_b64[d_b64_len] = '\0'; p_b64[p_b64_len] = '\0'; q_b64[q_b64_len] = '\0'; @@ -2117,16 +2204,19 @@ static int LuaConvertPemToJwk(lua_State *L) { mbedtls_base64_encode(NULL, 0, &y_b64_len, y, y_len); x_b64 = malloc(x_b64_len + 1); y_b64 = malloc(y_b64_len + 1); - mbedtls_base64_encode((unsigned char *)x_b64, x_b64_len, &x_b64_len, x, x_len); + mbedtls_base64_encode((unsigned char *)x_b64, x_b64_len, &x_b64_len, x, + x_len); x_b64[x_b64_len] = '\0'; base64_to_base64url(x_b64); - mbedtls_base64_encode((unsigned char *)y_b64, y_b64_len, &y_b64_len, y, y_len); + mbedtls_base64_encode((unsigned char *)y_b64, y_b64_len, &y_b64_len, y, + y_len); y_b64[y_b64_len] = '\0'; base64_to_base64url(y_b64); // Set kty and crv for EC keys lua_pushstring(L, "EC"); lua_setfield(L, -2, "kty"); - const mbedtls_ecp_curve_info *curve_info = mbedtls_ecp_curve_info_from_grp_id(ec->grp.id); + const mbedtls_ecp_curve_info *curve_info = + mbedtls_ecp_curve_info_from_grp_id(ec->grp.id); if (curve_info && curve_info->name) { lua_pushstring(L, curve_info->name); lua_setfield(L, -2, "crv"); @@ -2142,19 +2232,34 @@ static int LuaConvertPemToJwk(lua_State *L) { if (mbedtls_ecp_check_privkey(&ec->grp, &ec->d) == 0 && ec->d.p) { size_t d_len = mbedtls_mpi_size(&ec->d); unsigned char *d = malloc(d_len); - if (!d) { free(x); free(y); free(x_b64); free(y_b64); lua_pushnil(L); lua_pushstring(L, "Memory allocation failed"); mbedtls_pk_free(&key); return 2; } + if (!d) { + free(x); + free(y); + free(x_b64); + free(y_b64); + lua_pushnil(L); + lua_pushstring(L, "Memory allocation failed"); + mbedtls_pk_free(&key); + return 2; + } mbedtls_mpi_write_binary(&ec->d, d, d_len); char *d_b64 = NULL; size_t d_b64_len; mbedtls_base64_encode(NULL, 0, &d_b64_len, d, d_len); d_b64 = malloc(d_b64_len + 1); - mbedtls_base64_encode((unsigned char *)d_b64, d_b64_len, &d_b64_len, d, d_len); + mbedtls_base64_encode((unsigned char *)d_b64, d_b64_len, &d_b64_len, d, + d_len); d_b64[d_b64_len] = '\0'; base64_to_base64url(d_b64); - lua_pushstring(L, d_b64); lua_setfield(L, -2, "d"); - free(d); free(d_b64); + lua_pushstring(L, d_b64); + lua_setfield(L, -2, "d"); + free(d); + free(d_b64); } - free(x); free(y); free(x_b64); free(y_b64); + free(x); + free(y); + free(x_b64); + free(y_b64); } else { lua_pushnil(L); lua_pushstring(L, "Unsupported key type"); @@ -2166,8 +2271,10 @@ static int LuaConvertPemToJwk(lua_State *L) { // Merge additional claims if provided and compatible with RFC7517 if (has_claims) { - static const char *allowed[] = {"kty","use","sig","key_ops","alg","kid","x5u","x5c","x5t","x5t#S256",NULL}; - lua_pushnil(L); // first key + static const char *allowed[] = {"kty", "use", "sig", "key_ops", + "alg", "kid", "x5u", "x5c", + "x5t", "x5t#S256", NULL}; + lua_pushnil(L); // first key while (lua_next(L, 2) != 0) { const char *k = lua_tostring(L, -2); int allowed_key = 0; @@ -2252,7 +2359,6 @@ static int LuaGenerateCSR(lua_State *L) { return 1; } - // LuaCrypto compatible API static int LuaCryptoSign(lua_State *L) { // Type of signature (e.g., "rsa", "ecdsa", "rsa-pss") @@ -2262,7 +2368,8 @@ static int LuaCryptoSign(lua_State *L) { if (strcasecmp(dtype, "rsa") == 0) { return LuaRSASign(L); - } else if (strcasecmp(dtype, "rsa-pss") == 0 || strcasecmp(dtype, "rsapss") == 0) { + } else if (strcasecmp(dtype, "rsa-pss") == 0 || + strcasecmp(dtype, "rsapss") == 0) { return LuaRSAPSSSign(L); } else if (strcasecmp(dtype, "ecdsa") == 0) { return LuaECDSASign(L); @@ -2279,7 +2386,8 @@ static int LuaCryptoVerify(lua_State *L) { if (strcasecmp(dtype, "rsa") == 0) { return LuaRSAVerify(L); - } else if (strcasecmp(dtype, "rsa-pss") == 0 || strcasecmp(dtype, "rsapss") == 0) { + } else if (strcasecmp(dtype, "rsa-pss") == 0 || + strcasecmp(dtype, "rsapss") == 0) { return LuaRSAPSSVerify(L); } else if (strcasecmp(dtype, "ecdsa") == 0) { return LuaECDSAVerify(L); @@ -2339,8 +2447,6 @@ static int LuaCryptoGenerateKeyPair(lua_State *L) { } } - - static const luaL_Reg kLuaCrypto[] = { {"sign", LuaCryptoSign}, // {"verify", LuaCryptoVerify}, // From 77ff5a24f9c0856bd473943d6df61df8eeb5d585 Mon Sep 17 00:00:00 2001 From: Miguel Terron Date: Wed, 25 Jun 2025 16:53:48 +1200 Subject: [PATCH 17/18] Cleanup hash validation Improve JWK functions Enable PKCS1 v2.1 support in MBEDTLS --- third_party/mbedtls/config.h | 4 +- tool/net/lcrypto.c | 524 +++++++++++++++++------------------ 2 files changed, 258 insertions(+), 270 deletions(-) diff --git a/third_party/mbedtls/config.h b/third_party/mbedtls/config.h index c4e457749..1a6fc18c9 100644 --- a/third_party/mbedtls/config.h +++ b/third_party/mbedtls/config.h @@ -395,7 +395,9 @@ * * This enables support for RSAES-OAEP and RSASSA-PSS operations. */ -/*#define MBEDTLS_PKCS1_V21*/ +#ifndef TINY +#define MBEDTLS_PKCS1_V21 +#endif /** * \def MBEDTLS_RSA_NO_CRT diff --git a/tool/net/lcrypto.c b/tool/net/lcrypto.c index 4cda00e1d..6e4ec032e 100644 --- a/tool/net/lcrypto.c +++ b/tool/net/lcrypto.c @@ -32,6 +32,7 @@ #include "third_party/mbedtls/rsa.h" #include "third_party/mbedtls/x509_csr.h" +// Elliptic curves // Supported curves mapping typedef struct { const char *name; @@ -39,23 +40,25 @@ typedef struct { } curve_map_t; static const curve_map_t supported_curves[] = { - {"secp256r1", MBEDTLS_ECP_DP_SECP256R1}, - {"P256", MBEDTLS_ECP_DP_SECP256R1}, - {"P-256", MBEDTLS_ECP_DP_SECP256R1}, - {"secp384r1", MBEDTLS_ECP_DP_SECP384R1}, - {"P384", MBEDTLS_ECP_DP_SECP384R1}, - {"P-384", MBEDTLS_ECP_DP_SECP384R1}, - {"secp521r1", MBEDTLS_ECP_DP_SECP521R1}, - {"P521", MBEDTLS_ECP_DP_SECP521R1}, - {"P-521", MBEDTLS_ECP_DP_SECP521R1}, - {"curve25519", MBEDTLS_ECP_DP_CURVE25519}, + {"secp256r1", MBEDTLS_ECP_DP_SECP256R1}, // + {"P256", MBEDTLS_ECP_DP_SECP256R1}, // + {"P-256", MBEDTLS_ECP_DP_SECP256R1}, // + {"secp384r1", MBEDTLS_ECP_DP_SECP384R1}, // + {"P384", MBEDTLS_ECP_DP_SECP384R1}, // + {"P-384", MBEDTLS_ECP_DP_SECP384R1}, // + {"secp521r1", MBEDTLS_ECP_DP_SECP521R1}, // + {"P521", MBEDTLS_ECP_DP_SECP521R1}, // + {"P-521", MBEDTLS_ECP_DP_SECP521R1}, // + {"curve25519", MBEDTLS_ECP_DP_CURVE25519}, // +#ifndef TINY + {"curve448", MBEDTLS_ECP_DP_CURVE448}, // +#endif {NULL, 0}}; // List available curves static int LuaListCurves(lua_State *L) { const curve_map_t *curve = supported_curves; int i = 1; - lua_newtable(L); while (curve->name != NULL) { @@ -63,28 +66,68 @@ static int LuaListCurves(lua_State *L) { lua_rawseti(L, -2, i++); curve++; } - return 1; } -// Remove hash_algorithm_t and hash_to_md_type indirection -static mbedtls_md_type_t string_to_md_type(const char *hash_name) { - if (!hash_name || !*hash_name) { - return MBEDTLS_MD_SHA256; // Default to SHA-256 if no name provided +// Find curve ID by name +static mbedtls_ecp_group_id find_curve_by_name(const char *name) { + const curve_map_t *curve = supported_curves; + + while (curve->name != NULL) { + if (strcasecmp(curve->name, name) == 0) { + return curve->id; + } + curve++; } - if (strcasecmp(hash_name, "sha256") == 0 || - strcasecmp(hash_name, "sha-256") == 0) { - return MBEDTLS_MD_SHA256; - } else if (strcasecmp(hash_name, "sha384") == 0 || - strcasecmp(hash_name, "sha-384") == 0) { - return MBEDTLS_MD_SHA384; - } else if (strcasecmp(hash_name, "sha512") == 0 || - strcasecmp(hash_name, "sha-512") == 0) { - return MBEDTLS_MD_SHA512; - } else { - WARNF("(crypto) Unknown hash algorithm '%s'", hash_name); - return -1; + return MBEDTLS_ECP_DP_NONE; +} + +// Message digests +// Supported digests mapping +typedef struct { + const char *name; + mbedtls_md_type_t id; +} digest_map_t; + +static const digest_map_t supported_digests[] = { + {"MD5", MBEDTLS_MD_MD5}, // + {"SHA1", MBEDTLS_MD_SHA1}, // + {"SHA-1", MBEDTLS_MD_SHA1}, // + {"SHA224", MBEDTLS_MD_SHA224}, // + {"SHA-224", MBEDTLS_MD_SHA224}, // + {"SHA256", MBEDTLS_MD_SHA256}, // + {"SHA-256", MBEDTLS_MD_SHA256}, // + {"SHA384", MBEDTLS_MD_SHA384}, // + {"SHA-384", MBEDTLS_MD_SHA384}, // + {"SHA512", MBEDTLS_MD_SHA512}, // + {"SHA-512", MBEDTLS_MD_SHA512}, // + {NULL, 0}}; + +// List available digests +static int LuaListDigests(lua_State *L) { + const digest_map_t *digest = supported_digests; + int i = 1; + lua_newtable(L); + + while (digest->name != NULL) { + lua_pushstring(L, digest->name); + lua_rawseti(L, -2, i++); + digest++; } + return 1; +} + +// Find digest ID by name +static mbedtls_md_type_t find_digest_by_name(const char *name) { + const digest_map_t *digest = supported_digests; + + while (digest->name != NULL) { + if (strcasecmp(digest->name, name) == 0) { + return digest->id; + } + digest++; + } + return MBEDTLS_MD_NONE; } // Get the size of the hash output based on the mbedtls_md_type_t @@ -96,43 +139,16 @@ static size_t get_hash_size_from_md_type(mbedtls_md_type_t md_type) { return 48; case MBEDTLS_MD_SHA512: return 64; + case MBEDTLS_MD_SHA1: + return 20; + case MBEDTLS_MD_MD5: + return 16; default: return 32; } } -// List available hash algorithms -static int LuaListHashAlgorithms(lua_State *L) { - lua_newtable(L); - - lua_pushstring(L, "SHA256"); - lua_rawseti(L, -2, 1); - - lua_pushstring(L, "SHA384"); - lua_rawseti(L, -2, 2); - - lua_pushstring(L, "SHA512"); - lua_rawseti(L, -2, 3); - - // Add hyphenated versions - lua_pushstring(L, "SHA-256"); - lua_rawseti(L, -2, 4); - - lua_pushstring(L, "SHA-384"); - lua_rawseti(L, -2, 5); - - lua_pushstring(L, "SHA-512"); - lua_rawseti(L, -2, 6); - - lua_pushstring(L, "SHA1"); - lua_rawseti(L, -2, 7); - - lua_pushstring(L, "MD5"); - lua_rawseti(L, -2, 8); - - return 1; -} - +// Compute hash using mbedtls static int compute_hash(mbedtls_md_type_t md_type, const unsigned char *input, size_t input_len, unsigned char *output, size_t output_size) { @@ -182,20 +198,6 @@ cleanup: return ret; } -// Find curve ID by name -static mbedtls_ecp_group_id find_curve_by_name(const char *name) { - const curve_map_t *curve = supported_curves; - - while (curve->name != NULL) { - if (strcasecmp(curve->name, name) == 0) { - return curve->id; - } - curve++; - } - - return MBEDTLS_ECP_DP_NONE; -} - // Strong RNG using mbedtls_entropy_context and mbedtls_ctr_drbg_context int GenerateRandom(void *ctx, unsigned char *output, size_t len) { static mbedtls_entropy_context entropy; @@ -499,18 +501,26 @@ static int LuaRSADecrypt(lua_State *L) { // RSA Signing static char *RSASign(const char *private_key_pem, const unsigned char *data, - size_t data_len, const char *hash_algo_str, - size_t *sig_len) { + size_t data_len, const char *hash_name, size_t *sig_len) { int rc; unsigned char hash[64]; // Large enough for SHA-512 - size_t hash_len = 32; // Default for SHA-256 + size_t hash_len = 0; + mbedtls_md_type_t hash_algo; unsigned char *signature; - mbedtls_md_type_t hash_algo = MBEDTLS_MD_SHA256; // Default // Determine hash algorithm - if (hash_algo_str) { - hash_algo = string_to_md_type(hash_algo_str); - hash_len = get_hash_size_from_md_type(hash_algo); + if (hash_name == NULL || hash_name[0] == '\0') { + hash_algo = MBEDTLS_MD_SHA256; + VERBOSEF("(crypto) No digest specified, using default: SHA256"); + } else { + // Find the digest by name + hash_algo = find_digest_by_name(hash_name); + if (hash_algo == MBEDTLS_MD_NONE) { + WARNF("(crypto) Unknown digest: '%s'", hash_name); + return NULL; + } else { + hash_len = get_hash_size_from_md_type(hash_algo); + } } // Parse private key @@ -559,7 +569,7 @@ static char *RSASign(const char *private_key_pem, const unsigned char *data, } static int LuaRSASign(lua_State *L) { size_t msg_len, key_len; - const char *msg, *key_pem, *hash_algo_str = NULL; + const char *msg, *key_pem, *hash_name = NULL; unsigned char *signature; size_t sig_len = 0; @@ -570,7 +580,6 @@ static int LuaRSASign(lua_State *L) { return 2; } if (lua_type(L, 1) == LUA_TTABLE) { - DEBUGF("(crypto) Detected table type for parameter 1"); lua_pushnil(L); lua_pushstring(L, "Key must be a string, got a table instead"); return 2; @@ -587,12 +596,12 @@ static int LuaRSASign(lua_State *L) { // Optional hash algorithm parameter if (!lua_isnoneornil(L, 3)) { - hash_algo_str = luaL_checkstring(L, 3); + hash_name = luaL_checkstring(L, 3); } // Call the C implementation signature = (unsigned char *)RSASign(key_pem, (const unsigned char *)msg, - msg_len, hash_algo_str, &sig_len); + msg_len, hash_name, &sig_len); if (!signature) { return luaL_error(L, "failed to sign message"); @@ -609,18 +618,27 @@ static int LuaRSASign(lua_State *L) { // RSA PSS Signing static char *RSAPSSSign(const char *private_key_pem, const unsigned char *data, - size_t data_len, const char *hash_algo_str, - size_t *sig_len, int salt_len) { + size_t data_len, const char *hash_name, size_t *sig_len, + int salt_len) { int rc; unsigned char hash[64]; // Large enough for SHA-512 - size_t hash_len = 32; // Default for SHA-256 + size_t hash_len = 0; + mbedtls_md_type_t hash_algo; unsigned char *signature; - mbedtls_md_type_t hash_algo = MBEDTLS_MD_SHA256; // Default // Determine hash algorithm - if (hash_algo_str) { - hash_algo = string_to_md_type(hash_algo_str); - hash_len = get_hash_size_from_md_type(hash_algo); + if (hash_name == NULL || hash_name[0] == '\0') { + hash_algo = MBEDTLS_MD_SHA256; + VERBOSEF("(crypto) No digest specified, using default: SHA256"); + } else { + // Find the digest by name + hash_algo = find_digest_by_name(hash_name); + if (hash_algo == MBEDTLS_MD_NONE) { + WARNF("(crypto) Unknown digest: '%s'", hash_name); + return NULL; + } else { + hash_len = get_hash_size_from_md_type(hash_algo); + } } // Parse private key @@ -723,16 +741,25 @@ static int LuaRSAPSSSign(lua_State *L) { static int RSAVerify(const char *public_key_pem, const unsigned char *data, size_t data_len, const unsigned char *signature, - size_t sig_len, const char *hash_algo_str) { + size_t sig_len, const char *hash_name) { int rc; - unsigned char hash[64]; // Large enough for SHA-512 - size_t hash_len = 32; // Default for SHA-256 - mbedtls_md_type_t hash_algo = MBEDTLS_MD_SHA256; // Default + unsigned char hash[64]; // Large enough for SHA-512 + size_t hash_len = 0; + mbedtls_md_type_t hash_algo; // Determine hash algorithm - if (hash_algo_str) { - hash_algo = string_to_md_type(hash_algo_str); - hash_len = get_hash_size_from_md_type(hash_algo); + if (hash_name == NULL || hash_name[0] == '\0') { + hash_algo = MBEDTLS_MD_SHA256; + VERBOSEF("(crypto) No digest specified, using default: SHA256"); + } else { + // Find the digest by name + hash_algo = find_digest_by_name(hash_name); + if (hash_algo == MBEDTLS_MD_NONE) { + WARNF("(crypto) Unknown digest: '%s'", hash_name); + return -1; + } else { + hash_len = get_hash_size_from_md_type(hash_algo); + } } // Parse public key @@ -770,7 +797,7 @@ static int RSAVerify(const char *public_key_pem, const unsigned char *data, } static int LuaRSAVerify(lua_State *L) { size_t msg_len, key_len, sig_len; - const char *msg, *key_pem, *signature, *hash_algo_str = NULL; + const char *msg, *key_pem, *signature, *hash_name = NULL; int result; // Get parameters from Lua @@ -791,12 +818,12 @@ static int LuaRSAVerify(lua_State *L) { // Optional hash algorithm parameter if (!lua_isnoneornil(L, 4)) { - hash_algo_str = luaL_checkstring(L, 4); + hash_name = luaL_checkstring(L, 4); } // Call the C implementation result = RSAVerify(key_pem, (const unsigned char *)msg, msg_len, - (const unsigned char *)signature, sig_len, hash_algo_str); + (const unsigned char *)signature, sig_len, hash_name); // Return boolean result (0 means valid signature) lua_pushboolean(L, result == 0); @@ -807,17 +834,26 @@ static int LuaRSAVerify(lua_State *L) { // RSA PSS Verification static int RSAPSSVerify(const char *public_key_pem, const unsigned char *data, size_t data_len, const unsigned char *signature, - size_t sig_len, const char *hash_algo_str, + size_t sig_len, const char *hash_name, int expected_salt_len) { int rc; - unsigned char hash[64]; // Large enough for SHA-512 - size_t hash_len = 32; // Default for SHA-256 - mbedtls_md_type_t hash_algo = MBEDTLS_MD_SHA256; // Default + unsigned char hash[64]; // Large enough for SHA-512 + size_t hash_len = 0; + mbedtls_md_type_t hash_algo; // Determine hash algorithm - if (hash_algo_str) { - hash_algo = string_to_md_type(hash_algo_str); - hash_len = get_hash_size_from_md_type(hash_algo); + if (hash_name == NULL || hash_name[0] == '\0') { + hash_algo = MBEDTLS_MD_SHA256; + VERBOSEF("(crypto) No digest specified, using default: SHA256"); + } else { + // Find the digest by name + hash_algo = find_digest_by_name(hash_name); + if (hash_algo == MBEDTLS_MD_NONE) { + WARNF("(crypto) Unknown digest: '%s'", hash_name); + return -1; + } else { + hash_len = get_hash_size_from_md_type(hash_algo); + } } // Parse public key @@ -864,7 +900,7 @@ static int LuaRSAPSSVerify(lua_State *L) { // crypto.verify('rsapss', key, msg, signature, hash_algo, salt_len) size_t msg_len, key_len, sig_len; const char *msg, *key_pem, *signature; - const char *hash_algo_str = NULL; // Default to SHA-256 + const char *hash_name = NULL; // Default to SHA-256 int expected_salt_len = -1; int result; @@ -892,24 +928,17 @@ static int LuaRSAPSSVerify(lua_State *L) { signature = luaL_checklstring(L, 3, &sig_len); // Optional hash algorithm parameter if (lua_isstring(L, 4)) { - hash_algo_str = luaL_checkstring(L, 4); + hash_name = luaL_checkstring(L, 4); // Optional salt length parameter expected_salt_len = luaL_optinteger(L, 5, 32); } else if (lua_isnumber(L, 4)) { // If it's a number, treat it as the expected salt length expected_salt_len = (int)lua_tointeger(L, 4); } - DEBUGF("(DEBUG) Key PEM: %s", key_pem); - DEBUGF("(DEBUG) Message length: %zu", msg_len); - DEBUGF("(DEBUG) Signature length: %zu", sig_len); - DEBUGF("(DEBUG) Hash algorithm: %s", - hash_algo_str ? hash_algo_str : "SHA-256"); - DEBUGF("(DEBUG) Expected salt length: %d", expected_salt_len); - DEBUGF("(DEBUG) Signature: %.*s", (int)sig_len, signature); // Call the C implementation result = RSAPSSVerify(key_pem, (const unsigned char *)msg, msg_len, - (const unsigned char *)signature, sig_len, - hash_algo_str, expected_salt_len); + (const unsigned char *)signature, sig_len, hash_name, + expected_salt_len); // Return boolean result (0 means valid signature) lua_pushboolean(L, result == 0); @@ -1108,13 +1137,24 @@ static int LuaECDSASign(lua_State *L) { const char *priv_key_pem = luaL_checkstring(L, 1); const char *message = luaL_checkstring(L, 2); const char *hash_name = luaL_optstring(L, 3, "sha256"); + mbedtls_md_type_t hash_algo; - mbedtls_md_type_t hash_alg = string_to_md_type(hash_name); + if (hash_name == NULL || hash_name[0] == '\0') { + hash_algo = MBEDTLS_MD_SHA256; + VERBOSEF("(crypto) No digest specified, using default: SHA256"); + } else { + // Find the digest by name + hash_algo = find_digest_by_name(hash_name); + if (hash_algo == MBEDTLS_MD_NONE) { + WARNF("(crypto) Unknown digest: '%s'", hash_name); + return -1; + } + } unsigned char *signature = NULL; size_t sig_len = 0; - int ret = ECDSASign(priv_key_pem, message, hash_alg, &signature, &sig_len); + int ret = ECDSASign(priv_key_pem, message, hash_algo, &signature, &sig_len); if (ret == 0) { lua_pushlstring(L, (const char *)signature, sig_len); @@ -1187,10 +1227,21 @@ static int LuaECDSAVerify(lua_State *L) { const unsigned char *signature = (const unsigned char *)luaL_checklstring(L, 3, &sig_len); const char *hash_name = luaL_optstring(L, 4, "sha256"); + mbedtls_md_type_t hash_algo; - mbedtls_md_type_t hash_alg = string_to_md_type(hash_name); + if (hash_name == NULL || hash_name[0] == '\0') { + hash_algo = MBEDTLS_MD_SHA256; + VERBOSEF("(crypto) No digest specified, using default: SHA256"); + } else { + // Find the digest by name + hash_algo = find_digest_by_name(hash_name); + if (hash_algo == MBEDTLS_MD_NONE) { + WARNF("(crypto) Unknown digest: '%s'", hash_name); + lua_pushboolean(L, false); + } + } - int ret = ECDSAVerify(pub_key_pem, message, signature, sig_len, hash_alg); + int ret = ECDSAVerify(pub_key_pem, message, signature, sig_len, hash_algo); lua_pushboolean(L, ret == 0); return 1; @@ -1739,8 +1790,59 @@ static int LuaAesDecrypt(lua_State *L) { } // JWK functions +// Helper: convert base64url to standard base64 (in-place) +static void base64url_to_base64(char *input) { + if (!input) + return; + // Replace URL-safe characters with standard base64 characters + for (char *p = input; *p; ++p) { + if (*p == '-') + *p = '+'; + else if (*p == '_') + *p = '/'; + } + // Add padding if necessary + size_t len = strlen(input); + int mod = len % 4; + if (mod) { + for (int i = 0; i < 4 - mod; ++i) + input[len + i] = '='; + input[len + 4 - mod] = '\0'; + } +} -// Convert JWK (Lua table) to PEM key format +// Helper: convert standard base64 to base64url (in-place) +static void base64_to_base64url(char *input) { + if (!input) + return; + size_t len = strlen(input); + // Replace standard base64 characters with URL-safe characters + for (size_t i = 0; i < len; i++) { + if (input[i] == '+') + input[i] = '-'; + else if (input[i] == '/') + input[i] = '_'; + } + // Remove padding + while (len > 0 && input[len - 1] == '=') { + input[--len] = '\0'; + } +} + +// Helper: encode binary to base64url +static char *b64url_encode(const unsigned char *data, size_t len) { + size_t b64_len; + mbedtls_base64_encode(NULL, 0, &b64_len, data, len); + char *b64 = malloc(b64_len + 1); + if (!b64) + return NULL; + mbedtls_base64_encode((unsigned char *)b64, b64_len, &b64_len, data, len); + b64[b64_len] = '\0'; + base64_to_base64url(b64); + return b64; +} + +// Convert JWK key to PEM (string) format static int LuaConvertJwkToPem(lua_State *L) { luaL_checktype(L, 1, LUA_TTABLE); const char *kty; @@ -1797,25 +1899,10 @@ static int LuaConvertJwkToPem(lua_State *L) { // Decode base64url to binary size_t n_len, e_len; unsigned char n_bin[1024], e_bin[16]; - char *n_b64_std = strdup(n_b64), *e_b64_std = strdup(e_b64); - for (char *p = n_b64_std; *p; ++p) - if (*p == '-') - *p = '+'; - else if (*p == '_') - *p = '/'; - for (char *p = e_b64_std; *p; ++p) - if (*p == '-') - *p = '+'; - else if (*p == '_') - *p = '/'; - int n_mod = strlen(n_b64_std) % 4; - int e_mod = strlen(e_b64_std) % 4; - if (n_mod) - for (int i = 0; i < 4 - n_mod; ++i) - strcat(n_b64_std, "="); - if (e_mod) - for (int i = 0; i < 4 - e_mod; ++i) - strcat(e_b64_std, "="); + char *n_b64_std = strdup(n_b64); + char *e_b64_std = strdup(e_b64); + base64url_to_base64(n_b64_std); + base64url_to_base64(e_b64_std); if (mbedtls_base64_decode(n_bin, sizeof(n_bin), &n_len, (const unsigned char *)n_b64_std, strlen(n_b64_std)) != 0 || @@ -1832,6 +1919,7 @@ static int LuaConvertJwkToPem(lua_State *L) { free(e_b64_std); // Build RSA context in pk if ((ret = mbedtls_pk_setup( + &pk, mbedtls_pk_info_from_type(MBEDTLS_PK_RSA))) != 0) { lua_pushnil(L); lua_pushstring(L, "mbedtls_pk_setup failed"); @@ -1851,15 +1939,7 @@ static int LuaConvertJwkToPem(lua_State *L) { #define DECODE_B64URL(var, b64, bin, binlen) \ if (b64 && *b64) { \ char *b64_std = strdup(b64); \ - for (char *p = b64_std; *p; ++p) \ - if (*p == '-') \ - *p = '+'; \ - else if (*p == '_') \ - *p = '/'; \ - int mod = strlen(b64_std) % 4; \ - if (mod) \ - for (int i = 0; i < 4 - mod; ++i) \ - strcat(b64_std, "="); \ + base64url_to_base64(b64_std); \ mbedtls_base64_decode(bin, sizeof(bin), &binlen, \ (const unsigned char *)b64_std, strlen(b64_std)); \ free(b64_std); \ @@ -1929,25 +2009,10 @@ static int LuaConvertJwkToPem(lua_State *L) { } size_t x_len, y_len; unsigned char x_bin[72], y_bin[72]; - char *x_b64_std = strdup(x_b64), *y_b64_std = strdup(y_b64); - for (char *p = x_b64_std; *p; ++p) - if (*p == '-') - *p = '+'; - else if (*p == '_') - *p = '/'; - for (char *p = y_b64_std; *p; ++p) - if (*p == '-') - *p = '+'; - else if (*p == '_') - *p = '/'; - int x_mod = strlen(x_b64_std) % 4; - int y_mod = strlen(y_b64_std) % 4; - if (x_mod) - for (int i = 0; i < 4 - x_mod; ++i) - strcat(x_b64_std, "="); - if (y_mod) - for (int i = 0; i < 4 - y_mod; ++i) - strcat(y_b64_std, "="); + char *x_b64_std = strdup(x_b64); + char *y_b64_std = strdup(y_b64); + base64url_to_base64(x_b64_std); + base64url_to_base64(y_b64_std); if (mbedtls_base64_decode(x_bin, sizeof(x_bin), &x_len, (const unsigned char *)x_b64_std, strlen(x_b64_std)) != 0 || @@ -2004,23 +2069,6 @@ static int LuaConvertJwkToPem(lua_State *L) { } } -// Convert PEM key to JWK (Lua table) format -static void base64_to_base64url(char *str) { - if (!str) - return; - for (char *p = str; *p; p++) { - if (*p == '+') - *p = '-'; - else if (*p == '/') - *p = '_'; - } - // Remove padding - size_t len = strlen(str); - while (len > 0 && str[len - 1] == '=') { - str[--len] = '\0'; - } -} - static int LuaConvertPemToJwk(lua_State *L) { const char *pem_key = luaL_checkstring(L, 1); int has_claims = 0; @@ -2063,20 +2111,8 @@ static int LuaConvertPemToJwk(lua_State *L) { } mbedtls_mpi_write_binary(&rsa->N, n, n_len); mbedtls_mpi_write_binary(&rsa->E, e, e_len); - char *n_b64 = NULL, *e_b64 = NULL; - size_t n_b64_len, e_b64_len; - mbedtls_base64_encode(NULL, 0, &n_b64_len, n, n_len); - mbedtls_base64_encode(NULL, 0, &e_b64_len, e, e_len); - n_b64 = malloc(n_b64_len + 1); - e_b64 = malloc(e_b64_len + 1); - mbedtls_base64_encode((unsigned char *)n_b64, n_b64_len, &n_b64_len, n, - n_len); - n_b64[n_b64_len] = '\0'; - base64_to_base64url(n_b64); - mbedtls_base64_encode((unsigned char *)e_b64, e_b64_len, &e_b64_len, e, - e_len); - e_b64[e_b64_len] = '\0'; - base64_to_base64url(e_b64); + char *n_b64 = b64url_encode(n, n_len); + char *e_b64 = b64url_encode(e, e_len); lua_pushstring(L, n_b64); lua_setfield(L, -2, "n"); lua_pushstring(L, e_b64); @@ -2110,47 +2146,12 @@ static int LuaConvertPemToJwk(lua_State *L) { mbedtls_mpi_write_binary(&rsa->DP, dp, dp_len); mbedtls_mpi_write_binary(&rsa->DQ, dq, dq_len); mbedtls_mpi_write_binary(&rsa->QP, qi, qi_len); - char *d_b64 = NULL, *p_b64 = NULL, *q_b64 = NULL, *dp_b64 = NULL, - *dq_b64 = NULL, *qi_b64 = NULL; - size_t d_b64_len, p_b64_len, q_b64_len, dp_b64_len, dq_b64_len, - qi_b64_len; - mbedtls_base64_encode(NULL, 0, &d_b64_len, d, d_len); - mbedtls_base64_encode(NULL, 0, &p_b64_len, p, p_len); - mbedtls_base64_encode(NULL, 0, &q_b64_len, q, q_len); - mbedtls_base64_encode(NULL, 0, &dp_b64_len, dp, dp_len); - mbedtls_base64_encode(NULL, 0, &dq_b64_len, dq, dq_len); - mbedtls_base64_encode(NULL, 0, &qi_b64_len, qi, qi_len); - d_b64 = malloc(d_b64_len + 1); - p_b64 = malloc(p_b64_len + 1); - q_b64 = malloc(q_b64_len + 1); - dp_b64 = malloc(dp_b64_len + 1); - dq_b64 = malloc(dq_b64_len + 1); - qi_b64 = malloc(qi_b64_len + 1); - mbedtls_base64_encode((unsigned char *)d_b64, d_b64_len, &d_b64_len, d, - d_len); - mbedtls_base64_encode((unsigned char *)p_b64, p_b64_len, &p_b64_len, p, - p_len); - mbedtls_base64_encode((unsigned char *)q_b64, q_b64_len, &q_b64_len, q, - q_len); - mbedtls_base64_encode((unsigned char *)dp_b64, dp_b64_len, &dp_b64_len, - dp, dp_len); - mbedtls_base64_encode((unsigned char *)dq_b64, dq_b64_len, &dq_b64_len, - dq, dq_len); - mbedtls_base64_encode((unsigned char *)qi_b64, qi_b64_len, &qi_b64_len, - qi, qi_len); - d_b64[d_b64_len] = '\0'; - p_b64[p_b64_len] = '\0'; - q_b64[q_b64_len] = '\0'; - dp_b64[dp_b64_len] = '\0'; - dq_b64[dq_b64_len] = '\0'; - qi_b64[qi_b64_len] = '\0'; - // Convert all private components to base64url - base64_to_base64url(d_b64); - base64_to_base64url(p_b64); - base64_to_base64url(q_b64); - base64_to_base64url(dp_b64); - base64_to_base64url(dq_b64); - base64_to_base64url(qi_b64); + char *d_b64 = b64url_encode(d, d_len); + char *p_b64 = b64url_encode(p, p_len); + char *q_b64 = b64url_encode(q, q_len); + char *dp_b64 = b64url_encode(dp, dp_len); + char *dq_b64 = b64url_encode(dq, dq_len); + char *qi_b64 = b64url_encode(qi, qi_len); lua_pushstring(L, d_b64); lua_setfield(L, -2, "d"); lua_pushstring(L, p_b64); @@ -2198,20 +2199,8 @@ static int LuaConvertPemToJwk(lua_State *L) { } mbedtls_mpi_write_binary(&Q->X, x, x_len); mbedtls_mpi_write_binary(&Q->Y, y, y_len); - char *x_b64 = NULL, *y_b64 = NULL; - size_t x_b64_len, y_b64_len; - mbedtls_base64_encode(NULL, 0, &x_b64_len, x, x_len); - mbedtls_base64_encode(NULL, 0, &y_b64_len, y, y_len); - x_b64 = malloc(x_b64_len + 1); - y_b64 = malloc(y_b64_len + 1); - mbedtls_base64_encode((unsigned char *)x_b64, x_b64_len, &x_b64_len, x, - x_len); - x_b64[x_b64_len] = '\0'; - base64_to_base64url(x_b64); - mbedtls_base64_encode((unsigned char *)y_b64, y_b64_len, &y_b64_len, y, - y_len); - y_b64[y_b64_len] = '\0'; - base64_to_base64url(y_b64); + char *x_b64 = b64url_encode(x, x_len); + char *y_b64 = b64url_encode(y, y_len); // Set kty and crv for EC keys lua_pushstring(L, "EC"); lua_setfield(L, -2, "kty"); @@ -2243,14 +2232,7 @@ static int LuaConvertPemToJwk(lua_State *L) { return 2; } mbedtls_mpi_write_binary(&ec->d, d, d_len); - char *d_b64 = NULL; - size_t d_b64_len; - mbedtls_base64_encode(NULL, 0, &d_b64_len, d, d_len); - d_b64 = malloc(d_b64_len + 1); - mbedtls_base64_encode((unsigned char *)d_b64, d_b64_len, &d_b64_len, d, - d_len); - d_b64[d_b64_len] = '\0'; - base64_to_base64url(d_b64); + char *d_b64 = b64url_encode(d, d_len); lua_pushstring(L, d_b64); lua_setfield(L, -2, "d"); free(d); @@ -2368,10 +2350,14 @@ static int LuaCryptoSign(lua_State *L) { if (strcasecmp(dtype, "rsa") == 0) { return LuaRSASign(L); - } else if (strcasecmp(dtype, "rsa-pss") == 0 || - strcasecmp(dtype, "rsapss") == 0) { + } +#ifndef TINY + else if (strcasecmp(dtype, "rsa-pss") == 0 || + strcasecmp(dtype, "rsapss") == 0) { return LuaRSAPSSSign(L); - } else if (strcasecmp(dtype, "ecdsa") == 0) { + } +#endif + else if (strcasecmp(dtype, "ecdsa") == 0) { return LuaECDSASign(L); } else { return luaL_error(L, "Unsupported signature type: %s", dtype); From e7fc09b4a493c76c4c845e4e5320fb9d5cd9dc26 Mon Sep 17 00:00:00 2001 From: Miguel Terron Date: Sat, 28 Jun 2025 14:14:02 +1200 Subject: [PATCH 18/18] All modes available on "tiny" builds --- test/tool/net/lcrypto_test.lua | 14 +-- tool/net/definitions.lua | 157 +++++++++++++++------------------ tool/net/lcrypto.c | 126 +++++++++++++++++++++++--- 3 files changed, 194 insertions(+), 103 deletions(-) diff --git a/test/tool/net/lcrypto_test.lua b/test/tool/net/lcrypto_test.lua index 27d4fb3fb..d21321e4b 100644 --- a/test/tool/net/lcrypto_test.lua +++ b/test/tool/net/lcrypto_test.lua @@ -1,4 +1,3 @@ ----@diagnostic disable: lowercase-global -- Test RSA key pair generation local function test_rsa_keypair_generation() local priv_key, pub_key = crypto.generateKeyPair("rsa", 2048) @@ -678,14 +677,14 @@ end local function run_tests() Log(kLogVerbose,"Testing RSA keypair generation...") test_rsa_keypair_generation() - Log(kLogVerbose,"Testing RSA signing and verification...") - test_rsa_signing_verification() - Log(kLogVerbose,"Testing RSA-PSS signing and verification...") - test_rsapss_signing_verification() - Log(kLogVerbose,"Testing RSA encryption and decryption...") - test_rsa_encryption_decryption() Log(kLogVerbose,"Testing RSA key size variations...") test_rsa_key_sizes() + Log(kLogVerbose,"Testing RSA signing and verification...") + test_rsa_signing_verification() + Log(kLogVerbose,"Testing RSA encryption and decryption...") + test_rsa_encryption_decryption() + Log(kLogVerbose,"Testing RSA-PSS signing and verification...") + test_rsapss_signing_verification() Log(kLogVerbose,"Testing ECDSA keypair generation...") test_ecdsa_keypair_generation() @@ -753,4 +752,5 @@ local function run_tests() end EXIT = 70 + os.exit(run_tests()) diff --git a/tool/net/definitions.lua b/tool/net/definitions.lua index c21ad82cc..11e545a60 100644 --- a/tool/net/definitions.lua +++ b/tool/net/definitions.lua @@ -746,7 +746,7 @@ function EscapeHtml(str) end ---@param path string? function LaunchBrowser(path) end ----@param ip uint32 +---@param ip integer|string ---@return string # a string describing the IP address. This is currently Class A granular. It can tell you if traffic originated from private networks, ARIN, APNIC, DOD, etc. ---@nodiscard function CategorizeIp(ip) end @@ -1142,10 +1142,10 @@ function FormatHttpDateTime(seconds) end --- Turns integer like `0x01020304` into a string like `"1.2.3.4"`. See also --- `ParseIp` for the inverse operation. ----@param uint32 integer +---@param ip integer ---@return string ---@nodiscard -function FormatIp(uint32) end +function FormatIp(ip) end --- Returns client ip4 address and port, e.g. `0x01020304`,`31337` would represent --- `1.2.3.4:31337`. This is the same as `GetClientAddr` except it will use the @@ -1363,25 +1363,25 @@ function HidePath(prefix) end ---@nodiscard function IsHiddenPath(path) end ----@param uint32 integer +---@param ip integer|string|string ---@return boolean # `true` if IP address is not a private network (`10.0.0.0/8`, `172.16.0.0/12`, `192.168.0.0/16`) and is not localhost (`127.0.0.0/8`). --- Note: we intentionally regard TEST-NET IPs as public. ---@nodiscard -function IsPublicIp(uint32) end +function IsPublicIp(ip) end ----@param uint32 integer +---@param ip integer|string|string ---@return boolean # `true` if IP address is part of a private network (10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16). ---@nodiscard -function IsPrivateIp(uint32) end +function IsPrivateIp(ip) end ---@return boolean # `true` if the client IP address (returned by GetRemoteAddr) is part of the localhost network (127.0.0.0/8). ---@nodiscard function IsLoopbackClient() end ----@param uint32 integer +---@param ip integer|string|string ---@return boolean # true if IP address is part of the localhost network (127.0.0.0/8). ---@nodiscard -function IsLoopbackIp(uint32) end +function IsLoopbackIp(ip) end ---@param path string ---@return boolean # `true` if ZIP artifact at path is stored on disk using DEFLATE compression. @@ -1615,7 +1615,7 @@ function GetCryptoHash(name, payload, key) end --- to the system-configured DNS resolution service. Please note that in MODE=tiny --- the HOSTS.TXT and DNS resolution isn't included, and therefore an IP must be --- provided. ----@param ip integer +---@param ip integer|string|string ---@overload fun(host:string) function ProgramAddr(ip) end @@ -1669,8 +1669,8 @@ function ProgramTimeout(milliseconds) end --- Hard-codes the port number on which to listen, which can be any number in the --- range `1..65535`, or alternatively `0` to ask the operating system to choose a --- port, which may be revealed later on by `GetServerAddr` or the `-z` flag to stdout. ----@param uint16 integer -function ProgramPort(uint16) end +---@param port integer +function ProgramPort(port) end --- Sets the maximum HTTP message payload size in bytes. The --- default is very conservatively set to 65536 so this is @@ -2169,7 +2169,7 @@ function bin(int) end --- unspecified format describing the error. Calls to this function may be wrapped --- in `assert()` if an exception is desired. ---@param hostname string ----@return uint32 ip uint32 +---@return string ---@nodiscard ---@overload fun(hostname: string): nil, error: string function ResolveIp(hostname) end @@ -2183,7 +2183,7 @@ function ResolveIp(hostname) end --- The network interface addresses used by the host machine are always --- considered trustworthy, e.g. 127.0.0.1. This may change soon, if we --- decide to export a `GetHostIps()` API which queries your NIC devices. ----@param ip integer +---@param ip integer|string ---@return boolean function IsTrustedIp(ip) end @@ -2213,7 +2213,7 @@ function IsTrustedIp(ip) end --- --- Although you might want consider trusting redbean's open source --- freedom embracing solution to DDOS protection instead! ----@param ip integer +---@param ip integer|string ---@param cidr integer? function ProgramTrustedIp(ip, cidr) end @@ -8051,85 +8051,70 @@ kUrlLatin1 = nil --- This module provides cryptographic operations. +--- The crypto module for cryptographic operations crypto = {} ---- Signs a message using the specified key type. ---- Supported types: "rsa", "rsa-pss", "ecdsa" ----@param type "rsa"|"rsa-pss"|"rsapss"|"ecdsa" ----@param key string PEM-encoded private key ----@param message string ----@param hash? string Hash algorithm ("sha256", "sha384", "sha512"). Default: "sha256" ----@return string signature ----@overload fun(type: string, key: string, message: string, hash?: string): nil, error: string -function crypto.sign(type, key, message, hash) end +--- Converts a PEM-encoded key to JWK format +---@param pem string PEM-encoded key +---@return table?, string? JWK table or nil on error +---@return string? error message +function crypto.convertPemToJwk(pem) end ---- Verifies a signature using the specified key type. ---- Supported types: "rsa", "rsa-pss", "ecdsa" ----@param type "rsa"|"rsa-pss"|"rsapss"|"ecdsa" ----@param key string PEM-encoded public key ----@param message string ----@param signature string ----@param hash? string Hash algorithm ("sha256", "sha384", "sha512"). Default: "sha256" ----@return boolean valid -function crypto.verify(type, key, message, signature, hash) end +--- Generates a Certificate Signing Request (CSR) +---@param key_pem string PEM-encoded private key +---@param subject_name string? X.509 subject name +---@param san_list string? Subject Alternative Names +---@return string?, string? CSR in PEM format or nil on error and error message +function crypto.generateCsr(key_pem, subject_name, san_list) end ---- Encrypts data using the specified cipher. ---- Supported ciphers: "rsa", "aes" ----@param cipher "rsa"|"aes" ----@param key string PEM-encoded public key (RSA) or raw key (AES) ----@param plaintext string ----@param options? table Options for AES: { mode="cbc"|"gcm"|"ctr", iv=string, aad=string } ----@return string ciphertext, string? iv, string? tag ----@overload fun(cipher: string, key: string, plaintext: string, options?: table): nil, error: string -function crypto.encrypt(cipher, key, plaintext, options) end +--- Signs data using a private key +---@param key_type string "rsa" or "ecdsa" +---@param private_key string PEM-encoded private key +---@param message string Data to sign +---@param hash_algo string? Hash algorithm (default: SHA-256) +---@return string?, string? Signature or nil on error and error message +function crypto.sign(key_type, private_key, message, hash_algo) end ---- Decrypts data using the specified cipher. ---- Supported ciphers: "rsa", "aes" ----@param cipher "rsa"|"aes" ----@param key string PEM-encoded private key (RSA) or raw key (AES) ----@param ciphertext string ----@param options? table Options for AES: { mode="cbc"|"gcm"|"ctr", iv=string, tag=string, aad=string } ----@return string plaintext ----@overload fun(cipher: string, key: string, ciphertext: string, options?: table): nil, error: string -function crypto.decrypt(cipher, key, ciphertext, options) end +--- Verifies a signature +---@param key_type string "rsa" or "ecdsa" +---@param public_key string PEM-encoded public key +---@param message string Original message +---@param signature string Signature to verify +---@param hash_algo string? Hash algorithm (default: SHA-256) +---@return boolean?, string? True if valid or nil on error and error message +function crypto.verify(key_type, public_key, message, signature, hash_algo) end ---- Generates a key pair. ---- For RSA: bits = 2048 or 4096. ---- For ECDSA: curve = "secp256r1", "secp384r1", "secp521r1", "curve25519" ---- For AES: bits = 128, 192, or 256. ----@param type "rsa"|"ecdsa"|"aes" ----@param param? integer|string For RSA: bits; for ECDSA: curve name; for AES: bits ----@return string private_key, string public_key|nil ----@overload fun(type: string, param?: integer|string): nil, error: string -function crypto.generateKeyPair(type, param) end +--- Encrypts data +---@param cipher_type string "rsa" or "aes" +---@param key string Public key or symmetric key +---@param plaintext string Data to encrypt +---@param options table Table with optional parameters: +--- options.mode string? AES mode: "cbc", "gcm", "ctr" (default: "cbc") +--- options.iv string? Initialization Vector for AES +--- options.aad string? Additional data for AES-GCM +---@return string? Encrypted data or nil on error +---@return string? IV or error message +---@return string? Authentication tag for GCM mode +function crypto.encrypt(cipher_type, key, plaintext, options) end ---- Converts a JWK (JSON Web Key, as a Lua table) to PEM format. ----@param jwk table ----@return string pem ----@overload fun(jwk: table): nil, error: string -function crypto.convertJwkToPem(jwk) end +--- Decrypts data +---@param cipher_type string "rsa" or "aes" +---@param key string Private key or symmetric key +---@param ciphertext string Data to decrypt +---@param options table Table with optional parameters: +--- options.iv string? Initialization Vector for AES +--- options.mode string? AES mode: "cbc", "gcm", "ctr" (default: "cbc") +--- options.tag string? Authentication tag for AES-GCM +--- options.aad string? Additional data for AES-GCM +---@return string?, string? Decrypted data or nil on error and error message +function crypto.decrypt(cipher_type, key, ciphertext, options) end ---- Converts a PEM key to JWK (JSON Web Key, as a Lua table). ----@param pem string ----@param claims? table Additional claims to merge (RFC7517 fields) ----@return table jwk ----@overload fun(pem: string, claims?: table): nil, error: string -function crypto.convertPemToJwk(pem, claims) end - ---- Generates a Certificate Signing Request (CSR) from a PEM key. ----@param key string PEM-encoded private key ----@param subject string Subject name (e.g., "CN=example.com") ----@param sans? string Subject Alternative Names (SANs) as a comma-separated string ----@return string csr_pem ----@overload fun(key: string, subject: string, sans?: string): nil, error: string -function crypto.generateCsr(key, subject, sans) end - --- AES options table for encrypt/decrypt: ----@class CryptoAesOptions ----@field mode? "cbc"|"gcm"|"ctr" ----@field iv? string ----@field tag? string ----@field aad? string +--- Generates cryptographic keys +---@param key_type string? "rsa", "ecdsa", or "aes" +---@param key_size_or_curve number|string? Key size or curve name +---@return string? Private key or nil on error +---@return string? Public key (nil for AES) or error message +function crypto.generatekeypair(key_type, key_size_or_curve) end --[[ diff --git a/tool/net/lcrypto.c b/tool/net/lcrypto.c index 6e4ec032e..f97920893 100644 --- a/tool/net/lcrypto.c +++ b/tool/net/lcrypto.c @@ -50,9 +50,7 @@ static const curve_map_t supported_curves[] = { {"P521", MBEDTLS_ECP_DP_SECP521R1}, // {"P-521", MBEDTLS_ECP_DP_SECP521R1}, // {"curve25519", MBEDTLS_ECP_DP_CURVE25519}, // -#ifndef TINY - {"curve448", MBEDTLS_ECP_DP_CURVE448}, // -#endif + {"curve448", MBEDTLS_ECP_DP_CURVE448}, // {NULL, 0}}; // List available curves @@ -198,6 +196,24 @@ cleanup: return ret; } +// Ciphers +typedef struct { + const char *name; + mbedtls_cipher_id_t id; +} ciphers_map_t; + +static const ciphers_map_t supported_ciphers[] = { + {"AES-128-CBC", MBEDTLS_CIPHER_AES_128_CBC}, // + {"AES-192-CBC", MBEDTLS_CIPHER_AES_192_CBC}, // + {"AES-256-CBC", MBEDTLS_CIPHER_AES_256_CBC}, // + {"AES-128-CTR", MBEDTLS_CIPHER_AES_128_CTR}, // + {"AES-192-CTR", MBEDTLS_CIPHER_AES_192_CTR}, // + {"AES-256-CTR", MBEDTLS_CIPHER_AES_256_CTR}, // + {"AES-128-GCM", MBEDTLS_CIPHER_AES_128_GCM}, // + {"AES-192-GCM", MBEDTLS_CIPHER_AES_192_GCM}, // + {"AES-256-GCM", MBEDTLS_CIPHER_AES_256_GCM}, // + {NULL, 0}}; + // Strong RNG using mbedtls_entropy_context and mbedtls_ctr_drbg_context int GenerateRandom(void *ctx, unsigned char *output, size_t len) { static mbedtls_entropy_context entropy; @@ -2350,14 +2366,10 @@ static int LuaCryptoSign(lua_State *L) { if (strcasecmp(dtype, "rsa") == 0) { return LuaRSASign(L); - } -#ifndef TINY - else if (strcasecmp(dtype, "rsa-pss") == 0 || - strcasecmp(dtype, "rsapss") == 0) { + } else if (strcasecmp(dtype, "rsa-pss") == 0 || + strcasecmp(dtype, "rsapss") == 0) { return LuaRSAPSSSign(L); - } -#endif - else if (strcasecmp(dtype, "ecdsa") == 0) { + } else if (strcasecmp(dtype, "ecdsa") == 0) { return LuaECDSASign(L); } else { return luaL_error(L, "Unsupported signature type: %s", dtype); @@ -2433,6 +2445,99 @@ static int LuaCryptoGenerateKeyPair(lua_State *L) { } } +// Returns a Lua table array of supported digests and ciphers (strings), +// depending on the type argument: +// "ciphers" - returns list of ciphers supported by crypto.encrypt and +// crypto.decrypt "digests" - returns list of digests in supported_digests +// "curves" - returns list of curves in supported_curves +// If no argument is provided, returns a table with all three types +static int LuaList(lua_State *L) { + // Create a new table to hold the result + lua_newtable(L); + + // No argument provided - return all types in a structured table + if (lua_isnoneornil(L, 1)) { + // Create subtable for digests + lua_pushstring(L, "digests"); + lua_newtable(L); + const digest_map_t *digest = supported_digests; + int i = 1; + while (digest->name != NULL) { + lua_pushstring(L, digest->name); + lua_rawseti(L, -2, i++); + digest++; + } + lua_settable(L, -3); + + // Create subtable for curves + lua_pushstring(L, "curves"); + lua_newtable(L); + const curve_map_t *curve = supported_curves; + i = 1; + while (curve->name != NULL) { + lua_pushstring(L, curve->name); + lua_rawseti(L, -2, i++); + curve++; + } + lua_settable(L, -3); + + // Create subtable for ciphers + lua_pushstring(L, "ciphers"); + lua_newtable(L); + const ciphers_map_t *cipher = supported_ciphers; + i = 1; + while (cipher->name != NULL) { + lua_pushstring(L, cipher->name); + lua_rawseti(L, -2, i++); + cipher++; + } + lua_settable(L, -3); + + return 1; + } + + // Argument provided - handle specific type + const char *type = luaL_checkstring(L, 1); + + if (strcasecmp(type, "curves") == 0) { + // List all available curves + const curve_map_t *curve = supported_curves; + int i = 1; + + while (curve->name != NULL) { + lua_pushstring(L, curve->name); + lua_rawseti(L, -2, i++); + curve++; + } + } else if (strcasecmp(type, "digests") == 0) { + // List all available digests + const digest_map_t *digest = supported_digests; + int i = 1; + + while (digest->name != NULL) { + lua_pushstring(L, digest->name); + lua_rawseti(L, -2, i++); + digest++; + } + } else if (strcasecmp(type, "ciphers") == 0) { + // List all available ciphers + const ciphers_map_t *cipher = supported_ciphers; + int i = 1; + + while (cipher->name != NULL) { + lua_pushstring(L, cipher->name); + lua_rawseti(L, -2, i++); + cipher++; + } + } else { + // Invalid type, return empty table + lua_pushstring(L, "Invalid type. Use 'ciphers', 'digests', or 'curves'"); + lua_setfield(L, -2, "error"); + } + + return 1; // Return the table +} + static const luaL_Reg kLuaCrypto[] = { {"sign", LuaCryptoSign}, // {"verify", LuaCryptoVerify}, // @@ -2442,6 +2547,7 @@ static const luaL_Reg kLuaCrypto[] = { {"convertJwkToPem", LuaConvertJwkToPem}, // {"convertPemToJwk", LuaConvertPemToJwk}, // {"generateCsr", LuaGenerateCSR}, // + {"list", LuaList}, // {0}, // };