swap arguments to commutative ops to be the same as in forward_batch_wo_cache_flash_attn

This commit is contained in:
xaedes 2023-08-14 17:57:13 +02:00
parent 5a11b75875
commit b2f1310196
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -1590,7 +1590,7 @@ struct ggml_tensor * llama_build_train_graphs(
struct my_llama_layer & layer = model->layers[il];
struct ggml_tensor * t02 = ggml_rms_norm (ctx, cur, rms_norm_eps); assert_shape_2d(t02, n_embd, N*n_batch);
struct ggml_tensor * t03 = ggml_repeat (ctx, layer.attention_norm, t02); assert_shape_2d(t03, n_embd, N*n_batch);
struct ggml_tensor * t04 = ggml_mul (ctx, t02, t03); assert_shape_2d(t04, n_embd, N*n_batch);
struct ggml_tensor * t04 = ggml_mul (ctx, t03, t02); assert_shape_2d(t04, n_embd, N*n_batch);
struct ggml_tensor * t05 = ggml_mul_mat (ctx, layer.wq, t04); assert_shape_2d(t05, n_embd, N*n_batch);
struct ggml_tensor * t06 = ggml_reshape_4d (ctx, t05, n_embd/n_head, n_head, N, n_batch); assert_shape_4d(t06, n_embd/n_head, n_head, N, n_batch);
struct ggml_tensor * t07 = ggml_rope_inplace (ctx, t06, n_past, n_rot, rope_mode, n_ctx); assert_shape_4d(t07, n_embd/n_head, n_head, N, n_batch);
@ -1625,7 +1625,7 @@ struct ggml_tensor * llama_build_train_graphs(
struct ggml_tensor * t27 = ggml_silu (ctx, t26); assert_shape_2d(t27, n_ff, N*n_batch);
struct ggml_tensor * t28 = ggml_mul (ctx, t27, t25); assert_shape_2d(t28, n_ff, N*n_batch);
struct ggml_tensor * t29 = ggml_mul_mat (ctx, layer.w2, t28); assert_shape_2d(t29, n_embd, N*n_batch);
struct ggml_tensor * t30 = ggml_add (ctx, t21, t29); assert_shape_2d(t30, n_embd, N*n_batch);
struct ggml_tensor * t30 = ggml_add (ctx, t29, t21); assert_shape_2d(t30, n_embd, N*n_batch);
cur = t30;
checkpoints.push_back(cur);
}