llama training : fix ggml_rms_norm_back calls to pass configurable eps

This commit is contained in:
xaedes 2023-07-28 23:10:55 +02:00
parent ecdc16163e
commit c1a5e116a4
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -1838,7 +1838,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
clr_buf(0);
use_buf(0);
t30->grad = expand(gb, ggml_rms_norm_back(ctx0, t30, back_layer_inp->grad)); assert_shape_2d(t30->grad, n_embd, N*n_batch);
t30->grad = expand(gb, ggml_rms_norm_back(ctx0, t30, back_layer_inp->grad, rms_norm_eps)); assert_shape_2d(t30->grad, n_embd, N*n_batch);
if (grad_layer_inp) {
t30->grad = expand(gb, ggml_add(ctx0, t30->grad, grad_layer_inp->grad)); assert_shape_2d(t30->grad, n_embd, N*n_batch);
}
@ -1854,7 +1854,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
t23->grad = expand(gb, ggml_mul(ctx0, t24->grad, t22)); assert_shape_2d(t23->grad, n_embd, N*n_batch);
t22->grad = expand(gb, ggml_mul(ctx0, t24->grad, ggml_repeat(ctx0, layer.ffn_norm, t24->grad))); assert_shape_2d(t22->grad, n_embd, N*n_batch);
use_buf(1);
t21->grad = expand(gb, ggml_add(ctx0, t30->grad, ggml_rms_norm_back(ctx0, t21, t22->grad))); assert_shape_2d(t21->grad, n_embd, N*n_batch);
t21->grad = expand(gb, ggml_add(ctx0, t30->grad, ggml_rms_norm_back(ctx0, t21, t22->grad, rms_norm_eps))); assert_shape_2d(t21->grad, n_embd, N*n_batch);
grad_layer_inp = t21;
use_buf(0);
t20->grad = t21->grad; assert_shape_2d(t20->grad, n_embd, N*n_batch);
@ -1899,9 +1899,9 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
}
clr_buf(0);
use_buf(0);
t01->grad = expand(gb, ggml_add_inplace(ctx0, grad_layer_inp->grad, ggml_rms_norm_back(ctx0, t01, back_layer_inp->grad))); assert_shape_2d(t01->grad, n_embd, N*n_batch);
t01->grad = expand(gb, ggml_add_inplace(ctx0, grad_layer_inp->grad, ggml_rms_norm_back(ctx0, t01, back_layer_inp->grad, rms_norm_eps))); assert_shape_2d(t01->grad, n_embd, N*n_batch);
use_buf(-1);
model->tok_embeddings->grad = expand(gb, ggml_get_rows_back(ctx0, t01->grad, t00, model->tok_embeddings)); assert_shape_2d(model->tok_embeddings->grad, n_embd, n_vocab);
model->tok_embeddings->grad = expand(gb, ggml_get_rows_back(ctx0, t01->grad, t00, model->tok_embeddings)); assert_shape_2d(model->tok_embeddings->grad, n_embd, n_vocab);
// clr_buf(1);
// clr_buf(0);
@ -2396,9 +2396,9 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train_grad_checkpointing(
clr_buf(0);
use_buf(0);
t30->grad = expand(gb, ggml_rms_norm_back(ctx0, t30, back_layer_inp->grad)); assert_shape_2d(t30->grad, n_embd, N*n_batch);
t30->grad = expand(gb, ggml_rms_norm_back(ctx0, t30, back_layer_inp->grad, rms_norm_eps)); assert_shape_2d(t30->grad, n_embd, N*n_batch);
if (grad_layer_inp) {
t30->grad = expand(gb, ggml_add(ctx0, t30->grad, grad_layer_inp->grad)); assert_shape_2d(t30->grad, n_embd, N*n_batch);
t30->grad = expand(gb, ggml_add(ctx0, t30->grad, grad_layer_inp->grad)); assert_shape_2d(t30->grad, n_embd, N*n_batch);
}
clr_buf(1);
t29->grad = t30->grad; assert_shape_2d(t29->grad, n_embd, N*n_batch);
@ -2412,7 +2412,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train_grad_checkpointing(
t23->grad = expand(gb, ggml_mul(ctx0, t24->grad, t22)); assert_shape_2d(t23->grad, n_embd, N*n_batch);
t22->grad = expand(gb, ggml_mul(ctx0, t24->grad, ggml_repeat(ctx0, layer.ffn_norm, t24->grad))); assert_shape_2d(t22->grad, n_embd, N*n_batch);
use_buf(1);
t21->grad = expand(gb, ggml_add(ctx0, t30->grad, ggml_rms_norm_back(ctx0, t21, t22->grad))); assert_shape_2d(t21->grad, n_embd, N*n_batch);
t21->grad = expand(gb, ggml_add(ctx0, t30->grad, ggml_rms_norm_back(ctx0, t21, t22->grad, rms_norm_eps))); assert_shape_2d(t21->grad, n_embd, N*n_batch);
grad_layer_inp = t21;
use_buf(0);
t20->grad = t21->grad; assert_shape_2d(t20->grad, n_embd, N*n_batch);
@ -2458,9 +2458,9 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train_grad_checkpointing(
GGML_ASSERT(avail_begin == 0);
clr_buf(0);
use_buf(0);
t01->grad = expand(gb, ggml_add_inplace(ctx0, grad_layer_inp->grad, ggml_rms_norm_back(ctx0, t01, back_layer_inp->grad))); assert_shape_2d(t01->grad, n_embd, N*n_batch);
t01->grad = expand(gb, ggml_add_inplace(ctx0, grad_layer_inp->grad, ggml_rms_norm_back(ctx0, t01, back_layer_inp->grad, rms_norm_eps))); assert_shape_2d(t01->grad, n_embd, N*n_batch);
use_buf(-1);
model->tok_embeddings->grad = expand(gb, ggml_get_rows_back(ctx0, t01->grad, t00, model->tok_embeddings)); assert_shape_2d(model->tok_embeddings->grad, n_embd, n_vocab);
model->tok_embeddings->grad = expand(gb, ggml_get_rows_back(ctx0, t01->grad, t00, model->tok_embeddings)); assert_shape_2d(model->tok_embeddings->grad, n_embd, n_vocab);
*logits = t35;