use ggml_cross_entropy_loss in text training example
This commit is contained in:
parent
f056a04a80
commit
1fbd19abe1
1 changed files with 3 additions and 13 deletions
|
@ -1237,7 +1237,7 @@ void get_example_targets(const int * train_samples, size_t n_train_samples, cons
|
||||||
for (int i=1; i<n_tokens+1; ++i) {
|
for (int i=1; i<n_tokens+1; ++i) {
|
||||||
int token = clamp(train_data[sample+i-1], 0, n_vocab-1);
|
int token = clamp(train_data[sample+i-1], 0, n_vocab-1);
|
||||||
set_f32_2d(target_logits, token, i-1, +1.0f);
|
set_f32_2d(target_logits, token, i-1, +1.0f);
|
||||||
set_f32_2d(target_probs, token, i-1, -1.0f);
|
set_f32_2d(target_probs, token, i-1, +1.0f);
|
||||||
if (i<n_tokens) {
|
if (i<n_tokens) {
|
||||||
ggml_set_i32_1d(tokens_input, i, token);
|
ggml_set_i32_1d(tokens_input, i, token);
|
||||||
}
|
}
|
||||||
|
@ -1269,7 +1269,7 @@ void get_example_targets_batch(struct llama_context * lctx, const int * train_sa
|
||||||
int token = clamp(train_data[sample+i-1], 0, n_vocab-1);
|
int token = clamp(train_data[sample+i-1], 0, n_vocab-1);
|
||||||
// print_token(lctx, token);
|
// print_token(lctx, token);
|
||||||
set_f32_3d(target_logits, token, i-1, k, +1.0f);
|
set_f32_3d(target_logits, token, i-1, k, +1.0f);
|
||||||
set_f32_3d(target_probs, token, i-1, k, -1.0f);
|
set_f32_3d(target_probs, token, i-1, k, +1.0f);
|
||||||
if (i<n_tokens) {
|
if (i<n_tokens) {
|
||||||
set_i32_2d(tokens_input, i, k, token);
|
set_i32_2d(tokens_input, i, k, token);
|
||||||
}
|
}
|
||||||
|
@ -1301,17 +1301,7 @@ struct ggml_tensor * square_error_loss(struct ggml_context * ctx, struct ggml_te
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * cross_entropy_loss(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * probs) {
|
struct ggml_tensor * cross_entropy_loss(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * probs) {
|
||||||
const float eps = 1e-9f;
|
return ggml_cross_entropy_loss(ctx, a, probs);
|
||||||
return
|
|
||||||
ggml_sum(ctx,
|
|
||||||
ggml_mul(ctx,
|
|
||||||
probs,
|
|
||||||
ggml_log(ctx,
|
|
||||||
ggml_add1_inplace(ctx,
|
|
||||||
ggml_scale_inplace(ctx,
|
|
||||||
ggml_soft_max(ctx, a),
|
|
||||||
ggml_new_f32(ctx, 1.0f-eps)),
|
|
||||||
ggml_new_f32(ctx, eps)))));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef __GNUC__
|
#ifdef __GNUC__
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue