llama : store YaRN parameters in GGUF
This commit is contained in:
parent
dc26a0dd32
commit
904d4edfa1
7 changed files with 245 additions and 115 deletions
|
@ -192,36 +192,46 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||
break;
|
||||
}
|
||||
params.rope_freq_scale = std::stof(argv[i]);
|
||||
} else if (arg == "--rope-scaling") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
std::string value(argv[i]);
|
||||
/**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_NONE; }
|
||||
else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_LINEAR; }
|
||||
else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_YARN; }
|
||||
else { invalid_param = true; break; }
|
||||
} else if (arg == "--rope-scale") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.rope_freq_scale = 1.0f/std::stof(argv[i]);
|
||||
} else if (arg == "--rope-ext-factor") {
|
||||
} else if (arg == "--yarn-ext-factor") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.rope_ext_factor = std::stof(argv[i]);
|
||||
} else if (arg == "--rope-attn-factor") {
|
||||
params.yarn_ext_factor = std::stof(argv[i]);
|
||||
} else if (arg == "--yarn-attn-factor") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.rope_attn_factor = std::stof(argv[i]);
|
||||
} else if (arg == "--rope-beta-fast") {
|
||||
params.yarn_attn_factor = std::stof(argv[i]);
|
||||
} else if (arg == "--yarn-beta-fast") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.rope_beta_fast = std::stof(argv[i]);
|
||||
} else if (arg == "--rope-beta-slow") {
|
||||
params.yarn_beta_fast = std::stof(argv[i]);
|
||||
} else if (arg == "--yarn-beta-slow") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.rope_beta_slow = std::stof(argv[i]);
|
||||
params.yarn_beta_slow = std::stof(argv[i]);
|
||||
} else if (arg == "--memory-f32") {
|
||||
params.memory_f16 = false;
|
||||
} else if (arg == "--top-p") {
|
||||
|
@ -671,13 +681,15 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
|||
printf(" --cfg-negative-prompt-file FNAME\n");
|
||||
printf(" negative prompt file to use for guidance. (default: empty)\n");
|
||||
printf(" --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale);
|
||||
printf(" --rope-scale N RoPE context linear scaling factor, inverse of --rope-freq-scale\n");
|
||||
printf(" --rope-scaling {none,linear,yarn}\n");
|
||||
printf(" RoPE frequency scaling method, defaults to linear unless specified by the model\n");
|
||||
printf(" --rope-scale N RoPE context scaling factor, inverse of --rope-freq-scale\n");
|
||||
printf(" --rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: loaded from model)\n");
|
||||
printf(" --rope-freq-scale N RoPE frequency linear scaling factor (default: loaded from model)\n");
|
||||
printf(" --rope-ext-factor N RoPE extrapolation mix factor (default: %.1f)\n", params.rope_ext_factor);
|
||||
printf(" --rope-attn-factor N RoPE magnitude scaling factor (default: %.1f)\n", params.rope_attn_factor);
|
||||
printf(" --rope-beta-fast N RoPE low correction dim (default: %.1f)\n", params.rope_beta_fast);
|
||||
printf(" --rope-beta-slow N RoPE high correction dim (default: %.1f)\n", params.rope_beta_slow);
|
||||
printf(" --rope-freq-scale N RoPE frequency scaling factor (default: loaded from model)\n");
|
||||
printf(" --yarn-ext-factor N YaRN extrapolation mix factor (default: %.1f)\n", params.yarn_ext_factor);
|
||||
printf(" --yarn-attn-factor N YaRN magnitude scaling factor (default: %.1f)\n", params.yarn_attn_factor);
|
||||
printf(" --yarn-beta-fast N YaRN low correction dim (default: %.1f)\n", params.yarn_beta_fast);
|
||||
printf(" --yarn-beta-slow N YaRN high correction dim (default: %.1f)\n", params.yarn_beta_slow);
|
||||
printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
|
||||
printf(" --no-penalize-nl do not penalize newline token\n");
|
||||
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
|
||||
|
@ -758,22 +770,23 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
|
|||
if (params.n_gpu_layers != -1) {
|
||||
lparams.n_gpu_layers = params.n_gpu_layers;
|
||||
}
|
||||
lparams.main_gpu = params.main_gpu;
|
||||
lparams.tensor_split = params.tensor_split;
|
||||
lparams.low_vram = params.low_vram;
|
||||
lparams.mul_mat_q = params.mul_mat_q;
|
||||
lparams.seed = params.seed;
|
||||
lparams.f16_kv = params.memory_f16;
|
||||
lparams.use_mmap = params.use_mmap;
|
||||
lparams.use_mlock = params.use_mlock;
|
||||
lparams.logits_all = params.perplexity;
|
||||
lparams.embedding = params.embedding;
|
||||
lparams.rope_freq_base = params.rope_freq_base;
|
||||
lparams.rope_freq_scale = params.rope_freq_scale;
|
||||
lparams.rope_ext_factor = params.rope_ext_factor;
|
||||
lparams.rope_attn_factor = params.rope_attn_factor;
|
||||
lparams.rope_beta_fast = params.rope_beta_fast;
|
||||
lparams.rope_beta_slow = params.rope_beta_slow;
|
||||
lparams.main_gpu = params.main_gpu;
|
||||
lparams.tensor_split = params.tensor_split;
|
||||
lparams.low_vram = params.low_vram;
|
||||
lparams.mul_mat_q = params.mul_mat_q;
|
||||
lparams.seed = params.seed;
|
||||
lparams.f16_kv = params.memory_f16;
|
||||
lparams.use_mmap = params.use_mmap;
|
||||
lparams.use_mlock = params.use_mlock;
|
||||
lparams.logits_all = params.perplexity;
|
||||
lparams.embedding = params.embedding;
|
||||
lparams.rope_scaling_type = params.rope_scaling_type;
|
||||
lparams.rope_freq_base = params.rope_freq_base;
|
||||
lparams.rope_freq_scale = params.rope_freq_scale;
|
||||
lparams.yarn_ext_factor = params.yarn_ext_factor;
|
||||
lparams.yarn_attn_factor = params.yarn_attn_factor;
|
||||
lparams.yarn_beta_fast = params.yarn_beta_fast;
|
||||
lparams.yarn_beta_slow = params.yarn_beta_slow;
|
||||
|
||||
return lparams;
|
||||
}
|
||||
|
|
|
@ -50,10 +50,12 @@ struct gpt_params {
|
|||
int32_t n_beams = 0; // if non-zero then use beam search of given width.
|
||||
float rope_freq_base = 10000.0f; // RoPE base frequency
|
||||
float rope_freq_scale = 1.0f; // RoPE frequency scaling factor
|
||||
float rope_ext_factor = 0.0f; // RoPE extrapolation mix factor
|
||||
float rope_attn_factor = 1.0f; // RoPE magnitude scaling factor
|
||||
float rope_beta_fast = 32.0f; // RoPE low correction dim
|
||||
float rope_beta_slow = 1.0f; // RoPE high correction dim
|
||||
float yarn_ext_factor = 0.0f; // YaRN extrapolation mix factor
|
||||
float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor
|
||||
float yarn_beta_fast = 32.0f; // YaRN low correction dim
|
||||
float yarn_beta_slow = 1.0f; // YaRN high correction dim
|
||||
|
||||
llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_UNSPECIFIED;
|
||||
|
||||
// sampling parameters
|
||||
int32_t top_k = 40; // <= 0 to use vocab size
|
||||
|
|
53
convert.py
53
convert.py
|
@ -152,8 +152,11 @@ class Params:
|
|||
n_head_kv: int
|
||||
f_norm_eps: float
|
||||
|
||||
rope_scaling_type: gguf.RopeScalingType | None = None
|
||||
f_rope_freq_base: float | None = None
|
||||
f_rope_scale: float | None = None
|
||||
n_orig_ctx: int | None = None
|
||||
rope_finetuned: bool | None = None
|
||||
|
||||
ftype: GGMLFileType | None = None
|
||||
|
||||
|
@ -199,11 +202,20 @@ class Params:
|
|||
def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params:
|
||||
config = json.load(open(config_path))
|
||||
|
||||
rope_scaling_type = f_rope_scale = n_orig_ctx = rope_finetuned = None
|
||||
rope_scaling = config.get("rope_scaling")
|
||||
if isinstance(rope_scaling, dict) and rope_scaling.get("type") == "linear":
|
||||
f_rope_scale = config["rope_scaling"].get("factor")
|
||||
else:
|
||||
f_rope_scale = None
|
||||
|
||||
if rope_scaling is not None and typ := rope_scaling.get("type"):
|
||||
rope_factor = rope_scaling.get("factor")
|
||||
f_rope_scale = rope_factor
|
||||
if typ == "linear":
|
||||
rope_scaling_type = RopeScalingType.LINEAR
|
||||
elif typ == "yarn":
|
||||
rope_scaling_type = RopeScalingType.YARN
|
||||
n_orig_ctx = rope_scaling['original_max_position_embeddings']
|
||||
rope_finetuned = rope_scaling['finetuned']
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown rope scaling type: {typ}')
|
||||
|
||||
if "max_sequence_length" in config:
|
||||
n_ctx = config["max_sequence_length"]
|
||||
|
@ -214,16 +226,18 @@ class Params:
|
|||
"Suggestion: provide 'config.json' of the model in the same directory containing model files.")
|
||||
|
||||
return Params(
|
||||
n_vocab = config["vocab_size"],
|
||||
n_embd = config["hidden_size"],
|
||||
n_layer = config["num_hidden_layers"],
|
||||
n_ctx = n_ctx,
|
||||
n_ff = config["intermediate_size"],
|
||||
n_head = config["num_attention_heads"],
|
||||
n_head_kv = config["num_key_value_heads"] if "num_key_value_heads" in config else n_head,
|
||||
f_norm_eps = config["rms_norm_eps"],
|
||||
f_rope_freq_base = config["rope_theta"] if "rope_theta" in config else None,
|
||||
f_rope_scale = f_rope_scale,
|
||||
n_vocab = config["vocab_size"],
|
||||
n_embd = config["hidden_size"],
|
||||
n_layer = config["num_hidden_layers"],
|
||||
n_ctx = n_ctx,
|
||||
n_ff = config["intermediate_size"],
|
||||
n_head = config["num_attention_heads"],
|
||||
n_head_kv = config["num_key_value_heads"] if "num_key_value_heads" in config else n_head,
|
||||
f_norm_eps = config["rms_norm_eps"],
|
||||
f_rope_freq_base = config["rope_theta"] if "rope_theta" in config else None,
|
||||
f_rope_scale = f_rope_scale,
|
||||
n_orig_ctx = n_orig_ctx,
|
||||
rope_finetuned = rope_finetuned,
|
||||
)
|
||||
|
||||
# LLaMA v2 70B params.json
|
||||
|
@ -819,8 +833,15 @@ class OutputFile:
|
|||
if params.f_rope_freq_base is not None:
|
||||
self.gguf.add_rope_freq_base(params.f_rope_freq_base)
|
||||
|
||||
if params.f_rope_scale is not None:
|
||||
self.gguf.add_rope_scale_linear(params.f_rope_scale)
|
||||
if params.rope_scaling_type:
|
||||
self.gguf.add_rope_scaling_type(params.rope_scaling_type)
|
||||
self.gguf.add_rope_scaling_factor(params.f_rope_scale)
|
||||
|
||||
if params.n_orig_ctx is not None:
|
||||
self.gguf.add_rope_original_context_length(params.n_orig_ctx)
|
||||
|
||||
if params.rope_finetuned is not None:
|
||||
self.gguf.add_rope_finetuned(params.rope_finetuned)
|
||||
|
||||
if params.ftype is not None:
|
||||
self.gguf.add_file_type(params.ftype)
|
||||
|
|
|
@ -701,12 +701,14 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms,
|
|||
printf(" -v, --verbose verbose output (default: %s)\n", server_verbose ? "enabled" : "disabled");
|
||||
printf(" -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
|
||||
printf(" -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
|
||||
printf(" --rope-scaling {none,linear,yarn}\n");
|
||||
printf(" RoPE frequency scaling method, defaults to linear unless specified by the model\n");
|
||||
printf(" --rope-freq-base N RoPE base frequency (default: loaded from model)\n");
|
||||
printf(" --rope-freq-scale N RoPE frequency scaling factor (default: loaded from model)\n");
|
||||
printf(" --rope-ext-factor N RoPE extrapolation mix factor (default: %.1f)\n", params.rope_ext_factor);
|
||||
printf(" --rope-attn-factor N RoPE magnitude scaling factor (default: %.1f)\n", params.rope_attn_factor);
|
||||
printf(" --rope-beta-fast N RoPE low correction dim (default: %.1f)\n", params.rope_beta_fast);
|
||||
printf(" --rope-beta-slow N RoPE high correction dim (default: %.1f)\n", params.rope_beta_slow);
|
||||
printf(" --yarn-ext-factor N YaRN extrapolation mix factor (default: %.1f)\n", params.yarn_ext_factor);
|
||||
printf(" --yarn-attn-factor N YaRN magnitude scaling factor (default: %.1f)\n", params.yarn_attn_factor);
|
||||
printf(" --yarn-beta-fast N YaRN low correction dim (default: %.1f)\n", params.yarn_beta_fast);
|
||||
printf(" --yarn-beta-slow N YaRN high correction dim (default: %.1f)\n", params.yarn_beta_slow);
|
||||
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
|
||||
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
|
||||
printf(" not recommended: doubles context memory required and no measurable increase in quality\n");
|
||||
|
@ -824,6 +826,19 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
|||
}
|
||||
params.n_ctx = std::stoi(argv[i]);
|
||||
}
|
||||
else if (arg == "--rope-scaling")
|
||||
{
|
||||
if (++i >= argc)
|
||||
{
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
std::string value(argv[i]);
|
||||
/**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_NONE; }
|
||||
else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_LINEAR; }
|
||||
else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_YARN; }
|
||||
else { invalid_param = true; break; }
|
||||
}
|
||||
else if (arg == "--rope-freq-base")
|
||||
{
|
||||
if (++i >= argc)
|
||||
|
@ -842,37 +857,37 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
|||
}
|
||||
params.rope_freq_scale = std::stof(argv[i]);
|
||||
}
|
||||
else if (arg == "--rope-ext-factor")
|
||||
else if (arg == "--yarn-ext-factor")
|
||||
{
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.rope_ext_factor = std::stof(argv[i]);
|
||||
params.yarn_ext_factor = std::stof(argv[i]);
|
||||
}
|
||||
else if (arg == "--rope-attn-factor")
|
||||
else if (arg == "--yarn-attn-factor")
|
||||
{
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.rope_attn_factor = std::stof(argv[i]);
|
||||
params.yarn_attn_factor = std::stof(argv[i]);
|
||||
}
|
||||
else if (arg == "--rope-beta-fast")
|
||||
else if (arg == "--yarn-beta-fast")
|
||||
{
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.rope_beta_fast = std::stof(argv[i]);
|
||||
params.yarn_beta_fast = std::stof(argv[i]);
|
||||
}
|
||||
else if (arg == "--rope-beta-slow")
|
||||
else if (arg == "--yarn-beta-slow")
|
||||
{
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.rope_beta_slow = std::stof(argv[i]);
|
||||
params.yarn_beta_slow = std::stof(argv[i]);
|
||||
}
|
||||
else if (arg == "--memory-f32" || arg == "--memory_f32")
|
||||
{
|
||||
|
|
|
@ -52,9 +52,12 @@ KEY_ATTENTION_LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
|
|||
KEY_ATTENTION_LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
|
||||
|
||||
# RoPE
|
||||
KEY_ROPE_DIMENSION_COUNT = "{arch}.rope.dimension_count"
|
||||
KEY_ROPE_FREQ_BASE = "{arch}.rope.freq_base"
|
||||
KEY_ROPE_SCALE_LINEAR = "{arch}.rope.scale_linear"
|
||||
KEY_ROPE_DIMENSION_COUNT = "{arch}.rope.dimension_count"
|
||||
KEY_ROPE_FREQ_BASE = "{arch}.rope.freq_base"
|
||||
KEY_ROPE_SCALING_TYPE = "{arch}.rope.scaling.type"
|
||||
KEY_ROPE_SCALING_FACTOR = "{arch}.rope.scaling.factor"
|
||||
KEY_ROPE_SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
|
||||
KEY_ROPE_SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
|
||||
|
||||
# tokenization
|
||||
KEY_TOKENIZER_MODEL = "tokenizer.ggml.model"
|
||||
|
@ -407,6 +410,11 @@ class TokenType(IntEnum):
|
|||
UNUSED = 5
|
||||
BYTE = 6
|
||||
|
||||
class RopeScalingType(IntEnum):
|
||||
NONE = 0
|
||||
LINEAR = 1
|
||||
YARN = 2
|
||||
|
||||
#
|
||||
# implementation
|
||||
#
|
||||
|
@ -760,8 +768,17 @@ class GGUFWriter:
|
|||
def add_rope_freq_base(self, value: float):
|
||||
self.add_float32(KEY_ROPE_FREQ_BASE.format(arch=self.arch), value)
|
||||
|
||||
def add_rope_scale_linear(self, value: float):
|
||||
self.add_float32(KEY_ROPE_SCALE_LINEAR.format(arch=self.arch), value)
|
||||
def add_rope_scaling_type(self, value: RopeScalingType):
|
||||
self.add_uint8(KEY_ROPE_SCALING_TYPE.format(arch=self.arch), int(value))
|
||||
|
||||
def add_rope_scaling_factor(self, value: float):
|
||||
self.add_float32(KEY_ROPE_SCALING_FACTOR.format(arch=self.arch), value)
|
||||
|
||||
def add_rope_scaling_orig_ctx_len(self, value: int):
|
||||
self.add_uint32(KEY_ROPE_SCALING_ORIG_CTX_LEN.format(arch=self.arch), value)
|
||||
|
||||
def add_rope_scaling_finetuned(self, value: bool):
|
||||
self.add_bool(KEY_ROPE_SCALING_FINETUNED.format(arch=self.arch), value)
|
||||
|
||||
def add_tokenizer_model(self, model: str):
|
||||
self.add_string(KEY_TOKENIZER_MODEL, model)
|
||||
|
|
124
llama.cpp
124
llama.cpp
|
@ -204,7 +204,10 @@ enum llm_kv {
|
|||
|
||||
LLM_KV_ROPE_DIMENSION_COUNT,
|
||||
LLM_KV_ROPE_FREQ_BASE,
|
||||
LLM_KV_ROPE_SCALE_LINEAR,
|
||||
LLM_KV_ROPE_SCALING_TYPE,
|
||||
LLM_KV_ROPE_SCALING_FACTOR,
|
||||
LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,
|
||||
LLM_KV_ROPE_SCALING_FINETUNED,
|
||||
|
||||
LLM_KV_TOKENIZER_MODEL,
|
||||
LLM_KV_TOKENIZER_LIST,
|
||||
|
@ -246,9 +249,12 @@ static std::map<llm_kv, std::string> LLM_KV_NAMES = {
|
|||
{ LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" },
|
||||
{ LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" },
|
||||
|
||||
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
|
||||
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
|
||||
{ LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" },
|
||||
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
|
||||
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
|
||||
{ LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" },
|
||||
{ LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" },
|
||||
{ LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" },
|
||||
{ LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" },
|
||||
|
||||
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
|
||||
{ LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" },
|
||||
|
@ -943,12 +949,17 @@ struct llama_hparams {
|
|||
float f_norm_eps;
|
||||
float f_norm_rms_eps;
|
||||
|
||||
float rope_freq_base;
|
||||
float rope_freq_scale;
|
||||
float rope_ext_factor;
|
||||
float rope_attn_factor;
|
||||
float rope_beta_fast;
|
||||
float rope_beta_slow;
|
||||
float rope_freq_base;
|
||||
float rope_freq_scale;
|
||||
bool rope_finetuned;
|
||||
uint32_t n_yarn_orig_ctx;
|
||||
|
||||
// These hyperparameters are not exposed in GGUF, because all
|
||||
// existing YaRN models use the same values for them.
|
||||
float yarn_ext_factor;
|
||||
float yarn_attn_factor;
|
||||
float yarn_beta_fast;
|
||||
float yarn_beta_slow;
|
||||
|
||||
bool operator!=(const llama_hparams & other) const {
|
||||
return static_cast<bool>(memcmp(this, &other, sizeof(llama_hparams))); // NOLINT
|
||||
|
@ -1660,10 +1671,10 @@ static void llm_load_hparams(llama_model_loader & ml, llama_model & model, const
|
|||
hparams.n_ctx = params.n_ctx;
|
||||
hparams.rope_freq_base = params.rope_freq_base;
|
||||
hparams.rope_freq_scale = params.rope_freq_scale;
|
||||
hparams.rope_ext_factor = params.rope_ext_factor;
|
||||
hparams.rope_attn_factor = params.rope_attn_factor;
|
||||
hparams.rope_beta_fast = params.rope_beta_fast;
|
||||
hparams.rope_beta_slow = params.rope_beta_slow;
|
||||
hparams.yarn_ext_factor = params.yarn_ext_factor;
|
||||
hparams.yarn_attn_factor = params.yarn_attn_factor;
|
||||
hparams.yarn_beta_fast = params.yarn_beta_fast;
|
||||
hparams.yarn_beta_slow = params.yarn_beta_slow;
|
||||
|
||||
// get general kv
|
||||
GGUF_GET_KEY(ctx, model.name, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_GENERAL_NAME));
|
||||
|
@ -1680,6 +1691,14 @@ static void llm_load_hparams(llama_model_loader & ml, llama_model & model, const
|
|||
hparams.n_head_kv = hparams.n_head;
|
||||
GGUF_GET_KEY(ctx, hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ATTENTION_HEAD_COUNT_KV));
|
||||
|
||||
hparams.rope_finetuned = false;
|
||||
GGUF_GET_KEY(ctx, hparams.rope_finetuned, gguf_get_val_bool, GGUF_TYPE_BOOL, false,
|
||||
kv(LLM_KV_ROPE_SCALING_FINETUNED));
|
||||
|
||||
hparams.n_yarn_orig_ctx = 0;
|
||||
GGUF_GET_KEY(ctx, hparams.n_yarn_orig_ctx, gguf_get_val_u32, GGUF_TYPE_UINT32, false,
|
||||
kv(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN));
|
||||
|
||||
// rope_freq_base (optional)
|
||||
if (hparams.rope_freq_base == 0.0f) {
|
||||
float rope_freq_base = 10000.0f;
|
||||
|
@ -1687,13 +1706,28 @@ static void llm_load_hparams(llama_model_loader & ml, llama_model & model, const
|
|||
hparams.rope_freq_base = rope_freq_base;
|
||||
}
|
||||
|
||||
llama_rope_scaling_type rope_scaling_type = params.rope_scaling_type;
|
||||
|
||||
if (rope_scaling_type == LLAMA_ROPE_SCALING_UNSPECIFIED) {
|
||||
uint8_t type = LLAMA_ROPE_SCALING_LINEAR;
|
||||
GGUF_GET_KEY(ctx, type, gguf_get_val_u8, GGUF_TYPE_UINT8, false, kv(LLM_KV_ROPE_SCALING_TYPE));
|
||||
rope_scaling_type = llama_rope_scaling_type(type);
|
||||
}
|
||||
GGML_ASSERT(rope_scaling_type >= 0 && rope_scaling_type <= LLAMA_ROPE_SCALING_MAX_VALUE);
|
||||
|
||||
// rope_freq_scale (inverse of the kv) is optional
|
||||
if (hparams.rope_freq_scale == 0.0f) {
|
||||
if (rope_scaling_type == LLAMA_ROPE_SCALING_NONE) {
|
||||
hparams.rope_freq_scale = 1.0f;
|
||||
} else if (hparams.rope_freq_scale == 0.0f) {
|
||||
float ropescale = 1.0f;
|
||||
GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
|
||||
GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALING_FACTOR));
|
||||
hparams.rope_freq_scale = 1.0f/ropescale;
|
||||
}
|
||||
|
||||
if (rope_scaling_type == LLAMA_ROPE_SCALING_YARN) {
|
||||
hparams.yarn_ext_factor = 1.0f; // enable YaRN
|
||||
}
|
||||
|
||||
// sanity check for n_rot (optional)
|
||||
{
|
||||
hparams.n_rot = hparams.n_embd / hparams.n_head;
|
||||
|
@ -1902,6 +1936,11 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
|
|||
LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff);
|
||||
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base);
|
||||
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale);
|
||||
LLAMA_LOG_INFO("%s: YaRN scaling = %g\n", __func__, hparams.yarn_ext_factor);
|
||||
LLAMA_LOG_INFO("%s: YaRN orig ctx = %u\n", __func__, hparams.n_yarn_orig_ctx);
|
||||
LLAMA_LOG_INFO("%s: YaRN beta_fast = %f\n", __func__, hparams.yarn_beta_fast);
|
||||
LLAMA_LOG_INFO("%s: YaRN beta_slow = %f\n", __func__, hparams.yarn_beta_slow);
|
||||
LLAMA_LOG_INFO("%s: RoPE finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "no");
|
||||
LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type));
|
||||
LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str());
|
||||
LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, ml.n_elements*1e-9);
|
||||
|
@ -2444,10 +2483,10 @@ static struct ggml_cgraph * llm_build_llama(
|
|||
|
||||
const float freq_base = hparams.rope_freq_base;
|
||||
const float freq_scale = hparams.rope_freq_scale;
|
||||
const float ext_factor = hparams.rope_ext_factor;
|
||||
const float attn_factor = hparams.rope_attn_factor;
|
||||
const float beta_fast = hparams.rope_beta_fast;
|
||||
const float beta_slow = hparams.rope_beta_slow;
|
||||
const float ext_factor = hparams.yarn_ext_factor;
|
||||
const float attn_factor = hparams.yarn_attn_factor;
|
||||
const float beta_fast = hparams.yarn_beta_fast;
|
||||
const float beta_slow = hparams.yarn_beta_slow;
|
||||
const float norm_rms_eps = hparams.f_norm_eps;
|
||||
|
||||
const int n_gpu_layers = model.n_gpu_layers;
|
||||
|
@ -2561,15 +2600,13 @@ static struct ggml_cgraph * llm_build_llama(
|
|||
|
||||
struct ggml_tensor * Kcur = ggml_rope_custom_inplace(
|
||||
ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base,
|
||||
freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
offload_func_kq(Kcur);
|
||||
ggml_set_name(Kcur, "Kcur");
|
||||
|
||||
struct ggml_tensor * Qcur = ggml_rope_custom_inplace(
|
||||
ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base,
|
||||
freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
offload_func_kq(Qcur);
|
||||
ggml_set_name(Qcur, "Qcur");
|
||||
|
||||
|
@ -2786,6 +2823,10 @@ static struct ggml_cgraph * llm_build_baichaun(
|
|||
|
||||
const float freq_base = hparams.rope_freq_base;
|
||||
const float freq_scale = hparams.rope_freq_scale;
|
||||
const float ext_factor = hparams.yarn_ext_factor;
|
||||
const float attn_factor = hparams.yarn_attn_factor;
|
||||
const float beta_fast = hparams.yarn_beta_fast;
|
||||
const float beta_slow = hparams.yarn_beta_slow;
|
||||
const float norm_rms_eps = hparams.f_norm_rms_eps;
|
||||
|
||||
const int n_gpu_layers = model.n_gpu_layers;
|
||||
|
@ -2901,8 +2942,16 @@ static struct ggml_cgraph * llm_build_baichaun(
|
|||
struct ggml_tensor * Qcur;
|
||||
switch (model.type) {
|
||||
case MODEL_7B:
|
||||
Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale);
|
||||
Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale);
|
||||
Kcur = ggml_rope_custom_inplace(
|
||||
ctx0,
|
||||
ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N),
|
||||
n_past, n_embd_head, 0, 0, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
Qcur = ggml_rope_custom_inplace(
|
||||
ctx0,
|
||||
ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N),
|
||||
n_past, n_embd_head, 0, 0, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
break;
|
||||
case MODEL_13B:
|
||||
Kcur = ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N);
|
||||
|
@ -3146,10 +3195,10 @@ static struct ggml_cgraph * llm_build_falcon(
|
|||
|
||||
const float freq_base = hparams.rope_freq_base;
|
||||
const float freq_scale = hparams.rope_freq_scale;
|
||||
const float ext_factor = hparams.rope_ext_factor;
|
||||
const float attn_factor = hparams.rope_attn_factor;
|
||||
const float beta_fast = hparams.rope_beta_fast;
|
||||
const float beta_slow = hparams.rope_beta_slow;
|
||||
const float ext_factor = hparams.yarn_ext_factor;
|
||||
const float attn_factor = hparams.yarn_attn_factor;
|
||||
const float beta_fast = hparams.yarn_beta_fast;
|
||||
const float beta_slow = hparams.yarn_beta_slow;
|
||||
const float norm_eps = hparams.f_norm_eps;
|
||||
|
||||
const int n_gpu_layers = model.n_gpu_layers;
|
||||
|
@ -3302,11 +3351,13 @@ static struct ggml_cgraph * llm_build_falcon(
|
|||
|
||||
// using mode = 2 for neox mode
|
||||
struct ggml_tensor * Qcur = ggml_rope_custom_inplace(
|
||||
ctx0, tmpq, n_past, n_embd_head, 2, 0, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||
ctx0, tmpq, n_past, n_embd_head, 2, 0,
|
||||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
offload_func_kq(Qcur);
|
||||
struct ggml_tensor * Kcur = ggml_rope_custom_inplace(
|
||||
ctx0, tmpk, n_past, n_embd_head, 2, 0, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||
ctx0, tmpk, n_past, n_embd_head, 2, 0,
|
||||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
offload_func_kq(Kcur);
|
||||
|
||||
|
@ -6186,10 +6237,11 @@ struct llama_context_params llama_context_default_params() {
|
|||
/*.tensor_split =*/ nullptr,
|
||||
/*.rope_freq_base =*/ 0.0f,
|
||||
/*.rope_freq_scale =*/ 0.0f,
|
||||
/*.rope_ext_factor =*/ 0.0f,
|
||||
/*.rope_attn_factor =*/ 1.0f,
|
||||
/*.rope_beta_fast =*/ 32.0f,
|
||||
/*.rope_beta_slow =*/ 1.0f,
|
||||
/*.yarn_ext_factor =*/ 0.0f,
|
||||
/*.yarn_attn_factor =*/ 1.0f,
|
||||
/*.yarn_beta_fast =*/ 32.0f,
|
||||
/*.yarn_beta_slow =*/ 1.0f,
|
||||
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_UNSPECIFIED,
|
||||
/*.progress_callback =*/ nullptr,
|
||||
/*.progress_callback_user_data =*/ nullptr,
|
||||
/*.low_vram =*/ false,
|
||||
|
|
34
llama.h
34
llama.h
|
@ -108,6 +108,14 @@ extern "C" {
|
|||
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
|
||||
};
|
||||
|
||||
enum llama_rope_scaling_type: int8_t {
|
||||
LLAMA_ROPE_SCALING_UNSPECIFIED = -1,
|
||||
LLAMA_ROPE_SCALING_NONE = 0,
|
||||
LLAMA_ROPE_SCALING_LINEAR = 1,
|
||||
LLAMA_ROPE_SCALING_YARN = 2,
|
||||
LLAMA_ROPE_SCALING_MAX_VALUE = LLAMA_ROPE_SCALING_YARN,
|
||||
};
|
||||
|
||||
typedef struct llama_token_data {
|
||||
llama_token id; // token id
|
||||
float logit; // log-odds of the token
|
||||
|
@ -134,10 +142,12 @@ extern "C" {
|
|||
// ref: https://github.com/ggerganov/llama.cpp/pull/2054
|
||||
float rope_freq_base; // RoPE base frequency
|
||||
float rope_freq_scale; // RoPE frequency scaling factor
|
||||
float rope_ext_factor; // RoPE extrapolation mix factor
|
||||
float rope_attn_factor; // RoPE magnitude scaling factor
|
||||
float rope_beta_fast; // RoPE low correction dim
|
||||
float rope_beta_slow; // RoPE high correction dim
|
||||
float yarn_ext_factor; // YaRN extrapolation mix factor
|
||||
float yarn_attn_factor; // YaRN magnitude scaling factor
|
||||
float yarn_beta_fast; // YaRN low correction dim
|
||||
float yarn_beta_slow; // YaRN high correction dim
|
||||
|
||||
llama_rope_scaling_type rope_scaling_type;
|
||||
|
||||
// called with a progress value between 0 and 1, pass NULL to disable
|
||||
llama_progress_callback progress_callback;
|
||||
|
@ -145,14 +155,14 @@ extern "C" {
|
|||
void * progress_callback_user_data;
|
||||
|
||||
// Keep the booleans together to avoid misalignment during copy-by-value.
|
||||
bool low_vram; // if true, reduce VRAM usage at the cost of performance
|
||||
bool mul_mat_q; // if true, use experimental mul_mat_q kernels
|
||||
bool f16_kv; // use fp16 for KV cache
|
||||
bool logits_all; // the llama_eval() call computes all logits, not just the last one
|
||||
bool vocab_only; // only load the vocabulary, no weights
|
||||
bool use_mmap; // use mmap if possible
|
||||
bool use_mlock; // force system to keep model in RAM
|
||||
bool embedding; // embedding mode only
|
||||
bool low_vram; // if true, reduce VRAM usage at the cost of performance
|
||||
bool mul_mat_q; // if true, use experimental mul_mat_q kernels
|
||||
bool f16_kv; // use fp16 for KV cache
|
||||
bool logits_all; // the llama_eval() call computes all logits, not just the last one
|
||||
bool vocab_only; // only load the vocabulary, no weights
|
||||
bool use_mmap; // use mmap if possible
|
||||
bool use_mlock; // force system to keep model in RAM
|
||||
bool embedding; // embedding mode only
|
||||
};
|
||||
|
||||
// Signature for logging events
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue