llama training : fix ggml_rms_norm_back calls to pass configurable eps
This commit is contained in:
parent
ecdc16163e
commit
c1a5e116a4
1 changed files with 9 additions and 9 deletions
|
@ -1838,7 +1838,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
|
||||||
|
|
||||||
clr_buf(0);
|
clr_buf(0);
|
||||||
use_buf(0);
|
use_buf(0);
|
||||||
t30->grad = expand(gb, ggml_rms_norm_back(ctx0, t30, back_layer_inp->grad)); assert_shape_2d(t30->grad, n_embd, N*n_batch);
|
t30->grad = expand(gb, ggml_rms_norm_back(ctx0, t30, back_layer_inp->grad, rms_norm_eps)); assert_shape_2d(t30->grad, n_embd, N*n_batch);
|
||||||
if (grad_layer_inp) {
|
if (grad_layer_inp) {
|
||||||
t30->grad = expand(gb, ggml_add(ctx0, t30->grad, grad_layer_inp->grad)); assert_shape_2d(t30->grad, n_embd, N*n_batch);
|
t30->grad = expand(gb, ggml_add(ctx0, t30->grad, grad_layer_inp->grad)); assert_shape_2d(t30->grad, n_embd, N*n_batch);
|
||||||
}
|
}
|
||||||
|
@ -1854,7 +1854,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
|
||||||
t23->grad = expand(gb, ggml_mul(ctx0, t24->grad, t22)); assert_shape_2d(t23->grad, n_embd, N*n_batch);
|
t23->grad = expand(gb, ggml_mul(ctx0, t24->grad, t22)); assert_shape_2d(t23->grad, n_embd, N*n_batch);
|
||||||
t22->grad = expand(gb, ggml_mul(ctx0, t24->grad, ggml_repeat(ctx0, layer.ffn_norm, t24->grad))); assert_shape_2d(t22->grad, n_embd, N*n_batch);
|
t22->grad = expand(gb, ggml_mul(ctx0, t24->grad, ggml_repeat(ctx0, layer.ffn_norm, t24->grad))); assert_shape_2d(t22->grad, n_embd, N*n_batch);
|
||||||
use_buf(1);
|
use_buf(1);
|
||||||
t21->grad = expand(gb, ggml_add(ctx0, t30->grad, ggml_rms_norm_back(ctx0, t21, t22->grad))); assert_shape_2d(t21->grad, n_embd, N*n_batch);
|
t21->grad = expand(gb, ggml_add(ctx0, t30->grad, ggml_rms_norm_back(ctx0, t21, t22->grad, rms_norm_eps))); assert_shape_2d(t21->grad, n_embd, N*n_batch);
|
||||||
grad_layer_inp = t21;
|
grad_layer_inp = t21;
|
||||||
use_buf(0);
|
use_buf(0);
|
||||||
t20->grad = t21->grad; assert_shape_2d(t20->grad, n_embd, N*n_batch);
|
t20->grad = t21->grad; assert_shape_2d(t20->grad, n_embd, N*n_batch);
|
||||||
|
@ -1899,9 +1899,9 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
|
||||||
}
|
}
|
||||||
clr_buf(0);
|
clr_buf(0);
|
||||||
use_buf(0);
|
use_buf(0);
|
||||||
t01->grad = expand(gb, ggml_add_inplace(ctx0, grad_layer_inp->grad, ggml_rms_norm_back(ctx0, t01, back_layer_inp->grad))); assert_shape_2d(t01->grad, n_embd, N*n_batch);
|
t01->grad = expand(gb, ggml_add_inplace(ctx0, grad_layer_inp->grad, ggml_rms_norm_back(ctx0, t01, back_layer_inp->grad, rms_norm_eps))); assert_shape_2d(t01->grad, n_embd, N*n_batch);
|
||||||
use_buf(-1);
|
use_buf(-1);
|
||||||
model->tok_embeddings->grad = expand(gb, ggml_get_rows_back(ctx0, t01->grad, t00, model->tok_embeddings)); assert_shape_2d(model->tok_embeddings->grad, n_embd, n_vocab);
|
model->tok_embeddings->grad = expand(gb, ggml_get_rows_back(ctx0, t01->grad, t00, model->tok_embeddings)); assert_shape_2d(model->tok_embeddings->grad, n_embd, n_vocab);
|
||||||
// clr_buf(1);
|
// clr_buf(1);
|
||||||
// clr_buf(0);
|
// clr_buf(0);
|
||||||
|
|
||||||
|
@ -2396,9 +2396,9 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train_grad_checkpointing(
|
||||||
|
|
||||||
clr_buf(0);
|
clr_buf(0);
|
||||||
use_buf(0);
|
use_buf(0);
|
||||||
t30->grad = expand(gb, ggml_rms_norm_back(ctx0, t30, back_layer_inp->grad)); assert_shape_2d(t30->grad, n_embd, N*n_batch);
|
t30->grad = expand(gb, ggml_rms_norm_back(ctx0, t30, back_layer_inp->grad, rms_norm_eps)); assert_shape_2d(t30->grad, n_embd, N*n_batch);
|
||||||
if (grad_layer_inp) {
|
if (grad_layer_inp) {
|
||||||
t30->grad = expand(gb, ggml_add(ctx0, t30->grad, grad_layer_inp->grad)); assert_shape_2d(t30->grad, n_embd, N*n_batch);
|
t30->grad = expand(gb, ggml_add(ctx0, t30->grad, grad_layer_inp->grad)); assert_shape_2d(t30->grad, n_embd, N*n_batch);
|
||||||
}
|
}
|
||||||
clr_buf(1);
|
clr_buf(1);
|
||||||
t29->grad = t30->grad; assert_shape_2d(t29->grad, n_embd, N*n_batch);
|
t29->grad = t30->grad; assert_shape_2d(t29->grad, n_embd, N*n_batch);
|
||||||
|
@ -2412,7 +2412,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train_grad_checkpointing(
|
||||||
t23->grad = expand(gb, ggml_mul(ctx0, t24->grad, t22)); assert_shape_2d(t23->grad, n_embd, N*n_batch);
|
t23->grad = expand(gb, ggml_mul(ctx0, t24->grad, t22)); assert_shape_2d(t23->grad, n_embd, N*n_batch);
|
||||||
t22->grad = expand(gb, ggml_mul(ctx0, t24->grad, ggml_repeat(ctx0, layer.ffn_norm, t24->grad))); assert_shape_2d(t22->grad, n_embd, N*n_batch);
|
t22->grad = expand(gb, ggml_mul(ctx0, t24->grad, ggml_repeat(ctx0, layer.ffn_norm, t24->grad))); assert_shape_2d(t22->grad, n_embd, N*n_batch);
|
||||||
use_buf(1);
|
use_buf(1);
|
||||||
t21->grad = expand(gb, ggml_add(ctx0, t30->grad, ggml_rms_norm_back(ctx0, t21, t22->grad))); assert_shape_2d(t21->grad, n_embd, N*n_batch);
|
t21->grad = expand(gb, ggml_add(ctx0, t30->grad, ggml_rms_norm_back(ctx0, t21, t22->grad, rms_norm_eps))); assert_shape_2d(t21->grad, n_embd, N*n_batch);
|
||||||
grad_layer_inp = t21;
|
grad_layer_inp = t21;
|
||||||
use_buf(0);
|
use_buf(0);
|
||||||
t20->grad = t21->grad; assert_shape_2d(t20->grad, n_embd, N*n_batch);
|
t20->grad = t21->grad; assert_shape_2d(t20->grad, n_embd, N*n_batch);
|
||||||
|
@ -2458,9 +2458,9 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train_grad_checkpointing(
|
||||||
GGML_ASSERT(avail_begin == 0);
|
GGML_ASSERT(avail_begin == 0);
|
||||||
clr_buf(0);
|
clr_buf(0);
|
||||||
use_buf(0);
|
use_buf(0);
|
||||||
t01->grad = expand(gb, ggml_add_inplace(ctx0, grad_layer_inp->grad, ggml_rms_norm_back(ctx0, t01, back_layer_inp->grad))); assert_shape_2d(t01->grad, n_embd, N*n_batch);
|
t01->grad = expand(gb, ggml_add_inplace(ctx0, grad_layer_inp->grad, ggml_rms_norm_back(ctx0, t01, back_layer_inp->grad, rms_norm_eps))); assert_shape_2d(t01->grad, n_embd, N*n_batch);
|
||||||
use_buf(-1);
|
use_buf(-1);
|
||||||
model->tok_embeddings->grad = expand(gb, ggml_get_rows_back(ctx0, t01->grad, t00, model->tok_embeddings)); assert_shape_2d(model->tok_embeddings->grad, n_embd, n_vocab);
|
model->tok_embeddings->grad = expand(gb, ggml_get_rows_back(ctx0, t01->grad, t00, model->tok_embeddings)); assert_shape_2d(model->tok_embeddings->grad, n_embd, n_vocab);
|
||||||
|
|
||||||
*logits = t35;
|
*logits = t35;
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue