From fdeb99784abb1f6ad399df53aa6a4fae1b977e9d Mon Sep 17 00:00:00 2001 From: xaedes Date: Sun, 11 Jun 2023 19:58:36 +0200 Subject: [PATCH] bug fix in forward_batch_wo_cache_flash_attn_train --- .../train-text-from-scratch/train-text-from-scratch.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index f933c0164..9bbeda125 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -1708,7 +1708,8 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train( use_buf(-1); // t36->grad gets set to one by optimizer, so we need to create the tensor. // initialize it with 1.0f to make sure. - t36->grad = ggml_new_f32(ctx0, 1.0f); + GGML_ASSERT(t36->grad != NULL); + // t36->grad = expand(gb, ggml_new_f32(ctx0, 1.0f)); use_buf(1); t35->grad = expand(gb, ggml_cross_entropy_loss_back(ctx0, t35, targets, t36->grad)); assert_shape_3d(t35->grad, n_vocab, N, n_batch); @@ -1766,7 +1767,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train( use_buf(1); t30->grad = expand(gb, ggml_rms_norm_back(ctx0, t30, back_layer_inp->grad)); assert_shape_2d(t30->grad, n_embd, N*n_batch); if (grad_layer_inp) { - t30->grad = expand(gb, ggml_add(ctx0, t30->grad, grad_layer_inp)); assert_shape_2d(t30->grad, n_embd, N*n_batch); + t30->grad = expand(gb, ggml_add(ctx0, t30->grad, grad_layer_inp->grad)); assert_shape_2d(t30->grad, n_embd, N*n_batch); } clr_buf(2); t29->grad = t30->grad; assert_shape_2d(t29->grad, n_embd, N*n_batch); @@ -1808,7 +1809,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train( t03->grad = expand(gb, ggml_mul(ctx0, t04->grad, t02)); assert_shape_2d(t04->grad, n_embd, N*n_batch); use_buf(2); t02->grad = expand(gb, ggml_mul(ctx0, t04->grad, t03)); assert_shape_2d(t02->grad, n_embd, N*n_batch); - back_layer_inp = t02->grad; + back_layer_inp = t02; use_buf(1); use_buf(-1);