use ggml_cross_entropy_loss in text training example

This commit is contained in:
xaedes 2023-05-28 22:00:26 +02:00
parent f056a04a80
commit 1fbd19abe1
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -1237,7 +1237,7 @@ void get_example_targets(const int * train_samples, size_t n_train_samples, cons
for (int i=1; i<n_tokens+1; ++i) { for (int i=1; i<n_tokens+1; ++i) {
int token = clamp(train_data[sample+i-1], 0, n_vocab-1); int token = clamp(train_data[sample+i-1], 0, n_vocab-1);
set_f32_2d(target_logits, token, i-1, +1.0f); set_f32_2d(target_logits, token, i-1, +1.0f);
set_f32_2d(target_probs, token, i-1, -1.0f); set_f32_2d(target_probs, token, i-1, +1.0f);
if (i<n_tokens) { if (i<n_tokens) {
ggml_set_i32_1d(tokens_input, i, token); ggml_set_i32_1d(tokens_input, i, token);
} }
@ -1269,7 +1269,7 @@ void get_example_targets_batch(struct llama_context * lctx, const int * train_sa
int token = clamp(train_data[sample+i-1], 0, n_vocab-1); int token = clamp(train_data[sample+i-1], 0, n_vocab-1);
// print_token(lctx, token); // print_token(lctx, token);
set_f32_3d(target_logits, token, i-1, k, +1.0f); set_f32_3d(target_logits, token, i-1, k, +1.0f);
set_f32_3d(target_probs, token, i-1, k, -1.0f); set_f32_3d(target_probs, token, i-1, k, +1.0f);
if (i<n_tokens) { if (i<n_tokens) {
set_i32_2d(tokens_input, i, k, token); set_i32_2d(tokens_input, i, k, token);
} }
@ -1301,17 +1301,7 @@ struct ggml_tensor * square_error_loss(struct ggml_context * ctx, struct ggml_te
} }
struct ggml_tensor * cross_entropy_loss(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * probs) { struct ggml_tensor * cross_entropy_loss(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * probs) {
const float eps = 1e-9f; return ggml_cross_entropy_loss(ctx, a, probs);
return
ggml_sum(ctx,
ggml_mul(ctx,
probs,
ggml_log(ctx,
ggml_add1_inplace(ctx,
ggml_scale_inplace(ctx,
ggml_soft_max(ctx, a),
ggml_new_f32(ctx, 1.0f-eps)),
ggml_new_f32(ctx, eps)))));
} }
#ifdef __GNUC__ #ifdef __GNUC__