Address Paul's comments :)

This commit is contained in:
Miguel Terron 2025-06-24 20:30:35 +12:00
parent 12ff789a69
commit b8bdccc7fc

View file

@ -39,12 +39,15 @@ typedef struct {
} curve_map_t;
static const curve_map_t supported_curves[] = {
{"secp256r1", MBEDTLS_ECP_DP_SECP256R1},
{"P-256", MBEDTLS_ECP_DP_SECP256R1},
{"secp384r1", MBEDTLS_ECP_DP_SECP384R1},
{"P-384", MBEDTLS_ECP_DP_SECP384R1},
{"secp521r1", MBEDTLS_ECP_DP_SECP521R1},
{"P-521", MBEDTLS_ECP_DP_SECP521R1},
{"secp256r1", MBEDTLS_ECP_DP_SECP256R1},
{"P256", MBEDTLS_ECP_DP_SECP256R1},
{"P-256", MBEDTLS_ECP_DP_SECP256R1},
{"secp384r1", MBEDTLS_ECP_DP_SECP384R1},
{"P384", MBEDTLS_ECP_DP_SECP384R1},
{"P-384", MBEDTLS_ECP_DP_SECP384R1},
{"secp521r1", MBEDTLS_ECP_DP_SECP521R1},
{"P521", MBEDTLS_ECP_DP_SECP521R1},
{"P-521", MBEDTLS_ECP_DP_SECP521R1},
{"curve25519", MBEDTLS_ECP_DP_CURVE25519},
{NULL, 0}};
@ -69,11 +72,14 @@ static mbedtls_md_type_t string_to_md_type(const char *hash_name) {
if (!hash_name || !*hash_name) {
return MBEDTLS_MD_SHA256; // Default to SHA-256 if no name provided
}
if (strcasecmp(hash_name, "sha256") == 0 || strcasecmp(hash_name, "sha-256") == 0) {
if (strcasecmp(hash_name, "sha256") == 0 ||
strcasecmp(hash_name, "sha-256") == 0) {
return MBEDTLS_MD_SHA256;
} else if (strcasecmp(hash_name, "sha384") == 0 || strcasecmp(hash_name, "sha-384") == 0) {
} else if (strcasecmp(hash_name, "sha384") == 0 ||
strcasecmp(hash_name, "sha-384") == 0) {
return MBEDTLS_MD_SHA384;
} else if (strcasecmp(hash_name, "sha512") == 0 || strcasecmp(hash_name, "sha-512") == 0) {
} else if (strcasecmp(hash_name, "sha512") == 0 ||
strcasecmp(hash_name, "sha-512") == 0) {
return MBEDTLS_MD_SHA512;
} else {
WARNF("(crypto) Unknown hash algorithm '%s'", hash_name);
@ -81,6 +87,7 @@ static mbedtls_md_type_t string_to_md_type(const char *hash_name) {
}
}
// Get the size of the hash output based on the mbedtls_md_type_t
static size_t get_hash_size_from_md_type(mbedtls_md_type_t md_type) {
switch (md_type) {
case MBEDTLS_MD_SHA256:
@ -288,7 +295,10 @@ static int LuaRSAGenerateKeyPair(lua_State *L) {
// Check if key length is valid (only 2048 or 4096 bits are allowed)
if (bits != 2048 && bits != 4096) {
lua_pushnil(L);
lua_pushfstring(L, "Invalid RSA key length: %d. Only 2048 or 4096 bits key lengths are supported", bits);
lua_pushfstring(L,
"Invalid RSA key length: %d. Only 2048 or 4096 bits key "
"lengths are supported",
bits);
return 2;
}
@ -414,7 +424,7 @@ static char *RSADecrypt(const char *private_key_pem,
mbedtls_pk_context key;
mbedtls_pk_init(&key);
rc = mbedtls_pk_parse_key(&key, (const unsigned char *)private_key_pem,
strlen(private_key_pem) + 1, NULL, 0);
strlen(private_key_pem) + 1, NULL, 0);
if (rc != 0) {
WARNF("(crypto) Failed to parse private key (grep -0x%04x)", -rc);
mbedtls_pk_free(&key);
@ -439,8 +449,8 @@ static char *RSADecrypt(const char *private_key_pem,
// Decrypt data
size_t output_len = 0;
rc = mbedtls_rsa_pkcs1_decrypt(mbedtls_pk_rsa(key), GenerateRandom, 0,
MBEDTLS_RSA_PRIVATE, &output_len,
encrypted_data, output, key_size);
MBEDTLS_RSA_PRIVATE, &output_len,
encrypted_data, output, key_size);
if (rc != 0) {
WARNF("(crypto) Decryption failed (grep -0x%04x)", -rc);
free(output);
@ -560,10 +570,10 @@ static int LuaRSASign(lua_State *L) {
return 2;
}
if (lua_type(L, 1) == LUA_TTABLE) {
DEBUGF("(crypto) Detected table type for parameter 1");
lua_pushnil(L);
lua_pushstring(L, "Key must be a string, got a table instead");
return 2;
DEBUGF("(crypto) Detected table type for parameter 1");
lua_pushnil(L);
lua_pushstring(L, "Key must be a string, got a table instead");
return 2;
}
// Ensure msg is a string
if (lua_type(L, 2) != LUA_TSTRING) {
@ -617,7 +627,7 @@ static char *RSAPSSSign(const char *private_key_pem, const unsigned char *data,
mbedtls_pk_context key;
mbedtls_pk_init(&key);
rc = mbedtls_pk_parse_key(&key, (const unsigned char *)private_key_pem,
strlen(private_key_pem) + 1, NULL, 0);
strlen(private_key_pem) + 1, NULL, 0);
if (rc != 0) {
WARNF("(crypto) Failed to parse private key (grep -0x%04x)", -rc);
mbedtls_pk_free(&key);
@ -650,9 +660,9 @@ static char *RSAPSSSign(const char *private_key_pem, const unsigned char *data,
mbedtls_rsa_set_padding(rsa, MBEDTLS_RSA_PKCS_V21, hash_algo);
// Sign the hash using PSS
rc = mbedtls_rsa_rsassa_pss_sign(
rsa, GenerateRandom, 0, MBEDTLS_RSA_PRIVATE, hash_algo, (unsigned int)hash_len,
hash, signature);
rc = mbedtls_rsa_rsassa_pss_sign(rsa, GenerateRandom, 0, MBEDTLS_RSA_PRIVATE,
hash_algo, (unsigned int)hash_len, hash,
signature);
if (rc != 0) {
free(signature);
mbedtls_pk_free(&key);
@ -673,7 +683,7 @@ static int LuaRSAPSSSign(lua_State *L) {
unsigned char *signature;
size_t sig_len = 0;
// Get parameters from Lua
// Get parameters from Lua
if (lua_type(L, 1) != LUA_TSTRING) {
lua_pushnil(L);
lua_pushstring(L, "Key must be a string");
@ -697,8 +707,9 @@ static int LuaRSAPSSSign(lua_State *L) {
int salt_len = luaL_optinteger(L, 4, -1);
// Call the C implementation
signature = (unsigned char *)RSAPSSSign(key_pem, (const unsigned char *)msg,
msg_len, hash_name, &sig_len, salt_len);
signature =
(unsigned char *)RSAPSSSign(key_pem, (const unsigned char *)msg, msg_len,
hash_name, &sig_len, salt_len);
if (!signature) {
return luaL_error(L, "failed to sign message (PSS)");
@ -795,9 +806,9 @@ static int LuaRSAVerify(lua_State *L) {
// RSA PSS Verification
static int RSAPSSVerify(const char *public_key_pem, const unsigned char *data,
size_t data_len, const unsigned char *signature,
size_t sig_len, const char *hash_algo_str,
int expected_salt_len) {
size_t data_len, const unsigned char *signature,
size_t sig_len, const char *hash_algo_str,
int expected_salt_len) {
int rc;
unsigned char hash[64]; // Large enough for SHA-512
size_t hash_len = 32; // Default for SHA-256
@ -806,9 +817,7 @@ static int RSAPSSVerify(const char *public_key_pem, const unsigned char *data,
// Determine hash algorithm
if (hash_algo_str) {
hash_algo = string_to_md_type(hash_algo_str);
DEBUGF("(DEBUG) Using hash algorithm: %s", hash_algo_str);
hash_len = get_hash_size_from_md_type(hash_algo);
DEBUGF("(DEBUG) Hash length: %zu", hash_len);
}
// Parse public key
@ -840,9 +849,9 @@ static int RSAPSSVerify(const char *public_key_pem, const unsigned char *data,
mbedtls_rsa_context *rsa = mbedtls_pk_rsa(key);
// Verify the signature using PSS
rc = mbedtls_rsa_rsassa_pss_verify(
rsa, NULL, NULL, MBEDTLS_RSA_PUBLIC, hash_algo, (unsigned int)hash_len,
hash, signature);
rc = mbedtls_rsa_rsassa_pss_verify(rsa, NULL, NULL, MBEDTLS_RSA_PUBLIC,
hash_algo, (unsigned int)hash_len, hash,
signature);
// Clean up
mbedtls_pk_free(&key);
@ -859,7 +868,7 @@ static int LuaRSAPSSVerify(lua_State *L) {
int expected_salt_len = -1;
int result;
// Get parameters from Lua
// Get parameters from Lua
if (lua_type(L, 1) != LUA_TSTRING) {
lua_pushnil(L);
lua_pushstring(L, "Key must be a string");
@ -893,12 +902,14 @@ static int LuaRSAPSSVerify(lua_State *L) {
DEBUGF("(DEBUG) Key PEM: %s", key_pem);
DEBUGF("(DEBUG) Message length: %zu", msg_len);
DEBUGF("(DEBUG) Signature length: %zu", sig_len);
DEBUGF("(DEBUG) Hash algorithm: %s", hash_algo_str ? hash_algo_str : "SHA-256");
DEBUGF("(DEBUG) Hash algorithm: %s",
hash_algo_str ? hash_algo_str : "SHA-256");
DEBUGF("(DEBUG) Expected salt length: %d", expected_salt_len);
DEBUGF("(DEBUG) Signature: %.*s", (int)sig_len, signature);
// Call the C implementation
result = RSAPSSVerify(key_pem, (const unsigned char *)msg, msg_len,
(const unsigned char *)signature, sig_len, hash_algo_str, expected_salt_len);
(const unsigned char *)signature, sig_len,
hash_algo_str, expected_salt_len);
// Return boolean result (0 means valid signature)
lua_pushboolean(L, result == 0);
@ -1077,8 +1088,8 @@ static int ECDSASign(const char *priv_key_pem, const char *message,
}
// Sign the hash using GenerateRandom
ret = mbedtls_pk_sign(&key, hash_alg, hash, hash_size,
*signature, sig_len, GenerateRandom, 0);
ret = mbedtls_pk_sign(&key, hash_alg, hash, hash_size, *signature, sig_len,
GenerateRandom, 0);
if (ret != 0) {
WARNF("(crypto) Failed to sign message: -0x%04x", -ret);
@ -1158,8 +1169,7 @@ static int ECDSAVerify(const char *pub_key_pem, const char *message,
}
// Verify the signature
ret = mbedtls_pk_verify(&key, hash_alg, hash, hash_size,
signature, sig_len);
ret = mbedtls_pk_verify(&key, hash_alg, hash, hash_size, signature, sig_len);
if (ret != 0) {
WARNF("(crypto) Signature verification failed: -0x%04x", -ret);
goto cleanup;
@ -1223,7 +1233,8 @@ typedef struct {
size_t aadlen;
} aes_options_t;
static void parse_aes_options(lua_State *L, int options_idx, aes_options_t *opts) {
static void parse_aes_options(lua_State *L, int options_idx,
aes_options_t *opts) {
opts->mode = NULL;
opts->iv = NULL;
opts->ivlen = 0;
@ -1240,12 +1251,12 @@ static void parse_aes_options(lua_State *L, int options_idx, aes_options_t *opts
if (!lua_isnil(L, -1)) {
mode_field_found = 1;
const char *mode = lua_tostring(L, -1);
if (mode && (strcasecmp(mode, "cbc") == 0 ||
strcasecmp(mode, "gcm") == 0 ||
strcasecmp(mode, "ctr") == 0)) {
if (mode &&
(strcasecmp(mode, "cbc") == 0 || strcasecmp(mode, "gcm") == 0 ||
strcasecmp(mode, "ctr") == 0)) {
opts->mode = mode;
} else {
opts->mode = NULL; // Invalid mode
opts->mode = NULL; // Invalid mode
}
}
lua_pop(L, 1);
@ -1288,7 +1299,6 @@ static int LuaAesEncrypt(lua_State *L) {
// Args: key, plaintext, options table
size_t keylen, ptlen;
// Get parameters from Lua
// Ensure key is a string
if (lua_type(L, 1) != LUA_TSTRING) {
@ -1319,7 +1329,8 @@ static int LuaAesEncrypt(lua_State *L) {
const char *mode = opts.mode;
if (!mode) {
lua_pushnil(L);
lua_pushstring(L, "Invalid AES mode specified. Use 'cbc', 'gcm', or 'ctr'.");
lua_pushstring(L,
"Invalid AES mode specified. Use 'cbc', 'gcm', or 'ctr'.");
return 2;
}
const unsigned char *iv = opts.iv;
@ -1354,6 +1365,7 @@ static int LuaAesEncrypt(lua_State *L) {
lua_pushstring(L, "Failed to allocate IV");
return 2;
}
// Generate random IV
if (GenerateRandom(NULL, gen_iv, ivlen) != 0) {
free(gen_iv);
@ -1367,23 +1379,23 @@ static int LuaAesEncrypt(lua_State *L) {
// Validate IV/nonce length
if (is_cbc || is_ctr) {
// Validate IV/nonce length
if (is_cbc || is_ctr) {
if (opts.iv && opts.ivlen != 16) {
if (iv_was_generated) free(gen_iv);
lua_pushnil(L);
lua_pushstring(L, "AES IV/nonce must be 16 bytes for CBC/CTR");
return 2;
}
} else if (is_gcm) {
if (opts.iv && (opts.ivlen < 12 || opts.ivlen > 16)) {
if (iv_was_generated) free(gen_iv);
lua_pushnil(L);
lua_pushstring(L, "AES GCM nonce must be 12-16 bytes");
return 2;
}
if (opts.iv && opts.ivlen != 16) {
if (iv_was_generated)
free(gen_iv);
lua_pushnil(L);
lua_pushstring(L, "AES IV/nonce must be 16 bytes for CBC/CTR");
return 2;
}
} else if (is_gcm) {
if (opts.iv && (opts.ivlen < 12 || opts.ivlen > 16)) {
if (iv_was_generated)
free(gen_iv);
lua_pushnil(L);
lua_pushstring(L, "AES GCM nonce must be 12-16 bytes");
return 2;
}
}
if (is_cbc) {
// PKCS7 padding
size_t block_size = 16;
@ -1531,7 +1543,7 @@ static int LuaAesDecrypt(lua_State *L) {
return 2;
}
const unsigned char *key =
(const unsigned char *)luaL_checklstring(L, 1, &keylen);
(const unsigned char *)luaL_checklstring(L, 1, &keylen);
const unsigned char *ciphertext =
(const unsigned char *)luaL_checklstring(L, 2, &ctlen);
int options_idx = 3;
@ -1540,7 +1552,8 @@ static int LuaAesDecrypt(lua_State *L) {
const char *mode = opts.mode;
if (!mode) {
lua_pushnil(L);
lua_pushstring(L, "Invalid AES mode specified. Use 'cbc', 'gcm', or 'ctr'.");
lua_pushstring(L,
"Invalid AES mode specified. Use 'cbc', 'gcm', or 'ctr'.");
return 2;
}
const unsigned char *iv = opts.iv;
@ -1725,7 +1738,6 @@ static int LuaAesDecrypt(lua_State *L) {
return 2;
}
// JWK functions
// Convert JWK (Lua table) to PEM key format
@ -1786,23 +1798,44 @@ static int LuaConvertJwkToPem(lua_State *L) {
size_t n_len, e_len;
unsigned char n_bin[1024], e_bin[16];
char *n_b64_std = strdup(n_b64), *e_b64_std = strdup(e_b64);
for (char *p = n_b64_std; *p; ++p) if (*p == '-') *p = '+'; else if (*p == '_') *p = '/';
for (char *p = e_b64_std; *p; ++p) if (*p == '-') *p = '+'; else if (*p == '_') *p = '/';
for (char *p = n_b64_std; *p; ++p)
if (*p == '-')
*p = '+';
else if (*p == '_')
*p = '/';
for (char *p = e_b64_std; *p; ++p)
if (*p == '-')
*p = '+';
else if (*p == '_')
*p = '/';
int n_mod = strlen(n_b64_std) % 4;
int e_mod = strlen(e_b64_std) % 4;
if (n_mod) for (int i = 0; i < 4 - n_mod; ++i) strcat(n_b64_std, "=");
if (e_mod) for (int i = 0; i < 4 - e_mod; ++i) strcat(e_b64_std, "=");
if (mbedtls_base64_decode(n_bin, sizeof(n_bin), &n_len, (const unsigned char *)n_b64_std, strlen(n_b64_std)) != 0 ||
mbedtls_base64_decode(e_bin, sizeof(e_bin), &e_len, (const unsigned char *)e_b64_std, strlen(e_b64_std)) != 0) {
free(n_b64_std); free(e_b64_std);
if (n_mod)
for (int i = 0; i < 4 - n_mod; ++i)
strcat(n_b64_std, "=");
if (e_mod)
for (int i = 0; i < 4 - e_mod; ++i)
strcat(e_b64_std, "=");
if (mbedtls_base64_decode(n_bin, sizeof(n_bin), &n_len,
(const unsigned char *)n_b64_std,
strlen(n_b64_std)) != 0 ||
mbedtls_base64_decode(e_bin, sizeof(e_bin), &e_len,
(const unsigned char *)e_b64_std,
strlen(e_b64_std)) != 0) {
free(n_b64_std);
free(e_b64_std);
lua_pushnil(L);
lua_pushstring(L, "Base64 decode failed");
return 2;
}
free(n_b64_std); free(e_b64_std);
free(n_b64_std);
free(e_b64_std);
// Build RSA context in pk
if ((ret = mbedtls_pk_setup(&pk, mbedtls_pk_info_from_type(MBEDTLS_PK_RSA))) != 0) {
lua_pushnil(L); lua_pushstring(L, "mbedtls_pk_setup failed"); return 2;
if ((ret = mbedtls_pk_setup(
&pk, mbedtls_pk_info_from_type(MBEDTLS_PK_RSA))) != 0) {
lua_pushnil(L);
lua_pushstring(L, "mbedtls_pk_setup failed");
return 2;
}
mbedtls_rsa_context *rsa = mbedtls_pk_rsa(pk);
mbedtls_rsa_init(rsa, MBEDTLS_RSA_PKCS_V15, 0);
@ -1812,17 +1845,25 @@ static int LuaConvertJwkToPem(lua_State *L) {
if (has_private) {
// Decode and set private fields
size_t d_len, p_len, q_len, dp_len, dq_len, qi_len;
unsigned char d_bin[1024], p_bin[512], q_bin[512], dp_bin[512], dq_bin[512], qi_bin[512];
// Decode all private fields (skip if NULL)
#define DECODE_B64URL(var, b64, bin, binlen) \
if (b64 && *b64) { \
char *b64_std = strdup(b64); \
for (char *p = b64_std; *p; ++p) if (*p == '-') *p = '+'; else if (*p == '_') *p = '/'; \
int mod = strlen(b64_std) % 4; \
if (mod) for (int i = 0; i < 4 - mod; ++i) strcat(b64_std, "="); \
mbedtls_base64_decode(bin, sizeof(bin), &binlen, (const unsigned char *)b64_std, strlen(b64_std)); \
free(b64_std); \
}
unsigned char d_bin[1024], p_bin[512], q_bin[512], dp_bin[512],
dq_bin[512], qi_bin[512];
// Decode all private fields (skip if NULL)
#define DECODE_B64URL(var, b64, bin, binlen) \
if (b64 && *b64) { \
char *b64_std = strdup(b64); \
for (char *p = b64_std; *p; ++p) \
if (*p == '-') \
*p = '+'; \
else if (*p == '_') \
*p = '/'; \
int mod = strlen(b64_std) % 4; \
if (mod) \
for (int i = 0; i < 4 - mod; ++i) \
strcat(b64_std, "="); \
mbedtls_base64_decode(bin, sizeof(bin), &binlen, \
(const unsigned char *)b64_std, strlen(b64_std)); \
free(b64_std); \
}
DECODE_B64URL(d, d_b64, d_bin, d_len);
DECODE_B64URL(p, p_b64, p_bin, p_len);
DECODE_B64URL(q, q_b64, q_bin, q_len);
@ -1845,7 +1886,9 @@ static int LuaConvertJwkToPem(lua_State *L) {
}
if (ret != 0) {
mbedtls_pk_free(&pk);
lua_pushnil(L); lua_pushstring(L, "PEM write failed"); return 2;
lua_pushnil(L);
lua_pushstring(L, "PEM write failed");
return 2;
}
pem = strdup((char *)buf);
mbedtls_pk_free(&pk);
@ -1863,36 +1906,67 @@ static int LuaConvertJwkToPem(lua_State *L) {
const char *y_b64 = lua_tostring(L, -2);
const char *d_b64 = lua_tostring(L, -1);
if (!crv || !*crv) {
lua_pushnil(L); lua_pushstring(L, "Missing or empty 'crv' in JWK"); return 2;
lua_pushnil(L);
lua_pushstring(L, "Missing or empty 'crv' in JWK");
return 2;
}
if (!x_b64 || !*x_b64) {
lua_pushnil(L); lua_pushstring(L, "Missing or empty 'x' in JWK"); return 2;
lua_pushnil(L);
lua_pushstring(L, "Missing or empty 'x' in JWK");
return 2;
}
if (!y_b64 || !*y_b64) {
lua_pushnil(L); lua_pushstring(L, "Missing or empty 'y' in JWK"); return 2;
lua_pushnil(L);
lua_pushstring(L, "Missing or empty 'y' in JWK");
return 2;
}
int has_private = d_b64 && *d_b64;
mbedtls_ecp_group_id gid = find_curve_by_name(crv);
if (gid == MBEDTLS_ECP_DP_NONE) {
lua_pushnil(L); lua_pushstring(L, "Unknown curve"); return 2;
lua_pushnil(L);
lua_pushstring(L, "Unknown curve");
return 2;
}
size_t x_len, y_len;
unsigned char x_bin[72], y_bin[72];
char *x_b64_std = strdup(x_b64), *y_b64_std = strdup(y_b64);
for (char *p = x_b64_std; *p; ++p) if (*p == '-') *p = '+'; else if (*p == '_') *p = '/';
for ( char *p = y_b64_std; *p; ++p) if (*p == '-') *p = '+'; else if (*p == '_') *p = '/';
for (char *p = x_b64_std; *p; ++p)
if (*p == '-')
*p = '+';
else if (*p == '_')
*p = '/';
for (char *p = y_b64_std; *p; ++p)
if (*p == '-')
*p = '+';
else if (*p == '_')
*p = '/';
int x_mod = strlen(x_b64_std) % 4;
int y_mod = strlen(y_b64_std) % 4;
if (x_mod) for (int i = 0; i < 4 - x_mod; ++i) strcat(x_b64_std, "=");
if (y_mod) for (int i = 0; i < 4 - y_mod; ++i) strcat(y_b64_std, "=");
if (mbedtls_base64_decode(x_bin, sizeof(x_bin), &x_len, (const unsigned char *)x_b64_std, strlen(x_b64_std)) != 0 ||
mbedtls_base64_decode(y_bin, sizeof(y_bin), &y_len, (const unsigned char *)y_b64_std, strlen(y_b64_std)) != 0) {
free(x_b64_std); free(y_b64_std);
lua_pushnil(L); lua_pushstring(L, "Base64 decode failed"); return 2;
if (x_mod)
for (int i = 0; i < 4 - x_mod; ++i)
strcat(x_b64_std, "=");
if (y_mod)
for (int i = 0; i < 4 - y_mod; ++i)
strcat(y_b64_std, "=");
if (mbedtls_base64_decode(x_bin, sizeof(x_bin), &x_len,
(const unsigned char *)x_b64_std,
strlen(x_b64_std)) != 0 ||
mbedtls_base64_decode(y_bin, sizeof(y_bin), &y_len,
(const unsigned char *)y_b64_std,
strlen(y_b64_std)) != 0) {
free(x_b64_std);
free(y_b64_std);
lua_pushnil(L);
lua_pushstring(L, "Base64 decode failed");
return 2;
}
free(x_b64_std); free(y_b64_std);
if ((ret = mbedtls_pk_setup(&pk, mbedtls_pk_info_from_type(MBEDTLS_PK_ECKEY))) != 0) {
lua_pushnil(L); lua_pushstring(L, "mbedtls_pk_setup failed"); return 2;
free(x_b64_std);
free(y_b64_std);
if ((ret = mbedtls_pk_setup(
&pk, mbedtls_pk_info_from_type(MBEDTLS_PK_ECKEY))) != 0) {
lua_pushnil(L);
lua_pushstring(L, "mbedtls_pk_setup failed");
return 2;
}
mbedtls_ecp_keypair *ec = mbedtls_pk_ec(pk);
mbedtls_ecp_keypair_init(ec);
@ -1914,8 +1988,10 @@ static int LuaConvertJwkToPem(lua_State *L) {
}
if (ret != 0) {
mbedtls_pk_free(&pk);
lua_pushnil(L); lua_pushstring(L, "PEM write failed"); return 2;
}
lua_pushnil(L);
lua_pushstring(L, "PEM write failed");
return 2;
}
pem = strdup((char *)buf);
mbedtls_pk_free(&pk);
lua_pushstring(L, pem);
@ -1930,14 +2006,17 @@ static int LuaConvertJwkToPem(lua_State *L) {
// Convert PEM key to JWK (Lua table) format
static void base64_to_base64url(char *str) {
if (!str) return;
if (!str)
return;
for (char *p = str; *p; p++) {
if (*p == '+') *p = '-';
else if (*p == '/') *p = '_';
if (*p == '+')
*p = '-';
else if (*p == '/')
*p = '_';
}
// Remove padding
size_t len = strlen(str);
while (len > 0 && str[len-1] == '=') {
while (len > 0 && str[len - 1] == '=') {
str[--len] = '\0';
}
}
@ -1990,10 +2069,12 @@ static int LuaConvertPemToJwk(lua_State *L) {
mbedtls_base64_encode(NULL, 0, &e_b64_len, e, e_len);
n_b64 = malloc(n_b64_len + 1);
e_b64 = malloc(e_b64_len + 1);
mbedtls_base64_encode((unsigned char *)n_b64, n_b64_len, &n_b64_len, n, n_len);
mbedtls_base64_encode((unsigned char *)n_b64, n_b64_len, &n_b64_len, n,
n_len);
n_b64[n_b64_len] = '\0';
base64_to_base64url(n_b64);
mbedtls_base64_encode((unsigned char *)e_b64, e_b64_len, &e_b64_len, e, e_len);
mbedtls_base64_encode((unsigned char *)e_b64, e_b64_len, &e_b64_len, e,
e_len);
e_b64[e_b64_len] = '\0';
base64_to_base64url(e_b64);
lua_pushstring(L, n_b64);
@ -2045,12 +2126,18 @@ static int LuaConvertPemToJwk(lua_State *L) {
dp_b64 = malloc(dp_b64_len + 1);
dq_b64 = malloc(dq_b64_len + 1);
qi_b64 = malloc(qi_b64_len + 1);
mbedtls_base64_encode((unsigned char *)d_b64, d_b64_len, &d_b64_len, d, d_len);
mbedtls_base64_encode((unsigned char *)p_b64, p_b64_len, &p_b64_len, p, p_len);
mbedtls_base64_encode((unsigned char *)q_b64, q_b64_len, &q_b64_len, q, q_len);
mbedtls_base64_encode((unsigned char *)dp_b64, dp_b64_len, &dp_b64_len, dp, dp_len);
mbedtls_base64_encode((unsigned char *)dq_b64, dq_b64_len, &dq_b64_len, dq, dq_len);
mbedtls_base64_encode((unsigned char *)qi_b64, qi_b64_len, &qi_b64_len, qi, qi_len);
mbedtls_base64_encode((unsigned char *)d_b64, d_b64_len, &d_b64_len, d,
d_len);
mbedtls_base64_encode((unsigned char *)p_b64, p_b64_len, &p_b64_len, p,
p_len);
mbedtls_base64_encode((unsigned char *)q_b64, q_b64_len, &q_b64_len, q,
q_len);
mbedtls_base64_encode((unsigned char *)dp_b64, dp_b64_len, &dp_b64_len,
dp, dp_len);
mbedtls_base64_encode((unsigned char *)dq_b64, dq_b64_len, &dq_b64_len,
dq, dq_len);
mbedtls_base64_encode((unsigned char *)qi_b64, qi_b64_len, &qi_b64_len,
qi, qi_len);
d_b64[d_b64_len] = '\0';
p_b64[p_b64_len] = '\0';
q_b64[q_b64_len] = '\0';
@ -2117,16 +2204,19 @@ static int LuaConvertPemToJwk(lua_State *L) {
mbedtls_base64_encode(NULL, 0, &y_b64_len, y, y_len);
x_b64 = malloc(x_b64_len + 1);
y_b64 = malloc(y_b64_len + 1);
mbedtls_base64_encode((unsigned char *)x_b64, x_b64_len, &x_b64_len, x, x_len);
mbedtls_base64_encode((unsigned char *)x_b64, x_b64_len, &x_b64_len, x,
x_len);
x_b64[x_b64_len] = '\0';
base64_to_base64url(x_b64);
mbedtls_base64_encode((unsigned char *)y_b64, y_b64_len, &y_b64_len, y, y_len);
mbedtls_base64_encode((unsigned char *)y_b64, y_b64_len, &y_b64_len, y,
y_len);
y_b64[y_b64_len] = '\0';
base64_to_base64url(y_b64);
// Set kty and crv for EC keys
lua_pushstring(L, "EC");
lua_setfield(L, -2, "kty");
const mbedtls_ecp_curve_info *curve_info = mbedtls_ecp_curve_info_from_grp_id(ec->grp.id);
const mbedtls_ecp_curve_info *curve_info =
mbedtls_ecp_curve_info_from_grp_id(ec->grp.id);
if (curve_info && curve_info->name) {
lua_pushstring(L, curve_info->name);
lua_setfield(L, -2, "crv");
@ -2142,19 +2232,34 @@ static int LuaConvertPemToJwk(lua_State *L) {
if (mbedtls_ecp_check_privkey(&ec->grp, &ec->d) == 0 && ec->d.p) {
size_t d_len = mbedtls_mpi_size(&ec->d);
unsigned char *d = malloc(d_len);
if (!d) { free(x); free(y); free(x_b64); free(y_b64); lua_pushnil(L); lua_pushstring(L, "Memory allocation failed"); mbedtls_pk_free(&key); return 2; }
if (!d) {
free(x);
free(y);
free(x_b64);
free(y_b64);
lua_pushnil(L);
lua_pushstring(L, "Memory allocation failed");
mbedtls_pk_free(&key);
return 2;
}
mbedtls_mpi_write_binary(&ec->d, d, d_len);
char *d_b64 = NULL;
size_t d_b64_len;
mbedtls_base64_encode(NULL, 0, &d_b64_len, d, d_len);
d_b64 = malloc(d_b64_len + 1);
mbedtls_base64_encode((unsigned char *)d_b64, d_b64_len, &d_b64_len, d, d_len);
mbedtls_base64_encode((unsigned char *)d_b64, d_b64_len, &d_b64_len, d,
d_len);
d_b64[d_b64_len] = '\0';
base64_to_base64url(d_b64);
lua_pushstring(L, d_b64); lua_setfield(L, -2, "d");
free(d); free(d_b64);
lua_pushstring(L, d_b64);
lua_setfield(L, -2, "d");
free(d);
free(d_b64);
}
free(x); free(y); free(x_b64); free(y_b64);
free(x);
free(y);
free(x_b64);
free(y_b64);
} else {
lua_pushnil(L);
lua_pushstring(L, "Unsupported key type");
@ -2166,8 +2271,10 @@ static int LuaConvertPemToJwk(lua_State *L) {
// Merge additional claims if provided and compatible with RFC7517
if (has_claims) {
static const char *allowed[] = {"kty","use","sig","key_ops","alg","kid","x5u","x5c","x5t","x5t#S256",NULL};
lua_pushnil(L); // first key
static const char *allowed[] = {"kty", "use", "sig", "key_ops",
"alg", "kid", "x5u", "x5c",
"x5t", "x5t#S256", NULL};
lua_pushnil(L); // first key
while (lua_next(L, 2) != 0) {
const char *k = lua_tostring(L, -2);
int allowed_key = 0;
@ -2252,7 +2359,6 @@ static int LuaGenerateCSR(lua_State *L) {
return 1;
}
// LuaCrypto compatible API
static int LuaCryptoSign(lua_State *L) {
// Type of signature (e.g., "rsa", "ecdsa", "rsa-pss")
@ -2262,7 +2368,8 @@ static int LuaCryptoSign(lua_State *L) {
if (strcasecmp(dtype, "rsa") == 0) {
return LuaRSASign(L);
} else if (strcasecmp(dtype, "rsa-pss") == 0 || strcasecmp(dtype, "rsapss") == 0) {
} else if (strcasecmp(dtype, "rsa-pss") == 0 ||
strcasecmp(dtype, "rsapss") == 0) {
return LuaRSAPSSSign(L);
} else if (strcasecmp(dtype, "ecdsa") == 0) {
return LuaECDSASign(L);
@ -2279,7 +2386,8 @@ static int LuaCryptoVerify(lua_State *L) {
if (strcasecmp(dtype, "rsa") == 0) {
return LuaRSAVerify(L);
} else if (strcasecmp(dtype, "rsa-pss") == 0 || strcasecmp(dtype, "rsapss") == 0) {
} else if (strcasecmp(dtype, "rsa-pss") == 0 ||
strcasecmp(dtype, "rsapss") == 0) {
return LuaRSAPSSVerify(L);
} else if (strcasecmp(dtype, "ecdsa") == 0) {
return LuaECDSAVerify(L);
@ -2339,8 +2447,6 @@ static int LuaCryptoGenerateKeyPair(lua_State *L) {
}
}
static const luaL_Reg kLuaCrypto[] = {
{"sign", LuaCryptoSign}, //
{"verify", LuaCryptoVerify}, //