From b8bdccc7fc2fc33c59a9557f92ac743f7d652fd6 Mon Sep 17 00:00:00 2001 From: Miguel Terron Date: Tue, 24 Jun 2025 20:30:35 +1200 Subject: [PATCH] Address Paul's comments :) --- tool/net/lcrypto.c | 366 +++++++++++++++++++++++++++++---------------- 1 file changed, 236 insertions(+), 130 deletions(-) diff --git a/tool/net/lcrypto.c b/tool/net/lcrypto.c index 7ce0c3cd4..4cda00e1d 100644 --- a/tool/net/lcrypto.c +++ b/tool/net/lcrypto.c @@ -39,12 +39,15 @@ typedef struct { } curve_map_t; static const curve_map_t supported_curves[] = { - {"secp256r1", MBEDTLS_ECP_DP_SECP256R1}, - {"P-256", MBEDTLS_ECP_DP_SECP256R1}, - {"secp384r1", MBEDTLS_ECP_DP_SECP384R1}, - {"P-384", MBEDTLS_ECP_DP_SECP384R1}, - {"secp521r1", MBEDTLS_ECP_DP_SECP521R1}, - {"P-521", MBEDTLS_ECP_DP_SECP521R1}, + {"secp256r1", MBEDTLS_ECP_DP_SECP256R1}, + {"P256", MBEDTLS_ECP_DP_SECP256R1}, + {"P-256", MBEDTLS_ECP_DP_SECP256R1}, + {"secp384r1", MBEDTLS_ECP_DP_SECP384R1}, + {"P384", MBEDTLS_ECP_DP_SECP384R1}, + {"P-384", MBEDTLS_ECP_DP_SECP384R1}, + {"secp521r1", MBEDTLS_ECP_DP_SECP521R1}, + {"P521", MBEDTLS_ECP_DP_SECP521R1}, + {"P-521", MBEDTLS_ECP_DP_SECP521R1}, {"curve25519", MBEDTLS_ECP_DP_CURVE25519}, {NULL, 0}}; @@ -69,11 +72,14 @@ static mbedtls_md_type_t string_to_md_type(const char *hash_name) { if (!hash_name || !*hash_name) { return MBEDTLS_MD_SHA256; // Default to SHA-256 if no name provided } - if (strcasecmp(hash_name, "sha256") == 0 || strcasecmp(hash_name, "sha-256") == 0) { + if (strcasecmp(hash_name, "sha256") == 0 || + strcasecmp(hash_name, "sha-256") == 0) { return MBEDTLS_MD_SHA256; - } else if (strcasecmp(hash_name, "sha384") == 0 || strcasecmp(hash_name, "sha-384") == 0) { + } else if (strcasecmp(hash_name, "sha384") == 0 || + strcasecmp(hash_name, "sha-384") == 0) { return MBEDTLS_MD_SHA384; - } else if (strcasecmp(hash_name, "sha512") == 0 || strcasecmp(hash_name, "sha-512") == 0) { + } else if (strcasecmp(hash_name, "sha512") == 0 || + strcasecmp(hash_name, "sha-512") == 0) { return MBEDTLS_MD_SHA512; } else { WARNF("(crypto) Unknown hash algorithm '%s'", hash_name); @@ -81,6 +87,7 @@ static mbedtls_md_type_t string_to_md_type(const char *hash_name) { } } +// Get the size of the hash output based on the mbedtls_md_type_t static size_t get_hash_size_from_md_type(mbedtls_md_type_t md_type) { switch (md_type) { case MBEDTLS_MD_SHA256: @@ -288,7 +295,10 @@ static int LuaRSAGenerateKeyPair(lua_State *L) { // Check if key length is valid (only 2048 or 4096 bits are allowed) if (bits != 2048 && bits != 4096) { lua_pushnil(L); - lua_pushfstring(L, "Invalid RSA key length: %d. Only 2048 or 4096 bits key lengths are supported", bits); + lua_pushfstring(L, + "Invalid RSA key length: %d. Only 2048 or 4096 bits key " + "lengths are supported", + bits); return 2; } @@ -414,7 +424,7 @@ static char *RSADecrypt(const char *private_key_pem, mbedtls_pk_context key; mbedtls_pk_init(&key); rc = mbedtls_pk_parse_key(&key, (const unsigned char *)private_key_pem, - strlen(private_key_pem) + 1, NULL, 0); + strlen(private_key_pem) + 1, NULL, 0); if (rc != 0) { WARNF("(crypto) Failed to parse private key (grep -0x%04x)", -rc); mbedtls_pk_free(&key); @@ -439,8 +449,8 @@ static char *RSADecrypt(const char *private_key_pem, // Decrypt data size_t output_len = 0; rc = mbedtls_rsa_pkcs1_decrypt(mbedtls_pk_rsa(key), GenerateRandom, 0, - MBEDTLS_RSA_PRIVATE, &output_len, - encrypted_data, output, key_size); + MBEDTLS_RSA_PRIVATE, &output_len, + encrypted_data, output, key_size); if (rc != 0) { WARNF("(crypto) Decryption failed (grep -0x%04x)", -rc); free(output); @@ -560,10 +570,10 @@ static int LuaRSASign(lua_State *L) { return 2; } if (lua_type(L, 1) == LUA_TTABLE) { - DEBUGF("(crypto) Detected table type for parameter 1"); - lua_pushnil(L); - lua_pushstring(L, "Key must be a string, got a table instead"); - return 2; + DEBUGF("(crypto) Detected table type for parameter 1"); + lua_pushnil(L); + lua_pushstring(L, "Key must be a string, got a table instead"); + return 2; } // Ensure msg is a string if (lua_type(L, 2) != LUA_TSTRING) { @@ -617,7 +627,7 @@ static char *RSAPSSSign(const char *private_key_pem, const unsigned char *data, mbedtls_pk_context key; mbedtls_pk_init(&key); rc = mbedtls_pk_parse_key(&key, (const unsigned char *)private_key_pem, - strlen(private_key_pem) + 1, NULL, 0); + strlen(private_key_pem) + 1, NULL, 0); if (rc != 0) { WARNF("(crypto) Failed to parse private key (grep -0x%04x)", -rc); mbedtls_pk_free(&key); @@ -650,9 +660,9 @@ static char *RSAPSSSign(const char *private_key_pem, const unsigned char *data, mbedtls_rsa_set_padding(rsa, MBEDTLS_RSA_PKCS_V21, hash_algo); // Sign the hash using PSS - rc = mbedtls_rsa_rsassa_pss_sign( - rsa, GenerateRandom, 0, MBEDTLS_RSA_PRIVATE, hash_algo, (unsigned int)hash_len, - hash, signature); + rc = mbedtls_rsa_rsassa_pss_sign(rsa, GenerateRandom, 0, MBEDTLS_RSA_PRIVATE, + hash_algo, (unsigned int)hash_len, hash, + signature); if (rc != 0) { free(signature); mbedtls_pk_free(&key); @@ -673,7 +683,7 @@ static int LuaRSAPSSSign(lua_State *L) { unsigned char *signature; size_t sig_len = 0; - // Get parameters from Lua + // Get parameters from Lua if (lua_type(L, 1) != LUA_TSTRING) { lua_pushnil(L); lua_pushstring(L, "Key must be a string"); @@ -697,8 +707,9 @@ static int LuaRSAPSSSign(lua_State *L) { int salt_len = luaL_optinteger(L, 4, -1); // Call the C implementation - signature = (unsigned char *)RSAPSSSign(key_pem, (const unsigned char *)msg, - msg_len, hash_name, &sig_len, salt_len); + signature = + (unsigned char *)RSAPSSSign(key_pem, (const unsigned char *)msg, msg_len, + hash_name, &sig_len, salt_len); if (!signature) { return luaL_error(L, "failed to sign message (PSS)"); @@ -795,9 +806,9 @@ static int LuaRSAVerify(lua_State *L) { // RSA PSS Verification static int RSAPSSVerify(const char *public_key_pem, const unsigned char *data, - size_t data_len, const unsigned char *signature, - size_t sig_len, const char *hash_algo_str, - int expected_salt_len) { + size_t data_len, const unsigned char *signature, + size_t sig_len, const char *hash_algo_str, + int expected_salt_len) { int rc; unsigned char hash[64]; // Large enough for SHA-512 size_t hash_len = 32; // Default for SHA-256 @@ -806,9 +817,7 @@ static int RSAPSSVerify(const char *public_key_pem, const unsigned char *data, // Determine hash algorithm if (hash_algo_str) { hash_algo = string_to_md_type(hash_algo_str); - DEBUGF("(DEBUG) Using hash algorithm: %s", hash_algo_str); hash_len = get_hash_size_from_md_type(hash_algo); - DEBUGF("(DEBUG) Hash length: %zu", hash_len); } // Parse public key @@ -840,9 +849,9 @@ static int RSAPSSVerify(const char *public_key_pem, const unsigned char *data, mbedtls_rsa_context *rsa = mbedtls_pk_rsa(key); // Verify the signature using PSS - rc = mbedtls_rsa_rsassa_pss_verify( - rsa, NULL, NULL, MBEDTLS_RSA_PUBLIC, hash_algo, (unsigned int)hash_len, - hash, signature); + rc = mbedtls_rsa_rsassa_pss_verify(rsa, NULL, NULL, MBEDTLS_RSA_PUBLIC, + hash_algo, (unsigned int)hash_len, hash, + signature); // Clean up mbedtls_pk_free(&key); @@ -859,7 +868,7 @@ static int LuaRSAPSSVerify(lua_State *L) { int expected_salt_len = -1; int result; - // Get parameters from Lua + // Get parameters from Lua if (lua_type(L, 1) != LUA_TSTRING) { lua_pushnil(L); lua_pushstring(L, "Key must be a string"); @@ -893,12 +902,14 @@ static int LuaRSAPSSVerify(lua_State *L) { DEBUGF("(DEBUG) Key PEM: %s", key_pem); DEBUGF("(DEBUG) Message length: %zu", msg_len); DEBUGF("(DEBUG) Signature length: %zu", sig_len); - DEBUGF("(DEBUG) Hash algorithm: %s", hash_algo_str ? hash_algo_str : "SHA-256"); + DEBUGF("(DEBUG) Hash algorithm: %s", + hash_algo_str ? hash_algo_str : "SHA-256"); DEBUGF("(DEBUG) Expected salt length: %d", expected_salt_len); DEBUGF("(DEBUG) Signature: %.*s", (int)sig_len, signature); // Call the C implementation result = RSAPSSVerify(key_pem, (const unsigned char *)msg, msg_len, - (const unsigned char *)signature, sig_len, hash_algo_str, expected_salt_len); + (const unsigned char *)signature, sig_len, + hash_algo_str, expected_salt_len); // Return boolean result (0 means valid signature) lua_pushboolean(L, result == 0); @@ -1077,8 +1088,8 @@ static int ECDSASign(const char *priv_key_pem, const char *message, } // Sign the hash using GenerateRandom - ret = mbedtls_pk_sign(&key, hash_alg, hash, hash_size, - *signature, sig_len, GenerateRandom, 0); + ret = mbedtls_pk_sign(&key, hash_alg, hash, hash_size, *signature, sig_len, + GenerateRandom, 0); if (ret != 0) { WARNF("(crypto) Failed to sign message: -0x%04x", -ret); @@ -1158,8 +1169,7 @@ static int ECDSAVerify(const char *pub_key_pem, const char *message, } // Verify the signature - ret = mbedtls_pk_verify(&key, hash_alg, hash, hash_size, - signature, sig_len); + ret = mbedtls_pk_verify(&key, hash_alg, hash, hash_size, signature, sig_len); if (ret != 0) { WARNF("(crypto) Signature verification failed: -0x%04x", -ret); goto cleanup; @@ -1223,7 +1233,8 @@ typedef struct { size_t aadlen; } aes_options_t; -static void parse_aes_options(lua_State *L, int options_idx, aes_options_t *opts) { +static void parse_aes_options(lua_State *L, int options_idx, + aes_options_t *opts) { opts->mode = NULL; opts->iv = NULL; opts->ivlen = 0; @@ -1240,12 +1251,12 @@ static void parse_aes_options(lua_State *L, int options_idx, aes_options_t *opts if (!lua_isnil(L, -1)) { mode_field_found = 1; const char *mode = lua_tostring(L, -1); - if (mode && (strcasecmp(mode, "cbc") == 0 || - strcasecmp(mode, "gcm") == 0 || - strcasecmp(mode, "ctr") == 0)) { + if (mode && + (strcasecmp(mode, "cbc") == 0 || strcasecmp(mode, "gcm") == 0 || + strcasecmp(mode, "ctr") == 0)) { opts->mode = mode; } else { - opts->mode = NULL; // Invalid mode + opts->mode = NULL; // Invalid mode } } lua_pop(L, 1); @@ -1288,7 +1299,6 @@ static int LuaAesEncrypt(lua_State *L) { // Args: key, plaintext, options table size_t keylen, ptlen; - // Get parameters from Lua // Ensure key is a string if (lua_type(L, 1) != LUA_TSTRING) { @@ -1319,7 +1329,8 @@ static int LuaAesEncrypt(lua_State *L) { const char *mode = opts.mode; if (!mode) { lua_pushnil(L); - lua_pushstring(L, "Invalid AES mode specified. Use 'cbc', 'gcm', or 'ctr'."); + lua_pushstring(L, + "Invalid AES mode specified. Use 'cbc', 'gcm', or 'ctr'."); return 2; } const unsigned char *iv = opts.iv; @@ -1354,6 +1365,7 @@ static int LuaAesEncrypt(lua_State *L) { lua_pushstring(L, "Failed to allocate IV"); return 2; } + // Generate random IV if (GenerateRandom(NULL, gen_iv, ivlen) != 0) { free(gen_iv); @@ -1367,23 +1379,23 @@ static int LuaAesEncrypt(lua_State *L) { // Validate IV/nonce length if (is_cbc || is_ctr) { - // Validate IV/nonce length - if (is_cbc || is_ctr) { - if (opts.iv && opts.ivlen != 16) { - if (iv_was_generated) free(gen_iv); - lua_pushnil(L); - lua_pushstring(L, "AES IV/nonce must be 16 bytes for CBC/CTR"); - return 2; - } - } else if (is_gcm) { - if (opts.iv && (opts.ivlen < 12 || opts.ivlen > 16)) { - if (iv_was_generated) free(gen_iv); - lua_pushnil(L); - lua_pushstring(L, "AES GCM nonce must be 12-16 bytes"); - return 2; - } + if (opts.iv && opts.ivlen != 16) { + if (iv_was_generated) + free(gen_iv); + lua_pushnil(L); + lua_pushstring(L, "AES IV/nonce must be 16 bytes for CBC/CTR"); + return 2; } + } else if (is_gcm) { + if (opts.iv && (opts.ivlen < 12 || opts.ivlen > 16)) { + if (iv_was_generated) + free(gen_iv); + lua_pushnil(L); + lua_pushstring(L, "AES GCM nonce must be 12-16 bytes"); + return 2; } + } + if (is_cbc) { // PKCS7 padding size_t block_size = 16; @@ -1531,7 +1543,7 @@ static int LuaAesDecrypt(lua_State *L) { return 2; } const unsigned char *key = - (const unsigned char *)luaL_checklstring(L, 1, &keylen); + (const unsigned char *)luaL_checklstring(L, 1, &keylen); const unsigned char *ciphertext = (const unsigned char *)luaL_checklstring(L, 2, &ctlen); int options_idx = 3; @@ -1540,7 +1552,8 @@ static int LuaAesDecrypt(lua_State *L) { const char *mode = opts.mode; if (!mode) { lua_pushnil(L); - lua_pushstring(L, "Invalid AES mode specified. Use 'cbc', 'gcm', or 'ctr'."); + lua_pushstring(L, + "Invalid AES mode specified. Use 'cbc', 'gcm', or 'ctr'."); return 2; } const unsigned char *iv = opts.iv; @@ -1725,7 +1738,6 @@ static int LuaAesDecrypt(lua_State *L) { return 2; } - // JWK functions // Convert JWK (Lua table) to PEM key format @@ -1786,23 +1798,44 @@ static int LuaConvertJwkToPem(lua_State *L) { size_t n_len, e_len; unsigned char n_bin[1024], e_bin[16]; char *n_b64_std = strdup(n_b64), *e_b64_std = strdup(e_b64); - for (char *p = n_b64_std; *p; ++p) if (*p == '-') *p = '+'; else if (*p == '_') *p = '/'; - for (char *p = e_b64_std; *p; ++p) if (*p == '-') *p = '+'; else if (*p == '_') *p = '/'; + for (char *p = n_b64_std; *p; ++p) + if (*p == '-') + *p = '+'; + else if (*p == '_') + *p = '/'; + for (char *p = e_b64_std; *p; ++p) + if (*p == '-') + *p = '+'; + else if (*p == '_') + *p = '/'; int n_mod = strlen(n_b64_std) % 4; int e_mod = strlen(e_b64_std) % 4; - if (n_mod) for (int i = 0; i < 4 - n_mod; ++i) strcat(n_b64_std, "="); - if (e_mod) for (int i = 0; i < 4 - e_mod; ++i) strcat(e_b64_std, "="); - if (mbedtls_base64_decode(n_bin, sizeof(n_bin), &n_len, (const unsigned char *)n_b64_std, strlen(n_b64_std)) != 0 || - mbedtls_base64_decode(e_bin, sizeof(e_bin), &e_len, (const unsigned char *)e_b64_std, strlen(e_b64_std)) != 0) { - free(n_b64_std); free(e_b64_std); + if (n_mod) + for (int i = 0; i < 4 - n_mod; ++i) + strcat(n_b64_std, "="); + if (e_mod) + for (int i = 0; i < 4 - e_mod; ++i) + strcat(e_b64_std, "="); + if (mbedtls_base64_decode(n_bin, sizeof(n_bin), &n_len, + (const unsigned char *)n_b64_std, + strlen(n_b64_std)) != 0 || + mbedtls_base64_decode(e_bin, sizeof(e_bin), &e_len, + (const unsigned char *)e_b64_std, + strlen(e_b64_std)) != 0) { + free(n_b64_std); + free(e_b64_std); lua_pushnil(L); lua_pushstring(L, "Base64 decode failed"); return 2; } - free(n_b64_std); free(e_b64_std); + free(n_b64_std); + free(e_b64_std); // Build RSA context in pk - if ((ret = mbedtls_pk_setup(&pk, mbedtls_pk_info_from_type(MBEDTLS_PK_RSA))) != 0) { - lua_pushnil(L); lua_pushstring(L, "mbedtls_pk_setup failed"); return 2; + if ((ret = mbedtls_pk_setup( + &pk, mbedtls_pk_info_from_type(MBEDTLS_PK_RSA))) != 0) { + lua_pushnil(L); + lua_pushstring(L, "mbedtls_pk_setup failed"); + return 2; } mbedtls_rsa_context *rsa = mbedtls_pk_rsa(pk); mbedtls_rsa_init(rsa, MBEDTLS_RSA_PKCS_V15, 0); @@ -1812,17 +1845,25 @@ static int LuaConvertJwkToPem(lua_State *L) { if (has_private) { // Decode and set private fields size_t d_len, p_len, q_len, dp_len, dq_len, qi_len; - unsigned char d_bin[1024], p_bin[512], q_bin[512], dp_bin[512], dq_bin[512], qi_bin[512]; - // Decode all private fields (skip if NULL) - #define DECODE_B64URL(var, b64, bin, binlen) \ - if (b64 && *b64) { \ - char *b64_std = strdup(b64); \ - for (char *p = b64_std; *p; ++p) if (*p == '-') *p = '+'; else if (*p == '_') *p = '/'; \ - int mod = strlen(b64_std) % 4; \ - if (mod) for (int i = 0; i < 4 - mod; ++i) strcat(b64_std, "="); \ - mbedtls_base64_decode(bin, sizeof(bin), &binlen, (const unsigned char *)b64_std, strlen(b64_std)); \ - free(b64_std); \ - } + unsigned char d_bin[1024], p_bin[512], q_bin[512], dp_bin[512], + dq_bin[512], qi_bin[512]; +// Decode all private fields (skip if NULL) +#define DECODE_B64URL(var, b64, bin, binlen) \ + if (b64 && *b64) { \ + char *b64_std = strdup(b64); \ + for (char *p = b64_std; *p; ++p) \ + if (*p == '-') \ + *p = '+'; \ + else if (*p == '_') \ + *p = '/'; \ + int mod = strlen(b64_std) % 4; \ + if (mod) \ + for (int i = 0; i < 4 - mod; ++i) \ + strcat(b64_std, "="); \ + mbedtls_base64_decode(bin, sizeof(bin), &binlen, \ + (const unsigned char *)b64_std, strlen(b64_std)); \ + free(b64_std); \ + } DECODE_B64URL(d, d_b64, d_bin, d_len); DECODE_B64URL(p, p_b64, p_bin, p_len); DECODE_B64URL(q, q_b64, q_bin, q_len); @@ -1845,7 +1886,9 @@ static int LuaConvertJwkToPem(lua_State *L) { } if (ret != 0) { mbedtls_pk_free(&pk); - lua_pushnil(L); lua_pushstring(L, "PEM write failed"); return 2; + lua_pushnil(L); + lua_pushstring(L, "PEM write failed"); + return 2; } pem = strdup((char *)buf); mbedtls_pk_free(&pk); @@ -1863,36 +1906,67 @@ static int LuaConvertJwkToPem(lua_State *L) { const char *y_b64 = lua_tostring(L, -2); const char *d_b64 = lua_tostring(L, -1); if (!crv || !*crv) { - lua_pushnil(L); lua_pushstring(L, "Missing or empty 'crv' in JWK"); return 2; + lua_pushnil(L); + lua_pushstring(L, "Missing or empty 'crv' in JWK"); + return 2; } if (!x_b64 || !*x_b64) { - lua_pushnil(L); lua_pushstring(L, "Missing or empty 'x' in JWK"); return 2; + lua_pushnil(L); + lua_pushstring(L, "Missing or empty 'x' in JWK"); + return 2; } if (!y_b64 || !*y_b64) { - lua_pushnil(L); lua_pushstring(L, "Missing or empty 'y' in JWK"); return 2; + lua_pushnil(L); + lua_pushstring(L, "Missing or empty 'y' in JWK"); + return 2; } int has_private = d_b64 && *d_b64; mbedtls_ecp_group_id gid = find_curve_by_name(crv); if (gid == MBEDTLS_ECP_DP_NONE) { - lua_pushnil(L); lua_pushstring(L, "Unknown curve"); return 2; + lua_pushnil(L); + lua_pushstring(L, "Unknown curve"); + return 2; } size_t x_len, y_len; unsigned char x_bin[72], y_bin[72]; char *x_b64_std = strdup(x_b64), *y_b64_std = strdup(y_b64); - for (char *p = x_b64_std; *p; ++p) if (*p == '-') *p = '+'; else if (*p == '_') *p = '/'; - for ( char *p = y_b64_std; *p; ++p) if (*p == '-') *p = '+'; else if (*p == '_') *p = '/'; + for (char *p = x_b64_std; *p; ++p) + if (*p == '-') + *p = '+'; + else if (*p == '_') + *p = '/'; + for (char *p = y_b64_std; *p; ++p) + if (*p == '-') + *p = '+'; + else if (*p == '_') + *p = '/'; int x_mod = strlen(x_b64_std) % 4; int y_mod = strlen(y_b64_std) % 4; - if (x_mod) for (int i = 0; i < 4 - x_mod; ++i) strcat(x_b64_std, "="); - if (y_mod) for (int i = 0; i < 4 - y_mod; ++i) strcat(y_b64_std, "="); - if (mbedtls_base64_decode(x_bin, sizeof(x_bin), &x_len, (const unsigned char *)x_b64_std, strlen(x_b64_std)) != 0 || - mbedtls_base64_decode(y_bin, sizeof(y_bin), &y_len, (const unsigned char *)y_b64_std, strlen(y_b64_std)) != 0) { - free(x_b64_std); free(y_b64_std); - lua_pushnil(L); lua_pushstring(L, "Base64 decode failed"); return 2; + if (x_mod) + for (int i = 0; i < 4 - x_mod; ++i) + strcat(x_b64_std, "="); + if (y_mod) + for (int i = 0; i < 4 - y_mod; ++i) + strcat(y_b64_std, "="); + if (mbedtls_base64_decode(x_bin, sizeof(x_bin), &x_len, + (const unsigned char *)x_b64_std, + strlen(x_b64_std)) != 0 || + mbedtls_base64_decode(y_bin, sizeof(y_bin), &y_len, + (const unsigned char *)y_b64_std, + strlen(y_b64_std)) != 0) { + free(x_b64_std); + free(y_b64_std); + lua_pushnil(L); + lua_pushstring(L, "Base64 decode failed"); + return 2; } - free(x_b64_std); free(y_b64_std); - if ((ret = mbedtls_pk_setup(&pk, mbedtls_pk_info_from_type(MBEDTLS_PK_ECKEY))) != 0) { - lua_pushnil(L); lua_pushstring(L, "mbedtls_pk_setup failed"); return 2; + free(x_b64_std); + free(y_b64_std); + if ((ret = mbedtls_pk_setup( + &pk, mbedtls_pk_info_from_type(MBEDTLS_PK_ECKEY))) != 0) { + lua_pushnil(L); + lua_pushstring(L, "mbedtls_pk_setup failed"); + return 2; } mbedtls_ecp_keypair *ec = mbedtls_pk_ec(pk); mbedtls_ecp_keypair_init(ec); @@ -1914,8 +1988,10 @@ static int LuaConvertJwkToPem(lua_State *L) { } if (ret != 0) { mbedtls_pk_free(&pk); - lua_pushnil(L); lua_pushstring(L, "PEM write failed"); return 2; - } + lua_pushnil(L); + lua_pushstring(L, "PEM write failed"); + return 2; + } pem = strdup((char *)buf); mbedtls_pk_free(&pk); lua_pushstring(L, pem); @@ -1930,14 +2006,17 @@ static int LuaConvertJwkToPem(lua_State *L) { // Convert PEM key to JWK (Lua table) format static void base64_to_base64url(char *str) { - if (!str) return; + if (!str) + return; for (char *p = str; *p; p++) { - if (*p == '+') *p = '-'; - else if (*p == '/') *p = '_'; + if (*p == '+') + *p = '-'; + else if (*p == '/') + *p = '_'; } // Remove padding size_t len = strlen(str); - while (len > 0 && str[len-1] == '=') { + while (len > 0 && str[len - 1] == '=') { str[--len] = '\0'; } } @@ -1990,10 +2069,12 @@ static int LuaConvertPemToJwk(lua_State *L) { mbedtls_base64_encode(NULL, 0, &e_b64_len, e, e_len); n_b64 = malloc(n_b64_len + 1); e_b64 = malloc(e_b64_len + 1); - mbedtls_base64_encode((unsigned char *)n_b64, n_b64_len, &n_b64_len, n, n_len); + mbedtls_base64_encode((unsigned char *)n_b64, n_b64_len, &n_b64_len, n, + n_len); n_b64[n_b64_len] = '\0'; base64_to_base64url(n_b64); - mbedtls_base64_encode((unsigned char *)e_b64, e_b64_len, &e_b64_len, e, e_len); + mbedtls_base64_encode((unsigned char *)e_b64, e_b64_len, &e_b64_len, e, + e_len); e_b64[e_b64_len] = '\0'; base64_to_base64url(e_b64); lua_pushstring(L, n_b64); @@ -2045,12 +2126,18 @@ static int LuaConvertPemToJwk(lua_State *L) { dp_b64 = malloc(dp_b64_len + 1); dq_b64 = malloc(dq_b64_len + 1); qi_b64 = malloc(qi_b64_len + 1); - mbedtls_base64_encode((unsigned char *)d_b64, d_b64_len, &d_b64_len, d, d_len); - mbedtls_base64_encode((unsigned char *)p_b64, p_b64_len, &p_b64_len, p, p_len); - mbedtls_base64_encode((unsigned char *)q_b64, q_b64_len, &q_b64_len, q, q_len); - mbedtls_base64_encode((unsigned char *)dp_b64, dp_b64_len, &dp_b64_len, dp, dp_len); - mbedtls_base64_encode((unsigned char *)dq_b64, dq_b64_len, &dq_b64_len, dq, dq_len); - mbedtls_base64_encode((unsigned char *)qi_b64, qi_b64_len, &qi_b64_len, qi, qi_len); + mbedtls_base64_encode((unsigned char *)d_b64, d_b64_len, &d_b64_len, d, + d_len); + mbedtls_base64_encode((unsigned char *)p_b64, p_b64_len, &p_b64_len, p, + p_len); + mbedtls_base64_encode((unsigned char *)q_b64, q_b64_len, &q_b64_len, q, + q_len); + mbedtls_base64_encode((unsigned char *)dp_b64, dp_b64_len, &dp_b64_len, + dp, dp_len); + mbedtls_base64_encode((unsigned char *)dq_b64, dq_b64_len, &dq_b64_len, + dq, dq_len); + mbedtls_base64_encode((unsigned char *)qi_b64, qi_b64_len, &qi_b64_len, + qi, qi_len); d_b64[d_b64_len] = '\0'; p_b64[p_b64_len] = '\0'; q_b64[q_b64_len] = '\0'; @@ -2117,16 +2204,19 @@ static int LuaConvertPemToJwk(lua_State *L) { mbedtls_base64_encode(NULL, 0, &y_b64_len, y, y_len); x_b64 = malloc(x_b64_len + 1); y_b64 = malloc(y_b64_len + 1); - mbedtls_base64_encode((unsigned char *)x_b64, x_b64_len, &x_b64_len, x, x_len); + mbedtls_base64_encode((unsigned char *)x_b64, x_b64_len, &x_b64_len, x, + x_len); x_b64[x_b64_len] = '\0'; base64_to_base64url(x_b64); - mbedtls_base64_encode((unsigned char *)y_b64, y_b64_len, &y_b64_len, y, y_len); + mbedtls_base64_encode((unsigned char *)y_b64, y_b64_len, &y_b64_len, y, + y_len); y_b64[y_b64_len] = '\0'; base64_to_base64url(y_b64); // Set kty and crv for EC keys lua_pushstring(L, "EC"); lua_setfield(L, -2, "kty"); - const mbedtls_ecp_curve_info *curve_info = mbedtls_ecp_curve_info_from_grp_id(ec->grp.id); + const mbedtls_ecp_curve_info *curve_info = + mbedtls_ecp_curve_info_from_grp_id(ec->grp.id); if (curve_info && curve_info->name) { lua_pushstring(L, curve_info->name); lua_setfield(L, -2, "crv"); @@ -2142,19 +2232,34 @@ static int LuaConvertPemToJwk(lua_State *L) { if (mbedtls_ecp_check_privkey(&ec->grp, &ec->d) == 0 && ec->d.p) { size_t d_len = mbedtls_mpi_size(&ec->d); unsigned char *d = malloc(d_len); - if (!d) { free(x); free(y); free(x_b64); free(y_b64); lua_pushnil(L); lua_pushstring(L, "Memory allocation failed"); mbedtls_pk_free(&key); return 2; } + if (!d) { + free(x); + free(y); + free(x_b64); + free(y_b64); + lua_pushnil(L); + lua_pushstring(L, "Memory allocation failed"); + mbedtls_pk_free(&key); + return 2; + } mbedtls_mpi_write_binary(&ec->d, d, d_len); char *d_b64 = NULL; size_t d_b64_len; mbedtls_base64_encode(NULL, 0, &d_b64_len, d, d_len); d_b64 = malloc(d_b64_len + 1); - mbedtls_base64_encode((unsigned char *)d_b64, d_b64_len, &d_b64_len, d, d_len); + mbedtls_base64_encode((unsigned char *)d_b64, d_b64_len, &d_b64_len, d, + d_len); d_b64[d_b64_len] = '\0'; base64_to_base64url(d_b64); - lua_pushstring(L, d_b64); lua_setfield(L, -2, "d"); - free(d); free(d_b64); + lua_pushstring(L, d_b64); + lua_setfield(L, -2, "d"); + free(d); + free(d_b64); } - free(x); free(y); free(x_b64); free(y_b64); + free(x); + free(y); + free(x_b64); + free(y_b64); } else { lua_pushnil(L); lua_pushstring(L, "Unsupported key type"); @@ -2166,8 +2271,10 @@ static int LuaConvertPemToJwk(lua_State *L) { // Merge additional claims if provided and compatible with RFC7517 if (has_claims) { - static const char *allowed[] = {"kty","use","sig","key_ops","alg","kid","x5u","x5c","x5t","x5t#S256",NULL}; - lua_pushnil(L); // first key + static const char *allowed[] = {"kty", "use", "sig", "key_ops", + "alg", "kid", "x5u", "x5c", + "x5t", "x5t#S256", NULL}; + lua_pushnil(L); // first key while (lua_next(L, 2) != 0) { const char *k = lua_tostring(L, -2); int allowed_key = 0; @@ -2252,7 +2359,6 @@ static int LuaGenerateCSR(lua_State *L) { return 1; } - // LuaCrypto compatible API static int LuaCryptoSign(lua_State *L) { // Type of signature (e.g., "rsa", "ecdsa", "rsa-pss") @@ -2262,7 +2368,8 @@ static int LuaCryptoSign(lua_State *L) { if (strcasecmp(dtype, "rsa") == 0) { return LuaRSASign(L); - } else if (strcasecmp(dtype, "rsa-pss") == 0 || strcasecmp(dtype, "rsapss") == 0) { + } else if (strcasecmp(dtype, "rsa-pss") == 0 || + strcasecmp(dtype, "rsapss") == 0) { return LuaRSAPSSSign(L); } else if (strcasecmp(dtype, "ecdsa") == 0) { return LuaECDSASign(L); @@ -2279,7 +2386,8 @@ static int LuaCryptoVerify(lua_State *L) { if (strcasecmp(dtype, "rsa") == 0) { return LuaRSAVerify(L); - } else if (strcasecmp(dtype, "rsa-pss") == 0 || strcasecmp(dtype, "rsapss") == 0) { + } else if (strcasecmp(dtype, "rsa-pss") == 0 || + strcasecmp(dtype, "rsapss") == 0) { return LuaRSAPSSVerify(L); } else if (strcasecmp(dtype, "ecdsa") == 0) { return LuaECDSAVerify(L); @@ -2339,8 +2447,6 @@ static int LuaCryptoGenerateKeyPair(lua_State *L) { } } - - static const luaL_Reg kLuaCrypto[] = { {"sign", LuaCryptoSign}, // {"verify", LuaCryptoVerify}, //