diff --git a/test/tool/net/lcrypto_test.lua b/test/tool/net/lcrypto_test.lua index 34418533a..52a7a5521 100644 --- a/test/tool/net/lcrypto_test.lua +++ b/test/tool/net/lcrypto_test.lua @@ -1,14 +1,23 @@ -- Helper function to print test results -local function assert_equal(actual, expected, message) +local function assert_equal(actual, expected, plaintext) if actual ~= expected then - error(message .. ": expected " .. tostring(expected) .. ", got " .. tostring(actual)) + error(plaintext .. ": expected " .. tostring(expected) .. ", got " .. tostring(actual)) else - print("PASS: " .. message) + print("PASS: " .. plaintext) + end +end + +local function assert_not_equal(actual, not_expected, plaintext) + if actual == not_expected then + error(plaintext .. ": did not expect " .. tostring(not_expected)) + else + print("PASS: " .. plaintext) end end -- Test RSA key pair generation local function test_rsa_keypair_generation() + print('\27[1;7mTest RSA key pair generation \27[0m') local priv_key, pub_key = crypto.generatekeypair("rsa", 2048) assert_equal(type(priv_key), "string", "RSA private key generation") assert_equal(type(pub_key), "string", "RSA public key generation") @@ -16,6 +25,7 @@ end -- Test ECDSA key pair generation local function test_ecdsa_keypair_generation() + print('\n\27[1;7mTest ECDSA key pair generation \27[0m') local priv_key, pub_key = crypto.generatekeypair("ecdsa", "secp256r1") assert_equal(type(priv_key), "string", "ECDSA private key generation") assert_equal(type(pub_key), "string", "ECDSA public key generation") @@ -23,61 +33,207 @@ end -- Test RSA encryption and decryption local function test_rsa_encryption_decryption() + print('\n\27[1;7mTest RSA encryption and decryption \27[0m') local priv_key, pub_key = crypto.generatekeypair("rsa", 2048) - local message = "Hello, RSA!" - local encrypted = crypto.encrypt("rsa", pub_key, message) + local plaintext = "Hello, RSA!" + local encrypted = crypto.encrypt("rsa", pub_key, plaintext) assert_equal(type(encrypted), "string", "RSA encryption") local decrypted = crypto.decrypt("rsa", priv_key, encrypted) - assert_equal(decrypted, message, "RSA decryption") + assert_equal(decrypted, plaintext, "RSA decryption") end -- Test RSA signing and verification local function test_rsa_signing_verification() + print('\n\27[1;7mTest RSA signing and verification \27[0m') local priv_key, pub_key = crypto.generatekeypair("rsa", 2048) - local message = "Sign this message" - local signature = crypto.sign("rsa", priv_key, message, "sha256") + local plaintext = "Sign this plaintext" + local signature = crypto.sign("rsa", priv_key, plaintext, "sha256") assert_equal(type(signature), "string", "RSA signing") - local is_valid = crypto.verify("rsa", pub_key, message, signature, "sha256") + local is_valid = crypto.verify("rsa", pub_key, plaintext, signature, "sha256") assert_equal(is_valid, true, "RSA signature verification") end -- Test ECDSA signing and verification local function test_ecdsa_signing_verification() + print('\n\27[1;7mTest ECDSA signing and verification \27[0m') local priv_key, pub_key = crypto.generatekeypair("ecdsa", "secp256r1") - local message = "Sign this message with ECDSA" - local signature = crypto.sign("ecdsa", priv_key, message, "sha256") + local plaintext = "Sign this plaintext with ECDSA" + local signature = crypto.sign("ecdsa", priv_key, plaintext, "sha256") assert_equal(type(signature), "string", "ECDSA signing") - local is_valid = crypto.verify("ecdsa", pub_key, message, signature, "sha256") + local is_valid = crypto.verify("ecdsa", pub_key, plaintext, signature, "sha256") assert_equal(is_valid, true, "ECDSA signature verification") end --- Test CSR generation -local function test_csr_generation() - local priv_key, pub_key = crypto.generatekeypair("rsa", 2048) - local subject_name = "CN=example.com,O=Example Org,C=US" - local csr = crypto.generateCsr(priv_key, subject_name) - assert_equal(type(csr), "string", "CSR generation") +-- Test AES key generation +local function test_aes_key_generation() + print('\n\27[1;7mTest AES key generation \27[0m') + local key = crypto.generatekeypair('aes', 256) -- 256-bit key + assert_equal(type(key), "string", "AES key generation") + assert_equal(#key, 32, "AES key length (256 bits)") +end + +-- Test AES encryption and decryption (CBC mode) +local function test_aes_encryption_decryption() + print('\n\27[1;7mTest AES encryption and decryption (CBC mode) \27[0m') + local key = crypto.generatekeypair('aes',256) -- 256-bit key + local plaintext = "Hello, AES CBC!" + + -- Encrypt without providing IV (should auto-generate IV) + print('\27[1mAES encryption (auto IV)\27[0m') + local encrypted, iv = crypto.encrypt("aes", key, plaintext, nil) + assert_equal(type(encrypted), "string", "AES encryption (CBC, auto IV)") + assert_equal(type(iv), "string", "AES IV (auto-generated)") + + -- Decrypt + print('\n\27[1mAES decryption (auto IV)\27[0m') + local decrypted = crypto.decrypt("aes", key, encrypted, iv) + assert_equal(decrypted, plaintext, "AES decryption (CBC, auto IV)") + + -- Encrypt with explicit IV + print('\n\27[1mAES encryption (explicit IV)\27[0m') + local iv2 = GetRandomBytes(16) + local encrypted2, iv_used = crypto.encrypt("aes", key, plaintext, iv2) + assert_equal(type(encrypted2), "string", "AES encryption (CBC, explicit IV)") + assert_equal(iv_used, iv2, "AES IV (explicit)") + + print('\n\27[1mAES decryption (explicit IV)\27[0m') + local decrypted2 = crypto.decrypt("aes", key, encrypted2, iv2) + assert_equal(decrypted2, plaintext, "AES decryption (CBC, explicit IV)") +end + +-- Test AES encryption and decryption (CTR mode) +local function test_aes_encryption_decryption_ctr() + print('\n\27[1;7mTest AES encryption and decryption (CTR mode) \27[0m') + local key = crypto.generatekeypair('aes',256) + local plaintext = "Hello, AES CTR!" + + -- Encrypt without providing IV (should auto-generate IV) + print('\27[1mAES encryption (auto IV)\27[0m') + local encrypted, iv = crypto.encrypt("aes", key, plaintext, nil, "ctr") + assert_equal(type(encrypted), "string", "AES encryption (CTR, auto IV)") + assert_equal(type(iv), "string", "AES IV (auto-generated, CTR)") + + -- Decrypt + print('\n\27[1mAES decryption (auto IV)\27[0m') + local decrypted = crypto.decrypt("aes", key, encrypted, iv, "ctr") + assert_equal(decrypted, plaintext, "AES decryption (CTR, auto IV)") + + -- Encrypt with explicit IV + print('\n\27[1mAES encryption (explicit IV)\27[0m') + local iv2 = GetRandomBytes(16) + local encrypted2, iv_used = crypto.encrypt("aes", key, plaintext, iv2, "ctr") + assert_equal(type(encrypted2), "string", "AES encryption (CTR, explicit IV)") + assert_equal(iv_used, iv2, "AES IV (explicit, CTR)") + + print('\n\27[1mAES decryption (explicit IV)\27[0m') + local decrypted2 = crypto.decrypt("aes", key, encrypted2, iv2, "ctr") + assert_equal(decrypted2, plaintext, "AES decryption (CTR, explicit IV)") +end + +-- Test AES encryption and decryption (GCM mode) +local function test_aes_encryption_decryption_gcm() + print('\n\27[1;7mTest AES encryption and decryption (GCM mode) \27[0m') + local key = crypto.generatekeypair('aes',256) + local plaintext = "Hello, AES GCM!" + + -- Encrypt without providing IV (should auto-generate IV) + print('\27[1mAES encryption (auto IV)\27[0m') + local encrypted, iv, tag = crypto.encrypt("aes", key, plaintext, nil, "gcm") + assert_equal(type(encrypted), "string", "AES encryption (GCM, auto IV)") + assert_equal(type(iv), "string", "AES IV (auto-generated, GCM)") + assert_equal(type(tag), "string", "AES GCM tag (auto IV)") + + -- Decrypt + print('\n\27[1mAES decryption (auto IV)\27[0m') + local decrypted = crypto.decrypt("aes", key, encrypted, iv, "gcm", nil, tag) + assert_equal(decrypted, plaintext, "AES decryption (GCM, auto IV)") + + -- Encrypt with explicit IV + print('\n\27[1mAES encryption (explicit IV)\27[0m') + local iv2 = GetRandomBytes(13) -- GCM IV/nonce can be 12-16 bytes, 12 is standard + local encrypted2, iv_used, tag2 = crypto.encrypt("aes", key, plaintext, iv2, "gcm") + assert_equal(type(encrypted2), "string", "AES encryption (GCM, explicit IV)") + assert_equal(iv_used, iv2, "AES IV (explicit, GCM)") + assert_equal(type(tag2), "string", "AES GCM tag (explicit IV)") + + print('\n\27[1mAES decryption (explicit IV)\27[0m') + local decrypted2 = crypto.decrypt("aes", key, encrypted2, iv2, "gcm", nil, tag2) + assert_equal(decrypted2, plaintext, "AES decryption (GCM, explicit IV)") end -- Test PemToJwk conversion local function test_pem_to_jwk() - local priv_key, pub_key = crypto.generatekeypair("rsa", 2048) - local jwk = crypto.convertPemToJwk(pub_key) - assert_equal(type(jwk), "table", "PEM to JWK conversion") - assert_equal(jwk.kty, "RSA", "JWK key type") + print('\n\27[1;7mTest PEM to JWK conversion \27[0m') + local priv_key, pub_key = crypto.generatekeypair() + print('\27[1mRSA Private key to JWK conversion\27[0m') + local priv_jwk = crypto.convertPemToJwk(priv_key) + assert_equal(type(priv_jwk), "table", "PEM to JWK conversion") + assert_equal(priv_jwk.kty, "RSA", "JWK key type") + + print('\n\27[1mRSA Public key to JWK conversion\27[0m') + local pub_jwk = crypto.convertPemToJwk(pub_key) + assert_equal(type(pub_jwk), "table", "PEM to JWK conversion") + assert_equal(pub_jwk.kty, "RSA", "JWK key type") + + -- Test ECDSA keys + local priv_key, pub_key = crypto.generatekeypair('ecdsa') + print('\n\27[1mECDSA Private key to JWK conversion\27[0m') + local priv_jwk = crypto.convertPemToJwk(priv_key) + assert_equal(type(priv_jwk), "table", "PEM to JWK conversion") + assert_equal(priv_jwk.kty, "EC", "JWK key type") + + print('\n\27[1mECDSA Public key to JWK conversion\27[0m') + local pub_jwk = crypto.convertPemToJwk(pub_key) + assert_equal(type(pub_jwk), "table", "PEM to JWK conversion") + assert_equal(pub_jwk.kty, "EC", "JWK key type") +end + +-- Test CSR generation +local function test_csr_generation() + print('\n\27[1;7mTest CSR generation \27[0m') + local priv_key, _ = crypto.generatekeypair() + local subject_name = "CN=example.com,O=Example Org,C=US" + local san = "DNS:example.com, DNS:www.example.com, IP:192.168.1.1" + + local csr = crypto.GenerateCsr(priv_key, subject_name) + assert_equal(type(csr), "string", "CSR generation with subject name") + + csr = crypto.GenerateCsr(priv_key, subject_name, san) + assert_equal(type(csr), "string", "CSR generation with subject name and san") + + csr = crypto.GenerateCsr(priv_key, nil, san) + assert_equal(type(csr), "string", "CSR generation with nil subject name and san") + + csr = crypto.GenerateCsr(priv_key, '', san) + assert_equal(type(csr), "string", "CSR generation with empty subject name and san") + + -- These should fail + csr = crypto.GenerateCsr(priv_key, '') + assert_not_equal(type(csr), "string", "CSR generation with empty subject name and no san is rejected") + + csr = crypto.GenerateCsr(priv_key) + assert_not_equal(type(csr), "string", "CSR generation with nil subject name and no san is rejected") end -- Run all tests local function run_tests() print("Running tests for lcrypto...") test_rsa_keypair_generation() - test_ecdsa_keypair_generation() - test_rsa_encryption_decryption() test_rsa_signing_verification() + test_rsa_encryption_decryption() + test_ecdsa_keypair_generation() test_ecdsa_signing_verification() - test_csr_generation() + test_aes_key_generation() + test_aes_encryption_decryption() + test_aes_encryption_decryption_ctr() + test_aes_encryption_decryption_gcm() test_pem_to_jwk() + test_csr_generation() + print('') print("All tests passed!") + EXIT=0 + return EXIT end -run_tests() +EXIT=70 +os.exit(run_tests()) diff --git a/third_party/mbedtls/config.h b/third_party/mbedtls/config.h index 88087503e..d181060b1 100644 --- a/third_party/mbedtls/config.h +++ b/third_party/mbedtls/config.h @@ -40,9 +40,9 @@ #define MBEDTLS_GCM_C #ifndef TINY #define MBEDTLS_CIPHER_MODE_CBC +#define MBEDTLS_CIPHER_MODE_CTR /*#define MBEDTLS_CCM_C*/ /*#define MBEDTLS_CIPHER_MODE_CFB*/ -/*#define MBEDTLS_CIPHER_MODE_CTR*/ /*#define MBEDTLS_CIPHER_MODE_OFB*/ /*#define MBEDTLS_CIPHER_MODE_XTS*/ #endif diff --git a/tool/net/lcrypto.c b/tool/net/lcrypto.c index f5d2bced2..d5ca2890f 100644 --- a/tool/net/lcrypto.c +++ b/tool/net/lcrypto.c @@ -9,6 +9,10 @@ #include "third_party/mbedtls/oid.h" #include "third_party/mbedtls/md.h" #include "third_party/mbedtls/base64.h" +#include "third_party/mbedtls/aes.h" +#include "third_party/mbedtls/ctr_drbg.h" +#include "third_party/mbedtls/entropy.h" +#include "third_party/mbedtls/gcm.h" // Standard C library and redbean utilities #include "libc/errno.h" @@ -16,8 +20,8 @@ #include "libc/str/str.h" #include "tool/net/luacheck.h" -// Updated PemToJwk to parse PEM keys and convert them into JWK format -static int convertPemToJwk(lua_State *L) { +// Parse PEM keys and convert them into JWK format +static int LuaConvertPemToJwk(lua_State *L) { const char *pem_key = luaL_checkstring(L, 1); mbedtls_pk_context key; @@ -166,11 +170,23 @@ static int convertPemToJwk(lua_State *L) { } // CSR Creation Function -static int generateCsr(lua_State *L) { +static int LuaGenerateCSR(lua_State *L) { const char *key_pem = luaL_checkstring(L, 1); - const char *subject_name = luaL_checkstring(L, 2); + const char *subject_name; const char *san_list = luaL_optstring(L, 3, NULL); + if (lua_isnoneornil(L, 2)) { + subject_name = ""; + } else { + subject_name = luaL_checkstring(L, 2); + } + + + if (lua_isnoneornil(L, 3) && subject_name[0] == '\0') { + lua_pushnil(L); + lua_pushstring(L, "Subject name or SANs are required"); + return 2; + } mbedtls_pk_context key; mbedtls_x509write_csr req; char buf[4096]; @@ -211,7 +227,9 @@ static int generateCsr(lua_State *L) { return 1; } +// RSA +// Generate RSA Key Pair static bool RSAGenerateKeyPair(char **private_key_pem, size_t *private_key_len, char **public_key_pem, size_t *public_key_len, unsigned int key_length) { @@ -263,6 +281,7 @@ static bool RSAGenerateKeyPair(char **private_key_pem, size_t *private_key_len, mbedtls_pk_free(&key); return true; } + /** * Lua wrapper for RSA key pair generation * @@ -272,43 +291,38 @@ static bool RSAGenerateKeyPair(char **private_key_pem, size_t *private_key_len, * error_message) */ static int LuaRSAGenerateKeyPair(lua_State *L) { - char *private_key, *public_key; - size_t private_len, public_len; - int key_length = 2048; // Default RSA key length - - // Get key length from Lua (optional parameter) - if (lua_gettop(L) >= 1 && !lua_isnil(L, 1)) { - key_length = luaL_checkinteger(L, 1); - // Validate key length (common RSA key lengths are 1024, 2048, 3072, 4096) - if (key_length != 1024 && key_length != 2048 && key_length != 3072 && - key_length != 4096) { - lua_pushnil(L); - lua_pushstring(L, - "Invalid RSA key length. Use 1024, 2048, 3072, or 4096."); - return 2; + int bits = 2048; + // If no arguments, or first argument is nil, default to 2048 + if (lua_gettop(L) == 0 || lua_isnoneornil(L, 1)) { + bits = 2048; + } else if (lua_gettop(L) == 1 && lua_type(L, 1) == LUA_TNUMBER) { + bits = (int)lua_tointeger(L, 1); + } else { + bits = (int)luaL_optinteger(L, 2, 2048); } - } - // Call the C function to generate the key pair - if (!RSAGenerateKeyPair(&private_key, &private_len, &public_key, &public_len, - key_length)) { - lua_pushnil(L); - lua_pushstring(L, "Failed to generate RSA key pair"); + char *private_key, *public_key; + size_t private_len, public_len; + + // Call the C function to generate the key pair + if (!RSAGenerateKeyPair(&private_key, &private_len, &public_key, &public_len, bits)) { + lua_pushnil(L); + lua_pushstring(L, "Failed to generate RSA key pair"); + return 2; + } + + // Push results to Lua + lua_pushstring(L, private_key); + lua_pushstring(L, public_key); + + // Clean up + free(private_key); + free(public_key); + return 2; - } - - // Push results to Lua - lua_pushstring(L, private_key); - lua_pushstring(L, public_key); - - // Clean up - free(private_key); - free(public_key); - - return 2; } -// RSA + static char *RSAEncrypt(const char *public_key_pem, const unsigned char *data, size_t data_len, size_t *out_len) { int rc; @@ -622,7 +636,7 @@ static int LuaRSAVerify(lua_State *L) { return 1; } - +// Elliptic Curve Cryptography Functions // Supported curves mapping typedef struct { const char *name; @@ -710,6 +724,7 @@ static int LuaListHashAlgorithms(lua_State *L) { return 1; } + // List available curves static int LuaListCurves(lua_State *L) { const curve_map_t *curve = supported_curves; @@ -911,68 +926,77 @@ static int LuaECDSAGenerateKeyPair(lua_State *L) { static int ECDSASign(const char *priv_key_pem, const char *message, hash_algorithm_t hash_alg, unsigned char **signature, size_t *sig_len) { - mbedtls_pk_context key; - unsigned char hash[64]; // Max hash size (SHA-512) - size_t hash_size; - int ret; + mbedtls_pk_context key; + unsigned char hash[64]; // Max hash size (SHA-512) + size_t hash_size; + int ret; + *signature = NULL; + *sig_len = 0; + + if (!priv_key_pem) { + WARNF("(ecdsa) Private key is NULL"); + return -1; + } + + // Get the length of the PEM string (excluding null terminator) + size_t key_len = strlen(priv_key_pem); + if (key_len == 0) { + WARNF("(ecdsa) Private key is empty"); + return -1; + } + + // Get hash size for the selected algorithm + hash_size = get_hash_size(hash_alg); + + mbedtls_pk_init(&key); + + // Parse the private key from PEM directly without creating a copy + ret = mbedtls_pk_parse_key(&key, (const unsigned char *)priv_key_pem, + key_len + 1, NULL, 0); + + if (ret != 0) { + WARNF("(ecdsa) Failed to parse private key: -0x%04x", -ret); + goto cleanup; + } + + // Compute hash of the message using the specified algorithm + ret = compute_hash(hash_alg, (const unsigned char *)message, strlen(message), + hash, sizeof(hash)); + if (ret != 0) { + WARNF("(ecdsa) Failed to compute message hash"); + goto cleanup; + } + + // Allocate memory for signature (max size for ECDSA) + *signature = malloc(MBEDTLS_ECDSA_MAX_LEN); + if (*signature == NULL) { + WARNF("(ecdsa) Failed to allocate memory for signature"); + ret = -1; + goto cleanup; + } + + // Sign the hash using GenerateHardRandom + ret = mbedtls_pk_sign(&key, hash_to_md_type(hash_alg), hash, hash_size, + *signature, sig_len, GenerateHardRandom, 0); + + if (ret != 0) { + WARNF("(ecdsa) Failed to sign message: -0x%04x", -ret); + free(*signature); *signature = NULL; *sig_len = 0; + goto cleanup; + } - if (!priv_key_pem || strlen(priv_key_pem) == 0) { - WARNF("(ecdsa) Private key is NULL or empty"); - return -1; - } - - mbedtls_pk_init(&key); - - // Parse the private key from PEM (PKCS#8 format) - ret = mbedtls_pk_parse_key(&key, (const unsigned char *)priv_key_pem, - strlen(priv_key_pem) + 1, NULL, 0); - if (ret != 0) { - WARNF("(ecdsa) Failed to parse private key: -0x%04x", -ret); - mbedtls_pk_free(&key); - return -1; - } - - // Compute hash of the message - hash_size = get_hash_size(hash_alg); - ret = compute_hash(hash_alg, (const unsigned char *)message, strlen(message), - hash, sizeof(hash)); - if (ret != 0) { - WARNF("(ecdsa) Failed to compute message hash"); - mbedtls_pk_free(&key); - return -1; - } - - // Allocate memory for the signature - *signature = malloc(MBEDTLS_PK_SIGNATURE_MAX_SIZE); - if (*signature == NULL) { - WARNF("(ecdsa) Failed to allocate memory for signature"); - mbedtls_pk_free(&key); - return -1; - } - - // Sign the hash - ret = mbedtls_pk_sign(&key, hash_to_md_type(hash_alg), hash, hash_size, - *signature, sig_len, GenerateHardRandom, NULL); - if (ret != 0) { - WARNF("(ecdsa) Failed to sign message: -0x%04x", -ret); - free(*signature); - *signature = NULL; - *sig_len = 0; - mbedtls_pk_free(&key); - return -1; - } - - mbedtls_pk_free(&key); - return 0; -} -// Lua binding for signing a message +cleanup: + mbedtls_pk_free(&key); + return ret; +} // Lua binding for signing a message static int LuaECDSASign(lua_State *L) { - const char *hash_name = luaL_optstring(L, 3, "sha256"); // Default to SHA-256 - const char *message = luaL_checkstring(L, 2); + // Correct order: priv_key, message, hash_name (default sha256) const char *priv_key_pem = luaL_checkstring(L, 1); + const char *message = luaL_checkstring(L, 2); + const char *hash_name = luaL_optstring(L, 3, "sha256"); hash_algorithm_t hash_alg = string_to_hash_alg(hash_name); @@ -1046,12 +1070,12 @@ cleanup: return ret; } static int LuaECDSAVerify(lua_State *L) { + // Correct order: pub_key, message, signature, hash_name (default sha256) const char *pub_key_pem = luaL_checkstring(L, 1); const char *message = luaL_checkstring(L, 2); size_t sig_len; - const unsigned char *signature = - (const unsigned char *)luaL_checklstring(L, 3, &sig_len); - const char *hash_name = luaL_optstring(L, 4, "sha256"); // Default to SHA-256 + const unsigned char *signature = (const unsigned char *)luaL_checklstring(L, 3, &sig_len); + const char *hash_name = luaL_optstring(L, 4, "sha256"); hash_algorithm_t hash_alg = string_to_hash_alg(hash_name); @@ -1061,6 +1085,437 @@ static int LuaECDSAVerify(lua_State *L) { return 1; } + +// AES +// AES key generation helper +static int LuaAesGenerateKey(lua_State *L) { + int keybits = 128; + if (lua_gettop(L) >= 1 && !lua_isnil(L, 1)) { + keybits = luaL_checkinteger(L, 1); + } + int keylen = keybits / 8; + if ((keybits != 128 && keybits != 192 && keybits != 256) || (keylen != 16 && keylen != 24 && keylen != 32)) { + lua_pushnil(L); + lua_pushstring(L, "AES key length must be 128, 192, or 256 bits"); + return 2; + } + unsigned char key[32]; + mbedtls_entropy_context entropy; + mbedtls_ctr_drbg_context ctr_drbg; + mbedtls_entropy_init(&entropy); + mbedtls_ctr_drbg_init(&ctr_drbg); + const char *pers = "aes_keygen"; + int ret = mbedtls_ctr_drbg_seed(&ctr_drbg, mbedtls_entropy_func, &entropy, (const unsigned char *)pers, strlen(pers)); + if (ret != 0) { + lua_pushnil(L); + lua_pushstring(L, "Failed to initialize RNG for AES key"); + mbedtls_ctr_drbg_free(&ctr_drbg); + mbedtls_entropy_free(&entropy); + return 2; + } + ret = mbedtls_ctr_drbg_random(&ctr_drbg, key, keylen); + mbedtls_ctr_drbg_free(&ctr_drbg); + mbedtls_entropy_free(&entropy); + if (ret != 0) { + lua_pushnil(L); + lua_pushstring(L, "Failed to generate random AES key"); + return 2; + } + lua_pushlstring(L, (const char *)key, keylen); + return 1; +} + +// AES encryption supporting CBC, GCM, and CTR modes +static int LuaAesEncrypt(lua_State *L) { + // Accept IV as the 3rd argument (after key, plaintext) + size_t keylen, ivlen = 0, ptlen; + const unsigned char *key = (const unsigned char *)luaL_checklstring(L, 1, &keylen); + const unsigned char *plaintext = (const unsigned char *)luaL_checklstring(L, 2, &ptlen); + const unsigned char *iv = NULL; + unsigned char *gen_iv = NULL; + int iv_was_generated = 0; + + const char *mode = luaL_optstring(L, 4, "cbc"); // Default to CBC if not provided + int ret = 0; + unsigned char *output = NULL; + int is_gcm = 0, is_ctr = 0, is_cbc = 0; + + if (strcasecmp(mode, "cbc") == 0) { + is_cbc = 1; + } else if (strcasecmp(mode, "gcm") == 0) { + is_gcm = 1; + } else if (strcasecmp(mode, "ctr") == 0) { + is_ctr = 1; + } else { + lua_pushnil(L); + lua_pushstring(L, "Unsupported AES mode. Use 'cbc', 'gcm', or 'ctr'."); + return 2; + } + + // If IV is not provided (arg3 is nil or missing), auto-generate + if (lua_isnoneornil(L, 3)) { + // For GCM, standard is 12 bytes, but allow 12-16 + if (is_gcm) { + ivlen = 12; + } else { + ivlen = 16; + } + gen_iv = malloc(ivlen); + if (!gen_iv) { + lua_pushnil(L); + lua_pushstring(L, "Failed to allocate IV"); + return 2; + } + mbedtls_entropy_context entropy; + mbedtls_ctr_drbg_context ctr_drbg; + mbedtls_entropy_init(&entropy); + mbedtls_ctr_drbg_init(&ctr_drbg); + mbedtls_ctr_drbg_seed(&ctr_drbg, mbedtls_entropy_func, &entropy, NULL, 0); + mbedtls_ctr_drbg_random(&ctr_drbg, gen_iv, ivlen); + mbedtls_ctr_drbg_free(&ctr_drbg); + mbedtls_entropy_free(&entropy); + iv = gen_iv; + iv_was_generated = 1; + } else { + // IV provided + iv = (const unsigned char *)luaL_checklstring(L, 3, &ivlen); + // Do not force ivlen to 16 here! Accept actual length for GCM (12-16) + if (is_cbc || is_ctr) { + if (ivlen != 16) { + lua_pushnil(L); + lua_pushstring(L, "AES IV must be 16 bytes for CBC/CTR"); + return 2; + } + } else if (is_gcm) { + if (ivlen < 12 || ivlen > 16) { + lua_pushnil(L); + lua_pushstring(L, "AES GCM IV/nonce must be 12-16 bytes"); + return 2; + } + } + iv_was_generated = 0; + } + + if (is_cbc) { + // PKCS7 padding + size_t block_size = 16; + size_t padlen = block_size - (ptlen % block_size); + size_t ctlen = ptlen + padlen; + unsigned char *input = malloc(ctlen); + if (!input) { + lua_pushnil(L); + lua_pushstring(L, "Memory allocation failed"); + return 2; + } + memcpy(input, plaintext, ptlen); + memset(input + ptlen, (unsigned char)padlen, padlen); + output = malloc(ctlen); + if (!output) { + free(input); + lua_pushnil(L); + lua_pushstring(L, "Memory allocation failed"); + return 2; + } + mbedtls_aes_context aes; + mbedtls_aes_init(&aes); + ret = mbedtls_aes_setkey_enc(&aes, key, keylen * 8); + if (ret != 0) { + free(input); + free(output); + mbedtls_aes_free(&aes); + lua_pushnil(L); + lua_pushstring(L, "Failed to set AES encryption key"); + return 2; + } + unsigned char iv_copy[16]; + memcpy(iv_copy, iv, 16); + ret = mbedtls_aes_crypt_cbc(&aes, MBEDTLS_AES_ENCRYPT, ctlen, iv_copy, input, output); + mbedtls_aes_free(&aes); + free(input); + if (ret != 0) { + free(output); + lua_pushnil(L); + lua_pushstring(L, "AES CBC encryption failed"); + return 2; + } + lua_pushlstring(L, (const char *)output, ctlen); + lua_pushlstring(L, (const char *)iv, ivlen); + free(output); + if (iv_was_generated) free(gen_iv); + return 2; + } else if (is_ctr) { + // CTR mode: no padding + output = malloc(ptlen); + if (!output) { + lua_pushnil(L); + lua_pushstring(L, "Memory allocation failed"); + return 2; + } + mbedtls_aes_context aes; + mbedtls_aes_init(&aes); + ret = mbedtls_aes_setkey_enc(&aes, key, keylen * 8); + if (ret != 0) { + free(output); + mbedtls_aes_free(&aes); + lua_pushnil(L); + lua_pushstring(L, "Failed to set AES encryption key"); + return 2; + } + unsigned char nonce_counter[16]; + unsigned char stream_block[16]; + size_t nc_off = 0; + memcpy(nonce_counter, iv, 16); + memset(stream_block, 0, 16); + ret = mbedtls_aes_crypt_ctr(&aes, ptlen, &nc_off, nonce_counter, stream_block, plaintext, output); + mbedtls_aes_free(&aes); + if (ret != 0) { + free(output); + lua_pushnil(L); + lua_pushstring(L, "AES CTR encryption failed"); + return 2; + } + lua_pushlstring(L, (const char *)output, ptlen); + lua_pushlstring(L, (const char *)iv, ivlen); + free(output); + if (iv_was_generated) free(gen_iv); + return 2; + } else if (is_gcm) { + // GCM mode: authenticated encryption + size_t taglen = 16; + unsigned char tag[16]; + output = malloc(ptlen); + if (!output) { + lua_pushnil(L); + lua_pushstring(L, "Memory allocation failed"); + return 2; + } + mbedtls_gcm_context gcm; + mbedtls_gcm_init(&gcm); + ret = mbedtls_gcm_setkey(&gcm, MBEDTLS_CIPHER_ID_AES, key, keylen * 8); + if (ret != 0) { + free(output); + mbedtls_gcm_free(&gcm); + lua_pushnil(L); + lua_pushstring(L, "Failed to set AES GCM key"); + return 2; + } + // Use actual ivlen, not hardcoded 16 + ret = mbedtls_gcm_crypt_and_tag(&gcm, MBEDTLS_GCM_ENCRYPT, ptlen, iv, ivlen, NULL, 0, plaintext, output, taglen, tag); + mbedtls_gcm_free(&gcm); + if (ret != 0) { + free(output); + lua_pushnil(L); + lua_pushstring(L, "AES GCM encryption failed"); + return 2; + } + lua_pushlstring(L, (const char *)output, ptlen); + lua_pushlstring(L, (const char *)iv, ivlen); + lua_pushlstring(L, (const char *)tag, taglen); + free(output); + if (iv_was_generated) free(gen_iv); + return 3; + } + lua_pushnil(L); + lua_pushstring(L, "Internal error in AES encrypt"); + return 2; +} + +// AES decryption supporting CBC, GCM, and CTR modes +static int LuaAesDecrypt(lua_State *L) { + size_t keylen, ctlen, ivlen; + const unsigned char *key = (const unsigned char *)luaL_checklstring(L, 1, &keylen); + const unsigned char *ciphertext = (const unsigned char *)luaL_checklstring(L, 2, &ctlen); + const unsigned char *iv = (const unsigned char *)luaL_checklstring(L, 3, &ivlen); + const char *mode = luaL_optstring(L, 4, "cbc"); // Default to CBC if not provided + const unsigned char *aad = NULL; + const unsigned char *tag = NULL; + size_t aadlen = 0, taglen = 0; + int is_gcm = 0, is_ctr = 0, is_cbc = 0; + + if (strcasecmp(mode, "cbc") == 0) { + is_cbc = 1; + } else if (strcasecmp(mode, "gcm") == 0) { + is_gcm = 1; + } else if (strcasecmp(mode, "ctr") == 0) { + is_ctr = 1; + } else { + lua_pushnil(L); + lua_pushstring(L, "Unsupported AES mode. Use 'cbc', 'gcm', or 'ctr'."); + return 2; + } + + // Validate key length (16, 24, 32 bytes) + if (keylen != 16 && keylen != 24 && keylen != 32) { + lua_pushnil(L); + lua_pushstring(L, "AES key must be 16, 24, or 32 bytes"); + return 2; + } + // Validate IV/nonce length + if (is_cbc || is_ctr) { + if (ivlen != 16) { + lua_pushnil(L); + lua_pushstring(L, "AES IV/nonce must be 16 bytes for CBC/CTR"); + return 2; + } + } else if (is_gcm) { + if (ivlen < 12 || ivlen > 16) { + lua_pushnil(L); + lua_pushstring(L, "AES GCM nonce must be 12-16 bytes"); + return 2; + } + } + + // GCM: require tag and optional AAD + if (is_gcm) { + if (!lua_isnoneornil(L, 5)) { + aad = (const unsigned char *)luaL_checklstring(L, 5, &aadlen); + } + if (!lua_isnoneornil(L, 6)) { + tag = (const unsigned char *)luaL_checklstring(L, 6, &taglen); + if (taglen < 12 || taglen > 16) { + lua_pushnil(L); + lua_pushstring(L, "AES GCM tag must be 12-16 bytes"); + return 2; + } + } else { + lua_pushnil(L); + lua_pushstring(L, "AES GCM tag required as 6th argument"); + return 2; + } + } + + int ret = 0; + unsigned char *output = NULL; + + if (is_cbc) { + // Ciphertext must be a multiple of block size + if (ctlen == 0 || (ctlen % 16) != 0) { + lua_pushnil(L); + lua_pushstring(L, "Ciphertext length must be a multiple of 16"); + return 2; + } + output = malloc(ctlen); + if (!output) { + lua_pushnil(L); + lua_pushstring(L, "Memory allocation failed"); + return 2; + } + mbedtls_aes_context aes; + mbedtls_aes_init(&aes); + ret = mbedtls_aes_setkey_dec(&aes, key, keylen * 8); + if (ret != 0) { + free(output); + mbedtls_aes_free(&aes); + lua_pushnil(L); + lua_pushstring(L, "Failed to set AES decryption key"); + return 2; + } + unsigned char iv_copy[16]; + memcpy(iv_copy, iv, 16); + ret = mbedtls_aes_crypt_cbc(&aes, MBEDTLS_AES_DECRYPT, ctlen, iv_copy, ciphertext, output); + mbedtls_aes_free(&aes); + if (ret != 0) { + free(output); + lua_pushnil(L); + lua_pushstring(L, "AES CBC decryption failed"); + return 2; + } + // PKCS7 unpadding + if (ctlen == 0) { + free(output); + lua_pushnil(L); + lua_pushstring(L, "Decrypted data is empty"); + return 2; + } + unsigned char pad = output[ctlen - 1]; + if (pad == 0 || pad > 16) { + free(output); + lua_pushnil(L); + lua_pushstring(L, "Invalid PKCS7 padding"); + return 2; + } + for (size_t i = 0; i < pad; ++i) { + if (output[ctlen - 1 - i] != pad) { + free(output); + lua_pushnil(L); + lua_pushstring(L, "Invalid PKCS7 padding"); + return 2; + } + } + size_t ptlen = ctlen - pad; + lua_pushlstring(L, (const char *)output, ptlen); + free(output); + return 1; + } else if (is_ctr) { + // CTR mode: no padding + output = malloc(ctlen); + if (!output) { + lua_pushnil(L); + lua_pushstring(L, "Memory allocation failed"); + return 2; + } + mbedtls_aes_context aes; + mbedtls_aes_init(&aes); + ret = mbedtls_aes_setkey_enc(&aes, key, keylen * 8); + if (ret != 0) { + free(output); + mbedtls_aes_free(&aes); + lua_pushnil(L); + lua_pushstring(L, "Failed to set AES encryption key"); + return 2; + } + unsigned char nonce_counter[16]; + unsigned char stream_block[16]; + size_t nc_off = 0; + memcpy(nonce_counter, iv, 16); + memset(stream_block, 0, 16); + ret = mbedtls_aes_crypt_ctr(&aes, ctlen, &nc_off, nonce_counter, stream_block, ciphertext, output); + mbedtls_aes_free(&aes); + if (ret != 0) { + free(output); + lua_pushnil(L); + lua_pushstring(L, "AES CTR decryption failed"); + return 2; + } + lua_pushlstring(L, (const char *)output, ctlen); + free(output); + return 1; + } else if (is_gcm) { + // GCM mode: authenticated decryption + output = malloc(ctlen); + if (!output) { + lua_pushnil(L); + lua_pushstring(L, "Memory allocation failed"); + return 2; + } + mbedtls_gcm_context gcm; + mbedtls_gcm_init(&gcm); + ret = mbedtls_gcm_setkey(&gcm, MBEDTLS_CIPHER_ID_AES, key, keylen * 8); + if (ret != 0) { + free(output); + mbedtls_gcm_free(&gcm); + lua_pushnil(L); + lua_pushstring(L, "Failed to set AES GCM key"); + return 2; + } + ret = mbedtls_gcm_auth_decrypt(&gcm, ctlen, iv, ivlen, aad, aadlen, tag, taglen, ciphertext, output); + mbedtls_gcm_free(&gcm); + if (ret != 0) { + free(output); + lua_pushnil(L); + lua_pushstring(L, "AES GCM decryption failed or authentication failed"); + return 2; + } + lua_pushlstring(L, (const char *)output, ctlen); + free(output); + return 1; + } + lua_pushnil(L); + lua_pushstring(L, "Internal error in AES decrypt"); + return 2; +} + +// LuaCrypto compatible API static int LuaCryptoSign(lua_State *L) { const char *dtype = luaL_checkstring(L, 1); // Type of signature (e.g., "rsa", "ecdsa") lua_remove(L, 1); // Remove the first argument (key type or cipher type) before dispatching @@ -1088,41 +1543,48 @@ static int LuaCryptoVerify(lua_State *L) { } static int LuaCryptoEncrypt(lua_State *L) { - const char *cipher = luaL_checkstring(L, 1); // Cipher type (e.g., "rsa") + const char *cipher = luaL_checkstring(L, 1); // Cipher type (e.g., "rsa", "aes") lua_remove(L, 1); // Remove the first argument (key type or cipher type) before dispatching if (strcasecmp(cipher, "rsa") == 0) { return LuaRSAEncrypt(L); + } else if (strcasecmp(cipher, "aes") == 0) { + return LuaAesEncrypt(L); } else { return luaL_error(L, "Unsupported cipher type: %s", cipher); } } static int LuaCryptoDecrypt(lua_State *L) { - const char *cipher = luaL_checkstring(L, 1); // Cipher type (e.g., "rsa") + const char *cipher = luaL_checkstring(L, 1); // Cipher type (e.g., "rsa", "aes") lua_remove(L, 1); // Remove the first argument (key type or cipher type) before dispatching if (strcasecmp(cipher, "rsa") == 0) { return LuaRSADecrypt(L); + } else if (strcasecmp(cipher, "aes") == 0) { + return LuaAesDecrypt(L); } else { return luaL_error(L, "Unsupported cipher type: %s", cipher); } } static int LuaCryptoGenerateKeyPair(lua_State *L) { - const char *key_type = "rsa"; // Key type (e.g., "rsa", "ecdsa") - - if (! lua_isinteger(L, 1) && ! lua_isnoneornil(L, 1)) { - key_type = luaL_checkstring(L, 1); // Get key type from first argumen - lua_remove(L, 1); // Remove the first argument (key type or cipher type) before dispatching - } - - if (strcasecmp(key_type, "rsa") == 0) { + // If the first argument is a number, treat as RSA key length + if (lua_gettop(L) >= 1 && lua_type(L, 1) == LUA_TNUMBER) { + // Call LuaRSAGenerateKeyPair with the number as the key length return LuaRSAGenerateKeyPair(L); - } else if (strcasecmp(key_type, "ecdsa") == 0) { + } + // Otherwise, get the key type from the first argument, default to "rsa" if not provided + const char *type = luaL_optstring(L, 1, "rsa"); + lua_remove(L, 1); + if (strcasecmp(type, "rsa") == 0) { + return LuaRSAGenerateKeyPair(L); + } else if (strcasecmp(type, "ecdsa") == 0) { return LuaECDSAGenerateKeyPair(L); + } else if (strcasecmp(type, "aes") == 0) { + return LuaAesGenerateKey(L); } else { - return luaL_error(L, "Unsupported key type: %s", key_type); + return luaL_error(L, "Unsupported key type: %s", type); } } @@ -1132,8 +1594,8 @@ static const luaL_Reg kLuaCrypto[] = { {"encrypt", LuaCryptoEncrypt}, // {"decrypt", LuaCryptoDecrypt}, // {"generatekeypair", LuaCryptoGenerateKeyPair}, // - {"convertPemToJwk", convertPemToJwk}, // - {"generateCsr", generateCsr}, // + {"convertPemToJwk", LuaConvertPemToJwk}, // + {"GenerateCsr", LuaGenerateCSR}, // {0}, // };