diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 2c17d0b99..70fcdc5de 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -1838,7 +1838,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train( clr_buf(0); use_buf(0); - t30->grad = expand(gb, ggml_rms_norm_back(ctx0, t30, back_layer_inp->grad)); assert_shape_2d(t30->grad, n_embd, N*n_batch); + t30->grad = expand(gb, ggml_rms_norm_back(ctx0, t30, back_layer_inp->grad, rms_norm_eps)); assert_shape_2d(t30->grad, n_embd, N*n_batch); if (grad_layer_inp) { t30->grad = expand(gb, ggml_add(ctx0, t30->grad, grad_layer_inp->grad)); assert_shape_2d(t30->grad, n_embd, N*n_batch); } @@ -1854,7 +1854,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train( t23->grad = expand(gb, ggml_mul(ctx0, t24->grad, t22)); assert_shape_2d(t23->grad, n_embd, N*n_batch); t22->grad = expand(gb, ggml_mul(ctx0, t24->grad, ggml_repeat(ctx0, layer.ffn_norm, t24->grad))); assert_shape_2d(t22->grad, n_embd, N*n_batch); use_buf(1); - t21->grad = expand(gb, ggml_add(ctx0, t30->grad, ggml_rms_norm_back(ctx0, t21, t22->grad))); assert_shape_2d(t21->grad, n_embd, N*n_batch); + t21->grad = expand(gb, ggml_add(ctx0, t30->grad, ggml_rms_norm_back(ctx0, t21, t22->grad, rms_norm_eps))); assert_shape_2d(t21->grad, n_embd, N*n_batch); grad_layer_inp = t21; use_buf(0); t20->grad = t21->grad; assert_shape_2d(t20->grad, n_embd, N*n_batch); @@ -1899,9 +1899,9 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train( } clr_buf(0); use_buf(0); - t01->grad = expand(gb, ggml_add_inplace(ctx0, grad_layer_inp->grad, ggml_rms_norm_back(ctx0, t01, back_layer_inp->grad))); assert_shape_2d(t01->grad, n_embd, N*n_batch); + t01->grad = expand(gb, ggml_add_inplace(ctx0, grad_layer_inp->grad, ggml_rms_norm_back(ctx0, t01, back_layer_inp->grad, rms_norm_eps))); assert_shape_2d(t01->grad, n_embd, N*n_batch); use_buf(-1); - model->tok_embeddings->grad = expand(gb, ggml_get_rows_back(ctx0, t01->grad, t00, model->tok_embeddings)); assert_shape_2d(model->tok_embeddings->grad, n_embd, n_vocab); + model->tok_embeddings->grad = expand(gb, ggml_get_rows_back(ctx0, t01->grad, t00, model->tok_embeddings)); assert_shape_2d(model->tok_embeddings->grad, n_embd, n_vocab); // clr_buf(1); // clr_buf(0); @@ -2396,9 +2396,9 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train_grad_checkpointing( clr_buf(0); use_buf(0); - t30->grad = expand(gb, ggml_rms_norm_back(ctx0, t30, back_layer_inp->grad)); assert_shape_2d(t30->grad, n_embd, N*n_batch); + t30->grad = expand(gb, ggml_rms_norm_back(ctx0, t30, back_layer_inp->grad, rms_norm_eps)); assert_shape_2d(t30->grad, n_embd, N*n_batch); if (grad_layer_inp) { - t30->grad = expand(gb, ggml_add(ctx0, t30->grad, grad_layer_inp->grad)); assert_shape_2d(t30->grad, n_embd, N*n_batch); + t30->grad = expand(gb, ggml_add(ctx0, t30->grad, grad_layer_inp->grad)); assert_shape_2d(t30->grad, n_embd, N*n_batch); } clr_buf(1); t29->grad = t30->grad; assert_shape_2d(t29->grad, n_embd, N*n_batch); @@ -2412,7 +2412,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train_grad_checkpointing( t23->grad = expand(gb, ggml_mul(ctx0, t24->grad, t22)); assert_shape_2d(t23->grad, n_embd, N*n_batch); t22->grad = expand(gb, ggml_mul(ctx0, t24->grad, ggml_repeat(ctx0, layer.ffn_norm, t24->grad))); assert_shape_2d(t22->grad, n_embd, N*n_batch); use_buf(1); - t21->grad = expand(gb, ggml_add(ctx0, t30->grad, ggml_rms_norm_back(ctx0, t21, t22->grad))); assert_shape_2d(t21->grad, n_embd, N*n_batch); + t21->grad = expand(gb, ggml_add(ctx0, t30->grad, ggml_rms_norm_back(ctx0, t21, t22->grad, rms_norm_eps))); assert_shape_2d(t21->grad, n_embd, N*n_batch); grad_layer_inp = t21; use_buf(0); t20->grad = t21->grad; assert_shape_2d(t20->grad, n_embd, N*n_batch); @@ -2458,9 +2458,9 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train_grad_checkpointing( GGML_ASSERT(avail_begin == 0); clr_buf(0); use_buf(0); - t01->grad = expand(gb, ggml_add_inplace(ctx0, grad_layer_inp->grad, ggml_rms_norm_back(ctx0, t01, back_layer_inp->grad))); assert_shape_2d(t01->grad, n_embd, N*n_batch); + t01->grad = expand(gb, ggml_add_inplace(ctx0, grad_layer_inp->grad, ggml_rms_norm_back(ctx0, t01, back_layer_inp->grad, rms_norm_eps))); assert_shape_2d(t01->grad, n_embd, N*n_batch); use_buf(-1); - model->tok_embeddings->grad = expand(gb, ggml_get_rows_back(ctx0, t01->grad, t00, model->tok_embeddings)); assert_shape_2d(model->tok_embeddings->grad, n_embd, n_vocab); + model->tok_embeddings->grad = expand(gb, ggml_get_rows_back(ctx0, t01->grad, t00, model->tok_embeddings)); assert_shape_2d(model->tok_embeddings->grad, n_embd, n_vocab); *logits = t35;