mirror of
https://github.com/jart/cosmopolitan.git
synced 2025-07-12 14:09:12 +00:00
Add AES based encryption and decryption in multiple modes (CBC, CTR & GCM)
Improve test coverage
This commit is contained in:
parent
9e121882d0
commit
19541f95bd
3 changed files with 755 additions and 137 deletions
|
@ -1,14 +1,23 @@
|
||||||
-- Helper function to print test results
|
-- Helper function to print test results
|
||||||
local function assert_equal(actual, expected, message)
|
local function assert_equal(actual, expected, plaintext)
|
||||||
if actual ~= expected then
|
if actual ~= expected then
|
||||||
error(message .. ": expected " .. tostring(expected) .. ", got " .. tostring(actual))
|
error(plaintext .. ": expected " .. tostring(expected) .. ", got " .. tostring(actual))
|
||||||
else
|
else
|
||||||
print("PASS: " .. message)
|
print("PASS: " .. plaintext)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
local function assert_not_equal(actual, not_expected, plaintext)
|
||||||
|
if actual == not_expected then
|
||||||
|
error(plaintext .. ": did not expect " .. tostring(not_expected))
|
||||||
|
else
|
||||||
|
print("PASS: " .. plaintext)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Test RSA key pair generation
|
-- Test RSA key pair generation
|
||||||
local function test_rsa_keypair_generation()
|
local function test_rsa_keypair_generation()
|
||||||
|
print('\27[1;7mTest RSA key pair generation \27[0m')
|
||||||
local priv_key, pub_key = crypto.generatekeypair("rsa", 2048)
|
local priv_key, pub_key = crypto.generatekeypair("rsa", 2048)
|
||||||
assert_equal(type(priv_key), "string", "RSA private key generation")
|
assert_equal(type(priv_key), "string", "RSA private key generation")
|
||||||
assert_equal(type(pub_key), "string", "RSA public key generation")
|
assert_equal(type(pub_key), "string", "RSA public key generation")
|
||||||
|
@ -16,6 +25,7 @@ end
|
||||||
|
|
||||||
-- Test ECDSA key pair generation
|
-- Test ECDSA key pair generation
|
||||||
local function test_ecdsa_keypair_generation()
|
local function test_ecdsa_keypair_generation()
|
||||||
|
print('\n\27[1;7mTest ECDSA key pair generation \27[0m')
|
||||||
local priv_key, pub_key = crypto.generatekeypair("ecdsa", "secp256r1")
|
local priv_key, pub_key = crypto.generatekeypair("ecdsa", "secp256r1")
|
||||||
assert_equal(type(priv_key), "string", "ECDSA private key generation")
|
assert_equal(type(priv_key), "string", "ECDSA private key generation")
|
||||||
assert_equal(type(pub_key), "string", "ECDSA public key generation")
|
assert_equal(type(pub_key), "string", "ECDSA public key generation")
|
||||||
|
@ -23,61 +33,207 @@ end
|
||||||
|
|
||||||
-- Test RSA encryption and decryption
|
-- Test RSA encryption and decryption
|
||||||
local function test_rsa_encryption_decryption()
|
local function test_rsa_encryption_decryption()
|
||||||
|
print('\n\27[1;7mTest RSA encryption and decryption \27[0m')
|
||||||
local priv_key, pub_key = crypto.generatekeypair("rsa", 2048)
|
local priv_key, pub_key = crypto.generatekeypair("rsa", 2048)
|
||||||
local message = "Hello, RSA!"
|
local plaintext = "Hello, RSA!"
|
||||||
local encrypted = crypto.encrypt("rsa", pub_key, message)
|
local encrypted = crypto.encrypt("rsa", pub_key, plaintext)
|
||||||
assert_equal(type(encrypted), "string", "RSA encryption")
|
assert_equal(type(encrypted), "string", "RSA encryption")
|
||||||
local decrypted = crypto.decrypt("rsa", priv_key, encrypted)
|
local decrypted = crypto.decrypt("rsa", priv_key, encrypted)
|
||||||
assert_equal(decrypted, message, "RSA decryption")
|
assert_equal(decrypted, plaintext, "RSA decryption")
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Test RSA signing and verification
|
-- Test RSA signing and verification
|
||||||
local function test_rsa_signing_verification()
|
local function test_rsa_signing_verification()
|
||||||
|
print('\n\27[1;7mTest RSA signing and verification \27[0m')
|
||||||
local priv_key, pub_key = crypto.generatekeypair("rsa", 2048)
|
local priv_key, pub_key = crypto.generatekeypair("rsa", 2048)
|
||||||
local message = "Sign this message"
|
local plaintext = "Sign this plaintext"
|
||||||
local signature = crypto.sign("rsa", priv_key, message, "sha256")
|
local signature = crypto.sign("rsa", priv_key, plaintext, "sha256")
|
||||||
assert_equal(type(signature), "string", "RSA signing")
|
assert_equal(type(signature), "string", "RSA signing")
|
||||||
local is_valid = crypto.verify("rsa", pub_key, message, signature, "sha256")
|
local is_valid = crypto.verify("rsa", pub_key, plaintext, signature, "sha256")
|
||||||
assert_equal(is_valid, true, "RSA signature verification")
|
assert_equal(is_valid, true, "RSA signature verification")
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Test ECDSA signing and verification
|
-- Test ECDSA signing and verification
|
||||||
local function test_ecdsa_signing_verification()
|
local function test_ecdsa_signing_verification()
|
||||||
|
print('\n\27[1;7mTest ECDSA signing and verification \27[0m')
|
||||||
local priv_key, pub_key = crypto.generatekeypair("ecdsa", "secp256r1")
|
local priv_key, pub_key = crypto.generatekeypair("ecdsa", "secp256r1")
|
||||||
local message = "Sign this message with ECDSA"
|
local plaintext = "Sign this plaintext with ECDSA"
|
||||||
local signature = crypto.sign("ecdsa", priv_key, message, "sha256")
|
local signature = crypto.sign("ecdsa", priv_key, plaintext, "sha256")
|
||||||
assert_equal(type(signature), "string", "ECDSA signing")
|
assert_equal(type(signature), "string", "ECDSA signing")
|
||||||
local is_valid = crypto.verify("ecdsa", pub_key, message, signature, "sha256")
|
local is_valid = crypto.verify("ecdsa", pub_key, plaintext, signature, "sha256")
|
||||||
assert_equal(is_valid, true, "ECDSA signature verification")
|
assert_equal(is_valid, true, "ECDSA signature verification")
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Test CSR generation
|
-- Test AES key generation
|
||||||
local function test_csr_generation()
|
local function test_aes_key_generation()
|
||||||
local priv_key, pub_key = crypto.generatekeypair("rsa", 2048)
|
print('\n\27[1;7mTest AES key generation \27[0m')
|
||||||
local subject_name = "CN=example.com,O=Example Org,C=US"
|
local key = crypto.generatekeypair('aes', 256) -- 256-bit key
|
||||||
local csr = crypto.generateCsr(priv_key, subject_name)
|
assert_equal(type(key), "string", "AES key generation")
|
||||||
assert_equal(type(csr), "string", "CSR generation")
|
assert_equal(#key, 32, "AES key length (256 bits)")
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Test AES encryption and decryption (CBC mode)
|
||||||
|
local function test_aes_encryption_decryption()
|
||||||
|
print('\n\27[1;7mTest AES encryption and decryption (CBC mode) \27[0m')
|
||||||
|
local key = crypto.generatekeypair('aes',256) -- 256-bit key
|
||||||
|
local plaintext = "Hello, AES CBC!"
|
||||||
|
|
||||||
|
-- Encrypt without providing IV (should auto-generate IV)
|
||||||
|
print('\27[1mAES encryption (auto IV)\27[0m')
|
||||||
|
local encrypted, iv = crypto.encrypt("aes", key, plaintext, nil)
|
||||||
|
assert_equal(type(encrypted), "string", "AES encryption (CBC, auto IV)")
|
||||||
|
assert_equal(type(iv), "string", "AES IV (auto-generated)")
|
||||||
|
|
||||||
|
-- Decrypt
|
||||||
|
print('\n\27[1mAES decryption (auto IV)\27[0m')
|
||||||
|
local decrypted = crypto.decrypt("aes", key, encrypted, iv)
|
||||||
|
assert_equal(decrypted, plaintext, "AES decryption (CBC, auto IV)")
|
||||||
|
|
||||||
|
-- Encrypt with explicit IV
|
||||||
|
print('\n\27[1mAES encryption (explicit IV)\27[0m')
|
||||||
|
local iv2 = GetRandomBytes(16)
|
||||||
|
local encrypted2, iv_used = crypto.encrypt("aes", key, plaintext, iv2)
|
||||||
|
assert_equal(type(encrypted2), "string", "AES encryption (CBC, explicit IV)")
|
||||||
|
assert_equal(iv_used, iv2, "AES IV (explicit)")
|
||||||
|
|
||||||
|
print('\n\27[1mAES decryption (explicit IV)\27[0m')
|
||||||
|
local decrypted2 = crypto.decrypt("aes", key, encrypted2, iv2)
|
||||||
|
assert_equal(decrypted2, plaintext, "AES decryption (CBC, explicit IV)")
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Test AES encryption and decryption (CTR mode)
|
||||||
|
local function test_aes_encryption_decryption_ctr()
|
||||||
|
print('\n\27[1;7mTest AES encryption and decryption (CTR mode) \27[0m')
|
||||||
|
local key = crypto.generatekeypair('aes',256)
|
||||||
|
local plaintext = "Hello, AES CTR!"
|
||||||
|
|
||||||
|
-- Encrypt without providing IV (should auto-generate IV)
|
||||||
|
print('\27[1mAES encryption (auto IV)\27[0m')
|
||||||
|
local encrypted, iv = crypto.encrypt("aes", key, plaintext, nil, "ctr")
|
||||||
|
assert_equal(type(encrypted), "string", "AES encryption (CTR, auto IV)")
|
||||||
|
assert_equal(type(iv), "string", "AES IV (auto-generated, CTR)")
|
||||||
|
|
||||||
|
-- Decrypt
|
||||||
|
print('\n\27[1mAES decryption (auto IV)\27[0m')
|
||||||
|
local decrypted = crypto.decrypt("aes", key, encrypted, iv, "ctr")
|
||||||
|
assert_equal(decrypted, plaintext, "AES decryption (CTR, auto IV)")
|
||||||
|
|
||||||
|
-- Encrypt with explicit IV
|
||||||
|
print('\n\27[1mAES encryption (explicit IV)\27[0m')
|
||||||
|
local iv2 = GetRandomBytes(16)
|
||||||
|
local encrypted2, iv_used = crypto.encrypt("aes", key, plaintext, iv2, "ctr")
|
||||||
|
assert_equal(type(encrypted2), "string", "AES encryption (CTR, explicit IV)")
|
||||||
|
assert_equal(iv_used, iv2, "AES IV (explicit, CTR)")
|
||||||
|
|
||||||
|
print('\n\27[1mAES decryption (explicit IV)\27[0m')
|
||||||
|
local decrypted2 = crypto.decrypt("aes", key, encrypted2, iv2, "ctr")
|
||||||
|
assert_equal(decrypted2, plaintext, "AES decryption (CTR, explicit IV)")
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Test AES encryption and decryption (GCM mode)
|
||||||
|
local function test_aes_encryption_decryption_gcm()
|
||||||
|
print('\n\27[1;7mTest AES encryption and decryption (GCM mode) \27[0m')
|
||||||
|
local key = crypto.generatekeypair('aes',256)
|
||||||
|
local plaintext = "Hello, AES GCM!"
|
||||||
|
|
||||||
|
-- Encrypt without providing IV (should auto-generate IV)
|
||||||
|
print('\27[1mAES encryption (auto IV)\27[0m')
|
||||||
|
local encrypted, iv, tag = crypto.encrypt("aes", key, plaintext, nil, "gcm")
|
||||||
|
assert_equal(type(encrypted), "string", "AES encryption (GCM, auto IV)")
|
||||||
|
assert_equal(type(iv), "string", "AES IV (auto-generated, GCM)")
|
||||||
|
assert_equal(type(tag), "string", "AES GCM tag (auto IV)")
|
||||||
|
|
||||||
|
-- Decrypt
|
||||||
|
print('\n\27[1mAES decryption (auto IV)\27[0m')
|
||||||
|
local decrypted = crypto.decrypt("aes", key, encrypted, iv, "gcm", nil, tag)
|
||||||
|
assert_equal(decrypted, plaintext, "AES decryption (GCM, auto IV)")
|
||||||
|
|
||||||
|
-- Encrypt with explicit IV
|
||||||
|
print('\n\27[1mAES encryption (explicit IV)\27[0m')
|
||||||
|
local iv2 = GetRandomBytes(13) -- GCM IV/nonce can be 12-16 bytes, 12 is standard
|
||||||
|
local encrypted2, iv_used, tag2 = crypto.encrypt("aes", key, plaintext, iv2, "gcm")
|
||||||
|
assert_equal(type(encrypted2), "string", "AES encryption (GCM, explicit IV)")
|
||||||
|
assert_equal(iv_used, iv2, "AES IV (explicit, GCM)")
|
||||||
|
assert_equal(type(tag2), "string", "AES GCM tag (explicit IV)")
|
||||||
|
|
||||||
|
print('\n\27[1mAES decryption (explicit IV)\27[0m')
|
||||||
|
local decrypted2 = crypto.decrypt("aes", key, encrypted2, iv2, "gcm", nil, tag2)
|
||||||
|
assert_equal(decrypted2, plaintext, "AES decryption (GCM, explicit IV)")
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Test PemToJwk conversion
|
-- Test PemToJwk conversion
|
||||||
local function test_pem_to_jwk()
|
local function test_pem_to_jwk()
|
||||||
local priv_key, pub_key = crypto.generatekeypair("rsa", 2048)
|
print('\n\27[1;7mTest PEM to JWK conversion \27[0m')
|
||||||
local jwk = crypto.convertPemToJwk(pub_key)
|
local priv_key, pub_key = crypto.generatekeypair()
|
||||||
assert_equal(type(jwk), "table", "PEM to JWK conversion")
|
print('\27[1mRSA Private key to JWK conversion\27[0m')
|
||||||
assert_equal(jwk.kty, "RSA", "JWK key type")
|
local priv_jwk = crypto.convertPemToJwk(priv_key)
|
||||||
|
assert_equal(type(priv_jwk), "table", "PEM to JWK conversion")
|
||||||
|
assert_equal(priv_jwk.kty, "RSA", "JWK key type")
|
||||||
|
|
||||||
|
print('\n\27[1mRSA Public key to JWK conversion\27[0m')
|
||||||
|
local pub_jwk = crypto.convertPemToJwk(pub_key)
|
||||||
|
assert_equal(type(pub_jwk), "table", "PEM to JWK conversion")
|
||||||
|
assert_equal(pub_jwk.kty, "RSA", "JWK key type")
|
||||||
|
|
||||||
|
-- Test ECDSA keys
|
||||||
|
local priv_key, pub_key = crypto.generatekeypair('ecdsa')
|
||||||
|
print('\n\27[1mECDSA Private key to JWK conversion\27[0m')
|
||||||
|
local priv_jwk = crypto.convertPemToJwk(priv_key)
|
||||||
|
assert_equal(type(priv_jwk), "table", "PEM to JWK conversion")
|
||||||
|
assert_equal(priv_jwk.kty, "EC", "JWK key type")
|
||||||
|
|
||||||
|
print('\n\27[1mECDSA Public key to JWK conversion\27[0m')
|
||||||
|
local pub_jwk = crypto.convertPemToJwk(pub_key)
|
||||||
|
assert_equal(type(pub_jwk), "table", "PEM to JWK conversion")
|
||||||
|
assert_equal(pub_jwk.kty, "EC", "JWK key type")
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Test CSR generation
|
||||||
|
local function test_csr_generation()
|
||||||
|
print('\n\27[1;7mTest CSR generation \27[0m')
|
||||||
|
local priv_key, _ = crypto.generatekeypair()
|
||||||
|
local subject_name = "CN=example.com,O=Example Org,C=US"
|
||||||
|
local san = "DNS:example.com, DNS:www.example.com, IP:192.168.1.1"
|
||||||
|
|
||||||
|
local csr = crypto.GenerateCsr(priv_key, subject_name)
|
||||||
|
assert_equal(type(csr), "string", "CSR generation with subject name")
|
||||||
|
|
||||||
|
csr = crypto.GenerateCsr(priv_key, subject_name, san)
|
||||||
|
assert_equal(type(csr), "string", "CSR generation with subject name and san")
|
||||||
|
|
||||||
|
csr = crypto.GenerateCsr(priv_key, nil, san)
|
||||||
|
assert_equal(type(csr), "string", "CSR generation with nil subject name and san")
|
||||||
|
|
||||||
|
csr = crypto.GenerateCsr(priv_key, '', san)
|
||||||
|
assert_equal(type(csr), "string", "CSR generation with empty subject name and san")
|
||||||
|
|
||||||
|
-- These should fail
|
||||||
|
csr = crypto.GenerateCsr(priv_key, '')
|
||||||
|
assert_not_equal(type(csr), "string", "CSR generation with empty subject name and no san is rejected")
|
||||||
|
|
||||||
|
csr = crypto.GenerateCsr(priv_key)
|
||||||
|
assert_not_equal(type(csr), "string", "CSR generation with nil subject name and no san is rejected")
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Run all tests
|
-- Run all tests
|
||||||
local function run_tests()
|
local function run_tests()
|
||||||
print("Running tests for lcrypto...")
|
print("Running tests for lcrypto...")
|
||||||
test_rsa_keypair_generation()
|
test_rsa_keypair_generation()
|
||||||
test_ecdsa_keypair_generation()
|
|
||||||
test_rsa_encryption_decryption()
|
|
||||||
test_rsa_signing_verification()
|
test_rsa_signing_verification()
|
||||||
|
test_rsa_encryption_decryption()
|
||||||
|
test_ecdsa_keypair_generation()
|
||||||
test_ecdsa_signing_verification()
|
test_ecdsa_signing_verification()
|
||||||
test_csr_generation()
|
test_aes_key_generation()
|
||||||
|
test_aes_encryption_decryption()
|
||||||
|
test_aes_encryption_decryption_ctr()
|
||||||
|
test_aes_encryption_decryption_gcm()
|
||||||
test_pem_to_jwk()
|
test_pem_to_jwk()
|
||||||
|
test_csr_generation()
|
||||||
|
print('')
|
||||||
print("All tests passed!")
|
print("All tests passed!")
|
||||||
|
EXIT=0
|
||||||
|
return EXIT
|
||||||
end
|
end
|
||||||
|
|
||||||
run_tests()
|
EXIT=70
|
||||||
|
os.exit(run_tests())
|
||||||
|
|
2
third_party/mbedtls/config.h
vendored
2
third_party/mbedtls/config.h
vendored
|
@ -40,9 +40,9 @@
|
||||||
#define MBEDTLS_GCM_C
|
#define MBEDTLS_GCM_C
|
||||||
#ifndef TINY
|
#ifndef TINY
|
||||||
#define MBEDTLS_CIPHER_MODE_CBC
|
#define MBEDTLS_CIPHER_MODE_CBC
|
||||||
|
#define MBEDTLS_CIPHER_MODE_CTR
|
||||||
/*#define MBEDTLS_CCM_C*/
|
/*#define MBEDTLS_CCM_C*/
|
||||||
/*#define MBEDTLS_CIPHER_MODE_CFB*/
|
/*#define MBEDTLS_CIPHER_MODE_CFB*/
|
||||||
/*#define MBEDTLS_CIPHER_MODE_CTR*/
|
|
||||||
/*#define MBEDTLS_CIPHER_MODE_OFB*/
|
/*#define MBEDTLS_CIPHER_MODE_OFB*/
|
||||||
/*#define MBEDTLS_CIPHER_MODE_XTS*/
|
/*#define MBEDTLS_CIPHER_MODE_XTS*/
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -9,6 +9,10 @@
|
||||||
#include "third_party/mbedtls/oid.h"
|
#include "third_party/mbedtls/oid.h"
|
||||||
#include "third_party/mbedtls/md.h"
|
#include "third_party/mbedtls/md.h"
|
||||||
#include "third_party/mbedtls/base64.h"
|
#include "third_party/mbedtls/base64.h"
|
||||||
|
#include "third_party/mbedtls/aes.h"
|
||||||
|
#include "third_party/mbedtls/ctr_drbg.h"
|
||||||
|
#include "third_party/mbedtls/entropy.h"
|
||||||
|
#include "third_party/mbedtls/gcm.h"
|
||||||
|
|
||||||
// Standard C library and redbean utilities
|
// Standard C library and redbean utilities
|
||||||
#include "libc/errno.h"
|
#include "libc/errno.h"
|
||||||
|
@ -16,8 +20,8 @@
|
||||||
#include "libc/str/str.h"
|
#include "libc/str/str.h"
|
||||||
#include "tool/net/luacheck.h"
|
#include "tool/net/luacheck.h"
|
||||||
|
|
||||||
// Updated PemToJwk to parse PEM keys and convert them into JWK format
|
// Parse PEM keys and convert them into JWK format
|
||||||
static int convertPemToJwk(lua_State *L) {
|
static int LuaConvertPemToJwk(lua_State *L) {
|
||||||
const char *pem_key = luaL_checkstring(L, 1);
|
const char *pem_key = luaL_checkstring(L, 1);
|
||||||
|
|
||||||
mbedtls_pk_context key;
|
mbedtls_pk_context key;
|
||||||
|
@ -166,11 +170,23 @@ static int convertPemToJwk(lua_State *L) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CSR Creation Function
|
// CSR Creation Function
|
||||||
static int generateCsr(lua_State *L) {
|
static int LuaGenerateCSR(lua_State *L) {
|
||||||
const char *key_pem = luaL_checkstring(L, 1);
|
const char *key_pem = luaL_checkstring(L, 1);
|
||||||
const char *subject_name = luaL_checkstring(L, 2);
|
const char *subject_name;
|
||||||
const char *san_list = luaL_optstring(L, 3, NULL);
|
const char *san_list = luaL_optstring(L, 3, NULL);
|
||||||
|
|
||||||
|
if (lua_isnoneornil(L, 2)) {
|
||||||
|
subject_name = "";
|
||||||
|
} else {
|
||||||
|
subject_name = luaL_checkstring(L, 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if (lua_isnoneornil(L, 3) && subject_name[0] == '\0') {
|
||||||
|
lua_pushnil(L);
|
||||||
|
lua_pushstring(L, "Subject name or SANs are required");
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
mbedtls_pk_context key;
|
mbedtls_pk_context key;
|
||||||
mbedtls_x509write_csr req;
|
mbedtls_x509write_csr req;
|
||||||
char buf[4096];
|
char buf[4096];
|
||||||
|
@ -211,7 +227,9 @@ static int generateCsr(lua_State *L) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RSA
|
||||||
|
|
||||||
|
// Generate RSA Key Pair
|
||||||
static bool RSAGenerateKeyPair(char **private_key_pem, size_t *private_key_len,
|
static bool RSAGenerateKeyPair(char **private_key_pem, size_t *private_key_len,
|
||||||
char **public_key_pem, size_t *public_key_len,
|
char **public_key_pem, size_t *public_key_len,
|
||||||
unsigned int key_length) {
|
unsigned int key_length) {
|
||||||
|
@ -263,6 +281,7 @@ static bool RSAGenerateKeyPair(char **private_key_pem, size_t *private_key_len,
|
||||||
mbedtls_pk_free(&key);
|
mbedtls_pk_free(&key);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Lua wrapper for RSA key pair generation
|
* Lua wrapper for RSA key pair generation
|
||||||
*
|
*
|
||||||
|
@ -272,43 +291,38 @@ static bool RSAGenerateKeyPair(char **private_key_pem, size_t *private_key_len,
|
||||||
* error_message)
|
* error_message)
|
||||||
*/
|
*/
|
||||||
static int LuaRSAGenerateKeyPair(lua_State *L) {
|
static int LuaRSAGenerateKeyPair(lua_State *L) {
|
||||||
char *private_key, *public_key;
|
int bits = 2048;
|
||||||
size_t private_len, public_len;
|
// If no arguments, or first argument is nil, default to 2048
|
||||||
int key_length = 2048; // Default RSA key length
|
if (lua_gettop(L) == 0 || lua_isnoneornil(L, 1)) {
|
||||||
|
bits = 2048;
|
||||||
// Get key length from Lua (optional parameter)
|
} else if (lua_gettop(L) == 1 && lua_type(L, 1) == LUA_TNUMBER) {
|
||||||
if (lua_gettop(L) >= 1 && !lua_isnil(L, 1)) {
|
bits = (int)lua_tointeger(L, 1);
|
||||||
key_length = luaL_checkinteger(L, 1);
|
} else {
|
||||||
// Validate key length (common RSA key lengths are 1024, 2048, 3072, 4096)
|
bits = (int)luaL_optinteger(L, 2, 2048);
|
||||||
if (key_length != 1024 && key_length != 2048 && key_length != 3072 &&
|
|
||||||
key_length != 4096) {
|
|
||||||
lua_pushnil(L);
|
|
||||||
lua_pushstring(L,
|
|
||||||
"Invalid RSA key length. Use 1024, 2048, 3072, or 4096.");
|
|
||||||
return 2;
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Call the C function to generate the key pair
|
char *private_key, *public_key;
|
||||||
if (!RSAGenerateKeyPair(&private_key, &private_len, &public_key, &public_len,
|
size_t private_len, public_len;
|
||||||
key_length)) {
|
|
||||||
lua_pushnil(L);
|
// Call the C function to generate the key pair
|
||||||
lua_pushstring(L, "Failed to generate RSA key pair");
|
if (!RSAGenerateKeyPair(&private_key, &private_len, &public_key, &public_len, bits)) {
|
||||||
|
lua_pushnil(L);
|
||||||
|
lua_pushstring(L, "Failed to generate RSA key pair");
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Push results to Lua
|
||||||
|
lua_pushstring(L, private_key);
|
||||||
|
lua_pushstring(L, public_key);
|
||||||
|
|
||||||
|
// Clean up
|
||||||
|
free(private_key);
|
||||||
|
free(public_key);
|
||||||
|
|
||||||
return 2;
|
return 2;
|
||||||
}
|
|
||||||
|
|
||||||
// Push results to Lua
|
|
||||||
lua_pushstring(L, private_key);
|
|
||||||
lua_pushstring(L, public_key);
|
|
||||||
|
|
||||||
// Clean up
|
|
||||||
free(private_key);
|
|
||||||
free(public_key);
|
|
||||||
|
|
||||||
return 2;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RSA
|
|
||||||
static char *RSAEncrypt(const char *public_key_pem, const unsigned char *data,
|
static char *RSAEncrypt(const char *public_key_pem, const unsigned char *data,
|
||||||
size_t data_len, size_t *out_len) {
|
size_t data_len, size_t *out_len) {
|
||||||
int rc;
|
int rc;
|
||||||
|
@ -622,7 +636,7 @@ static int LuaRSAVerify(lua_State *L) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Elliptic Curve Cryptography Functions
|
||||||
// Supported curves mapping
|
// Supported curves mapping
|
||||||
typedef struct {
|
typedef struct {
|
||||||
const char *name;
|
const char *name;
|
||||||
|
@ -710,6 +724,7 @@ static int LuaListHashAlgorithms(lua_State *L) {
|
||||||
|
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// List available curves
|
// List available curves
|
||||||
static int LuaListCurves(lua_State *L) {
|
static int LuaListCurves(lua_State *L) {
|
||||||
const curve_map_t *curve = supported_curves;
|
const curve_map_t *curve = supported_curves;
|
||||||
|
@ -911,68 +926,77 @@ static int LuaECDSAGenerateKeyPair(lua_State *L) {
|
||||||
static int ECDSASign(const char *priv_key_pem, const char *message,
|
static int ECDSASign(const char *priv_key_pem, const char *message,
|
||||||
hash_algorithm_t hash_alg, unsigned char **signature,
|
hash_algorithm_t hash_alg, unsigned char **signature,
|
||||||
size_t *sig_len) {
|
size_t *sig_len) {
|
||||||
mbedtls_pk_context key;
|
mbedtls_pk_context key;
|
||||||
unsigned char hash[64]; // Max hash size (SHA-512)
|
unsigned char hash[64]; // Max hash size (SHA-512)
|
||||||
size_t hash_size;
|
size_t hash_size;
|
||||||
int ret;
|
int ret;
|
||||||
|
|
||||||
|
*signature = NULL;
|
||||||
|
*sig_len = 0;
|
||||||
|
|
||||||
|
if (!priv_key_pem) {
|
||||||
|
WARNF("(ecdsa) Private key is NULL");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the length of the PEM string (excluding null terminator)
|
||||||
|
size_t key_len = strlen(priv_key_pem);
|
||||||
|
if (key_len == 0) {
|
||||||
|
WARNF("(ecdsa) Private key is empty");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get hash size for the selected algorithm
|
||||||
|
hash_size = get_hash_size(hash_alg);
|
||||||
|
|
||||||
|
mbedtls_pk_init(&key);
|
||||||
|
|
||||||
|
// Parse the private key from PEM directly without creating a copy
|
||||||
|
ret = mbedtls_pk_parse_key(&key, (const unsigned char *)priv_key_pem,
|
||||||
|
key_len + 1, NULL, 0);
|
||||||
|
|
||||||
|
if (ret != 0) {
|
||||||
|
WARNF("(ecdsa) Failed to parse private key: -0x%04x", -ret);
|
||||||
|
goto cleanup;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute hash of the message using the specified algorithm
|
||||||
|
ret = compute_hash(hash_alg, (const unsigned char *)message, strlen(message),
|
||||||
|
hash, sizeof(hash));
|
||||||
|
if (ret != 0) {
|
||||||
|
WARNF("(ecdsa) Failed to compute message hash");
|
||||||
|
goto cleanup;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allocate memory for signature (max size for ECDSA)
|
||||||
|
*signature = malloc(MBEDTLS_ECDSA_MAX_LEN);
|
||||||
|
if (*signature == NULL) {
|
||||||
|
WARNF("(ecdsa) Failed to allocate memory for signature");
|
||||||
|
ret = -1;
|
||||||
|
goto cleanup;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sign the hash using GenerateHardRandom
|
||||||
|
ret = mbedtls_pk_sign(&key, hash_to_md_type(hash_alg), hash, hash_size,
|
||||||
|
*signature, sig_len, GenerateHardRandom, 0);
|
||||||
|
|
||||||
|
if (ret != 0) {
|
||||||
|
WARNF("(ecdsa) Failed to sign message: -0x%04x", -ret);
|
||||||
|
free(*signature);
|
||||||
*signature = NULL;
|
*signature = NULL;
|
||||||
*sig_len = 0;
|
*sig_len = 0;
|
||||||
|
goto cleanup;
|
||||||
|
}
|
||||||
|
|
||||||
if (!priv_key_pem || strlen(priv_key_pem) == 0) {
|
cleanup:
|
||||||
WARNF("(ecdsa) Private key is NULL or empty");
|
mbedtls_pk_free(&key);
|
||||||
return -1;
|
return ret;
|
||||||
}
|
} // Lua binding for signing a message
|
||||||
|
|
||||||
mbedtls_pk_init(&key);
|
|
||||||
|
|
||||||
// Parse the private key from PEM (PKCS#8 format)
|
|
||||||
ret = mbedtls_pk_parse_key(&key, (const unsigned char *)priv_key_pem,
|
|
||||||
strlen(priv_key_pem) + 1, NULL, 0);
|
|
||||||
if (ret != 0) {
|
|
||||||
WARNF("(ecdsa) Failed to parse private key: -0x%04x", -ret);
|
|
||||||
mbedtls_pk_free(&key);
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute hash of the message
|
|
||||||
hash_size = get_hash_size(hash_alg);
|
|
||||||
ret = compute_hash(hash_alg, (const unsigned char *)message, strlen(message),
|
|
||||||
hash, sizeof(hash));
|
|
||||||
if (ret != 0) {
|
|
||||||
WARNF("(ecdsa) Failed to compute message hash");
|
|
||||||
mbedtls_pk_free(&key);
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Allocate memory for the signature
|
|
||||||
*signature = malloc(MBEDTLS_PK_SIGNATURE_MAX_SIZE);
|
|
||||||
if (*signature == NULL) {
|
|
||||||
WARNF("(ecdsa) Failed to allocate memory for signature");
|
|
||||||
mbedtls_pk_free(&key);
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sign the hash
|
|
||||||
ret = mbedtls_pk_sign(&key, hash_to_md_type(hash_alg), hash, hash_size,
|
|
||||||
*signature, sig_len, GenerateHardRandom, NULL);
|
|
||||||
if (ret != 0) {
|
|
||||||
WARNF("(ecdsa) Failed to sign message: -0x%04x", -ret);
|
|
||||||
free(*signature);
|
|
||||||
*signature = NULL;
|
|
||||||
*sig_len = 0;
|
|
||||||
mbedtls_pk_free(&key);
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
mbedtls_pk_free(&key);
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
// Lua binding for signing a message
|
|
||||||
static int LuaECDSASign(lua_State *L) {
|
static int LuaECDSASign(lua_State *L) {
|
||||||
const char *hash_name = luaL_optstring(L, 3, "sha256"); // Default to SHA-256
|
// Correct order: priv_key, message, hash_name (default sha256)
|
||||||
const char *message = luaL_checkstring(L, 2);
|
|
||||||
const char *priv_key_pem = luaL_checkstring(L, 1);
|
const char *priv_key_pem = luaL_checkstring(L, 1);
|
||||||
|
const char *message = luaL_checkstring(L, 2);
|
||||||
|
const char *hash_name = luaL_optstring(L, 3, "sha256");
|
||||||
|
|
||||||
hash_algorithm_t hash_alg = string_to_hash_alg(hash_name);
|
hash_algorithm_t hash_alg = string_to_hash_alg(hash_name);
|
||||||
|
|
||||||
|
@ -1046,12 +1070,12 @@ cleanup:
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
static int LuaECDSAVerify(lua_State *L) {
|
static int LuaECDSAVerify(lua_State *L) {
|
||||||
|
// Correct order: pub_key, message, signature, hash_name (default sha256)
|
||||||
const char *pub_key_pem = luaL_checkstring(L, 1);
|
const char *pub_key_pem = luaL_checkstring(L, 1);
|
||||||
const char *message = luaL_checkstring(L, 2);
|
const char *message = luaL_checkstring(L, 2);
|
||||||
size_t sig_len;
|
size_t sig_len;
|
||||||
const unsigned char *signature =
|
const unsigned char *signature = (const unsigned char *)luaL_checklstring(L, 3, &sig_len);
|
||||||
(const unsigned char *)luaL_checklstring(L, 3, &sig_len);
|
const char *hash_name = luaL_optstring(L, 4, "sha256");
|
||||||
const char *hash_name = luaL_optstring(L, 4, "sha256"); // Default to SHA-256
|
|
||||||
|
|
||||||
hash_algorithm_t hash_alg = string_to_hash_alg(hash_name);
|
hash_algorithm_t hash_alg = string_to_hash_alg(hash_name);
|
||||||
|
|
||||||
|
@ -1061,6 +1085,437 @@ static int LuaECDSAVerify(lua_State *L) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// AES
|
||||||
|
// AES key generation helper
|
||||||
|
static int LuaAesGenerateKey(lua_State *L) {
|
||||||
|
int keybits = 128;
|
||||||
|
if (lua_gettop(L) >= 1 && !lua_isnil(L, 1)) {
|
||||||
|
keybits = luaL_checkinteger(L, 1);
|
||||||
|
}
|
||||||
|
int keylen = keybits / 8;
|
||||||
|
if ((keybits != 128 && keybits != 192 && keybits != 256) || (keylen != 16 && keylen != 24 && keylen != 32)) {
|
||||||
|
lua_pushnil(L);
|
||||||
|
lua_pushstring(L, "AES key length must be 128, 192, or 256 bits");
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
unsigned char key[32];
|
||||||
|
mbedtls_entropy_context entropy;
|
||||||
|
mbedtls_ctr_drbg_context ctr_drbg;
|
||||||
|
mbedtls_entropy_init(&entropy);
|
||||||
|
mbedtls_ctr_drbg_init(&ctr_drbg);
|
||||||
|
const char *pers = "aes_keygen";
|
||||||
|
int ret = mbedtls_ctr_drbg_seed(&ctr_drbg, mbedtls_entropy_func, &entropy, (const unsigned char *)pers, strlen(pers));
|
||||||
|
if (ret != 0) {
|
||||||
|
lua_pushnil(L);
|
||||||
|
lua_pushstring(L, "Failed to initialize RNG for AES key");
|
||||||
|
mbedtls_ctr_drbg_free(&ctr_drbg);
|
||||||
|
mbedtls_entropy_free(&entropy);
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
ret = mbedtls_ctr_drbg_random(&ctr_drbg, key, keylen);
|
||||||
|
mbedtls_ctr_drbg_free(&ctr_drbg);
|
||||||
|
mbedtls_entropy_free(&entropy);
|
||||||
|
if (ret != 0) {
|
||||||
|
lua_pushnil(L);
|
||||||
|
lua_pushstring(L, "Failed to generate random AES key");
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
lua_pushlstring(L, (const char *)key, keylen);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// AES encryption supporting CBC, GCM, and CTR modes
|
||||||
|
static int LuaAesEncrypt(lua_State *L) {
|
||||||
|
// Accept IV as the 3rd argument (after key, plaintext)
|
||||||
|
size_t keylen, ivlen = 0, ptlen;
|
||||||
|
const unsigned char *key = (const unsigned char *)luaL_checklstring(L, 1, &keylen);
|
||||||
|
const unsigned char *plaintext = (const unsigned char *)luaL_checklstring(L, 2, &ptlen);
|
||||||
|
const unsigned char *iv = NULL;
|
||||||
|
unsigned char *gen_iv = NULL;
|
||||||
|
int iv_was_generated = 0;
|
||||||
|
|
||||||
|
const char *mode = luaL_optstring(L, 4, "cbc"); // Default to CBC if not provided
|
||||||
|
int ret = 0;
|
||||||
|
unsigned char *output = NULL;
|
||||||
|
int is_gcm = 0, is_ctr = 0, is_cbc = 0;
|
||||||
|
|
||||||
|
if (strcasecmp(mode, "cbc") == 0) {
|
||||||
|
is_cbc = 1;
|
||||||
|
} else if (strcasecmp(mode, "gcm") == 0) {
|
||||||
|
is_gcm = 1;
|
||||||
|
} else if (strcasecmp(mode, "ctr") == 0) {
|
||||||
|
is_ctr = 1;
|
||||||
|
} else {
|
||||||
|
lua_pushnil(L);
|
||||||
|
lua_pushstring(L, "Unsupported AES mode. Use 'cbc', 'gcm', or 'ctr'.");
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If IV is not provided (arg3 is nil or missing), auto-generate
|
||||||
|
if (lua_isnoneornil(L, 3)) {
|
||||||
|
// For GCM, standard is 12 bytes, but allow 12-16
|
||||||
|
if (is_gcm) {
|
||||||
|
ivlen = 12;
|
||||||
|
} else {
|
||||||
|
ivlen = 16;
|
||||||
|
}
|
||||||
|
gen_iv = malloc(ivlen);
|
||||||
|
if (!gen_iv) {
|
||||||
|
lua_pushnil(L);
|
||||||
|
lua_pushstring(L, "Failed to allocate IV");
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
mbedtls_entropy_context entropy;
|
||||||
|
mbedtls_ctr_drbg_context ctr_drbg;
|
||||||
|
mbedtls_entropy_init(&entropy);
|
||||||
|
mbedtls_ctr_drbg_init(&ctr_drbg);
|
||||||
|
mbedtls_ctr_drbg_seed(&ctr_drbg, mbedtls_entropy_func, &entropy, NULL, 0);
|
||||||
|
mbedtls_ctr_drbg_random(&ctr_drbg, gen_iv, ivlen);
|
||||||
|
mbedtls_ctr_drbg_free(&ctr_drbg);
|
||||||
|
mbedtls_entropy_free(&entropy);
|
||||||
|
iv = gen_iv;
|
||||||
|
iv_was_generated = 1;
|
||||||
|
} else {
|
||||||
|
// IV provided
|
||||||
|
iv = (const unsigned char *)luaL_checklstring(L, 3, &ivlen);
|
||||||
|
// Do not force ivlen to 16 here! Accept actual length for GCM (12-16)
|
||||||
|
if (is_cbc || is_ctr) {
|
||||||
|
if (ivlen != 16) {
|
||||||
|
lua_pushnil(L);
|
||||||
|
lua_pushstring(L, "AES IV must be 16 bytes for CBC/CTR");
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
} else if (is_gcm) {
|
||||||
|
if (ivlen < 12 || ivlen > 16) {
|
||||||
|
lua_pushnil(L);
|
||||||
|
lua_pushstring(L, "AES GCM IV/nonce must be 12-16 bytes");
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
iv_was_generated = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (is_cbc) {
|
||||||
|
// PKCS7 padding
|
||||||
|
size_t block_size = 16;
|
||||||
|
size_t padlen = block_size - (ptlen % block_size);
|
||||||
|
size_t ctlen = ptlen + padlen;
|
||||||
|
unsigned char *input = malloc(ctlen);
|
||||||
|
if (!input) {
|
||||||
|
lua_pushnil(L);
|
||||||
|
lua_pushstring(L, "Memory allocation failed");
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
memcpy(input, plaintext, ptlen);
|
||||||
|
memset(input + ptlen, (unsigned char)padlen, padlen);
|
||||||
|
output = malloc(ctlen);
|
||||||
|
if (!output) {
|
||||||
|
free(input);
|
||||||
|
lua_pushnil(L);
|
||||||
|
lua_pushstring(L, "Memory allocation failed");
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
mbedtls_aes_context aes;
|
||||||
|
mbedtls_aes_init(&aes);
|
||||||
|
ret = mbedtls_aes_setkey_enc(&aes, key, keylen * 8);
|
||||||
|
if (ret != 0) {
|
||||||
|
free(input);
|
||||||
|
free(output);
|
||||||
|
mbedtls_aes_free(&aes);
|
||||||
|
lua_pushnil(L);
|
||||||
|
lua_pushstring(L, "Failed to set AES encryption key");
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
unsigned char iv_copy[16];
|
||||||
|
memcpy(iv_copy, iv, 16);
|
||||||
|
ret = mbedtls_aes_crypt_cbc(&aes, MBEDTLS_AES_ENCRYPT, ctlen, iv_copy, input, output);
|
||||||
|
mbedtls_aes_free(&aes);
|
||||||
|
free(input);
|
||||||
|
if (ret != 0) {
|
||||||
|
free(output);
|
||||||
|
lua_pushnil(L);
|
||||||
|
lua_pushstring(L, "AES CBC encryption failed");
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
lua_pushlstring(L, (const char *)output, ctlen);
|
||||||
|
lua_pushlstring(L, (const char *)iv, ivlen);
|
||||||
|
free(output);
|
||||||
|
if (iv_was_generated) free(gen_iv);
|
||||||
|
return 2;
|
||||||
|
} else if (is_ctr) {
|
||||||
|
// CTR mode: no padding
|
||||||
|
output = malloc(ptlen);
|
||||||
|
if (!output) {
|
||||||
|
lua_pushnil(L);
|
||||||
|
lua_pushstring(L, "Memory allocation failed");
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
mbedtls_aes_context aes;
|
||||||
|
mbedtls_aes_init(&aes);
|
||||||
|
ret = mbedtls_aes_setkey_enc(&aes, key, keylen * 8);
|
||||||
|
if (ret != 0) {
|
||||||
|
free(output);
|
||||||
|
mbedtls_aes_free(&aes);
|
||||||
|
lua_pushnil(L);
|
||||||
|
lua_pushstring(L, "Failed to set AES encryption key");
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
unsigned char nonce_counter[16];
|
||||||
|
unsigned char stream_block[16];
|
||||||
|
size_t nc_off = 0;
|
||||||
|
memcpy(nonce_counter, iv, 16);
|
||||||
|
memset(stream_block, 0, 16);
|
||||||
|
ret = mbedtls_aes_crypt_ctr(&aes, ptlen, &nc_off, nonce_counter, stream_block, plaintext, output);
|
||||||
|
mbedtls_aes_free(&aes);
|
||||||
|
if (ret != 0) {
|
||||||
|
free(output);
|
||||||
|
lua_pushnil(L);
|
||||||
|
lua_pushstring(L, "AES CTR encryption failed");
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
lua_pushlstring(L, (const char *)output, ptlen);
|
||||||
|
lua_pushlstring(L, (const char *)iv, ivlen);
|
||||||
|
free(output);
|
||||||
|
if (iv_was_generated) free(gen_iv);
|
||||||
|
return 2;
|
||||||
|
} else if (is_gcm) {
|
||||||
|
// GCM mode: authenticated encryption
|
||||||
|
size_t taglen = 16;
|
||||||
|
unsigned char tag[16];
|
||||||
|
output = malloc(ptlen);
|
||||||
|
if (!output) {
|
||||||
|
lua_pushnil(L);
|
||||||
|
lua_pushstring(L, "Memory allocation failed");
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
mbedtls_gcm_context gcm;
|
||||||
|
mbedtls_gcm_init(&gcm);
|
||||||
|
ret = mbedtls_gcm_setkey(&gcm, MBEDTLS_CIPHER_ID_AES, key, keylen * 8);
|
||||||
|
if (ret != 0) {
|
||||||
|
free(output);
|
||||||
|
mbedtls_gcm_free(&gcm);
|
||||||
|
lua_pushnil(L);
|
||||||
|
lua_pushstring(L, "Failed to set AES GCM key");
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
// Use actual ivlen, not hardcoded 16
|
||||||
|
ret = mbedtls_gcm_crypt_and_tag(&gcm, MBEDTLS_GCM_ENCRYPT, ptlen, iv, ivlen, NULL, 0, plaintext, output, taglen, tag);
|
||||||
|
mbedtls_gcm_free(&gcm);
|
||||||
|
if (ret != 0) {
|
||||||
|
free(output);
|
||||||
|
lua_pushnil(L);
|
||||||
|
lua_pushstring(L, "AES GCM encryption failed");
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
lua_pushlstring(L, (const char *)output, ptlen);
|
||||||
|
lua_pushlstring(L, (const char *)iv, ivlen);
|
||||||
|
lua_pushlstring(L, (const char *)tag, taglen);
|
||||||
|
free(output);
|
||||||
|
if (iv_was_generated) free(gen_iv);
|
||||||
|
return 3;
|
||||||
|
}
|
||||||
|
lua_pushnil(L);
|
||||||
|
lua_pushstring(L, "Internal error in AES encrypt");
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// AES decryption supporting CBC, GCM, and CTR modes
|
||||||
|
static int LuaAesDecrypt(lua_State *L) {
|
||||||
|
size_t keylen, ctlen, ivlen;
|
||||||
|
const unsigned char *key = (const unsigned char *)luaL_checklstring(L, 1, &keylen);
|
||||||
|
const unsigned char *ciphertext = (const unsigned char *)luaL_checklstring(L, 2, &ctlen);
|
||||||
|
const unsigned char *iv = (const unsigned char *)luaL_checklstring(L, 3, &ivlen);
|
||||||
|
const char *mode = luaL_optstring(L, 4, "cbc"); // Default to CBC if not provided
|
||||||
|
const unsigned char *aad = NULL;
|
||||||
|
const unsigned char *tag = NULL;
|
||||||
|
size_t aadlen = 0, taglen = 0;
|
||||||
|
int is_gcm = 0, is_ctr = 0, is_cbc = 0;
|
||||||
|
|
||||||
|
if (strcasecmp(mode, "cbc") == 0) {
|
||||||
|
is_cbc = 1;
|
||||||
|
} else if (strcasecmp(mode, "gcm") == 0) {
|
||||||
|
is_gcm = 1;
|
||||||
|
} else if (strcasecmp(mode, "ctr") == 0) {
|
||||||
|
is_ctr = 1;
|
||||||
|
} else {
|
||||||
|
lua_pushnil(L);
|
||||||
|
lua_pushstring(L, "Unsupported AES mode. Use 'cbc', 'gcm', or 'ctr'.");
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate key length (16, 24, 32 bytes)
|
||||||
|
if (keylen != 16 && keylen != 24 && keylen != 32) {
|
||||||
|
lua_pushnil(L);
|
||||||
|
lua_pushstring(L, "AES key must be 16, 24, or 32 bytes");
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
// Validate IV/nonce length
|
||||||
|
if (is_cbc || is_ctr) {
|
||||||
|
if (ivlen != 16) {
|
||||||
|
lua_pushnil(L);
|
||||||
|
lua_pushstring(L, "AES IV/nonce must be 16 bytes for CBC/CTR");
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
} else if (is_gcm) {
|
||||||
|
if (ivlen < 12 || ivlen > 16) {
|
||||||
|
lua_pushnil(L);
|
||||||
|
lua_pushstring(L, "AES GCM nonce must be 12-16 bytes");
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GCM: require tag and optional AAD
|
||||||
|
if (is_gcm) {
|
||||||
|
if (!lua_isnoneornil(L, 5)) {
|
||||||
|
aad = (const unsigned char *)luaL_checklstring(L, 5, &aadlen);
|
||||||
|
}
|
||||||
|
if (!lua_isnoneornil(L, 6)) {
|
||||||
|
tag = (const unsigned char *)luaL_checklstring(L, 6, &taglen);
|
||||||
|
if (taglen < 12 || taglen > 16) {
|
||||||
|
lua_pushnil(L);
|
||||||
|
lua_pushstring(L, "AES GCM tag must be 12-16 bytes");
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
lua_pushnil(L);
|
||||||
|
lua_pushstring(L, "AES GCM tag required as 6th argument");
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int ret = 0;
|
||||||
|
unsigned char *output = NULL;
|
||||||
|
|
||||||
|
if (is_cbc) {
|
||||||
|
// Ciphertext must be a multiple of block size
|
||||||
|
if (ctlen == 0 || (ctlen % 16) != 0) {
|
||||||
|
lua_pushnil(L);
|
||||||
|
lua_pushstring(L, "Ciphertext length must be a multiple of 16");
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
output = malloc(ctlen);
|
||||||
|
if (!output) {
|
||||||
|
lua_pushnil(L);
|
||||||
|
lua_pushstring(L, "Memory allocation failed");
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
mbedtls_aes_context aes;
|
||||||
|
mbedtls_aes_init(&aes);
|
||||||
|
ret = mbedtls_aes_setkey_dec(&aes, key, keylen * 8);
|
||||||
|
if (ret != 0) {
|
||||||
|
free(output);
|
||||||
|
mbedtls_aes_free(&aes);
|
||||||
|
lua_pushnil(L);
|
||||||
|
lua_pushstring(L, "Failed to set AES decryption key");
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
unsigned char iv_copy[16];
|
||||||
|
memcpy(iv_copy, iv, 16);
|
||||||
|
ret = mbedtls_aes_crypt_cbc(&aes, MBEDTLS_AES_DECRYPT, ctlen, iv_copy, ciphertext, output);
|
||||||
|
mbedtls_aes_free(&aes);
|
||||||
|
if (ret != 0) {
|
||||||
|
free(output);
|
||||||
|
lua_pushnil(L);
|
||||||
|
lua_pushstring(L, "AES CBC decryption failed");
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
// PKCS7 unpadding
|
||||||
|
if (ctlen == 0) {
|
||||||
|
free(output);
|
||||||
|
lua_pushnil(L);
|
||||||
|
lua_pushstring(L, "Decrypted data is empty");
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
unsigned char pad = output[ctlen - 1];
|
||||||
|
if (pad == 0 || pad > 16) {
|
||||||
|
free(output);
|
||||||
|
lua_pushnil(L);
|
||||||
|
lua_pushstring(L, "Invalid PKCS7 padding");
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < pad; ++i) {
|
||||||
|
if (output[ctlen - 1 - i] != pad) {
|
||||||
|
free(output);
|
||||||
|
lua_pushnil(L);
|
||||||
|
lua_pushstring(L, "Invalid PKCS7 padding");
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
size_t ptlen = ctlen - pad;
|
||||||
|
lua_pushlstring(L, (const char *)output, ptlen);
|
||||||
|
free(output);
|
||||||
|
return 1;
|
||||||
|
} else if (is_ctr) {
|
||||||
|
// CTR mode: no padding
|
||||||
|
output = malloc(ctlen);
|
||||||
|
if (!output) {
|
||||||
|
lua_pushnil(L);
|
||||||
|
lua_pushstring(L, "Memory allocation failed");
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
mbedtls_aes_context aes;
|
||||||
|
mbedtls_aes_init(&aes);
|
||||||
|
ret = mbedtls_aes_setkey_enc(&aes, key, keylen * 8);
|
||||||
|
if (ret != 0) {
|
||||||
|
free(output);
|
||||||
|
mbedtls_aes_free(&aes);
|
||||||
|
lua_pushnil(L);
|
||||||
|
lua_pushstring(L, "Failed to set AES encryption key");
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
unsigned char nonce_counter[16];
|
||||||
|
unsigned char stream_block[16];
|
||||||
|
size_t nc_off = 0;
|
||||||
|
memcpy(nonce_counter, iv, 16);
|
||||||
|
memset(stream_block, 0, 16);
|
||||||
|
ret = mbedtls_aes_crypt_ctr(&aes, ctlen, &nc_off, nonce_counter, stream_block, ciphertext, output);
|
||||||
|
mbedtls_aes_free(&aes);
|
||||||
|
if (ret != 0) {
|
||||||
|
free(output);
|
||||||
|
lua_pushnil(L);
|
||||||
|
lua_pushstring(L, "AES CTR decryption failed");
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
lua_pushlstring(L, (const char *)output, ctlen);
|
||||||
|
free(output);
|
||||||
|
return 1;
|
||||||
|
} else if (is_gcm) {
|
||||||
|
// GCM mode: authenticated decryption
|
||||||
|
output = malloc(ctlen);
|
||||||
|
if (!output) {
|
||||||
|
lua_pushnil(L);
|
||||||
|
lua_pushstring(L, "Memory allocation failed");
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
mbedtls_gcm_context gcm;
|
||||||
|
mbedtls_gcm_init(&gcm);
|
||||||
|
ret = mbedtls_gcm_setkey(&gcm, MBEDTLS_CIPHER_ID_AES, key, keylen * 8);
|
||||||
|
if (ret != 0) {
|
||||||
|
free(output);
|
||||||
|
mbedtls_gcm_free(&gcm);
|
||||||
|
lua_pushnil(L);
|
||||||
|
lua_pushstring(L, "Failed to set AES GCM key");
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
ret = mbedtls_gcm_auth_decrypt(&gcm, ctlen, iv, ivlen, aad, aadlen, tag, taglen, ciphertext, output);
|
||||||
|
mbedtls_gcm_free(&gcm);
|
||||||
|
if (ret != 0) {
|
||||||
|
free(output);
|
||||||
|
lua_pushnil(L);
|
||||||
|
lua_pushstring(L, "AES GCM decryption failed or authentication failed");
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
lua_pushlstring(L, (const char *)output, ctlen);
|
||||||
|
free(output);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
lua_pushnil(L);
|
||||||
|
lua_pushstring(L, "Internal error in AES decrypt");
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// LuaCrypto compatible API
|
||||||
static int LuaCryptoSign(lua_State *L) {
|
static int LuaCryptoSign(lua_State *L) {
|
||||||
const char *dtype = luaL_checkstring(L, 1); // Type of signature (e.g., "rsa", "ecdsa")
|
const char *dtype = luaL_checkstring(L, 1); // Type of signature (e.g., "rsa", "ecdsa")
|
||||||
lua_remove(L, 1); // Remove the first argument (key type or cipher type) before dispatching
|
lua_remove(L, 1); // Remove the first argument (key type or cipher type) before dispatching
|
||||||
|
@ -1088,41 +1543,48 @@ static int LuaCryptoVerify(lua_State *L) {
|
||||||
}
|
}
|
||||||
|
|
||||||
static int LuaCryptoEncrypt(lua_State *L) {
|
static int LuaCryptoEncrypt(lua_State *L) {
|
||||||
const char *cipher = luaL_checkstring(L, 1); // Cipher type (e.g., "rsa")
|
const char *cipher = luaL_checkstring(L, 1); // Cipher type (e.g., "rsa", "aes")
|
||||||
lua_remove(L, 1); // Remove the first argument (key type or cipher type) before dispatching
|
lua_remove(L, 1); // Remove the first argument (key type or cipher type) before dispatching
|
||||||
|
|
||||||
if (strcasecmp(cipher, "rsa") == 0) {
|
if (strcasecmp(cipher, "rsa") == 0) {
|
||||||
return LuaRSAEncrypt(L);
|
return LuaRSAEncrypt(L);
|
||||||
|
} else if (strcasecmp(cipher, "aes") == 0) {
|
||||||
|
return LuaAesEncrypt(L);
|
||||||
} else {
|
} else {
|
||||||
return luaL_error(L, "Unsupported cipher type: %s", cipher);
|
return luaL_error(L, "Unsupported cipher type: %s", cipher);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static int LuaCryptoDecrypt(lua_State *L) {
|
static int LuaCryptoDecrypt(lua_State *L) {
|
||||||
const char *cipher = luaL_checkstring(L, 1); // Cipher type (e.g., "rsa")
|
const char *cipher = luaL_checkstring(L, 1); // Cipher type (e.g., "rsa", "aes")
|
||||||
lua_remove(L, 1); // Remove the first argument (key type or cipher type) before dispatching
|
lua_remove(L, 1); // Remove the first argument (key type or cipher type) before dispatching
|
||||||
|
|
||||||
if (strcasecmp(cipher, "rsa") == 0) {
|
if (strcasecmp(cipher, "rsa") == 0) {
|
||||||
return LuaRSADecrypt(L);
|
return LuaRSADecrypt(L);
|
||||||
|
} else if (strcasecmp(cipher, "aes") == 0) {
|
||||||
|
return LuaAesDecrypt(L);
|
||||||
} else {
|
} else {
|
||||||
return luaL_error(L, "Unsupported cipher type: %s", cipher);
|
return luaL_error(L, "Unsupported cipher type: %s", cipher);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static int LuaCryptoGenerateKeyPair(lua_State *L) {
|
static int LuaCryptoGenerateKeyPair(lua_State *L) {
|
||||||
const char *key_type = "rsa"; // Key type (e.g., "rsa", "ecdsa")
|
// If the first argument is a number, treat as RSA key length
|
||||||
|
if (lua_gettop(L) >= 1 && lua_type(L, 1) == LUA_TNUMBER) {
|
||||||
if (! lua_isinteger(L, 1) && ! lua_isnoneornil(L, 1)) {
|
// Call LuaRSAGenerateKeyPair with the number as the key length
|
||||||
key_type = luaL_checkstring(L, 1); // Get key type from first argumen
|
|
||||||
lua_remove(L, 1); // Remove the first argument (key type or cipher type) before dispatching
|
|
||||||
}
|
|
||||||
|
|
||||||
if (strcasecmp(key_type, "rsa") == 0) {
|
|
||||||
return LuaRSAGenerateKeyPair(L);
|
return LuaRSAGenerateKeyPair(L);
|
||||||
} else if (strcasecmp(key_type, "ecdsa") == 0) {
|
}
|
||||||
|
// Otherwise, get the key type from the first argument, default to "rsa" if not provided
|
||||||
|
const char *type = luaL_optstring(L, 1, "rsa");
|
||||||
|
lua_remove(L, 1);
|
||||||
|
if (strcasecmp(type, "rsa") == 0) {
|
||||||
|
return LuaRSAGenerateKeyPair(L);
|
||||||
|
} else if (strcasecmp(type, "ecdsa") == 0) {
|
||||||
return LuaECDSAGenerateKeyPair(L);
|
return LuaECDSAGenerateKeyPair(L);
|
||||||
|
} else if (strcasecmp(type, "aes") == 0) {
|
||||||
|
return LuaAesGenerateKey(L);
|
||||||
} else {
|
} else {
|
||||||
return luaL_error(L, "Unsupported key type: %s", key_type);
|
return luaL_error(L, "Unsupported key type: %s", type);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1132,8 +1594,8 @@ static const luaL_Reg kLuaCrypto[] = {
|
||||||
{"encrypt", LuaCryptoEncrypt}, //
|
{"encrypt", LuaCryptoEncrypt}, //
|
||||||
{"decrypt", LuaCryptoDecrypt}, //
|
{"decrypt", LuaCryptoDecrypt}, //
|
||||||
{"generatekeypair", LuaCryptoGenerateKeyPair}, //
|
{"generatekeypair", LuaCryptoGenerateKeyPair}, //
|
||||||
{"convertPemToJwk", convertPemToJwk}, //
|
{"convertPemToJwk", LuaConvertPemToJwk}, //
|
||||||
{"generateCsr", generateCsr}, //
|
{"GenerateCsr", LuaGenerateCSR}, //
|
||||||
{0}, //
|
{0}, //
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue