From 0d4b87de3de6e0d910de5a0a2416ef6b10332fbe Mon Sep 17 00:00:00 2001 From: xaedes Date: Thu, 1 Jun 2023 19:50:48 +0200 Subject: [PATCH] improve training memory usage with scratch buffers instead of relying on the automatic backward pass, we manually create the graph for the backward pass. it turns out that all backward pass operations need only temporary memory which can be reused after each layer. will compute backward pass for ALL model parameters --- .../train-text-from-scratch.cpp | 593 +++++++++++++++++- 1 file changed, 577 insertions(+), 16 deletions(-) diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 7e8d80b94..ee17bd8e4 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -1337,6 +1337,505 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn( return inpL; } +struct ggml_tensor * forward_batch_wo_cache_flash_attn_train( + struct my_llama_model * model, + struct ggml_context * ctx0, + struct ggml_cgraph * gf, + struct ggml_cgraph * gb, + struct ggml_tensor * * logits, + struct ggml_tensor * tokens_input, + struct ggml_tensor * targets, + void * compute_buf_0, + void * compute_buf_1, + void * compute_buf_2, + size_t size_buf_0, + size_t size_buf_1, + size_t size_buf_2, + const int n_tokens, + const int n_batch) { + + ggml_set_scratch(ctx0, { 0, 0, nullptr, }); + + const int n_past = 0; + const int N = n_tokens; + + gf->n_nodes = 0; + gf->n_leafs = 0; + gf->work_size = 0; + gf->perf_runs = 0; + gf->perf_cycles = 0; + gf->perf_time_us = 0; + gf->work = NULL; + + const auto & hparams = model->hparams; + const int n_ctx = hparams.n_ctx; + const int n_vocab = hparams.n_vocab; + const int n_embd = hparams.n_embd; + const int n_layer = hparams.n_layer; + const int n_head = hparams.n_head; + const int n_rot = hparams.n_rot; + const int n_ff = get_n_ff(&hparams); + const int rope_mode = 0; + + auto expand = [] (struct ggml_cgraph * g, struct ggml_tensor * t) -> struct ggml_tensor * { + ggml_build_forward_expand(g, t); + return t; + }; + + int last_buf = -1; + size_t buf_offs[3] = { 0, 0, 0 }; + size_t buf_size[3] = { size_buf_0, + size_buf_1, + size_buf_2 }; + void * buf_data[3] = { compute_buf_0, + compute_buf_1, + compute_buf_2 }; + auto use_buf = [ctx0, &last_buf, &buf_offs, &buf_size, &buf_data] (int buf) { + size_t last_offs = 0; + last_offs = ggml_set_scratch(ctx0, { 0, 0, nullptr, }); + if (last_buf >= 0) { + buf_offs[last_buf] = last_offs; + } + if (buf >= 0) { + size_t offs = buf_offs[buf]; + size_t size = buf_size[buf]; + void * data = buf_data[buf]; + ggml_set_scratch(ctx0, { offs, size, data, }); + } + last_buf = buf; + }; + + auto clr_buf = [&buf_offs] (int buf) { + if (buf < 0) return; + // size_t last_offs = 0; + // last_offs = ggml_set_scratch(ctx, { 0, 0, nullptr, }); + // if (last_buf >= 0) { + // buf_offs[last_buf] = last_offs; + // } + // buf_max_size[buf] = std::max(buf_max_size[buf], buf_offs[buf]); + buf_offs[buf] = 0; + // if (last_buf >= 0) { + // size_t offs = buf_offs[last_buf]; + // size_t size = buf_size[last_buf]; + // void * data = buf_data[last_buf]; + // ggml_set_scratch(ctx0, { offset, size, data, }); + // } + }; + + auto view__q = [ctx0, n_embd, n_head, N, n_batch] (struct ggml_tensor * t) -> struct ggml_tensor * { + int64_t ne0 = n_embd/n_head; + int64_t ne1 = N; + int64_t ne2 = n_head; + int64_t ne3 = n_batch; + size_t nb0 = ggml_element_size(t); + size_t nb1 = nb0*ne0; + size_t nb2 = nb1*ne1; + size_t nb3 = nb2*ne2; + size_t offset = 0; + return ggml_view_4d(ctx0, t, ne0, ne1, ne2, ne3, nb1, nb2, nb3, offset); + }; + + auto view__k = [ctx0, n_embd, n_head, N, n_batch] (struct ggml_tensor * t) -> struct ggml_tensor * { + int64_t ne0 = n_embd/n_head; + int64_t ne1 = N; + int64_t ne2 = n_head; + int64_t ne3 = n_batch; + size_t nb0 = ggml_element_size(t); + size_t nb1 = nb0*ne0; + size_t nb2 = nb1*ne1; + size_t nb3 = nb2*ne2; + size_t offset = nb3*ne3; + return ggml_view_4d(ctx0, t, ne0, ne1, ne2, ne3, nb1, nb2, nb3, offset); + }; + + auto view__v = [ctx0, n_embd, n_head, N, n_batch] (struct ggml_tensor * t) -> struct ggml_tensor * { + int64_t ne0 = N; + int64_t ne1 = n_embd/n_head; + int64_t ne2 = n_head; + int64_t ne3 = n_batch; + size_t nb0 = ggml_element_size(t); + size_t nb1 = nb0*ne0; + size_t nb2 = nb1*ne1; + size_t nb3 = nb2*ne2; + size_t offset = 2*nb3*ne3; + return ggml_view_4d(ctx0, t, ne0, ne1, ne2, ne3, nb1, nb2, nb3, offset); + }; + + auto add_or_set = [ctx0] (struct ggml_tensor * a, struct ggml_tensor * b) -> struct ggml_tensor * { + if (a == NULL) { + return b; + } else { + return ggml_add_inplace(ctx0, a, b); + } + }; + + use_buf(-1); + + model->tok_embeddings->grad = ggml_dup_tensor(ctx0, model->tok_embeddings->grad); + model->norm->grad = ggml_dup_tensor(ctx0, model->norm->grad); + model->output->grad = ggml_dup_tensor(ctx0, model->output->grad); + + for (int il = 0; il < n_layer; ++il) { + struct my_llama_layer & layer = model->layers[il]; + layer.attention_norm->grad = ggml_dup_tensor(ctx0, layer.attention_norm->grad); + layer.wq->grad = ggml_dup_tensor(ctx0, layer.wq->grad); + layer.wk->grad = ggml_dup_tensor(ctx0, layer.wk->grad); + layer.wv->grad = ggml_dup_tensor(ctx0, layer.wv->grad); + layer.wo->grad = ggml_dup_tensor(ctx0, layer.wo->grad); + layer.ffn_norm->grad = ggml_dup_tensor(ctx0, layer.ffn_norm->grad); + layer.w1->grad = ggml_dup_tensor(ctx0, layer.w1->grad); + layer.w2->grad = ggml_dup_tensor(ctx0, layer.w2->grad); + layer.w3->grad = ggml_dup_tensor(ctx0, layer.w3->grad); + } + + clr_buf(1); + clr_buf(2); + + use_buf(0); + + struct ggml_tensor * t00 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N*n_batch); assert_shape_1d(t00, N*n_batch); + memcpy(t00->data, tokens_input->data, ggml_element_size(t00)*N*n_batch); + + struct ggml_tensor * t01 = expand(gf, ggml_get_rows(ctx0, model->tok_embeddings, t00)); assert_shape_2d(t01, n_embd, N*n_batch); + + // need to remember these for the backward pass + std::vector t02L; t02L.resize(n_layer, NULL); + std::vector t03L; t03L.resize(n_layer, NULL); + std::vector t04L; t04L.resize(n_layer, NULL); + std::vector t05L; t05L.resize(n_layer, NULL); + std::vector t06L; t06L.resize(n_layer, NULL); + std::vector t07L; t07L.resize(n_layer, NULL); + std::vector t08L; t08L.resize(n_layer, NULL); + std::vector t09L; t09L.resize(n_layer, NULL); + std::vector t10L; t10L.resize(n_layer, NULL); + std::vector t11L; t11L.resize(n_layer, NULL); + std::vector t12L; t12L.resize(n_layer, NULL); + std::vector t13L; t13L.resize(n_layer, NULL); + std::vector t14L; t14L.resize(n_layer, NULL); + std::vector t15L; t15L.resize(n_layer, NULL); + std::vector t16L; t16L.resize(n_layer, NULL); + std::vector t17L; t17L.resize(n_layer, NULL); + std::vector t18L; t18L.resize(n_layer, NULL); + std::vector t19L; t19L.resize(n_layer, NULL); + std::vector t20L; t20L.resize(n_layer, NULL); + std::vector t21L; t21L.resize(n_layer, NULL); + std::vector t22L; t22L.resize(n_layer, NULL); + std::vector t23L; t23L.resize(n_layer, NULL); + std::vector t24L; t24L.resize(n_layer, NULL); + std::vector t25L; t25L.resize(n_layer, NULL); + std::vector t26L; t26L.resize(n_layer, NULL); + std::vector t27L; t27L.resize(n_layer, NULL); + std::vector t28L; t28L.resize(n_layer, NULL); + std::vector t29L; t29L.resize(n_layer, NULL); + std::vector t30L; t30L.resize(n_layer, NULL); + + struct ggml_tensor * cur = t01; + + for (int il = 0; il < n_layer; ++il) { + clr_buf(1); + struct my_llama_layer & layer = model->layers[il]; + // tensors with values necessary for backward pass are in persistent buf(0) + // other tensors with buf(1) are only temporary needed, and their memory reused after layer is completed. + use_buf(0); struct ggml_tensor * t02 = expand(gf, ggml_rms_norm (ctx0, cur)); assert_shape_2d(t02, n_embd, N*n_batch); // n_embd, N*n_batch + use_buf(1); struct ggml_tensor * t03 = expand(gf, ggml_repeat (ctx0, layer.attention_norm, t02)); assert_shape_2d(t03, n_embd, N*n_batch); + use_buf(0); struct ggml_tensor * t04 = expand(gf, ggml_mul (ctx0, t02, t03)); assert_shape_2d(t04, n_embd, N*n_batch); // n_embd, N*n_batch + use_buf(1); struct ggml_tensor * t05 = expand(gf, ggml_mul_mat (ctx0, layer.wq, t04)); assert_shape_2d(t05, n_embd, N*n_batch); + use_buf(1); struct ggml_tensor * t06 = expand(gf, ggml_reshape_4d (ctx0, t05, n_embd/n_head, n_head, N, n_batch)); assert_shape_4d(t06, n_embd/n_head, n_head, N, n_batch); + use_buf(1); struct ggml_tensor * t07 = expand(gf, ggml_rope_inplace (ctx0, t06, n_past, n_rot, rope_mode)); assert_shape_4d(t07, n_embd/n_head, n_head, N, n_batch); + use_buf(1); struct ggml_tensor * t08 = expand(gf, ggml_mul_mat (ctx0, layer.wk, t04)); assert_shape_2d(t08, n_embd, N*n_batch); + use_buf(1); struct ggml_tensor * t09 = expand(gf, ggml_reshape_4d (ctx0, t08, n_embd/n_head, n_head, N, n_batch)); assert_shape_4d(t09, n_embd/n_head, n_head, N, n_batch); + use_buf(1); struct ggml_tensor * t10 = expand(gf, ggml_rope_inplace (ctx0, t09, n_past, n_rot, rope_mode)); assert_shape_4d(t10, n_embd/n_head, n_head, N, n_batch); + use_buf(1); struct ggml_tensor * t11 = expand(gf, ggml_mul_mat (ctx0, t04, layer.wv)); assert_shape_2d(t11, N*n_batch, n_embd); + use_buf(1); struct ggml_tensor * t12 = expand(gf, ggml_reshape_4d (ctx0, t11, N, n_batch, n_embd/n_head, n_head)); assert_shape_4d(t12, N, n_batch, n_embd/n_head, n_head); + use_buf(0); struct ggml_tensor * t13 = expand(gf, ggml_permute (ctx0, t07, 0, 2, 1, 3)); assert_shape_4d(t13, n_embd/n_head, N, n_head, n_batch); // n_embd/n_head, N, n_head, n_batch + use_buf(0); struct ggml_tensor * t14 = expand(gf, ggml_permute (ctx0, t10, 0, 2, 1, 3)); assert_shape_4d(t14, n_embd/n_head, N, n_head, n_batch); // n_embd/n_head, N, n_head, n_batch + use_buf(0); struct ggml_tensor * t15 = expand(gf, ggml_permute (ctx0, t12, 0, 3, 1, 2)); assert_shape_4d(t15, N, n_embd/n_head, n_head, n_batch); // N, n_embd/n_head, n_head, n_batch + use_buf(1); struct ggml_tensor * t16 = expand(gf, ggml_flash_attn (ctx0, t13, t14, t15, true)); assert_shape_4d(t16, n_embd/n_head, N, n_head, n_batch); + use_buf(1); struct ggml_tensor * t17 = expand(gf, ggml_permute (ctx0, t16, 0, 2, 1, 3)); assert_shape_4d(t17, n_embd/n_head, n_head, N, n_batch); + use_buf(1); struct ggml_tensor * t18 = expand(gf, ggml_cont (ctx0, t17)); assert_shape_4d(t18, n_embd/n_head, n_head, N, n_batch); + use_buf(0); struct ggml_tensor * t19 = expand(gf, ggml_reshape_2d (ctx0, t18, n_embd, N*n_batch)); assert_shape_2d(t19, n_embd, N*n_batch); // n_embd, N*n_batch + use_buf(1); struct ggml_tensor * t20 = expand(gf, ggml_mul_mat (ctx0, layer.wo, t19)); assert_shape_2d(t20, n_embd, N*n_batch); + use_buf(0); struct ggml_tensor * t21 = expand(gf, ggml_add (ctx0, t20, cur)); assert_shape_2d(t21, n_embd, N*n_batch); // n_embd, N*n_batch + use_buf(0); struct ggml_tensor * t22 = expand(gf, ggml_rms_norm (ctx0, t21)); assert_shape_2d(t22, n_embd, N*n_batch); // n_embd, N*n_batch + use_buf(1); struct ggml_tensor * t23 = expand(gf, ggml_repeat (ctx0, layer.ffn_norm, t22)); assert_shape_2d(t23, n_embd, N*n_batch); + use_buf(0); struct ggml_tensor * t24 = expand(gf, ggml_mul (ctx0, t23, t22)); assert_shape_2d(t24, n_embd, N*n_batch); // n_embd, N*n_batch + use_buf(0); struct ggml_tensor * t25 = expand(gf, ggml_mul_mat (ctx0, layer.w3, t24)); assert_shape_2d(t25, n_ff, N*n_batch); // n_ff, N*n_batch + use_buf(0); struct ggml_tensor * t26 = expand(gf, ggml_mul_mat (ctx0, layer.w1, t24)); assert_shape_2d(t26, n_ff, N*n_batch); // n_ff, N*n_batch + use_buf(0); struct ggml_tensor * t27 = expand(gf, ggml_silu (ctx0, t26)); assert_shape_2d(t27, n_ff, N*n_batch); // n_ff, N*n_batch + use_buf(0); struct ggml_tensor * t28 = expand(gf, ggml_mul (ctx0, t27, t25)); assert_shape_2d(t28, n_ff, N*n_batch); // n_ff, N*n_batch + use_buf(1); struct ggml_tensor * t29 = expand(gf, ggml_mul_mat (ctx0, layer.w2, t28)); assert_shape_2d(t29, n_embd, N*n_batch); + use_buf(0); struct ggml_tensor * t30 = expand(gf, ggml_add (ctx0, t21, t29)); assert_shape_2d(t30, n_embd, N*n_batch); // n_embd, N*n_batch + t02L[il] = t02; + t03L[il] = t03; + t04L[il] = t04; + t05L[il] = t05; + t06L[il] = t06; + t07L[il] = t07; + t08L[il] = t08; + t09L[il] = t09; + t10L[il] = t10; + t11L[il] = t11; + t12L[il] = t12; + t13L[il] = t13; + t14L[il] = t14; + t15L[il] = t15; + t16L[il] = t16; + t17L[il] = t17; + t18L[il] = t18; + t19L[il] = t19; + t20L[il] = t20; + t21L[il] = t21; + t22L[il] = t22; + t23L[il] = t23; + t24L[il] = t24; + t25L[il] = t25; + t26L[il] = t26; + t27L[il] = t27; + t28L[il] = t28; + t29L[il] = t29; + t30L[il] = t30; + + cur = t30; + } + clr_buf(1); + use_buf(1); + struct ggml_tensor * t31 = expand(gf, ggml_rms_norm (ctx0, cur)); assert_shape_2d(t31, n_embd, N*n_batch); + struct ggml_tensor * t32 = expand(gf, ggml_repeat (ctx0, model->norm, t31)); assert_shape_2d(t32, n_embd, N*n_batch); + struct ggml_tensor * t33 = expand(gf, ggml_mul (ctx0, t32, t31)); assert_shape_2d(t33, n_embd, N*n_batch); + struct ggml_tensor * t34 = expand(gf, ggml_mul_mat (ctx0, model->output, t33)); assert_shape_2d(t34, n_vocab, N*n_batch); + struct ggml_tensor * t35 = expand(gf, ggml_reshape_3d(ctx0, t34, n_vocab, N, n_batch)); assert_shape_3d(t35, n_vocab, N, n_batch); + struct ggml_tensor * t36 = expand(gf, ggml_cross_entropy_loss(ctx0, t35, targets)); assert_shape_1d(t36, 1); + + { + /* + tok_embeddings | grad_tok_embeddings = ggml_get_rows_back(grad_t01, t00) + L0_att_norm | grad_L0_att_norm = ggml_repeat_back(grad_t03L0, L0_att_norm.shape) + L0_wq | grad_L0_wq = ggml_out_prod(t04L0, grad_t05L0) + L0_wk | grad_L0_wk = ggml_out_prod(t04L0, grad_t08L0) + L0_wv | grad_L0_wv = ggml_out_prod(t04L0, ggml_transpose(grad_t11L0)) + L0_wo | grad_L0_wo = ggml_out_prod(t19L0, grad_t20L0) + L0_ffn_norm | grad_L0_ffn_norm = ggml_repeat_back(grad_t23L0, L0_ffn_norm.shape) + L0_w1 | grad_L0_w1 = ggml_out_prod(t24L0, grad_t26L0) + L0_w2 | grad_L0_w2 = ggml_out_prod(t28L0, grad_t29L0) + L0_w3 | grad_L0_w3 = ggml_out_prod(t24L0, grad_t25L0) + L1_att_norm | grad_L1_att_norm = ggml_repeat_back(grad_t03L1, L1_att_norm.shape) + L1_wq | grad_L1_wq = ggml_out_prod(t04L1, grad_t05L1) + L1_wk | grad_L1_wk = ggml_out_prod(t04L1, grad_t08L1) + L1_wv | grad_L1_wv = ggml_out_prod(t04L1, ggml_transpose(grad_t11L1)) + L1_wo | grad_L1_wo = ggml_out_prod(t19L1, grad_t20L1) + L1_ffn_norm | grad_L1_ffn_norm = ggml_repeat_back(grad_t23L1, L1_ffn_norm.shape) + L1_w1 | grad_L1_w1 = ggml_out_prod(t24L1, grad_t26L1) + L1_w2 | grad_L1_w2 = ggml_out_prod(t28L1, grad_t29L1) + L1_w3 | grad_L1_w3 = ggml_out_prod(t24L1, grad_t25L1) + norm | grad_norm = ggml_repeat_back(grad_t32, norm.shape) + output | grad_output = ggml_out_prod(t33, grad_t34) + | + t01 = ggml_get_rows(tok_embeddings, t00) | grad_t01 = grad_t21L0 + ggml_rms_norm_back(t01, grad_t02L0) + for layer: | + t02L0*= ggml_rms_norm (t01) | grad_t02L0 = ggml_mul(grad_t04L0, t03L0) + t03L0 = ggml_repeat (L0_att_norm, t02L0_shape) | grad_t03L0 = ggml_mul(grad_t04L0, t02L0) + t04L0*= ggml_mul (t02L0, t03L0) | grad_t04L0 = ggml_out_prod(L0_wv, grad_t11L0) + ggml_out_prod(L0_wk, ggml_transpose(grad_t08L0)) + ggml_out_prod(L0_wq, ggml_transpose(grad_t05L0)) + t05L0 = ggml_mul_mat (L0_wq, t04L0) | grad_t05L0 = ggml_reshape(grad_t06L0, t05L0_shape) + t06L0 = ggml_reshape_4d (t05L0, n_embd/n_head, n_head, N, n_batch) | grad_t06L0 = ggml_rope_back(grad_t07L0) + t07L0 = ggml_rope_inplace (t06L0) | grad_t07L0 = ggml_permute_back(grad_t13L0, 0, 2, 1, 3) = ggml_permute(grad_t13L0, 0, 2, 1, 3) + t08L0 = ggml_mul_mat (L0_wk, t04L0) | grad_t08L0 = ggml_reshape(grad_t09L0, t08L0_shape) + t09L0 = ggml_reshape_4d (t08L0, n_embd/n_head, n_head, N, n_batch) | grad_t09L0 = ggml_rope_back(grad_t10L0) + t10L0 = ggml_rope_inplace (t09L0) | grad_t10L0 = ggml_permute_back(grad_t14L0, 0, 2, 1, 3) = ggml_permute(grad_t14L0, 0, 2, 1, 3) + t11L0 = ggml_mul_mat (t04L0, L0_wv) | grad_t11L0 = ggml_reshape(grad_t12L0, t11L0_shape) + t12L0 = ggml_reshape_4d (t11L0, N, n_batch, n_embd/n_head, n_head) | grad_t12L0 = ggml_permute_back(grad_t15L0, 0, 3, 1, 2) = ggml_permute(grad_t15L0, 0, 2, 3, 1) + t13L0*= ggml_permute (t07L0, 0, 2, 1, 3) | grad_t13L0 = view__q(ggml_flash_attn_back(t13L0, t14L0, t15L0, grad_t16L0)) + t14L0*= ggml_permute (t10L0, 0, 2, 1, 3) | grad_t14L0 = view__k(ggml_flash_attn_back(t13L0, t14L0, t15L0, grad_t16L0)) + t15L0*= ggml_permute (t12L0, 0, 3, 1, 2) | grad_t15L0 = view__v(ggml_flash_attn_back(t13L0, t14L0, t15L0, grad_t16L0)) + t16L0 = ggml_flash_attn (t13L0, t14L0, t15L0) | grad_t16L0 = ggml_permute_back(grad_t17L0, 0, 2, 1, 3) = ggml_permute(grad_t17L0, 0, 2, 1, 3) + t17L0 = ggml_permute (t16L0, 0, 2, 1, 3) | grad_t17L0 = grad_t18L0 + t18L0 = ggml_cont (t17L0) | grad_t18L0 = ggml_reshape(grad_t19L0, t18L0_shape) + t19L0*= ggml_reshape_2d (t18L0, n_embd, N*n_batch) | grad_t19L0 = ggml_out_prod(L0_wo, ggml_transpose(grad_t20L0)) + t20L0 = ggml_mul_mat (L0_wo, t19L0) | grad_t20L0 = grad_t21L0 + t21L0*= ggml_add (t20L0, t01) | grad_t21L0 = grad_t30L0 + ggml_rms_norm_back(t21L0, grad_t22L0) + t22L0*= ggml_rms_norm (t21L0) | grad_t22L0 = ggml_mul(grad_t24L0, t23L0) + t23L0 = ggml_repeat (L0_ffn_norm, t22L0_shape) | grad_t23L0 = ggml_mul(grad_t24L0, t22L0) + t24L0*= ggml_mul (t23L0, t22L0) | grad_t24L0 = ggml_out_prod(L0_w1, ggml_transpose(grad_t26L0)) + ggml_out_prod(L0_w3, ggml_transpose(grad_t25L0)) + t25L0*= ggml_mul_mat (L0_w3, t24L0) | grad_t25L0 = ggml_mul(grad_t28L0, t27L0) + t26L0*= ggml_mul_mat (L0_w1, t24L0) | grad_t26L0 = ggml_silu_back(t26L0, grad_t27L0) + t27L0*= ggml_silu (t26L0) | grad_t27L0 = ggml_mul(grad_t28L0, t25L0) + t28L0*= ggml_mul (t27L0, t25L0) | grad_t28L0 = ggml_out_prod(L0_w2, ggml_transpose(grad_t29L0)) + t29L0 = ggml_mul_mat (L0_w2, t28L0) | grad_t29L0 = grad_t30L0 + t30L0*= ggml_add (t21L0, t29L0) | grad_t30L0 = ggml_rms_norm_back(t30L0, grad_t02L1) + grad_t21L1 + ^ + t02L1*= ggml_rms_norm (t30L0) | grad_t02L1 = ggml_mul(grad_t04L1, t03L1) + t03L1 = ggml_repeat (L1_att_norm, t02L1_shape) | grad_t03L1 = ggml_mul(grad_t04L1, t02L1) + t04L1*= ggml_mul (t02L1, t03L1) | grad_t04L1 = ggml_out_prod(L1_wv, grad_t11L1) + ggml_out_prod(L1_wk, ggml_transpose(grad_t08L1)) + ggml_out_prod(L1_wq, ggml_transpose(grad_t05L1)) + t05L1 = ggml_mul_mat (L1_wq, t04L1) | grad_t05L1 = ggml_reshape(grad_t06L1, t05L1_shape) + t06L1 = ggml_reshape_4d (t05L1, n_embd/n_head, n_head, N, n_batch) | grad_t06L1 = ggml_rope_back(grad_t07L1) + t07L1 = ggml_rope_inplace (t06L1) | grad_t07L1 = ggml_permute_back(grad_t13L1, 0, 2, 1, 3) = ggml_permute(grad_t13L1, 0, 2, 1, 3) + t08L1 = ggml_mul_mat (L1_wk, t04L1) | grad_t08L1 = ggml_reshape(grad_t09L1, t08L1_shape) + t09L1 = ggml_reshape_4d (t08L1, n_embd/n_head, n_head, N, n_batch) | grad_t09L1 = ggml_rope_back(grad_t10L1) + t10L1 = ggml_rope_inplace (t09L1) | grad_t10L1 = ggml_permute_back(grad_t14L1, 0, 2, 1, 3) = ggml_permute(grad_t14L1, 0, 2, 1, 3) + t11L1 = ggml_mul_mat (t04L1, L1_wv) | grad_t11L1 = ggml_reshape(grad_t12L1, t11L1_shape) + t12L1 = ggml_reshape_4d (t11L1, N, n_batch, n_embd/n_head, n_head) | grad_t12L1 = ggml_permute_back(grad_t15L1, 0, 3, 1, 2) = ggml_permute(grad_t15L1, 0, 2, 3, 1) + t13L1*= ggml_permute (t07L1, 0, 2, 1, 3) | grad_t13L1 = view__q(ggml_flash_attn_back(t13L1, t14L1, t15L1, grad_t16L1)) + t14L1*= ggml_permute (t10L1, 0, 2, 1, 3) | grad_t14L1 = view__k(ggml_flash_attn_back(t13L1, t14L1, t15L1, grad_t16L1)) + t15L1*= ggml_permute (t12L1, 0, 3, 1, 2) | grad_t15L1 = view__v(ggml_flash_attn_back(t13L1, t14L1, t15L1, grad_t16L1)) + t16L1 = ggml_flash_attn (t13L1, t14L1, t15L1) | grad_t16L1 = ggml_permute_back(grad_t17L1, 0, 2, 1, 3) = ggml_permute(grad_t17L1, 0, 2, 1, 3) + t17L1 = ggml_permute (t16L1, 0, 2, 1, 3) | grad_t17L1 = grad_t18L1 + t18L1 = ggml_cont (t17L1) | grad_t18L1 = ggml_reshape(grad_t19L1, t18L1_shape) + t19L1*= ggml_reshape_2d (t18L1, n_embd, N*n_batch) | grad_t19L1 = ggml_out_prod(L1_wo, ggml_transpose(grad_t20L1)) + t20L1 = ggml_mul_mat (L1_wo, t19L1) | grad_t20L1 = grad_t21L1 + t21L1*= ggml_add (t20L1, t30L0) | grad_t21L1 = grad_t30L1 + ggml_rms_norm_back(t21L1, grad_t22L1) + t22L1*= ggml_rms_norm (t21L1) | grad_t22L1 = ggml_mul(grad_t24L1, t23L1) + t23L1 = ggml_repeat (L1_ffn_norm, t22L1_shape) | grad_t23L1 = ggml_mul(grad_t24L1, t22L1) + t24L1*= ggml_mul (t23L1, t22L1) | grad_t24L1 = ggml_out_prod(L1_w1, ggml_transpose(grad_t26L1)) + ggml_out_prod(L1_w3, ggml_transpose(grad_t25L1)) + t25L1*= ggml_mul_mat (L1_w3, t24L1) | grad_t25L1 = ggml_mul(grad_t28L1, t27L1) + t26L1*= ggml_mul_mat (L1_w1, t24L1) | grad_t26L1 = ggml_silu_back(t26L1, grad_t27L1) + t27L1*= ggml_silu (t26L1) | grad_t27L1 = ggml_mul(grad_t28L1, t25L1) + t28L1*= ggml_mul (t27L1, t25L1) | grad_t28L1 = ggml_out_prod(L1_w2, ggml_transpose(grad_t29L1)) + t29L1 = ggml_mul_mat (L1_w2, t28L1) | grad_t29L1 = grad_t30L1 + t30L1*= ggml_add (t21L1, t29L1) | grad_t30L1 = ggml_rms_norm_back(t30L1, grad_t31) + ^ + t31 = ggml_rms_norm (t30L1) | grad_t31 = ggml_mul(grad_t33, t32) + t32 = ggml_repeat (norm, t31.shape) | grad_t32 = ggml_mul(grad_t33, t31) + t33 = ggml_mul (t32, t31) | grad_t33 = ggml_out_prod(output, ggml_transpose(grad_t34)) + t34 = ggml_mul_mat (output, t33) | grad_t34 = ggml_reshape(grad_t35, t34.shape) + t35 = ggml_reshape_3d (t34, n_vocab, N, n_batch) | grad_t35 = ggml_cross_entropy_loss_back(t35, targets, grad_t36) + t36 = ggml_cross_entropy_loss(t35, targets) | grad_t36 = 1 (optimizer) + tensors marked with * need to be stored until grad computation + tensors during grad computation are all temporary + */ + } + + *gb = *gf; + + use_buf(-1); + // t36->grad gets set to one by optimizer, so we need to create the tensor. + // initialize it with 1.0f to make sure. + t36->grad = ggml_new_f32(ctx0, 1.0f); + + use_buf(1); + t35->grad = expand(gb, ggml_cross_entropy_loss_back(ctx0, t35, targets, t36->grad)); assert_shape_3d(t35->grad, n_vocab, N, n_batch); + t34->grad = expand(gb, ggml_reshape_2d (ctx0, t35->grad, n_vocab, N*n_batch)); assert_shape_2d(t34->grad, n_vocab, N*n_batch); + t33->grad = expand(gb, ggml_out_prod (ctx0, model->output, ggml_transpose(ctx0, t34->grad))); assert_shape_2d(t33->grad, n_embd, N*n_batch); + t32->grad = expand(gb, ggml_mul (ctx0, t33->grad, t31)); assert_shape_2d(t32->grad, n_embd, N*n_batch); + + use_buf(-1); + + model->norm->grad = expand(gb, add_or_set(model->norm->grad, ggml_repeat_back(ctx0, t32->grad, model->norm))); assert_shape_1d(model->norm->grad, n_embd); + model->output->grad = expand(gb, add_or_set(model->output->grad, ggml_out_prod(ctx0, t33, t34->grad))); assert_shape_2d(model->output->grad, n_embd, n_vocab); + + clr_buf(2); + use_buf(2); + t31->grad = expand(gb, ggml_mul(ctx0, t33->grad, t32)); assert_shape_2d(t31->grad, n_embd, N*n_batch); + + struct ggml_tensor * back_layer_inp = t31; + struct ggml_tensor * grad_layer_inp = NULL; + + for (int k = 0; k < n_layer; ++k) { + int il = n_layer-1-k; + struct my_llama_layer & layer = model->layers[il]; + + struct ggml_tensor * t02 = t02L[il]; + struct ggml_tensor * t03 = t03L[il]; + struct ggml_tensor * t04 = t04L[il]; + struct ggml_tensor * t05 = t05L[il]; + struct ggml_tensor * t06 = t06L[il]; + struct ggml_tensor * t07 = t07L[il]; + struct ggml_tensor * t08 = t08L[il]; + struct ggml_tensor * t09 = t09L[il]; + struct ggml_tensor * t10 = t10L[il]; + struct ggml_tensor * t11 = t11L[il]; + struct ggml_tensor * t12 = t12L[il]; + struct ggml_tensor * t13 = t13L[il]; + struct ggml_tensor * t14 = t14L[il]; + struct ggml_tensor * t15 = t15L[il]; + struct ggml_tensor * t16 = t16L[il]; + struct ggml_tensor * t17 = t17L[il]; + struct ggml_tensor * t18 = t18L[il]; + struct ggml_tensor * t19 = t19L[il]; + struct ggml_tensor * t20 = t20L[il]; + struct ggml_tensor * t21 = t21L[il]; + struct ggml_tensor * t22 = t22L[il]; + struct ggml_tensor * t23 = t23L[il]; + struct ggml_tensor * t24 = t24L[il]; + struct ggml_tensor * t25 = t25L[il]; + struct ggml_tensor * t26 = t26L[il]; + struct ggml_tensor * t27 = t27L[il]; + struct ggml_tensor * t28 = t28L[il]; + struct ggml_tensor * t29 = t29L[il]; + struct ggml_tensor * t30 = t30L[il]; + + clr_buf(1); + use_buf(1); + t30->grad = expand(gb, ggml_rms_norm_back(ctx0, t30, back_layer_inp->grad)); assert_shape_2d(t30->grad, n_embd, N*n_batch); + if (grad_layer_inp) { + t30->grad = expand(gb, ggml_add(ctx0, t30->grad, grad_layer_inp)); assert_shape_2d(t30->grad, n_embd, N*n_batch); + } + clr_buf(2); + t29->grad = t30->grad; assert_shape_2d(t29->grad, n_embd, N*n_batch); + t28->grad = expand(gb, ggml_out_prod(ctx0, layer.w2, ggml_transpose(ctx0, t29->grad))); assert_shape_2d(t28->grad, n_ff, N*n_batch); + t27->grad = expand(gb, ggml_mul(ctx0, t28->grad, t25)); assert_shape_2d(t27->grad, n_ff, N*n_batch); + t26->grad = expand(gb, ggml_silu_back(ctx0, t26, t27->grad)); assert_shape_2d(t26->grad, n_ff, N*n_batch); + t25->grad = expand(gb, ggml_mul(ctx0, t28->grad, t27)); assert_shape_2d(t25->grad, n_ff, N*n_batch); + t24->grad = expand(gb, ggml_add_inplace(ctx0, + ggml_out_prod(ctx0, layer.w1, ggml_transpose(ctx0, t26->grad)), + ggml_out_prod(ctx0, layer.w3, ggml_transpose(ctx0, t25->grad)))); assert_shape_2d(t24->grad, n_embd, N*n_batch); + t23->grad = expand(gb, ggml_mul(ctx0, t24->grad, t22)); assert_shape_2d(t23->grad, n_embd, N*n_batch); + t22->grad = expand(gb, ggml_mul(ctx0, t24->grad, ggml_repeat(ctx0, layer.ffn_norm, t24->grad))); assert_shape_2d(t22->grad, n_embd, N*n_batch); + use_buf(2); + t21->grad = expand(gb, ggml_add(ctx0, t30->grad, ggml_rms_norm_back(ctx0, t21, t22->grad))); assert_shape_2d(t21->grad, n_embd, N*n_batch); + grad_layer_inp = t21; + use_buf(1); + t20->grad = t21->grad; assert_shape_2d(t20->grad, n_embd, N*n_batch); + t19->grad = expand(gb, ggml_out_prod(ctx0, layer.wo, ggml_transpose(ctx0, t20->grad))); assert_shape_2d(t19->grad, n_embd, N*n_batch); + t18->grad = expand(gb, ggml_reshape(ctx0, t19->grad, t18)); assert_shape_4d(t18->grad, n_embd/n_head, n_head, N, n_batch); + t17->grad = t18->grad; assert_shape_4d(t17->grad, n_embd/n_head, n_head, N, n_batch); + t16->grad = expand(gb, ggml_permute(ctx0, t17->grad, 0, 2, 1, 3)); assert_shape_4d(t16->grad, n_embd/n_head, N, n_head, n_batch); + struct ggml_tensor * flash_attn = expand(gb, ggml_flash_attn_back(ctx0, t13, t14, t15, t16->grad, true)); assert_shape_4d(flash_attn, n_embd/n_head, N*3, n_head, n_batch); + t15->grad = expand(gb, view__v(flash_attn)); assert_shape_4d(t15->grad, N, n_embd/n_head, n_head, n_batch); + t14->grad = expand(gb, view__k(flash_attn)); assert_shape_4d(t14->grad, n_embd/n_head, N, n_head, n_batch); + t13->grad = expand(gb, view__q(flash_attn)); assert_shape_4d(t13->grad, n_embd/n_head, N, n_head, n_batch); + t12->grad = expand(gb, ggml_permute(ctx0, t15->grad, 0, 2, 3, 1)); assert_shape_4d(t12->grad, N, n_batch, n_embd/n_head, n_head); + t11->grad = expand(gb, ggml_reshape(ctx0, ggml_cont(ctx0, t12->grad), t11)); assert_shape_2d(t11->grad, N*n_batch, n_embd); + t10->grad = expand(gb, ggml_permute(ctx0, t14->grad, 0, 2, 1, 3)); assert_shape_4d(t10->grad, n_embd/n_head, n_head, N, n_batch); + t09->grad = expand(gb, ggml_rope_back(ctx0, t10->grad, n_past, n_rot, rope_mode)); assert_shape_4d(t09->grad, n_embd/n_head, n_head, N, n_batch); + t08->grad = expand(gb, ggml_reshape(ctx0, t09->grad, t08)); assert_shape_2d(t08->grad, n_embd, N*n_batch); + t07->grad = expand(gb, ggml_permute(ctx0, t13->grad, 0, 2, 1, 3)); assert_shape_4d(t07->grad, n_embd/n_head, n_head, N, n_batch); + t06->grad = expand(gb, ggml_rope_back(ctx0, t07->grad, n_past, n_rot, rope_mode)); assert_shape_4d(t06->grad, n_embd/n_head, n_head, N, n_batch); + t05->grad = expand(gb, ggml_reshape(ctx0, t06->grad, t05)); assert_shape_2d(t05->grad, n_embd, N*n_batch); + t04->grad = expand(gb, ggml_add_inplace(ctx0, + ggml_add_inplace(ctx0, + ggml_out_prod(ctx0, layer.wv, t11->grad), + ggml_out_prod(ctx0, layer.wk, ggml_transpose(ctx0, t08->grad))), + ggml_out_prod(ctx0, layer.wq, ggml_transpose(ctx0, t05->grad)))); assert_shape_2d(t04->grad, n_embd, N*n_batch); + t03->grad = expand(gb, ggml_mul(ctx0, t04->grad, t02)); assert_shape_2d(t04->grad, n_embd, N*n_batch); + use_buf(2); + t02->grad = expand(gb, ggml_mul(ctx0, t04->grad, t03)); assert_shape_2d(t02->grad, n_embd, N*n_batch); + back_layer_inp = t02->grad; + use_buf(1); + + use_buf(-1); + layer.attention_norm->grad = expand(gb, add_or_set(layer.attention_norm->grad, ggml_repeat_back(ctx0, t03->grad, layer.attention_norm))); assert_shape_1d(layer.attention_norm->grad, n_embd); + layer.wq->grad = expand(gb, add_or_set(layer.wq->grad, ggml_out_prod(ctx0, t04, t05->grad))); assert_shape_2d(layer.wq->grad, n_embd, n_embd); + layer.wk->grad = expand(gb, add_or_set(layer.wk->grad, ggml_out_prod(ctx0, t04, t08->grad))); assert_shape_2d(layer.wk->grad, n_embd, n_embd); + layer.wv->grad = expand(gb, add_or_set(layer.wv->grad, ggml_out_prod(ctx0, t04, ggml_transpose(ctx0, t11->grad)))); assert_shape_2d(layer.wv->grad, n_embd, n_embd); + layer.wo->grad = expand(gb, add_or_set(layer.wo->grad, ggml_out_prod(ctx0, t19, t20->grad))); assert_shape_2d(layer.wo->grad, n_embd, n_embd); + layer.ffn_norm->grad = expand(gb, add_or_set(layer.ffn_norm->grad, ggml_repeat_back(ctx0, t23->grad, layer.ffn_norm))); assert_shape_1d(layer.ffn_norm->grad, n_embd); + layer.w1->grad = expand(gb, add_or_set(layer.w1->grad, ggml_out_prod(ctx0, t24, t26->grad))); assert_shape_2d(layer.w1->grad, n_embd, n_ff); + layer.w2->grad = expand(gb, add_or_set(layer.w2->grad, ggml_out_prod(ctx0, t28, t29->grad))); assert_shape_2d(layer.w2->grad, n_ff, n_embd); + layer.w3->grad = expand(gb, add_or_set(layer.w3->grad, ggml_out_prod(ctx0, t24, t25->grad))); assert_shape_2d(layer.w3->grad, n_embd, n_ff); + use_buf(1); + } + clr_buf(1); + use_buf(1); + t01->grad = expand(gb, ggml_add_inplace(ctx0, grad_layer_inp->grad, ggml_rms_norm_back(ctx0, t01, back_layer_inp->grad))); assert_shape_2d(t01->grad, n_embd, N*n_batch); + use_buf(-1); + model->tok_embeddings->grad = expand(gb, ggml_get_rows_back(ctx0, t01->grad, t00, model->tok_embeddings)); assert_shape_2d(model->tok_embeddings->grad, n_embd, n_vocab); + clr_buf(2); + clr_buf(1); + + *logits = t35; + + return t36; +} + void set_f32_3d(struct ggml_tensor * tensor, int64_t i0, int64_t i1, int64_t i2, float value) { float * ptr = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2]); *ptr = value; @@ -2129,6 +2628,9 @@ struct train_params { int mem_model_gb; int mem_compute_gb; + int mem_compute0_gb; + int mem_compute1_gb; + int mem_compute2_gb; }; struct train_params get_default_train_params() { @@ -2172,7 +2674,10 @@ struct train_params get_default_train_params() { params.adam_decay = 1e-3; params.mem_model_gb = 2; - params.mem_compute_gb = 32; + params.mem_compute_gb = 8; + params.mem_compute0_gb = 24; + params.mem_compute1_gb = 8; + params.mem_compute2_gb = 8; return params; } @@ -2215,6 +2720,9 @@ void train_print_usage(int /*argc*/, char ** argv, const struct train_params * p fprintf(stderr, " --adam-decay N AdamW weight decay. Values greater zero enable AdamW instead of regular Adam. (default %f)\n", params->adam_decay); fprintf(stderr, " --mem-model N Memory to allocate for model and cache in gigabytes. (default %d)\n", params->mem_model_gb); fprintf(stderr, " --mem-compute N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute_gb); + fprintf(stderr, " --mem-compute0 N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute0_gb); + fprintf(stderr, " --mem-compute1 N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute1_gb); + fprintf(stderr, " --mem-compute2 N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute2_gb); fprintf(stderr, "\n"); } @@ -2408,6 +2916,24 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) { break; } params->mem_compute_gb = std::stoi(argv[i]); + } else if (arg == "--mem-compute0") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->mem_compute0_gb = std::stoi(argv[i]); + } else if (arg == "--mem-compute1") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->mem_compute1_gb = std::stoi(argv[i]); + } else if (arg == "--mem-compute2") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->mem_compute2_gb = std::stoi(argv[i]); } else if (arg == "-h" || arg == "--help") { train_print_usage(argc, argv, &default_params); exit(0); @@ -2563,6 +3089,13 @@ int main(int argc, char ** argv) { size_t compute_size = 1024ll*1024ll*1024ll*((size_t) params.mem_compute_gb); uint8_t * compute_addr = new uint8_t[compute_size]; + size_t size_buf_0 = 1024ll*1024ll*1024ll*((size_t) params.mem_compute0_gb); + size_t size_buf_1 = 1024ll*1024ll*1024ll*((size_t) params.mem_compute1_gb); + size_t size_buf_2 = 1024ll*1024ll*1024ll*((size_t) params.mem_compute2_gb); + uint8_t * compute_buf_0 = new uint8_t[size_buf_0]; + uint8_t * compute_buf_1 = new uint8_t[size_buf_1]; + uint8_t * compute_buf_2 = new uint8_t[size_buf_2]; + GGML_ASSERT(train_tokens.size() > n_tokens);; std::vector train_samples; train_samples.push_back(0); @@ -2601,22 +3134,46 @@ int main(int argc, char ** argv) { int n_past = 0; - ggml_cgraph gf = {}; - gf.n_threads = params.n_threads; + struct ggml_tensor * gfbuf = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / ggml_type_size(GGML_TYPE_I32) + (sizeof(struct ggml_cgraph) % ggml_type_size(GGML_TYPE_I32) ? 1 : 0)); + struct ggml_tensor * gbbuf = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / ggml_type_size(GGML_TYPE_I32) + (sizeof(struct ggml_cgraph) % ggml_type_size(GGML_TYPE_I32) ? 1 : 0)); + + struct ggml_cgraph * gf = (struct ggml_cgraph *) gfbuf->data; + struct ggml_cgraph * gb = (struct ggml_cgraph *) gbbuf->data; + + // ggml_cgraph gf = {}; + gf->n_threads = params.n_threads; + gb->n_threads = params.n_threads; get_example_targets_batch(lctx, train_samples.data(), train_samples.size(), train_tokens.data(), train_tokens.size(), ex, tokens_input, target_logits, target_probs); - struct ggml_tensor * logits = - (n_past == 0) - ? (params.use_flash - ? forward_batch_wo_cache_flash_attn(&model, ctx0, &gf, tokens_input, n_tokens, n_batch) - : forward_batch_wo_cache(&model, ctx0, &gf, tokens_input, n_tokens, n_batch)) - : forward_batch(&model, &kv_self, ctx0, &gf, tokens_input, n_tokens, n_past, n_batch); + // struct ggml_tensor * logits = + // (n_past == 0) + // ? (params.use_flash + // ? forward_batch_wo_cache_flash_attn(&model, ctx0, &gf, tokens_input, n_tokens, n_batch) + // : forward_batch_wo_cache(&model, ctx0, &gf, tokens_input, n_tokens, n_batch)) + // : forward_batch(&model, &kv_self, ctx0, &gf, tokens_input, n_tokens, n_past, n_batch); - struct ggml_tensor * e = cross_entropy_loss(ctx0, logits, target_probs); + // struct ggml_tensor * e = cross_entropy_loss(ctx0, logits, target_probs); + struct ggml_tensor * logits; + struct ggml_tensor * e = forward_batch_wo_cache_flash_attn_train( + &model, + ctx0, + gf, + gb, + &logits, + tokens_input, + target_probs, + compute_buf_0, + compute_buf_1, + compute_buf_2, + size_buf_0, + size_buf_1, + size_buf_2, + n_tokens, + n_batch); - ggml_build_forward_expand(&gf, e); - ggml_graph_compute(ctx0, &gf); + // ggml_build_forward_expand(&gf, e); + ggml_graph_compute(ctx0, gf); size_t used_mem_before_opt = ggml_used_mem(ctx0); @@ -2633,7 +3190,8 @@ int main(int argc, char ** argv) { printf("%s: opt->params.adam.sched %.5f\n", __func__, opt->params.adam.sched); // ggml_opt(ctx0, opt->params, e); - ggml_opt_resume(ctx0, opt, e); + // ggml_opt_resume(ctx0, opt, e); + ggml_opt_resume_g(ctx0, opt, e, gf, gb); size_t used_mem_after_opt = ggml_used_mem(ctx0); @@ -2641,8 +3199,8 @@ int main(int argc, char ** argv) { model.train_samples += n_batch; model.train_tokens += n_batch * n_tokens; - ggml_build_forward_expand(&gf, e); - ggml_graph_compute(ctx0, &gf); + //ggml_build_forward_expand(&gf, e); + ggml_graph_compute(ctx0, gf); float error_after_opt = ggml_get_f32_1d(e, 0); @@ -2753,7 +3311,10 @@ int main(int argc, char ** argv) { } } - free(compute_addr); + delete[] compute_addr; + delete[] compute_buf_0; + delete[] compute_buf_1; + delete[] compute_buf_2; ggml_free(model.ctx); return 0;