swap arguments to commutative ops to be the same as in forward_batch_wo_cache_flash_attn
This commit is contained in:
parent
5a11b75875
commit
b2f1310196
1 changed files with 2 additions and 2 deletions
|
@ -1590,7 +1590,7 @@ struct ggml_tensor * llama_build_train_graphs(
|
|||
struct my_llama_layer & layer = model->layers[il];
|
||||
struct ggml_tensor * t02 = ggml_rms_norm (ctx, cur, rms_norm_eps); assert_shape_2d(t02, n_embd, N*n_batch);
|
||||
struct ggml_tensor * t03 = ggml_repeat (ctx, layer.attention_norm, t02); assert_shape_2d(t03, n_embd, N*n_batch);
|
||||
struct ggml_tensor * t04 = ggml_mul (ctx, t02, t03); assert_shape_2d(t04, n_embd, N*n_batch);
|
||||
struct ggml_tensor * t04 = ggml_mul (ctx, t03, t02); assert_shape_2d(t04, n_embd, N*n_batch);
|
||||
struct ggml_tensor * t05 = ggml_mul_mat (ctx, layer.wq, t04); assert_shape_2d(t05, n_embd, N*n_batch);
|
||||
struct ggml_tensor * t06 = ggml_reshape_4d (ctx, t05, n_embd/n_head, n_head, N, n_batch); assert_shape_4d(t06, n_embd/n_head, n_head, N, n_batch);
|
||||
struct ggml_tensor * t07 = ggml_rope_inplace (ctx, t06, n_past, n_rot, rope_mode, n_ctx); assert_shape_4d(t07, n_embd/n_head, n_head, N, n_batch);
|
||||
|
@ -1625,7 +1625,7 @@ struct ggml_tensor * llama_build_train_graphs(
|
|||
struct ggml_tensor * t27 = ggml_silu (ctx, t26); assert_shape_2d(t27, n_ff, N*n_batch);
|
||||
struct ggml_tensor * t28 = ggml_mul (ctx, t27, t25); assert_shape_2d(t28, n_ff, N*n_batch);
|
||||
struct ggml_tensor * t29 = ggml_mul_mat (ctx, layer.w2, t28); assert_shape_2d(t29, n_embd, N*n_batch);
|
||||
struct ggml_tensor * t30 = ggml_add (ctx, t21, t29); assert_shape_2d(t30, n_embd, N*n_batch);
|
||||
struct ggml_tensor * t30 = ggml_add (ctx, t29, t21); assert_shape_2d(t30, n_embd, N*n_batch);
|
||||
cur = t30;
|
||||
checkpoints.push_back(cur);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue