mirror of
https://github.com/jart/cosmopolitan.git
synced 2025-07-12 14:09:12 +00:00
Address Paul's comments :)
This commit is contained in:
parent
12ff789a69
commit
b8bdccc7fc
1 changed files with 236 additions and 130 deletions
|
@ -39,12 +39,15 @@ typedef struct {
|
|||
} curve_map_t;
|
||||
|
||||
static const curve_map_t supported_curves[] = {
|
||||
{"secp256r1", MBEDTLS_ECP_DP_SECP256R1},
|
||||
{"P-256", MBEDTLS_ECP_DP_SECP256R1},
|
||||
{"secp384r1", MBEDTLS_ECP_DP_SECP384R1},
|
||||
{"P-384", MBEDTLS_ECP_DP_SECP384R1},
|
||||
{"secp521r1", MBEDTLS_ECP_DP_SECP521R1},
|
||||
{"P-521", MBEDTLS_ECP_DP_SECP521R1},
|
||||
{"secp256r1", MBEDTLS_ECP_DP_SECP256R1},
|
||||
{"P256", MBEDTLS_ECP_DP_SECP256R1},
|
||||
{"P-256", MBEDTLS_ECP_DP_SECP256R1},
|
||||
{"secp384r1", MBEDTLS_ECP_DP_SECP384R1},
|
||||
{"P384", MBEDTLS_ECP_DP_SECP384R1},
|
||||
{"P-384", MBEDTLS_ECP_DP_SECP384R1},
|
||||
{"secp521r1", MBEDTLS_ECP_DP_SECP521R1},
|
||||
{"P521", MBEDTLS_ECP_DP_SECP521R1},
|
||||
{"P-521", MBEDTLS_ECP_DP_SECP521R1},
|
||||
{"curve25519", MBEDTLS_ECP_DP_CURVE25519},
|
||||
{NULL, 0}};
|
||||
|
||||
|
@ -69,11 +72,14 @@ static mbedtls_md_type_t string_to_md_type(const char *hash_name) {
|
|||
if (!hash_name || !*hash_name) {
|
||||
return MBEDTLS_MD_SHA256; // Default to SHA-256 if no name provided
|
||||
}
|
||||
if (strcasecmp(hash_name, "sha256") == 0 || strcasecmp(hash_name, "sha-256") == 0) {
|
||||
if (strcasecmp(hash_name, "sha256") == 0 ||
|
||||
strcasecmp(hash_name, "sha-256") == 0) {
|
||||
return MBEDTLS_MD_SHA256;
|
||||
} else if (strcasecmp(hash_name, "sha384") == 0 || strcasecmp(hash_name, "sha-384") == 0) {
|
||||
} else if (strcasecmp(hash_name, "sha384") == 0 ||
|
||||
strcasecmp(hash_name, "sha-384") == 0) {
|
||||
return MBEDTLS_MD_SHA384;
|
||||
} else if (strcasecmp(hash_name, "sha512") == 0 || strcasecmp(hash_name, "sha-512") == 0) {
|
||||
} else if (strcasecmp(hash_name, "sha512") == 0 ||
|
||||
strcasecmp(hash_name, "sha-512") == 0) {
|
||||
return MBEDTLS_MD_SHA512;
|
||||
} else {
|
||||
WARNF("(crypto) Unknown hash algorithm '%s'", hash_name);
|
||||
|
@ -81,6 +87,7 @@ static mbedtls_md_type_t string_to_md_type(const char *hash_name) {
|
|||
}
|
||||
}
|
||||
|
||||
// Get the size of the hash output based on the mbedtls_md_type_t
|
||||
static size_t get_hash_size_from_md_type(mbedtls_md_type_t md_type) {
|
||||
switch (md_type) {
|
||||
case MBEDTLS_MD_SHA256:
|
||||
|
@ -288,7 +295,10 @@ static int LuaRSAGenerateKeyPair(lua_State *L) {
|
|||
// Check if key length is valid (only 2048 or 4096 bits are allowed)
|
||||
if (bits != 2048 && bits != 4096) {
|
||||
lua_pushnil(L);
|
||||
lua_pushfstring(L, "Invalid RSA key length: %d. Only 2048 or 4096 bits key lengths are supported", bits);
|
||||
lua_pushfstring(L,
|
||||
"Invalid RSA key length: %d. Only 2048 or 4096 bits key "
|
||||
"lengths are supported",
|
||||
bits);
|
||||
return 2;
|
||||
}
|
||||
|
||||
|
@ -414,7 +424,7 @@ static char *RSADecrypt(const char *private_key_pem,
|
|||
mbedtls_pk_context key;
|
||||
mbedtls_pk_init(&key);
|
||||
rc = mbedtls_pk_parse_key(&key, (const unsigned char *)private_key_pem,
|
||||
strlen(private_key_pem) + 1, NULL, 0);
|
||||
strlen(private_key_pem) + 1, NULL, 0);
|
||||
if (rc != 0) {
|
||||
WARNF("(crypto) Failed to parse private key (grep -0x%04x)", -rc);
|
||||
mbedtls_pk_free(&key);
|
||||
|
@ -439,8 +449,8 @@ static char *RSADecrypt(const char *private_key_pem,
|
|||
// Decrypt data
|
||||
size_t output_len = 0;
|
||||
rc = mbedtls_rsa_pkcs1_decrypt(mbedtls_pk_rsa(key), GenerateRandom, 0,
|
||||
MBEDTLS_RSA_PRIVATE, &output_len,
|
||||
encrypted_data, output, key_size);
|
||||
MBEDTLS_RSA_PRIVATE, &output_len,
|
||||
encrypted_data, output, key_size);
|
||||
if (rc != 0) {
|
||||
WARNF("(crypto) Decryption failed (grep -0x%04x)", -rc);
|
||||
free(output);
|
||||
|
@ -560,10 +570,10 @@ static int LuaRSASign(lua_State *L) {
|
|||
return 2;
|
||||
}
|
||||
if (lua_type(L, 1) == LUA_TTABLE) {
|
||||
DEBUGF("(crypto) Detected table type for parameter 1");
|
||||
lua_pushnil(L);
|
||||
lua_pushstring(L, "Key must be a string, got a table instead");
|
||||
return 2;
|
||||
DEBUGF("(crypto) Detected table type for parameter 1");
|
||||
lua_pushnil(L);
|
||||
lua_pushstring(L, "Key must be a string, got a table instead");
|
||||
return 2;
|
||||
}
|
||||
// Ensure msg is a string
|
||||
if (lua_type(L, 2) != LUA_TSTRING) {
|
||||
|
@ -617,7 +627,7 @@ static char *RSAPSSSign(const char *private_key_pem, const unsigned char *data,
|
|||
mbedtls_pk_context key;
|
||||
mbedtls_pk_init(&key);
|
||||
rc = mbedtls_pk_parse_key(&key, (const unsigned char *)private_key_pem,
|
||||
strlen(private_key_pem) + 1, NULL, 0);
|
||||
strlen(private_key_pem) + 1, NULL, 0);
|
||||
if (rc != 0) {
|
||||
WARNF("(crypto) Failed to parse private key (grep -0x%04x)", -rc);
|
||||
mbedtls_pk_free(&key);
|
||||
|
@ -650,9 +660,9 @@ static char *RSAPSSSign(const char *private_key_pem, const unsigned char *data,
|
|||
mbedtls_rsa_set_padding(rsa, MBEDTLS_RSA_PKCS_V21, hash_algo);
|
||||
|
||||
// Sign the hash using PSS
|
||||
rc = mbedtls_rsa_rsassa_pss_sign(
|
||||
rsa, GenerateRandom, 0, MBEDTLS_RSA_PRIVATE, hash_algo, (unsigned int)hash_len,
|
||||
hash, signature);
|
||||
rc = mbedtls_rsa_rsassa_pss_sign(rsa, GenerateRandom, 0, MBEDTLS_RSA_PRIVATE,
|
||||
hash_algo, (unsigned int)hash_len, hash,
|
||||
signature);
|
||||
if (rc != 0) {
|
||||
free(signature);
|
||||
mbedtls_pk_free(&key);
|
||||
|
@ -673,7 +683,7 @@ static int LuaRSAPSSSign(lua_State *L) {
|
|||
unsigned char *signature;
|
||||
size_t sig_len = 0;
|
||||
|
||||
// Get parameters from Lua
|
||||
// Get parameters from Lua
|
||||
if (lua_type(L, 1) != LUA_TSTRING) {
|
||||
lua_pushnil(L);
|
||||
lua_pushstring(L, "Key must be a string");
|
||||
|
@ -697,8 +707,9 @@ static int LuaRSAPSSSign(lua_State *L) {
|
|||
int salt_len = luaL_optinteger(L, 4, -1);
|
||||
|
||||
// Call the C implementation
|
||||
signature = (unsigned char *)RSAPSSSign(key_pem, (const unsigned char *)msg,
|
||||
msg_len, hash_name, &sig_len, salt_len);
|
||||
signature =
|
||||
(unsigned char *)RSAPSSSign(key_pem, (const unsigned char *)msg, msg_len,
|
||||
hash_name, &sig_len, salt_len);
|
||||
|
||||
if (!signature) {
|
||||
return luaL_error(L, "failed to sign message (PSS)");
|
||||
|
@ -795,9 +806,9 @@ static int LuaRSAVerify(lua_State *L) {
|
|||
|
||||
// RSA PSS Verification
|
||||
static int RSAPSSVerify(const char *public_key_pem, const unsigned char *data,
|
||||
size_t data_len, const unsigned char *signature,
|
||||
size_t sig_len, const char *hash_algo_str,
|
||||
int expected_salt_len) {
|
||||
size_t data_len, const unsigned char *signature,
|
||||
size_t sig_len, const char *hash_algo_str,
|
||||
int expected_salt_len) {
|
||||
int rc;
|
||||
unsigned char hash[64]; // Large enough for SHA-512
|
||||
size_t hash_len = 32; // Default for SHA-256
|
||||
|
@ -806,9 +817,7 @@ static int RSAPSSVerify(const char *public_key_pem, const unsigned char *data,
|
|||
// Determine hash algorithm
|
||||
if (hash_algo_str) {
|
||||
hash_algo = string_to_md_type(hash_algo_str);
|
||||
DEBUGF("(DEBUG) Using hash algorithm: %s", hash_algo_str);
|
||||
hash_len = get_hash_size_from_md_type(hash_algo);
|
||||
DEBUGF("(DEBUG) Hash length: %zu", hash_len);
|
||||
}
|
||||
|
||||
// Parse public key
|
||||
|
@ -840,9 +849,9 @@ static int RSAPSSVerify(const char *public_key_pem, const unsigned char *data,
|
|||
mbedtls_rsa_context *rsa = mbedtls_pk_rsa(key);
|
||||
|
||||
// Verify the signature using PSS
|
||||
rc = mbedtls_rsa_rsassa_pss_verify(
|
||||
rsa, NULL, NULL, MBEDTLS_RSA_PUBLIC, hash_algo, (unsigned int)hash_len,
|
||||
hash, signature);
|
||||
rc = mbedtls_rsa_rsassa_pss_verify(rsa, NULL, NULL, MBEDTLS_RSA_PUBLIC,
|
||||
hash_algo, (unsigned int)hash_len, hash,
|
||||
signature);
|
||||
|
||||
// Clean up
|
||||
mbedtls_pk_free(&key);
|
||||
|
@ -859,7 +868,7 @@ static int LuaRSAPSSVerify(lua_State *L) {
|
|||
int expected_salt_len = -1;
|
||||
int result;
|
||||
|
||||
// Get parameters from Lua
|
||||
// Get parameters from Lua
|
||||
if (lua_type(L, 1) != LUA_TSTRING) {
|
||||
lua_pushnil(L);
|
||||
lua_pushstring(L, "Key must be a string");
|
||||
|
@ -893,12 +902,14 @@ static int LuaRSAPSSVerify(lua_State *L) {
|
|||
DEBUGF("(DEBUG) Key PEM: %s", key_pem);
|
||||
DEBUGF("(DEBUG) Message length: %zu", msg_len);
|
||||
DEBUGF("(DEBUG) Signature length: %zu", sig_len);
|
||||
DEBUGF("(DEBUG) Hash algorithm: %s", hash_algo_str ? hash_algo_str : "SHA-256");
|
||||
DEBUGF("(DEBUG) Hash algorithm: %s",
|
||||
hash_algo_str ? hash_algo_str : "SHA-256");
|
||||
DEBUGF("(DEBUG) Expected salt length: %d", expected_salt_len);
|
||||
DEBUGF("(DEBUG) Signature: %.*s", (int)sig_len, signature);
|
||||
// Call the C implementation
|
||||
result = RSAPSSVerify(key_pem, (const unsigned char *)msg, msg_len,
|
||||
(const unsigned char *)signature, sig_len, hash_algo_str, expected_salt_len);
|
||||
(const unsigned char *)signature, sig_len,
|
||||
hash_algo_str, expected_salt_len);
|
||||
|
||||
// Return boolean result (0 means valid signature)
|
||||
lua_pushboolean(L, result == 0);
|
||||
|
@ -1077,8 +1088,8 @@ static int ECDSASign(const char *priv_key_pem, const char *message,
|
|||
}
|
||||
|
||||
// Sign the hash using GenerateRandom
|
||||
ret = mbedtls_pk_sign(&key, hash_alg, hash, hash_size,
|
||||
*signature, sig_len, GenerateRandom, 0);
|
||||
ret = mbedtls_pk_sign(&key, hash_alg, hash, hash_size, *signature, sig_len,
|
||||
GenerateRandom, 0);
|
||||
|
||||
if (ret != 0) {
|
||||
WARNF("(crypto) Failed to sign message: -0x%04x", -ret);
|
||||
|
@ -1158,8 +1169,7 @@ static int ECDSAVerify(const char *pub_key_pem, const char *message,
|
|||
}
|
||||
|
||||
// Verify the signature
|
||||
ret = mbedtls_pk_verify(&key, hash_alg, hash, hash_size,
|
||||
signature, sig_len);
|
||||
ret = mbedtls_pk_verify(&key, hash_alg, hash, hash_size, signature, sig_len);
|
||||
if (ret != 0) {
|
||||
WARNF("(crypto) Signature verification failed: -0x%04x", -ret);
|
||||
goto cleanup;
|
||||
|
@ -1223,7 +1233,8 @@ typedef struct {
|
|||
size_t aadlen;
|
||||
} aes_options_t;
|
||||
|
||||
static void parse_aes_options(lua_State *L, int options_idx, aes_options_t *opts) {
|
||||
static void parse_aes_options(lua_State *L, int options_idx,
|
||||
aes_options_t *opts) {
|
||||
opts->mode = NULL;
|
||||
opts->iv = NULL;
|
||||
opts->ivlen = 0;
|
||||
|
@ -1240,12 +1251,12 @@ static void parse_aes_options(lua_State *L, int options_idx, aes_options_t *opts
|
|||
if (!lua_isnil(L, -1)) {
|
||||
mode_field_found = 1;
|
||||
const char *mode = lua_tostring(L, -1);
|
||||
if (mode && (strcasecmp(mode, "cbc") == 0 ||
|
||||
strcasecmp(mode, "gcm") == 0 ||
|
||||
strcasecmp(mode, "ctr") == 0)) {
|
||||
if (mode &&
|
||||
(strcasecmp(mode, "cbc") == 0 || strcasecmp(mode, "gcm") == 0 ||
|
||||
strcasecmp(mode, "ctr") == 0)) {
|
||||
opts->mode = mode;
|
||||
} else {
|
||||
opts->mode = NULL; // Invalid mode
|
||||
opts->mode = NULL; // Invalid mode
|
||||
}
|
||||
}
|
||||
lua_pop(L, 1);
|
||||
|
@ -1288,7 +1299,6 @@ static int LuaAesEncrypt(lua_State *L) {
|
|||
// Args: key, plaintext, options table
|
||||
size_t keylen, ptlen;
|
||||
|
||||
|
||||
// Get parameters from Lua
|
||||
// Ensure key is a string
|
||||
if (lua_type(L, 1) != LUA_TSTRING) {
|
||||
|
@ -1319,7 +1329,8 @@ static int LuaAesEncrypt(lua_State *L) {
|
|||
const char *mode = opts.mode;
|
||||
if (!mode) {
|
||||
lua_pushnil(L);
|
||||
lua_pushstring(L, "Invalid AES mode specified. Use 'cbc', 'gcm', or 'ctr'.");
|
||||
lua_pushstring(L,
|
||||
"Invalid AES mode specified. Use 'cbc', 'gcm', or 'ctr'.");
|
||||
return 2;
|
||||
}
|
||||
const unsigned char *iv = opts.iv;
|
||||
|
@ -1354,6 +1365,7 @@ static int LuaAesEncrypt(lua_State *L) {
|
|||
lua_pushstring(L, "Failed to allocate IV");
|
||||
return 2;
|
||||
}
|
||||
|
||||
// Generate random IV
|
||||
if (GenerateRandom(NULL, gen_iv, ivlen) != 0) {
|
||||
free(gen_iv);
|
||||
|
@ -1367,23 +1379,23 @@ static int LuaAesEncrypt(lua_State *L) {
|
|||
|
||||
// Validate IV/nonce length
|
||||
if (is_cbc || is_ctr) {
|
||||
// Validate IV/nonce length
|
||||
if (is_cbc || is_ctr) {
|
||||
if (opts.iv && opts.ivlen != 16) {
|
||||
if (iv_was_generated) free(gen_iv);
|
||||
lua_pushnil(L);
|
||||
lua_pushstring(L, "AES IV/nonce must be 16 bytes for CBC/CTR");
|
||||
return 2;
|
||||
}
|
||||
} else if (is_gcm) {
|
||||
if (opts.iv && (opts.ivlen < 12 || opts.ivlen > 16)) {
|
||||
if (iv_was_generated) free(gen_iv);
|
||||
lua_pushnil(L);
|
||||
lua_pushstring(L, "AES GCM nonce must be 12-16 bytes");
|
||||
return 2;
|
||||
}
|
||||
if (opts.iv && opts.ivlen != 16) {
|
||||
if (iv_was_generated)
|
||||
free(gen_iv);
|
||||
lua_pushnil(L);
|
||||
lua_pushstring(L, "AES IV/nonce must be 16 bytes for CBC/CTR");
|
||||
return 2;
|
||||
}
|
||||
} else if (is_gcm) {
|
||||
if (opts.iv && (opts.ivlen < 12 || opts.ivlen > 16)) {
|
||||
if (iv_was_generated)
|
||||
free(gen_iv);
|
||||
lua_pushnil(L);
|
||||
lua_pushstring(L, "AES GCM nonce must be 12-16 bytes");
|
||||
return 2;
|
||||
}
|
||||
}
|
||||
|
||||
if (is_cbc) {
|
||||
// PKCS7 padding
|
||||
size_t block_size = 16;
|
||||
|
@ -1531,7 +1543,7 @@ static int LuaAesDecrypt(lua_State *L) {
|
|||
return 2;
|
||||
}
|
||||
const unsigned char *key =
|
||||
(const unsigned char *)luaL_checklstring(L, 1, &keylen);
|
||||
(const unsigned char *)luaL_checklstring(L, 1, &keylen);
|
||||
const unsigned char *ciphertext =
|
||||
(const unsigned char *)luaL_checklstring(L, 2, &ctlen);
|
||||
int options_idx = 3;
|
||||
|
@ -1540,7 +1552,8 @@ static int LuaAesDecrypt(lua_State *L) {
|
|||
const char *mode = opts.mode;
|
||||
if (!mode) {
|
||||
lua_pushnil(L);
|
||||
lua_pushstring(L, "Invalid AES mode specified. Use 'cbc', 'gcm', or 'ctr'.");
|
||||
lua_pushstring(L,
|
||||
"Invalid AES mode specified. Use 'cbc', 'gcm', or 'ctr'.");
|
||||
return 2;
|
||||
}
|
||||
const unsigned char *iv = opts.iv;
|
||||
|
@ -1725,7 +1738,6 @@ static int LuaAesDecrypt(lua_State *L) {
|
|||
return 2;
|
||||
}
|
||||
|
||||
|
||||
// JWK functions
|
||||
|
||||
// Convert JWK (Lua table) to PEM key format
|
||||
|
@ -1786,23 +1798,44 @@ static int LuaConvertJwkToPem(lua_State *L) {
|
|||
size_t n_len, e_len;
|
||||
unsigned char n_bin[1024], e_bin[16];
|
||||
char *n_b64_std = strdup(n_b64), *e_b64_std = strdup(e_b64);
|
||||
for (char *p = n_b64_std; *p; ++p) if (*p == '-') *p = '+'; else if (*p == '_') *p = '/';
|
||||
for (char *p = e_b64_std; *p; ++p) if (*p == '-') *p = '+'; else if (*p == '_') *p = '/';
|
||||
for (char *p = n_b64_std; *p; ++p)
|
||||
if (*p == '-')
|
||||
*p = '+';
|
||||
else if (*p == '_')
|
||||
*p = '/';
|
||||
for (char *p = e_b64_std; *p; ++p)
|
||||
if (*p == '-')
|
||||
*p = '+';
|
||||
else if (*p == '_')
|
||||
*p = '/';
|
||||
int n_mod = strlen(n_b64_std) % 4;
|
||||
int e_mod = strlen(e_b64_std) % 4;
|
||||
if (n_mod) for (int i = 0; i < 4 - n_mod; ++i) strcat(n_b64_std, "=");
|
||||
if (e_mod) for (int i = 0; i < 4 - e_mod; ++i) strcat(e_b64_std, "=");
|
||||
if (mbedtls_base64_decode(n_bin, sizeof(n_bin), &n_len, (const unsigned char *)n_b64_std, strlen(n_b64_std)) != 0 ||
|
||||
mbedtls_base64_decode(e_bin, sizeof(e_bin), &e_len, (const unsigned char *)e_b64_std, strlen(e_b64_std)) != 0) {
|
||||
free(n_b64_std); free(e_b64_std);
|
||||
if (n_mod)
|
||||
for (int i = 0; i < 4 - n_mod; ++i)
|
||||
strcat(n_b64_std, "=");
|
||||
if (e_mod)
|
||||
for (int i = 0; i < 4 - e_mod; ++i)
|
||||
strcat(e_b64_std, "=");
|
||||
if (mbedtls_base64_decode(n_bin, sizeof(n_bin), &n_len,
|
||||
(const unsigned char *)n_b64_std,
|
||||
strlen(n_b64_std)) != 0 ||
|
||||
mbedtls_base64_decode(e_bin, sizeof(e_bin), &e_len,
|
||||
(const unsigned char *)e_b64_std,
|
||||
strlen(e_b64_std)) != 0) {
|
||||
free(n_b64_std);
|
||||
free(e_b64_std);
|
||||
lua_pushnil(L);
|
||||
lua_pushstring(L, "Base64 decode failed");
|
||||
return 2;
|
||||
}
|
||||
free(n_b64_std); free(e_b64_std);
|
||||
free(n_b64_std);
|
||||
free(e_b64_std);
|
||||
// Build RSA context in pk
|
||||
if ((ret = mbedtls_pk_setup(&pk, mbedtls_pk_info_from_type(MBEDTLS_PK_RSA))) != 0) {
|
||||
lua_pushnil(L); lua_pushstring(L, "mbedtls_pk_setup failed"); return 2;
|
||||
if ((ret = mbedtls_pk_setup(
|
||||
&pk, mbedtls_pk_info_from_type(MBEDTLS_PK_RSA))) != 0) {
|
||||
lua_pushnil(L);
|
||||
lua_pushstring(L, "mbedtls_pk_setup failed");
|
||||
return 2;
|
||||
}
|
||||
mbedtls_rsa_context *rsa = mbedtls_pk_rsa(pk);
|
||||
mbedtls_rsa_init(rsa, MBEDTLS_RSA_PKCS_V15, 0);
|
||||
|
@ -1812,17 +1845,25 @@ static int LuaConvertJwkToPem(lua_State *L) {
|
|||
if (has_private) {
|
||||
// Decode and set private fields
|
||||
size_t d_len, p_len, q_len, dp_len, dq_len, qi_len;
|
||||
unsigned char d_bin[1024], p_bin[512], q_bin[512], dp_bin[512], dq_bin[512], qi_bin[512];
|
||||
// Decode all private fields (skip if NULL)
|
||||
#define DECODE_B64URL(var, b64, bin, binlen) \
|
||||
if (b64 && *b64) { \
|
||||
char *b64_std = strdup(b64); \
|
||||
for (char *p = b64_std; *p; ++p) if (*p == '-') *p = '+'; else if (*p == '_') *p = '/'; \
|
||||
int mod = strlen(b64_std) % 4; \
|
||||
if (mod) for (int i = 0; i < 4 - mod; ++i) strcat(b64_std, "="); \
|
||||
mbedtls_base64_decode(bin, sizeof(bin), &binlen, (const unsigned char *)b64_std, strlen(b64_std)); \
|
||||
free(b64_std); \
|
||||
}
|
||||
unsigned char d_bin[1024], p_bin[512], q_bin[512], dp_bin[512],
|
||||
dq_bin[512], qi_bin[512];
|
||||
// Decode all private fields (skip if NULL)
|
||||
#define DECODE_B64URL(var, b64, bin, binlen) \
|
||||
if (b64 && *b64) { \
|
||||
char *b64_std = strdup(b64); \
|
||||
for (char *p = b64_std; *p; ++p) \
|
||||
if (*p == '-') \
|
||||
*p = '+'; \
|
||||
else if (*p == '_') \
|
||||
*p = '/'; \
|
||||
int mod = strlen(b64_std) % 4; \
|
||||
if (mod) \
|
||||
for (int i = 0; i < 4 - mod; ++i) \
|
||||
strcat(b64_std, "="); \
|
||||
mbedtls_base64_decode(bin, sizeof(bin), &binlen, \
|
||||
(const unsigned char *)b64_std, strlen(b64_std)); \
|
||||
free(b64_std); \
|
||||
}
|
||||
DECODE_B64URL(d, d_b64, d_bin, d_len);
|
||||
DECODE_B64URL(p, p_b64, p_bin, p_len);
|
||||
DECODE_B64URL(q, q_b64, q_bin, q_len);
|
||||
|
@ -1845,7 +1886,9 @@ static int LuaConvertJwkToPem(lua_State *L) {
|
|||
}
|
||||
if (ret != 0) {
|
||||
mbedtls_pk_free(&pk);
|
||||
lua_pushnil(L); lua_pushstring(L, "PEM write failed"); return 2;
|
||||
lua_pushnil(L);
|
||||
lua_pushstring(L, "PEM write failed");
|
||||
return 2;
|
||||
}
|
||||
pem = strdup((char *)buf);
|
||||
mbedtls_pk_free(&pk);
|
||||
|
@ -1863,36 +1906,67 @@ static int LuaConvertJwkToPem(lua_State *L) {
|
|||
const char *y_b64 = lua_tostring(L, -2);
|
||||
const char *d_b64 = lua_tostring(L, -1);
|
||||
if (!crv || !*crv) {
|
||||
lua_pushnil(L); lua_pushstring(L, "Missing or empty 'crv' in JWK"); return 2;
|
||||
lua_pushnil(L);
|
||||
lua_pushstring(L, "Missing or empty 'crv' in JWK");
|
||||
return 2;
|
||||
}
|
||||
if (!x_b64 || !*x_b64) {
|
||||
lua_pushnil(L); lua_pushstring(L, "Missing or empty 'x' in JWK"); return 2;
|
||||
lua_pushnil(L);
|
||||
lua_pushstring(L, "Missing or empty 'x' in JWK");
|
||||
return 2;
|
||||
}
|
||||
if (!y_b64 || !*y_b64) {
|
||||
lua_pushnil(L); lua_pushstring(L, "Missing or empty 'y' in JWK"); return 2;
|
||||
lua_pushnil(L);
|
||||
lua_pushstring(L, "Missing or empty 'y' in JWK");
|
||||
return 2;
|
||||
}
|
||||
int has_private = d_b64 && *d_b64;
|
||||
mbedtls_ecp_group_id gid = find_curve_by_name(crv);
|
||||
if (gid == MBEDTLS_ECP_DP_NONE) {
|
||||
lua_pushnil(L); lua_pushstring(L, "Unknown curve"); return 2;
|
||||
lua_pushnil(L);
|
||||
lua_pushstring(L, "Unknown curve");
|
||||
return 2;
|
||||
}
|
||||
size_t x_len, y_len;
|
||||
unsigned char x_bin[72], y_bin[72];
|
||||
char *x_b64_std = strdup(x_b64), *y_b64_std = strdup(y_b64);
|
||||
for (char *p = x_b64_std; *p; ++p) if (*p == '-') *p = '+'; else if (*p == '_') *p = '/';
|
||||
for ( char *p = y_b64_std; *p; ++p) if (*p == '-') *p = '+'; else if (*p == '_') *p = '/';
|
||||
for (char *p = x_b64_std; *p; ++p)
|
||||
if (*p == '-')
|
||||
*p = '+';
|
||||
else if (*p == '_')
|
||||
*p = '/';
|
||||
for (char *p = y_b64_std; *p; ++p)
|
||||
if (*p == '-')
|
||||
*p = '+';
|
||||
else if (*p == '_')
|
||||
*p = '/';
|
||||
int x_mod = strlen(x_b64_std) % 4;
|
||||
int y_mod = strlen(y_b64_std) % 4;
|
||||
if (x_mod) for (int i = 0; i < 4 - x_mod; ++i) strcat(x_b64_std, "=");
|
||||
if (y_mod) for (int i = 0; i < 4 - y_mod; ++i) strcat(y_b64_std, "=");
|
||||
if (mbedtls_base64_decode(x_bin, sizeof(x_bin), &x_len, (const unsigned char *)x_b64_std, strlen(x_b64_std)) != 0 ||
|
||||
mbedtls_base64_decode(y_bin, sizeof(y_bin), &y_len, (const unsigned char *)y_b64_std, strlen(y_b64_std)) != 0) {
|
||||
free(x_b64_std); free(y_b64_std);
|
||||
lua_pushnil(L); lua_pushstring(L, "Base64 decode failed"); return 2;
|
||||
if (x_mod)
|
||||
for (int i = 0; i < 4 - x_mod; ++i)
|
||||
strcat(x_b64_std, "=");
|
||||
if (y_mod)
|
||||
for (int i = 0; i < 4 - y_mod; ++i)
|
||||
strcat(y_b64_std, "=");
|
||||
if (mbedtls_base64_decode(x_bin, sizeof(x_bin), &x_len,
|
||||
(const unsigned char *)x_b64_std,
|
||||
strlen(x_b64_std)) != 0 ||
|
||||
mbedtls_base64_decode(y_bin, sizeof(y_bin), &y_len,
|
||||
(const unsigned char *)y_b64_std,
|
||||
strlen(y_b64_std)) != 0) {
|
||||
free(x_b64_std);
|
||||
free(y_b64_std);
|
||||
lua_pushnil(L);
|
||||
lua_pushstring(L, "Base64 decode failed");
|
||||
return 2;
|
||||
}
|
||||
free(x_b64_std); free(y_b64_std);
|
||||
if ((ret = mbedtls_pk_setup(&pk, mbedtls_pk_info_from_type(MBEDTLS_PK_ECKEY))) != 0) {
|
||||
lua_pushnil(L); lua_pushstring(L, "mbedtls_pk_setup failed"); return 2;
|
||||
free(x_b64_std);
|
||||
free(y_b64_std);
|
||||
if ((ret = mbedtls_pk_setup(
|
||||
&pk, mbedtls_pk_info_from_type(MBEDTLS_PK_ECKEY))) != 0) {
|
||||
lua_pushnil(L);
|
||||
lua_pushstring(L, "mbedtls_pk_setup failed");
|
||||
return 2;
|
||||
}
|
||||
mbedtls_ecp_keypair *ec = mbedtls_pk_ec(pk);
|
||||
mbedtls_ecp_keypair_init(ec);
|
||||
|
@ -1914,8 +1988,10 @@ static int LuaConvertJwkToPem(lua_State *L) {
|
|||
}
|
||||
if (ret != 0) {
|
||||
mbedtls_pk_free(&pk);
|
||||
lua_pushnil(L); lua_pushstring(L, "PEM write failed"); return 2;
|
||||
}
|
||||
lua_pushnil(L);
|
||||
lua_pushstring(L, "PEM write failed");
|
||||
return 2;
|
||||
}
|
||||
pem = strdup((char *)buf);
|
||||
mbedtls_pk_free(&pk);
|
||||
lua_pushstring(L, pem);
|
||||
|
@ -1930,14 +2006,17 @@ static int LuaConvertJwkToPem(lua_State *L) {
|
|||
|
||||
// Convert PEM key to JWK (Lua table) format
|
||||
static void base64_to_base64url(char *str) {
|
||||
if (!str) return;
|
||||
if (!str)
|
||||
return;
|
||||
for (char *p = str; *p; p++) {
|
||||
if (*p == '+') *p = '-';
|
||||
else if (*p == '/') *p = '_';
|
||||
if (*p == '+')
|
||||
*p = '-';
|
||||
else if (*p == '/')
|
||||
*p = '_';
|
||||
}
|
||||
// Remove padding
|
||||
size_t len = strlen(str);
|
||||
while (len > 0 && str[len-1] == '=') {
|
||||
while (len > 0 && str[len - 1] == '=') {
|
||||
str[--len] = '\0';
|
||||
}
|
||||
}
|
||||
|
@ -1990,10 +2069,12 @@ static int LuaConvertPemToJwk(lua_State *L) {
|
|||
mbedtls_base64_encode(NULL, 0, &e_b64_len, e, e_len);
|
||||
n_b64 = malloc(n_b64_len + 1);
|
||||
e_b64 = malloc(e_b64_len + 1);
|
||||
mbedtls_base64_encode((unsigned char *)n_b64, n_b64_len, &n_b64_len, n, n_len);
|
||||
mbedtls_base64_encode((unsigned char *)n_b64, n_b64_len, &n_b64_len, n,
|
||||
n_len);
|
||||
n_b64[n_b64_len] = '\0';
|
||||
base64_to_base64url(n_b64);
|
||||
mbedtls_base64_encode((unsigned char *)e_b64, e_b64_len, &e_b64_len, e, e_len);
|
||||
mbedtls_base64_encode((unsigned char *)e_b64, e_b64_len, &e_b64_len, e,
|
||||
e_len);
|
||||
e_b64[e_b64_len] = '\0';
|
||||
base64_to_base64url(e_b64);
|
||||
lua_pushstring(L, n_b64);
|
||||
|
@ -2045,12 +2126,18 @@ static int LuaConvertPemToJwk(lua_State *L) {
|
|||
dp_b64 = malloc(dp_b64_len + 1);
|
||||
dq_b64 = malloc(dq_b64_len + 1);
|
||||
qi_b64 = malloc(qi_b64_len + 1);
|
||||
mbedtls_base64_encode((unsigned char *)d_b64, d_b64_len, &d_b64_len, d, d_len);
|
||||
mbedtls_base64_encode((unsigned char *)p_b64, p_b64_len, &p_b64_len, p, p_len);
|
||||
mbedtls_base64_encode((unsigned char *)q_b64, q_b64_len, &q_b64_len, q, q_len);
|
||||
mbedtls_base64_encode((unsigned char *)dp_b64, dp_b64_len, &dp_b64_len, dp, dp_len);
|
||||
mbedtls_base64_encode((unsigned char *)dq_b64, dq_b64_len, &dq_b64_len, dq, dq_len);
|
||||
mbedtls_base64_encode((unsigned char *)qi_b64, qi_b64_len, &qi_b64_len, qi, qi_len);
|
||||
mbedtls_base64_encode((unsigned char *)d_b64, d_b64_len, &d_b64_len, d,
|
||||
d_len);
|
||||
mbedtls_base64_encode((unsigned char *)p_b64, p_b64_len, &p_b64_len, p,
|
||||
p_len);
|
||||
mbedtls_base64_encode((unsigned char *)q_b64, q_b64_len, &q_b64_len, q,
|
||||
q_len);
|
||||
mbedtls_base64_encode((unsigned char *)dp_b64, dp_b64_len, &dp_b64_len,
|
||||
dp, dp_len);
|
||||
mbedtls_base64_encode((unsigned char *)dq_b64, dq_b64_len, &dq_b64_len,
|
||||
dq, dq_len);
|
||||
mbedtls_base64_encode((unsigned char *)qi_b64, qi_b64_len, &qi_b64_len,
|
||||
qi, qi_len);
|
||||
d_b64[d_b64_len] = '\0';
|
||||
p_b64[p_b64_len] = '\0';
|
||||
q_b64[q_b64_len] = '\0';
|
||||
|
@ -2117,16 +2204,19 @@ static int LuaConvertPemToJwk(lua_State *L) {
|
|||
mbedtls_base64_encode(NULL, 0, &y_b64_len, y, y_len);
|
||||
x_b64 = malloc(x_b64_len + 1);
|
||||
y_b64 = malloc(y_b64_len + 1);
|
||||
mbedtls_base64_encode((unsigned char *)x_b64, x_b64_len, &x_b64_len, x, x_len);
|
||||
mbedtls_base64_encode((unsigned char *)x_b64, x_b64_len, &x_b64_len, x,
|
||||
x_len);
|
||||
x_b64[x_b64_len] = '\0';
|
||||
base64_to_base64url(x_b64);
|
||||
mbedtls_base64_encode((unsigned char *)y_b64, y_b64_len, &y_b64_len, y, y_len);
|
||||
mbedtls_base64_encode((unsigned char *)y_b64, y_b64_len, &y_b64_len, y,
|
||||
y_len);
|
||||
y_b64[y_b64_len] = '\0';
|
||||
base64_to_base64url(y_b64);
|
||||
// Set kty and crv for EC keys
|
||||
lua_pushstring(L, "EC");
|
||||
lua_setfield(L, -2, "kty");
|
||||
const mbedtls_ecp_curve_info *curve_info = mbedtls_ecp_curve_info_from_grp_id(ec->grp.id);
|
||||
const mbedtls_ecp_curve_info *curve_info =
|
||||
mbedtls_ecp_curve_info_from_grp_id(ec->grp.id);
|
||||
if (curve_info && curve_info->name) {
|
||||
lua_pushstring(L, curve_info->name);
|
||||
lua_setfield(L, -2, "crv");
|
||||
|
@ -2142,19 +2232,34 @@ static int LuaConvertPemToJwk(lua_State *L) {
|
|||
if (mbedtls_ecp_check_privkey(&ec->grp, &ec->d) == 0 && ec->d.p) {
|
||||
size_t d_len = mbedtls_mpi_size(&ec->d);
|
||||
unsigned char *d = malloc(d_len);
|
||||
if (!d) { free(x); free(y); free(x_b64); free(y_b64); lua_pushnil(L); lua_pushstring(L, "Memory allocation failed"); mbedtls_pk_free(&key); return 2; }
|
||||
if (!d) {
|
||||
free(x);
|
||||
free(y);
|
||||
free(x_b64);
|
||||
free(y_b64);
|
||||
lua_pushnil(L);
|
||||
lua_pushstring(L, "Memory allocation failed");
|
||||
mbedtls_pk_free(&key);
|
||||
return 2;
|
||||
}
|
||||
mbedtls_mpi_write_binary(&ec->d, d, d_len);
|
||||
char *d_b64 = NULL;
|
||||
size_t d_b64_len;
|
||||
mbedtls_base64_encode(NULL, 0, &d_b64_len, d, d_len);
|
||||
d_b64 = malloc(d_b64_len + 1);
|
||||
mbedtls_base64_encode((unsigned char *)d_b64, d_b64_len, &d_b64_len, d, d_len);
|
||||
mbedtls_base64_encode((unsigned char *)d_b64, d_b64_len, &d_b64_len, d,
|
||||
d_len);
|
||||
d_b64[d_b64_len] = '\0';
|
||||
base64_to_base64url(d_b64);
|
||||
lua_pushstring(L, d_b64); lua_setfield(L, -2, "d");
|
||||
free(d); free(d_b64);
|
||||
lua_pushstring(L, d_b64);
|
||||
lua_setfield(L, -2, "d");
|
||||
free(d);
|
||||
free(d_b64);
|
||||
}
|
||||
free(x); free(y); free(x_b64); free(y_b64);
|
||||
free(x);
|
||||
free(y);
|
||||
free(x_b64);
|
||||
free(y_b64);
|
||||
} else {
|
||||
lua_pushnil(L);
|
||||
lua_pushstring(L, "Unsupported key type");
|
||||
|
@ -2166,8 +2271,10 @@ static int LuaConvertPemToJwk(lua_State *L) {
|
|||
|
||||
// Merge additional claims if provided and compatible with RFC7517
|
||||
if (has_claims) {
|
||||
static const char *allowed[] = {"kty","use","sig","key_ops","alg","kid","x5u","x5c","x5t","x5t#S256",NULL};
|
||||
lua_pushnil(L); // first key
|
||||
static const char *allowed[] = {"kty", "use", "sig", "key_ops",
|
||||
"alg", "kid", "x5u", "x5c",
|
||||
"x5t", "x5t#S256", NULL};
|
||||
lua_pushnil(L); // first key
|
||||
while (lua_next(L, 2) != 0) {
|
||||
const char *k = lua_tostring(L, -2);
|
||||
int allowed_key = 0;
|
||||
|
@ -2252,7 +2359,6 @@ static int LuaGenerateCSR(lua_State *L) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
|
||||
// LuaCrypto compatible API
|
||||
static int LuaCryptoSign(lua_State *L) {
|
||||
// Type of signature (e.g., "rsa", "ecdsa", "rsa-pss")
|
||||
|
@ -2262,7 +2368,8 @@ static int LuaCryptoSign(lua_State *L) {
|
|||
|
||||
if (strcasecmp(dtype, "rsa") == 0) {
|
||||
return LuaRSASign(L);
|
||||
} else if (strcasecmp(dtype, "rsa-pss") == 0 || strcasecmp(dtype, "rsapss") == 0) {
|
||||
} else if (strcasecmp(dtype, "rsa-pss") == 0 ||
|
||||
strcasecmp(dtype, "rsapss") == 0) {
|
||||
return LuaRSAPSSSign(L);
|
||||
} else if (strcasecmp(dtype, "ecdsa") == 0) {
|
||||
return LuaECDSASign(L);
|
||||
|
@ -2279,7 +2386,8 @@ static int LuaCryptoVerify(lua_State *L) {
|
|||
|
||||
if (strcasecmp(dtype, "rsa") == 0) {
|
||||
return LuaRSAVerify(L);
|
||||
} else if (strcasecmp(dtype, "rsa-pss") == 0 || strcasecmp(dtype, "rsapss") == 0) {
|
||||
} else if (strcasecmp(dtype, "rsa-pss") == 0 ||
|
||||
strcasecmp(dtype, "rsapss") == 0) {
|
||||
return LuaRSAPSSVerify(L);
|
||||
} else if (strcasecmp(dtype, "ecdsa") == 0) {
|
||||
return LuaECDSAVerify(L);
|
||||
|
@ -2339,8 +2447,6 @@ static int LuaCryptoGenerateKeyPair(lua_State *L) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
static const luaL_Reg kLuaCrypto[] = {
|
||||
{"sign", LuaCryptoSign}, //
|
||||
{"verify", LuaCryptoVerify}, //
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue