llama : store YaRN parameters in GGUF

This commit is contained in:
Cebtenzzre 2023-09-14 13:26:10 -04:00
parent dc26a0dd32
commit 904d4edfa1
7 changed files with 245 additions and 115 deletions

View file

@ -192,36 +192,46 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break;
}
params.rope_freq_scale = std::stof(argv[i]);
} else if (arg == "--rope-scaling") {
if (++i >= argc) {
invalid_param = true;
break;
}
std::string value(argv[i]);
/**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_NONE; }
else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_LINEAR; }
else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_YARN; }
else { invalid_param = true; break; }
} else if (arg == "--rope-scale") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.rope_freq_scale = 1.0f/std::stof(argv[i]);
} else if (arg == "--rope-ext-factor") {
} else if (arg == "--yarn-ext-factor") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.rope_ext_factor = std::stof(argv[i]);
} else if (arg == "--rope-attn-factor") {
params.yarn_ext_factor = std::stof(argv[i]);
} else if (arg == "--yarn-attn-factor") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.rope_attn_factor = std::stof(argv[i]);
} else if (arg == "--rope-beta-fast") {
params.yarn_attn_factor = std::stof(argv[i]);
} else if (arg == "--yarn-beta-fast") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.rope_beta_fast = std::stof(argv[i]);
} else if (arg == "--rope-beta-slow") {
params.yarn_beta_fast = std::stof(argv[i]);
} else if (arg == "--yarn-beta-slow") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.rope_beta_slow = std::stof(argv[i]);
params.yarn_beta_slow = std::stof(argv[i]);
} else if (arg == "--memory-f32") {
params.memory_f16 = false;
} else if (arg == "--top-p") {
@ -671,13 +681,15 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --cfg-negative-prompt-file FNAME\n");
printf(" negative prompt file to use for guidance. (default: empty)\n");
printf(" --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale);
printf(" --rope-scale N RoPE context linear scaling factor, inverse of --rope-freq-scale\n");
printf(" --rope-scaling {none,linear,yarn}\n");
printf(" RoPE frequency scaling method, defaults to linear unless specified by the model\n");
printf(" --rope-scale N RoPE context scaling factor, inverse of --rope-freq-scale\n");
printf(" --rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: loaded from model)\n");
printf(" --rope-freq-scale N RoPE frequency linear scaling factor (default: loaded from model)\n");
printf(" --rope-ext-factor N RoPE extrapolation mix factor (default: %.1f)\n", params.rope_ext_factor);
printf(" --rope-attn-factor N RoPE magnitude scaling factor (default: %.1f)\n", params.rope_attn_factor);
printf(" --rope-beta-fast N RoPE low correction dim (default: %.1f)\n", params.rope_beta_fast);
printf(" --rope-beta-slow N RoPE high correction dim (default: %.1f)\n", params.rope_beta_slow);
printf(" --rope-freq-scale N RoPE frequency scaling factor (default: loaded from model)\n");
printf(" --yarn-ext-factor N YaRN extrapolation mix factor (default: %.1f)\n", params.yarn_ext_factor);
printf(" --yarn-attn-factor N YaRN magnitude scaling factor (default: %.1f)\n", params.yarn_attn_factor);
printf(" --yarn-beta-fast N YaRN low correction dim (default: %.1f)\n", params.yarn_beta_fast);
printf(" --yarn-beta-slow N YaRN high correction dim (default: %.1f)\n", params.yarn_beta_slow);
printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
printf(" --no-penalize-nl do not penalize newline token\n");
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
@ -758,22 +770,23 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
if (params.n_gpu_layers != -1) {
lparams.n_gpu_layers = params.n_gpu_layers;
}
lparams.main_gpu = params.main_gpu;
lparams.tensor_split = params.tensor_split;
lparams.low_vram = params.low_vram;
lparams.mul_mat_q = params.mul_mat_q;
lparams.seed = params.seed;
lparams.f16_kv = params.memory_f16;
lparams.use_mmap = params.use_mmap;
lparams.use_mlock = params.use_mlock;
lparams.logits_all = params.perplexity;
lparams.embedding = params.embedding;
lparams.rope_freq_base = params.rope_freq_base;
lparams.rope_freq_scale = params.rope_freq_scale;
lparams.rope_ext_factor = params.rope_ext_factor;
lparams.rope_attn_factor = params.rope_attn_factor;
lparams.rope_beta_fast = params.rope_beta_fast;
lparams.rope_beta_slow = params.rope_beta_slow;
lparams.main_gpu = params.main_gpu;
lparams.tensor_split = params.tensor_split;
lparams.low_vram = params.low_vram;
lparams.mul_mat_q = params.mul_mat_q;
lparams.seed = params.seed;
lparams.f16_kv = params.memory_f16;
lparams.use_mmap = params.use_mmap;
lparams.use_mlock = params.use_mlock;
lparams.logits_all = params.perplexity;
lparams.embedding = params.embedding;
lparams.rope_scaling_type = params.rope_scaling_type;
lparams.rope_freq_base = params.rope_freq_base;
lparams.rope_freq_scale = params.rope_freq_scale;
lparams.yarn_ext_factor = params.yarn_ext_factor;
lparams.yarn_attn_factor = params.yarn_attn_factor;
lparams.yarn_beta_fast = params.yarn_beta_fast;
lparams.yarn_beta_slow = params.yarn_beta_slow;
return lparams;
}

View file

@ -50,10 +50,12 @@ struct gpt_params {
int32_t n_beams = 0; // if non-zero then use beam search of given width.
float rope_freq_base = 10000.0f; // RoPE base frequency
float rope_freq_scale = 1.0f; // RoPE frequency scaling factor
float rope_ext_factor = 0.0f; // RoPE extrapolation mix factor
float rope_attn_factor = 1.0f; // RoPE magnitude scaling factor
float rope_beta_fast = 32.0f; // RoPE low correction dim
float rope_beta_slow = 1.0f; // RoPE high correction dim
float yarn_ext_factor = 0.0f; // YaRN extrapolation mix factor
float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor
float yarn_beta_fast = 32.0f; // YaRN low correction dim
float yarn_beta_slow = 1.0f; // YaRN high correction dim
llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_UNSPECIFIED;
// sampling parameters
int32_t top_k = 40; // <= 0 to use vocab size

View file

@ -152,8 +152,11 @@ class Params:
n_head_kv: int
f_norm_eps: float
rope_scaling_type: gguf.RopeScalingType | None = None
f_rope_freq_base: float | None = None
f_rope_scale: float | None = None
n_orig_ctx: int | None = None
rope_finetuned: bool | None = None
ftype: GGMLFileType | None = None
@ -199,11 +202,20 @@ class Params:
def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params:
config = json.load(open(config_path))
rope_scaling_type = f_rope_scale = n_orig_ctx = rope_finetuned = None
rope_scaling = config.get("rope_scaling")
if isinstance(rope_scaling, dict) and rope_scaling.get("type") == "linear":
f_rope_scale = config["rope_scaling"].get("factor")
else:
f_rope_scale = None
if rope_scaling is not None and typ := rope_scaling.get("type"):
rope_factor = rope_scaling.get("factor")
f_rope_scale = rope_factor
if typ == "linear":
rope_scaling_type = RopeScalingType.LINEAR
elif typ == "yarn":
rope_scaling_type = RopeScalingType.YARN
n_orig_ctx = rope_scaling['original_max_position_embeddings']
rope_finetuned = rope_scaling['finetuned']
else:
raise NotImplementedError(f'Unknown rope scaling type: {typ}')
if "max_sequence_length" in config:
n_ctx = config["max_sequence_length"]
@ -214,16 +226,18 @@ class Params:
"Suggestion: provide 'config.json' of the model in the same directory containing model files.")
return Params(
n_vocab = config["vocab_size"],
n_embd = config["hidden_size"],
n_layer = config["num_hidden_layers"],
n_ctx = n_ctx,
n_ff = config["intermediate_size"],
n_head = config["num_attention_heads"],
n_head_kv = config["num_key_value_heads"] if "num_key_value_heads" in config else n_head,
f_norm_eps = config["rms_norm_eps"],
f_rope_freq_base = config["rope_theta"] if "rope_theta" in config else None,
f_rope_scale = f_rope_scale,
n_vocab = config["vocab_size"],
n_embd = config["hidden_size"],
n_layer = config["num_hidden_layers"],
n_ctx = n_ctx,
n_ff = config["intermediate_size"],
n_head = config["num_attention_heads"],
n_head_kv = config["num_key_value_heads"] if "num_key_value_heads" in config else n_head,
f_norm_eps = config["rms_norm_eps"],
f_rope_freq_base = config["rope_theta"] if "rope_theta" in config else None,
f_rope_scale = f_rope_scale,
n_orig_ctx = n_orig_ctx,
rope_finetuned = rope_finetuned,
)
# LLaMA v2 70B params.json
@ -819,8 +833,15 @@ class OutputFile:
if params.f_rope_freq_base is not None:
self.gguf.add_rope_freq_base(params.f_rope_freq_base)
if params.f_rope_scale is not None:
self.gguf.add_rope_scale_linear(params.f_rope_scale)
if params.rope_scaling_type:
self.gguf.add_rope_scaling_type(params.rope_scaling_type)
self.gguf.add_rope_scaling_factor(params.f_rope_scale)
if params.n_orig_ctx is not None:
self.gguf.add_rope_original_context_length(params.n_orig_ctx)
if params.rope_finetuned is not None:
self.gguf.add_rope_finetuned(params.rope_finetuned)
if params.ftype is not None:
self.gguf.add_file_type(params.ftype)

View file

@ -701,12 +701,14 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
printf(" -v, --verbose verbose output (default: %s)\n", server_verbose ? "enabled" : "disabled");
printf(" -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
printf(" -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
printf(" --rope-scaling {none,linear,yarn}\n");
printf(" RoPE frequency scaling method, defaults to linear unless specified by the model\n");
printf(" --rope-freq-base N RoPE base frequency (default: loaded from model)\n");
printf(" --rope-freq-scale N RoPE frequency scaling factor (default: loaded from model)\n");
printf(" --rope-ext-factor N RoPE extrapolation mix factor (default: %.1f)\n", params.rope_ext_factor);
printf(" --rope-attn-factor N RoPE magnitude scaling factor (default: %.1f)\n", params.rope_attn_factor);
printf(" --rope-beta-fast N RoPE low correction dim (default: %.1f)\n", params.rope_beta_fast);
printf(" --rope-beta-slow N RoPE high correction dim (default: %.1f)\n", params.rope_beta_slow);
printf(" --yarn-ext-factor N YaRN extrapolation mix factor (default: %.1f)\n", params.yarn_ext_factor);
printf(" --yarn-attn-factor N YaRN magnitude scaling factor (default: %.1f)\n", params.yarn_attn_factor);
printf(" --yarn-beta-fast N YaRN low correction dim (default: %.1f)\n", params.yarn_beta_fast);
printf(" --yarn-beta-slow N YaRN high correction dim (default: %.1f)\n", params.yarn_beta_slow);
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
printf(" not recommended: doubles context memory required and no measurable increase in quality\n");
@ -824,6 +826,19 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
}
params.n_ctx = std::stoi(argv[i]);
}
else if (arg == "--rope-scaling")
{
if (++i >= argc)
{
invalid_param = true;
break;
}
std::string value(argv[i]);
/**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_NONE; }
else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_LINEAR; }
else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_YARN; }
else { invalid_param = true; break; }
}
else if (arg == "--rope-freq-base")
{
if (++i >= argc)
@ -842,37 +857,37 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
}
params.rope_freq_scale = std::stof(argv[i]);
}
else if (arg == "--rope-ext-factor")
else if (arg == "--yarn-ext-factor")
{
if (++i >= argc) {
invalid_param = true;
break;
}
params.rope_ext_factor = std::stof(argv[i]);
params.yarn_ext_factor = std::stof(argv[i]);
}
else if (arg == "--rope-attn-factor")
else if (arg == "--yarn-attn-factor")
{
if (++i >= argc) {
invalid_param = true;
break;
}
params.rope_attn_factor = std::stof(argv[i]);
params.yarn_attn_factor = std::stof(argv[i]);
}
else if (arg == "--rope-beta-fast")
else if (arg == "--yarn-beta-fast")
{
if (++i >= argc) {
invalid_param = true;
break;
}
params.rope_beta_fast = std::stof(argv[i]);
params.yarn_beta_fast = std::stof(argv[i]);
}
else if (arg == "--rope-beta-slow")
else if (arg == "--yarn-beta-slow")
{
if (++i >= argc) {
invalid_param = true;
break;
}
params.rope_beta_slow = std::stof(argv[i]);
params.yarn_beta_slow = std::stof(argv[i]);
}
else if (arg == "--memory-f32" || arg == "--memory_f32")
{

View file

@ -52,9 +52,12 @@ KEY_ATTENTION_LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
KEY_ATTENTION_LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
# RoPE
KEY_ROPE_DIMENSION_COUNT = "{arch}.rope.dimension_count"
KEY_ROPE_FREQ_BASE = "{arch}.rope.freq_base"
KEY_ROPE_SCALE_LINEAR = "{arch}.rope.scale_linear"
KEY_ROPE_DIMENSION_COUNT = "{arch}.rope.dimension_count"
KEY_ROPE_FREQ_BASE = "{arch}.rope.freq_base"
KEY_ROPE_SCALING_TYPE = "{arch}.rope.scaling.type"
KEY_ROPE_SCALING_FACTOR = "{arch}.rope.scaling.factor"
KEY_ROPE_SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
KEY_ROPE_SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
# tokenization
KEY_TOKENIZER_MODEL = "tokenizer.ggml.model"
@ -407,6 +410,11 @@ class TokenType(IntEnum):
UNUSED = 5
BYTE = 6
class RopeScalingType(IntEnum):
NONE = 0
LINEAR = 1
YARN = 2
#
# implementation
#
@ -760,8 +768,17 @@ class GGUFWriter:
def add_rope_freq_base(self, value: float):
self.add_float32(KEY_ROPE_FREQ_BASE.format(arch=self.arch), value)
def add_rope_scale_linear(self, value: float):
self.add_float32(KEY_ROPE_SCALE_LINEAR.format(arch=self.arch), value)
def add_rope_scaling_type(self, value: RopeScalingType):
self.add_uint8(KEY_ROPE_SCALING_TYPE.format(arch=self.arch), int(value))
def add_rope_scaling_factor(self, value: float):
self.add_float32(KEY_ROPE_SCALING_FACTOR.format(arch=self.arch), value)
def add_rope_scaling_orig_ctx_len(self, value: int):
self.add_uint32(KEY_ROPE_SCALING_ORIG_CTX_LEN.format(arch=self.arch), value)
def add_rope_scaling_finetuned(self, value: bool):
self.add_bool(KEY_ROPE_SCALING_FINETUNED.format(arch=self.arch), value)
def add_tokenizer_model(self, model: str):
self.add_string(KEY_TOKENIZER_MODEL, model)

124
llama.cpp
View file

@ -204,7 +204,10 @@ enum llm_kv {
LLM_KV_ROPE_DIMENSION_COUNT,
LLM_KV_ROPE_FREQ_BASE,
LLM_KV_ROPE_SCALE_LINEAR,
LLM_KV_ROPE_SCALING_TYPE,
LLM_KV_ROPE_SCALING_FACTOR,
LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,
LLM_KV_ROPE_SCALING_FINETUNED,
LLM_KV_TOKENIZER_MODEL,
LLM_KV_TOKENIZER_LIST,
@ -246,9 +249,12 @@ static std::map<llm_kv, std::string> LLM_KV_NAMES = {
{ LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" },
{ LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" },
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
{ LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" },
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
{ LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" },
{ LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" },
{ LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" },
{ LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" },
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
{ LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" },
@ -943,12 +949,17 @@ struct llama_hparams {
float f_norm_eps;
float f_norm_rms_eps;
float rope_freq_base;
float rope_freq_scale;
float rope_ext_factor;
float rope_attn_factor;
float rope_beta_fast;
float rope_beta_slow;
float rope_freq_base;
float rope_freq_scale;
bool rope_finetuned;
uint32_t n_yarn_orig_ctx;
// These hyperparameters are not exposed in GGUF, because all
// existing YaRN models use the same values for them.
float yarn_ext_factor;
float yarn_attn_factor;
float yarn_beta_fast;
float yarn_beta_slow;
bool operator!=(const llama_hparams & other) const {
return static_cast<bool>(memcmp(this, &other, sizeof(llama_hparams))); // NOLINT
@ -1660,10 +1671,10 @@ static void llm_load_hparams(llama_model_loader & ml, llama_model & model, const
hparams.n_ctx = params.n_ctx;
hparams.rope_freq_base = params.rope_freq_base;
hparams.rope_freq_scale = params.rope_freq_scale;
hparams.rope_ext_factor = params.rope_ext_factor;
hparams.rope_attn_factor = params.rope_attn_factor;
hparams.rope_beta_fast = params.rope_beta_fast;
hparams.rope_beta_slow = params.rope_beta_slow;
hparams.yarn_ext_factor = params.yarn_ext_factor;
hparams.yarn_attn_factor = params.yarn_attn_factor;
hparams.yarn_beta_fast = params.yarn_beta_fast;
hparams.yarn_beta_slow = params.yarn_beta_slow;
// get general kv
GGUF_GET_KEY(ctx, model.name, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_GENERAL_NAME));
@ -1680,6 +1691,14 @@ static void llm_load_hparams(llama_model_loader & ml, llama_model & model, const
hparams.n_head_kv = hparams.n_head;
GGUF_GET_KEY(ctx, hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ATTENTION_HEAD_COUNT_KV));
hparams.rope_finetuned = false;
GGUF_GET_KEY(ctx, hparams.rope_finetuned, gguf_get_val_bool, GGUF_TYPE_BOOL, false,
kv(LLM_KV_ROPE_SCALING_FINETUNED));
hparams.n_yarn_orig_ctx = 0;
GGUF_GET_KEY(ctx, hparams.n_yarn_orig_ctx, gguf_get_val_u32, GGUF_TYPE_UINT32, false,
kv(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN));
// rope_freq_base (optional)
if (hparams.rope_freq_base == 0.0f) {
float rope_freq_base = 10000.0f;
@ -1687,13 +1706,28 @@ static void llm_load_hparams(llama_model_loader & ml, llama_model & model, const
hparams.rope_freq_base = rope_freq_base;
}
llama_rope_scaling_type rope_scaling_type = params.rope_scaling_type;
if (rope_scaling_type == LLAMA_ROPE_SCALING_UNSPECIFIED) {
uint8_t type = LLAMA_ROPE_SCALING_LINEAR;
GGUF_GET_KEY(ctx, type, gguf_get_val_u8, GGUF_TYPE_UINT8, false, kv(LLM_KV_ROPE_SCALING_TYPE));
rope_scaling_type = llama_rope_scaling_type(type);
}
GGML_ASSERT(rope_scaling_type >= 0 && rope_scaling_type <= LLAMA_ROPE_SCALING_MAX_VALUE);
// rope_freq_scale (inverse of the kv) is optional
if (hparams.rope_freq_scale == 0.0f) {
if (rope_scaling_type == LLAMA_ROPE_SCALING_NONE) {
hparams.rope_freq_scale = 1.0f;
} else if (hparams.rope_freq_scale == 0.0f) {
float ropescale = 1.0f;
GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALING_FACTOR));
hparams.rope_freq_scale = 1.0f/ropescale;
}
if (rope_scaling_type == LLAMA_ROPE_SCALING_YARN) {
hparams.yarn_ext_factor = 1.0f; // enable YaRN
}
// sanity check for n_rot (optional)
{
hparams.n_rot = hparams.n_embd / hparams.n_head;
@ -1902,6 +1936,11 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff);
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base);
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale);
LLAMA_LOG_INFO("%s: YaRN scaling = %g\n", __func__, hparams.yarn_ext_factor);
LLAMA_LOG_INFO("%s: YaRN orig ctx = %u\n", __func__, hparams.n_yarn_orig_ctx);
LLAMA_LOG_INFO("%s: YaRN beta_fast = %f\n", __func__, hparams.yarn_beta_fast);
LLAMA_LOG_INFO("%s: YaRN beta_slow = %f\n", __func__, hparams.yarn_beta_slow);
LLAMA_LOG_INFO("%s: RoPE finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "no");
LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type));
LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str());
LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, ml.n_elements*1e-9);
@ -2444,10 +2483,10 @@ static struct ggml_cgraph * llm_build_llama(
const float freq_base = hparams.rope_freq_base;
const float freq_scale = hparams.rope_freq_scale;
const float ext_factor = hparams.rope_ext_factor;
const float attn_factor = hparams.rope_attn_factor;
const float beta_fast = hparams.rope_beta_fast;
const float beta_slow = hparams.rope_beta_slow;
const float ext_factor = hparams.yarn_ext_factor;
const float attn_factor = hparams.yarn_attn_factor;
const float beta_fast = hparams.yarn_beta_fast;
const float beta_slow = hparams.yarn_beta_slow;
const float norm_rms_eps = hparams.f_norm_eps;
const int n_gpu_layers = model.n_gpu_layers;
@ -2561,15 +2600,13 @@ static struct ggml_cgraph * llm_build_llama(
struct ggml_tensor * Kcur = ggml_rope_custom_inplace(
ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base,
freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
offload_func_kq(Kcur);
ggml_set_name(Kcur, "Kcur");
struct ggml_tensor * Qcur = ggml_rope_custom_inplace(
ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base,
freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
offload_func_kq(Qcur);
ggml_set_name(Qcur, "Qcur");
@ -2786,6 +2823,10 @@ static struct ggml_cgraph * llm_build_baichaun(
const float freq_base = hparams.rope_freq_base;
const float freq_scale = hparams.rope_freq_scale;
const float ext_factor = hparams.yarn_ext_factor;
const float attn_factor = hparams.yarn_attn_factor;
const float beta_fast = hparams.yarn_beta_fast;
const float beta_slow = hparams.yarn_beta_slow;
const float norm_rms_eps = hparams.f_norm_rms_eps;
const int n_gpu_layers = model.n_gpu_layers;
@ -2901,8 +2942,16 @@ static struct ggml_cgraph * llm_build_baichaun(
struct ggml_tensor * Qcur;
switch (model.type) {
case MODEL_7B:
Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale);
Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale);
Kcur = ggml_rope_custom_inplace(
ctx0,
ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N),
n_past, n_embd_head, 0, 0, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
Qcur = ggml_rope_custom_inplace(
ctx0,
ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N),
n_past, n_embd_head, 0, 0, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
break;
case MODEL_13B:
Kcur = ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N);
@ -3146,10 +3195,10 @@ static struct ggml_cgraph * llm_build_falcon(
const float freq_base = hparams.rope_freq_base;
const float freq_scale = hparams.rope_freq_scale;
const float ext_factor = hparams.rope_ext_factor;
const float attn_factor = hparams.rope_attn_factor;
const float beta_fast = hparams.rope_beta_fast;
const float beta_slow = hparams.rope_beta_slow;
const float ext_factor = hparams.yarn_ext_factor;
const float attn_factor = hparams.yarn_attn_factor;
const float beta_fast = hparams.yarn_beta_fast;
const float beta_slow = hparams.yarn_beta_slow;
const float norm_eps = hparams.f_norm_eps;
const int n_gpu_layers = model.n_gpu_layers;
@ -3302,11 +3351,13 @@ static struct ggml_cgraph * llm_build_falcon(
// using mode = 2 for neox mode
struct ggml_tensor * Qcur = ggml_rope_custom_inplace(
ctx0, tmpq, n_past, n_embd_head, 2, 0, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
ctx0, tmpq, n_past, n_embd_head, 2, 0,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
offload_func_kq(Qcur);
struct ggml_tensor * Kcur = ggml_rope_custom_inplace(
ctx0, tmpk, n_past, n_embd_head, 2, 0, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
ctx0, tmpk, n_past, n_embd_head, 2, 0,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
offload_func_kq(Kcur);
@ -6186,10 +6237,11 @@ struct llama_context_params llama_context_default_params() {
/*.tensor_split =*/ nullptr,
/*.rope_freq_base =*/ 0.0f,
/*.rope_freq_scale =*/ 0.0f,
/*.rope_ext_factor =*/ 0.0f,
/*.rope_attn_factor =*/ 1.0f,
/*.rope_beta_fast =*/ 32.0f,
/*.rope_beta_slow =*/ 1.0f,
/*.yarn_ext_factor =*/ 0.0f,
/*.yarn_attn_factor =*/ 1.0f,
/*.yarn_beta_fast =*/ 32.0f,
/*.yarn_beta_slow =*/ 1.0f,
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_UNSPECIFIED,
/*.progress_callback =*/ nullptr,
/*.progress_callback_user_data =*/ nullptr,
/*.low_vram =*/ false,

34
llama.h
View file

@ -108,6 +108,14 @@ extern "C" {
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
};
enum llama_rope_scaling_type: int8_t {
LLAMA_ROPE_SCALING_UNSPECIFIED = -1,
LLAMA_ROPE_SCALING_NONE = 0,
LLAMA_ROPE_SCALING_LINEAR = 1,
LLAMA_ROPE_SCALING_YARN = 2,
LLAMA_ROPE_SCALING_MAX_VALUE = LLAMA_ROPE_SCALING_YARN,
};
typedef struct llama_token_data {
llama_token id; // token id
float logit; // log-odds of the token
@ -134,10 +142,12 @@ extern "C" {
// ref: https://github.com/ggerganov/llama.cpp/pull/2054
float rope_freq_base; // RoPE base frequency
float rope_freq_scale; // RoPE frequency scaling factor
float rope_ext_factor; // RoPE extrapolation mix factor
float rope_attn_factor; // RoPE magnitude scaling factor
float rope_beta_fast; // RoPE low correction dim
float rope_beta_slow; // RoPE high correction dim
float yarn_ext_factor; // YaRN extrapolation mix factor
float yarn_attn_factor; // YaRN magnitude scaling factor
float yarn_beta_fast; // YaRN low correction dim
float yarn_beta_slow; // YaRN high correction dim
llama_rope_scaling_type rope_scaling_type;
// called with a progress value between 0 and 1, pass NULL to disable
llama_progress_callback progress_callback;
@ -145,14 +155,14 @@ extern "C" {
void * progress_callback_user_data;
// Keep the booleans together to avoid misalignment during copy-by-value.
bool low_vram; // if true, reduce VRAM usage at the cost of performance
bool mul_mat_q; // if true, use experimental mul_mat_q kernels
bool f16_kv; // use fp16 for KV cache
bool logits_all; // the llama_eval() call computes all logits, not just the last one
bool vocab_only; // only load the vocabulary, no weights
bool use_mmap; // use mmap if possible
bool use_mlock; // force system to keep model in RAM
bool embedding; // embedding mode only
bool low_vram; // if true, reduce VRAM usage at the cost of performance
bool mul_mat_q; // if true, use experimental mul_mat_q kernels
bool f16_kv; // use fp16 for KV cache
bool logits_all; // the llama_eval() call computes all logits, not just the last one
bool vocab_only; // only load the vocabulary, no weights
bool use_mmap; // use mmap if possible
bool use_mlock; // force system to keep model in RAM
bool embedding; // embedding mode only
};
// Signature for logging events