add square_error_loss and cross_entropy_loss functions

This commit is contained in:
xaedes 2023-05-07 01:21:26 +02:00
parent 73fd66e9e5
commit 7a5dec24f8
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -607,6 +607,25 @@ void lshift_examples(struct ggml_tensor * tokens_input, struct ggml_tensor * tar
}
}
struct ggml_tensor * square_error_loss(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) {
// todo: instead of a-b: a[1:]-b[:-1]
return ggml_sum(ctx, ggml_sqr(ctx, ggml_sub(ctx, a, b)));
}
struct ggml_tensor * cross_entropy_loss(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) {
const float eps = 1e-3;
return
ggml_sum(ctx,
ggml_neg(ctx,
ggml_sum_rows(ctx,
ggml_mul(ctx,
ggml_soft_max(ctx, a),
ggml_log(ctx,
ggml_add1(ctx,
ggml_soft_max(ctx, b),
ggml_new_f32(ctx, eps)))))));
}
int main(int argc, char ** argv) {
struct ggml_init_params lcparams;
lcparams.mem_size = 1024ll*1024ll*1024ll;
@ -645,7 +664,7 @@ int main(int argc, char ** argv) {
size_t compute_size = 1024ll*1024ll*1024ll;
uint8_t * compute_addr = new uint8_t[compute_size];
int n_examples = 32;
int n_examples = 128;
int n_tokens = model.hparams.n_ctx;
for (int ex=0; ex<n_examples; ++ex) {
@ -660,8 +679,8 @@ int main(int argc, char ** argv) {
// struct ggml_tensor * before_opt_best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
// struct ggml_tensor * before_opt_probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, model.hparams.n_vocab, n_tokens);
// struct ggml_tensor * after_opt_best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
// struct ggml_tensor * after_opt_probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, model.hparams.n_vocab, n_tokens);
struct ggml_tensor * after_opt_best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
struct ggml_tensor * after_opt_probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, model.hparams.n_vocab, n_tokens);
struct ggml_tensor * tokens_input = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
struct ggml_tensor * targets = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, model.hparams.n_vocab, n_tokens);
@ -676,7 +695,7 @@ int main(int argc, char ** argv) {
// print_tokens(tokens_input, model.hparams.n_vocab);
struct ggml_tensor * logits = forward(&model, &kv_self, ctx0, &gf, tokens_input, n_tokens, n_past);
struct ggml_tensor * e = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, targets, logits)));
struct ggml_tensor * e = square_error_loss(ctx0, targets, logits);
ggml_build_forward_expand(&gf, e);
ggml_graph_compute(ctx0, &gf);
@ -704,15 +723,21 @@ int main(int argc, char ** argv) {
ggml_graph_compute(ctx0, &gf);
float error_after_opt = ggml_get_f32_1d(e, 0);
// sample_softmax(logits, after_opt_probs, after_opt_best_samples);
printf("error_before_opt: %.2f\n", error_before_opt);
printf("error_after_opt: %.2f\n", error_after_opt);
// printf("probabilities after optimization:\n");
// print_probs(after_opt_probs);
// printf("best samples after optimization:\n");
// print_tokens(after_opt_best_samples, model.hparams.n_vocab);
if (ex % 8 == 0) {
printf("Example %d\n", (ex+1));
printf("error_before_opt: %.2f\n", error_before_opt);
printf("error_after_opt: %.2f\n", error_after_opt);
}
if (ex % 64 == 0) {
sample_softmax(logits, after_opt_probs, after_opt_best_samples);
// printf("probabilities after optimization:\n");
// print_probs(after_opt_probs);
printf("best samples after optimization:\n");
print_tokens(after_opt_best_samples, model.hparams.n_vocab);
}
ggml_free(ctx0);
}