YaRN : store rope scaling type as int32_t in memory

This commit is contained in:
Jared Van Bortel 2024-02-02 10:27:59 -05:00
parent 191221178f
commit 1dd7aa9b1c
3 changed files with 6 additions and 7 deletions

View file

@ -75,8 +75,7 @@ struct gpt_params {
float yarn_beta_fast = 32.0f; // YaRN low correction dim
float yarn_beta_slow = 1.0f; // YaRN high correction dim
int32_t yarn_orig_ctx = 0; // YaRN original context length
int8_t rope_scaling_type = LLAMA_ROPE_SCALING_UNSPECIFIED; // TODO: better to be int32_t for alignment
// pinging @cebtenzzre
int32_t rope_scaling_type = LLAMA_ROPE_SCALING_UNSPECIFIED;
// // sampling parameters
struct llama_sampling_params sparams;

View file

@ -747,13 +747,13 @@ struct LLM_TN {
// gguf helpers
//
static std::map<int8_t, std::string> LLAMA_ROPE_SCALING_TYPES = {
static std::map<int32_t, std::string> LLAMA_ROPE_SCALING_TYPES = {
{ LLAMA_ROPE_SCALING_NONE, "none" },
{ LLAMA_ROPE_SCALING_LINEAR, "linear" },
{ LLAMA_ROPE_SCALING_YARN, "yarn" },
};
static int8_t llama_rope_scaling_type_from_string(const std::string & name) {
static int32_t llama_rope_scaling_type_from_string(const std::string & name) {
for (const auto & kv : LLAMA_ROPE_SCALING_TYPES) {
if (kv.second == name) {
return kv.first;
@ -1415,6 +1415,7 @@ static const size_t GiB = 1024*MiB;
struct llama_hparams {
bool vocab_only;
bool rope_finetuned;
uint32_t n_vocab;
uint32_t n_ctx_train; // context size the model was trained on
uint32_t n_embd;
@ -1434,8 +1435,7 @@ struct llama_hparams {
float rope_freq_base_train;
float rope_freq_scale_train;
uint32_t n_yarn_orig_ctx;
int8_t rope_scaling_type_train : 3;
bool rope_finetuned : 1;
int32_t rope_scaling_type_train;
float f_clamp_kqv;
float f_max_alibi_bias;

View file

@ -213,7 +213,7 @@ extern "C" {
uint32_t n_batch; // prompt processing maximum batch size
uint32_t n_threads; // number of threads to use for generation
uint32_t n_threads_batch; // number of threads to use for batch processing
int8_t rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
int32_t rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
// ref: https://github.com/ggerganov/llama.cpp/pull/2054
float rope_freq_base; // RoPE base frequency, 0 = from model