From 01cd5a6670e3368e51f830bceba110c8ee828f0a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 4 Jul 2024 13:45:36 +0300 Subject: [PATCH] llama : minor --- examples/main/main.cpp | 1 + src/llama.cpp | 15 ++++++++++----- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 6e5fe769d..22bb37889 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -532,6 +532,7 @@ int main(int argc, char ** argv) { if (decoder_start_token_id == -1) { decoder_start_token_id = llama_token_bos(model); } + embd_inp.clear(); embd_inp.push_back(decoder_start_token_id); } diff --git a/src/llama.cpp b/src/llama.cpp index d1fa683aa..c3edf70bc 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2120,7 +2120,6 @@ struct llama_hparams { uint32_t n_expert_used = 0; uint32_t n_vocab_type = 0; // for BERT-style token types uint32_t n_rel_attn_bkts = 0; - int32_t dec_start_token_id = -1; uint32_t n_layer_dense_lead = 0; uint32_t n_lora_q = 0; @@ -2156,6 +2155,10 @@ struct llama_hparams { bool use_alibi = false; bool attn_soft_cap = false; + // needed by encoder-decoder models (e.g. T5, FLAN-T5) + // ref: https://github.com/ggerganov/llama.cpp/pull/8141 + llama_token dec_start_token_id = -1; + enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE; enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE; @@ -2177,7 +2180,6 @@ struct llama_hparams { if (this->n_expert_used != other.n_expert_used) return true; if (this->n_rel_attn_bkts != other.n_rel_attn_bkts) return true; - if (this->dec_start_token_id != other.dec_start_token_id) return true; if (this->n_layer_dense_lead != other.n_layer_dense_lead) return true; if (this->n_lora_q != other.n_lora_q) return true; if (this->n_lora_kv != other.n_lora_kv) return true; @@ -2193,6 +2195,8 @@ struct llama_hparams { if (this->ssm_d_state != other.ssm_d_state) return true; if (this->ssm_dt_rank != other.ssm_dt_rank) return true; + if (this->dec_start_token_id != other.dec_start_token_id) return true; + const float EPSILON = 1e-9f; if (!is_float_close(this->f_norm_eps, other.f_norm_eps, EPSILON)) return true; @@ -4970,9 +4974,10 @@ static void llm_load_hparams( { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts); - uint32_t decoder_start_token_id; - if (ml.get_key(LLM_KV_DECODER_START_TOKEN_ID, decoder_start_token_id, false)) { - hparams.dec_start_token_id = decoder_start_token_id; + + uint32_t dec_start_token_id; + if (ml.get_key(LLM_KV_DECODER_START_TOKEN_ID, dec_start_token_id, false)) { + hparams.dec_start_token_id = dec_start_token_id; } switch (hparams.n_layer) {