Merge parse aes parse options functions

This commit is contained in:
Miguel Terron 2025-06-04 21:42:17 +12:00
parent d06d0879b8
commit e35f99c7db

View file

@ -1,18 +1,18 @@
#include "libc/log/log.h" #include "libc/log/log.h"
#include "net/https/https.h" #include "net/https/https.h"
#include "third_party/lua/lauxlib.h" #include "third_party/lua/lauxlib.h"
#include "third_party/mbedtls/aes.h"
#include "third_party/mbedtls/base64.h"
#include "third_party/mbedtls/ctr_drbg.h"
#include "third_party/mbedtls/ecdsa.h"
#include "third_party/mbedtls/entropy.h"
#include "third_party/mbedtls/error.h" #include "third_party/mbedtls/error.h"
#include "third_party/mbedtls/gcm.h"
#include "third_party/mbedtls/md.h"
#include "third_party/mbedtls/oid.h"
#include "third_party/mbedtls/pk.h" #include "third_party/mbedtls/pk.h"
#include "third_party/mbedtls/rsa.h" #include "third_party/mbedtls/rsa.h"
#include "third_party/mbedtls/ecdsa.h"
#include "third_party/mbedtls/x509_csr.h" #include "third_party/mbedtls/x509_csr.h"
#include "third_party/mbedtls/oid.h"
#include "third_party/mbedtls/md.h"
#include "third_party/mbedtls/base64.h"
#include "third_party/mbedtls/aes.h"
#include "third_party/mbedtls/ctr_drbg.h"
#include "third_party/mbedtls/entropy.h"
#include "third_party/mbedtls/gcm.h"
// Standard C library and redbean utilities // Standard C library and redbean utilities
#include "libc/errno.h" #include "libc/errno.h"
@ -29,8 +29,10 @@ static int LuaConvertPemToJwk(lua_State *L) {
int ret; int ret;
// Parse the PEM key // Parse the PEM key
if ((ret = mbedtls_pk_parse_key(&key, (const unsigned char *)pem_key, strlen(pem_key) + 1, NULL, 0)) != 0 && if ((ret = mbedtls_pk_parse_key(&key, (const unsigned char *)pem_key,
(ret = mbedtls_pk_parse_public_key(&key, (const unsigned char *)pem_key, strlen(pem_key) + 1)) != 0) { strlen(pem_key) + 1, NULL, 0)) != 0 &&
(ret = mbedtls_pk_parse_public_key(&key, (const unsigned char *)pem_key,
strlen(pem_key) + 1)) != 0) {
lua_pushnil(L); lua_pushnil(L);
lua_pushfstring(L, "Failed to parse PEM key: -0x%04x", -ret); lua_pushfstring(L, "Failed to parse PEM key: -0x%04x", -ret);
mbedtls_pk_free(&key); mbedtls_pk_free(&key);
@ -80,8 +82,10 @@ static int LuaConvertPemToJwk(lua_State *L) {
return 2; return 2;
} }
mbedtls_base64_encode((unsigned char *)n_b64, n_b64_len, &n_b64_len, n, n_len); mbedtls_base64_encode((unsigned char *)n_b64, n_b64_len, &n_b64_len, n,
mbedtls_base64_encode((unsigned char *)e_b64, e_b64_len, &e_b64_len, e, e_len); n_len);
mbedtls_base64_encode((unsigned char *)e_b64, e_b64_len, &e_b64_len, e,
e_len);
n_b64[n_b64_len] = '\0'; n_b64[n_b64_len] = '\0';
e_b64[e_b64_len] = '\0'; e_b64[e_b64_len] = '\0';
@ -139,8 +143,10 @@ static int LuaConvertPemToJwk(lua_State *L) {
return 2; return 2;
} }
mbedtls_base64_encode((unsigned char *)x_b64, x_b64_len, &x_b64_len, x, x_len); mbedtls_base64_encode((unsigned char *)x_b64, x_b64_len, &x_b64_len, x,
mbedtls_base64_encode((unsigned char *)y_b64, y_b64_len, &y_b64_len, y, y_len); x_len);
mbedtls_base64_encode((unsigned char *)y_b64, y_b64_len, &y_b64_len, y,
y_len);
x_b64[x_b64_len] = '\0'; x_b64[x_b64_len] = '\0';
y_b64[y_b64_len] = '\0'; y_b64[y_b64_len] = '\0';
@ -181,7 +187,6 @@ static int LuaGenerateCSR(lua_State *L) {
subject_name = luaL_checkstring(L, 2); subject_name = luaL_checkstring(L, 2);
} }
if (lua_isnoneornil(L, 3) && subject_name[0] == '\0') { if (lua_isnoneornil(L, 3) && subject_name[0] == '\0') {
lua_pushnil(L); lua_pushnil(L);
lua_pushstring(L, "Subject name or SANs are required"); lua_pushstring(L, "Subject name or SANs are required");
@ -195,7 +200,8 @@ static int LuaGenerateCSR(lua_State *L) {
mbedtls_pk_init(&key); mbedtls_pk_init(&key);
mbedtls_x509write_csr_init(&req); mbedtls_x509write_csr_init(&req);
if ((ret = mbedtls_pk_parse_key(&key, (const unsigned char *)key_pem, strlen(key_pem) + 1, NULL, 0)) != 0) { if ((ret = mbedtls_pk_parse_key(&key, (const unsigned char *)key_pem,
strlen(key_pem) + 1, NULL, 0)) != 0) {
lua_pushnil(L); lua_pushnil(L);
lua_pushfstring(L, "Failed to parse key: %d", ret); lua_pushfstring(L, "Failed to parse key: %d", ret);
return 2; return 2;
@ -206,14 +212,18 @@ static int LuaGenerateCSR(lua_State *L) {
mbedtls_x509write_csr_set_md_alg(&req, MBEDTLS_MD_SHA256); mbedtls_x509write_csr_set_md_alg(&req, MBEDTLS_MD_SHA256);
if (san_list) { if (san_list) {
if ((ret = mbedtls_x509write_csr_set_extension(&req, MBEDTLS_OID_SUBJECT_ALT_NAME, MBEDTLS_OID_SIZE(MBEDTLS_OID_SUBJECT_ALT_NAME), (const unsigned char *)san_list, strlen(san_list))) != 0) { if ((ret = mbedtls_x509write_csr_set_extension(
&req, MBEDTLS_OID_SUBJECT_ALT_NAME,
MBEDTLS_OID_SIZE(MBEDTLS_OID_SUBJECT_ALT_NAME),
(const unsigned char *)san_list, strlen(san_list))) != 0) {
lua_pushnil(L); lua_pushnil(L);
lua_pushfstring(L, "Failed to set SANs: %d", ret); lua_pushfstring(L, "Failed to set SANs: %d", ret);
return 2; return 2;
} }
} }
if ((ret = mbedtls_x509write_csr_pem(&req, (unsigned char *)buf, sizeof(buf), NULL, NULL)) < 0) { if ((ret = mbedtls_x509write_csr_pem(&req, (unsigned char *)buf, sizeof(buf),
NULL, NULL)) < 0) {
lua_pushnil(L); lua_pushnil(L);
lua_pushfstring(L, "Failed to write CSR: %d", ret); lua_pushfstring(L, "Failed to write CSR: %d", ret);
return 2; return 2;
@ -281,15 +291,6 @@ static bool RSAGenerateKeyPair(char **private_key_pem, size_t *private_key_len,
mbedtls_pk_free(&key); mbedtls_pk_free(&key);
return true; return true;
} }
/**
* Lua wrapper for RSA key pair generation
*
* Lua function signature: RSAGenerateKeyPair([key_length])
* @param L Lua state
* @return 2 on success (private_key, public_key), 2 on failure (nil,
* error_message)
*/
static int LuaRSAGenerateKeyPair(lua_State *L) { static int LuaRSAGenerateKeyPair(lua_State *L) {
int bits = 2048; int bits = 2048;
// If no arguments, or first argument is nil, default to 2048 // If no arguments, or first argument is nil, default to 2048
@ -305,7 +306,8 @@ static int LuaRSAGenerateKeyPair(lua_State *L) {
size_t private_len, public_len; size_t private_len, public_len;
// Call the C function to generate the key pair // Call the C function to generate the key pair
if (!RSAGenerateKeyPair(&private_key, &private_len, &public_key, &public_len, bits)) { if (!RSAGenerateKeyPair(&private_key, &private_len, &public_key, &public_len,
bits)) {
lua_pushnil(L); lua_pushnil(L);
lua_pushstring(L, "Failed to generate RSA key pair"); lua_pushstring(L, "Failed to generate RSA key pair");
return 2; return 2;
@ -383,7 +385,8 @@ static int LuaRSAEncrypt(lua_State *L) {
// Args: key, plaintext, options table // Args: key, plaintext, options table
size_t keylen, ptlen; size_t keylen, ptlen;
const char *key = luaL_checklstring(L, 1, &keylen); const char *key = luaL_checklstring(L, 1, &keylen);
const unsigned char *plaintext = (const unsigned char *)luaL_checklstring(L, 2, &ptlen); const unsigned char *plaintext =
(const unsigned char *)luaL_checklstring(L, 2, &ptlen);
// int options_idx = 3; // int options_idx = 3;
// const char *padding = parse_rsa_options(L, options_idx); // const char *padding = parse_rsa_options(L, options_idx);
size_t out_len; size_t out_len;
@ -402,8 +405,8 @@ static int LuaRSAEncrypt(lua_State *L) {
} }
static char *RSADecrypt(const char *private_key_pem, static char *RSADecrypt(const char *private_key_pem,
const unsigned char *encrypted_data, size_t encrypted_len, const unsigned char *encrypted_data,
size_t *out_len) { size_t encrypted_len, size_t *out_len) {
int rc; int rc;
// Parse private key // Parse private key
@ -450,7 +453,8 @@ static int LuaRSADecrypt(lua_State *L) {
// Args: key, ciphertext, options table // Args: key, ciphertext, options table
size_t keylen, ctlen; size_t keylen, ctlen;
const char *key = luaL_checklstring(L, 1, &keylen); const char *key = luaL_checklstring(L, 1, &keylen);
const unsigned char *ciphertext = (const unsigned char *)luaL_checklstring(L, 2, &ctlen); const unsigned char *ciphertext =
(const unsigned char *)luaL_checklstring(L, 2, &ctlen);
// int options_idx = 3; // int options_idx = 3;
// const char *padding = parse_rsa_options(L, options_idx); // const char *padding = parse_rsa_options(L, options_idx);
size_t out_len; size_t out_len;
@ -470,7 +474,8 @@ static int LuaRSADecrypt(lua_State *L) {
// RSA Signing // RSA Signing
static char *RSASign(const char *private_key_pem, const unsigned char *data, static char *RSASign(const char *private_key_pem, const unsigned char *data,
size_t data_len, const char *hash_algo_str, size_t *sig_len) { size_t data_len, const char *hash_algo_str,
size_t *sig_len) {
int rc; int rc;
unsigned char hash[64]; // Large enough for SHA-512 unsigned char hash[64]; // Large enough for SHA-512
size_t hash_len = 32; // Default for SHA-256 size_t hash_len = 32; // Default for SHA-256
@ -1089,7 +1094,8 @@ static int LuaECDSAVerify(lua_State *L) {
const char *pub_key_pem = luaL_checkstring(L, 1); const char *pub_key_pem = luaL_checkstring(L, 1);
const char *message = luaL_checkstring(L, 2); const char *message = luaL_checkstring(L, 2);
size_t sig_len; size_t sig_len;
const unsigned char *signature = (const unsigned char *)luaL_checklstring(L, 3, &sig_len); const unsigned char *signature =
(const unsigned char *)luaL_checklstring(L, 3, &sig_len);
const char *hash_name = luaL_optstring(L, 4, "sha256"); const char *hash_name = luaL_optstring(L, 4, "sha256");
hash_algorithm_t hash_alg = string_to_hash_alg(hash_name); hash_algorithm_t hash_alg = string_to_hash_alg(hash_name);
@ -1100,7 +1106,6 @@ static int LuaECDSAVerify(lua_State *L) {
return 1; return 1;
} }
// AES // AES
// AES key generation helper // AES key generation helper
@ -1110,7 +1115,8 @@ static int LuaAesGenerateKey(lua_State *L) {
keybits = luaL_checkinteger(L, 1); keybits = luaL_checkinteger(L, 1);
} }
int keylen = keybits / 8; int keylen = keybits / 8;
if ((keybits != 128 && keybits != 192 && keybits != 256) || (keylen != 16 && keylen != 24 && keylen != 32)) { if ((keybits != 128 && keybits != 192 && keybits != 256) ||
(keylen != 16 && keylen != 24 && keylen != 32)) {
lua_pushnil(L); lua_pushnil(L);
lua_pushstring(L, "AES key length must be 128, 192, or 256 bits"); lua_pushstring(L, "AES key length must be 128, 192, or 256 bits");
return 2; return 2;
@ -1121,7 +1127,8 @@ static int LuaAesGenerateKey(lua_State *L) {
mbedtls_entropy_init(&entropy); mbedtls_entropy_init(&entropy);
mbedtls_ctr_drbg_init(&ctr_drbg); mbedtls_ctr_drbg_init(&ctr_drbg);
const char *pers = "aes_keygen"; const char *pers = "aes_keygen";
int ret = mbedtls_ctr_drbg_seed(&ctr_drbg, mbedtls_entropy_func, &entropy, (const unsigned char *)pers, strlen(pers)); int ret = mbedtls_ctr_drbg_seed(&ctr_drbg, mbedtls_entropy_func, &entropy,
(const unsigned char *)pers, strlen(pers));
if (ret != 0) { if (ret != 0) {
lua_pushnil(L); lua_pushnil(L);
lua_pushstring(L, "Failed to initialize RNG for AES key"); lua_pushstring(L, "Failed to initialize RNG for AES key");
@ -1142,29 +1149,6 @@ static int LuaAesGenerateKey(lua_State *L) {
} }
// Helper to get string field from options table // Helper to get string field from options table
typedef struct {
const char *mode;
const unsigned char *iv;
size_t ivlen;
} aes_options_t;
static void parse_aes_options(lua_State *L, int options_idx, aes_options_t *opts) {
opts->mode = "cbc";
opts->iv = NULL;
opts->ivlen = 0;
if (lua_istable(L, options_idx)) {
lua_getfield(L, options_idx, "mode");
if (!lua_isnil(L, -1)) opts->mode = lua_tostring(L, -1);
lua_pop(L, 1);
lua_getfield(L, options_idx, "iv");
if (lua_isstring(L, -1)) {
opts->iv = (const unsigned char *)lua_tolstring(L, -1, &opts->ivlen);
}
lua_pop(L, 1);
}
}
// Helper for AES decrypt options
typedef struct { typedef struct {
const char *mode; const char *mode;
const unsigned char *iv; const unsigned char *iv;
@ -1173,9 +1157,10 @@ typedef struct {
size_t taglen; size_t taglen;
const unsigned char *aad; const unsigned char *aad;
size_t aadlen; size_t aadlen;
} aes_decrypt_options_t; } aes_options_t;
static void parse_aes_decrypt_options(lua_State *L, int options_idx, aes_decrypt_options_t *opts) { static void parse_aes_options(lua_State *L, int options_idx,
aes_options_t *opts) {
opts->mode = "cbc"; opts->mode = "cbc";
opts->iv = NULL; opts->iv = NULL;
opts->ivlen = 0; opts->ivlen = 0;
@ -1185,7 +1170,8 @@ static void parse_aes_decrypt_options(lua_State *L, int options_idx, aes_decrypt
opts->aadlen = 0; opts->aadlen = 0;
if (lua_istable(L, options_idx)) { if (lua_istable(L, options_idx)) {
lua_getfield(L, options_idx, "mode"); lua_getfield(L, options_idx, "mode");
if (!lua_isnil(L, -1)) opts->mode = lua_tostring(L, -1); if (!lua_isnil(L, -1))
opts->mode = lua_tostring(L, -1);
lua_pop(L, 1); lua_pop(L, 1);
lua_getfield(L, options_idx, "iv"); lua_getfield(L, options_idx, "iv");
if (lua_isstring(L, -1)) { if (lua_isstring(L, -1)) {
@ -1209,8 +1195,10 @@ static void parse_aes_decrypt_options(lua_State *L, int options_idx, aes_decrypt
static int LuaAesEncrypt(lua_State *L) { static int LuaAesEncrypt(lua_State *L) {
// Args: key, plaintext, options table // Args: key, plaintext, options table
size_t keylen, ptlen; size_t keylen, ptlen;
const unsigned char *key = (const unsigned char *)luaL_checklstring(L, 1, &keylen); const unsigned char *key =
const unsigned char *plaintext = (const unsigned char *)luaL_checklstring(L, 2, &ptlen); (const unsigned char *)luaL_checklstring(L, 1, &keylen);
const unsigned char *plaintext =
(const unsigned char *)luaL_checklstring(L, 2, &ptlen);
int options_idx = 3; int options_idx = 3;
aes_options_t opts; aes_options_t opts;
parse_aes_options(L, options_idx, &opts); parse_aes_options(L, options_idx, &opts);
@ -1290,7 +1278,8 @@ static int LuaAesEncrypt(lua_State *L) {
} }
unsigned char iv_copy[16]; unsigned char iv_copy[16];
memcpy(iv_copy, iv, 16); memcpy(iv_copy, iv, 16);
ret = mbedtls_aes_crypt_cbc(&aes, MBEDTLS_AES_ENCRYPT, ctlen, iv_copy, input, output); ret = mbedtls_aes_crypt_cbc(&aes, MBEDTLS_AES_ENCRYPT, ctlen, iv_copy,
input, output);
mbedtls_aes_free(&aes); mbedtls_aes_free(&aes);
free(input); free(input);
if (ret != 0) { if (ret != 0) {
@ -1302,7 +1291,8 @@ static int LuaAesEncrypt(lua_State *L) {
lua_pushlstring(L, (const char *)output, ctlen); lua_pushlstring(L, (const char *)output, ctlen);
lua_pushlstring(L, (const char *)iv, ivlen); lua_pushlstring(L, (const char *)iv, ivlen);
free(output); free(output);
if (iv_was_generated) free(gen_iv); if (iv_was_generated)
free(gen_iv);
return 2; return 2;
} else if (is_ctr) { } else if (is_ctr) {
// CTR mode: no padding // CTR mode: no padding
@ -1327,7 +1317,8 @@ static int LuaAesEncrypt(lua_State *L) {
size_t nc_off = 0; size_t nc_off = 0;
memcpy(nonce_counter, iv, 16); memcpy(nonce_counter, iv, 16);
memset(stream_block, 0, 16); memset(stream_block, 0, 16);
ret = mbedtls_aes_crypt_ctr(&aes, ptlen, &nc_off, nonce_counter, stream_block, plaintext, output); ret = mbedtls_aes_crypt_ctr(&aes, ptlen, &nc_off, nonce_counter,
stream_block, plaintext, output);
mbedtls_aes_free(&aes); mbedtls_aes_free(&aes);
if (ret != 0) { if (ret != 0) {
free(output); free(output);
@ -1338,7 +1329,8 @@ static int LuaAesEncrypt(lua_State *L) {
lua_pushlstring(L, (const char *)output, ptlen); lua_pushlstring(L, (const char *)output, ptlen);
lua_pushlstring(L, (const char *)iv, ivlen); lua_pushlstring(L, (const char *)iv, ivlen);
free(output); free(output);
if (iv_was_generated) free(gen_iv); if (iv_was_generated)
free(gen_iv);
return 2; return 2;
} else if (is_gcm) { } else if (is_gcm) {
// GCM mode: authenticated encryption // GCM mode: authenticated encryption
@ -1360,8 +1352,8 @@ static int LuaAesEncrypt(lua_State *L) {
lua_pushstring(L, "Failed to set AES GCM key"); lua_pushstring(L, "Failed to set AES GCM key");
return 2; return 2;
} }
// Use actual ivlen, not hardcoded 16 ret = mbedtls_gcm_crypt_and_tag(&gcm, MBEDTLS_GCM_ENCRYPT, ptlen, iv, ivlen,
ret = mbedtls_gcm_crypt_and_tag(&gcm, MBEDTLS_GCM_ENCRYPT, ptlen, iv, ivlen, NULL, 0, plaintext, output, taglen, tag); NULL, 0, plaintext, output, taglen, tag);
mbedtls_gcm_free(&gcm); mbedtls_gcm_free(&gcm);
if (ret != 0) { if (ret != 0) {
free(output); free(output);
@ -1373,7 +1365,8 @@ static int LuaAesEncrypt(lua_State *L) {
lua_pushlstring(L, (const char *)iv, ivlen); lua_pushlstring(L, (const char *)iv, ivlen);
lua_pushlstring(L, (const char *)tag, taglen); lua_pushlstring(L, (const char *)tag, taglen);
free(output); free(output);
if (iv_was_generated) free(gen_iv); if (iv_was_generated)
free(gen_iv);
return 3; return 3;
} }
lua_pushnil(L); lua_pushnil(L);
@ -1385,11 +1378,13 @@ static int LuaAesEncrypt(lua_State *L) {
static int LuaAesDecrypt(lua_State *L) { static int LuaAesDecrypt(lua_State *L) {
// Args: key, ciphertext, options table // Args: key, ciphertext, options table
size_t keylen, ctlen; size_t keylen, ctlen;
const unsigned char *key = (const unsigned char *)luaL_checklstring(L, 1, &keylen); const unsigned char *key =
const unsigned char *ciphertext = (const unsigned char *)luaL_checklstring(L, 2, &ctlen); (const unsigned char *)luaL_checklstring(L, 1, &keylen);
const unsigned char *ciphertext =
(const unsigned char *)luaL_checklstring(L, 2, &ctlen);
int options_idx = 3; int options_idx = 3;
aes_decrypt_options_t opts; aes_options_t opts;
parse_aes_decrypt_options(L, options_idx, &opts); parse_aes_options(L, options_idx, &opts);
const char *mode = opts.mode; const char *mode = opts.mode;
const unsigned char *iv = opts.iv; const unsigned char *iv = opts.iv;
size_t ivlen = opts.ivlen; size_t ivlen = opts.ivlen;
@ -1467,7 +1462,8 @@ static int LuaAesDecrypt(lua_State *L) {
} }
unsigned char iv_copy[16]; unsigned char iv_copy[16];
memcpy(iv_copy, iv, 16); memcpy(iv_copy, iv, 16);
ret = mbedtls_aes_crypt_cbc(&aes, MBEDTLS_AES_DECRYPT, ctlen, iv_copy, ciphertext, output); ret = mbedtls_aes_crypt_cbc(&aes, MBEDTLS_AES_DECRYPT, ctlen, iv_copy,
ciphertext, output);
mbedtls_aes_free(&aes); mbedtls_aes_free(&aes);
if (ret != 0) { if (ret != 0) {
free(output); free(output);
@ -1524,7 +1520,8 @@ static int LuaAesDecrypt(lua_State *L) {
size_t nc_off = 0; size_t nc_off = 0;
memcpy(nonce_counter, iv, 16); memcpy(nonce_counter, iv, 16);
memset(stream_block, 0, 16); memset(stream_block, 0, 16);
ret = mbedtls_aes_crypt_ctr(&aes, ctlen, &nc_off, nonce_counter, stream_block, ciphertext, output); ret = mbedtls_aes_crypt_ctr(&aes, ctlen, &nc_off, nonce_counter,
stream_block, ciphertext, output);
mbedtls_aes_free(&aes); mbedtls_aes_free(&aes);
if (ret != 0) { if (ret != 0) {
free(output); free(output);
@ -1553,7 +1550,8 @@ static int LuaAesDecrypt(lua_State *L) {
lua_pushstring(L, "Failed to set AES GCM key"); lua_pushstring(L, "Failed to set AES GCM key");
return 2; return 2;
} }
ret = mbedtls_gcm_auth_decrypt(&gcm, ctlen, iv, ivlen, aad, aadlen, tag, taglen, ciphertext, output); ret = mbedtls_gcm_auth_decrypt(&gcm, ctlen, iv, ivlen, aad, aadlen, tag,
taglen, ciphertext, output);
mbedtls_gcm_free(&gcm); mbedtls_gcm_free(&gcm);
if (ret != 0) { if (ret != 0) {
free(output); free(output);
@ -1572,8 +1570,10 @@ static int LuaAesDecrypt(lua_State *L) {
// LuaCrypto compatible API // LuaCrypto compatible API
static int LuaCryptoSign(lua_State *L) { static int LuaCryptoSign(lua_State *L) {
const char *dtype = luaL_checkstring(L, 1); // Type of signature (e.g., "rsa", "ecdsa") const char *dtype =
lua_remove(L, 1); // Remove the first argument (key type or cipher type) before dispatching luaL_checkstring(L, 1); // Type of signature (e.g., "rsa", "ecdsa")
lua_remove(L, 1); // Remove the first argument (key type or cipher type)
// before dispatching
if (strcasecmp(dtype, "rsa") == 0) { if (strcasecmp(dtype, "rsa") == 0) {
return LuaRSASign(L); return LuaRSASign(L);
@ -1585,8 +1585,10 @@ static int LuaCryptoSign(lua_State *L) {
} }
static int LuaCryptoVerify(lua_State *L) { static int LuaCryptoVerify(lua_State *L) {
const char *dtype = luaL_checkstring(L, 1); // Type of signature (e.g., "rsa", "ecdsa") const char *dtype =
lua_remove(L, 1); // Remove the first argument (key type or cipher type) before dispatching luaL_checkstring(L, 1); // Type of signature (e.g., "rsa", "ecdsa")
lua_remove(L, 1); // Remove the first argument (key type or cipher type)
// before dispatching
if (strcasecmp(dtype, "rsa") == 0) { if (strcasecmp(dtype, "rsa") == 0) {
return LuaRSAVerify(L); return LuaRSAVerify(L);
@ -1600,10 +1602,10 @@ static int LuaCryptoVerify(lua_State *L) {
static int LuaCryptoEncrypt(lua_State *L) { static int LuaCryptoEncrypt(lua_State *L) {
// Args: cipher_type, key, msg, options table // Args: cipher_type, key, msg, options table
const char *cipher = luaL_checkstring(L, 1); const char *cipher = luaL_checkstring(L, 1);
// Remove cipher_type from stack, so key is at 1, msg at 2, options at 3 lua_remove(L, 1); // Remove cipher_type from stack, so key is at 1, msg at 2,
lua_remove(L, 1); // options at 3
if (strcasecmp(cipher, "rsa") == 0) { if (strcasecmp(cipher, "rsa") == 0) {
// Update LuaRSAEncrypt to accept (key, msg, options)
return LuaRSAEncrypt(L); return LuaRSAEncrypt(L);
} else if (strcasecmp(cipher, "aes") == 0) { } else if (strcasecmp(cipher, "aes") == 0) {
return LuaAesEncrypt(L); return LuaAesEncrypt(L);
@ -1615,7 +1617,10 @@ static int LuaCryptoEncrypt(lua_State *L) {
static int LuaCryptoDecrypt(lua_State *L) { static int LuaCryptoDecrypt(lua_State *L) {
// Args: cipher_type, key, ciphertext, options table // Args: cipher_type, key, ciphertext, options table
const char *cipher = luaL_checkstring(L, 1); const char *cipher = luaL_checkstring(L, 1);
lua_remove(L, 1); // Remove cipher_type, so key is at 1, ciphertext at 2, options at 3 lua_remove(
L,
1); // Remove cipher_type, so key is at 1, ciphertext at 2, options at 3
if (strcasecmp(cipher, "rsa") == 0) { if (strcasecmp(cipher, "rsa") == 0) {
return LuaRSADecrypt(L); return LuaRSADecrypt(L);
} else if (strcasecmp(cipher, "aes") == 0) { } else if (strcasecmp(cipher, "aes") == 0) {
@ -1626,12 +1631,13 @@ static int LuaCryptoDecrypt(lua_State *L) {
} }
static int LuaCryptoGenerateKeyPair(lua_State *L) { static int LuaCryptoGenerateKeyPair(lua_State *L) {
// If the first argument is a number, treat as RSA key length // If the first argument is a number, treat it as RSA key length
if (lua_gettop(L) >= 1 && lua_type(L, 1) == LUA_TNUMBER) { if (lua_gettop(L) >= 1 && lua_type(L, 1) == LUA_TNUMBER) {
// Call LuaRSAGenerateKeyPair with the number as the key length // Call LuaRSAGenerateKeyPair with the number as the key length
return LuaRSAGenerateKeyPair(L); return LuaRSAGenerateKeyPair(L);
} }
// Otherwise, get the key type from the first argument, default to "rsa" if not provided // Otherwise, get the key type from the first argument, default to "rsa" if
// not provided
const char *type = luaL_optstring(L, 1, "rsa"); const char *type = luaL_optstring(L, 1, "rsa");
lua_remove(L, 1); lua_remove(L, 1);
if (strcasecmp(type, "rsa") == 0) { if (strcasecmp(type, "rsa") == 0) {