diff --git a/test/tool/net/lcrypto_test.lua b/test/tool/net/lcrypto_test.lua index 0516a6ec2..27d4fb3fb 100644 --- a/test/tool/net/lcrypto_test.lua +++ b/test/tool/net/lcrypto_test.lua @@ -1,246 +1,753 @@ --- Helper function to print test results -local function assert_equal(actual, expected, message) - if actual ~= expected then - error("FAIL: " .. message .. ": expected " .. tostring(expected) .. ", got " .. tostring(actual)) - end -end - -local function assert_not_equal(actual, not_expected, message) - if actual == not_expected then - error(message .. ": did not expect " .. tostring(not_expected)) - end -end - +---@diagnostic disable: lowercase-global -- Test RSA key pair generation local function test_rsa_keypair_generation() - local priv_key, pub_key = crypto.generatekeypair("rsa", 2048) - assert_equal(type(priv_key), "string", "Private key type") - assert_equal(type(pub_key), "string", "Public key type") + local priv_key, pub_key = crypto.generateKeyPair("rsa", 2048) + assert(type(priv_key) == "string", "Private key type") + assert(type(pub_key) == "string", "Public key type") end -- Test ECDSA key pair generation local function test_ecdsa_keypair_generation() - local priv_key, pub_key = crypto.generatekeypair("ecdsa", "secp256r1") - assert_equal(type(priv_key), "string", "Private key type") - assert_equal(type(pub_key), "string", "Public key type") + local priv_key, pub_key = crypto.generateKeyPair("ecdsa", "secp256r1") + assert(type(priv_key) == "string", "Private key type") + assert(type(pub_key) == "string", "Public key type") end -- Test RSA encryption and decryption local function test_rsa_encryption_decryption() - local priv_key, pub_key = crypto.generatekeypair("rsa", 2048) + local priv_key, pub_key = crypto.generateKeyPair("rsa", 2048) assert(type(priv_key) == "string", "Private key type") assert(type(pub_key) == "string", "Public key type") + local plaintext = "Hello, RSA!" - local encrypted = crypto.encrypt("rsa", pub_key, plaintext) - assert_equal(type(encrypted), "string", "Ciphertext type") - local decrypted = crypto.decrypt("rsa", priv_key, encrypted) - assert_equal(decrypted, plaintext, "Decrypted ciphertext matches plaintext") + local ciphertext = crypto.encrypt("rsa", pub_key, plaintext) + assert(type(ciphertext) == "string", "Ciphertext type") + + local decrypted_plaintext = crypto.decrypt("rsa", priv_key, ciphertext) + assert(decrypted_plaintext == plaintext, "Decrypted ciphertext matches plaintext") end -- Test RSA signing and verification local function test_rsa_signing_verification() - local priv_key, pub_key = crypto.generatekeypair("rsa", 2048) + local priv_key, pub_key = crypto.generateKeyPair("rsa", 2048) assert(type(priv_key) == "string", "Private key type") assert(type(pub_key) == "string", "Public key type") + local message = "Sign this message" local signature = crypto.sign("rsa", priv_key, message, "sha256") - assert_equal(type(signature), "string", "Signature type") + assert(type(signature) == "string", "Signature type") + local is_valid = crypto.verify("rsa", pub_key, message, signature, "sha256") - assert_equal(is_valid, true, "Signature verification") + assert(is_valid == true, "Signature verification") +end + +-- Test RSA-PSS signing and verification +local function test_rsapss_signing_verification() + local priv_key, pub_key = crypto.generateKeyPair("rsa", 2048) + assert(type(priv_key) == "string", "Private key type") + assert(type(pub_key) == "string", "Public key type") + + Log(kLogVerbose," - Testing RSA-PSS signing") + local message = "Sign this message with RSA-PSS" + local signature = crypto.sign("rsapss", priv_key, message, "sha256") + assert(type(signature) == "string", "Signature type") + + Log(kLogVerbose," - Testing RSA-PSS verification") + local is_valid = crypto.verify("rsapss", pub_key, message, signature, "sha256") + assert(is_valid == true, "RSA-PSS Signature verification") + + -- Test with different hash algorithm + Log(kLogVerbose," - Testing RSA-PSS with different hash algorithms") + signature = crypto.sign("rsapss", priv_key, message, "sha384") + assert(type(signature) == "string", "SHA-384 Signature type") + is_valid = crypto.verify("rsapss", pub_key, message, signature, "sha384") + assert(is_valid == true, "RSA-PSS SHA-384 Signature verification") + + Log(kLogVerbose," - Testing RSA-PSS with SHA-512") + signature = crypto.sign("rsapss", priv_key, message, "sha512") + assert(type(signature) == "string", "SHA-512 Signature type") + is_valid = crypto.verify("rsapss", pub_key, message, signature, "sha512") + assert(is_valid == true, "RSA-PSS SHA-512 Signature verification") end -- Test ECDSA signing and verification local function test_ecdsa_signing_verification() - local priv_key, pub_key = crypto.generatekeypair("ecdsa", "secp256r1") + local priv_key, pub_key = crypto.generateKeyPair("ecdsa", "secp256r1") assert(type(priv_key) == "string", "Private key type") assert(type(pub_key) == "string", "Public key type") + local message = "Sign this message with ECDSA" local signature = crypto.sign("ecdsa", priv_key, message, "sha256") - assert_equal(type(signature), "string", "Signature type") + assert(type(signature) == "string", "Signature type") + local is_valid = crypto.verify("ecdsa", pub_key, message, signature, "sha256") - assert_equal(is_valid, true, "Signature verification") + assert(is_valid == true, "Signature verification") end -- Test AES key generation local function test_aes_key_generation() - local key = crypto.generatekeypair('aes', 256) -- 256-bit key - assert_equal(type(key), "string", "Key type") - assert_equal(#key, 32, "Key length (256 bits)") + local key = crypto.generateKeyPair('aes', 256) -- 256-bit key + assert(type(key) == "string", "Key type") + assert(#key == 32, "Key length (256 bits)") end -- Test AES encryption and decryption (CBC mode) local function test_aes_encryption_decryption_cbc() - local key = crypto.generatekeypair('aes', 256) -- 256-bit key + local key = crypto.generateKeyPair('aes', 256) -- 256-bit key local plaintext = "Hello, AES CBC!" -- Encrypt without providing IV (should auto-generate IV) - local encrypted, iv = crypto.encrypt("aes", key, plaintext, nil) - assert_equal(type(encrypted), "string", "Ciphertext type") - assert_equal(type(iv), "string", "IV type") + local ciphertext, iv = crypto.encrypt("aes", key, plaintext, nil) + assert(type(ciphertext) == "string", "Ciphertext type") + assert(type(iv) == "string", "IV type") -- Decrypt - local decrypted = crypto.decrypt("aes", key, encrypted, {mode="cbc",iv=iv}) - assert_equal(decrypted, plaintext, "Decrypted ciphertext matches plaintext") + local decrypted_plaintext = crypto.decrypt("aes", key, ciphertext, { mode = "cbc", iv = iv }) + assert(decrypted_plaintext == plaintext, "Decrypted ciphertext matches plaintext") -- Encrypt with explicit IV local iv2 = GetRandomBytes(16) - local encrypted2, iv_used = crypto.encrypt("aes", key, plaintext, {mode="cbc",iv=iv2}) - assert_equal(type(encrypted2), "string", "Ciphertext type") - assert_equal(iv_used, iv2, "IV match") + local ciphertext2, iv_used = crypto.encrypt("aes", key, plaintext, { mode = "cbc", iv = iv2 }) + assert(type(ciphertext2) == "string", "Ciphertext type") + assert(iv_used == iv2, "IV match") - local decrypted2 = crypto.decrypt("aes", key, encrypted2, {mode="cbc",iv=iv2}) - assert_equal(decrypted2, plaintext, "Decrypted ciphertext matches plaintext") + local decrypted_plaintext2 = crypto.decrypt("aes", key, ciphertext2, { mode = "cbc", iv = iv2 }) + assert(decrypted_plaintext2 == plaintext, "Decrypted ciphertext matches plaintext") end -- Test AES encryption and decryption (CTR mode) local function test_aes_encryption_decryption_ctr() - local key = crypto.generatekeypair('aes', 256) + local key = crypto.generateKeyPair('aes', 256) local plaintext = "Hello, AES CTR!" -- Encrypt without providing IV (should auto-generate IV) - local encrypted, iv = crypto.encrypt("aes", key, plaintext, {mode="ctr"}) - assert_equal(type(encrypted), "string", "Ciphertext type") - assert_equal(type(iv), "string", "IV type") + local ciphertext, iv = crypto.encrypt("aes", key, plaintext, { mode = "ctr" }) + assert(type(ciphertext) == "string", "Ciphertext type") + assert(type(iv) == "string", "IV type") -- Decrypt - local decrypted = crypto.decrypt("aes", key, encrypted, {mode="ctr", iv=iv}) - assert_equal(decrypted, plaintext, "Decrypted ciphertext matches plaintext") + local decrypted_plaintext = crypto.decrypt("aes", key, ciphertext, { mode = "ctr", iv = iv }) + assert(decrypted_plaintext == plaintext, "Decrypted ciphertext matches plaintext") -- Encrypt with explicit IV local iv2 = GetRandomBytes(16) - local encrypted2, iv_used = crypto.encrypt("aes", key, plaintext, {mode="ctr", iv=iv2}) - assert_equal(type(encrypted2), "string", "Ciphertext type") - assert_equal(iv_used, iv2, "IV match") + local ciphertext2, iv_used = crypto.encrypt("aes", key, plaintext, { mode = "ctr", iv = iv2 }) + assert(type(ciphertext2) == "string", "Ciphertext type") + assert(iv_used == iv2, "IV match") - local decrypted2 = crypto.decrypt("aes", key, encrypted2, {mode="ctr", iv=iv2}) - assert_equal(decrypted2, plaintext, "Decrypted ciphertext matches plaintext") + local decrypted_plaintext2 = crypto.decrypt("aes", key, ciphertext2, { mode = "ctr", iv = iv2 }) + assert(decrypted_plaintext2 == plaintext, "Decrypted ciphertext matches plaintext") end -- Test AES encryption and decryption (GCM mode) local function test_aes_encryption_decryption_gcm() - local key = crypto.generatekeypair('aes', 256) - assert_equal(type(key), "string", "key type") + local key = crypto.generateKeyPair('aes', 256) + assert(type(key) == "string", "key type") local plaintext = "Hello, AES GCM!" -- Encrypt without providing IV (should auto-generate IV) - local encrypted, iv, tag = crypto.encrypt("aes", key, plaintext, {mode="gcm"}) - assert_equal(#plaintext, #encrypted, "Ciphertext length matches plaintext") - assert_equal(type(encrypted), "string", "Ciphertext type") - assert_equal(type(iv), "string", "IV type") - assert_equal(type(tag), "string", "Tag type") + local ciphertext, iv, tag = crypto.encrypt("aes", key, plaintext, { mode = "gcm" }) + assert(#plaintext == #ciphertext, "Ciphertext length matches plaintext") + assert(type(ciphertext) == "string", "Ciphertext type") + assert(type(iv) == "string", "IV type") + assert(type(tag) == "string", "Tag type") -- Decrypt - local decrypted = crypto.decrypt("aes", key, encrypted, {mode="gcm",iv=iv,tag=tag}) - assert_equal(decrypted, plaintext, "Decrypted ciphertext matches plaintext") + local decrypted_plaintext = crypto.decrypt("aes", key, ciphertext, { mode = "gcm", iv = iv, tag = tag }) + assert(decrypted_plaintext == plaintext, "Decrypted ciphertext matches plaintext") -- Encrypt with explicit IV local iv2 = GetRandomBytes(13) -- GCM IV/nonce can be 12-16 bytes, 12 is standard - local encrypted2, iv_used, tag2 = crypto.encrypt("aes", key, plaintext, {mode="gcm",iv=iv2}) - assert_equal(type(encrypted2), "string", "Ciphertext type") - assert_equal(iv_used, iv2, "IV match") - assert_equal(type(tag2), "string", "Tag type") + local ciphertext2, iv_used, tag2 = crypto.encrypt("aes", key, plaintext, { mode = "gcm", iv = iv2 }) + assert(type(ciphertext2) == "string", "Ciphertext type") + assert(iv_used == iv2, "IV match") + assert(type(tag2) == "string", "Tag type") - local decrypted2 = crypto.decrypt("aes", key, encrypted2, {mode="gcm",iv=iv2,tag=tag2}) - assert_equal(decrypted2, plaintext, "Decrypted ciphertext matches plaintext") + local decrypted_plaintext2 = crypto.decrypt("aes", key, ciphertext2, { mode = "gcm", iv = iv2, tag = tag2 }) + assert(decrypted_plaintext2 == plaintext, "Decrypted ciphertext matches plaintext") end -- Test PemToJwk conversion local function test_pem_to_jwk() - local priv_key, pub_key = crypto.generatekeypair() + local priv_key, pub_key = crypto.generateKeyPair() local priv_jwk = crypto.convertPemToJwk(priv_key) - assert_equal(type(priv_jwk), "table", "JWK type") - assert_equal(priv_jwk.kty, "RSA", "kty is correct") + assert(type(priv_jwk) == "table", "JWK type") + assert(priv_jwk.kty == "RSA", "kty is correct") local pub_jwk = crypto.convertPemToJwk(pub_key) - assert_equal(type(pub_jwk), "table", "JWK type") - assert_equal(pub_jwk.kty, "RSA", "kty is correct") + assert(type(pub_jwk) == "table", "JWK type") + assert(pub_jwk.kty == "RSA", "kty is correct") -- Test ECDSA keys - local priv_key, pub_key = crypto.generatekeypair('ecdsa') - local priv_jwk = crypto.convertPemToJwk(priv_key) - assert_equal(type(priv_jwk), "table", "JWK type") - assert_equal(priv_jwk.kty, "EC", "kty is correct") + priv_key, pub_key = crypto.generateKeyPair('ecdsa') + priv_jwk = crypto.convertPemToJwk(priv_key) + assert(type(priv_jwk) == "table", "JWK type") + assert(priv_jwk.kty == "EC", "kty is correct") - local pub_jwk = crypto.convertPemToJwk(pub_key) - assert_equal(type(pub_jwk), "table", "JWK type") - assert_equal(pub_jwk.kty, "EC", "kty is correct") + pub_jwk = crypto.convertPemToJwk(pub_key) + assert(type(pub_jwk) == "table", "JWK type") + assert(pub_jwk.kty == "EC", "kty is correct") end -- Test JwkToPem conversion local function test_jwk_to_pem() - local priv_key, pub_key = crypto.generatekeypair() + local priv_key, pub_key = crypto.generateKeyPair() local priv_jwk = crypto.convertPemToJwk(priv_key) local pub_jwk = crypto.convertPemToJwk(pub_key) local priv_pem = crypto.convertJwkToPem(priv_jwk) local pub_pem = crypto.convertJwkToPem(pub_jwk) - assert_equal(type(priv_pem), "string", "Private PEM type") + assert(type(priv_pem) == "string", "Private PEM type") -- Roundtrip - assert_equal(priv_key,priv_pem, "Private PEM matches original RSA key") - assert_equal(pub_key,pub_pem, "Public PEM matches original RSA key") + assert(priv_key == priv_pem, "Private PEM matches original RSA key") + assert(pub_key == pub_pem, "Public PEM matches original RSA key") - local pub_pem = crypto.convertJwkToPem(pub_jwk) - assert_equal(type(pub_pem), "string", "Public PEM type") + pub_pem = crypto.convertJwkToPem(pub_jwk) + assert(type(pub_pem) == "string", "Public PEM type") -- Test ECDSA keys - local priv_key, pub_key = crypto.generatekeypair('ecdsa') - local priv_jwk = crypto.convertPemToJwk(priv_key) - local pub_jwk = crypto.convertPemToJwk(pub_key) + priv_key, pub_key = crypto.generateKeyPair('ecdsa') + priv_jwk = crypto.convertPemToJwk(priv_key) + pub_jwk = crypto.convertPemToJwk(pub_key) + + priv_pem = crypto.convertJwkToPem(priv_jwk) + pub_pem = crypto.convertJwkToPem(pub_jwk) + assert(type(priv_pem) == "string", "Private PEM type for ECDSA") - local priv_pem = crypto.convertJwkToPem(priv_jwk) - local pub_pem = crypto.convertJwkToPem(pub_jwk) - assert_equal(type(priv_pem), "string", "Private PEM type for ECDSA") - -- Roundtrip - assert_equal(priv_key,priv_pem, "Private PEM matches original ECDSA key") - assert_equal(pub_key,pub_pem, "Public PEM matches original ECDSA key") + assert(priv_key == priv_pem, "Private PEM matches original ECDSA key") + assert(pub_key == pub_pem, "Public PEM matches original ECDSA key") - local pub_pem = crypto.convertJwkToPem(pub_jwk) - assert_equal(type(pub_pem), "string", "Public PEM type for ECDSA") + pub_pem = crypto.convertJwkToPem(pub_jwk) + assert(type(pub_pem) == "string", "Public PEM type for ECDSA") end -- Test CSR generation local function test_csr_generation() - local priv_key, _ = crypto.generatekeypair() + local priv_key, _ = crypto.generateKeyPair() local subject_name = "CN=example.com,O=Example Org,C=US" local san = "DNS:example.com, DNS:www.example.com, IP:192.168.1.1" - assert_equal(type(priv_key), "string", "Private key type") + assert(type(priv_key) == "string", "Private key type") local csr = crypto.generateCsr(priv_key, subject_name) - assert_equal(type(csr), "string", "CSR generation with subject name") + assert(type(csr) == "string", "CSR generation with subject name") csr = crypto.generateCsr(priv_key, subject_name, san) - assert_equal(type(csr), "string", "CSR generation with subject name and san") + assert(type(csr) == "string", "CSR generation with subject name and san") csr = crypto.generateCsr(priv_key, nil, san) - assert_equal(type(csr), "string", "CSR generation with nil subject name and san") + assert(type(csr) == "string", "CSR generation with nil subject name and san") csr = crypto.generateCsr(priv_key, '', san) - assert_equal(type(csr), "string", "CSR generation with empty subject name and san") + assert(type(csr) == "string", "CSR generation with empty subject name and san") -- These should fail csr = crypto.generateCsr(priv_key, '') - assert_not_equal(type(csr), "string", "CSR generation with empty subject name and no san is rejected") + assert(type(csr) ~= "string", "CSR generation with empty subject name and no san is rejected") csr = crypto.generateCsr(priv_key) - assert_not_equal(type(csr), "string", "CSR generation with nil subject name and no san is rejected") + assert(type(csr) ~= "string", "CSR generation with nil subject name and no san is rejected") +end + +-- Test various hash algorithms +local function test_hash_algorithms() + local priv_key, pub_key = crypto.generateKeyPair("rsa", 2048) + local message = "Test message for hash algorithms" + + -- Test different hash algorithms for RSA signatures + local hash_algorithms = { "sha256", "sha384", "sha512" } + for _, hash in ipairs(hash_algorithms) do + local signature = crypto.sign("rsa", priv_key, message, hash) + assert(type(signature) == "string", "RSA signature with " .. hash) + local is_valid = crypto.verify("rsa", pub_key, message, signature, hash) + assert(is_valid == true, "RSA verification with " .. hash) + + -- Test with RSA-PSS + local signature_pss = crypto.sign("rsapss", priv_key, message, hash) + assert(type(signature_pss) == "string", "RSA-PSS signature with " .. hash) + local is_valid_pss = crypto.verify("rsapss", pub_key, message, signature_pss, hash) + assert(is_valid_pss == true, "RSA-PSS verification with " .. hash) + end + + -- Test ECDSA with different hash algorithms + local ec_priv_key, ec_pub_key = crypto.generateKeyPair("ecdsa", "secp256r1") + for _, hash in ipairs(hash_algorithms) do + local signature = crypto.sign("ecdsa", ec_priv_key, message, hash) + assert(type(signature) == "string", "ECDSA signature with " .. hash) + + local is_valid = crypto.verify("ecdsa", ec_pub_key, message, signature, hash) + assert(is_valid == true, "ECDSA verification with " .. hash) + end +end + +-- Test negative cases for hash algorithms +local function test_negative_hash_algorithms() + local priv_key, pub_key = crypto.generateKeyPair() + local message = "Test message for hash algorithms" + + -- Test with invalid hash algorithm + local ok = pcall(function() return crypto.sign("rsa", priv_key, message, "invalid-hash") end) + assert(ok == false, "Sign with invalid hash should fail") + + -- Test with nil hash algorithm (should default to SHA-256) + local signature = crypto.sign("rsa", priv_key, message) + assert(type(signature) == "string", "Sign with nil hash") + + local is_valid = crypto.verify("rsa", pub_key, message, signature) + assert(is_valid == true, "Verify with nil hash") +end + +-- Negative tests for crypto functions +local function test_negative_keypair_generation() + -- Invalid algorithm + local ok = pcall(function() return crypto.generateKeyPair('invalidalg', 2048) end) + assert(ok == false, "generatekeypair with invalid algorithm should fail") + -- Invalid RSA key size + local pk, _ = crypto.generateKeyPair('rsa', 123) + assert(pk == nil, "generatekeypair with invalid RSA size should fail") + -- Invalid ECDSA curve + pk, _ = crypto.generateKeyPair('ecdsa', 'invalidcurve') + assert(pk == nil, "generatekeypair with invalid ECDSA curve should fail") +end + +local function test_negative_encrypt_decrypt() + local priv_key, pub_key = crypto.generateKeyPair('rsa', 2048) + -- Encrypt with invalid algorithm + local ok, ciphertext = pcall(function() return crypto.encrypt('invalidalg', pub_key, 'data') end) + assert(ok == false, "RSA encrypt with invalid algorithm should fail") + + -- Decrypt with invalid algorithm + ok, _ = pcall(function() return crypto.decrypt('invalidalg', priv_key, 'data') end) + assert(ok == false, "RSA decrypt with invalid algorithm should fail") + + -- Encrypt with invalid key + ciphertext = crypto.encrypt('rsa', 'notakey', 'data') + assert(ciphertext == nil, "RSA encrypt with invalid key should fail") + + -- Decrypt with invalid key + local retval = crypto.decrypt('rsa', 'notakey', 'data') + assert(retval == nil, "RSA decrypt with invalid key should fail") + + -- AES: invalid IV length + local key = crypto.generateKeyPair('aes', 256) + ciphertext = crypto.encrypt('aes', key, 'data', { mode = "cbc", iv = "tooShortIV" }) + assert(ciphertext == nil, "AES encrypt with short IV should fail") + + retval = crypto.decrypt('aes', key, 'data', { mode = "cbc", iv = "tooShortIV" }) + assert(retval == nil, "AES decrypt with short IV should fail") +end + +local function test_negative_sign_verify() + local priv_key, pub_key = crypto.generateKeyPair('rsa', 2048) + -- Sign with invalid algorithm + local ok = pcall(function() return crypto.sign('invalidalg', priv_key, 'msg', 'sha256') end) + assert(ok == false, "RSA sign with invalid algorithm should fail") + + -- Verify with invalid algorithm + ok = pcall(function() return crypto.verify('invalidalg', pub_key, 'msg', 'sig', 'sha256') end) + assert(ok == false, "RSA verify with invalid algorithm should fail") + + -- Sign with invalid key + ok = pcall(function() return crypto.sign('rsa', 'notakey', 'msg', 'sha256') end) + assert(ok == false, "RSA sign with invalid key should fail") + + -- Verify with invalid key + local verified = crypto.verify('rsa', 'notakey', 'msg', 'sig', 'sha256') + assert(verified == false, "verify with invalid key should fail") + + -- Verify with wrong signature (should return false, not error) + local badsig = 'thisisnotavalidsignature' + local result = crypto.verify('rsa', pub_key, 'msg', badsig, 'sha256') + assert(result == false, "RSA verify with wrong signature should return false") +end + +local function test_negative_pem_jwk_conversion() + -- Invalid PEM + local ok = pcall(function() return crypto.convertPemToJwk('notapem') end) + assert(ok == false, "convertPemToJwk with invalid PEM should fail") + + -- Invalid JWK (wrong type, but still a table) + local pem = crypto.convertJwkToPem({ kty = 'INVALID' }) + assert(pem == nil, "convertJwkToPem with invalid JWK should fail") + + -- Missing kty in JWK + pem = crypto.convertJwkToPem({}) + assert(pem == nil, "convertJwkToPem with missing kty should fail") +end + +local function test_negative_csr_generation() + -- Invalid key + local csr = crypto.generateCsr('notakey', 'CN=bad') + assert(csr == nil, "generateCsr with invalid key should fail") +end + +-- Add additional tests for edge cases in crypto functions + +-- Test RSA key size variations +local function test_rsa_key_sizes() + -- Test 2048-bit keys + local priv_key_2048, pub_key_2048 = crypto.generateKeyPair("rsa", 2048) + assert(type(priv_key_2048) == "string", "2048-bit private key type") + assert(type(pub_key_2048) == "string", "2048-bit public key type") + + -- Test 4096-bit keys + local priv_key_4096, pub_key_4096 = crypto.generateKeyPair("rsa", 4096) + assert(type(priv_key_4096) == "string", "4096-bit private key type") + assert(type(pub_key_4096) == "string", "4096-bit public key type") + + -- Test signing and verification with different key sizes + local message = "Test message for RSA key sizes" + + local signature_2048 = crypto.sign("rsa", priv_key_2048, message, "sha256") + assert(type(signature_2048) == "string", "2048-bit key signature") + + local is_valid_2048 = crypto.verify("rsa", pub_key_2048, message, signature_2048, "sha256") + assert(is_valid_2048 == true, "2048-bit key verification") + + local signature_4096 = crypto.sign("rsa", priv_key_4096, message, "sha256") + assert(type(signature_4096) == "string", "4096-bit key signature") + + local is_valid_4096 = crypto.verify("rsa", pub_key_4096, message, signature_4096, "sha256") + assert(is_valid_4096 == true, "4096-bit key verification") +end + +-- Test ECDSA curves +local function test_ecdsa_curves() + local curves = { "secp256r1", "secp384r1", "secp521r1" } + local message = "Test message for ECDSA curves" + + for _, curve in ipairs(curves) do + local priv_key, pub_key = crypto.generateKeyPair("ecdsa", curve) + assert(type(priv_key) == "string", curve .. " private key type") + assert(type(pub_key) == "string", curve .. " public key type") + + local signature = crypto.sign("ecdsa", priv_key, message, "sha256") + assert(type(signature) == "string", curve .. " signature") + local is_valid = crypto.verify("ecdsa", pub_key, message, signature, "sha256") + assert(is_valid == true, curve .. " verification") + end +end + +-- Test AES key sizes +local function test_aes_key_sizes() + local key_sizes = { 128, 192, 256 } + local plaintext = "Test message for AES key sizes" + + for _, size in ipairs(key_sizes) do + local key = crypto.generateKeyPair("aes", size) + assert(type(key) == "string", size .. "-bit AES key type") + assert(#key == size / 8, size .. "-bit AES key length") + + -- Test CBC mode + local ciphertext, iv = crypto.encrypt("aes", key, plaintext, { mode = "cbc" }) + assert(type(ciphertext) == "string", size .. "-bit AES CBC encryption") + local decrypted_plaintext_cbc = crypto.decrypt("aes", key, ciphertext, { mode = "cbc", iv = iv }) + assert(decrypted_plaintext_cbc == plaintext, size .. "-bit AES CBC decryption") + + -- Test CTR mode + local ciphertext_ctr, iv_ctr = crypto.encrypt("aes", key, plaintext, { mode = "ctr" }) + assert(type(ciphertext_ctr) == "string", size .. "-bit AES CTR encryption") + local decrypted_plaintext_ctr = crypto.decrypt("aes", key, ciphertext_ctr, { mode = "ctr", iv = iv_ctr }) + assert(decrypted_plaintext_ctr == plaintext, size .. "-bit AES CTR decryption") + + -- Test GCM mode + local ciphertext_gcm, iv_gcm, tag = crypto.encrypt("aes", key, plaintext, { mode = "gcm" }) + assert(type(ciphertext_gcm) == "string", size .. "-bit AES GCM encryption") + local decrypted_plaintext_gcm = crypto.decrypt("aes", key, ciphertext_gcm, { mode = "gcm", iv = iv_gcm, tag = tag }) + assert(decrypted_plaintext_gcm == plaintext, size .. "-bit AES GCM decryption") + end +end + +-- Test AES decryption with corrupted ciphertext and tag +local function test_aes_corruption_handling() + local key = crypto.generateKeyPair('aes', 256) + local plaintext = "Sensitive data for corruption test" + -- CBC mode + local ciphertext, iv = crypto.encrypt("aes", key, plaintext, { mode = "cbc" }) + + -- Corrupt ciphertext + Log(kLogVerbose," - CBC decryption with corrupted ciphertext should fail") + local corrupted = ciphertext:sub(1, #ciphertext - 1) .. string.char((ciphertext:byte(-1) ~ 0xFF) % 256) + local plaintext_cbc = crypto.decrypt("aes", key, corrupted, { mode = "cbc", iv = iv }) + assert(plaintext_cbc == nil, "CBC decryption with corrupted ciphertext should fail") + + -- CTR mode (should not error, but output will be wrong) + Log(kLogVerbose," - CTR decryption with corrupted ciphertext should not match original") + local ciphertext_ctr, iv_ctr = crypto.encrypt("aes", key, plaintext, { mode = "ctr" }) + local corrupted_ctr = ciphertext_ctr:sub(1, #ciphertext_ctr - 1) .. + string.char((ciphertext_ctr:byte(-1) ~ 0xFF) % 256) + local plaintext_ctr = crypto.decrypt("aes", key, corrupted_ctr, { mode = "ctr", iv = iv_ctr }) + assert(plaintext_ctr ~= plaintext, "CTR decryption with corrupted ciphertext should not match original") + + -- GCM mode (should fail authentication) + Log(kLogVerbose,"GCM decryption with corrupted ciphertext should fail") + local ciphertext_gcm, iv_gcm, tag = crypto.encrypt("aes", key, plaintext, { mode = "gcm" }) + local corrupted_gcm = ciphertext_gcm:sub(1, #ciphertext_gcm - 1) .. + string.char((ciphertext_gcm:byte(-1) ~ 0xFF) % 256) + local plaintext_gcm = crypto.decrypt("aes", key, corrupted_gcm, { mode = "gcm", iv = iv_gcm, tag = tag }) + assert(plaintext_gcm == nil, "GCM decryption with corrupted ciphertext should fail") + + -- GCM mode with corrupted tag + Log(kLogVerbose,"GCM decryption with corrupted tag should fail") + local badtag = tag:sub(1, #tag - 1) .. string.char((tag:byte(-1) ~ 0xFF) % 256) + local plaintext_gcm2 = crypto.decrypt("aes", key, ciphertext_gcm, { mode = "gcm", iv = iv_gcm, tag = badtag }) + assert(plaintext_gcm2 ~= plaintext, "GCM decryption with corrupted tag should fail") +end + +-- Test AES encryption/decryption with empty plaintext +local function test_aes_empty_plaintext() + local key = crypto.generateKeyPair('aes', 256) + local empty = "" + for _, mode in ipairs({ "cbc", "ctr", "gcm" }) do + local ciphertext, iv, tag = crypto.encrypt("aes", key, empty, { mode = mode }) + assert(type(ciphertext) == "string", "AES " .. mode .. " encrypt empty string") + + local opts = { mode = mode, iv = iv, tag = tag } + if mode ~= "gcm" then opts.tag = nil end + + local plaintext = crypto.decrypt("aes", key, ciphertext, opts) + assert(plaintext == empty, "AES " .. mode .. " decrypt empty string") + end +end + +-- Test sign/verify with empty message +local function test_sign_verify_empty_message() + local priv_key, pub_key = crypto.generateKeyPair("rsa", 2048) + local signature = crypto.sign("rsa", priv_key, "", "sha256") + assert(type(signature) == "string", "RSA sign empty message") + + local is_valid = crypto.verify("rsa", pub_key, "", signature, "sha256") + assert(is_valid == true, "RSA verify empty message") + + local ec_priv, ec_pub = crypto.generateKeyPair("ecdsa", "secp256r1") + local ec_sig = crypto.sign("ecdsa", ec_priv, "", "sha256") + assert(type(ec_sig) == "string", "ECDSA sign empty message") + + local ec_valid = crypto.verify("ecdsa", ec_pub, "", ec_sig, "sha256") + assert(ec_valid == true, "ECDSA verify empty message") +end + +-- Test JWK to PEM with minimal valid JWKs and missing fields +local function test_jwk_to_pem_minimal() + -- Minimal valid RSA public JWK + local _, pub_key = crypto.generateKeyPair("rsa", 2048) + local pub_jwk = crypto.convertPemToJwk(pub_key) + local minimal_jwk = { kty = pub_jwk.kty, n = pub_jwk.n, e = pub_jwk.e } + Log(kLogVerbose," - Testing minimal JWK to PEM conversion") + local pem = crypto.convertJwkToPem(minimal_jwk) + assert(type(pem) == "string", "Minimal RSA JWK to PEM") + + -- Missing 'n' field + Log(kLogVerbose," - Testing missing 'n' field in JWK to PEM conversion") + local bad_jwk = { kty = "RSA", e = pub_jwk.e } + local pem2 = crypto.convertJwkToPem(bad_jwk) + assert(pem2 == nil, "JWK to PEM with missing n should fail") + + -- Minimal EC public JWK + Log(kLogVerbose," - Testing minimal EC JWK to PEM conversion") + local _, ec_pub = crypto.generateKeyPair("ecdsa", "secp256r1") + local ec_jwk = crypto.convertPemToJwk(ec_pub) + local minimal_ec_jwk = { kty = ec_jwk.kty, crv = ec_jwk.crv, x = ec_jwk.x, y = ec_jwk.y } + local ec_pem = crypto.convertJwkToPem(minimal_ec_jwk) + assert(type(ec_pem) == "string", "Minimal EC JWK to PEM") + + -- Missing 'x' field + Log(kLogVerbose," - Testing missing 'x' field in EC JWK to PEM conversion") + local bad_ec_jwk = { kty = "EC", crv = ec_jwk.crv, y = ec_jwk.y } + local ec_pem2 = crypto.convertJwkToPem(bad_ec_jwk) + assert(ec_pem2 == nil, "EC JWK to PEM with missing x should fail") +end + +-- Test PEM to JWK with corrupted PEM +local function test_pem_to_jwk_corrupted() + local badpem = "-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA7\n-----END PUBLIC KEY-----" + local ok = pcall(function() return crypto.convertPemToJwk(badpem) end) + assert(ok == false, "PEM to JWK with corrupted PEM should fail") +end + +-- Test CSR generation with missing/invalid subject/SAN +local function test_csr_generation_edge_cases() + local priv_key, _ = crypto.generateKeyPair() + -- Missing subject and SAN + local csr = crypto.generateCsr(priv_key) + assert(csr == nil, "CSR with missing subject and SAN should fail") + -- Invalid SAN type (not validated yet) + -- local csr2, err2 = crypto.generateCsr(priv_key, "CN=foo", 12345) + -- assert(csr2 == nil, "CSR with invalid SAN type should fail") +end + +-- Test unsupported AES mode +local function test_unsupported_aes_mode() + Log(kLogVerbose," - AES decrypt with unsupported mode should fail") + local key = crypto.generateKeyPair('aes', 256) + local ciphertext = crypto.encrypt('aes', key, 'data', { mode = 'ofb' }) + assert(ciphertext == nil, "AES encrypt with unsupported mode should fail") + + local plaintext = crypto.decrypt('aes', key, 'data', { mode = 'ofb', iv = string.rep('A', 16) }) + assert(plaintext == nil, "AES decrypt with unsupported mode should fail") +end + +-- Test encrypting and signing very large messages +local function test_large_message_handling() + local priv_key, pub_key = crypto.generateKeyPair("rsa", 2048) + local large_message = string.rep("A", 1024 * 1024) -- 1MB + + -- RSA encryption (should fail or be limited by key size) + Log(kLogVerbose," - RSA encrypt large message should fail or be limited") + local ciphertext = crypto.encrypt("rsa", pub_key, large_message) + assert(ciphertext == nil, "RSA encrypt large message should fail or be limited") + + -- AES encryption (should succeed) + Log(kLogVerbose," - AES encrypt large message") + local key = crypto.generateKeyPair('aes', 256) + local aes_ciphertext, iv = crypto.encrypt("aes", key, large_message, { mode = "cbc" }) + assert(type(aes_ciphertext) == "string", "AES encrypt large message") + + local decrypted_large_message = crypto.decrypt("aes", key, aes_ciphertext, { mode = "cbc", iv = iv }) + assert(decrypted_large_message == large_message, "AES decrypt large message") + + -- RSA sign large message + Log(kLogVerbose," - RSA verify large message") + local signature = crypto.sign("rsa", priv_key, large_message, "sha256") + assert(type(signature) == "string", "RSA sign large message") + + local is_valid = crypto.verify("rsa", pub_key, large_message, signature, "sha256") + assert(is_valid == true, "RSA verify large message") +end + +-- Test passing non-string values as keys/messages/options +local function test_invalid_types() + local priv_key, _ = crypto.generateKeyPair("rsa", 2048) + local key = crypto.generateKeyPair('aes', 256) + + -- Non-string message + Log(kLogVerbose," - RSA sign with integer message should fail") + local signature = crypto.sign("rsa", priv_key, 12345, "sha256") + assert(signature == nil, "RSA sign with non-string message should fail") + + Log(kLogVerbose," - AES encrypt with boolean message should fail") + ciphertext = crypto.encrypt("aes", key, true, { mode = "cbc" }) + assert(ciphertext == nil, "AES encrypt with boolean message should fail") + + -- Non-string key + Log(kLogVerbose," - RSA sign with table as key should fail") + signature = crypto.sign("rsa", {}, "msg", "sha256") + assert(signature == nil, "RSA sign with table as key should fail") + + -- Non-table options + Log(kLogVerbose," - AES encrypt with number as options should fail") + ciphertext = crypto.encrypt("aes", key, "msg", 123) + assert(ciphertext == nil, "AES encrypt with number as options should fail") +end + +-- Test encrypting with one mode and decrypting with another +local function test_mixed_mode_operations() + local key = crypto.generateKeyPair('aes', 256) + local plaintext = "Mixed mode test" + local ciphertext, iv = crypto.encrypt("aes", key, plaintext, { mode = "cbc" }) + local decrypted_plaintext = crypto.decrypt("aes", key, ciphertext, { mode = "ctr", iv = iv }) + assert(decrypted_plaintext ~= plaintext, "Decrypt CBC ciphertext with CTR mode should not match") + + local ciphertext2, iv2 = crypto.encrypt("aes", key, plaintext, { mode = "ctr" }) + local decrypted_plaintext2 = crypto.decrypt("aes", key, ciphertext2, { mode = "cbc", iv = iv2 }) + assert(decrypted_plaintext2 ~= plaintext, "Decrypt CTR ciphertext with CBC mode should not match") +end + +-- Test signing/verifying/converting with nil or empty parameters +local function test_nil_empty_parameters() + Log(kLogVerbose," - RSA sign with nil message should fail") + local priv_key, pub_key = crypto.generateKeyPair("rsa", 2048) + local signature = crypto.sign("rsa", priv_key, nil, "sha256") + assert(signature == nil, "Sign with nil message should fail") + + Log(kLogVerbose," - RSA verify with nil message should fail") + local is_valid = crypto.verify("rsa", pub_key, nil, "sig", "sha256") + assert(is_valid == nil, "Verify with nil message should fail") + + Log(kLogVerbose," - JWK to PEM with nil should fail") + local ok = pcall(function() return crypto.convertJwkToPem(nil) end) + assert(ok == false, "convertJwkToPem with nil should fail") + + Log(kLogVerbose," - JWK to PEM with empty string should fail") + ok = pcall(function() return crypto.convertJwkToPem("") end) + assert(ok == false, "convertJwkToPem with empty string should fail") end -- Run all tests local function run_tests() + Log(kLogVerbose,"Testing RSA keypair generation...") test_rsa_keypair_generation() + Log(kLogVerbose,"Testing RSA signing and verification...") test_rsa_signing_verification() + Log(kLogVerbose,"Testing RSA-PSS signing and verification...") + test_rsapss_signing_verification() + Log(kLogVerbose,"Testing RSA encryption and decryption...") test_rsa_encryption_decryption() + Log(kLogVerbose,"Testing RSA key size variations...") + test_rsa_key_sizes() + + Log(kLogVerbose,"Testing ECDSA keypair generation...") test_ecdsa_keypair_generation() + Log(kLogVerbose,"Testing ECDSA signing and verification...") test_ecdsa_signing_verification() + Log(kLogVerbose,"Testing ECDSA curves...") + test_ecdsa_curves() + + Log(kLogVerbose,"Testing AES key generation...") test_aes_key_generation() + Log(kLogVerbose,"Testing AES encryption and decryption (CBC mode)...") test_aes_encryption_decryption_cbc() + Log(kLogVerbose,"Testing AES encryption and decryption (CTR mode)...") test_aes_encryption_decryption_ctr() + Log(kLogVerbose,"Testing AES encryption and decryption (GCM mode)...") test_aes_encryption_decryption_gcm() + Log(kLogVerbose,"Testing unsupported AES mode...") + test_unsupported_aes_mode() + Log(kLogVerbose,"Testing AES key sizes...") + test_aes_key_sizes() + Log(kLogVerbose,"Testing AES decryption with corrupted ciphertext and tag...") + test_aes_corruption_handling() + Log(kLogVerbose,"Testing AES encryption/decryption with empty plaintext...") + test_aes_empty_plaintext() + Log(kLogVerbose,"Testing large message encryption and signing...") + test_large_message_handling() + + Log(kLogVerbose,"Testing various hash algorithms...") + test_hash_algorithms() + Log(kLogVerbose,"Testing negative cases for hash algorithms...") + test_negative_hash_algorithms() + Log(kLogVerbose,"Testing sign/verify with empty message...") + test_sign_verify_empty_message() + Log(kLogVerbose,"Testing invalid input types...") + test_invalid_types() + Log(kLogVerbose,"Testing mixed mode encryption/decryption...") + test_mixed_mode_operations() + Log(kLogVerbose,"Testing nil/empty parameters...") + test_nil_empty_parameters() + Log(kLogVerbose,"Testing negative keypair generation...") + test_negative_keypair_generation() + Log(kLogVerbose,"Testing negative encrypt/decrypt...") + test_negative_encrypt_decrypt() + Log(kLogVerbose,"Testing negative sign/verify...") + test_negative_sign_verify() + + Log(kLogVerbose,"Testing PEM to JWK conversion...") test_pem_to_jwk() + Log(kLogVerbose,"Testing PEM to JWK with corrupted PEM...") + test_pem_to_jwk_corrupted() + Log(kLogVerbose,"Testing JWK to PEM conversion...") test_jwk_to_pem() + Log(kLogVerbose,"Testing negative PEM/JWK conversion...") + test_negative_pem_jwk_conversion() + Log(kLogVerbose,"Testing JWK to PEM with minimal valid JWKs and missing fields...") + test_jwk_to_pem_minimal() + Log(kLogVerbose,"Testing CSR generation...") test_csr_generation() + Log(kLogVerbose,"Testing CSR generation with missing/invalid subject/SAN...") + test_csr_generation_edge_cases() + Log(kLogVerbose,"Testing negative CSR generation...") + test_negative_csr_generation() EXIT = 0 return EXIT end diff --git a/tool/net/definitions.lua b/tool/net/definitions.lua index 661e0ea27..c21ad82cc 100644 --- a/tool/net/definitions.lua +++ b/tool/net/definitions.lua @@ -8051,68 +8051,85 @@ kUrlLatin1 = nil --- This module provides cryptographic operations. ---- The crypto module for cryptographic operations crypto = {} ---- Converts a PEM-encoded key to JWK format ----@param pem string PEM-encoded key ----@return table?, string? JWK table or nil on error ----@return string? error message -function crypto.convertPemToJwk(pem) end +--- Signs a message using the specified key type. +--- Supported types: "rsa", "rsa-pss", "ecdsa" +---@param type "rsa"|"rsa-pss"|"rsapss"|"ecdsa" +---@param key string PEM-encoded private key +---@param message string +---@param hash? string Hash algorithm ("sha256", "sha384", "sha512"). Default: "sha256" +---@return string signature +---@overload fun(type: string, key: string, message: string, hash?: string): nil, error: string +function crypto.sign(type, key, message, hash) end ---- Generates a Certificate Signing Request (CSR) ----@param key_pem string PEM-encoded private key ----@param subject_name string? X.509 subject name ----@param san_list string? Subject Alternative Names ----@return string?, string? CSR in PEM format or nil on error and error message -function crypto.generateCsr(key_pem, subject_name, san_list) end +--- Verifies a signature using the specified key type. +--- Supported types: "rsa", "rsa-pss", "ecdsa" +---@param type "rsa"|"rsa-pss"|"rsapss"|"ecdsa" +---@param key string PEM-encoded public key +---@param message string +---@param signature string +---@param hash? string Hash algorithm ("sha256", "sha384", "sha512"). Default: "sha256" +---@return boolean valid +function crypto.verify(type, key, message, signature, hash) end ---- Signs data using a private key ----@param key_type string "rsa" or "ecdsa" ----@param private_key string PEM-encoded private key ----@param message string Data to sign ----@param hash_algo string? Hash algorithm (default: SHA-256) ----@return string?, string? Signature or nil on error and error message -function crypto.sign(key_type, private_key, message, hash_algo) end +--- Encrypts data using the specified cipher. +--- Supported ciphers: "rsa", "aes" +---@param cipher "rsa"|"aes" +---@param key string PEM-encoded public key (RSA) or raw key (AES) +---@param plaintext string +---@param options? table Options for AES: { mode="cbc"|"gcm"|"ctr", iv=string, aad=string } +---@return string ciphertext, string? iv, string? tag +---@overload fun(cipher: string, key: string, plaintext: string, options?: table): nil, error: string +function crypto.encrypt(cipher, key, plaintext, options) end ---- Verifies a signature ----@param key_type string "rsa" or "ecdsa" ----@param public_key string PEM-encoded public key ----@param message string Original message ----@param signature string Signature to verify ----@param hash_algo string? Hash algorithm (default: SHA-256) ----@return boolean?, string? True if valid or nil on error and error message -function crypto.verify(key_type, public_key, message, signature, hash_algo) end +--- Decrypts data using the specified cipher. +--- Supported ciphers: "rsa", "aes" +---@param cipher "rsa"|"aes" +---@param key string PEM-encoded private key (RSA) or raw key (AES) +---@param ciphertext string +---@param options? table Options for AES: { mode="cbc"|"gcm"|"ctr", iv=string, tag=string, aad=string } +---@return string plaintext +---@overload fun(cipher: string, key: string, ciphertext: string, options?: table): nil, error: string +function crypto.decrypt(cipher, key, ciphertext, options) end ---- Encrypts data ----@param cipher_type string "rsa" or "aes" ----@param key string Public key or symmetric key ----@param plaintext string Data to encrypt ----@param mode string? AES mode: "cbc", "gcm", "ctr" (default: "cbc") ----@param iv string? Initialization Vector for AES ----@param aad string? Additional data for AES-GCM ----@return string? Encrypted data or nil on error ----@return string? IV or error message ----@return string? Authentication tag for GCM mode -function crypto.encrypt(cipher_type, key, plaintext, mode, iv, aad) end +--- Generates a key pair. +--- For RSA: bits = 2048 or 4096. +--- For ECDSA: curve = "secp256r1", "secp384r1", "secp521r1", "curve25519" +--- For AES: bits = 128, 192, or 256. +---@param type "rsa"|"ecdsa"|"aes" +---@param param? integer|string For RSA: bits; for ECDSA: curve name; for AES: bits +---@return string private_key, string public_key|nil +---@overload fun(type: string, param?: integer|string): nil, error: string +function crypto.generateKeyPair(type, param) end ---- Decrypts data ----@param cipher_type string "rsa" or "aes" ----@param key string Private key or symmetric key ----@param ciphertext string Data to decrypt ----@param iv string? Initialization Vector for AES ----@param mode string? AES mode: "cbc", "gcm", "ctr" (default: "cbc") ----@param tag string? Authentication tag for AES-GCM ----@param aad string? Additional data for AES-GCM ----@return string?, string? Decrypted data or nil on error and error message -function crypto.decrypt(cipher_type, key, ciphertext, iv, mode, tag, aad) end +--- Converts a JWK (JSON Web Key, as a Lua table) to PEM format. +---@param jwk table +---@return string pem +---@overload fun(jwk: table): nil, error: string +function crypto.convertJwkToPem(jwk) end ---- Generates cryptographic keys ----@param key_type string? "rsa", "ecdsa", or "aes" ----@param key_size_or_curve number|string? Key size or curve name ----@return string? Private key or nil on error ----@return string? Public key (nil for AES) or error message -function crypto.generatekeypair(key_type, key_size_or_curve) end +--- Converts a PEM key to JWK (JSON Web Key, as a Lua table). +---@param pem string +---@param claims? table Additional claims to merge (RFC7517 fields) +---@return table jwk +---@overload fun(pem: string, claims?: table): nil, error: string +function crypto.convertPemToJwk(pem, claims) end + +--- Generates a Certificate Signing Request (CSR) from a PEM key. +---@param key string PEM-encoded private key +---@param subject string Subject name (e.g., "CN=example.com") +---@param sans? string Subject Alternative Names (SANs) as a comma-separated string +---@return string csr_pem +---@overload fun(key: string, subject: string, sans?: string): nil, error: string +function crypto.generateCsr(key, subject, sans) end + +-- AES options table for encrypt/decrypt: +---@class CryptoAesOptions +---@field mode? "cbc"|"gcm"|"ctr" +---@field iv? string +---@field tag? string +---@field aad? string --[[ diff --git a/tool/net/lcrypto.c b/tool/net/lcrypto.c index e50304fe4..7ce0c3cd4 100644 --- a/tool/net/lcrypto.c +++ b/tool/net/lcrypto.c @@ -32,6 +32,163 @@ #include "third_party/mbedtls/rsa.h" #include "third_party/mbedtls/x509_csr.h" +// Supported curves mapping +typedef struct { + const char *name; + mbedtls_ecp_group_id id; +} curve_map_t; + +static const curve_map_t supported_curves[] = { + {"secp256r1", MBEDTLS_ECP_DP_SECP256R1}, + {"P-256", MBEDTLS_ECP_DP_SECP256R1}, + {"secp384r1", MBEDTLS_ECP_DP_SECP384R1}, + {"P-384", MBEDTLS_ECP_DP_SECP384R1}, + {"secp521r1", MBEDTLS_ECP_DP_SECP521R1}, + {"P-521", MBEDTLS_ECP_DP_SECP521R1}, + {"curve25519", MBEDTLS_ECP_DP_CURVE25519}, + {NULL, 0}}; + +// List available curves +static int LuaListCurves(lua_State *L) { + const curve_map_t *curve = supported_curves; + int i = 1; + + lua_newtable(L); + + while (curve->name != NULL) { + lua_pushstring(L, curve->name); + lua_rawseti(L, -2, i++); + curve++; + } + + return 1; +} + +// Remove hash_algorithm_t and hash_to_md_type indirection +static mbedtls_md_type_t string_to_md_type(const char *hash_name) { + if (!hash_name || !*hash_name) { + return MBEDTLS_MD_SHA256; // Default to SHA-256 if no name provided + } + if (strcasecmp(hash_name, "sha256") == 0 || strcasecmp(hash_name, "sha-256") == 0) { + return MBEDTLS_MD_SHA256; + } else if (strcasecmp(hash_name, "sha384") == 0 || strcasecmp(hash_name, "sha-384") == 0) { + return MBEDTLS_MD_SHA384; + } else if (strcasecmp(hash_name, "sha512") == 0 || strcasecmp(hash_name, "sha-512") == 0) { + return MBEDTLS_MD_SHA512; + } else { + WARNF("(crypto) Unknown hash algorithm '%s'", hash_name); + return -1; + } +} + +static size_t get_hash_size_from_md_type(mbedtls_md_type_t md_type) { + switch (md_type) { + case MBEDTLS_MD_SHA256: + return 32; + case MBEDTLS_MD_SHA384: + return 48; + case MBEDTLS_MD_SHA512: + return 64; + default: + return 32; + } +} + +// List available hash algorithms +static int LuaListHashAlgorithms(lua_State *L) { + lua_newtable(L); + + lua_pushstring(L, "SHA256"); + lua_rawseti(L, -2, 1); + + lua_pushstring(L, "SHA384"); + lua_rawseti(L, -2, 2); + + lua_pushstring(L, "SHA512"); + lua_rawseti(L, -2, 3); + + // Add hyphenated versions + lua_pushstring(L, "SHA-256"); + lua_rawseti(L, -2, 4); + + lua_pushstring(L, "SHA-384"); + lua_rawseti(L, -2, 5); + + lua_pushstring(L, "SHA-512"); + lua_rawseti(L, -2, 6); + + lua_pushstring(L, "SHA1"); + lua_rawseti(L, -2, 7); + + lua_pushstring(L, "MD5"); + lua_rawseti(L, -2, 8); + + return 1; +} + +static int compute_hash(mbedtls_md_type_t md_type, const unsigned char *input, + size_t input_len, unsigned char *output, + size_t output_size) { + mbedtls_md_context_t md_ctx; + const mbedtls_md_info_t *md_info; + int ret; + + md_info = mbedtls_md_info_from_type(md_type); + if (md_info == NULL) { + WARNF("(crypto) Unsupported hash algorithm"); + return -1; + } + + if (output_size < mbedtls_md_get_size(md_info)) { + WARNF("(crypto) Output buffer too small for hash"); + return -1; + } + + mbedtls_md_init(&md_ctx); + + ret = mbedtls_md_setup(&md_ctx, md_info, 0); // 0 = non-HMAC + if (ret != 0) { + WARNF("(crypto) Failed to set up hash context: -0x%04x", -ret); + goto cleanup; + } + + ret = mbedtls_md_starts(&md_ctx); + if (ret != 0) { + WARNF("(crypto) Failed to start hash operation: -0x%04x", -ret); + goto cleanup; + } + + ret = mbedtls_md_update(&md_ctx, input, input_len); + if (ret != 0) { + WARNF("(crypto) Failed to update hash: -0x%04x", -ret); + goto cleanup; + } + + ret = mbedtls_md_finish(&md_ctx, output); + if (ret != 0) { + WARNF("(crypto) Failed to finish hash: -0x%04x", -ret); + goto cleanup; + } + +cleanup: + mbedtls_md_free(&md_ctx); + return ret; +} + +// Find curve ID by name +static mbedtls_ecp_group_id find_curve_by_name(const char *name) { + const curve_map_t *curve = supported_curves; + + while (curve->name != NULL) { + if (strcasecmp(curve->name, name) == 0) { + return curve->id; + } + curve++; + } + + return MBEDTLS_ECP_DP_NONE; +} + // Strong RNG using mbedtls_entropy_context and mbedtls_ctr_drbg_context int GenerateRandom(void *ctx, unsigned char *output, size_t len) { static mbedtls_entropy_context entropy; @@ -77,7 +234,7 @@ static bool RSAGenerateKeyPair(char **private_key_pem, size_t *private_key_len, // Initialize as RSA key if ((rc = mbedtls_pk_setup(&key, mbedtls_pk_info_from_type(MBEDTLS_PK_RSA))) != 0) { - WARNF("Failed to setup key (grep -0x%04x)", -rc); + WARNF("(crypto) Failed to setup key (grep -0x%04x)", -rc); mbedtls_pk_free(&key); return false; } @@ -85,7 +242,7 @@ static bool RSAGenerateKeyPair(char **private_key_pem, size_t *private_key_len, // Generate RSA key if ((rc = mbedtls_rsa_gen_key(mbedtls_pk_rsa(key), GenerateRandom, 0, key_length, 65537)) != 0) { - WARNF("Failed to generate key (grep -0x%04x)", -rc); + WARNF("(crypto) Failed to generate key (grep -0x%04x)", -rc); mbedtls_pk_free(&key); return false; } @@ -95,7 +252,7 @@ static bool RSAGenerateKeyPair(char **private_key_pem, size_t *private_key_len, *private_key_pem = calloc(1, *private_key_len); if ((rc = mbedtls_pk_write_key_pem(&key, (unsigned char *)*private_key_pem, *private_key_len)) != 0) { - WARNF("Failed to write private key (grep -0x%04x)", -rc); + WARNF("(crypto) Failed to write private key (grep -0x%04x)", -rc); free(*private_key_pem); mbedtls_pk_free(&key); return false; @@ -107,7 +264,7 @@ static bool RSAGenerateKeyPair(char **private_key_pem, size_t *private_key_len, *public_key_pem = calloc(1, *public_key_len); if ((rc = mbedtls_pk_write_pubkey_pem(&key, (unsigned char *)*public_key_pem, *public_key_len)) != 0) { - WARNF("Failed to write public key (grep -0x%04x)", -rc); + WARNF("(crypto) Failed to write public key (grep -0x%04x)", -rc); free(*private_key_pem); free(*public_key_pem); mbedtls_pk_free(&key); @@ -128,6 +285,12 @@ static int LuaRSAGenerateKeyPair(lua_State *L) { } else { bits = (int)luaL_optinteger(L, 2, 2048); } + // Check if key length is valid (only 2048 or 4096 bits are allowed) + if (bits != 2048 && bits != 4096) { + lua_pushnil(L); + lua_pushfstring(L, "Invalid RSA key length: %d. Only 2048 or 4096 bits key lengths are supported", bits); + return 2; + } char *private_key, *public_key; size_t private_len, public_len; @@ -174,14 +337,14 @@ static char *RSAEncrypt(const char *public_key_pem, const unsigned char *data, if ((rc = mbedtls_pk_parse_public_key(&key, (const unsigned char *)public_key_pem, strlen(public_key_pem) + 1)) != 0) { - WARNF("Failed to parse public key (grep -0x%04x)", -rc); + WARNF("(crypto) Failed to parse public key (grep -0x%04x)", -rc); mbedtls_pk_free(&key); return NULL; } // Check if key is RSA if (mbedtls_pk_get_type(&key) != MBEDTLS_PK_RSA) { - WARNF("Key is not an RSA key"); + WARNF("(crypto) Key is not an RSA key"); mbedtls_pk_free(&key); return NULL; } @@ -198,7 +361,7 @@ static char *RSAEncrypt(const char *public_key_pem, const unsigned char *data, if ((rc = mbedtls_rsa_pkcs1_encrypt(mbedtls_pk_rsa(key), GenerateRandom, 0, MBEDTLS_RSA_PUBLIC, data_len, data, output)) != 0) { - WARNF("Encryption failed (grep -0x%04x)", -rc); + WARNF("(crypto) Encryption failed (grep -0x%04x)", -rc); free(output); mbedtls_pk_free(&key); return NULL; @@ -211,7 +374,19 @@ static char *RSAEncrypt(const char *public_key_pem, const unsigned char *data, static int LuaRSAEncrypt(lua_State *L) { // Args: key, plaintext, options table size_t keylen, ptlen; + // Ensure key is a string + if (!lua_isstring(L, 1)) { + lua_pushnil(L); + lua_pushstring(L, "Key must be a string"); + return 2; + } const char *key = luaL_checklstring(L, 1, &keylen); + // Ensure plaintext is a string + if (!lua_isstring(L, 2)) { + lua_pushnil(L); + lua_pushstring(L, "Plaintext must be a string"); + return 2; + } const unsigned char *plaintext = (const unsigned char *)luaL_checklstring(L, 2, &ptlen); // int options_idx = 3; @@ -230,7 +405,6 @@ static int LuaRSAEncrypt(lua_State *L) { return 1; } - static char *RSADecrypt(const char *private_key_pem, const unsigned char *encrypted_data, size_t encrypted_len, size_t *out_len) { @@ -239,16 +413,17 @@ static char *RSADecrypt(const char *private_key_pem, // Parse private key mbedtls_pk_context key; mbedtls_pk_init(&key); - if ((rc = mbedtls_pk_parse_key(&key, (const unsigned char *)private_key_pem, - strlen(private_key_pem) + 1, NULL, 0)) != 0) { - WARNF("Failed to parse private key (grep -0x%04x)", -rc); + rc = mbedtls_pk_parse_key(&key, (const unsigned char *)private_key_pem, + strlen(private_key_pem) + 1, NULL, 0); + if (rc != 0) { + WARNF("(crypto) Failed to parse private key (grep -0x%04x)", -rc); mbedtls_pk_free(&key); return NULL; } // Check if key is RSA if (mbedtls_pk_get_type(&key) != MBEDTLS_PK_RSA) { - WARNF("Key is not an RSA key"); + WARNF("(crypto) Key is not an RSA key"); mbedtls_pk_free(&key); return NULL; } @@ -263,10 +438,11 @@ static char *RSADecrypt(const char *private_key_pem, // Decrypt data size_t output_len = 0; - if ((rc = mbedtls_rsa_pkcs1_decrypt(mbedtls_pk_rsa(key), GenerateRandom, 0, - MBEDTLS_RSA_PRIVATE, &output_len, - encrypted_data, output, key_size)) != 0) { - WARNF("Decryption failed (grep -0x%04x)", -rc); + rc = mbedtls_rsa_pkcs1_decrypt(mbedtls_pk_rsa(key), GenerateRandom, 0, + MBEDTLS_RSA_PRIVATE, &output_len, + encrypted_data, output, key_size); + if (rc != 0) { + WARNF("(crypto) Decryption failed (grep -0x%04x)", -rc); free(output); mbedtls_pk_free(&key); return NULL; @@ -279,7 +455,19 @@ static char *RSADecrypt(const char *private_key_pem, static int LuaRSADecrypt(lua_State *L) { // Args: key, ciphertext, options table size_t keylen, ctlen; + // Ensure key is a string + if (!lua_isstring(L, 1)) { + lua_pushnil(L); + lua_pushstring(L, "Key must be a string"); + return 2; + } const char *key = luaL_checklstring(L, 1, &keylen); + // Ensure ciphertext is a string + if (!lua_isstring(L, 2)) { + lua_pushnil(L); + lua_pushstring(L, "Ciphertext must be a string"); + return 2; + } const unsigned char *ciphertext = (const unsigned char *)luaL_checklstring(L, 2, &ctlen); // int options_idx = 3; @@ -311,18 +499,8 @@ static char *RSASign(const char *private_key_pem, const unsigned char *data, // Determine hash algorithm if (hash_algo_str) { - if (strcasecmp(hash_algo_str, "sha256") == 0) { - hash_algo = MBEDTLS_MD_SHA256; - hash_len = 32; - } else if (strcasecmp(hash_algo_str, "sha384") == 0) { - hash_algo = MBEDTLS_MD_SHA384; - hash_len = 48; - } else if (strcasecmp(hash_algo_str, "sha512") == 0) { - hash_algo = MBEDTLS_MD_SHA512; - hash_len = 64; - } else { - return NULL; // Unsupported hash algorithm - } + hash_algo = string_to_md_type(hash_algo_str); + hash_len = get_hash_size_from_md_type(hash_algo); } // Parse private key @@ -330,14 +508,14 @@ static char *RSASign(const char *private_key_pem, const unsigned char *data, mbedtls_pk_init(&key); if ((rc = mbedtls_pk_parse_key(&key, (const unsigned char *)private_key_pem, strlen(private_key_pem) + 1, NULL, 0)) != 0) { - WARNF("Failed to parse private key (grep -0x%04x)", -rc); + WARNF("(crypto) Failed to parse private key (grep -0x%04x)", -rc); mbedtls_pk_free(&key); return NULL; } // Check if key is RSA if (mbedtls_pk_get_type(&key) != MBEDTLS_PK_RSA) { - WARNF("Key is not an RSA key"); + WARNF("(crypto) Key is not an RSA key"); mbedtls_pk_free(&key); return NULL; } @@ -376,6 +554,24 @@ static int LuaRSASign(lua_State *L) { size_t sig_len = 0; // Get parameters from Lua + if (lua_type(L, 1) != LUA_TSTRING) { + lua_pushnil(L); + lua_pushstring(L, "Key must be a string"); + return 2; + } + if (lua_type(L, 1) == LUA_TTABLE) { + DEBUGF("(crypto) Detected table type for parameter 1"); + lua_pushnil(L); + lua_pushstring(L, "Key must be a string, got a table instead"); + return 2; + } + // Ensure msg is a string + if (lua_type(L, 2) != LUA_TSTRING) { + lua_pushnil(L); + lua_pushstring(L, "Plaintext must be a string"); + return 2; + } + key_pem = luaL_checklstring(L, 1, &key_len); msg = luaL_checklstring(L, 2, &msg_len); @@ -401,6 +597,119 @@ static int LuaRSASign(lua_State *L) { return 1; } +// RSA PSS Signing +static char *RSAPSSSign(const char *private_key_pem, const unsigned char *data, + size_t data_len, const char *hash_algo_str, + size_t *sig_len, int salt_len) { + int rc; + unsigned char hash[64]; // Large enough for SHA-512 + size_t hash_len = 32; // Default for SHA-256 + unsigned char *signature; + mbedtls_md_type_t hash_algo = MBEDTLS_MD_SHA256; // Default + + // Determine hash algorithm + if (hash_algo_str) { + hash_algo = string_to_md_type(hash_algo_str); + hash_len = get_hash_size_from_md_type(hash_algo); + } + + // Parse private key + mbedtls_pk_context key; + mbedtls_pk_init(&key); + rc = mbedtls_pk_parse_key(&key, (const unsigned char *)private_key_pem, + strlen(private_key_pem) + 1, NULL, 0); + if (rc != 0) { + WARNF("(crypto) Failed to parse private key (grep -0x%04x)", -rc); + mbedtls_pk_free(&key); + return NULL; + } + + // Check if key is RSA + if (mbedtls_pk_get_type(&key) != MBEDTLS_PK_RSA) { + WARNF("(crypto) Key is not an RSA key"); + mbedtls_pk_free(&key); + return NULL; + } + + // Hash the message + rc = mbedtls_md(mbedtls_md_info_from_type(hash_algo), data, data_len, hash); + if (rc != 0) { + mbedtls_pk_free(&key); + return NULL; + } + + // Allocate buffer for signature + signature = malloc(MBEDTLS_PK_SIGNATURE_MAX_SIZE); + if (!signature) { + mbedtls_pk_free(&key); + return NULL; + } + + // Get RSA context + mbedtls_rsa_context *rsa = mbedtls_pk_rsa(key); + mbedtls_rsa_set_padding(rsa, MBEDTLS_RSA_PKCS_V21, hash_algo); + + // Sign the hash using PSS + rc = mbedtls_rsa_rsassa_pss_sign( + rsa, GenerateRandom, 0, MBEDTLS_RSA_PRIVATE, hash_algo, (unsigned int)hash_len, + hash, signature); + if (rc != 0) { + free(signature); + mbedtls_pk_free(&key); + return NULL; + } + + *sig_len = mbedtls_pk_get_len(&key); + + // Clean up + mbedtls_pk_free(&key); + + return (char *)signature; +} + +static int LuaRSAPSSSign(lua_State *L) { + size_t key_len, msg_len; + const char *key_pem, *msg; + unsigned char *signature; + size_t sig_len = 0; + + // Get parameters from Lua + if (lua_type(L, 1) != LUA_TSTRING) { + lua_pushnil(L); + lua_pushstring(L, "Key must be a string"); + return 2; + } + // Ensure msg is a string + if (lua_type(L, 2) != LUA_TSTRING) { + lua_pushnil(L); + lua_pushstring(L, "Plaintext must be a string"); + return 2; + } + + // Get parameters from Lua + key_pem = luaL_checklstring(L, 1, &key_len); + msg = luaL_checklstring(L, 2, &msg_len); + + // Optional hash algorithm parameter + const char *hash_name = luaL_optstring(L, 3, "sha256"); + + // Optional salt length parameter + int salt_len = luaL_optinteger(L, 4, -1); + + // Call the C implementation + signature = (unsigned char *)RSAPSSSign(key_pem, (const unsigned char *)msg, + msg_len, hash_name, &sig_len, salt_len); + + if (!signature) { + return luaL_error(L, "failed to sign message (PSS)"); + } + + // Return the signature as a Lua string + lua_pushlstring(L, (char *)signature, sig_len); + free(signature); + return 1; +} + static int RSAVerify(const char *public_key_pem, const unsigned char *data, size_t data_len, const unsigned char *signature, size_t sig_len, const char *hash_algo_str) { @@ -411,18 +720,8 @@ static int RSAVerify(const char *public_key_pem, const unsigned char *data, // Determine hash algorithm if (hash_algo_str) { - if (strcasecmp(hash_algo_str, "sha256") == 0) { - hash_algo = MBEDTLS_MD_SHA256; - hash_len = 32; - } else if (strcasecmp(hash_algo_str, "sha384") == 0) { - hash_algo = MBEDTLS_MD_SHA384; - hash_len = 48; - } else if (strcasecmp(hash_algo_str, "sha512") == 0) { - hash_algo = MBEDTLS_MD_SHA512; - hash_len = 64; - } else { - return -1; // Unsupported hash algorithm - } + hash_algo = string_to_md_type(hash_algo_str); + hash_len = get_hash_size_from_md_type(hash_algo); } // Parse public key @@ -431,14 +730,14 @@ static int RSAVerify(const char *public_key_pem, const unsigned char *data, if ((rc = mbedtls_pk_parse_public_key(&key, (const unsigned char *)public_key_pem, strlen(public_key_pem) + 1)) != 0) { - WARNF("Failed to parse public key (grep -0x%04x)", -rc); + WARNF("(crypto) Failed to parse public key (grep -0x%04x)", -rc); mbedtls_pk_free(&key); return -1; } // Check if key is RSA if (mbedtls_pk_get_type(&key) != MBEDTLS_PK_RSA) { - WARNF("Key is not an RSA key"); + WARNF("(crypto) Key is not an RSA key"); mbedtls_pk_free(&key); return -1; } @@ -464,6 +763,17 @@ static int LuaRSAVerify(lua_State *L) { int result; // Get parameters from Lua + if (!lua_isstring(L, 1)) { + lua_pushnil(L); + lua_pushstring(L, "Key must be a string"); + return 2; + } + // Ensure msg is a string + if (!lua_isstring(L, 2)) { + lua_pushnil(L); + lua_pushstring(L, "Plaintext must be a string"); + return 2; + } key_pem = luaL_checklstring(L, 1, &key_len); msg = luaL_checklstring(L, 2, &msg_len); signature = luaL_checklstring(L, 3, &sig_len); @@ -483,177 +793,120 @@ static int LuaRSAVerify(lua_State *L) { return 1; } +// RSA PSS Verification +static int RSAPSSVerify(const char *public_key_pem, const unsigned char *data, + size_t data_len, const unsigned char *signature, + size_t sig_len, const char *hash_algo_str, + int expected_salt_len) { + int rc; + unsigned char hash[64]; // Large enough for SHA-512 + size_t hash_len = 32; // Default for SHA-256 + mbedtls_md_type_t hash_algo = MBEDTLS_MD_SHA256; // Default + + // Determine hash algorithm + if (hash_algo_str) { + hash_algo = string_to_md_type(hash_algo_str); + DEBUGF("(DEBUG) Using hash algorithm: %s", hash_algo_str); + hash_len = get_hash_size_from_md_type(hash_algo); + DEBUGF("(DEBUG) Hash length: %zu", hash_len); + } + + // Parse public key + mbedtls_pk_context key; + mbedtls_pk_init(&key); + if ((rc = mbedtls_pk_parse_public_key(&key, + (const unsigned char *)public_key_pem, + strlen(public_key_pem) + 1)) != 0) { + WARNF("(crypto) Failed to parse public key (grep -0x%04x)", -rc); + mbedtls_pk_free(&key); + return -1; + } + + // Check if key is RSA + if (mbedtls_pk_get_type(&key) != MBEDTLS_PK_RSA) { + WARNF("(crypto) Key is not an RSA key"); + mbedtls_pk_free(&key); + return -1; + } + + // Hash the message + if ((rc = mbedtls_md(mbedtls_md_info_from_type(hash_algo), data, data_len, + hash)) != 0) { + mbedtls_pk_free(&key); + return -1; + } + + // Get RSA context + mbedtls_rsa_context *rsa = mbedtls_pk_rsa(key); + + // Verify the signature using PSS + rc = mbedtls_rsa_rsassa_pss_verify( + rsa, NULL, NULL, MBEDTLS_RSA_PUBLIC, hash_algo, (unsigned int)hash_len, + hash, signature); + + // Clean up + mbedtls_pk_free(&key); + + return rc; // 0 means success (valid signature) +} + +static int LuaRSAPSSVerify(lua_State *L) { + // Args: key, msg, signature, hash_algo (optional), salt_len (optional + // crypto.verify('rsapss', key, msg, signature, hash_algo, salt_len) + size_t msg_len, key_len, sig_len; + const char *msg, *key_pem, *signature; + const char *hash_algo_str = NULL; // Default to SHA-256 + int expected_salt_len = -1; + int result; + + // Get parameters from Lua + if (lua_type(L, 1) != LUA_TSTRING) { + lua_pushnil(L); + lua_pushstring(L, "Key must be a string"); + return 2; + } + // Ensure msg is a string + if (lua_type(L, 2) != LUA_TSTRING) { + lua_pushnil(L); + lua_pushstring(L, "Plaintext must be a string"); + return 2; + } + // Ensure signature is a string + if (lua_type(L, 3) != LUA_TSTRING) { + lua_pushnil(L); + lua_pushstring(L, "Signature must be a string"); + return 2; + } + // Get parameters from Lua + key_pem = luaL_checklstring(L, 1, &key_len); + msg = luaL_checklstring(L, 2, &msg_len); + signature = luaL_checklstring(L, 3, &sig_len); + // Optional hash algorithm parameter + if (lua_isstring(L, 4)) { + hash_algo_str = luaL_checkstring(L, 4); + // Optional salt length parameter + expected_salt_len = luaL_optinteger(L, 5, 32); + } else if (lua_isnumber(L, 4)) { + // If it's a number, treat it as the expected salt length + expected_salt_len = (int)lua_tointeger(L, 4); + } + DEBUGF("(DEBUG) Key PEM: %s", key_pem); + DEBUGF("(DEBUG) Message length: %zu", msg_len); + DEBUGF("(DEBUG) Signature length: %zu", sig_len); + DEBUGF("(DEBUG) Hash algorithm: %s", hash_algo_str ? hash_algo_str : "SHA-256"); + DEBUGF("(DEBUG) Expected salt length: %d", expected_salt_len); + DEBUGF("(DEBUG) Signature: %.*s", (int)sig_len, signature); + // Call the C implementation + result = RSAPSSVerify(key_pem, (const unsigned char *)msg, msg_len, + (const unsigned char *)signature, sig_len, hash_algo_str, expected_salt_len); + + // Return boolean result (0 means valid signature) + lua_pushboolean(L, result == 0); + + return 1; +} + // Elliptic Curve Cryptography Functions -// Supported curves mapping -typedef struct { - const char *name; - mbedtls_ecp_group_id id; -} curve_map_t; - -static const curve_map_t supported_curves[] = { - {"secp256r1", MBEDTLS_ECP_DP_SECP256R1}, - {"secp384r1", MBEDTLS_ECP_DP_SECP384R1}, - {"secp521r1", MBEDTLS_ECP_DP_SECP521R1}, - {"secp192r1", MBEDTLS_ECP_DP_SECP192R1}, - {"secp224r1", MBEDTLS_ECP_DP_SECP224R1}, - {"curve25519", MBEDTLS_ECP_DP_CURVE25519}, - {NULL, 0}}; - -typedef enum { SHA256, SHA384, SHA512 } hash_algorithm_t; - -static mbedtls_md_type_t hash_to_md_type(hash_algorithm_t hash_alg) { - switch (hash_alg) { - case SHA256: - return MBEDTLS_MD_SHA256; - case SHA384: - return MBEDTLS_MD_SHA384; - case SHA512: - return MBEDTLS_MD_SHA512; - default: - return MBEDTLS_MD_SHA256; // Default to SHA-256 - } -} - -static size_t get_hash_size(hash_algorithm_t hash_alg) { - switch (hash_alg) { - case SHA256: - return 32; - case SHA384: - return 48; - case SHA512: - return 64; - default: - return 32; // Default to SHA-256 - } -} - -static hash_algorithm_t string_to_hash_alg(const char *hash_name) { - if (!hash_name || !*hash_name) { - return SHA256; // Default to SHA-256 if no name provided - } - - if (strcasecmp(hash_name, "sha256") == 0 || - strcasecmp(hash_name, "sha-256") == 0) { - return SHA256; - } else if (strcasecmp(hash_name, "sha384") == 0 || - strcasecmp(hash_name, "sha-384") == 0) { - return SHA384; - } else if (strcasecmp(hash_name, "sha512") == 0 || - strcasecmp(hash_name, "sha-512") == 0) { - return SHA512; - } else { - WARNF("(ecdsa) Unknown hash algorithm '%s', using SHA-256", hash_name); - return SHA256; - } -} - -static int LuaListHashAlgorithms(lua_State *L) { - lua_newtable(L); - - lua_pushstring(L, "SHA256"); - lua_rawseti(L, -2, 1); - - lua_pushstring(L, "SHA384"); - lua_rawseti(L, -2, 2); - - lua_pushstring(L, "SHA512"); - lua_rawseti(L, -2, 3); - - // Add hyphenated versions - lua_pushstring(L, "SHA-256"); - lua_rawseti(L, -2, 4); - - lua_pushstring(L, "SHA-384"); - lua_rawseti(L, -2, 5); - - lua_pushstring(L, "SHA-512"); - lua_rawseti(L, -2, 6); - - lua_pushstring(L, "MD5"); - lua_rawseti(L, -2, 7); - - return 1; -} - -// List available curves -static int LuaListCurves(lua_State *L) { - const curve_map_t *curve = supported_curves; - int i = 1; - - lua_newtable(L); - - while (curve->name != NULL) { - lua_pushstring(L, curve->name); - lua_rawseti(L, -2, i++); - curve++; - } - - return 1; -} - -static int compute_hash(hash_algorithm_t hash_alg, const unsigned char *input, - size_t input_len, unsigned char *output, - size_t output_size) { - mbedtls_md_context_t md_ctx; - const mbedtls_md_info_t *md_info; - int ret; - - mbedtls_md_type_t md_type = hash_to_md_type(hash_alg); - md_info = mbedtls_md_info_from_type(md_type); - if (md_info == NULL) { - WARNF("(ecdsa) Unsupported hash algorithm"); - return -1; - } - - if (output_size < mbedtls_md_get_size(md_info)) { - WARNF("(ecdsa) Output buffer too small for hash"); - return -1; - } - - mbedtls_md_init(&md_ctx); - - ret = mbedtls_md_setup(&md_ctx, md_info, 0); // 0 = non-HMAC - if (ret != 0) { - WARNF("(ecdsa) Failed to set up hash context: -0x%04x", -ret); - goto cleanup; - } - - ret = mbedtls_md_starts(&md_ctx); - if (ret != 0) { - WARNF("(ecdsa) Failed to start hash operation: -0x%04x", -ret); - goto cleanup; - } - - ret = mbedtls_md_update(&md_ctx, input, input_len); - if (ret != 0) { - WARNF("(ecdsa) Failed to update hash: -0x%04x", -ret); - goto cleanup; - } - - ret = mbedtls_md_finish(&md_ctx, output); - if (ret != 0) { - WARNF("(ecdsa) Failed to finish hash: -0x%04x", -ret); - goto cleanup; - } - -cleanup: - mbedtls_md_free(&md_ctx); - return ret; -} - -// Find curve ID by name -static mbedtls_ecp_group_id find_curve_by_name(const char *name) { - const curve_map_t *curve = supported_curves; - - while (curve->name != NULL) { - if (strcasecmp(curve->name, name) == 0) { - return curve->id; - } - curve++; - } - - return MBEDTLS_ECP_DP_NONE; -} // Generate an ECDSA key pair and return in PEM format static int ECDSAGenerateKeyPair(const char *curve_name, char **priv_key_pem, @@ -672,13 +925,13 @@ static int ECDSAGenerateKeyPair(const char *curve_name, char **priv_key_pem, // Use secp256r1 as default if curve_name is NULL or empty if (curve_name == NULL || curve_name[0] == '\0') { curve_id = MBEDTLS_ECP_DP_SECP256R1; - VERBOSEF("(ecdsa) No curve specified, using default: secp256r1"); + VERBOSEF("(crypto) No curve specified, using default: secp256r1"); } else { // Find the curve by name curve_id = find_curve_by_name(curve_name); if (curve_id == MBEDTLS_ECP_DP_NONE) { - WARNF("(ecdsa) Unknown curve: %s, using default: secp256r1", curve_name); - curve_id = MBEDTLS_ECP_DP_SECP256R1; + WARNF("(crypto) Unknown curve: '%s'", curve_name); + return -1; } } @@ -687,13 +940,13 @@ static int ECDSAGenerateKeyPair(const char *curve_name, char **priv_key_pem, // Generate the key with the specified curve ret = mbedtls_pk_setup(&key, mbedtls_pk_info_from_type(MBEDTLS_PK_ECKEY)); if (ret != 0) { - WARNF("(ecdsa) Failed to setup key: -0x%04x", -ret); + WARNF("(crypto) Failed to setup key: -0x%04x", -ret); goto cleanup; } ret = mbedtls_ecp_gen_key(curve_id, mbedtls_pk_ec(key), GenerateRandom, 0); if (ret != 0) { - WARNF("(ecdsa) Failed to generate key: -0x%04x", -ret); + WARNF("(crypto) Failed to generate key: -0x%04x", -ret); goto cleanup; } @@ -702,12 +955,12 @@ static int ECDSAGenerateKeyPair(const char *curve_name, char **priv_key_pem, memset(output_buf, 0, sizeof(output_buf)); ret = mbedtls_pk_write_key_pem(&key, output_buf, sizeof(output_buf)); if (ret != 0) { - WARNF("(ecdsa) Failed to write private key: -0x%04x", -ret); + WARNF("(crypto) Failed to write private key: -0x%04x", -ret); goto cleanup; } *priv_key_pem = strdup((char *)output_buf); if (*priv_key_pem == NULL) { - WARNF("(ecdsa) Failed to allocate memory for private key PEM"); + WARNF("(crypto) Failed to allocate memory for private key PEM"); ret = -1; goto cleanup; } @@ -718,12 +971,12 @@ static int ECDSAGenerateKeyPair(const char *curve_name, char **priv_key_pem, memset(output_buf, 0, sizeof(output_buf)); ret = mbedtls_pk_write_pubkey_pem(&key, output_buf, sizeof(output_buf)); if (ret != 0) { - WARNF("(ecdsa) Failed to write public key: -0x%04x", -ret); + WARNF("(crypto) Failed to write public key: -0x%04x", -ret); goto cleanup; } *pub_key_pem = strdup((char *)output_buf); if (*pub_key_pem == NULL) { - WARNF("(ecdsa) Failed to allocate memory for public key PEM"); + WARNF("(crypto) Failed to allocate memory for public key PEM"); ret = -1; goto cleanup; } @@ -771,7 +1024,7 @@ static int LuaECDSAGenerateKeyPair(lua_State *L) { // Sign a message using an ECDSA private key in PEM format static int ECDSASign(const char *priv_key_pem, const char *message, - hash_algorithm_t hash_alg, unsigned char **signature, + mbedtls_md_type_t hash_alg, unsigned char **signature, size_t *sig_len) { mbedtls_pk_context key; unsigned char hash[64]; // Max hash size (SHA-512) @@ -782,19 +1035,19 @@ static int ECDSASign(const char *priv_key_pem, const char *message, *sig_len = 0; if (!priv_key_pem) { - WARNF("(ecdsa) Private key is NULL"); + WARNF("(crypto) Private key is NULL"); return -1; } // Get the length of the PEM string (excluding null terminator) size_t key_len = strlen(priv_key_pem); if (key_len == 0) { - WARNF("(ecdsa) Private key is empty"); + WARNF("(crypto) Private key is empty"); return -1; } // Get hash size for the selected algorithm - hash_size = get_hash_size(hash_alg); + hash_size = get_hash_size_from_md_type(hash_alg); mbedtls_pk_init(&key); @@ -803,7 +1056,7 @@ static int ECDSASign(const char *priv_key_pem, const char *message, key_len + 1, NULL, 0); if (ret != 0) { - WARNF("(ecdsa) Failed to parse private key: -0x%04x", -ret); + WARNF("(crypto) Failed to parse private key: -0x%04x", -ret); goto cleanup; } @@ -811,24 +1064,24 @@ static int ECDSASign(const char *priv_key_pem, const char *message, ret = compute_hash(hash_alg, (const unsigned char *)message, strlen(message), hash, sizeof(hash)); if (ret != 0) { - WARNF("(ecdsa) Failed to compute message hash"); + WARNF("(crypto) Failed to compute message hash"); goto cleanup; } // Allocate memory for signature (max size for ECDSA) *signature = malloc(MBEDTLS_ECDSA_MAX_LEN); if (*signature == NULL) { - WARNF("(ecdsa) Failed to allocate memory for signature"); + WARNF("(crypto) Failed to allocate memory for signature"); ret = -1; goto cleanup; } // Sign the hash using GenerateRandom - ret = mbedtls_pk_sign(&key, hash_to_md_type(hash_alg), hash, hash_size, + ret = mbedtls_pk_sign(&key, hash_alg, hash, hash_size, *signature, sig_len, GenerateRandom, 0); if (ret != 0) { - WARNF("(ecdsa) Failed to sign message: -0x%04x", -ret); + WARNF("(crypto) Failed to sign message: -0x%04x", -ret); free(*signature); *signature = NULL; *sig_len = 0; @@ -845,7 +1098,7 @@ static int LuaECDSASign(lua_State *L) { const char *message = luaL_checkstring(L, 2); const char *hash_name = luaL_optstring(L, 3, "sha256"); - hash_algorithm_t hash_alg = string_to_hash_alg(hash_name); + mbedtls_md_type_t hash_alg = string_to_md_type(hash_name); unsigned char *signature = NULL; size_t sig_len = 0; @@ -865,26 +1118,26 @@ static int LuaECDSASign(lua_State *L) { // Verify a signature using an ECDSA public key in PEM format static int ECDSAVerify(const char *pub_key_pem, const char *message, const unsigned char *signature, size_t sig_len, - hash_algorithm_t hash_alg) { + mbedtls_md_type_t hash_alg) { mbedtls_pk_context key; unsigned char hash[64]; // Max hash size (SHA-512) size_t hash_size; int ret; if (!pub_key_pem) { - WARNF("(ecdsa) Public key is NULL"); + WARNF("(crypto) Public key is NULL"); return -1; } // Get the length of the PEM string (excluding null terminator) size_t key_len = strlen(pub_key_pem); if (key_len == 0) { - WARNF("(ecdsa) Public key is empty"); + WARNF("(crypto) Public key is empty"); return -1; } // Get hash size for the selected algorithm - hash_size = get_hash_size(hash_alg); + hash_size = get_hash_size_from_md_type(hash_alg); mbedtls_pk_init(&key); @@ -892,7 +1145,7 @@ static int ECDSAVerify(const char *pub_key_pem, const char *message, ret = mbedtls_pk_parse_public_key(&key, (const unsigned char *)pub_key_pem, key_len + 1); if (ret != 0) { - WARNF("(ecdsa) Failed to parse public key: -0x%04x", -ret); + WARNF("(crypto) Failed to parse public key: -0x%04x", -ret); goto cleanup; } @@ -900,15 +1153,15 @@ static int ECDSAVerify(const char *pub_key_pem, const char *message, ret = compute_hash(hash_alg, (const unsigned char *)message, strlen(message), hash, sizeof(hash)); if (ret != 0) { - WARNF("(ecdsa) Failed to compute message hash"); + WARNF("(crypto) Failed to compute message hash"); goto cleanup; } // Verify the signature - ret = mbedtls_pk_verify(&key, hash_to_md_type(hash_alg), hash, hash_size, + ret = mbedtls_pk_verify(&key, hash_alg, hash, hash_size, signature, sig_len); if (ret != 0) { - WARNF("(ecdsa) Signature verification failed: -0x%04x", -ret); + WARNF("(crypto) Signature verification failed: -0x%04x", -ret); goto cleanup; } @@ -925,7 +1178,7 @@ static int LuaECDSAVerify(lua_State *L) { (const unsigned char *)luaL_checklstring(L, 3, &sig_len); const char *hash_name = luaL_optstring(L, 4, "sha256"); - hash_algorithm_t hash_alg = string_to_hash_alg(hash_name); + mbedtls_md_type_t hash_alg = string_to_md_type(hash_name); int ret = ECDSAVerify(pub_key_pem, message, signature, sig_len, hash_alg); @@ -970,42 +1223,92 @@ typedef struct { size_t aadlen; } aes_options_t; -static void parse_aes_options(lua_State *L, int options_idx, - aes_options_t *opts) { - opts->mode = "cbc"; +static void parse_aes_options(lua_State *L, int options_idx, aes_options_t *opts) { + opts->mode = NULL; opts->iv = NULL; opts->ivlen = 0; opts->tag = NULL; opts->taglen = 0; opts->aad = NULL; opts->aadlen = 0; + + int mode_field_found = 0; + if (lua_istable(L, options_idx)) { + // Get mode lua_getfield(L, options_idx, "mode"); - if (!lua_isnil(L, -1)) - opts->mode = lua_tostring(L, -1); + if (!lua_isnil(L, -1)) { + mode_field_found = 1; + const char *mode = lua_tostring(L, -1); + if (mode && (strcasecmp(mode, "cbc") == 0 || + strcasecmp(mode, "gcm") == 0 || + strcasecmp(mode, "ctr") == 0)) { + opts->mode = mode; + } else { + opts->mode = NULL; // Invalid mode + } + } lua_pop(L, 1); + + // Get IV lua_getfield(L, options_idx, "iv"); if (lua_isstring(L, -1)) { - opts->iv = (const unsigned char *)lua_tolstring(L, -1, &opts->ivlen); + size_t ivlen; + opts->iv = (const unsigned char *)lua_tolstring(L, -1, &ivlen); + opts->ivlen = ivlen; } lua_pop(L, 1); + + // Get tag (for GCM) lua_getfield(L, options_idx, "tag"); if (lua_isstring(L, -1)) { - opts->tag = (const unsigned char *)lua_tolstring(L, -1, &opts->taglen); + size_t taglen; + opts->tag = (const unsigned char *)lua_tolstring(L, -1, &taglen); + opts->taglen = taglen; } lua_pop(L, 1); + + // Get aad (for GCM) lua_getfield(L, options_idx, "aad"); if (lua_isstring(L, -1)) { - opts->aad = (const unsigned char *)lua_tolstring(L, -1, &opts->aadlen); + size_t aadlen; + opts->aad = (const unsigned char *)lua_tolstring(L, -1, &aadlen); + opts->aadlen = aadlen; } lua_pop(L, 1); } -} + // Only default to cbc if mode field was not found at all + if (!mode_field_found) { + opts->mode = "cbc"; + } +} // AES encryption supporting CBC, GCM, and CTR modes static int LuaAesEncrypt(lua_State *L) { // Args: key, plaintext, options table size_t keylen, ptlen; + + + // Get parameters from Lua + // Ensure key is a string + if (lua_type(L, 1) != LUA_TSTRING) { + lua_pushnil(L); + lua_pushstring(L, "Key must be a string"); + return 2; + } + // Ensure plaintext is a string + if (lua_type(L, 2) != LUA_TSTRING) { + lua_pushnil(L); + lua_pushstring(L, "Plaintext must be a string"); + return 2; + } + // Ensure options is a table or nil + if (!lua_istable(L, 3) && !lua_isnil(L, 3)) { + lua_pushnil(L); + lua_pushstring(L, "Options must be a table or nil"); + return 2; + } + const unsigned char *key = (const unsigned char *)luaL_checklstring(L, 1, &keylen); const unsigned char *plaintext = @@ -1014,6 +1317,11 @@ static int LuaAesEncrypt(lua_State *L) { aes_options_t opts; parse_aes_options(L, options_idx, &opts); const char *mode = opts.mode; + if (!mode) { + lua_pushnil(L); + lua_pushstring(L, "Invalid AES mode specified. Use 'cbc', 'gcm', or 'ctr'."); + return 2; + } const unsigned char *iv = opts.iv; size_t ivlen = opts.ivlen; unsigned char *gen_iv = NULL; @@ -1032,6 +1340,7 @@ static int LuaAesEncrypt(lua_State *L) { lua_pushstring(L, "Unsupported AES mode. Use 'cbc', 'gcm', or 'ctr'."); return 2; } + // If IV is not provided, auto-generate if (!iv) { if (is_gcm) { @@ -1055,6 +1364,26 @@ static int LuaAesEncrypt(lua_State *L) { iv = gen_iv; iv_was_generated = 1; } + + // Validate IV/nonce length + if (is_cbc || is_ctr) { + // Validate IV/nonce length + if (is_cbc || is_ctr) { + if (opts.iv && opts.ivlen != 16) { + if (iv_was_generated) free(gen_iv); + lua_pushnil(L); + lua_pushstring(L, "AES IV/nonce must be 16 bytes for CBC/CTR"); + return 2; + } + } else if (is_gcm) { + if (opts.iv && (opts.ivlen < 12 || opts.ivlen > 16)) { + if (iv_was_generated) free(gen_iv); + lua_pushnil(L); + lua_pushstring(L, "AES GCM nonce must be 12-16 bytes"); + return 2; + } + } + } if (is_cbc) { // PKCS7 padding size_t block_size = 16; @@ -1188,14 +1517,32 @@ static int LuaAesEncrypt(lua_State *L) { static int LuaAesDecrypt(lua_State *L) { // Args: key, ciphertext, options table size_t keylen, ctlen; + // Ensure key is a string + if (lua_type(L, 1) != LUA_TSTRING) { + lua_pushnil(L); + lua_pushstring(L, "Key must be a string"); + return 2; + } + + // Ensure ciphertext is a string + if (lua_type(L, 2) != LUA_TSTRING) { + lua_pushnil(L); + lua_pushstring(L, "Ciphertext must be a string"); + return 2; + } const unsigned char *key = - (const unsigned char *)luaL_checklstring(L, 1, &keylen); + (const unsigned char *)luaL_checklstring(L, 1, &keylen); const unsigned char *ciphertext = (const unsigned char *)luaL_checklstring(L, 2, &ctlen); int options_idx = 3; aes_options_t opts; parse_aes_options(L, options_idx, &opts); const char *mode = opts.mode; + if (!mode) { + lua_pushnil(L); + lua_pushstring(L, "Invalid AES mode specified. Use 'cbc', 'gcm', or 'ctr'."); + return 2; + } const unsigned char *iv = opts.iv; size_t ivlen = opts.ivlen; const unsigned char *tag = opts.tag; @@ -1385,6 +1732,13 @@ static int LuaAesDecrypt(lua_State *L) { static int LuaConvertJwkToPem(lua_State *L) { luaL_checktype(L, 1, LUA_TTABLE); const char *kty; + + if (lua_isnoneornil(L, 1) || lua_type(L, 1) != LUA_TTABLE) { + lua_pushnil(L); + lua_pushstring(L, "Expected a JWK table, got nil"); + return 2; + } + lua_getfield(L, 1, "kty"); kty = lua_tostring(L, -1); if (!kty) { @@ -1404,6 +1758,16 @@ static int LuaConvertJwkToPem(lua_State *L) { lua_getfield(L, 1, "e"); const char *n_b64 = lua_tostring(L, -2); const char *e_b64 = lua_tostring(L, -1); + if (!n_b64 || !*n_b64) { + lua_pushnil(L); + lua_pushstring(L, "Missing or empty 'n' in JWK"); + return 2; + } + if (!e_b64 || !*e_b64) { + lua_pushnil(L); + lua_pushstring(L, "Missing or empty 'e' in JWK"); + return 2; + } // Optional private fields lua_getfield(L, 1, "d"); lua_getfield(L, 1, "p"); @@ -1498,6 +1862,15 @@ static int LuaConvertJwkToPem(lua_State *L) { const char *x_b64 = lua_tostring(L, -3); const char *y_b64 = lua_tostring(L, -2); const char *d_b64 = lua_tostring(L, -1); + if (!crv || !*crv) { + lua_pushnil(L); lua_pushstring(L, "Missing or empty 'crv' in JWK"); return 2; + } + if (!x_b64 || !*x_b64) { + lua_pushnil(L); lua_pushstring(L, "Missing or empty 'x' in JWK"); return 2; + } + if (!y_b64 || !*y_b64) { + lua_pushnil(L); lua_pushstring(L, "Missing or empty 'y' in JWK"); return 2; + } int has_private = d_b64 && *d_b64; mbedtls_ecp_group_id gid = find_curve_by_name(crv); if (gid == MBEDTLS_ECP_DP_NONE) { @@ -1507,7 +1880,7 @@ static int LuaConvertJwkToPem(lua_State *L) { unsigned char x_bin[72], y_bin[72]; char *x_b64_std = strdup(x_b64), *y_b64_std = strdup(y_b64); for (char *p = x_b64_std; *p; ++p) if (*p == '-') *p = '+'; else if (*p == '_') *p = '/'; - for (char *p = y_b64_std; *p; ++p) if (*p == '-') *p = '+'; else if (*p == '_') *p = '/'; + for ( char *p = y_b64_std; *p; ++p) if (*p == '-') *p = '+'; else if (*p == '_') *p = '/'; int x_mod = strlen(x_b64_std) % 4; int y_mod = strlen(y_b64_std) % 4; if (x_mod) for (int i = 0; i < 4 - x_mod; ++i) strcat(x_b64_std, "="); @@ -1542,7 +1915,7 @@ static int LuaConvertJwkToPem(lua_State *L) { if (ret != 0) { mbedtls_pk_free(&pk); lua_pushnil(L); lua_pushstring(L, "PEM write failed"); return 2; - } + } pem = strdup((char *)buf); mbedtls_pk_free(&pk); lua_pushstring(L, pem); @@ -1556,8 +1929,25 @@ static int LuaConvertJwkToPem(lua_State *L) { } // Convert PEM key to JWK (Lua table) format +static void base64_to_base64url(char *str) { + if (!str) return; + for (char *p = str; *p; p++) { + if (*p == '+') *p = '-'; + else if (*p == '/') *p = '_'; + } + // Remove padding + size_t len = strlen(str); + while (len > 0 && str[len-1] == '=') { + str[--len] = '\0'; + } +} + static int LuaConvertPemToJwk(lua_State *L) { const char *pem_key = luaL_checkstring(L, 1); + int has_claims = 0; + if (!lua_isnoneornil(L, 2) && lua_istable(L, 2)) { + has_claims = 1; + } mbedtls_pk_context key; mbedtls_pk_init(&key); @@ -1600,12 +1990,12 @@ static int LuaConvertPemToJwk(lua_State *L) { mbedtls_base64_encode(NULL, 0, &e_b64_len, e, e_len); n_b64 = malloc(n_b64_len + 1); e_b64 = malloc(e_b64_len + 1); - mbedtls_base64_encode((unsigned char *)n_b64, n_b64_len, &n_b64_len, n, - n_len); - mbedtls_base64_encode((unsigned char *)e_b64, e_b64_len, &e_b64_len, e, - e_len); + mbedtls_base64_encode((unsigned char *)n_b64, n_b64_len, &n_b64_len, n, n_len); n_b64[n_b64_len] = '\0'; + base64_to_base64url(n_b64); + mbedtls_base64_encode((unsigned char *)e_b64, e_b64_len, &e_b64_len, e, e_len); e_b64[e_b64_len] = '\0'; + base64_to_base64url(e_b64); lua_pushstring(L, n_b64); lua_setfield(L, -2, "n"); lua_pushstring(L, e_b64); @@ -1655,24 +2045,25 @@ static int LuaConvertPemToJwk(lua_State *L) { dp_b64 = malloc(dp_b64_len + 1); dq_b64 = malloc(dq_b64_len + 1); qi_b64 = malloc(qi_b64_len + 1); - mbedtls_base64_encode((unsigned char *)d_b64, d_b64_len, &d_b64_len, d, - d_len); - mbedtls_base64_encode((unsigned char *)p_b64, p_b64_len, &p_b64_len, p, - p_len); - mbedtls_base64_encode((unsigned char *)q_b64, q_b64_len, &q_b64_len, q, - q_len); - mbedtls_base64_encode((unsigned char *)dp_b64, dp_b64_len, &dp_b64_len, - dp, dp_len); - mbedtls_base64_encode((unsigned char *)dq_b64, dq_b64_len, &dq_b64_len, - dq, dq_len); - mbedtls_base64_encode((unsigned char *)qi_b64, qi_b64_len, &qi_b64_len, - qi, qi_len); + mbedtls_base64_encode((unsigned char *)d_b64, d_b64_len, &d_b64_len, d, d_len); + mbedtls_base64_encode((unsigned char *)p_b64, p_b64_len, &p_b64_len, p, p_len); + mbedtls_base64_encode((unsigned char *)q_b64, q_b64_len, &q_b64_len, q, q_len); + mbedtls_base64_encode((unsigned char *)dp_b64, dp_b64_len, &dp_b64_len, dp, dp_len); + mbedtls_base64_encode((unsigned char *)dq_b64, dq_b64_len, &dq_b64_len, dq, dq_len); + mbedtls_base64_encode((unsigned char *)qi_b64, qi_b64_len, &qi_b64_len, qi, qi_len); d_b64[d_b64_len] = '\0'; p_b64[p_b64_len] = '\0'; q_b64[q_b64_len] = '\0'; dp_b64[dp_b64_len] = '\0'; dq_b64[dq_b64_len] = '\0'; qi_b64[qi_b64_len] = '\0'; + // Convert all private components to base64url + base64_to_base64url(d_b64); + base64_to_base64url(p_b64); + base64_to_base64url(q_b64); + base64_to_base64url(dp_b64); + base64_to_base64url(dq_b64); + base64_to_base64url(qi_b64); lua_pushstring(L, d_b64); lua_setfield(L, -2, "d"); lua_pushstring(L, p_b64); @@ -1727,9 +2118,11 @@ static int LuaConvertPemToJwk(lua_State *L) { x_b64 = malloc(x_b64_len + 1); y_b64 = malloc(y_b64_len + 1); mbedtls_base64_encode((unsigned char *)x_b64, x_b64_len, &x_b64_len, x, x_len); - mbedtls_base64_encode((unsigned char *)y_b64, y_b64_len, &y_b64_len, y, y_len); x_b64[x_b64_len] = '\0'; + base64_to_base64url(x_b64); + mbedtls_base64_encode((unsigned char *)y_b64, y_b64_len, &y_b64_len, y, y_len); y_b64[y_b64_len] = '\0'; + base64_to_base64url(y_b64); // Set kty and crv for EC keys lua_pushstring(L, "EC"); lua_setfield(L, -2, "kty"); @@ -1757,6 +2150,7 @@ static int LuaConvertPemToJwk(lua_State *L) { d_b64 = malloc(d_b64_len + 1); mbedtls_base64_encode((unsigned char *)d_b64, d_b64_len, &d_b64_len, d, d_len); d_b64[d_b64_len] = '\0'; + base64_to_base64url(d_b64); lua_pushstring(L, d_b64); lua_setfield(L, -2, "d"); free(d); free(d_b64); } @@ -1769,10 +2163,33 @@ static int LuaConvertPemToJwk(lua_State *L) { } mbedtls_pk_free(&key); + + // Merge additional claims if provided and compatible with RFC7517 + if (has_claims) { + static const char *allowed[] = {"kty","use","sig","key_ops","alg","kid","x5u","x5c","x5t","x5t#S256",NULL}; + lua_pushnil(L); // first key + while (lua_next(L, 2) != 0) { + const char *k = lua_tostring(L, -2); + int allowed_key = 0; + for (int i = 0; allowed[i]; ++i) { + if (strcmp(k, allowed[i]) == 0) { + allowed_key = 1; + break; + } + } + if (allowed_key) { + lua_pushvalue(L, -2); + lua_insert(L, -2); + lua_settable(L, -4); + } else { + lua_pop(L, 1); + } + } + } + return 1; } - // CSR creation Function static int LuaGenerateCSR(lua_State *L) { const char *key_pem = luaL_checkstring(L, 1); @@ -1838,13 +2255,15 @@ static int LuaGenerateCSR(lua_State *L) { // LuaCrypto compatible API static int LuaCryptoSign(lua_State *L) { - // Type of signature (e.g., "rsa", "ecdsa") + // Type of signature (e.g., "rsa", "ecdsa", "rsa-pss") const char *dtype = luaL_checkstring(L, 1); // Remove the first argument (key type or cipher type) before dispatching lua_remove(L, 1); if (strcasecmp(dtype, "rsa") == 0) { return LuaRSASign(L); + } else if (strcasecmp(dtype, "rsa-pss") == 0 || strcasecmp(dtype, "rsapss") == 0) { + return LuaRSAPSSSign(L); } else if (strcasecmp(dtype, "ecdsa") == 0) { return LuaECDSASign(L); } else { @@ -1853,13 +2272,15 @@ static int LuaCryptoSign(lua_State *L) { } static int LuaCryptoVerify(lua_State *L) { - // Type of signature (e.g., "rsa", "ecdsa") + // Type of signature (e.g., "rsa", "ecdsa", "rsa-pss") const char *dtype = luaL_checkstring(L, 1); // Remove the first argument (key type or cipher type) before dispatching lua_remove(L, 1); if (strcasecmp(dtype, "rsa") == 0) { return LuaRSAVerify(L); + } else if (strcasecmp(dtype, "rsa-pss") == 0 || strcasecmp(dtype, "rsapss") == 0) { + return LuaRSAPSSVerify(L); } else if (strcasecmp(dtype, "ecdsa") == 0) { return LuaECDSAVerify(L); } else { @@ -1925,7 +2346,7 @@ static const luaL_Reg kLuaCrypto[] = { {"verify", LuaCryptoVerify}, // {"encrypt", LuaCryptoEncrypt}, // {"decrypt", LuaCryptoDecrypt}, // - {"generatekeypair", LuaCryptoGenerateKeyPair}, // + {"generateKeyPair", LuaCryptoGenerateKeyPair}, // {"convertJwkToPem", LuaConvertJwkToPem}, // {"convertPemToJwk", LuaConvertPemToJwk}, // {"generateCsr", LuaGenerateCSR}, //