working in cpu, metal buggy
This commit is contained in:
parent
101c578715
commit
a1cf66ea94
2 changed files with 163 additions and 151 deletions
|
@ -109,7 +109,7 @@ gguf_writer.add_max_position_embeddings(hparams["n_positions"])
|
|||
gguf_writer.add_feed_forward_length(4 * hparams["n_embd"])
|
||||
gguf_writer.add_block_count(block_count)
|
||||
gguf_writer.add_head_count(hparams["n_head"])
|
||||
gguf_writer.add_head_count_kv(1)
|
||||
gguf_writer.add_head_count_kv(hparams["n_head"])
|
||||
gguf_writer.add_layer_norm_eps(hparams["layer_norm_epsilon"])
|
||||
gguf_writer.add_file_type(ftype)
|
||||
|
||||
|
@ -209,6 +209,24 @@ for part_name in part_names:
|
|||
|
||||
data = data.squeeze().numpy()
|
||||
|
||||
if name.endswith(".attn.c_attn.weight") or name.endswith(".attn.c_attn.bias"):
|
||||
print("Duplicate K,V heads to use MHA instead of MQA for", name)
|
||||
|
||||
embed_dim = hparams["n_embd"]
|
||||
head_dim = embed_dim // hparams["n_head"]
|
||||
|
||||
# ((n_heads + 2) * head_dim, hidden_dim) -> (3 * n_heads * head_dim, hidden_dim)
|
||||
q, k ,v = np.split(data, (hparams["n_head"] * head_dim, (hparams["n_head"] + 1) * head_dim), axis=0)
|
||||
# duplicate k, v along the first axis (head_dim, hidden_dim) -> (n_heads * head_dim, hidden_dim)
|
||||
if len(k.shape) == 2:
|
||||
k = np.tile(k, (hparams["n_head"], 1))
|
||||
v = np.tile(v, (hparams["n_head"], 1))
|
||||
elif len(k.shape) == 1:
|
||||
k = np.tile(k, (hparams["n_head"]))
|
||||
v = np.tile(v, (hparams["n_head"]))
|
||||
# concat q, k, v along the first axis (n_heads * head_dim, hidden_dim) -> (3 * n_heads * head_dim, hidden_dim)
|
||||
data = np.concatenate((q, k, v), axis=0)
|
||||
|
||||
# map tensor names
|
||||
new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
|
||||
if new_name is None:
|
||||
|
|
284
llama.cpp
284
llama.cpp
|
@ -1221,6 +1221,7 @@ static bool llama_kv_cache_init(
|
|||
return false;
|
||||
}
|
||||
|
||||
fprintf(stderr, "n_embed: %d n_layer: %d n_ctx: %d n_elements: %d\n", n_embd, n_layer, n_ctx, n_elements);
|
||||
cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
|
||||
cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
|
||||
ggml_set_name(cache.k, "cache_k");
|
||||
|
@ -2259,8 +2260,8 @@ static void llm_load_tensors(
|
|||
layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend);
|
||||
layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend);
|
||||
|
||||
layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split);
|
||||
layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, backend_split);
|
||||
layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, 3*n_embd}, backend_split);
|
||||
layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {3*n_embd}, backend_split);
|
||||
|
||||
layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
|
||||
layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend_split);
|
||||
|
@ -3540,16 +3541,8 @@ static struct ggml_cgraph * llm_build_starcoder(
|
|||
}
|
||||
#endif // GGML_USE_CUBLAS
|
||||
|
||||
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
|
||||
ggml_allocr_alloc(lctx.alloc, KQ_scale);
|
||||
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
||||
ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head));
|
||||
}
|
||||
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
|
||||
|
||||
#define PRINT_SHAPE(x) fprintf(stderr, "%d %s: (%s)\n", __LINE__, #x, llama_format_tensor_shape(x).c_str())
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * attn_norm;
|
||||
|
||||
offload_func_t offload_func = llama_nop;
|
||||
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
|
@ -3558,186 +3551,187 @@ static struct ggml_cgraph * llm_build_starcoder(
|
|||
}
|
||||
#endif // GGML_USE_CUBLAS
|
||||
|
||||
// self-attention
|
||||
// TODO: refactor into common function (shared with LLaMA)
|
||||
{
|
||||
attn_norm = ggml_norm(ctx0, inpL, norm_eps);
|
||||
offload_func(attn_norm);
|
||||
// Norm
|
||||
cur = ggml_norm(ctx0, inpL, norm_eps);
|
||||
|
||||
attn_norm = ggml_add(ctx0,
|
||||
ggml_mul(ctx0, attn_norm, model.layers[il].attn_norm),
|
||||
model.layers[il].attn_norm_b);
|
||||
offload_func(attn_norm->src[0]);
|
||||
offload_func(attn_norm);
|
||||
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].attn_norm), model.layers[il].attn_norm_b);
|
||||
|
||||
cur = attn_norm;
|
||||
}
|
||||
|
||||
// compute QKV
|
||||
{
|
||||
// Compute QKV
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
|
||||
offload_func_kq(cur);
|
||||
|
||||
// ===== TBD (QKV Split + FF) ====
|
||||
#define PRINT_SHAPE(x) fprintf(stderr, "%d %s: (%s)\n", __LINE__, #x, llama_format_tensor_shape(x).c_str())
|
||||
GGML_ASSERT(false);
|
||||
|
||||
// Note that the strides for Kcur, Vcur are set up so that the
|
||||
// resulting views are misaligned with the tensor's storage
|
||||
// (by applying the K/V offset we shift the tensor's original
|
||||
// view to stick out behind the viewed QKV tensor's allocated
|
||||
// memory, so to say). This is ok because no actual accesses
|
||||
// happen to that out-of-range memory, but it can require some
|
||||
// trickery when trying to accurately dump these views for
|
||||
// debugging.
|
||||
|
||||
const size_t wsize = ggml_type_size(cur->type);
|
||||
|
||||
// TODO: these 2 ggml_conts are technically not needed, but we add them until CUDA support for
|
||||
// non-contiguous views is added for the rope operator
|
||||
struct ggml_tensor * tmpq = ggml_cont(ctx0, ggml_view_3d(
|
||||
ctx0, cur, n_embd_head, n_head, N,
|
||||
wsize * n_embd_head,
|
||||
wsize * n_embd_head * (n_head + 2 * n_head_kv),
|
||||
0));
|
||||
offload_func_kq(tmpq);
|
||||
|
||||
struct ggml_tensor * tmpk = ggml_cont(ctx0, ggml_view_3d(
|
||||
ctx0, cur, n_embd_head, n_head_kv, N,
|
||||
wsize * n_embd_head,
|
||||
wsize * n_embd_head * (n_head + 2 * n_head_kv),
|
||||
wsize * n_embd_head * n_head));
|
||||
offload_func_kq(tmpk);
|
||||
|
||||
struct ggml_tensor * tmpv = ggml_view_3d(
|
||||
ctx0, cur, n_embd_head, n_head_kv, N,
|
||||
wsize * n_embd_head,
|
||||
wsize * n_embd_head * (n_head + 2 * n_head_kv),
|
||||
wsize * n_embd_head * (n_head + n_head_kv));
|
||||
offload_func_v(tmpv);
|
||||
|
||||
// using mode = 2 for neox mode
|
||||
struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, tmpq, n_past, n_embd_head, 2, 0, freq_base, freq_scale);
|
||||
offload_func_kq(Qcur);
|
||||
struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, tmpk, n_past, n_embd_head, 2, 0, freq_base, freq_scale);
|
||||
offload_func_kq(Kcur);
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
|
||||
}
|
||||
|
||||
{
|
||||
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, N));
|
||||
offload_func_v(Vcur);
|
||||
offload_func_v(Vcur->src[0]->src[0]);
|
||||
ggml_set_name(Vcur, "Vcur");
|
||||
// Self Attention
|
||||
struct ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 0*sizeof(float)*n_embd);
|
||||
struct ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 1*sizeof(float)*n_embd);
|
||||
struct ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 2*sizeof(float)*n_embd);
|
||||
|
||||
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past));
|
||||
offload_func_kq(k);
|
||||
ggml_set_name(k, "k");
|
||||
|
||||
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa,
|
||||
( n_ctx)*ggml_element_size(kv_self.v),
|
||||
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v));
|
||||
offload_func_v(v);
|
||||
// store key and value to memory
|
||||
if (N >= 1) {
|
||||
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past));
|
||||
struct ggml_tensor * v = ggml_view_1d(ctx0, kv_self.v, N*n_embd, (ggml_element_size(kv_self.v)*n_embd)*(il*n_ctx + n_past));
|
||||
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
|
||||
}
|
||||
|
||||
struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
|
||||
offload_func_kq(Q);
|
||||
ggml_set_name(Q, "Q");
|
||||
// Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
|
||||
// [64, N, 12]
|
||||
struct ggml_tensor * Q =
|
||||
ggml_permute(ctx0,
|
||||
ggml_cpy(ctx0,
|
||||
Qcur,
|
||||
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)),
|
||||
0, 2, 1, 3);
|
||||
|
||||
// K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
|
||||
// [64, n_past + N, 12]
|
||||
struct ggml_tensor * K =
|
||||
ggml_view_3d(ctx0, kv_self.k,
|
||||
n_embd_head, n_past + N, n_head_kv,
|
||||
ggml_element_size(kv_self.k)*n_embd_gqa,
|
||||
ggml_element_size(kv_self.k)*n_embd_head,
|
||||
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
|
||||
offload_func_kq(K);
|
||||
ggml_set_name(K, "K");
|
||||
ggml_permute(ctx0,
|
||||
ggml_reshape_3d(ctx0,
|
||||
ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.k)*n_embd),
|
||||
n_embd/n_head, n_head, n_past + N),
|
||||
0, 2, 1, 3); //TODO: need to be tiled
|
||||
|
||||
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
||||
offload_func_kq(KQ);
|
||||
ggml_set_name(KQ, "KQ");
|
||||
// GG: flash attention
|
||||
//struct ggml_tensor * V =
|
||||
// ggml_cpy(ctx0,
|
||||
// ggml_permute(ctx0,
|
||||
// ggml_reshape_3d(ctx0,
|
||||
// ggml_view_1d(ctx0, kv_self.v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.v)*n_embd),
|
||||
// n_embd/n_head, n_head, n_past + N),
|
||||
// 1, 2, 0, 3),
|
||||
// ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_past + N, n_embd/n_head, n_head));
|
||||
|
||||
struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
|
||||
offload_func_kq(KQ_scaled);
|
||||
ggml_set_name(KQ_scaled, "KQ_scaled");
|
||||
//struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, true);
|
||||
|
||||
// K * Q
|
||||
// [n_past + N, N, 12]
|
||||
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); //TODO: check if it broadcasts
|
||||
|
||||
// KQ_scaled = KQ / sqrt(n_embd/n_head)
|
||||
// [n_past + N, N, 12]
|
||||
struct ggml_tensor * KQ_scaled =
|
||||
ggml_scale_inplace(ctx0,
|
||||
KQ,
|
||||
ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head))
|
||||
);
|
||||
|
||||
// KQ_masked = mask_past(KQ_scaled)
|
||||
// [n_past + N, N, 12]
|
||||
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
|
||||
offload_func_kq(KQ_masked);
|
||||
ggml_set_name(KQ_masked, "KQ_masked");
|
||||
|
||||
// KQ = soft_max(KQ_masked)
|
||||
// [n_past + N, N, 12]
|
||||
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
|
||||
offload_func_v(KQ_soft_max);
|
||||
ggml_set_name(KQ_soft_max, "KQ_soft_max");
|
||||
|
||||
struct ggml_tensor * V =
|
||||
ggml_view_3d(ctx0, kv_self.v,
|
||||
n_past + N, n_embd_head, n_head_kv,
|
||||
ggml_element_size(kv_self.v)*n_ctx,
|
||||
ggml_element_size(kv_self.v)*n_ctx*n_embd_head,
|
||||
ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il);
|
||||
offload_func_v(V);
|
||||
ggml_set_name(V, "V");
|
||||
// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
|
||||
// [n_past + N, 64, 12]
|
||||
struct ggml_tensor * V_trans =
|
||||
ggml_cpy(ctx0,
|
||||
ggml_permute(ctx0,
|
||||
ggml_reshape_3d(ctx0,
|
||||
ggml_view_1d(ctx0, kv_self.v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.v)*n_embd),
|
||||
n_embd/n_head, n_head, n_past + N),
|
||||
1, 2, 0, 3),
|
||||
ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd/n_head, n_head));
|
||||
|
||||
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
||||
offload_func_v(KQV);
|
||||
ggml_set_name(KQV, "KQV");
|
||||
// KQV = transpose(V) * KQ_soft_max
|
||||
// [64, N, 12]
|
||||
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
|
||||
|
||||
// KQV_merged = KQV.permute(0, 2, 1, 3)
|
||||
// [64, 12, N]
|
||||
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
||||
offload_func_v(KQV_merged);
|
||||
ggml_set_name(KQV_merged, "KQV_merged");
|
||||
|
||||
cur = ggml_cpy(ctx0, KQV_merged, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
|
||||
offload_func_v(cur);
|
||||
ggml_set_name(cur, "KQV_merged_contiguous");
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur);
|
||||
offload_func(cur);
|
||||
ggml_set_name(cur, "result_wo");
|
||||
// cur = KQV_merged.contiguous().view(n_embd, N)
|
||||
// [768, N]
|
||||
cur = ggml_cpy(ctx0,
|
||||
KQV_merged,
|
||||
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
|
||||
}
|
||||
|
||||
struct ggml_tensor * attn_out = cur;
|
||||
|
||||
// feed forward
|
||||
// Projection
|
||||
{
|
||||
struct ggml_tensor * inpFF = attn_norm;
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].w3, inpFF);
|
||||
offload_func(cur);
|
||||
|
||||
cur = ggml_gelu(ctx0, cur);
|
||||
offload_func(cur);
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur);
|
||||
offload_func(cur);
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur);
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].bo);
|
||||
}
|
||||
|
||||
cur = ggml_add(ctx0, cur, attn_out);
|
||||
offload_func(cur);
|
||||
// add the input
|
||||
cur = ggml_add(ctx0, cur, inpL);
|
||||
offload_func(cur);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
struct ggml_tensor * inpFF = cur;
|
||||
|
||||
// FF
|
||||
{
|
||||
// norm
|
||||
{
|
||||
cur = ggml_norm(ctx0, inpFF, norm_eps);
|
||||
|
||||
// cur = ln_2_g*cur + ln_2_b
|
||||
// [ 768, N]
|
||||
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ffn_norm), model.layers[il].ffn_norm_b);
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
// fully connected
|
||||
// [3072, 768] - model.layers[il].c_mlp_fc_w
|
||||
// [3072, 1] - model.layers[il].c_mlp_fc_b
|
||||
// [ 768, N] - cur (in)
|
||||
// [3072, N] - cur (out)
|
||||
//
|
||||
// cur = fc_w*cur + fc_b
|
||||
// [3072, N]
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
model.layers[il].w3,
|
||||
cur);
|
||||
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].b3);
|
||||
|
||||
// GELU activation
|
||||
// [3072, N]
|
||||
cur = ggml_gelu(ctx0, cur);
|
||||
|
||||
// projection
|
||||
// [ 768, 3072] - model.layers[il].c_mlp_proj_w
|
||||
// [ 768, 1] - model.layers[il].c_mlp_proj_b
|
||||
// [3072, N] - cur (in)
|
||||
// [ 768, N] - cur (out)
|
||||
//
|
||||
// cur = proj_w*cur + proj_b
|
||||
// [768, N]
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
model.layers[il].w2,
|
||||
cur);
|
||||
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].b2);
|
||||
}
|
||||
|
||||
inpL = ggml_add(ctx0, cur, inpFF);
|
||||
}
|
||||
|
||||
// norm
|
||||
{
|
||||
cur = ggml_norm(ctx0, cur, norm_eps);
|
||||
offload_func_nr(cur);
|
||||
// [ 768, N]
|
||||
inpL = ggml_norm(ctx0, inpL, norm_eps);
|
||||
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_mul(ctx0, cur, model.output_norm),
|
||||
model.output_norm_b);
|
||||
ggml_set_name(cur, "result_norm");
|
||||
// inpL = ln_f_g*inpL + ln_f_b
|
||||
// [ 768, N]
|
||||
inpL = ggml_add(ctx0, ggml_mul(ctx0, inpL, model.output_norm), model.output_norm_b);
|
||||
}
|
||||
ggml_set_name(inpL, "result_norm");
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.output, cur);
|
||||
cur = ggml_mul_mat(ctx0, model.output, inpL);
|
||||
ggml_set_name(cur, "result_output");
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
ggml_free(ctx0);
|
||||
|
||||
// norm
|
||||
return gf;
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue