diff --git a/examples/baby-llama/baby-llama-text.cpp b/examples/baby-llama/baby-llama-text.cpp index ecdb418bf..5d48b7155 100644 --- a/examples/baby-llama/baby-llama-text.cpp +++ b/examples/baby-llama/baby-llama-text.cpp @@ -1191,7 +1191,6 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn( struct ggml_tensor * tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N*n_batch); memcpy(tokens->data, tokens_input->data, ggml_element_size(tokens)*N*n_batch); - // inpL shape [n_embd,N*n_batch,1] struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens); assert_shape_2d(inpL, n_embd, N*n_batch); for (int il = 0; il < n_layer; ++il) { @@ -1199,11 +1198,8 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn( struct ggml_tensor * cur; - // lctx.use_buf(ctx0, 0); - // norm { - // cur shape [n_embd,N*n_batch,1,1] cur = ggml_rms_norm(ctx0, inpL); assert_shape_2d(cur, n_embd, N*n_batch); @@ -1219,94 +1215,48 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn( // compute Q and K and RoPE them // wq shape [n_embd, n_embd, 1, 1] // wk shape [n_embd, n_embd, 1, 1] - // Qcur shape [n_embd/n_head, n_head, N, n_batch] - // Kcur shape [n_embd/n_head, n_head, N, n_batch] struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0); struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0); assert_shape_4d(Qcur, n_embd/n_head, n_head, N, n_batch); assert_shape_4d(Kcur, n_embd/n_head, n_head, N, n_batch); - // Vcur shape [N, n_batch, n_embd/n_head, n_head] struct ggml_tensor * Vcur = ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, cur, model->layers[il].wv), N, n_batch, n_embd/n_head, n_head); assert_shape_4d(Vcur, N, n_batch, n_embd/n_head, n_head); - // Qcur shape [n_embd/n_head, n_head, N, n_batch] - // Q shape [n_embd/n_head, N, n_head, n_batch] struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); assert_shape_4d(Q, n_embd/n_head, N, n_head, n_batch); - // kv_self.k shape [n_embd * n_ctx * n_batch * n_layer] - // K shape [n_embd/n_head, N, n_head, n_batch] struct ggml_tensor * K = ggml_permute(ctx0, Kcur, 0, 2, 1, 3); assert_shape_4d(K, n_embd/n_head, N, n_head, n_batch); - // // K * Q - // // KQ shape [N, N, n_head, n_batch] - // struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - // assert_shape_4d(KQ, N, N, n_head, n_batch); - - // // KQ_scaled = KQ / sqrt(n_embd/n_head) - // // KQ_scaled shape [N, N, n_head, n_batch] - // struct ggml_tensor * KQ_scaled = - // ggml_scale_inplace(ctx0, - // KQ, - // ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head))); - // assert_shape_4d(KQ_scaled, N, N, n_head, n_batch); - - // // KQ_masked = mask_past(KQ_scaled) - // // KQ_masked shape [N, N, n_head, n_batch] - // struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); - // assert_shape_4d(KQ_masked, N, N, n_head, n_batch); - - // // KQ = soft_max(KQ_masked) - // // KQ_soft_max shape [N, N, n_head, n_batch] - // struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); - // assert_shape_4d(KQ_soft_max, N, N, n_head, n_batch); - - // Vcur shape [N, n_batch, n_embd/n_head, n_head] - // V shape [N, n_embd/n_head, n_head, n_batch] struct ggml_tensor * V = ggml_permute(ctx0, Vcur, 0, 3, 1, 2); assert_shape_4d(V, N, n_embd/n_head, n_head, n_batch); - // // KQV shape [n_embd/n_head, N, n_head, n_batch] - // struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); - // assert_shape_4d(KQV, n_embd/n_head, N, n_head, n_batch); - - bool masked = true; struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, masked); assert_shape_4d(KQV, n_embd/n_head, N, n_head, n_batch); - // KQV_merged = KQV.permute(0, 2, 1, 3) - // KQV_merged shape [n_embd/n_head, n_head, N, n_batch] struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); assert_shape_4d(KQV_merged, n_embd/n_head, n_head, N, n_batch); - // KQV_merged shape - - // cur shape [n_embd,N*n_batch,1,1] cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, KQV_merged), n_embd, N*n_batch); assert_shape_2d(cur, n_embd, N*n_batch); // projection (no bias) - // cur shape [n_embd,N*n_batch,1,1] cur = ggml_mul_mat(ctx0, model->layers[il].wo, cur); assert_shape_2d(cur, n_embd, N*n_batch); } - // lctx.use_buf(ctx0, 1); - - // inpFF shape [n_embd,N*n_batch,1,1] struct ggml_tensor * inpFF = ggml_add_inplace(ctx0, cur, inpSA); assert_shape_2d(inpFF, n_embd, N*n_batch); @@ -1314,52 +1264,43 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn( { // norm { - // cur shape [n_embd,N*n_batch,1,1] cur = ggml_rms_norm(ctx0, inpFF); assert_shape_2d(cur, n_embd, N*n_batch); // cur = ffn_norm*cur - // cur shape [n_embd,N*n_batch,1,1] cur = ggml_mul(ctx0, ggml_repeat(ctx0, model->layers[il].ffn_norm, cur), cur); assert_shape_2d(cur, n_embd, N*n_batch); } - // tmp shape [n_ff,N*n_batch,1,1] struct ggml_tensor * tmp = ggml_mul_mat(ctx0, model->layers[il].w3, cur); assert_shape_2d(tmp, n_ff, N*n_batch); - // cur shape [n_ff,N*n_batch,1,1] cur = ggml_mul_mat(ctx0, model->layers[il].w1, cur); assert_shape_2d(cur, n_ff, N*n_batch); // SILU activation - // cur shape [n_ff,N*n_batch,1,1] cur = ggml_silu(ctx0, cur); assert_shape_2d(cur, n_ff, N*n_batch); - // cur shape [n_ff,N*n_batch,1,1] cur = ggml_mul(ctx0, cur, tmp); assert_shape_2d(cur, n_ff, N*n_batch); - // cur shape [n_embd,N*n_batch,1,1] cur = ggml_mul_mat(ctx0, model->layers[il].w2, cur); assert_shape_2d(cur, n_embd, N*n_batch); } - // cur shape [n_embd,N*n_batch,1,1] cur = ggml_add_inplace(ctx0, cur, inpFF); assert_shape_2d(cur, n_embd, N*n_batch); // input for next layer - // inpL shape [n_embd,N*n_batch,1,1] inpL = cur; assert_shape_2d(inpL, n_embd, N*n_batch); } @@ -1367,28 +1308,22 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn( // norm { - // inpL shape [n_embd,N*n_batch,1,1] inpL = ggml_rms_norm(ctx0, inpL); assert_shape_2d(inpL, n_embd, N*n_batch); // inpL = norm*inpL - // inpL shape [n_embd,N*n_batch,1,1] inpL = ggml_mul(ctx0, ggml_repeat(ctx0, model->norm, inpL), inpL); assert_shape_2d(inpL, n_embd, N*n_batch); - - //embeddings = inpL; } // lm_head - // inpL shape [n_vocab,N*n_batch,1,1] inpL = ggml_mul_mat(ctx0, model->output, inpL); assert_shape_2d(inpL, n_vocab, N*n_batch); { - // inpL shape [n_vocab,N,n_batch,1] inpL = ggml_reshape_3d(ctx0, inpL, n_vocab, N, n_batch);