From 904d4edfa1c3ad79e32e48365168bbe0e5bc36f2 Mon Sep 17 00:00:00 2001 From: Cebtenzzre Date: Thu, 14 Sep 2023 13:26:10 -0400 Subject: [PATCH] llama : store YaRN parameters in GGUF --- common/common.cpp | 73 +++++++++++++--------- common/common.h | 10 +-- convert.py | 53 +++++++++++----- examples/server/server.cpp | 39 ++++++++---- gguf-py/gguf/gguf.py | 27 ++++++-- llama.cpp | 124 ++++++++++++++++++++++++++----------- llama.h | 34 ++++++---- 7 files changed, 245 insertions(+), 115 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 9e4452dab..ca4b9c1cc 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -192,36 +192,46 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.rope_freq_scale = std::stof(argv[i]); + } else if (arg == "--rope-scaling") { + if (++i >= argc) { + invalid_param = true; + break; + } + std::string value(argv[i]); + /**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_NONE; } + else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_LINEAR; } + else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_YARN; } + else { invalid_param = true; break; } } else if (arg == "--rope-scale") { if (++i >= argc) { invalid_param = true; break; } params.rope_freq_scale = 1.0f/std::stof(argv[i]); - } else if (arg == "--rope-ext-factor") { + } else if (arg == "--yarn-ext-factor") { if (++i >= argc) { invalid_param = true; break; } - params.rope_ext_factor = std::stof(argv[i]); - } else if (arg == "--rope-attn-factor") { + params.yarn_ext_factor = std::stof(argv[i]); + } else if (arg == "--yarn-attn-factor") { if (++i >= argc) { invalid_param = true; break; } - params.rope_attn_factor = std::stof(argv[i]); - } else if (arg == "--rope-beta-fast") { + params.yarn_attn_factor = std::stof(argv[i]); + } else if (arg == "--yarn-beta-fast") { if (++i >= argc) { invalid_param = true; break; } - params.rope_beta_fast = std::stof(argv[i]); - } else if (arg == "--rope-beta-slow") { + params.yarn_beta_fast = std::stof(argv[i]); + } else if (arg == "--yarn-beta-slow") { if (++i >= argc) { invalid_param = true; break; } - params.rope_beta_slow = std::stof(argv[i]); + params.yarn_beta_slow = std::stof(argv[i]); } else if (arg == "--memory-f32") { params.memory_f16 = false; } else if (arg == "--top-p") { @@ -671,13 +681,15 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" --cfg-negative-prompt-file FNAME\n"); printf(" negative prompt file to use for guidance. (default: empty)\n"); printf(" --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale); - printf(" --rope-scale N RoPE context linear scaling factor, inverse of --rope-freq-scale\n"); + printf(" --rope-scaling {none,linear,yarn}\n"); + printf(" RoPE frequency scaling method, defaults to linear unless specified by the model\n"); + printf(" --rope-scale N RoPE context scaling factor, inverse of --rope-freq-scale\n"); printf(" --rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: loaded from model)\n"); - printf(" --rope-freq-scale N RoPE frequency linear scaling factor (default: loaded from model)\n"); - printf(" --rope-ext-factor N RoPE extrapolation mix factor (default: %.1f)\n", params.rope_ext_factor); - printf(" --rope-attn-factor N RoPE magnitude scaling factor (default: %.1f)\n", params.rope_attn_factor); - printf(" --rope-beta-fast N RoPE low correction dim (default: %.1f)\n", params.rope_beta_fast); - printf(" --rope-beta-slow N RoPE high correction dim (default: %.1f)\n", params.rope_beta_slow); + printf(" --rope-freq-scale N RoPE frequency scaling factor (default: loaded from model)\n"); + printf(" --yarn-ext-factor N YaRN extrapolation mix factor (default: %.1f)\n", params.yarn_ext_factor); + printf(" --yarn-attn-factor N YaRN magnitude scaling factor (default: %.1f)\n", params.yarn_attn_factor); + printf(" --yarn-beta-fast N YaRN low correction dim (default: %.1f)\n", params.yarn_beta_fast); + printf(" --yarn-beta-slow N YaRN high correction dim (default: %.1f)\n", params.yarn_beta_slow); printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n"); printf(" --no-penalize-nl do not penalize newline token\n"); printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n"); @@ -758,22 +770,23 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param if (params.n_gpu_layers != -1) { lparams.n_gpu_layers = params.n_gpu_layers; } - lparams.main_gpu = params.main_gpu; - lparams.tensor_split = params.tensor_split; - lparams.low_vram = params.low_vram; - lparams.mul_mat_q = params.mul_mat_q; - lparams.seed = params.seed; - lparams.f16_kv = params.memory_f16; - lparams.use_mmap = params.use_mmap; - lparams.use_mlock = params.use_mlock; - lparams.logits_all = params.perplexity; - lparams.embedding = params.embedding; - lparams.rope_freq_base = params.rope_freq_base; - lparams.rope_freq_scale = params.rope_freq_scale; - lparams.rope_ext_factor = params.rope_ext_factor; - lparams.rope_attn_factor = params.rope_attn_factor; - lparams.rope_beta_fast = params.rope_beta_fast; - lparams.rope_beta_slow = params.rope_beta_slow; + lparams.main_gpu = params.main_gpu; + lparams.tensor_split = params.tensor_split; + lparams.low_vram = params.low_vram; + lparams.mul_mat_q = params.mul_mat_q; + lparams.seed = params.seed; + lparams.f16_kv = params.memory_f16; + lparams.use_mmap = params.use_mmap; + lparams.use_mlock = params.use_mlock; + lparams.logits_all = params.perplexity; + lparams.embedding = params.embedding; + lparams.rope_scaling_type = params.rope_scaling_type; + lparams.rope_freq_base = params.rope_freq_base; + lparams.rope_freq_scale = params.rope_freq_scale; + lparams.yarn_ext_factor = params.yarn_ext_factor; + lparams.yarn_attn_factor = params.yarn_attn_factor; + lparams.yarn_beta_fast = params.yarn_beta_fast; + lparams.yarn_beta_slow = params.yarn_beta_slow; return lparams; } diff --git a/common/common.h b/common/common.h index 0b45b4278..a1e7da128 100644 --- a/common/common.h +++ b/common/common.h @@ -50,10 +50,12 @@ struct gpt_params { int32_t n_beams = 0; // if non-zero then use beam search of given width. float rope_freq_base = 10000.0f; // RoPE base frequency float rope_freq_scale = 1.0f; // RoPE frequency scaling factor - float rope_ext_factor = 0.0f; // RoPE extrapolation mix factor - float rope_attn_factor = 1.0f; // RoPE magnitude scaling factor - float rope_beta_fast = 32.0f; // RoPE low correction dim - float rope_beta_slow = 1.0f; // RoPE high correction dim + float yarn_ext_factor = 0.0f; // YaRN extrapolation mix factor + float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor + float yarn_beta_fast = 32.0f; // YaRN low correction dim + float yarn_beta_slow = 1.0f; // YaRN high correction dim + + llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_UNSPECIFIED; // sampling parameters int32_t top_k = 40; // <= 0 to use vocab size diff --git a/convert.py b/convert.py index 649624cff..f08cf01c5 100755 --- a/convert.py +++ b/convert.py @@ -152,8 +152,11 @@ class Params: n_head_kv: int f_norm_eps: float + rope_scaling_type: gguf.RopeScalingType | None = None f_rope_freq_base: float | None = None f_rope_scale: float | None = None + n_orig_ctx: int | None = None + rope_finetuned: bool | None = None ftype: GGMLFileType | None = None @@ -199,11 +202,20 @@ class Params: def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params: config = json.load(open(config_path)) + rope_scaling_type = f_rope_scale = n_orig_ctx = rope_finetuned = None rope_scaling = config.get("rope_scaling") - if isinstance(rope_scaling, dict) and rope_scaling.get("type") == "linear": - f_rope_scale = config["rope_scaling"].get("factor") - else: - f_rope_scale = None + + if rope_scaling is not None and typ := rope_scaling.get("type"): + rope_factor = rope_scaling.get("factor") + f_rope_scale = rope_factor + if typ == "linear": + rope_scaling_type = RopeScalingType.LINEAR + elif typ == "yarn": + rope_scaling_type = RopeScalingType.YARN + n_orig_ctx = rope_scaling['original_max_position_embeddings'] + rope_finetuned = rope_scaling['finetuned'] + else: + raise NotImplementedError(f'Unknown rope scaling type: {typ}') if "max_sequence_length" in config: n_ctx = config["max_sequence_length"] @@ -214,16 +226,18 @@ class Params: "Suggestion: provide 'config.json' of the model in the same directory containing model files.") return Params( - n_vocab = config["vocab_size"], - n_embd = config["hidden_size"], - n_layer = config["num_hidden_layers"], - n_ctx = n_ctx, - n_ff = config["intermediate_size"], - n_head = config["num_attention_heads"], - n_head_kv = config["num_key_value_heads"] if "num_key_value_heads" in config else n_head, - f_norm_eps = config["rms_norm_eps"], - f_rope_freq_base = config["rope_theta"] if "rope_theta" in config else None, - f_rope_scale = f_rope_scale, + n_vocab = config["vocab_size"], + n_embd = config["hidden_size"], + n_layer = config["num_hidden_layers"], + n_ctx = n_ctx, + n_ff = config["intermediate_size"], + n_head = config["num_attention_heads"], + n_head_kv = config["num_key_value_heads"] if "num_key_value_heads" in config else n_head, + f_norm_eps = config["rms_norm_eps"], + f_rope_freq_base = config["rope_theta"] if "rope_theta" in config else None, + f_rope_scale = f_rope_scale, + n_orig_ctx = n_orig_ctx, + rope_finetuned = rope_finetuned, ) # LLaMA v2 70B params.json @@ -819,8 +833,15 @@ class OutputFile: if params.f_rope_freq_base is not None: self.gguf.add_rope_freq_base(params.f_rope_freq_base) - if params.f_rope_scale is not None: - self.gguf.add_rope_scale_linear(params.f_rope_scale) + if params.rope_scaling_type: + self.gguf.add_rope_scaling_type(params.rope_scaling_type) + self.gguf.add_rope_scaling_factor(params.f_rope_scale) + + if params.n_orig_ctx is not None: + self.gguf.add_rope_original_context_length(params.n_orig_ctx) + + if params.rope_finetuned is not None: + self.gguf.add_rope_finetuned(params.rope_finetuned) if params.ftype is not None: self.gguf.add_file_type(params.ftype) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 0fb4e2c32..3a1c55b1c 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -701,12 +701,14 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, printf(" -v, --verbose verbose output (default: %s)\n", server_verbose ? "enabled" : "disabled"); printf(" -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); printf(" -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx); + printf(" --rope-scaling {none,linear,yarn}\n"); + printf(" RoPE frequency scaling method, defaults to linear unless specified by the model\n"); printf(" --rope-freq-base N RoPE base frequency (default: loaded from model)\n"); printf(" --rope-freq-scale N RoPE frequency scaling factor (default: loaded from model)\n"); - printf(" --rope-ext-factor N RoPE extrapolation mix factor (default: %.1f)\n", params.rope_ext_factor); - printf(" --rope-attn-factor N RoPE magnitude scaling factor (default: %.1f)\n", params.rope_attn_factor); - printf(" --rope-beta-fast N RoPE low correction dim (default: %.1f)\n", params.rope_beta_fast); - printf(" --rope-beta-slow N RoPE high correction dim (default: %.1f)\n", params.rope_beta_slow); + printf(" --yarn-ext-factor N YaRN extrapolation mix factor (default: %.1f)\n", params.yarn_ext_factor); + printf(" --yarn-attn-factor N YaRN magnitude scaling factor (default: %.1f)\n", params.yarn_attn_factor); + printf(" --yarn-beta-fast N YaRN low correction dim (default: %.1f)\n", params.yarn_beta_fast); + printf(" --yarn-beta-slow N YaRN high correction dim (default: %.1f)\n", params.yarn_beta_slow); printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n"); printf(" not recommended: doubles context memory required and no measurable increase in quality\n"); @@ -824,6 +826,19 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } params.n_ctx = std::stoi(argv[i]); } + else if (arg == "--rope-scaling") + { + if (++i >= argc) + { + invalid_param = true; + break; + } + std::string value(argv[i]); + /**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_NONE; } + else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_LINEAR; } + else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_YARN; } + else { invalid_param = true; break; } + } else if (arg == "--rope-freq-base") { if (++i >= argc) @@ -842,37 +857,37 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } params.rope_freq_scale = std::stof(argv[i]); } - else if (arg == "--rope-ext-factor") + else if (arg == "--yarn-ext-factor") { if (++i >= argc) { invalid_param = true; break; } - params.rope_ext_factor = std::stof(argv[i]); + params.yarn_ext_factor = std::stof(argv[i]); } - else if (arg == "--rope-attn-factor") + else if (arg == "--yarn-attn-factor") { if (++i >= argc) { invalid_param = true; break; } - params.rope_attn_factor = std::stof(argv[i]); + params.yarn_attn_factor = std::stof(argv[i]); } - else if (arg == "--rope-beta-fast") + else if (arg == "--yarn-beta-fast") { if (++i >= argc) { invalid_param = true; break; } - params.rope_beta_fast = std::stof(argv[i]); + params.yarn_beta_fast = std::stof(argv[i]); } - else if (arg == "--rope-beta-slow") + else if (arg == "--yarn-beta-slow") { if (++i >= argc) { invalid_param = true; break; } - params.rope_beta_slow = std::stof(argv[i]); + params.yarn_beta_slow = std::stof(argv[i]); } else if (arg == "--memory-f32" || arg == "--memory_f32") { diff --git a/gguf-py/gguf/gguf.py b/gguf-py/gguf/gguf.py index e0e0dbcbb..742ce2143 100644 --- a/gguf-py/gguf/gguf.py +++ b/gguf-py/gguf/gguf.py @@ -52,9 +52,12 @@ KEY_ATTENTION_LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon" KEY_ATTENTION_LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon" # RoPE -KEY_ROPE_DIMENSION_COUNT = "{arch}.rope.dimension_count" -KEY_ROPE_FREQ_BASE = "{arch}.rope.freq_base" -KEY_ROPE_SCALE_LINEAR = "{arch}.rope.scale_linear" +KEY_ROPE_DIMENSION_COUNT = "{arch}.rope.dimension_count" +KEY_ROPE_FREQ_BASE = "{arch}.rope.freq_base" +KEY_ROPE_SCALING_TYPE = "{arch}.rope.scaling.type" +KEY_ROPE_SCALING_FACTOR = "{arch}.rope.scaling.factor" +KEY_ROPE_SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length" +KEY_ROPE_SCALING_FINETUNED = "{arch}.rope.scaling.finetuned" # tokenization KEY_TOKENIZER_MODEL = "tokenizer.ggml.model" @@ -407,6 +410,11 @@ class TokenType(IntEnum): UNUSED = 5 BYTE = 6 +class RopeScalingType(IntEnum): + NONE = 0 + LINEAR = 1 + YARN = 2 + # # implementation # @@ -760,8 +768,17 @@ class GGUFWriter: def add_rope_freq_base(self, value: float): self.add_float32(KEY_ROPE_FREQ_BASE.format(arch=self.arch), value) - def add_rope_scale_linear(self, value: float): - self.add_float32(KEY_ROPE_SCALE_LINEAR.format(arch=self.arch), value) + def add_rope_scaling_type(self, value: RopeScalingType): + self.add_uint8(KEY_ROPE_SCALING_TYPE.format(arch=self.arch), int(value)) + + def add_rope_scaling_factor(self, value: float): + self.add_float32(KEY_ROPE_SCALING_FACTOR.format(arch=self.arch), value) + + def add_rope_scaling_orig_ctx_len(self, value: int): + self.add_uint32(KEY_ROPE_SCALING_ORIG_CTX_LEN.format(arch=self.arch), value) + + def add_rope_scaling_finetuned(self, value: bool): + self.add_bool(KEY_ROPE_SCALING_FINETUNED.format(arch=self.arch), value) def add_tokenizer_model(self, model: str): self.add_string(KEY_TOKENIZER_MODEL, model) diff --git a/llama.cpp b/llama.cpp index 87aea2468..cd545b254 100644 --- a/llama.cpp +++ b/llama.cpp @@ -204,7 +204,10 @@ enum llm_kv { LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_FREQ_BASE, - LLM_KV_ROPE_SCALE_LINEAR, + LLM_KV_ROPE_SCALING_TYPE, + LLM_KV_ROPE_SCALING_FACTOR, + LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, + LLM_KV_ROPE_SCALING_FINETUNED, LLM_KV_TOKENIZER_MODEL, LLM_KV_TOKENIZER_LIST, @@ -246,9 +249,12 @@ static std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" }, { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" }, - { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, - { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, - { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, + { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, + { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, + { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, + { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, + { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, + { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, @@ -943,12 +949,17 @@ struct llama_hparams { float f_norm_eps; float f_norm_rms_eps; - float rope_freq_base; - float rope_freq_scale; - float rope_ext_factor; - float rope_attn_factor; - float rope_beta_fast; - float rope_beta_slow; + float rope_freq_base; + float rope_freq_scale; + bool rope_finetuned; + uint32_t n_yarn_orig_ctx; + + // These hyperparameters are not exposed in GGUF, because all + // existing YaRN models use the same values for them. + float yarn_ext_factor; + float yarn_attn_factor; + float yarn_beta_fast; + float yarn_beta_slow; bool operator!=(const llama_hparams & other) const { return static_cast(memcmp(this, &other, sizeof(llama_hparams))); // NOLINT @@ -1660,10 +1671,10 @@ static void llm_load_hparams(llama_model_loader & ml, llama_model & model, const hparams.n_ctx = params.n_ctx; hparams.rope_freq_base = params.rope_freq_base; hparams.rope_freq_scale = params.rope_freq_scale; - hparams.rope_ext_factor = params.rope_ext_factor; - hparams.rope_attn_factor = params.rope_attn_factor; - hparams.rope_beta_fast = params.rope_beta_fast; - hparams.rope_beta_slow = params.rope_beta_slow; + hparams.yarn_ext_factor = params.yarn_ext_factor; + hparams.yarn_attn_factor = params.yarn_attn_factor; + hparams.yarn_beta_fast = params.yarn_beta_fast; + hparams.yarn_beta_slow = params.yarn_beta_slow; // get general kv GGUF_GET_KEY(ctx, model.name, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_GENERAL_NAME)); @@ -1680,6 +1691,14 @@ static void llm_load_hparams(llama_model_loader & ml, llama_model & model, const hparams.n_head_kv = hparams.n_head; GGUF_GET_KEY(ctx, hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ATTENTION_HEAD_COUNT_KV)); + hparams.rope_finetuned = false; + GGUF_GET_KEY(ctx, hparams.rope_finetuned, gguf_get_val_bool, GGUF_TYPE_BOOL, false, + kv(LLM_KV_ROPE_SCALING_FINETUNED)); + + hparams.n_yarn_orig_ctx = 0; + GGUF_GET_KEY(ctx, hparams.n_yarn_orig_ctx, gguf_get_val_u32, GGUF_TYPE_UINT32, false, + kv(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN)); + // rope_freq_base (optional) if (hparams.rope_freq_base == 0.0f) { float rope_freq_base = 10000.0f; @@ -1687,13 +1706,28 @@ static void llm_load_hparams(llama_model_loader & ml, llama_model & model, const hparams.rope_freq_base = rope_freq_base; } + llama_rope_scaling_type rope_scaling_type = params.rope_scaling_type; + + if (rope_scaling_type == LLAMA_ROPE_SCALING_UNSPECIFIED) { + uint8_t type = LLAMA_ROPE_SCALING_LINEAR; + GGUF_GET_KEY(ctx, type, gguf_get_val_u8, GGUF_TYPE_UINT8, false, kv(LLM_KV_ROPE_SCALING_TYPE)); + rope_scaling_type = llama_rope_scaling_type(type); + } + GGML_ASSERT(rope_scaling_type >= 0 && rope_scaling_type <= LLAMA_ROPE_SCALING_MAX_VALUE); + // rope_freq_scale (inverse of the kv) is optional - if (hparams.rope_freq_scale == 0.0f) { + if (rope_scaling_type == LLAMA_ROPE_SCALING_NONE) { + hparams.rope_freq_scale = 1.0f; + } else if (hparams.rope_freq_scale == 0.0f) { float ropescale = 1.0f; - GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR)); + GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALING_FACTOR)); hparams.rope_freq_scale = 1.0f/ropescale; } + if (rope_scaling_type == LLAMA_ROPE_SCALING_YARN) { + hparams.yarn_ext_factor = 1.0f; // enable YaRN + } + // sanity check for n_rot (optional) { hparams.n_rot = hparams.n_embd / hparams.n_head; @@ -1902,6 +1936,11 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale); + LLAMA_LOG_INFO("%s: YaRN scaling = %g\n", __func__, hparams.yarn_ext_factor); + LLAMA_LOG_INFO("%s: YaRN orig ctx = %u\n", __func__, hparams.n_yarn_orig_ctx); + LLAMA_LOG_INFO("%s: YaRN beta_fast = %f\n", __func__, hparams.yarn_beta_fast); + LLAMA_LOG_INFO("%s: YaRN beta_slow = %f\n", __func__, hparams.yarn_beta_slow); + LLAMA_LOG_INFO("%s: RoPE finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "no"); LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type)); LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str()); LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, ml.n_elements*1e-9); @@ -2444,10 +2483,10 @@ static struct ggml_cgraph * llm_build_llama( const float freq_base = hparams.rope_freq_base; const float freq_scale = hparams.rope_freq_scale; - const float ext_factor = hparams.rope_ext_factor; - const float attn_factor = hparams.rope_attn_factor; - const float beta_fast = hparams.rope_beta_fast; - const float beta_slow = hparams.rope_beta_slow; + const float ext_factor = hparams.yarn_ext_factor; + const float attn_factor = hparams.yarn_attn_factor; + const float beta_fast = hparams.yarn_beta_fast; + const float beta_slow = hparams.yarn_beta_slow; const float norm_rms_eps = hparams.f_norm_eps; const int n_gpu_layers = model.n_gpu_layers; @@ -2561,15 +2600,13 @@ static struct ggml_cgraph * llm_build_llama( struct ggml_tensor * Kcur = ggml_rope_custom_inplace( ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, - freq_scale, ext_factor, attn_factor, beta_fast, beta_slow - ); + freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); offload_func_kq(Kcur); ggml_set_name(Kcur, "Kcur"); struct ggml_tensor * Qcur = ggml_rope_custom_inplace( ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, - freq_scale, ext_factor, attn_factor, beta_fast, beta_slow - ); + freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); offload_func_kq(Qcur); ggml_set_name(Qcur, "Qcur"); @@ -2786,6 +2823,10 @@ static struct ggml_cgraph * llm_build_baichaun( const float freq_base = hparams.rope_freq_base; const float freq_scale = hparams.rope_freq_scale; + const float ext_factor = hparams.yarn_ext_factor; + const float attn_factor = hparams.yarn_attn_factor; + const float beta_fast = hparams.yarn_beta_fast; + const float beta_slow = hparams.yarn_beta_slow; const float norm_rms_eps = hparams.f_norm_rms_eps; const int n_gpu_layers = model.n_gpu_layers; @@ -2901,8 +2942,16 @@ static struct ggml_cgraph * llm_build_baichaun( struct ggml_tensor * Qcur; switch (model.type) { case MODEL_7B: - Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); - Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); + Kcur = ggml_rope_custom_inplace( + ctx0, + ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), + n_past, n_embd_head, 0, 0, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow + ); + Qcur = ggml_rope_custom_inplace( + ctx0, + ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), + n_past, n_embd_head, 0, 0, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow + ); break; case MODEL_13B: Kcur = ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N); @@ -3146,10 +3195,10 @@ static struct ggml_cgraph * llm_build_falcon( const float freq_base = hparams.rope_freq_base; const float freq_scale = hparams.rope_freq_scale; - const float ext_factor = hparams.rope_ext_factor; - const float attn_factor = hparams.rope_attn_factor; - const float beta_fast = hparams.rope_beta_fast; - const float beta_slow = hparams.rope_beta_slow; + const float ext_factor = hparams.yarn_ext_factor; + const float attn_factor = hparams.yarn_attn_factor; + const float beta_fast = hparams.yarn_beta_fast; + const float beta_slow = hparams.yarn_beta_slow; const float norm_eps = hparams.f_norm_eps; const int n_gpu_layers = model.n_gpu_layers; @@ -3302,11 +3351,13 @@ static struct ggml_cgraph * llm_build_falcon( // using mode = 2 for neox mode struct ggml_tensor * Qcur = ggml_rope_custom_inplace( - ctx0, tmpq, n_past, n_embd_head, 2, 0, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow + ctx0, tmpq, n_past, n_embd_head, 2, 0, + freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); offload_func_kq(Qcur); struct ggml_tensor * Kcur = ggml_rope_custom_inplace( - ctx0, tmpk, n_past, n_embd_head, 2, 0, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow + ctx0, tmpk, n_past, n_embd_head, 2, 0, + freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); offload_func_kq(Kcur); @@ -6186,10 +6237,11 @@ struct llama_context_params llama_context_default_params() { /*.tensor_split =*/ nullptr, /*.rope_freq_base =*/ 0.0f, /*.rope_freq_scale =*/ 0.0f, - /*.rope_ext_factor =*/ 0.0f, - /*.rope_attn_factor =*/ 1.0f, - /*.rope_beta_fast =*/ 32.0f, - /*.rope_beta_slow =*/ 1.0f, + /*.yarn_ext_factor =*/ 0.0f, + /*.yarn_attn_factor =*/ 1.0f, + /*.yarn_beta_fast =*/ 32.0f, + /*.yarn_beta_slow =*/ 1.0f, + /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_UNSPECIFIED, /*.progress_callback =*/ nullptr, /*.progress_callback_user_data =*/ nullptr, /*.low_vram =*/ false, diff --git a/llama.h b/llama.h index 762362bc6..5d69997bf 100644 --- a/llama.h +++ b/llama.h @@ -108,6 +108,14 @@ extern "C" { LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file }; + enum llama_rope_scaling_type: int8_t { + LLAMA_ROPE_SCALING_UNSPECIFIED = -1, + LLAMA_ROPE_SCALING_NONE = 0, + LLAMA_ROPE_SCALING_LINEAR = 1, + LLAMA_ROPE_SCALING_YARN = 2, + LLAMA_ROPE_SCALING_MAX_VALUE = LLAMA_ROPE_SCALING_YARN, + }; + typedef struct llama_token_data { llama_token id; // token id float logit; // log-odds of the token @@ -134,10 +142,12 @@ extern "C" { // ref: https://github.com/ggerganov/llama.cpp/pull/2054 float rope_freq_base; // RoPE base frequency float rope_freq_scale; // RoPE frequency scaling factor - float rope_ext_factor; // RoPE extrapolation mix factor - float rope_attn_factor; // RoPE magnitude scaling factor - float rope_beta_fast; // RoPE low correction dim - float rope_beta_slow; // RoPE high correction dim + float yarn_ext_factor; // YaRN extrapolation mix factor + float yarn_attn_factor; // YaRN magnitude scaling factor + float yarn_beta_fast; // YaRN low correction dim + float yarn_beta_slow; // YaRN high correction dim + + llama_rope_scaling_type rope_scaling_type; // called with a progress value between 0 and 1, pass NULL to disable llama_progress_callback progress_callback; @@ -145,14 +155,14 @@ extern "C" { void * progress_callback_user_data; // Keep the booleans together to avoid misalignment during copy-by-value. - bool low_vram; // if true, reduce VRAM usage at the cost of performance - bool mul_mat_q; // if true, use experimental mul_mat_q kernels - bool f16_kv; // use fp16 for KV cache - bool logits_all; // the llama_eval() call computes all logits, not just the last one - bool vocab_only; // only load the vocabulary, no weights - bool use_mmap; // use mmap if possible - bool use_mlock; // force system to keep model in RAM - bool embedding; // embedding mode only + bool low_vram; // if true, reduce VRAM usage at the cost of performance + bool mul_mat_q; // if true, use experimental mul_mat_q kernels + bool f16_kv; // use fp16 for KV cache + bool logits_all; // the llama_eval() call computes all logits, not just the last one + bool vocab_only; // only load the vocabulary, no weights + bool use_mmap; // use mmap if possible + bool use_mlock; // force system to keep model in RAM + bool embedding; // embedding mode only }; // Signature for logging events