diff --git a/examples/baby-llama/baby-llama.cpp b/examples/baby-llama/baby-llama.cpp index 2316391e8..68ed00d9e 100644 --- a/examples/baby-llama/baby-llama.cpp +++ b/examples/baby-llama/baby-llama.cpp @@ -681,27 +681,54 @@ int main(int argc, char ** argv) { // struct ggml_tensor * before_opt_probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, model.hparams.n_vocab, n_tokens); struct ggml_tensor * after_opt_best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); struct ggml_tensor * after_opt_probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, model.hparams.n_vocab, n_tokens); - struct ggml_tensor * tokens_input = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); - struct ggml_tensor * targets = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, model.hparams.n_vocab, n_tokens); + struct ggml_tensor * tokens_input1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + struct ggml_tensor * tokens_input2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + // struct ggml_tensor * tokens_input3 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + // struct ggml_tensor * tokens_input4 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + struct ggml_tensor * targets1 = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, model.hparams.n_vocab, n_tokens); + struct ggml_tensor * targets2 = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, model.hparams.n_vocab, n_tokens); + // struct ggml_tensor * targets3 = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, model.hparams.n_vocab, n_tokens); + // struct ggml_tensor * targets4 = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, model.hparams.n_vocab, n_tokens); int n_past = 0; ggml_cgraph gf = {}; gf.n_threads = 1; - get_example_targets(ex, tokens_input, targets); - printf("Example %d\n", (ex+1)); + get_example_targets(64*ex+0, tokens_input1, targets1); + get_example_targets(64*ex+16, tokens_input2, targets2); + // get_example_targets(64*ex+32, tokens_input3, targets3); + // get_example_targets(64*ex+48, tokens_input4, targets4); // print_probs(targets); // print_tokens(tokens_input, model.hparams.n_vocab); - struct ggml_tensor * logits = forward(&model, &kv_self, ctx0, &gf, tokens_input, n_tokens, n_past); - struct ggml_tensor * e = square_error_loss(ctx0, targets, logits); + struct ggml_tensor * logits1 = forward(&model, &kv_self, ctx0, &gf, tokens_input1, n_tokens, n_past); + struct ggml_tensor * logits2 = forward(&model, &kv_self, ctx0, &gf, tokens_input2, n_tokens, n_past); + // struct ggml_tensor * logits3 = forward(&model, &kv_self, ctx0, &gf, tokens_input3, n_tokens, n_past); + // struct ggml_tensor * logits4 = forward(&model, &kv_self, ctx0, &gf, tokens_input4, n_tokens, n_past); + + // struct ggml_tensor * e = cross_entropy_loss(ctx0, targets1, logits1); + // struct ggml_tensor * e = square_error_loss(ctx0, targets1, logits1); + + struct ggml_tensor * e = ggml_add(ctx0, + square_error_loss(ctx0, targets1, logits1), + square_error_loss(ctx0, targets2, logits2)); + // struct ggml_tensor * e = ggml_add(ctx0, + // cross_entropy_loss(ctx0, targets1, logits1), + // cross_entropy_loss(ctx0, targets2, logits2)); + // struct ggml_tensor * e = ggml_add(ctx0, + // ggml_add(ctx0, + // cross_entropy_loss(ctx0, targets1, logits1), + // cross_entropy_loss(ctx0, targets2, logits2)), + // ggml_add(ctx0, + // cross_entropy_loss(ctx0, targets3, logits3), + // cross_entropy_loss(ctx0, targets4, logits4))); ggml_build_forward_expand(&gf, e); ggml_graph_compute(ctx0, &gf); float error_before_opt = ggml_get_f32_1d(e, 0); - // sample_softmax(logits, before_opt_probs, before_opt_best_samples); + // sample_softmax(logits1, before_opt_probs, before_opt_best_samples); // printf("probabilities before optimization:\n"); // print_probs(before_opt_probs); @@ -732,7 +759,7 @@ int main(int argc, char ** argv) { } if (ex % 64 == 0) { - sample_softmax(logits, after_opt_probs, after_opt_best_samples); + sample_softmax(logits1, after_opt_probs, after_opt_best_samples); // printf("probabilities after optimization:\n"); // print_probs(after_opt_probs); printf("best samples after optimization:\n"); @@ -804,6 +831,6 @@ int main(int argc, char ** argv) { printf("done\n"); // ggml_free(kv_self.ctx); - // ggml_free(model.ctx); + ggml_free(model.ctx); return 0; }