From 7a5dec24f887b558afbd4abfecd8ae8f49a0b48b Mon Sep 17 00:00:00 2001 From: xaedes Date: Sun, 7 May 2023 01:21:26 +0200 Subject: [PATCH] add square_error_loss and cross_entropy_loss functions --- examples/baby-llama/baby-llama.cpp | 47 +++++++++++++++++++++++------- 1 file changed, 36 insertions(+), 11 deletions(-) diff --git a/examples/baby-llama/baby-llama.cpp b/examples/baby-llama/baby-llama.cpp index 3b02a383e..2316391e8 100644 --- a/examples/baby-llama/baby-llama.cpp +++ b/examples/baby-llama/baby-llama.cpp @@ -607,6 +607,25 @@ void lshift_examples(struct ggml_tensor * tokens_input, struct ggml_tensor * tar } } +struct ggml_tensor * square_error_loss(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) { + // todo: instead of a-b: a[1:]-b[:-1] + return ggml_sum(ctx, ggml_sqr(ctx, ggml_sub(ctx, a, b))); +} + +struct ggml_tensor * cross_entropy_loss(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) { + const float eps = 1e-3; + return + ggml_sum(ctx, + ggml_neg(ctx, + ggml_sum_rows(ctx, + ggml_mul(ctx, + ggml_soft_max(ctx, a), + ggml_log(ctx, + ggml_add1(ctx, + ggml_soft_max(ctx, b), + ggml_new_f32(ctx, eps))))))); +} + int main(int argc, char ** argv) { struct ggml_init_params lcparams; lcparams.mem_size = 1024ll*1024ll*1024ll; @@ -645,7 +664,7 @@ int main(int argc, char ** argv) { size_t compute_size = 1024ll*1024ll*1024ll; uint8_t * compute_addr = new uint8_t[compute_size]; - int n_examples = 32; + int n_examples = 128; int n_tokens = model.hparams.n_ctx; for (int ex=0; ex