Support more hashing functions

This commit is contained in:
Miguel Terron 2025-05-20 10:48:21 +12:00
parent 9e9f47342a
commit 87614a732d

View file

@ -23,8 +23,8 @@
#include "third_party/lua/lauxlib.h" #include "third_party/lua/lauxlib.h"
#include "third_party/mbedtls/ecdsa.h" #include "third_party/mbedtls/ecdsa.h"
#include "third_party/mbedtls/error.h" #include "third_party/mbedtls/error.h"
#include "third_party/mbedtls/md.h"
#include "third_party/mbedtls/pk.h" #include "third_party/mbedtls/pk.h"
#include "third_party/mbedtls/sha256.h"
// Supported curves mapping // Supported curves mapping
typedef struct { typedef struct {
@ -41,6 +41,144 @@ static const curve_map_t supported_curves[] = {
{"curve25519", MBEDTLS_ECP_DP_CURVE25519}, {"curve25519", MBEDTLS_ECP_DP_CURVE25519},
{NULL, 0}}; {NULL, 0}};
typedef enum { SHA256, SHA384, SHA512 } hash_algorithm_t;
static mbedtls_md_type_t hash_to_md_type(hash_algorithm_t hash_alg) {
switch (hash_alg) {
case SHA256:
return MBEDTLS_MD_SHA256;
case SHA384:
return MBEDTLS_MD_SHA384;
case SHA512:
return MBEDTLS_MD_SHA512;
default:
return MBEDTLS_MD_SHA256; // Default to SHA-256
}
}
static size_t get_hash_size(hash_algorithm_t hash_alg) {
switch (hash_alg) {
case SHA256:
return 32;
case SHA384:
return 48;
case SHA512:
return 64;
default:
return 32; // Default to SHA-256
}
}
static hash_algorithm_t string_to_hash_alg(const char *hash_name) {
if (!hash_name || !*hash_name) {
return SHA256; // Default to SHA-256 if no name provided
}
if (strcasecmp(hash_name, "sha256") == 0 ||
strcasecmp(hash_name, "sha-256") == 0) {
return SHA256;
} else if (strcasecmp(hash_name, "sha384") == 0 ||
strcasecmp(hash_name, "sha-384") == 0) {
return SHA384;
} else if (strcasecmp(hash_name, "sha512") == 0 ||
strcasecmp(hash_name, "sha-512") == 0) {
return SHA512;
} else {
WARNF("(ecdsa) Unknown hash algorithm '%s', using SHA-256", hash_name);
return SHA256;
}
}
static int LuaListHashAlgorithms(lua_State *L) {
lua_newtable(L);
lua_pushstring(L, "SHA256");
lua_rawseti(L, -2, 1);
lua_pushstring(L, "SHA384");
lua_rawseti(L, -2, 2);
lua_pushstring(L, "SHA512");
lua_rawseti(L, -2, 3);
// Add hyphenated versions
lua_pushstring(L, "SHA-256");
lua_rawseti(L, -2, 4);
lua_pushstring(L, "SHA-384");
lua_rawseti(L, -2, 5);
lua_pushstring(L, "SHA-512");
lua_rawseti(L, -2, 6);
return 1;
}
// List available curves
static int LuaListCurves(lua_State *L) {
const curve_map_t *curve = supported_curves;
int i = 1;
lua_newtable(L);
while (curve->name != NULL) {
lua_pushstring(L, curve->name);
lua_rawseti(L, -2, i++);
curve++;
}
return 1;
}
static int compute_hash(hash_algorithm_t hash_alg, const unsigned char *input,
size_t input_len, unsigned char *output,
size_t output_size) {
mbedtls_md_context_t md_ctx;
const mbedtls_md_info_t *md_info;
int ret;
mbedtls_md_type_t md_type = hash_to_md_type(hash_alg);
md_info = mbedtls_md_info_from_type(md_type);
if (md_info == NULL) {
WARNF("(ecdsa) Unsupported hash algorithm");
return -1;
}
if (output_size < mbedtls_md_get_size(md_info)) {
WARNF("(ecdsa) Output buffer too small for hash");
return -1;
}
mbedtls_md_init(&md_ctx);
ret = mbedtls_md_setup(&md_ctx, md_info, 0); // 0 = non-HMAC
if (ret != 0) {
WARNF("(ecdsa) Failed to set up hash context: -0x%04x", -ret);
goto cleanup;
}
ret = mbedtls_md_starts(&md_ctx);
if (ret != 0) {
WARNF("(ecdsa) Failed to start hash operation: -0x%04x", -ret);
goto cleanup;
}
ret = mbedtls_md_update(&md_ctx, input, input_len);
if (ret != 0) {
WARNF("(ecdsa) Failed to update hash: -0x%04x", -ret);
goto cleanup;
}
ret = mbedtls_md_finish(&md_ctx, output);
if (ret != 0) {
WARNF("(ecdsa) Failed to finish hash: -0x%04x", -ret);
goto cleanup;
}
cleanup:
mbedtls_md_free(&md_ctx);
return ret;
}
// Find curve ID by name // Find curve ID by name
static mbedtls_ecp_group_id find_curve_by_name(const char *name) { static mbedtls_ecp_group_id find_curve_by_name(const char *name) {
const curve_map_t *curve = supported_curves; const curve_map_t *curve = supported_curves;
@ -174,9 +312,11 @@ static int LuaGenerateKeyPair(lua_State *L) {
// Sign a message using an ECDSA private key in PEM format // Sign a message using an ECDSA private key in PEM format
static int Sign(const char *priv_key_pem, const char *message, static int Sign(const char *priv_key_pem, const char *message,
unsigned char **signature, size_t *sig_len) { hash_algorithm_t hash_alg, unsigned char **signature,
size_t *sig_len) {
mbedtls_pk_context key; mbedtls_pk_context key;
unsigned char hash[32]; // SHA-256 hash unsigned char hash[64]; // Max hash size (SHA-512)
size_t hash_size;
int ret; int ret;
*signature = NULL; *signature = NULL;
@ -194,6 +334,9 @@ static int Sign(const char *priv_key_pem, const char *message,
return -1; return -1;
} }
// Get hash size for the selected algorithm
hash_size = get_hash_size(hash_alg);
mbedtls_pk_init(&key); mbedtls_pk_init(&key);
// Parse the private key from PEM directly without creating a copy // Parse the private key from PEM directly without creating a copy
@ -205,11 +348,11 @@ static int Sign(const char *priv_key_pem, const char *message,
goto cleanup; goto cleanup;
} }
// Compute SHA-256 hash of the message // Compute hash of the message using the specified algorithm
ret = mbedtls_sha256_ret((const unsigned char *)message, strlen(message), ret = compute_hash(hash_alg, (const unsigned char *)message, strlen(message),
hash, 0); hash, sizeof(hash));
if (ret != 0) { if (ret != 0) {
WARNF("(ecdsa) Failed to compute message hash: -0x%04x", -ret); WARNF("(ecdsa) Failed to compute message hash");
goto cleanup; goto cleanup;
} }
@ -222,8 +365,8 @@ static int Sign(const char *priv_key_pem, const char *message,
} }
// Sign the hash using GenerateHardRandom // Sign the hash using GenerateHardRandom
ret = mbedtls_pk_sign(&key, MBEDTLS_MD_SHA256, hash, sizeof(hash), *signature, ret = mbedtls_pk_sign(&key, hash_to_md_type(hash_alg), hash, hash_size,
sig_len, GenerateHardRandom, 0); *signature, sig_len, GenerateHardRandom, 0);
if (ret != 0) { if (ret != 0) {
WARNF("(ecdsa) Failed to sign message: -0x%04x", -ret); WARNF("(ecdsa) Failed to sign message: -0x%04x", -ret);
@ -236,16 +379,18 @@ static int Sign(const char *priv_key_pem, const char *message,
cleanup: cleanup:
mbedtls_pk_free(&key); mbedtls_pk_free(&key);
return ret; return ret;
} } // Lua binding for signing a message
// Lua binding for signing a message
static int LuaSign(lua_State *L) { static int LuaSign(lua_State *L) {
const char *priv_key_pem = luaL_checkstring(L, 1); const char *hash_name = luaL_optstring(L, 1, "sha256"); // Default to SHA-256
const char *message = luaL_checkstring(L, 2); const char *message = luaL_checkstring(L, 2);
const char *priv_key_pem = luaL_checkstring(L, 3);
hash_algorithm_t hash_alg = string_to_hash_alg(hash_name);
unsigned char *signature = NULL; unsigned char *signature = NULL;
size_t sig_len = 0; size_t sig_len = 0;
int ret = Sign(priv_key_pem, message, &signature, &sig_len); int ret = Sign(priv_key_pem, message, hash_alg, &signature, &sig_len);
if (ret == 0) { if (ret == 0) {
lua_pushlstring(L, (const char *)signature, sig_len); lua_pushlstring(L, (const char *)signature, sig_len);
@ -259,31 +404,48 @@ static int LuaSign(lua_State *L) {
// Verify a signature using an ECDSA public key in PEM format // Verify a signature using an ECDSA public key in PEM format
static int Verify(const char *pub_key_pem, const char *message, static int Verify(const char *pub_key_pem, const char *message,
const unsigned char *signature, size_t sig_len) { const unsigned char *signature, size_t sig_len,
hash_algorithm_t hash_alg) {
mbedtls_pk_context key; mbedtls_pk_context key;
unsigned char hash[32]; // SHA-256 hash unsigned char hash[64]; // Max hash size (SHA-512)
size_t hash_size;
int ret; int ret;
if (!pub_key_pem) {
WARNF("(ecdsa) Public key is NULL");
return -1;
}
// Get the length of the PEM string (excluding null terminator)
size_t key_len = strlen(pub_key_pem);
if (key_len == 0) {
WARNF("(ecdsa) Public key is empty");
return -1;
}
// Get hash size for the selected algorithm
hash_size = get_hash_size(hash_alg);
mbedtls_pk_init(&key); mbedtls_pk_init(&key);
// Parse the public key from PEM // Parse the public key from PEM
ret = mbedtls_pk_parse_public_key(&key, (const unsigned char *)pub_key_pem, ret = mbedtls_pk_parse_public_key(&key, (const unsigned char *)pub_key_pem,
strlen(pub_key_pem) + 1); key_len + 1);
if (ret != 0) { if (ret != 0) {
WARNF("(ecdsa) Failed to parse public key: -0x%04x", -ret); WARNF("(ecdsa) Failed to parse public key: -0x%04x", -ret);
goto cleanup; goto cleanup;
} }
// Compute SHA-256 hash of the message // Compute hash of the message using the specified algorithm
ret = mbedtls_sha256_ret((const unsigned char *)message, strlen(message), ret = compute_hash(hash_alg, (const unsigned char *)message, strlen(message),
hash, 0); hash, sizeof(hash));
if (ret != 0) { if (ret != 0) {
WARNF("(ecdsa) Failed to compute message hash: -0x%04x", -ret); WARNF("(ecdsa) Failed to compute message hash");
goto cleanup; goto cleanup;
} }
// Verify the signature // Verify the signature
ret = mbedtls_pk_verify(&key, MBEDTLS_MD_SHA256, hash, sizeof(hash), ret = mbedtls_pk_verify(&key, hash_to_md_type(hash_alg), hash, hash_size,
signature, sig_len); signature, sig_len);
if (ret != 0) { if (ret != 0) {
WARNF("(ecdsa) Signature verification failed: -0x%04x", -ret); WARNF("(ecdsa) Signature verification failed: -0x%04x", -ret);
@ -294,43 +456,30 @@ cleanup:
mbedtls_pk_free(&key); mbedtls_pk_free(&key);
return ret; return ret;
} }
// Lua binding for verifying a signature
static int LuaVerify(lua_State *L) { static int LuaVerify(lua_State *L) {
const char *pub_key_pem = luaL_checkstring(L, 1); const char *hash_name = luaL_optstring(L, 1, "sha256"); // Default to SHA-256
const char *message = luaL_checkstring(L, 2); const char *message = luaL_checkstring(L, 2);
const char *pub_key_pem = luaL_checkstring(L, 3);
size_t sig_len; size_t sig_len;
const unsigned char *signature = const unsigned char *signature =
(const unsigned char *)luaL_checklstring(L, 3, &sig_len); (const unsigned char *)luaL_checklstring(L, 4, &sig_len);
int ret = Verify(pub_key_pem, message, signature, sig_len); hash_algorithm_t hash_alg = string_to_hash_alg(hash_name);
int ret = Verify(pub_key_pem, message, signature, sig_len, hash_alg);
lua_pushboolean(L, ret == 0); lua_pushboolean(L, ret == 0);
return 1; return 1;
} }
// List available curves
static int LuaListCurves(lua_State *L) {
const curve_map_t *curve = supported_curves;
int i = 1;
lua_newtable(L);
while (curve->name != NULL) {
lua_pushstring(L, curve->name);
lua_rawseti(L, -2, i++);
curve++;
}
return 1;
}
// Register functions // Register functions
static const luaL_Reg kLuaECDSA[] = { static const luaL_Reg kLuaECDSA[] = {
{"GenerateKeyPair", LuaGenerateKeyPair}, // {"GenerateKeyPair", LuaGenerateKeyPair}, //
{"Sign", LuaSign}, // {"Sign", LuaSign}, //
{"Verify", LuaVerify}, // {"Verify", LuaVerify}, //
{"ListCurves", LuaListCurves}, // {"ListCurves", LuaListCurves}, //
{0} // {"ListHashAlgorithms", LuaListHashAlgorithms}, //
{0} //
}; };
int LuaECDSA(lua_State *L) { int LuaECDSA(lua_State *L) {