llama training : fix ggml_rms_norm_back calls to pass configurable eps
This commit is contained in:
parent
ecdc16163e
commit
c1a5e116a4
1 changed files with 9 additions and 9 deletions
|
@ -1838,7 +1838,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
|
|||
|
||||
clr_buf(0);
|
||||
use_buf(0);
|
||||
t30->grad = expand(gb, ggml_rms_norm_back(ctx0, t30, back_layer_inp->grad)); assert_shape_2d(t30->grad, n_embd, N*n_batch);
|
||||
t30->grad = expand(gb, ggml_rms_norm_back(ctx0, t30, back_layer_inp->grad, rms_norm_eps)); assert_shape_2d(t30->grad, n_embd, N*n_batch);
|
||||
if (grad_layer_inp) {
|
||||
t30->grad = expand(gb, ggml_add(ctx0, t30->grad, grad_layer_inp->grad)); assert_shape_2d(t30->grad, n_embd, N*n_batch);
|
||||
}
|
||||
|
@ -1854,7 +1854,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
|
|||
t23->grad = expand(gb, ggml_mul(ctx0, t24->grad, t22)); assert_shape_2d(t23->grad, n_embd, N*n_batch);
|
||||
t22->grad = expand(gb, ggml_mul(ctx0, t24->grad, ggml_repeat(ctx0, layer.ffn_norm, t24->grad))); assert_shape_2d(t22->grad, n_embd, N*n_batch);
|
||||
use_buf(1);
|
||||
t21->grad = expand(gb, ggml_add(ctx0, t30->grad, ggml_rms_norm_back(ctx0, t21, t22->grad))); assert_shape_2d(t21->grad, n_embd, N*n_batch);
|
||||
t21->grad = expand(gb, ggml_add(ctx0, t30->grad, ggml_rms_norm_back(ctx0, t21, t22->grad, rms_norm_eps))); assert_shape_2d(t21->grad, n_embd, N*n_batch);
|
||||
grad_layer_inp = t21;
|
||||
use_buf(0);
|
||||
t20->grad = t21->grad; assert_shape_2d(t20->grad, n_embd, N*n_batch);
|
||||
|
@ -1899,9 +1899,9 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
|
|||
}
|
||||
clr_buf(0);
|
||||
use_buf(0);
|
||||
t01->grad = expand(gb, ggml_add_inplace(ctx0, grad_layer_inp->grad, ggml_rms_norm_back(ctx0, t01, back_layer_inp->grad))); assert_shape_2d(t01->grad, n_embd, N*n_batch);
|
||||
t01->grad = expand(gb, ggml_add_inplace(ctx0, grad_layer_inp->grad, ggml_rms_norm_back(ctx0, t01, back_layer_inp->grad, rms_norm_eps))); assert_shape_2d(t01->grad, n_embd, N*n_batch);
|
||||
use_buf(-1);
|
||||
model->tok_embeddings->grad = expand(gb, ggml_get_rows_back(ctx0, t01->grad, t00, model->tok_embeddings)); assert_shape_2d(model->tok_embeddings->grad, n_embd, n_vocab);
|
||||
model->tok_embeddings->grad = expand(gb, ggml_get_rows_back(ctx0, t01->grad, t00, model->tok_embeddings)); assert_shape_2d(model->tok_embeddings->grad, n_embd, n_vocab);
|
||||
// clr_buf(1);
|
||||
// clr_buf(0);
|
||||
|
||||
|
@ -2396,9 +2396,9 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train_grad_checkpointing(
|
|||
|
||||
clr_buf(0);
|
||||
use_buf(0);
|
||||
t30->grad = expand(gb, ggml_rms_norm_back(ctx0, t30, back_layer_inp->grad)); assert_shape_2d(t30->grad, n_embd, N*n_batch);
|
||||
t30->grad = expand(gb, ggml_rms_norm_back(ctx0, t30, back_layer_inp->grad, rms_norm_eps)); assert_shape_2d(t30->grad, n_embd, N*n_batch);
|
||||
if (grad_layer_inp) {
|
||||
t30->grad = expand(gb, ggml_add(ctx0, t30->grad, grad_layer_inp->grad)); assert_shape_2d(t30->grad, n_embd, N*n_batch);
|
||||
t30->grad = expand(gb, ggml_add(ctx0, t30->grad, grad_layer_inp->grad)); assert_shape_2d(t30->grad, n_embd, N*n_batch);
|
||||
}
|
||||
clr_buf(1);
|
||||
t29->grad = t30->grad; assert_shape_2d(t29->grad, n_embd, N*n_batch);
|
||||
|
@ -2412,7 +2412,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train_grad_checkpointing(
|
|||
t23->grad = expand(gb, ggml_mul(ctx0, t24->grad, t22)); assert_shape_2d(t23->grad, n_embd, N*n_batch);
|
||||
t22->grad = expand(gb, ggml_mul(ctx0, t24->grad, ggml_repeat(ctx0, layer.ffn_norm, t24->grad))); assert_shape_2d(t22->grad, n_embd, N*n_batch);
|
||||
use_buf(1);
|
||||
t21->grad = expand(gb, ggml_add(ctx0, t30->grad, ggml_rms_norm_back(ctx0, t21, t22->grad))); assert_shape_2d(t21->grad, n_embd, N*n_batch);
|
||||
t21->grad = expand(gb, ggml_add(ctx0, t30->grad, ggml_rms_norm_back(ctx0, t21, t22->grad, rms_norm_eps))); assert_shape_2d(t21->grad, n_embd, N*n_batch);
|
||||
grad_layer_inp = t21;
|
||||
use_buf(0);
|
||||
t20->grad = t21->grad; assert_shape_2d(t20->grad, n_embd, N*n_batch);
|
||||
|
@ -2458,9 +2458,9 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train_grad_checkpointing(
|
|||
GGML_ASSERT(avail_begin == 0);
|
||||
clr_buf(0);
|
||||
use_buf(0);
|
||||
t01->grad = expand(gb, ggml_add_inplace(ctx0, grad_layer_inp->grad, ggml_rms_norm_back(ctx0, t01, back_layer_inp->grad))); assert_shape_2d(t01->grad, n_embd, N*n_batch);
|
||||
t01->grad = expand(gb, ggml_add_inplace(ctx0, grad_layer_inp->grad, ggml_rms_norm_back(ctx0, t01, back_layer_inp->grad, rms_norm_eps))); assert_shape_2d(t01->grad, n_embd, N*n_batch);
|
||||
use_buf(-1);
|
||||
model->tok_embeddings->grad = expand(gb, ggml_get_rows_back(ctx0, t01->grad, t00, model->tok_embeddings)); assert_shape_2d(model->tok_embeddings->grad, n_embd, n_vocab);
|
||||
model->tok_embeddings->grad = expand(gb, ggml_get_rows_back(ctx0, t01->grad, t00, model->tok_embeddings)); assert_shape_2d(model->tok_embeddings->grad, n_embd, n_vocab);
|
||||
|
||||
*logits = t35;
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue