bug fix in forward_batch_wo_cache_flash_attn_train
This commit is contained in:
parent
edf6fc252a
commit
fdeb99784a
1 changed files with 4 additions and 3 deletions
|
@ -1708,7 +1708,8 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
|
|||
use_buf(-1);
|
||||
// t36->grad gets set to one by optimizer, so we need to create the tensor.
|
||||
// initialize it with 1.0f to make sure.
|
||||
t36->grad = ggml_new_f32(ctx0, 1.0f);
|
||||
GGML_ASSERT(t36->grad != NULL);
|
||||
// t36->grad = expand(gb, ggml_new_f32(ctx0, 1.0f));
|
||||
|
||||
use_buf(1);
|
||||
t35->grad = expand(gb, ggml_cross_entropy_loss_back(ctx0, t35, targets, t36->grad)); assert_shape_3d(t35->grad, n_vocab, N, n_batch);
|
||||
|
@ -1766,7 +1767,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
|
|||
use_buf(1);
|
||||
t30->grad = expand(gb, ggml_rms_norm_back(ctx0, t30, back_layer_inp->grad)); assert_shape_2d(t30->grad, n_embd, N*n_batch);
|
||||
if (grad_layer_inp) {
|
||||
t30->grad = expand(gb, ggml_add(ctx0, t30->grad, grad_layer_inp)); assert_shape_2d(t30->grad, n_embd, N*n_batch);
|
||||
t30->grad = expand(gb, ggml_add(ctx0, t30->grad, grad_layer_inp->grad)); assert_shape_2d(t30->grad, n_embd, N*n_batch);
|
||||
}
|
||||
clr_buf(2);
|
||||
t29->grad = t30->grad; assert_shape_2d(t29->grad, n_embd, N*n_batch);
|
||||
|
@ -1808,7 +1809,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
|
|||
t03->grad = expand(gb, ggml_mul(ctx0, t04->grad, t02)); assert_shape_2d(t04->grad, n_embd, N*n_batch);
|
||||
use_buf(2);
|
||||
t02->grad = expand(gb, ggml_mul(ctx0, t04->grad, t03)); assert_shape_2d(t02->grad, n_embd, N*n_batch);
|
||||
back_layer_inp = t02->grad;
|
||||
back_layer_inp = t02;
|
||||
use_buf(1);
|
||||
|
||||
use_buf(-1);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue