remove unnecessary comments

This commit is contained in:
xaedes 2023-05-30 15:58:22 +02:00
parent ec8e262d1d
commit ad966da955
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -1191,7 +1191,6 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(
struct ggml_tensor * tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N*n_batch); struct ggml_tensor * tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N*n_batch);
memcpy(tokens->data, tokens_input->data, ggml_element_size(tokens)*N*n_batch); memcpy(tokens->data, tokens_input->data, ggml_element_size(tokens)*N*n_batch);
// inpL shape [n_embd,N*n_batch,1]
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens); struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens);
assert_shape_2d(inpL, n_embd, N*n_batch); assert_shape_2d(inpL, n_embd, N*n_batch);
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
@ -1199,11 +1198,8 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(
struct ggml_tensor * cur; struct ggml_tensor * cur;
// lctx.use_buf(ctx0, 0);
// norm // norm
{ {
// cur shape [n_embd,N*n_batch,1,1]
cur = ggml_rms_norm(ctx0, inpL); cur = ggml_rms_norm(ctx0, inpL);
assert_shape_2d(cur, n_embd, N*n_batch); assert_shape_2d(cur, n_embd, N*n_batch);
@ -1219,94 +1215,48 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(
// compute Q and K and RoPE them // compute Q and K and RoPE them
// wq shape [n_embd, n_embd, 1, 1] // wq shape [n_embd, n_embd, 1, 1]
// wk shape [n_embd, n_embd, 1, 1] // wk shape [n_embd, n_embd, 1, 1]
// Qcur shape [n_embd/n_head, n_head, N, n_batch]
// Kcur shape [n_embd/n_head, n_head, N, n_batch]
struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0); struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0);
struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0); struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0);
assert_shape_4d(Qcur, n_embd/n_head, n_head, N, n_batch); assert_shape_4d(Qcur, n_embd/n_head, n_head, N, n_batch);
assert_shape_4d(Kcur, n_embd/n_head, n_head, N, n_batch); assert_shape_4d(Kcur, n_embd/n_head, n_head, N, n_batch);
// Vcur shape [N, n_batch, n_embd/n_head, n_head]
struct ggml_tensor * Vcur = ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, cur, model->layers[il].wv), N, n_batch, n_embd/n_head, n_head); struct ggml_tensor * Vcur = ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, cur, model->layers[il].wv), N, n_batch, n_embd/n_head, n_head);
assert_shape_4d(Vcur, N, n_batch, n_embd/n_head, n_head); assert_shape_4d(Vcur, N, n_batch, n_embd/n_head, n_head);
// Qcur shape [n_embd/n_head, n_head, N, n_batch]
// Q shape [n_embd/n_head, N, n_head, n_batch]
struct ggml_tensor * Q = struct ggml_tensor * Q =
ggml_permute(ctx0, ggml_permute(ctx0,
Qcur, Qcur,
0, 2, 1, 3); 0, 2, 1, 3);
assert_shape_4d(Q, n_embd/n_head, N, n_head, n_batch); assert_shape_4d(Q, n_embd/n_head, N, n_head, n_batch);
// kv_self.k shape [n_embd * n_ctx * n_batch * n_layer]
// K shape [n_embd/n_head, N, n_head, n_batch]
struct ggml_tensor * K = struct ggml_tensor * K =
ggml_permute(ctx0, ggml_permute(ctx0,
Kcur, Kcur,
0, 2, 1, 3); 0, 2, 1, 3);
assert_shape_4d(K, n_embd/n_head, N, n_head, n_batch); assert_shape_4d(K, n_embd/n_head, N, n_head, n_batch);
// // K * Q
// // KQ shape [N, N, n_head, n_batch]
// struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
// assert_shape_4d(KQ, N, N, n_head, n_batch);
// // KQ_scaled = KQ / sqrt(n_embd/n_head)
// // KQ_scaled shape [N, N, n_head, n_batch]
// struct ggml_tensor * KQ_scaled =
// ggml_scale_inplace(ctx0,
// KQ,
// ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)));
// assert_shape_4d(KQ_scaled, N, N, n_head, n_batch);
// // KQ_masked = mask_past(KQ_scaled)
// // KQ_masked shape [N, N, n_head, n_batch]
// struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
// assert_shape_4d(KQ_masked, N, N, n_head, n_batch);
// // KQ = soft_max(KQ_masked)
// // KQ_soft_max shape [N, N, n_head, n_batch]
// struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
// assert_shape_4d(KQ_soft_max, N, N, n_head, n_batch);
// Vcur shape [N, n_batch, n_embd/n_head, n_head]
// V shape [N, n_embd/n_head, n_head, n_batch]
struct ggml_tensor * V = struct ggml_tensor * V =
ggml_permute(ctx0, ggml_permute(ctx0,
Vcur, Vcur,
0, 3, 1, 2); 0, 3, 1, 2);
assert_shape_4d(V, N, n_embd/n_head, n_head, n_batch); assert_shape_4d(V, N, n_embd/n_head, n_head, n_batch);
// // KQV shape [n_embd/n_head, N, n_head, n_batch]
// struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
// assert_shape_4d(KQV, n_embd/n_head, N, n_head, n_batch);
bool masked = true; bool masked = true;
struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, masked); struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, masked);
assert_shape_4d(KQV, n_embd/n_head, N, n_head, n_batch); assert_shape_4d(KQV, n_embd/n_head, N, n_head, n_batch);
// KQV_merged = KQV.permute(0, 2, 1, 3)
// KQV_merged shape [n_embd/n_head, n_head, N, n_batch]
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
assert_shape_4d(KQV_merged, n_embd/n_head, n_head, N, n_batch); assert_shape_4d(KQV_merged, n_embd/n_head, n_head, N, n_batch);
// KQV_merged shape
// cur shape [n_embd,N*n_batch,1,1]
cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, KQV_merged), n_embd, N*n_batch); cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, KQV_merged), n_embd, N*n_batch);
assert_shape_2d(cur, n_embd, N*n_batch); assert_shape_2d(cur, n_embd, N*n_batch);
// projection (no bias) // projection (no bias)
// cur shape [n_embd,N*n_batch,1,1]
cur = ggml_mul_mat(ctx0, cur = ggml_mul_mat(ctx0,
model->layers[il].wo, model->layers[il].wo,
cur); cur);
assert_shape_2d(cur, n_embd, N*n_batch); assert_shape_2d(cur, n_embd, N*n_batch);
} }
// lctx.use_buf(ctx0, 1);
// inpFF shape [n_embd,N*n_batch,1,1]
struct ggml_tensor * inpFF = ggml_add_inplace(ctx0, cur, inpSA); struct ggml_tensor * inpFF = ggml_add_inplace(ctx0, cur, inpSA);
assert_shape_2d(inpFF, n_embd, N*n_batch); assert_shape_2d(inpFF, n_embd, N*n_batch);
@ -1314,52 +1264,43 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(
{ {
// norm // norm
{ {
// cur shape [n_embd,N*n_batch,1,1]
cur = ggml_rms_norm(ctx0, inpFF); cur = ggml_rms_norm(ctx0, inpFF);
assert_shape_2d(cur, n_embd, N*n_batch); assert_shape_2d(cur, n_embd, N*n_batch);
// cur = ffn_norm*cur // cur = ffn_norm*cur
// cur shape [n_embd,N*n_batch,1,1]
cur = ggml_mul(ctx0, cur = ggml_mul(ctx0,
ggml_repeat(ctx0, model->layers[il].ffn_norm, cur), ggml_repeat(ctx0, model->layers[il].ffn_norm, cur),
cur); cur);
assert_shape_2d(cur, n_embd, N*n_batch); assert_shape_2d(cur, n_embd, N*n_batch);
} }
// tmp shape [n_ff,N*n_batch,1,1]
struct ggml_tensor * tmp = ggml_mul_mat(ctx0, struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
model->layers[il].w3, model->layers[il].w3,
cur); cur);
assert_shape_2d(tmp, n_ff, N*n_batch); assert_shape_2d(tmp, n_ff, N*n_batch);
// cur shape [n_ff,N*n_batch,1,1]
cur = ggml_mul_mat(ctx0, cur = ggml_mul_mat(ctx0,
model->layers[il].w1, model->layers[il].w1,
cur); cur);
assert_shape_2d(cur, n_ff, N*n_batch); assert_shape_2d(cur, n_ff, N*n_batch);
// SILU activation // SILU activation
// cur shape [n_ff,N*n_batch,1,1]
cur = ggml_silu(ctx0, cur); cur = ggml_silu(ctx0, cur);
assert_shape_2d(cur, n_ff, N*n_batch); assert_shape_2d(cur, n_ff, N*n_batch);
// cur shape [n_ff,N*n_batch,1,1]
cur = ggml_mul(ctx0, cur, tmp); cur = ggml_mul(ctx0, cur, tmp);
assert_shape_2d(cur, n_ff, N*n_batch); assert_shape_2d(cur, n_ff, N*n_batch);
// cur shape [n_embd,N*n_batch,1,1]
cur = ggml_mul_mat(ctx0, cur = ggml_mul_mat(ctx0,
model->layers[il].w2, model->layers[il].w2,
cur); cur);
assert_shape_2d(cur, n_embd, N*n_batch); assert_shape_2d(cur, n_embd, N*n_batch);
} }
// cur shape [n_embd,N*n_batch,1,1]
cur = ggml_add_inplace(ctx0, cur, inpFF); cur = ggml_add_inplace(ctx0, cur, inpFF);
assert_shape_2d(cur, n_embd, N*n_batch); assert_shape_2d(cur, n_embd, N*n_batch);
// input for next layer // input for next layer
// inpL shape [n_embd,N*n_batch,1,1]
inpL = cur; inpL = cur;
assert_shape_2d(inpL, n_embd, N*n_batch); assert_shape_2d(inpL, n_embd, N*n_batch);
} }
@ -1367,28 +1308,22 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(
// norm // norm
{ {
// inpL shape [n_embd,N*n_batch,1,1]
inpL = ggml_rms_norm(ctx0, inpL); inpL = ggml_rms_norm(ctx0, inpL);
assert_shape_2d(inpL, n_embd, N*n_batch); assert_shape_2d(inpL, n_embd, N*n_batch);
// inpL = norm*inpL // inpL = norm*inpL
// inpL shape [n_embd,N*n_batch,1,1]
inpL = ggml_mul(ctx0, inpL = ggml_mul(ctx0,
ggml_repeat(ctx0, model->norm, inpL), ggml_repeat(ctx0, model->norm, inpL),
inpL); inpL);
assert_shape_2d(inpL, n_embd, N*n_batch); assert_shape_2d(inpL, n_embd, N*n_batch);
//embeddings = inpL;
} }
// lm_head // lm_head
// inpL shape [n_vocab,N*n_batch,1,1]
inpL = ggml_mul_mat(ctx0, model->output, inpL); inpL = ggml_mul_mat(ctx0, model->output, inpL);
assert_shape_2d(inpL, n_vocab, N*n_batch); assert_shape_2d(inpL, n_vocab, N*n_batch);
{ {
// inpL shape [n_vocab,N,n_batch,1]
inpL = ggml_reshape_3d(ctx0, inpL = ggml_reshape_3d(ctx0,
inpL, inpL,
n_vocab, N, n_batch); n_vocab, N, n_batch);