add square_error_loss and cross_entropy_loss functions
This commit is contained in:
parent
73fd66e9e5
commit
7a5dec24f8
1 changed files with 36 additions and 11 deletions
|
@ -607,6 +607,25 @@ void lshift_examples(struct ggml_tensor * tokens_input, struct ggml_tensor * tar
|
|||
}
|
||||
}
|
||||
|
||||
struct ggml_tensor * square_error_loss(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) {
|
||||
// todo: instead of a-b: a[1:]-b[:-1]
|
||||
return ggml_sum(ctx, ggml_sqr(ctx, ggml_sub(ctx, a, b)));
|
||||
}
|
||||
|
||||
struct ggml_tensor * cross_entropy_loss(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) {
|
||||
const float eps = 1e-3;
|
||||
return
|
||||
ggml_sum(ctx,
|
||||
ggml_neg(ctx,
|
||||
ggml_sum_rows(ctx,
|
||||
ggml_mul(ctx,
|
||||
ggml_soft_max(ctx, a),
|
||||
ggml_log(ctx,
|
||||
ggml_add1(ctx,
|
||||
ggml_soft_max(ctx, b),
|
||||
ggml_new_f32(ctx, eps)))))));
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
struct ggml_init_params lcparams;
|
||||
lcparams.mem_size = 1024ll*1024ll*1024ll;
|
||||
|
@ -645,7 +664,7 @@ int main(int argc, char ** argv) {
|
|||
size_t compute_size = 1024ll*1024ll*1024ll;
|
||||
uint8_t * compute_addr = new uint8_t[compute_size];
|
||||
|
||||
int n_examples = 32;
|
||||
int n_examples = 128;
|
||||
int n_tokens = model.hparams.n_ctx;
|
||||
|
||||
for (int ex=0; ex<n_examples; ++ex) {
|
||||
|
@ -660,8 +679,8 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// struct ggml_tensor * before_opt_best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
|
||||
// struct ggml_tensor * before_opt_probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, model.hparams.n_vocab, n_tokens);
|
||||
// struct ggml_tensor * after_opt_best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
|
||||
// struct ggml_tensor * after_opt_probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, model.hparams.n_vocab, n_tokens);
|
||||
struct ggml_tensor * after_opt_best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
|
||||
struct ggml_tensor * after_opt_probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, model.hparams.n_vocab, n_tokens);
|
||||
struct ggml_tensor * tokens_input = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
|
||||
struct ggml_tensor * targets = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, model.hparams.n_vocab, n_tokens);
|
||||
|
||||
|
@ -676,7 +695,7 @@ int main(int argc, char ** argv) {
|
|||
// print_tokens(tokens_input, model.hparams.n_vocab);
|
||||
|
||||
struct ggml_tensor * logits = forward(&model, &kv_self, ctx0, &gf, tokens_input, n_tokens, n_past);
|
||||
struct ggml_tensor * e = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, targets, logits)));
|
||||
struct ggml_tensor * e = square_error_loss(ctx0, targets, logits);
|
||||
|
||||
ggml_build_forward_expand(&gf, e);
|
||||
ggml_graph_compute(ctx0, &gf);
|
||||
|
@ -704,15 +723,21 @@ int main(int argc, char ** argv) {
|
|||
ggml_graph_compute(ctx0, &gf);
|
||||
|
||||
float error_after_opt = ggml_get_f32_1d(e, 0);
|
||||
// sample_softmax(logits, after_opt_probs, after_opt_best_samples);
|
||||
|
||||
printf("error_before_opt: %.2f\n", error_before_opt);
|
||||
printf("error_after_opt: %.2f\n", error_after_opt);
|
||||
|
||||
// printf("probabilities after optimization:\n");
|
||||
// print_probs(after_opt_probs);
|
||||
// printf("best samples after optimization:\n");
|
||||
// print_tokens(after_opt_best_samples, model.hparams.n_vocab);
|
||||
if (ex % 8 == 0) {
|
||||
printf("Example %d\n", (ex+1));
|
||||
printf("error_before_opt: %.2f\n", error_before_opt);
|
||||
printf("error_after_opt: %.2f\n", error_after_opt);
|
||||
}
|
||||
|
||||
if (ex % 64 == 0) {
|
||||
sample_softmax(logits, after_opt_probs, after_opt_best_samples);
|
||||
// printf("probabilities after optimization:\n");
|
||||
// print_probs(after_opt_probs);
|
||||
printf("best samples after optimization:\n");
|
||||
print_tokens(after_opt_best_samples, model.hparams.n_vocab);
|
||||
}
|
||||
|
||||
ggml_free(ctx0);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue