Support more hashing functions

This commit is contained in:
Miguel Terron 2025-05-20 10:48:21 +12:00
parent 9e9f47342a
commit 87614a732d

View file

@ -23,8 +23,8 @@
#include "third_party/lua/lauxlib.h"
#include "third_party/mbedtls/ecdsa.h"
#include "third_party/mbedtls/error.h"
#include "third_party/mbedtls/md.h"
#include "third_party/mbedtls/pk.h"
#include "third_party/mbedtls/sha256.h"
// Supported curves mapping
typedef struct {
@ -41,6 +41,144 @@ static const curve_map_t supported_curves[] = {
{"curve25519", MBEDTLS_ECP_DP_CURVE25519},
{NULL, 0}};
typedef enum { SHA256, SHA384, SHA512 } hash_algorithm_t;
static mbedtls_md_type_t hash_to_md_type(hash_algorithm_t hash_alg) {
switch (hash_alg) {
case SHA256:
return MBEDTLS_MD_SHA256;
case SHA384:
return MBEDTLS_MD_SHA384;
case SHA512:
return MBEDTLS_MD_SHA512;
default:
return MBEDTLS_MD_SHA256; // Default to SHA-256
}
}
static size_t get_hash_size(hash_algorithm_t hash_alg) {
switch (hash_alg) {
case SHA256:
return 32;
case SHA384:
return 48;
case SHA512:
return 64;
default:
return 32; // Default to SHA-256
}
}
static hash_algorithm_t string_to_hash_alg(const char *hash_name) {
if (!hash_name || !*hash_name) {
return SHA256; // Default to SHA-256 if no name provided
}
if (strcasecmp(hash_name, "sha256") == 0 ||
strcasecmp(hash_name, "sha-256") == 0) {
return SHA256;
} else if (strcasecmp(hash_name, "sha384") == 0 ||
strcasecmp(hash_name, "sha-384") == 0) {
return SHA384;
} else if (strcasecmp(hash_name, "sha512") == 0 ||
strcasecmp(hash_name, "sha-512") == 0) {
return SHA512;
} else {
WARNF("(ecdsa) Unknown hash algorithm '%s', using SHA-256", hash_name);
return SHA256;
}
}
static int LuaListHashAlgorithms(lua_State *L) {
lua_newtable(L);
lua_pushstring(L, "SHA256");
lua_rawseti(L, -2, 1);
lua_pushstring(L, "SHA384");
lua_rawseti(L, -2, 2);
lua_pushstring(L, "SHA512");
lua_rawseti(L, -2, 3);
// Add hyphenated versions
lua_pushstring(L, "SHA-256");
lua_rawseti(L, -2, 4);
lua_pushstring(L, "SHA-384");
lua_rawseti(L, -2, 5);
lua_pushstring(L, "SHA-512");
lua_rawseti(L, -2, 6);
return 1;
}
// List available curves
static int LuaListCurves(lua_State *L) {
const curve_map_t *curve = supported_curves;
int i = 1;
lua_newtable(L);
while (curve->name != NULL) {
lua_pushstring(L, curve->name);
lua_rawseti(L, -2, i++);
curve++;
}
return 1;
}
static int compute_hash(hash_algorithm_t hash_alg, const unsigned char *input,
size_t input_len, unsigned char *output,
size_t output_size) {
mbedtls_md_context_t md_ctx;
const mbedtls_md_info_t *md_info;
int ret;
mbedtls_md_type_t md_type = hash_to_md_type(hash_alg);
md_info = mbedtls_md_info_from_type(md_type);
if (md_info == NULL) {
WARNF("(ecdsa) Unsupported hash algorithm");
return -1;
}
if (output_size < mbedtls_md_get_size(md_info)) {
WARNF("(ecdsa) Output buffer too small for hash");
return -1;
}
mbedtls_md_init(&md_ctx);
ret = mbedtls_md_setup(&md_ctx, md_info, 0); // 0 = non-HMAC
if (ret != 0) {
WARNF("(ecdsa) Failed to set up hash context: -0x%04x", -ret);
goto cleanup;
}
ret = mbedtls_md_starts(&md_ctx);
if (ret != 0) {
WARNF("(ecdsa) Failed to start hash operation: -0x%04x", -ret);
goto cleanup;
}
ret = mbedtls_md_update(&md_ctx, input, input_len);
if (ret != 0) {
WARNF("(ecdsa) Failed to update hash: -0x%04x", -ret);
goto cleanup;
}
ret = mbedtls_md_finish(&md_ctx, output);
if (ret != 0) {
WARNF("(ecdsa) Failed to finish hash: -0x%04x", -ret);
goto cleanup;
}
cleanup:
mbedtls_md_free(&md_ctx);
return ret;
}
// Find curve ID by name
static mbedtls_ecp_group_id find_curve_by_name(const char *name) {
const curve_map_t *curve = supported_curves;
@ -174,9 +312,11 @@ static int LuaGenerateKeyPair(lua_State *L) {
// Sign a message using an ECDSA private key in PEM format
static int Sign(const char *priv_key_pem, const char *message,
unsigned char **signature, size_t *sig_len) {
hash_algorithm_t hash_alg, unsigned char **signature,
size_t *sig_len) {
mbedtls_pk_context key;
unsigned char hash[32]; // SHA-256 hash
unsigned char hash[64]; // Max hash size (SHA-512)
size_t hash_size;
int ret;
*signature = NULL;
@ -194,6 +334,9 @@ static int Sign(const char *priv_key_pem, const char *message,
return -1;
}
// Get hash size for the selected algorithm
hash_size = get_hash_size(hash_alg);
mbedtls_pk_init(&key);
// Parse the private key from PEM directly without creating a copy
@ -205,11 +348,11 @@ static int Sign(const char *priv_key_pem, const char *message,
goto cleanup;
}
// Compute SHA-256 hash of the message
ret = mbedtls_sha256_ret((const unsigned char *)message, strlen(message),
hash, 0);
// Compute hash of the message using the specified algorithm
ret = compute_hash(hash_alg, (const unsigned char *)message, strlen(message),
hash, sizeof(hash));
if (ret != 0) {
WARNF("(ecdsa) Failed to compute message hash: -0x%04x", -ret);
WARNF("(ecdsa) Failed to compute message hash");
goto cleanup;
}
@ -222,8 +365,8 @@ static int Sign(const char *priv_key_pem, const char *message,
}
// Sign the hash using GenerateHardRandom
ret = mbedtls_pk_sign(&key, MBEDTLS_MD_SHA256, hash, sizeof(hash), *signature,
sig_len, GenerateHardRandom, 0);
ret = mbedtls_pk_sign(&key, hash_to_md_type(hash_alg), hash, hash_size,
*signature, sig_len, GenerateHardRandom, 0);
if (ret != 0) {
WARNF("(ecdsa) Failed to sign message: -0x%04x", -ret);
@ -236,16 +379,18 @@ static int Sign(const char *priv_key_pem, const char *message,
cleanup:
mbedtls_pk_free(&key);
return ret;
}
// Lua binding for signing a message
} // Lua binding for signing a message
static int LuaSign(lua_State *L) {
const char *priv_key_pem = luaL_checkstring(L, 1);
const char *hash_name = luaL_optstring(L, 1, "sha256"); // Default to SHA-256
const char *message = luaL_checkstring(L, 2);
const char *priv_key_pem = luaL_checkstring(L, 3);
hash_algorithm_t hash_alg = string_to_hash_alg(hash_name);
unsigned char *signature = NULL;
size_t sig_len = 0;
int ret = Sign(priv_key_pem, message, &signature, &sig_len);
int ret = Sign(priv_key_pem, message, hash_alg, &signature, &sig_len);
if (ret == 0) {
lua_pushlstring(L, (const char *)signature, sig_len);
@ -259,31 +404,48 @@ static int LuaSign(lua_State *L) {
// Verify a signature using an ECDSA public key in PEM format
static int Verify(const char *pub_key_pem, const char *message,
const unsigned char *signature, size_t sig_len) {
const unsigned char *signature, size_t sig_len,
hash_algorithm_t hash_alg) {
mbedtls_pk_context key;
unsigned char hash[32]; // SHA-256 hash
unsigned char hash[64]; // Max hash size (SHA-512)
size_t hash_size;
int ret;
if (!pub_key_pem) {
WARNF("(ecdsa) Public key is NULL");
return -1;
}
// Get the length of the PEM string (excluding null terminator)
size_t key_len = strlen(pub_key_pem);
if (key_len == 0) {
WARNF("(ecdsa) Public key is empty");
return -1;
}
// Get hash size for the selected algorithm
hash_size = get_hash_size(hash_alg);
mbedtls_pk_init(&key);
// Parse the public key from PEM
ret = mbedtls_pk_parse_public_key(&key, (const unsigned char *)pub_key_pem,
strlen(pub_key_pem) + 1);
key_len + 1);
if (ret != 0) {
WARNF("(ecdsa) Failed to parse public key: -0x%04x", -ret);
goto cleanup;
}
// Compute SHA-256 hash of the message
ret = mbedtls_sha256_ret((const unsigned char *)message, strlen(message),
hash, 0);
// Compute hash of the message using the specified algorithm
ret = compute_hash(hash_alg, (const unsigned char *)message, strlen(message),
hash, sizeof(hash));
if (ret != 0) {
WARNF("(ecdsa) Failed to compute message hash: -0x%04x", -ret);
WARNF("(ecdsa) Failed to compute message hash");
goto cleanup;
}
// Verify the signature
ret = mbedtls_pk_verify(&key, MBEDTLS_MD_SHA256, hash, sizeof(hash),
ret = mbedtls_pk_verify(&key, hash_to_md_type(hash_alg), hash, hash_size,
signature, sig_len);
if (ret != 0) {
WARNF("(ecdsa) Signature verification failed: -0x%04x", -ret);
@ -294,42 +456,29 @@ cleanup:
mbedtls_pk_free(&key);
return ret;
}
// Lua binding for verifying a signature
static int LuaVerify(lua_State *L) {
const char *pub_key_pem = luaL_checkstring(L, 1);
const char *hash_name = luaL_optstring(L, 1, "sha256"); // Default to SHA-256
const char *message = luaL_checkstring(L, 2);
const char *pub_key_pem = luaL_checkstring(L, 3);
size_t sig_len;
const unsigned char *signature =
(const unsigned char *)luaL_checklstring(L, 3, &sig_len);
(const unsigned char *)luaL_checklstring(L, 4, &sig_len);
int ret = Verify(pub_key_pem, message, signature, sig_len);
hash_algorithm_t hash_alg = string_to_hash_alg(hash_name);
int ret = Verify(pub_key_pem, message, signature, sig_len, hash_alg);
lua_pushboolean(L, ret == 0);
return 1;
}
// List available curves
static int LuaListCurves(lua_State *L) {
const curve_map_t *curve = supported_curves;
int i = 1;
lua_newtable(L);
while (curve->name != NULL) {
lua_pushstring(L, curve->name);
lua_rawseti(L, -2, i++);
curve++;
}
return 1;
}
// Register functions
static const luaL_Reg kLuaECDSA[] = {
{"GenerateKeyPair", LuaGenerateKeyPair}, //
{"Sign", LuaSign}, //
{"Verify", LuaVerify}, //
{"ListCurves", LuaListCurves}, //
{"ListHashAlgorithms", LuaListHashAlgorithms}, //
{0} //
};