Change of API for Encrypt and Decrypt.

The options are now passed in a table instead of positional parameters. This is not LuaCrypto compatible but it is a nicer interface.
This commit is contained in:
Miguel Terron 2025-06-04 21:13:09 +12:00
parent cef06a5b22
commit d06d0879b8
2 changed files with 155 additions and 97 deletions

View file

@ -69,7 +69,7 @@ local function test_aes_key_generation()
end end
-- Test AES encryption and decryption (CBC mode) -- Test AES encryption and decryption (CBC mode)
local function test_aes_encryption_decryption() local function test_aes_encryption_decryption_cbc()
local key = crypto.generatekeypair('aes', 256) -- 256-bit key local key = crypto.generatekeypair('aes', 256) -- 256-bit key
local plaintext = "Hello, AES CBC!" local plaintext = "Hello, AES CBC!"
@ -79,16 +79,16 @@ local function test_aes_encryption_decryption()
assert_equal(type(iv), "string", "IV type") assert_equal(type(iv), "string", "IV type")
-- Decrypt -- Decrypt
local decrypted = crypto.decrypt("aes", key, encrypted, iv) local decrypted = crypto.decrypt("aes", key, encrypted, {mode="cbc",iv=iv})
assert_equal(decrypted, plaintext, "Decrypted ciphertext matches plaintext") assert_equal(decrypted, plaintext, "Decrypted ciphertext matches plaintext")
-- Encrypt with explicit IV -- Encrypt with explicit IV
local iv2 = GetRandomBytes(16) local iv2 = GetRandomBytes(16)
local encrypted2, iv_used = crypto.encrypt("aes", key, plaintext, iv2) local encrypted2, iv_used = crypto.encrypt("aes", key, plaintext, {mode="cbc",iv=iv2})
assert_equal(type(encrypted2), "string", "Ciphertext type") assert_equal(type(encrypted2), "string", "Ciphertext type")
assert_equal(iv_used, iv2, "IV match") assert_equal(iv_used, iv2, "IV match")
local decrypted2 = crypto.decrypt("aes", key, encrypted2, iv2) local decrypted2 = crypto.decrypt("aes", key, encrypted2, {mode="cbc",iv=iv2})
assert_equal(decrypted2, plaintext, "Decrypted ciphertext matches plaintext") assert_equal(decrypted2, plaintext, "Decrypted ciphertext matches plaintext")
end end
@ -98,48 +98,49 @@ local function test_aes_encryption_decryption_ctr()
local plaintext = "Hello, AES CTR!" local plaintext = "Hello, AES CTR!"
-- Encrypt without providing IV (should auto-generate IV) -- Encrypt without providing IV (should auto-generate IV)
local encrypted, iv = crypto.encrypt("aes", key, plaintext, nil, "ctr") local encrypted, iv = crypto.encrypt("aes", key, plaintext, {mode="ctr"})
assert_equal(type(encrypted), "string", "Ciphertext type") assert_equal(type(encrypted), "string", "Ciphertext type")
assert_equal(type(iv), "string", "IV type") assert_equal(type(iv), "string", "IV type")
-- Decrypt -- Decrypt
local decrypted = crypto.decrypt("aes", key, encrypted, iv, "ctr") local decrypted = crypto.decrypt("aes", key, encrypted, {mode="ctr", iv=iv})
assert_equal(decrypted, plaintext, "Decrypted ciphertext matches plaintext") assert_equal(decrypted, plaintext, "Decrypted ciphertext matches plaintext")
-- Encrypt with explicit IV -- Encrypt with explicit IV
local iv2 = GetRandomBytes(16) local iv2 = GetRandomBytes(16)
local encrypted2, iv_used = crypto.encrypt("aes", key, plaintext, iv2, "ctr") local encrypted2, iv_used = crypto.encrypt("aes", key, plaintext, {mode="ctr", iv=iv2})
assert_equal(type(encrypted2), "string", "Ciphertext type") assert_equal(type(encrypted2), "string", "Ciphertext type")
assert_equal(iv_used, iv2, "IV match") assert_equal(iv_used, iv2, "IV match")
local decrypted2 = crypto.decrypt("aes", key, encrypted2, iv2, "ctr") local decrypted2 = crypto.decrypt("aes", key, encrypted2, {mode="ctr", iv=iv2})
assert_equal(decrypted2, plaintext, "Decrypted ciphertext matches plaintext") assert_equal(decrypted2, plaintext, "Decrypted ciphertext matches plaintext")
end end
-- Test AES encryption and decryption (GCM mode) -- Test AES encryption and decryption (GCM mode)
local function test_aes_encryption_decryption_gcm() local function test_aes_encryption_decryption_gcm()
local key = crypto.generatekeypair('aes', 256) local key = crypto.generatekeypair('aes', 256)
assert_equal(type(key), "string", "key type")
local plaintext = "Hello, AES GCM!" local plaintext = "Hello, AES GCM!"
-- Encrypt without providing IV (should auto-generate IV) -- Encrypt without providing IV (should auto-generate IV)
local encrypted, iv, tag = crypto.encrypt("aes", key, plaintext, nil, "gcm") local encrypted, iv, tag = crypto.encrypt("aes", key, plaintext, {mode="gcm"})
assert_equal(#plaintext, #encrypted, "Ciphertext length matches plaintext") assert_equal(#plaintext, #encrypted, "Ciphertext length matches plaintext")
assert_equal(type(encrypted), "string", "Ciphertext type") assert_equal(type(encrypted), "string", "Ciphertext type")
assert_equal(type(iv), "string", "IV type") assert_equal(type(iv), "string", "IV type")
assert_equal(type(tag), "string", "Tag type") assert_equal(type(tag), "string", "Tag type")
-- Decrypt -- Decrypt
local decrypted = crypto.decrypt("aes", key, encrypted, iv, "gcm", nil, tag) local decrypted = crypto.decrypt("aes", key, encrypted, {mode="gcm",iv=iv,tag=tag})
assert_equal(decrypted, plaintext, "Decrypted ciphertext matches plaintext") assert_equal(decrypted, plaintext, "Decrypted ciphertext matches plaintext")
-- Encrypt with explicit IV -- Encrypt with explicit IV
local iv2 = GetRandomBytes(13) -- GCM IV/nonce can be 12-16 bytes, 12 is standard local iv2 = GetRandomBytes(13) -- GCM IV/nonce can be 12-16 bytes, 12 is standard
local encrypted2, iv_used, tag2 = crypto.encrypt("aes", key, plaintext, iv2, "gcm") local encrypted2, iv_used, tag2 = crypto.encrypt("aes", key, plaintext, {mode="gcm",iv=iv2})
assert_equal(type(encrypted2), "string", "Ciphertext type") assert_equal(type(encrypted2), "string", "Ciphertext type")
assert_equal(iv_used, iv2, "IV match") assert_equal(iv_used, iv2, "IV match")
assert_equal(type(tag2), "string", "Tag type") assert_equal(type(tag2), "string", "Tag type")
local decrypted2 = crypto.decrypt("aes", key, encrypted2, iv2, "gcm", nil, tag2) local decrypted2 = crypto.decrypt("aes", key, encrypted2, {mode="gcm",iv=iv2,tag=tag2})
assert_equal(decrypted2, plaintext, "Decrypted ciphertext matches plaintext") assert_equal(decrypted2, plaintext, "Decrypted ciphertext matches plaintext")
end end
@ -200,7 +201,7 @@ local function run_tests()
test_ecdsa_keypair_generation() test_ecdsa_keypair_generation()
test_ecdsa_signing_verification() test_ecdsa_signing_verification()
test_aes_key_generation() test_aes_key_generation()
test_aes_encryption_decryption() test_aes_encryption_decryption_cbc()
test_aes_encryption_decryption_ctr() test_aes_encryption_decryption_ctr()
test_aes_encryption_decryption_gcm() test_aes_encryption_decryption_gcm()
test_pem_to_jwk() test_pem_to_jwk()

View file

@ -322,6 +322,18 @@ static int LuaRSAGenerateKeyPair(lua_State *L) {
return 2; return 2;
} }
// Helper to get string field from options table for RSA
// static const char *parse_rsa_options(lua_State *L, int options_idx) {
// const char *padding = "pkcs1"; // default
// if (lua_istable(L, options_idx)) {
// lua_getfield(L, options_idx, "padding");
// if (lua_isstring(L, -1)) {
// padding = lua_tostring(L, -1);
// }
// lua_pop(L, 1);
// }
// return padding;
// }
static char *RSAEncrypt(const char *public_key_pem, const unsigned char *data, static char *RSAEncrypt(const char *public_key_pem, const unsigned char *data,
size_t data_len, size_t *out_len) { size_t data_len, size_t *out_len) {
@ -368,13 +380,15 @@ static char *RSAEncrypt(const char *public_key_pem, const unsigned char *data,
return (char *)output; return (char *)output;
} }
static int LuaRSAEncrypt(lua_State *L) { static int LuaRSAEncrypt(lua_State *L) {
const char *public_key = luaL_checkstring(L, 1); // Args: key, plaintext, options table
size_t data_len; size_t keylen, ptlen;
const unsigned char *data = const char *key = luaL_checklstring(L, 1, &keylen);
(const unsigned char *)luaL_checklstring(L, 2, &data_len); const unsigned char *plaintext = (const unsigned char *)luaL_checklstring(L, 2, &ptlen);
// int options_idx = 3;
// const char *padding = parse_rsa_options(L, options_idx);
size_t out_len; size_t out_len;
char *encrypted = RSAEncrypt(public_key, data, data_len, &out_len); char *encrypted = RSAEncrypt(key, plaintext, ptlen, &out_len);
if (!encrypted) { if (!encrypted) {
lua_pushnil(L); lua_pushnil(L);
lua_pushstring(L, "Encryption failed"); lua_pushstring(L, "Encryption failed");
@ -433,14 +447,15 @@ static char *RSADecrypt(const char *private_key_pem,
return (char *)output; return (char *)output;
} }
static int LuaRSADecrypt(lua_State *L) { static int LuaRSADecrypt(lua_State *L) {
const char *private_key = luaL_checkstring(L, 1); // Args: key, ciphertext, options table
size_t encrypted_len; size_t keylen, ctlen;
const unsigned char *encrypted_data = const char *key = luaL_checklstring(L, 1, &keylen);
(const unsigned char *)luaL_checklstring(L, 2, &encrypted_len); const unsigned char *ciphertext = (const unsigned char *)luaL_checklstring(L, 2, &ctlen);
// int options_idx = 3;
// const char *padding = parse_rsa_options(L, options_idx);
size_t out_len; size_t out_len;
char *decrypted = char *decrypted = RSADecrypt(key, ciphertext, ctlen, &out_len);
RSADecrypt(private_key, encrypted_data, encrypted_len, &out_len);
if (!decrypted) { if (!decrypted) {
lua_pushnil(L); lua_pushnil(L);
lua_pushstring(L, "Decryption failed"); lua_pushstring(L, "Decryption failed");
@ -1087,6 +1102,7 @@ static int LuaECDSAVerify(lua_State *L) {
// AES // AES
// AES key generation helper // AES key generation helper
static int LuaAesGenerateKey(lua_State *L) { static int LuaAesGenerateKey(lua_State *L) {
int keybits = 128; int keybits = 128;
@ -1125,21 +1141,87 @@ static int LuaAesGenerateKey(lua_State *L) {
return 1; return 1;
} }
// Helper to get string field from options table
typedef struct {
const char *mode;
const unsigned char *iv;
size_t ivlen;
} aes_options_t;
static void parse_aes_options(lua_State *L, int options_idx, aes_options_t *opts) {
opts->mode = "cbc";
opts->iv = NULL;
opts->ivlen = 0;
if (lua_istable(L, options_idx)) {
lua_getfield(L, options_idx, "mode");
if (!lua_isnil(L, -1)) opts->mode = lua_tostring(L, -1);
lua_pop(L, 1);
lua_getfield(L, options_idx, "iv");
if (lua_isstring(L, -1)) {
opts->iv = (const unsigned char *)lua_tolstring(L, -1, &opts->ivlen);
}
lua_pop(L, 1);
}
}
// Helper for AES decrypt options
typedef struct {
const char *mode;
const unsigned char *iv;
size_t ivlen;
const unsigned char *tag;
size_t taglen;
const unsigned char *aad;
size_t aadlen;
} aes_decrypt_options_t;
static void parse_aes_decrypt_options(lua_State *L, int options_idx, aes_decrypt_options_t *opts) {
opts->mode = "cbc";
opts->iv = NULL;
opts->ivlen = 0;
opts->tag = NULL;
opts->taglen = 0;
opts->aad = NULL;
opts->aadlen = 0;
if (lua_istable(L, options_idx)) {
lua_getfield(L, options_idx, "mode");
if (!lua_isnil(L, -1)) opts->mode = lua_tostring(L, -1);
lua_pop(L, 1);
lua_getfield(L, options_idx, "iv");
if (lua_isstring(L, -1)) {
opts->iv = (const unsigned char *)lua_tolstring(L, -1, &opts->ivlen);
}
lua_pop(L, 1);
lua_getfield(L, options_idx, "tag");
if (lua_isstring(L, -1)) {
opts->tag = (const unsigned char *)lua_tolstring(L, -1, &opts->taglen);
}
lua_pop(L, 1);
lua_getfield(L, options_idx, "aad");
if (lua_isstring(L, -1)) {
opts->aad = (const unsigned char *)lua_tolstring(L, -1, &opts->aadlen);
}
lua_pop(L, 1);
}
}
// AES encryption supporting CBC, GCM, and CTR modes // AES encryption supporting CBC, GCM, and CTR modes
static int LuaAesEncrypt(lua_State *L) { static int LuaAesEncrypt(lua_State *L) {
// Accept IV as the 3rd argument (after key, plaintext) // Args: key, plaintext, options table
size_t keylen, ivlen = 0, ptlen; size_t keylen, ptlen;
const unsigned char *key = (const unsigned char *)luaL_checklstring(L, 1, &keylen); const unsigned char *key = (const unsigned char *)luaL_checklstring(L, 1, &keylen);
const unsigned char *plaintext = (const unsigned char *)luaL_checklstring(L, 2, &ptlen); const unsigned char *plaintext = (const unsigned char *)luaL_checklstring(L, 2, &ptlen);
const unsigned char *iv = NULL; int options_idx = 3;
aes_options_t opts;
parse_aes_options(L, options_idx, &opts);
const char *mode = opts.mode;
const unsigned char *iv = opts.iv;
size_t ivlen = opts.ivlen;
unsigned char *gen_iv = NULL; unsigned char *gen_iv = NULL;
int iv_was_generated = 0; int iv_was_generated = 0;
const char *mode = luaL_optstring(L, 4, "cbc"); // Default to CBC if not provided
int ret = 0; int ret = 0;
unsigned char *output = NULL; unsigned char *output = NULL;
int is_gcm = 0, is_ctr = 0, is_cbc = 0; int is_gcm = 0, is_ctr = 0, is_cbc = 0;
if (strcasecmp(mode, "cbc") == 0) { if (strcasecmp(mode, "cbc") == 0) {
is_cbc = 1; is_cbc = 1;
} else if (strcasecmp(mode, "gcm") == 0) { } else if (strcasecmp(mode, "gcm") == 0) {
@ -1151,10 +1233,8 @@ static int LuaAesEncrypt(lua_State *L) {
lua_pushstring(L, "Unsupported AES mode. Use 'cbc', 'gcm', or 'ctr'."); lua_pushstring(L, "Unsupported AES mode. Use 'cbc', 'gcm', or 'ctr'.");
return 2; return 2;
} }
// If IV is not provided, auto-generate
// If IV is not provided (arg3 is nil or missing), auto-generate if (!iv) {
if (lua_isnoneornil(L, 3)) {
// For GCM, standard is 12 bytes, but allow 12-16
if (is_gcm) { if (is_gcm) {
ivlen = 12; ivlen = 12;
} else { } else {
@ -1176,26 +1256,7 @@ static int LuaAesEncrypt(lua_State *L) {
mbedtls_entropy_free(&entropy); mbedtls_entropy_free(&entropy);
iv = gen_iv; iv = gen_iv;
iv_was_generated = 1; iv_was_generated = 1;
} else {
// IV provided
iv = (const unsigned char *)luaL_checklstring(L, 3, &ivlen);
// Do not force ivlen to 16 here! Accept actual length for GCM (12-16)
if (is_cbc || is_ctr) {
if (ivlen != 16) {
lua_pushnil(L);
lua_pushstring(L, "AES IV must be 16 bytes for CBC/CTR");
return 2;
} }
} else if (is_gcm) {
if (ivlen < 12 || ivlen > 16) {
lua_pushnil(L);
lua_pushstring(L, "AES GCM IV/nonce must be 12-16 bytes");
return 2;
}
}
iv_was_generated = 0;
}
if (is_cbc) { if (is_cbc) {
// PKCS7 padding // PKCS7 padding
size_t block_size = 16; size_t block_size = 16;
@ -1322,16 +1383,21 @@ static int LuaAesEncrypt(lua_State *L) {
// AES decryption supporting CBC, GCM, and CTR modes // AES decryption supporting CBC, GCM, and CTR modes
static int LuaAesDecrypt(lua_State *L) { static int LuaAesDecrypt(lua_State *L) {
size_t keylen, ctlen, ivlen; // Args: key, ciphertext, options table
size_t keylen, ctlen;
const unsigned char *key = (const unsigned char *)luaL_checklstring(L, 1, &keylen); const unsigned char *key = (const unsigned char *)luaL_checklstring(L, 1, &keylen);
const unsigned char *ciphertext = (const unsigned char *)luaL_checklstring(L, 2, &ctlen); const unsigned char *ciphertext = (const unsigned char *)luaL_checklstring(L, 2, &ctlen);
const unsigned char *iv = (const unsigned char *)luaL_checklstring(L, 3, &ivlen); int options_idx = 3;
const char *mode = luaL_optstring(L, 4, "cbc"); // Default to CBC if not provided aes_decrypt_options_t opts;
const unsigned char *aad = NULL; parse_aes_decrypt_options(L, options_idx, &opts);
const unsigned char *tag = NULL; const char *mode = opts.mode;
size_t aadlen = 0, taglen = 0; const unsigned char *iv = opts.iv;
size_t ivlen = opts.ivlen;
const unsigned char *tag = opts.tag;
size_t taglen = opts.taglen;
const unsigned char *aad = opts.aad;
size_t aadlen = opts.aadlen;
int is_gcm = 0, is_ctr = 0, is_cbc = 0; int is_gcm = 0, is_ctr = 0, is_cbc = 0;
if (strcasecmp(mode, "cbc") == 0) { if (strcasecmp(mode, "cbc") == 0) {
is_cbc = 1; is_cbc = 1;
} else if (strcasecmp(mode, "gcm") == 0) { } else if (strcasecmp(mode, "gcm") == 0) {
@ -1343,7 +1409,6 @@ static int LuaAesDecrypt(lua_State *L) {
lua_pushstring(L, "Unsupported AES mode. Use 'cbc', 'gcm', or 'ctr'."); lua_pushstring(L, "Unsupported AES mode. Use 'cbc', 'gcm', or 'ctr'.");
return 2; return 2;
} }
// Validate key length (16, 24, 32 bytes) // Validate key length (16, 24, 32 bytes)
if (keylen != 16 && keylen != 24 && keylen != 32) { if (keylen != 16 && keylen != 24 && keylen != 32) {
lua_pushnil(L); lua_pushnil(L);
@ -1367,21 +1432,11 @@ static int LuaAesDecrypt(lua_State *L) {
// GCM: require tag and optional AAD // GCM: require tag and optional AAD
if (is_gcm) { if (is_gcm) {
if (!lua_isnoneornil(L, 5)) { if (!tag || taglen < 12 || taglen > 16) {
aad = (const unsigned char *)luaL_checklstring(L, 5, &aadlen);
}
if (!lua_isnoneornil(L, 6)) {
tag = (const unsigned char *)luaL_checklstring(L, 6, &taglen);
if (taglen < 12 || taglen > 16) {
lua_pushnil(L); lua_pushnil(L);
lua_pushstring(L, "AES GCM tag must be 12-16 bytes"); lua_pushstring(L, "AES GCM tag must be 12-16 bytes");
return 2; return 2;
} }
} else {
lua_pushnil(L);
lua_pushstring(L, "AES GCM tag required as 6th argument");
return 2;
}
} }
int ret = 0; int ret = 0;
@ -1543,10 +1598,12 @@ static int LuaCryptoVerify(lua_State *L) {
} }
static int LuaCryptoEncrypt(lua_State *L) { static int LuaCryptoEncrypt(lua_State *L) {
const char *cipher = luaL_checkstring(L, 1); // Cipher type (e.g., "rsa", "aes") // Args: cipher_type, key, msg, options table
lua_remove(L, 1); // Remove the first argument (key type or cipher type) before dispatching const char *cipher = luaL_checkstring(L, 1);
// Remove cipher_type from stack, so key is at 1, msg at 2, options at 3
lua_remove(L, 1);
if (strcasecmp(cipher, "rsa") == 0) { if (strcasecmp(cipher, "rsa") == 0) {
// Update LuaRSAEncrypt to accept (key, msg, options)
return LuaRSAEncrypt(L); return LuaRSAEncrypt(L);
} else if (strcasecmp(cipher, "aes") == 0) { } else if (strcasecmp(cipher, "aes") == 0) {
return LuaAesEncrypt(L); return LuaAesEncrypt(L);
@ -1556,9 +1613,9 @@ static int LuaCryptoEncrypt(lua_State *L) {
} }
static int LuaCryptoDecrypt(lua_State *L) { static int LuaCryptoDecrypt(lua_State *L) {
const char *cipher = luaL_checkstring(L, 1); // Cipher type (e.g., "rsa", "aes") // Args: cipher_type, key, ciphertext, options table
lua_remove(L, 1); // Remove the first argument (key type or cipher type) before dispatching const char *cipher = luaL_checkstring(L, 1);
lua_remove(L, 1); // Remove cipher_type, so key is at 1, ciphertext at 2, options at 3
if (strcasecmp(cipher, "rsa") == 0) { if (strcasecmp(cipher, "rsa") == 0) {
return LuaRSADecrypt(L); return LuaRSADecrypt(L);
} else if (strcasecmp(cipher, "aes") == 0) { } else if (strcasecmp(cipher, "aes") == 0) {