it runs; tokenization is messed up; pooling is wrong for multi batches

This commit is contained in:
Douglas Hanley 2024-02-08 01:02:10 -06:00
parent ef10d7867e
commit 0051c82d52
2 changed files with 96 additions and 79 deletions

View file

@ -1590,6 +1590,33 @@ class BertModel(Model):
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
self.gguf_writer.add_file_type(self.ftype)
def set_vocab(self):
path = self.dir_model
added_tokens_path = self.dir_model if self.dir_model.exists() else None
vocab = HfVocab(path, added_tokens_path)
tokens, scores, toktypes = zip(*vocab.all_tokens())
assert len(tokens) == vocab.vocab_size
# for some reason set(toktypes) = {1, 3} so we need to compress it
all_types, toktypes1 = np.unique(toktypes, return_inverse=True)
n_token_types, toktypes1 = len(all_types), toktypes1.tolist()
self.gguf_writer.add_uint32("tokenizer.ggml.token_type_count", n_token_types)
# convert tokens to SPM style
tokens = [
(t[2:] if t.startswith(b"##") else b"\xe2\x96\x81" + t) for t in tokens
]
self.gguf_writer.add_tokenizer_model("llama")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_scores(scores)
self.gguf_writer.add_token_types(toktypes) # ignore types for now (all zero)
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
special_vocab.add_to_gguf(self.gguf_writer)
def write_tensors(self):
tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
tensors = dict(self.get_tensors())

148
llama.cpp
View file

@ -359,6 +359,7 @@ struct LLM_KV {
enum llm_tensor {
LLM_TENSOR_TOKEN_EMBD,
LLM_TENSOR_TOKEN_EMBD_NORM,
LLM_TENSOR_TOKEN_TYPES,
LLM_TENSOR_POS_EMBD,
LLM_TENSOR_OUTPUT,
LLM_TENSOR_OUTPUT_NORM,
@ -547,13 +548,12 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
{ LLM_TENSOR_TOKEN_TYPES, "token_types" },
{ LLM_TENSOR_POS_EMBD, "position_embd" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_output_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.layer_output_norm" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
@ -1519,6 +1519,8 @@ struct llama_hparams {
float f_clamp_kqv;
float f_max_alibi_bias;
bool causal_attn = true;
bool operator!=(const llama_hparams & other) const {
if (this->vocab_only != other.vocab_only) return true;
@ -1611,8 +1613,6 @@ struct llama_layer {
struct ggml_tensor * bqkv;
// normalization
struct ggml_tensor * attn_out_norm;
struct ggml_tensor * attn_out_norm_b;
struct ggml_tensor * ffn_norm;
struct ggml_tensor * ffn_norm_b;
@ -1631,10 +1631,6 @@ struct llama_layer {
struct ggml_tensor * ffn_down_b; // b2
struct ggml_tensor * ffn_up_b; // b3
struct ggml_tensor * ffn_act;
// normalization
struct ggml_tensor * layer_out_norm;
struct ggml_tensor * layer_out_norm_b;
};
struct llama_kv_cell {
@ -3036,7 +3032,7 @@ static void llm_load_hparams(
case LLM_ARCH_BERT:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
ml.get_key(LLM_KV_TOKEN_TYPE_COUNT, hparams.n_vocab_type);
ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type);
switch (hparams.n_embd) {
case 384: // MiniLM
@ -3279,7 +3275,11 @@ static void llm_load_vocab(
// determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n'
if (vocab.type == LLAMA_VOCAB_TYPE_SPM) {
vocab.linefeed_id = llama_byte_to_token(vocab, '\n');
try {
vocab.linefeed_id = llama_byte_to_token(vocab, '\n');
} catch (const std::exception & e) {
LLAMA_LOG_WARN("%s: model vocab missing newline token: %s\n", __func__, e.what());
}
} else {
const std::vector<int> ids = llama_tokenize_internal(vocab, "\u010A", false);
GGML_ASSERT(!ids.empty() && "model vocab missing newline token");
@ -3617,6 +3617,7 @@ static bool llm_load_tensors(
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
const int64_t n_embd_gqa = n_embd_v_gqa;
const int64_t n_vocab = hparams.n_vocab;
const int64_t n_vocab_type = hparams.n_vocab_type;
const int64_t n_ff = hparams.n_ff;
GGML_ASSERT(n_embd_gqa == n_embd_k_gqa);
@ -3834,7 +3835,7 @@ static bool llm_load_tensors(
case LLM_ARCH_BERT:
{
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
model.type_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_vocab_type, n_embd});
model.type_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type});
model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, hparams.n_ctx_train});
model.tok_norm = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd});
model.tok_norm_b = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd});
@ -3845,11 +3846,11 @@ static bool llm_load_tensors(
auto & layer = model.layers[i];
layer.attn_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd});
layer.attn_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd});
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd});
layer.layer_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd});
layer.layer_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd});
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd});
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd});
@ -4826,6 +4827,7 @@ struct llm_build_context {
const int32_t n_orig_ctx;
const bool do_rope_shift;
const bool causal_attn;
const llm_build_cb & cb;
@ -4869,6 +4871,7 @@ struct llm_build_context {
kv_head (worst_case ? n_ctx - n_tokens : kv_self.head),
n_orig_ctx (cparams.n_yarn_orig_ctx),
do_rope_shift (worst_case || kv_self.has_shift),
causal_attn (hparams.causal_attn),
cb (cb),
buf_compute_meta (lctx.buf_compute_meta) {
// all initializations should be done in init()
@ -5722,69 +5725,50 @@ struct llm_build_context {
struct ggml_tensor * cur;
struct ggml_tensor * inpL;
// construct input embeddings (token, type, position)
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
// inp_pos - contains the positions
struct ggml_tensor * inp_type = ggml_view_1d(ctx0, lctx.inp_type, n_tokens, 0);
inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.type_embd, inp_type), inpL);
struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.pos_embd, inp_pos), inpL);
cb(inpL, "inp_embd", -1);
inpL = ggml_add(ctx0,
ggml_get_rows(ctx0, model.type_embd, lctx.inp_type),
inpL);
inpL = ggml_add(ctx0,
ggml_get_rows(ctx0, model.pos_embd, lctx.inp_pos),
inpL);
// embed layer norm
inpL = llm_build_norm(ctx0, inpL, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1);
inpL = llm_build_norm(ctx0, inpL, hparams,
model.tok_norm,
model.tok_norm_b,
LLM_NORM, cb, -1);
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
cb(KQ_mask, "KQ_mask", -1); // [n_kv, n_tokens]
// iterate layers
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * cur = inpL;
// self-attention
{
// compute Q and K
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
struct ggml_tensor * Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), model.layers[il].bq);
cb(Qcur, "Qcur", il);
struct ggml_tensor * Kcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), model.layers[il].bk);
cb(Kcur, "Kcur", il);
struct ggml_tensor * Vcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), model.layers[il].bv);
cb(Vcur, "Vcur", il);
// seems like we just need to do this for Q?
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head, n_tokens);
struct ggml_tensor * K = ggml_permute(ctx0, Kcur, 0, 2, 1, 3);
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head, n_tokens);
struct ggml_tensor * V = ggml_permute(ctx0, Vcur, 0, 2, 1, 3);
struct ggml_tensor *KQ = ggml_mul_mat(ctx0, K, Q);
// KQ = soft_max(KQ / sqrt(head width))
KQ = ggml_soft_max(ctx0,
ggml_scale(ctx0, KQ, 1.0f / sqrt((float)n_embd_head)));
V = ggml_cont(ctx0, ggml_transpose(ctx0, V));
struct ggml_tensor *KQV = ggml_mul_mat(ctx0, V, KQ);
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, KQV), n_embd, N);
// attention output
cur = ggml_add(ctx0,
model.layers[il].bo,
ggml_mul_mat(ctx0, model.layers[il].wo, cur));
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il);
}
// re-add the layer input
cur = ggml_add(ctx0, cur, inpSA);
cur = ggml_add(ctx0, cur, inpL);
// attention layer norm
cur = llm_build_norm(ctx0, cur, hparams,
model.layers[il].attn_out_norm,
model.layers[il].attn_out_norm_b,
LLM_NORM, cb, il);
cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_norm, model.layers[il].attn_norm_b, LLM_NORM, cb, il);
struct ggml_tensor * ffn_inp = cur;
@ -5800,19 +5784,19 @@ struct llm_build_context {
cur = ggml_add(ctx0, cur, ffn_inp);
// output layer norm
cur = llm_build_norm(ctx0, cur, hparams,
model.layers[il].layer_out_norm,
model.layers[il].layer_out_norm_b,
LLM_NORM, cb, il);
cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].ffn_norm, model.layers[il].ffn_norm_b, LLM_NORM, cb, il);
// input for next layer
inpL = cur;
}
// final output
cur = inpL;
// pooling (sum = [L, 1, B])
cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), lctx.inp_sum); // [E, 1, B]
// pooling
struct ggml_tensor * inp_sum = ggml_view_1d(ctx0, lctx.inp_sum, n_tokens, 0);
cur = ggml_mul_mat(ctx0, inp_sum, ggml_cont(ctx0, ggml_transpose(ctx0, cur)));
cb(cur, "result_embed", -1);
ggml_build_forward_expand(gf, cur);
@ -7260,7 +7244,8 @@ static struct ggml_cgraph * llama_build_graph(
for (int i = 0; i < n_kv; ++i) {
float f;
if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
if (!lctx.kv_self.cells[i].has_seq_id(seq_id) ||
(llm.causal_attn && lctx.kv_self.cells[i].pos > pos)) {
f = -INFINITY;
} else {
f = 0;
@ -7283,7 +7268,7 @@ static struct ggml_cgraph * llama_build_graph(
}
{
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_inp_sum->buffer));
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_sum->buffer));
float * data = (float *) lctx.inp_sum->data;
for (int i = 0; i < batch.n_tokens; ++i) {
@ -7482,13 +7467,18 @@ static int llama_decode_internal(
// the output is always the last tensor in the graph
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
GGML_ASSERT(strcmp(res->name, "result_output") == 0);
// the embeddings could be the second to last tensor, or the third to last tensor
struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2];
if (strcmp(embeddings->name, "result_norm") != 0) {
embeddings = gf->nodes[gf->n_nodes - 3];
GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0);
struct ggml_tensor * embeddings = nullptr;
if (strcmp(res->name, "result_embed") == 0) {
embeddings = res;
res = nullptr;
} else {
// the embeddings could be the second to last tensor, or the third to last tensor
GGML_ASSERT(strcmp(res->name, "result_output") == 0);
embeddings = gf->nodes[gf->n_nodes - 2];
if (strcmp(embeddings->name, "result_norm") != 0) {
embeddings = gf->nodes[gf->n_nodes - 3];
GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0);
}
}
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
@ -7555,7 +7545,7 @@ static int llama_decode_internal(
// extract logits
// TODO: do not compute and extract logits if only embeddings are needed
// need to update the graphs to skip "result_output"
{
if (res) {
auto & logits_out = lctx.logits;
#ifndef NDEBUG
@ -7601,7 +7591,7 @@ static int llama_decode_internal(
embedding_out.resize(n_embd);
ggml_backend_t embeddings_backend = ggml_backend_sched_get_node_backend(lctx.sched, embeddings);
ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), (n_embd*(n_tokens - 1))*sizeof(float), n_embd*sizeof(float));
ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), 0, n_embd*sizeof(float));
ggml_backend_synchronize(embeddings_backend);
}