This commit is contained in:
joshcarp 2024-05-07 14:38:58 -04:00
parent 16b8ecdaf5
commit 308c817af4
3 changed files with 43 additions and 155 deletions

View file

@ -2916,11 +2916,14 @@ class OpenELM(Model):
head_dim = self.find_hparam(["head_dim"])
n_head = n_embd // head_dim
rot_pct = 1.0
self.gguf_writer.add_context_length(self.find_hparam(["max_context_length"]))
self.gguf_writer.add_embedding_length(n_embd)
# self.gguf_writer.add_embedding_length(n_embd)
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_head_count(n_head)
self.gguf_writer.add_head_count_kv(n_head)
# self.gguf_writer.add_head_count(n_head)
# self.gguf_writer.add_head_count_kv(n_head)
self.gguf_writer.add_head_count_kv(n_head*10)
self.gguf_writer.add_head_count(n_head*10)
self.gguf_writer.add_rope_dimension_count(int(rot_pct * n_embd) // n_head)
self.gguf_writer.add_file_type(self.ftype)
self.gguf_writer.add_feed_forward_length(0) # dynamically calculated
@ -2977,6 +2980,7 @@ class OpenELM(Model):
block_count = self.hparams.get("num_transformer_layers", self.hparams.get("num_hidden_layers", self.hparams.get("num_transformer_layers")))
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
n_head = self.hparams.get("model_dim") // self.hparams.get("head_dim") # TODO: propagate this
foobar = {}
for name, data_torch in self.get_tensors():
old_dtype = data_torch.dtype
# convert any unsupported data types to float32
@ -3002,6 +3006,8 @@ class OpenELM(Model):
data = data.astype(np.float16)
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
self.gguf_writer.add_tensor(new_name, data)
foobar[name] = (data_torch, new_name, data)
foobar
###### CONVERSION LOGIC ######

View file

@ -145,7 +145,6 @@ class TensorNameMap:
"model.layers.layers.{bid}.self_attn.k_proj", # plamo
"model.layers.{bid}.attention.wk", # internlm2
"transformer.decoder_layer.{bid}.multi_head_attention.key", # Grok
"transformer.layers.{bid}.attn.k_norm.weight" # openelm
),
# Attention value

185
llama.cpp
View file

@ -2435,8 +2435,8 @@ static bool llama_kv_cache_init(
for (int i = 0; i < (int) n_layer; i++) {
struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size*10);
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size*10);
ggml_format_name(k, "cache_k_l%d", i);
ggml_format_name(v, "cache_v_l%d", i);
cache.k_l.push_back(k);
@ -5976,8 +5976,6 @@ static bool llm_load_tensors(
std::vector<int> num_query_heads = {12, 12, 12, 12, 12, 16, 16, 16, 16, 16, 16, 16, 20, 20, 20, 20};
std::vector<float> ffn_multipliers = {0.5, 0.73, 0.97, 1.2, 1.43, 1.67, 1.9, 2.13, 2.37, 2.6, 2.83, 3.07, 3.3, 3.53, 3.77, 4.0};
llama_hparams modified_hparams(hparams);
const int64_t n_embd_head = hparams.n_embd_head_v;
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab });
{
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd });
@ -5988,10 +5986,14 @@ static bool llm_load_tensors(
for (int i = 0; i < n_layer; ++i) {
const int64_t n_head_k = num_kv_heads[i];
const int64_t n_head_v = num_kv_heads[i];
const int64_t n_head_kv = n_head_k+n_head_v;
const int64_t n_head = n_head_kv+ num_query_heads[i];
const int64_t n_head_kv = n_head_k + n_head_v;
const int64_t n_head = n_head_kv + num_query_heads[i];
// const int64_t n_kv = (num_kv_heads[i]+num_kv_heads[i])*n_embd_head;
modified_hparams.n_head = n_head;
modified_hparams.n_embd_head_v = 64;
modified_hparams.n_embd_head_k = 64;
int64_t n_embd_head = modified_hparams.n_embd_head_v;
modified_hparams.n_head_kv = n_head_kv;
const int64_t n_embd_gqa = n_embd_head * n_head;
const int64_t n_embd_k_gqa = modified_hparams.n_embd_k_gqa();
@ -6320,42 +6322,6 @@ static void llm_build_kv_store(
(kv_head)*ggml_element_size(kv.v_l[il]));
cb(v_cache_view, "v_cache_view", il);
// important: storing RoPE-ed version of K in the KV cache!
// ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));
// ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur_t, v_cache_view));
}
static void llm_build_kv_store2(
struct ggml_context * ctx,
const llama_hparams & hparams,
const llama_kv_cache & kv,
struct ggml_cgraph * graph,
struct ggml_tensor * k_cur,
struct ggml_tensor * v_cur,
int64_t n_ctx,
int32_t n_tokens,
int32_t kv_head,
const llm_build_cb & cb,
int64_t il) {
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa()/hparams.n_head_kv;
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa()/hparams.n_head_kv;
GGML_ASSERT(kv.size == n_ctx);
// compute the transposed [n_tokens, n_embd] V matrix
assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens);
struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur);
cb(v_cur_t, "v_cur_t", il);
struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], ggml_nbytes(k_cur)/4,
(ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head);
cb(k_cache_view, "k_cache_view", il);
struct ggml_tensor * v_cache_view = ggml_view_1d(ctx, kv.v_l[il], ggml_nbytes(v_cur)/4,
// ( n_ctx)*ggml_element_size(kv.v_l[il]),
(kv_head)*ggml_element_size(kv.v_l[il]));
cb(v_cache_view, "v_cache_view", il);
// important: storing RoPE-ed version of K in the KV cache!
ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));
ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur_t, v_cache_view));
@ -6746,44 +6712,6 @@ static struct ggml_tensor * llm_build_kv(
return cur;
}
static struct ggml_tensor * llm_build_kv2(
struct ggml_context * ctx,
const llama_model & model,
const llama_hparams & hparams,
const llama_kv_cache & kv,
struct ggml_cgraph * graph,
struct ggml_tensor * wo,
struct ggml_tensor * wo_b,
struct ggml_tensor * k_cur,
struct ggml_tensor * v_cur,
struct ggml_tensor * q_cur,
struct ggml_tensor * kq_mask,
struct ggml_tensor * kq_pos,
int64_t n_ctx,
int32_t n_tokens,
int32_t kv_head,
int32_t n_kv,
float kq_scale,
const llm_build_cb & cb,
int il) {
// these nodes are added to the graph together so that they are not reordered
// by doing so, the number of splits in the graph is reduced
ggml_build_forward_expand(graph, q_cur);
ggml_build_forward_expand(graph, k_cur);
ggml_build_forward_expand(graph, v_cur);
llm_build_kv_store2(ctx, hparams, kv, graph, k_cur, v_cur, n_ctx, n_tokens, kv_head, cb, il);
struct ggml_tensor * cur;
cur = llm_build_kqv(ctx, model, hparams, kv, graph, wo, wo_b,
q_cur, kq_mask, kq_pos, n_ctx, n_tokens, n_kv, kq_scale, cb, il);
cb(cur, "kqv_out", il);
return cur;
}
struct llm_build_context {
const llama_model & model;
llama_context & lctx;
@ -10802,7 +10730,7 @@ struct llm_build_context {
struct ggml_cgraph * build_openelm() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = 64;
// TODO: get this from config
std::vector<int> num_kv_heads = {3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5};
std::vector<int> num_query_heads = {12, 12, 12, 12, 12, 16, 16, 16, 16, 16, 16, 16, 20, 20, 20, 20};
@ -10811,11 +10739,9 @@ struct llm_build_context {
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
struct ggml_tensor * inp_pos = build_inp_pos();
llama_hparams modified_hparams(hparams);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
// GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
for (int il = 0; il < n_layer; ++il) {
auto residual = inpL;
// TODO: Want the offsets to be calculated with the num heads at layer level
@ -10823,24 +10749,22 @@ struct llm_build_context {
const int64_t n_head_k = num_kv_heads[il];
const int64_t n_head_v = num_kv_heads[il];
const int64_t n_head_q = num_query_heads[il];
const int64_t n_head_kv = n_head_k+n_head_v;
int64_t n_head_kv = n_head_k+n_head_v;
const int64_t n_head = n_head_kv+ num_query_heads[il];
// const int64_t n_kv = (num_kv_heads[il]+num_kv_heads[il])*n_embd_head; // This makes asserts fail
modified_hparams.n_head = n_head;
modified_hparams.n_head = 4*n_head_k; // somehow this works. Some places expect this to be groups*n_head_kv insteal of n_head. maybe this is the defintiion somewhere.
modified_hparams.n_head_kv = n_head_kv;
modified_hparams.n_head_kv = num_query_heads[il];
modified_hparams.n_embd_head_v = 64;
modified_hparams.n_embd_head_k = 64;
modified_hparams.n_embd = 64*n_head;
n_head_kv = modified_hparams.n_head_kv;
const int64_t n_embd_gqa = n_embd_head * n_head;
struct ggml_tensor * attn_q_norm = model.layers[il].attn_q_norm;
cb(attn_q_norm, "attn_q_norm", il);
struct ggml_tensor * attn_k_norm = model.layers[il].attn_k_norm;
cb(attn_k_norm, "attn_k_norm", il);
// const int64_t n_embd_k_gqa = modified_hparams.n_embd_k_gqa();
// const int64_t n_embd_v_gqa = modified_hparams.n_embd_v_gqa();
// self-attention
{
@ -10850,7 +10774,6 @@ struct llm_build_context {
NULL,
LLM_NORM_RMS, cb, il);
cb(attn_norm_output, "attn_norm", il);
struct ggml_tensor * Qcur = nullptr;
struct ggml_tensor * Kcur = nullptr;
struct ggml_tensor * Vcur = nullptr;
@ -10859,85 +10782,52 @@ struct llm_build_context {
cb(cur, "qkv", il);
cur = ggml_reshape_3d(ctx0, cur, n_embd_head, n_tokens, n_head);
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
// TODO: these need to be calculated correctly
/*
struct ggml_tensor * tmpqkv = ggml_reshape_4d(ctx0, cur, n_embd_head, 3, n_head, n_tokens);
cb(tmpqkv, "tmpqkv", il);
struct ggml_tensor * tmpqkv_perm = ggml_cont(ctx0, ggml_permute(ctx0, tmpqkv, 0, 3, 1, 2));
cb(tmpqkv_perm, "tmpqkv", il);
struct ggml_tensor * tmpq = ggml_view_3d(
ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens,
ggml_element_size(tmpqkv_perm) * n_embd_head,
ggml_element_size(tmpqkv_perm) * n_embd_head * n_head,
0
);
cb(tmpq, "tmpq", il);
struct ggml_tensor * tmpk = ggml_view_3d(
ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens,
ggml_element_size(tmpqkv_perm) * n_embd_head,
ggml_element_size(tmpqkv_perm) * n_embd_head * n_head,
ggml_element_size(tmpqkv_perm) * n_embd_head * n_head * n_tokens
);
*/
size_t elemsize = ggml_element_size(cur);
Qcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_tokens, num_query_heads[il], cur->nb[1], cur->nb[2]*num_query_heads[il], 0));
cb(Qcur, "queries", il);
Kcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_tokens, n_head_k, cur->nb[1], cur->nb[2]*n_head_k, cur->nb[2]*num_query_heads[il]));
cb(Kcur, "keys", il);
Vcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_tokens, n_head_q, cur->nb[1], cur->nb[2]*n_head_v, cur->nb[2]*(num_query_heads[il]+n_head_k)));
Vcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_tokens, n_head_v, cur->nb[1], cur->nb[2]*n_head_v, cur->nb[2]*(num_query_heads[il]+n_head_k)));
cb(Vcur, "values", il);
// Q/K Layernorm
cb(Qcur, "queries", il);
Qcur = llm_build_norm(ctx0, Qcur, modified_hparams,
model.layers[il].attn_q_norm,
NULL,
LLM_NORM_RMS, cb, il);
Kcur = llm_build_norm(ctx0, Kcur, modified_hparams,
model.layers[il].attn_k_norm,
NULL,
LLM_NORM_RMS, cb, il);
cb(Kcur, "keys", il);
// reshape, Qcur -> [64][12(first layer)][n_tokens]
// reshape, Kcur -> [64][3(first layer)][n_tokens]
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, num_query_heads[il], n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_k, n_tokens);
Qcur = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 0, 2, 1, 3));
Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
Qcur = ggml_rope_custom(
ctx0, Qcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "queries", il);
// Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head)));
// cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom(
ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
int64_t nev[GGML_MAX_DIMS] = {Vcur->ne[0], Vcur->ne[1], 4*Vcur->ne[2], Vcur->ne[3]};
struct ggml_tensor * Vcur2 = ggml_new_tensor(ctx0, Vcur->type, GGML_MAX_DIMS, nev);
Vcur2 = ggml_repeat(ctx0, Vcur, Vcur2);
Vcur = ggml_repeat(ctx0, Vcur, Vcur2);
// Vcur = Vcur2;
cb(Vcur, "values", il);
int64_t nek[GGML_MAX_DIMS] = {Kcur->ne[0], Kcur->ne[1], 4*Kcur->ne[2], Kcur->ne[3]};
struct ggml_tensor * Kcur2 = ggml_new_tensor(ctx0, Vcur->type, GGML_MAX_DIMS, nek);
Kcur2 = ggml_repeat(ctx0, Kcur, Kcur2);
struct ggml_tensor * Kcur2 = ggml_new_tensor(ctx0, Kcur->type, GGML_MAX_DIMS, nek);
Kcur = ggml_repeat(ctx0, Kcur, Kcur2);
// Kcur = Kcur2;
cb(Kcur, "keys", il);
Vcur = ggml_reshape_2d(ctx0, Vcur, 4*modified_hparams.n_embd_head_v*n_head_v, n_tokens);
Qcur = ggml_reshape_3d(ctx0, Qcur, modified_hparams.n_embd_head_v, n_head_q, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, 4*modified_hparams.n_embd_head_v, n_head_k, n_tokens);
cur = llm_build_kv(ctx0, model, modified_hparams, kv_self, gf,
model.layers[il].wo, NULL,
Kcur2, Vcur2, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il);
model.layers[il].wo, nullptr,
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, n_head_kv, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
if (il == n_layer - 1) {
// skip computing output for unused tokens
@ -10946,15 +10836,10 @@ struct llm_build_context {
residual = ggml_get_rows(ctx0, residual, inp_out_ids);
}
cur = ggml_add(ctx0, cur, residual);
cur = llm_build_norm(ctx0, cur, modified_hparams,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm", il);
residual = cur;
// FF
{
cur = llm_build_norm(ctx0, cur, hparams,
cur = llm_build_norm(ctx0, cur, modified_hparams,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm", il);
@ -10969,14 +10854,12 @@ struct llm_build_context {
LLM_FFN_SILU2, LLM_FFN_SEQ, cb, il);
cb(cur, "ffn_out", il);
}
residual = cur;
cur = ggml_add(ctx0, residual, cur);
cb(cur, "l_out", il);
inpL = cur;
}
cur = llm_build_norm(ctx0, cur, hparams,
cur = llm_build_norm(ctx0, cur, modified_hparams,
model.output_norm, NULL,
LLM_NORM_RMS, cb, -1);
cb(cur, "result_norm", -1);