llama : minor

This commit is contained in:
Georgi Gerganov 2024-07-04 13:45:36 +03:00
parent ded682d43b
commit 01cd5a6670
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 11 additions and 5 deletions

View file

@ -532,6 +532,7 @@ int main(int argc, char ** argv) {
if (decoder_start_token_id == -1) { if (decoder_start_token_id == -1) {
decoder_start_token_id = llama_token_bos(model); decoder_start_token_id = llama_token_bos(model);
} }
embd_inp.clear(); embd_inp.clear();
embd_inp.push_back(decoder_start_token_id); embd_inp.push_back(decoder_start_token_id);
} }

View file

@ -2120,7 +2120,6 @@ struct llama_hparams {
uint32_t n_expert_used = 0; uint32_t n_expert_used = 0;
uint32_t n_vocab_type = 0; // for BERT-style token types uint32_t n_vocab_type = 0; // for BERT-style token types
uint32_t n_rel_attn_bkts = 0; uint32_t n_rel_attn_bkts = 0;
int32_t dec_start_token_id = -1;
uint32_t n_layer_dense_lead = 0; uint32_t n_layer_dense_lead = 0;
uint32_t n_lora_q = 0; uint32_t n_lora_q = 0;
@ -2156,6 +2155,10 @@ struct llama_hparams {
bool use_alibi = false; bool use_alibi = false;
bool attn_soft_cap = false; bool attn_soft_cap = false;
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
// ref: https://github.com/ggerganov/llama.cpp/pull/8141
llama_token dec_start_token_id = -1;
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE; enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE; enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
@ -2177,7 +2180,6 @@ struct llama_hparams {
if (this->n_expert_used != other.n_expert_used) return true; if (this->n_expert_used != other.n_expert_used) return true;
if (this->n_rel_attn_bkts != other.n_rel_attn_bkts) return true; if (this->n_rel_attn_bkts != other.n_rel_attn_bkts) return true;
if (this->dec_start_token_id != other.dec_start_token_id) return true;
if (this->n_layer_dense_lead != other.n_layer_dense_lead) return true; if (this->n_layer_dense_lead != other.n_layer_dense_lead) return true;
if (this->n_lora_q != other.n_lora_q) return true; if (this->n_lora_q != other.n_lora_q) return true;
if (this->n_lora_kv != other.n_lora_kv) return true; if (this->n_lora_kv != other.n_lora_kv) return true;
@ -2193,6 +2195,8 @@ struct llama_hparams {
if (this->ssm_d_state != other.ssm_d_state) return true; if (this->ssm_d_state != other.ssm_d_state) return true;
if (this->ssm_dt_rank != other.ssm_dt_rank) return true; if (this->ssm_dt_rank != other.ssm_dt_rank) return true;
if (this->dec_start_token_id != other.dec_start_token_id) return true;
const float EPSILON = 1e-9f; const float EPSILON = 1e-9f;
if (!is_float_close(this->f_norm_eps, other.f_norm_eps, EPSILON)) return true; if (!is_float_close(this->f_norm_eps, other.f_norm_eps, EPSILON)) return true;
@ -4970,9 +4974,10 @@ static void llm_load_hparams(
{ {
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts); ml.get_key(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts);
uint32_t decoder_start_token_id;
if (ml.get_key(LLM_KV_DECODER_START_TOKEN_ID, decoder_start_token_id, false)) { uint32_t dec_start_token_id;
hparams.dec_start_token_id = decoder_start_token_id; if (ml.get_key(LLM_KV_DECODER_START_TOKEN_ID, dec_start_token_id, false)) {
hparams.dec_start_token_id = dec_start_token_id;
} }
switch (hparams.n_layer) { switch (hparams.n_layer) {