bug fix in forward_batch_wo_cache_flash_attn_train

This commit is contained in:
xaedes 2023-06-11 19:58:36 +02:00
parent edf6fc252a
commit fdeb99784a
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -1708,7 +1708,8 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
use_buf(-1);
// t36->grad gets set to one by optimizer, so we need to create the tensor.
// initialize it with 1.0f to make sure.
t36->grad = ggml_new_f32(ctx0, 1.0f);
GGML_ASSERT(t36->grad != NULL);
// t36->grad = expand(gb, ggml_new_f32(ctx0, 1.0f));
use_buf(1);
t35->grad = expand(gb, ggml_cross_entropy_loss_back(ctx0, t35, targets, t36->grad)); assert_shape_3d(t35->grad, n_vocab, N, n_batch);
@ -1766,7 +1767,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
use_buf(1);
t30->grad = expand(gb, ggml_rms_norm_back(ctx0, t30, back_layer_inp->grad)); assert_shape_2d(t30->grad, n_embd, N*n_batch);
if (grad_layer_inp) {
t30->grad = expand(gb, ggml_add(ctx0, t30->grad, grad_layer_inp)); assert_shape_2d(t30->grad, n_embd, N*n_batch);
t30->grad = expand(gb, ggml_add(ctx0, t30->grad, grad_layer_inp->grad)); assert_shape_2d(t30->grad, n_embd, N*n_batch);
}
clr_buf(2);
t29->grad = t30->grad; assert_shape_2d(t29->grad, n_embd, N*n_batch);
@ -1808,7 +1809,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
t03->grad = expand(gb, ggml_mul(ctx0, t04->grad, t02)); assert_shape_2d(t04->grad, n_embd, N*n_batch);
use_buf(2);
t02->grad = expand(gb, ggml_mul(ctx0, t04->grad, t03)); assert_shape_2d(t02->grad, n_embd, N*n_batch);
back_layer_inp = t02->grad;
back_layer_inp = t02;
use_buf(1);
use_buf(-1);