From 0051c82d52de7ce81fb4aa427af4f181e61091cd Mon Sep 17 00:00:00 2001 From: Douglas Hanley Date: Thu, 8 Feb 2024 01:02:10 -0600 Subject: [PATCH] it runs; tokenization is messed up; pooling is wrong for multi batches --- convert-hf-to-gguf.py | 27 ++++++++ llama.cpp | 148 ++++++++++++++++++++---------------------- 2 files changed, 96 insertions(+), 79 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 88b9b912b..a952a5282 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1590,6 +1590,33 @@ class BertModel(Model): self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"]) self.gguf_writer.add_file_type(self.ftype) + def set_vocab(self): + path = self.dir_model + added_tokens_path = self.dir_model if self.dir_model.exists() else None + + vocab = HfVocab(path, added_tokens_path) + tokens, scores, toktypes = zip(*vocab.all_tokens()) + + assert len(tokens) == vocab.vocab_size + + # for some reason set(toktypes) = {1, 3} so we need to compress it + all_types, toktypes1 = np.unique(toktypes, return_inverse=True) + n_token_types, toktypes1 = len(all_types), toktypes1.tolist() + self.gguf_writer.add_uint32("tokenizer.ggml.token_type_count", n_token_types) + + # convert tokens to SPM style + tokens = [ + (t[2:] if t.startswith(b"##") else b"\xe2\x96\x81" + t) for t in tokens + ] + + self.gguf_writer.add_tokenizer_model("llama") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + self.gguf_writer.add_token_types(toktypes) # ignore types for now (all zero) + + special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) + special_vocab.add_to_gguf(self.gguf_writer) + def write_tensors(self): tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) tensors = dict(self.get_tensors()) diff --git a/llama.cpp b/llama.cpp index 18ff8c547..1118a291b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -359,6 +359,7 @@ struct LLM_KV { enum llm_tensor { LLM_TENSOR_TOKEN_EMBD, LLM_TENSOR_TOKEN_EMBD_NORM, + LLM_TENSOR_TOKEN_TYPES, LLM_TENSOR_POS_EMBD, LLM_TENSOR_OUTPUT, LLM_TENSOR_OUTPUT_NORM, @@ -547,13 +548,12 @@ static std::map> LLM_TENSOR_NAMES = { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, { LLM_TENSOR_TOKEN_TYPES, "token_types" }, { LLM_TENSOR_POS_EMBD, "position_embd" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_output_norm" }, { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, - { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.layer_output_norm" }, { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, @@ -1519,6 +1519,8 @@ struct llama_hparams { float f_clamp_kqv; float f_max_alibi_bias; + bool causal_attn = true; + bool operator!=(const llama_hparams & other) const { if (this->vocab_only != other.vocab_only) return true; @@ -1611,8 +1613,6 @@ struct llama_layer { struct ggml_tensor * bqkv; // normalization - struct ggml_tensor * attn_out_norm; - struct ggml_tensor * attn_out_norm_b; struct ggml_tensor * ffn_norm; struct ggml_tensor * ffn_norm_b; @@ -1631,10 +1631,6 @@ struct llama_layer { struct ggml_tensor * ffn_down_b; // b2 struct ggml_tensor * ffn_up_b; // b3 struct ggml_tensor * ffn_act; - - // normalization - struct ggml_tensor * layer_out_norm; - struct ggml_tensor * layer_out_norm_b; }; struct llama_kv_cell { @@ -3036,7 +3032,7 @@ static void llm_load_hparams( case LLM_ARCH_BERT: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_TOKEN_TYPE_COUNT, hparams.n_vocab_type); + ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type); switch (hparams.n_embd) { case 384: // MiniLM @@ -3279,7 +3275,11 @@ static void llm_load_vocab( // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n' if (vocab.type == LLAMA_VOCAB_TYPE_SPM) { - vocab.linefeed_id = llama_byte_to_token(vocab, '\n'); + try { + vocab.linefeed_id = llama_byte_to_token(vocab, '\n'); + } catch (const std::exception & e) { + LLAMA_LOG_WARN("%s: model vocab missing newline token: %s\n", __func__, e.what()); + } } else { const std::vector ids = llama_tokenize_internal(vocab, "\u010A", false); GGML_ASSERT(!ids.empty() && "model vocab missing newline token"); @@ -3617,6 +3617,7 @@ static bool llm_load_tensors( const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); const int64_t n_embd_gqa = n_embd_v_gqa; const int64_t n_vocab = hparams.n_vocab; + const int64_t n_vocab_type = hparams.n_vocab_type; const int64_t n_ff = hparams.n_ff; GGML_ASSERT(n_embd_gqa == n_embd_k_gqa); @@ -3834,7 +3835,7 @@ static bool llm_load_tensors( case LLM_ARCH_BERT: { model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); - model.type_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_vocab_type, n_embd}); + model.type_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type}); model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, hparams.n_ctx_train}); model.tok_norm = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); model.tok_norm_b = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}); @@ -3845,11 +3846,11 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}); - layer.attn_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}); + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); - layer.layer_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}); - layer.layer_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}); + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}); @@ -4826,6 +4827,7 @@ struct llm_build_context { const int32_t n_orig_ctx; const bool do_rope_shift; + const bool causal_attn; const llm_build_cb & cb; @@ -4869,6 +4871,7 @@ struct llm_build_context { kv_head (worst_case ? n_ctx - n_tokens : kv_self.head), n_orig_ctx (cparams.n_yarn_orig_ctx), do_rope_shift (worst_case || kv_self.has_shift), + causal_attn (hparams.causal_attn), cb (cb), buf_compute_meta (lctx.buf_compute_meta) { // all initializations should be done in init() @@ -5722,69 +5725,50 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; + // construct input embeddings (token, type, position) inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb); - - // inp_pos - contains the positions + struct ggml_tensor * inp_type = ggml_view_1d(ctx0, lctx.inp_type, n_tokens, 0); + inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.type_embd, inp_type), inpL); struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0); + inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.pos_embd, inp_pos), inpL); + cb(inpL, "inp_embd", -1); - inpL = ggml_add(ctx0, - ggml_get_rows(ctx0, model.type_embd, lctx.inp_type), - inpL); - inpL = ggml_add(ctx0, - ggml_get_rows(ctx0, model.pos_embd, lctx.inp_pos), - inpL); + // embed layer norm + inpL = llm_build_norm(ctx0, inpL, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1); - inpL = llm_build_norm(ctx0, inpL, hparams, - model.tok_norm, - model.tok_norm_b, - LLM_NORM, cb, -1); + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + cb(KQ_mask, "KQ_mask", -1); // [n_kv, n_tokens] + // iterate layers for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * cur = inpL; // self-attention { - // compute Q and K - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + struct ggml_tensor * Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), model.layers[il].bq); + cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), model.layers[il].bk); + cb(Kcur, "Kcur", il); + + struct ggml_tensor * Vcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), model.layers[il].bv); + cb(Vcur, "Vcur", il); + + // seems like we just need to do this for Q? Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head, n_tokens); - struct ggml_tensor * K = ggml_permute(ctx0, Kcur, 0, 2, 1, 3); - - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head, n_tokens); - struct ggml_tensor * V = ggml_permute(ctx0, Vcur, 0, 2, 1, 3); - - struct ggml_tensor *KQ = ggml_mul_mat(ctx0, K, Q); - // KQ = soft_max(KQ / sqrt(head width)) - KQ = ggml_soft_max(ctx0, - ggml_scale(ctx0, KQ, 1.0f / sqrt((float)n_embd_head))); - - V = ggml_cont(ctx0, ggml_transpose(ctx0, V)); - struct ggml_tensor *KQV = ggml_mul_mat(ctx0, V, KQ); - KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - - cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, KQV), n_embd, N); - - // attention output - cur = ggml_add(ctx0, - model.layers[il].bo, - ggml_mul_mat(ctx0, model.layers[il].wo, cur)); + cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); + cb(cur, "kqv_out", il); } // re-add the layer input - cur = ggml_add(ctx0, cur, inpSA); + cur = ggml_add(ctx0, cur, inpL); // attention layer norm - cur = llm_build_norm(ctx0, cur, hparams, - model.layers[il].attn_out_norm, - model.layers[il].attn_out_norm_b, - LLM_NORM, cb, il); + cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_norm, model.layers[il].attn_norm_b, LLM_NORM, cb, il); struct ggml_tensor * ffn_inp = cur; @@ -5800,19 +5784,19 @@ struct llm_build_context { cur = ggml_add(ctx0, cur, ffn_inp); // output layer norm - cur = llm_build_norm(ctx0, cur, hparams, - model.layers[il].layer_out_norm, - model.layers[il].layer_out_norm_b, - LLM_NORM, cb, il); + cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].ffn_norm, model.layers[il].ffn_norm_b, LLM_NORM, cb, il); // input for next layer inpL = cur; } + // final output cur = inpL; - // pooling (sum = [L, 1, B]) - cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), lctx.inp_sum); // [E, 1, B] + // pooling + struct ggml_tensor * inp_sum = ggml_view_1d(ctx0, lctx.inp_sum, n_tokens, 0); + cur = ggml_mul_mat(ctx0, inp_sum, ggml_cont(ctx0, ggml_transpose(ctx0, cur))); + cb(cur, "result_embed", -1); ggml_build_forward_expand(gf, cur); @@ -7260,7 +7244,8 @@ static struct ggml_cgraph * llama_build_graph( for (int i = 0; i < n_kv; ++i) { float f; - if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) { + if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || + (llm.causal_attn && lctx.kv_self.cells[i].pos > pos)) { f = -INFINITY; } else { f = 0; @@ -7283,7 +7268,7 @@ static struct ggml_cgraph * llama_build_graph( } { - GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_inp_sum->buffer)); + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_sum->buffer)); float * data = (float *) lctx.inp_sum->data; for (int i = 0; i < batch.n_tokens; ++i) { @@ -7482,13 +7467,18 @@ static int llama_decode_internal( // the output is always the last tensor in the graph struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; - GGML_ASSERT(strcmp(res->name, "result_output") == 0); - - // the embeddings could be the second to last tensor, or the third to last tensor - struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2]; - if (strcmp(embeddings->name, "result_norm") != 0) { - embeddings = gf->nodes[gf->n_nodes - 3]; - GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0); + struct ggml_tensor * embeddings = nullptr; + if (strcmp(res->name, "result_embed") == 0) { + embeddings = res; + res = nullptr; + } else { + // the embeddings could be the second to last tensor, or the third to last tensor + GGML_ASSERT(strcmp(res->name, "result_output") == 0); + embeddings = gf->nodes[gf->n_nodes - 2]; + if (strcmp(embeddings->name, "result_norm") != 0) { + embeddings = gf->nodes[gf->n_nodes - 3]; + GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0); + } } // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); @@ -7555,7 +7545,7 @@ static int llama_decode_internal( // extract logits // TODO: do not compute and extract logits if only embeddings are needed // need to update the graphs to skip "result_output" - { + if (res) { auto & logits_out = lctx.logits; #ifndef NDEBUG @@ -7601,7 +7591,7 @@ static int llama_decode_internal( embedding_out.resize(n_embd); ggml_backend_t embeddings_backend = ggml_backend_sched_get_node_backend(lctx.sched, embeddings); - ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), (n_embd*(n_tokens - 1))*sizeof(float), n_embd*sizeof(float)); + ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), 0, n_embd*sizeof(float)); ggml_backend_synchronize(embeddings_backend); }