From 581e5eb95406c2753fe8ef65bf1d2f1d6436465f Mon Sep 17 00:00:00 2001 From: xaedes Date: Thu, 11 May 2023 19:49:41 +0200 Subject: [PATCH] cleanup code for batched training --- examples/baby-llama/baby-llama.cpp | 62 +++++------------------------- 1 file changed, 9 insertions(+), 53 deletions(-) diff --git a/examples/baby-llama/baby-llama.cpp b/examples/baby-llama/baby-llama.cpp index 60d81bc4a..67059921a 100644 --- a/examples/baby-llama/baby-llama.cpp +++ b/examples/baby-llama/baby-llama.cpp @@ -1566,63 +1566,26 @@ int main(int argc, char ** argv) { struct ggml_context * ctx0 = ggml_init(params); - struct ggml_tensor * before_opt_best_samples = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch); - struct ggml_tensor * before_opt_probs = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch); struct ggml_tensor * after_opt_best_samples = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch); struct ggml_tensor * after_opt_probs = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch); - struct ggml_tensor * tokens_input1 = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch); - struct ggml_tensor * tokens_input2 = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch); - // struct ggml_tensor * tokens_input3 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); - // struct ggml_tensor * tokens_input4 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); - struct ggml_tensor * targets1 = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch); - struct ggml_tensor * targets2 = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch); - // struct ggml_tensor * targets3 = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens); - // struct ggml_tensor * targets4 = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens); + struct ggml_tensor * tokens_input = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch); + struct ggml_tensor * targets = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch); int n_past = 0; ggml_cgraph gf = {}; gf.n_threads = 1; - get_example_targets_batch(ctx0, 64*ex+0, tokens_input1, targets1); - // get_example_targets_batch(64*ex+16, tokens_input2, targets2); - // get_example_targets(64*ex+32, tokens_input3, targets3); - // get_example_targets(64*ex+48, tokens_input4, targets4); - // print_matrix(targets); - // print_tokens(tokens_input, n_vocab); + get_example_targets_batch(ctx0, 64*ex+0, tokens_input, targets); - struct ggml_tensor * logits1 = forward_batch(&model, &kv_self, ctx0, &gf, tokens_input1, n_tokens, n_past, n_batch); - // struct ggml_tensor * logits2 = forward_batch(&model, &kv_self, ctx0, &gf, tokens_input2, n_tokens, n_past, n_batch); - // struct ggml_tensor * logits3 = forward(&model, &kv_self, ctx0, &gf, tokens_input3, n_tokens, n_past); - // struct ggml_tensor * logits4 = forward(&model, &kv_self, ctx0, &gf, tokens_input4, n_tokens, n_past); - - // struct ggml_tensor * e = cross_entropy_loss(ctx0, targets1, logits1); - struct ggml_tensor * e = square_error_loss(ctx0, targets1, logits1); - - // struct ggml_tensor * e = ggml_add(ctx0, - // square_error_loss(ctx0, targets1, logits1), - // square_error_loss(ctx0, targets2, logits2)); - // struct ggml_tensor * e = ggml_add(ctx0, - // cross_entropy_loss(ctx0, targets1, logits1), - // cross_entropy_loss(ctx0, targets2, logits2)); - // struct ggml_tensor * e = ggml_add(ctx0, - // ggml_add(ctx0, - // cross_entropy_loss(ctx0, targets1, logits1), - // cross_entropy_loss(ctx0, targets2, logits2)), - // ggml_add(ctx0, - // cross_entropy_loss(ctx0, targets3, logits3), - // cross_entropy_loss(ctx0, targets4, logits4))); + struct ggml_tensor * logits = forward_batch(&model, &kv_self, ctx0, &gf, tokens_input, n_tokens, n_past, n_batch); + // struct ggml_tensor * e = cross_entropy_loss(ctx0, targets, logits); + struct ggml_tensor * e = square_error_loss(ctx0, targets, logits); ggml_build_forward_expand(&gf, e); ggml_graph_compute(ctx0, &gf); float error_before_opt = ggml_get_f32_1d(e, 0); - // sample_softmax(logits1, before_opt_probs, before_opt_best_samples); - - // printf("probabilities before optimization:\n"); - // print_matrix(before_opt_probs); - // printf("best samples before optimization:\n"); - // print_tokens(before_opt_best_samples, n_vocab); struct ggml_opt_params opt_params_adam = ggml_opt_default_params(GGML_OPT_ADAM); struct ggml_opt_params opt_params_lbfgs = ggml_opt_default_params(GGML_OPT_LBFGS); @@ -1632,15 +1595,14 @@ int main(int argc, char ** argv) { opt_params_lbfgs.print_backward_graph = false; opt_params_adam.adam.n_iter = 16; opt_params_lbfgs.lbfgs.n_iter = 16; - ggml_opt(ctx0, opt_params_adam, e); - // ggml_opt(ctx0, opt_params_lbfgs, e); + // ggml_opt(ctx0, opt_params_adam, e); + ggml_opt(ctx0, opt_params_lbfgs, e); // ggml_build_forward_expand(&gf, e); ggml_graph_compute(ctx0, &gf); float error_after_opt = ggml_get_f32_1d(e, 0); - if (ex % 8 == 0) { printf("Example %d\n", (ex+1)); printf("error_before_opt: %.2f\n", error_before_opt); @@ -1648,7 +1610,7 @@ int main(int argc, char ** argv) { } if (ex % 64 == 0) { - sample_softmax_batch(ctx0, logits1, after_opt_probs, after_opt_best_samples); + sample_softmax_batch(ctx0, logits, after_opt_probs, after_opt_best_samples); // printf("probabilities after optimization:\n"); // print_matrix(after_opt_probs); printf("best samples after optimization:\n"); @@ -1708,12 +1670,6 @@ int main(int argc, char ** argv) { ggml_set_i32_1d(tokens_input, 0, 0); ggml_set_i32_1d(tokens_input, sample_ctx-1, token); - // printf("---\n"); - // for (int i=0; i