llama : implement YaRN RoPE scaling (#2268)
Co-authored-by: cebtenzzre <cebtenzzre@gmail.com> Co-authored-by: Jeffrey Quesnelle <jquesnelle@gmail.com>
This commit is contained in:
		
							parent
							
								
									c43c2da8af
								
							
						
					
					
						commit
						898aeca90a
					
				
					 15 changed files with 763 additions and 257 deletions
				
			
		|  | @ -219,12 +219,52 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { | |||
|                 break; | ||||
|             } | ||||
|             params.rope_freq_scale = std::stof(argv[i]); | ||||
|         } else if (arg == "--rope-scaling") { | ||||
|             if (++i >= argc) { | ||||
|                 invalid_param = true; | ||||
|                 break; | ||||
|             } | ||||
|             std::string value(argv[i]); | ||||
|             /**/ if (value == "none")   { params.rope_scaling_type = LLAMA_ROPE_SCALING_NONE; } | ||||
|             else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_LINEAR; } | ||||
|             else if (value == "yarn")   { params.rope_scaling_type = LLAMA_ROPE_SCALING_YARN; } | ||||
|             else { invalid_param = true; break; } | ||||
|         } else if (arg == "--rope-scale") { | ||||
|             if (++i >= argc) { | ||||
|                 invalid_param = true; | ||||
|                 break; | ||||
|             } | ||||
|             params.rope_freq_scale = 1.0f/std::stof(argv[i]); | ||||
|         } else if (arg == "--yarn-orig-ctx") { | ||||
|             if (++i >= argc) { | ||||
|                 invalid_param = true; | ||||
|                 break; | ||||
|             } | ||||
|             params.yarn_orig_ctx = std::stoi(argv[i]); | ||||
|         } else if (arg == "--yarn-ext-factor") { | ||||
|             if (++i >= argc) { | ||||
|                 invalid_param = true; | ||||
|                 break; | ||||
|             } | ||||
|             params.yarn_ext_factor = std::stof(argv[i]); | ||||
|         } else if (arg == "--yarn-attn-factor") { | ||||
|             if (++i >= argc) { | ||||
|                 invalid_param = true; | ||||
|                 break; | ||||
|             } | ||||
|             params.yarn_attn_factor = std::stof(argv[i]); | ||||
|         } else if (arg == "--yarn-beta-fast") { | ||||
|             if (++i >= argc) { | ||||
|                 invalid_param = true; | ||||
|                 break; | ||||
|             } | ||||
|             params.yarn_beta_fast = std::stof(argv[i]); | ||||
|         } else if (arg == "--yarn-beta-slow") { | ||||
|             if (++i >= argc) { | ||||
|                 invalid_param = true; | ||||
|                 break; | ||||
|             } | ||||
|             params.yarn_beta_slow = std::stof(argv[i]); | ||||
|         } else if (arg == "--memory-f32") { | ||||
|             params.memory_f16 = false; | ||||
|         } else if (arg == "--top-p") { | ||||
|  | @ -716,9 +756,16 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { | |||
|     printf("  --cfg-negative-prompt-file FNAME\n"); | ||||
|     printf("                        negative prompt file to use for guidance. (default: empty)\n"); | ||||
|     printf("  --cfg-scale N         strength of guidance (default: %f, 1.0 = disable)\n", sparams.cfg_scale); | ||||
|     printf("  --rope-scale N        RoPE context linear scaling factor, inverse of --rope-freq-scale\n"); | ||||
|     printf("  --rope-scaling {none,linear,yarn}\n"); | ||||
|     printf("                        RoPE frequency scaling method, defaults to linear unless specified by the model\n"); | ||||
|     printf("  --rope-scale N        RoPE context scaling factor, expands context by a factor of N\n"); | ||||
|     printf("  --rope-freq-base N    RoPE base frequency, used by NTK-aware scaling (default: loaded from model)\n"); | ||||
|     printf("  --rope-freq-scale N   RoPE frequency linear scaling factor (default: loaded from model)\n"); | ||||
|     printf("  --rope-freq-scale N   RoPE frequency scaling factor, expands context by a factor of 1/N\n"); | ||||
|     printf("  --yarn-orig-ctx N     YaRN: original context size of model (default: 0 = model training context size)\n"); | ||||
|     printf("  --yarn-ext-factor N   YaRN: extrapolation mix factor (default: 1.0, 0.0 = full interpolation)\n"); | ||||
|     printf("  --yarn-attn-factor N  YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n"); | ||||
|     printf("  --yarn-beta-slow N    YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow); | ||||
|     printf("  --yarn-beta-fast N    YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast); | ||||
|     printf("  --ignore-eos          ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n"); | ||||
|     printf("  --no-penalize-nl      do not penalize newline token\n"); | ||||
|     printf("  --memory-f32          use f32 instead of f16 for memory key+value (default: disabled)\n"); | ||||
|  | @ -835,8 +882,14 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param | |||
|     cparams.f16_kv            = params.memory_f16; | ||||
|     cparams.logits_all        = params.logits_all; | ||||
|     cparams.embedding         = params.embedding; | ||||
|     cparams.rope_scaling_type = params.rope_scaling_type; | ||||
|     cparams.rope_freq_base    = params.rope_freq_base; | ||||
|     cparams.rope_freq_scale   = params.rope_freq_scale; | ||||
|     cparams.yarn_ext_factor   = params.yarn_ext_factor; | ||||
|     cparams.yarn_attn_factor  = params.yarn_attn_factor; | ||||
|     cparams.yarn_beta_fast    = params.yarn_beta_fast; | ||||
|     cparams.yarn_beta_slow    = params.yarn_beta_slow; | ||||
|     cparams.yarn_orig_ctx     = params.yarn_orig_ctx; | ||||
| 
 | ||||
|     return cparams; | ||||
| } | ||||
|  |  | |||
|  | @ -9,6 +9,7 @@ | |||
| #define LOG_NO_FILE_LINE_FUNCTION | ||||
| #include "log.h" | ||||
| 
 | ||||
| #include <cmath> | ||||
| #include <string> | ||||
| #include <vector> | ||||
| #include <random> | ||||
|  | @ -54,6 +55,12 @@ struct gpt_params { | |||
|     int32_t n_beams                         = 0;    // if non-zero then use beam search of given width.
 | ||||
|     float   rope_freq_base                  = 0.0f; // RoPE base frequency
 | ||||
|     float   rope_freq_scale                 = 0.0f; // RoPE frequency scaling factor
 | ||||
|     float   yarn_ext_factor                 = NAN;  // YaRN extrapolation mix factor
 | ||||
|     float   yarn_attn_factor                = 1.0f; // YaRN magnitude scaling factor
 | ||||
|     float   yarn_beta_fast                  = 32.0f;// YaRN low correction dim
 | ||||
|     float   yarn_beta_slow                  = 1.0f; // YaRN high correction dim
 | ||||
|     int32_t yarn_orig_ctx                   = 0;    // YaRN original context length
 | ||||
|     int8_t  rope_scaling_type               = LLAMA_ROPE_SCALING_UNSPECIFIED; | ||||
| 
 | ||||
|     // // sampling parameters
 | ||||
|     struct llama_sampling_params sparams; | ||||
|  |  | |||
|  | @ -163,7 +163,8 @@ gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"]) | |||
| if "rope_scaling" in hparams and hparams["rope_scaling"] != None and "factor" in hparams["rope_scaling"]: | ||||
|     if "type" in hparams["rope_scaling"]: | ||||
|         if hparams["rope_scaling"]["type"] == "linear": | ||||
|             gguf_writer.add_rope_scale_linear(hparams["rope_scaling"]["factor"]) | ||||
|             gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) | ||||
|             gguf_writer.add_rope_scaling_factor(hparams["rope_scaling"]["factor"]) | ||||
| 
 | ||||
| 
 | ||||
| # TOKENIZATION | ||||
|  |  | |||
							
								
								
									
										91
									
								
								convert.py
									
										
									
									
									
								
							
							
						
						
									
										91
									
								
								convert.py
									
										
									
									
									
								
							|  | @ -151,8 +151,11 @@ class Params: | |||
|     n_head_kv:  int | ||||
|     f_norm_eps: float | ||||
| 
 | ||||
|     rope_scaling_type: gguf.RopeScalingType | None = None | ||||
|     f_rope_freq_base: float | None = None | ||||
|     f_rope_scale: float | None = None | ||||
|     n_orig_ctx: int | None = None | ||||
|     rope_finetuned: bool | None = None | ||||
| 
 | ||||
|     ftype: GGMLFileType | None = None | ||||
| 
 | ||||
|  | @ -198,20 +201,20 @@ class Params: | |||
|     def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params: | ||||
|         config = json.load(open(config_path)) | ||||
| 
 | ||||
|         n_vocab          = config["vocab_size"] | ||||
|         n_embd           = config["hidden_size"] | ||||
|         n_layer          = config["num_hidden_layers"] | ||||
|         n_ff             = config["intermediate_size"] | ||||
|         n_head           = config["num_attention_heads"] | ||||
|         n_head_kv        = config["num_key_value_heads"] if "num_key_value_heads" in config else n_head | ||||
|         f_norm_eps       = config["rms_norm_eps"] | ||||
|         f_rope_freq_base = config["rope_theta"] if "rope_theta" in config else None | ||||
| 
 | ||||
|         rope_scaling_type = f_rope_scale = n_orig_ctx = rope_finetuned = None | ||||
|         rope_scaling = config.get("rope_scaling") | ||||
|         if isinstance(rope_scaling, dict) and rope_scaling.get("type") == "linear": | ||||
|             f_rope_scale = config["rope_scaling"].get("factor") | ||||
| 
 | ||||
|         if rope_scaling is not None and (typ := rope_scaling.get("type")): | ||||
|             rope_factor = rope_scaling.get("factor") | ||||
|             f_rope_scale = rope_factor | ||||
|             if typ == "linear": | ||||
|                 rope_scaling_type = gguf.RopeScalingType.LINEAR | ||||
|             elif typ == "yarn": | ||||
|                 rope_scaling_type = gguf.RopeScalingType.YARN | ||||
|                 n_orig_ctx = rope_scaling['original_max_position_embeddings'] | ||||
|                 rope_finetuned = rope_scaling['finetuned'] | ||||
|             else: | ||||
|             f_rope_scale = None | ||||
|                 raise NotImplementedError(f'Unknown rope scaling type: {typ}') | ||||
| 
 | ||||
|         if "max_sequence_length" in config: | ||||
|             n_ctx = config["max_sequence_length"] | ||||
|  | @ -222,16 +225,19 @@ class Params: | |||
|                             "Suggestion: provide 'config.json' of the model in the same directory containing model files.") | ||||
| 
 | ||||
|         return Params( | ||||
|             n_vocab          = n_vocab, | ||||
|             n_embd           = n_embd, | ||||
|             n_layer          = n_layer, | ||||
|             n_vocab           = config["vocab_size"], | ||||
|             n_embd            = config["hidden_size"], | ||||
|             n_layer           = config["num_hidden_layers"], | ||||
|             n_ctx             = n_ctx, | ||||
|             n_ff             = n_ff, | ||||
|             n_head           = n_head, | ||||
|             n_head_kv        = n_head_kv, | ||||
|             f_norm_eps       = f_norm_eps, | ||||
|             f_rope_freq_base = f_rope_freq_base, | ||||
|             n_ff              = config["intermediate_size"], | ||||
|             n_head            = (n_head := config["num_attention_heads"]), | ||||
|             n_head_kv         = config.get("num_key_value_heads", n_head), | ||||
|             f_norm_eps        = config["rms_norm_eps"], | ||||
|             f_rope_freq_base  = config.get("rope_theta"), | ||||
|             rope_scaling_type = rope_scaling_type, | ||||
|             f_rope_scale      = f_rope_scale, | ||||
|             n_orig_ctx        = n_orig_ctx, | ||||
|             rope_finetuned    = rope_finetuned, | ||||
|         ) | ||||
| 
 | ||||
|     # LLaMA v2 70B params.json | ||||
|  | @ -240,17 +246,8 @@ class Params: | |||
|     def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params: | ||||
|         config = json.load(open(config_path)) | ||||
| 
 | ||||
|         n_vocab          = config["vocab_size"] if "vocab_size" in config else -1 | ||||
|         n_embd           = config["dim"] | ||||
|         n_layer          = config["n_layers"] | ||||
|         n_ff             = -1 | ||||
|         n_head           = config["n_heads"] | ||||
|         n_head_kv        = config["n_kv_heads"] if "n_kv_heads" in config else n_head | ||||
|         f_norm_eps       = config["norm_eps"] | ||||
|         f_rope_freq_base = config["rope_theta"] if "rope_theta" in config else None | ||||
| 
 | ||||
|         # hack to determine LLaMA v1 vs v2 vs CodeLlama | ||||
|         if f_rope_freq_base == 1000000: | ||||
|         if config.get("rope_theta") == 1000000: | ||||
|             # CodeLlama | ||||
|             n_ctx = 16384 | ||||
|         elif config["norm_eps"] == 1e-05: | ||||
|  | @ -260,22 +257,16 @@ class Params: | |||
|             # LLaMA v1 | ||||
|             n_ctx = 2048 | ||||
| 
 | ||||
|         if n_vocab == -1: | ||||
|             n_vocab = model["tok_embeddings.weight"].shape[0] | ||||
| 
 | ||||
|         if n_ff == -1: | ||||
|             n_ff = model["layers.0.feed_forward.w1.weight"].shape[0] | ||||
| 
 | ||||
|         return Params( | ||||
|             n_vocab          = n_vocab, | ||||
|             n_embd           = n_embd, | ||||
|             n_layer          = n_layer, | ||||
|             n_vocab          = config.get("vocab_size", model["tok_embeddings.weight"].shape[0]), | ||||
|             n_embd           = config["dim"], | ||||
|             n_layer          = config["n_layers"], | ||||
|             n_ctx            = n_ctx, | ||||
|             n_ff             = n_ff, | ||||
|             n_head           = n_head, | ||||
|             n_head_kv        = n_head_kv, | ||||
|             f_norm_eps       = f_norm_eps, | ||||
|             f_rope_freq_base = f_rope_freq_base, | ||||
|             n_ff             = model["layers.0.feed_forward.w1.weight"].shape[0], | ||||
|             n_head           = (n_head := config["n_heads"]), | ||||
|             n_head_kv        = config.get("n_kv_heads", n_head), | ||||
|             f_norm_eps       = config["norm_eps"], | ||||
|             f_rope_freq_base = config.get("rope_theta"), | ||||
|         ) | ||||
| 
 | ||||
|     @staticmethod | ||||
|  | @ -831,8 +822,16 @@ class OutputFile: | |||
|         if params.f_rope_freq_base is not None: | ||||
|             self.gguf.add_rope_freq_base(params.f_rope_freq_base) | ||||
| 
 | ||||
|         if params.f_rope_scale is not None: | ||||
|             self.gguf.add_rope_scale_linear(params.f_rope_scale) | ||||
|         if params.rope_scaling_type: | ||||
|             assert params.f_rope_scale is not None | ||||
|             self.gguf.add_rope_scaling_type(params.rope_scaling_type) | ||||
|             self.gguf.add_rope_scaling_factor(params.f_rope_scale) | ||||
| 
 | ||||
|         if params.n_orig_ctx is not None: | ||||
|             self.gguf.add_rope_scaling_orig_ctx_len(params.n_orig_ctx) | ||||
| 
 | ||||
|         if params.rope_finetuned is not None: | ||||
|             self.gguf.add_rope_scaling_finetuned(params.rope_finetuned) | ||||
| 
 | ||||
|         if params.ftype is not None: | ||||
|             self.gguf.add_file_type(params.ftype) | ||||
|  |  | |||
|  | @ -642,8 +642,9 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs( | |||
|         const int rope_mode = 0; | ||||
| 
 | ||||
|         return ggml_rope_custom(ctx, | ||||
|             t, KQ_pos, n_rot, rope_mode, n_ctx, | ||||
|             rope_freq_base, rope_freq_scale); | ||||
|             t, KQ_pos, n_rot, rope_mode, n_ctx, 0, | ||||
|             rope_freq_base, rope_freq_scale, 0.0f, 0.0f, 0.0f, 0.0f | ||||
|         ); | ||||
|     }; | ||||
| 
 | ||||
|     set_name(tokens_input, "tokens_input"); | ||||
|  |  | |||
|  | @ -1758,8 +1758,14 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, | |||
|     printf("  -t N, --threads N         number of threads to use during computation (default: %d)\n", params.n_threads); | ||||
|     printf("  -tb N, --threads-batch N  number of threads to use during batch and prompt processing (default: same as --threads)\n"); | ||||
|     printf("  -c N, --ctx-size N        size of the prompt context (default: %d)\n", params.n_ctx); | ||||
|     printf("  --rope-scaling {none,linear,yarn}\n"); | ||||
|     printf("                            RoPE frequency scaling method, defaults to linear unless specified by the model\n"); | ||||
|     printf("  --rope-freq-base N        RoPE base frequency (default: loaded from model)\n"); | ||||
|     printf("  --rope-freq-scale N       RoPE frequency scaling factor (default: loaded from model)\n"); | ||||
|     printf("  --rope-freq-scale N       RoPE frequency scaling factor, expands context by a factor of 1/N\n"); | ||||
|     printf("  --yarn-ext-factor N       YaRN: extrapolation mix factor (default: 1.0, 0.0 = full interpolation)\n"); | ||||
|     printf("  --yarn-attn-factor N      YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n"); | ||||
|     printf("  --yarn-beta-slow N        YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow); | ||||
|     printf("  --yarn-beta-fast N        YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast); | ||||
|     printf("  -b N, --batch-size N      batch size for prompt processing (default: %d)\n", params.n_batch); | ||||
|     printf("  --memory-f32              use f32 instead of f16 for memory key+value (default: disabled)\n"); | ||||
|     printf("                            not recommended: doubles context memory required and no measurable increase in quality\n"); | ||||
|  | @ -1881,6 +1887,19 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, | |||
|             } | ||||
|             params.n_ctx = std::stoi(argv[i]); | ||||
|         } | ||||
|         else if (arg == "--rope-scaling") | ||||
|         { | ||||
|             if (++i >= argc) | ||||
|             { | ||||
|                 invalid_param = true; | ||||
|                 break; | ||||
|             } | ||||
|             std::string value(argv[i]); | ||||
|             /**/ if (value == "none")   { params.rope_scaling_type = LLAMA_ROPE_SCALING_NONE; } | ||||
|             else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_LINEAR; } | ||||
|             else if (value == "yarn")   { params.rope_scaling_type = LLAMA_ROPE_SCALING_YARN; } | ||||
|             else { invalid_param = true; break; } | ||||
|         } | ||||
|         else if (arg == "--rope-freq-base") | ||||
|         { | ||||
|             if (++i >= argc) | ||||
|  | @ -1899,6 +1918,38 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, | |||
|             } | ||||
|             params.rope_freq_scale = std::stof(argv[i]); | ||||
|         } | ||||
|         else if (arg == "--yarn-ext-factor") | ||||
|         { | ||||
|             if (++i >= argc) { | ||||
|                 invalid_param = true; | ||||
|                 break; | ||||
|             } | ||||
|             params.yarn_ext_factor = std::stof(argv[i]); | ||||
|         } | ||||
|         else if (arg == "--yarn-attn-factor") | ||||
|         { | ||||
|             if (++i >= argc) { | ||||
|                 invalid_param = true; | ||||
|                 break; | ||||
|             } | ||||
|             params.yarn_attn_factor = std::stof(argv[i]); | ||||
|         } | ||||
|         else if (arg == "--yarn-beta-fast") | ||||
|         { | ||||
|             if (++i >= argc) { | ||||
|                 invalid_param = true; | ||||
|                 break; | ||||
|             } | ||||
|             params.yarn_beta_fast = std::stof(argv[i]); | ||||
|         } | ||||
|         else if (arg == "--yarn-beta-slow") | ||||
|         { | ||||
|             if (++i >= argc) { | ||||
|                 invalid_param = true; | ||||
|                 break; | ||||
|             } | ||||
|             params.yarn_beta_slow = std::stof(argv[i]); | ||||
|         } | ||||
|         else if (arg == "--memory-f32" || arg == "--memory_f32") | ||||
|         { | ||||
|             params.memory_f16 = false; | ||||
|  |  | |||
|  | @ -349,9 +349,9 @@ static struct ggml_tensor * llama_build_train_graphs( | |||
|         // not capturing these, to silcence warnings
 | ||||
|         const int rope_mode = 0; | ||||
| 
 | ||||
|         return ggml_rope_custom(ctx, | ||||
|             t, KQ_pos, n_rot, rope_mode, n_ctx, | ||||
|             rope_freq_base, rope_freq_scale); | ||||
|         return ggml_rope_custom( | ||||
|             ctx, t, KQ_pos, n_rot, rope_mode, n_ctx, 0, rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f | ||||
|         ); | ||||
|     }; | ||||
| 
 | ||||
|     set_name(tokens_input, "tokens_input"); | ||||
|  |  | |||
							
								
								
									
										145
									
								
								ggml-cuda.cu
									
										
									
									
									
								
							
							
						
						
									
										145
									
								
								ggml-cuda.cu
									
										
									
									
									
								
							|  | @ -4493,11 +4493,41 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne, | |||
|     cpy_1(cx + x_offset, cdst + dst_offset); | ||||
| } | ||||
| 
 | ||||
| // rope == RoPE == rotary positional embedding | ||||
| static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) { | ||||
|     const float y = (i0 / 2 - low) / max(0.001f, high - low); | ||||
|     return 1.0f - min(1.0f, max(0.0f, y)); | ||||
| } | ||||
| 
 | ||||
| struct rope_corr_dims { | ||||
|     float v[4]; | ||||
| }; | ||||
| 
 | ||||
| // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn | ||||
| // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. | ||||
| static __device__ void rope_yarn( | ||||
|     float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale, | ||||
|     float * cos_theta, float * sin_theta | ||||
| ) { | ||||
|     // Get n-d rotational scaling corrected for extrapolation | ||||
|     float theta_interp = freq_scale * theta_extrap; | ||||
|     float theta = theta_interp; | ||||
|     if (ext_factor != 0.0f) { | ||||
|         float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor; | ||||
|         theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; | ||||
| 
 | ||||
|         // Get n-d magnitude scaling corrected for interpolation | ||||
|         mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale); | ||||
|     } | ||||
|     *cos_theta = cosf(theta) * mscale; | ||||
|     *sin_theta = sinf(theta) * mscale; | ||||
| } | ||||
| 
 | ||||
| // rope == RoPE == rotary positional embedding | ||||
| template<typename T, bool has_pos> | ||||
| static __global__ void rope(const T * x, T * dst, const int ncols, const int32_t * pos, const float freq_scale, | ||||
|                             const int p_delta_rows, const float theta_scale) { | ||||
| static __global__ void rope( | ||||
|     const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base, | ||||
|     float ext_factor, float attn_factor, rope_corr_dims corr_dims | ||||
| ) { | ||||
|     const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); | ||||
| 
 | ||||
|     if (col >= ncols) { | ||||
|  | @ -4509,10 +4539,10 @@ static __global__ void rope(const T * x, T * dst, const int ncols, const int32_t | |||
|     const int i2 = row/p_delta_rows; | ||||
| 
 | ||||
|     const int p = has_pos ? pos[i2] : 0; | ||||
|     const float p0 = p*freq_scale; | ||||
|     const float theta = p0*powf(theta_scale, col/2); | ||||
|     const float sin_theta = sinf(theta); | ||||
|     const float cos_theta = cosf(theta); | ||||
|     const float theta_base = p*powf(freq_base, -col/ncols); | ||||
| 
 | ||||
|     float cos_theta, sin_theta; | ||||
|     rope_yarn(theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta); | ||||
| 
 | ||||
|     const float x0 = x[i + 0]; | ||||
|     const float x1 = x[i + 1]; | ||||
|  | @ -4522,8 +4552,10 @@ static __global__ void rope(const T * x, T * dst, const int ncols, const int32_t | |||
| } | ||||
| 
 | ||||
| template<typename T, bool has_pos> | ||||
| static __global__ void rope_neox(const T * x, T * dst, const int ncols, const int32_t * pos, const float freq_scale, | ||||
|                                  const int p_delta_rows, const float theta_scale) { | ||||
| static __global__ void rope_neox( | ||||
|     const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base, | ||||
|     float ext_factor, float attn_factor, rope_corr_dims corr_dims | ||||
| ) { | ||||
|     const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); | ||||
| 
 | ||||
|     if (col >= ncols) { | ||||
|  | @ -4534,11 +4566,14 @@ static __global__ void rope_neox(const T * x, T * dst, const int ncols, const in | |||
|     const int i = row*ncols + col/2; | ||||
|     const int i2 = row/p_delta_rows; | ||||
| 
 | ||||
|     // simplified from `(row * ncols + col) * (-1 / ncols)` | ||||
|     const float cur_rot = -col/ncols - row; | ||||
| 
 | ||||
|     const int p = has_pos ? pos[i2] : 0; | ||||
|     const float p0 = p*freq_scale; | ||||
|     const float theta = p0*powf(theta_scale, col/2); | ||||
|     const float sin_theta = sinf(theta); | ||||
|     const float cos_theta = cosf(theta); | ||||
|     const float theta_base = p*powf(freq_base, cur_rot); | ||||
| 
 | ||||
|     float cos_theta, sin_theta; | ||||
|     rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta); | ||||
| 
 | ||||
|     const float x0 = x[i + 0]; | ||||
|     const float x1 = x[i + ncols/2]; | ||||
|  | @ -4547,8 +4582,10 @@ static __global__ void rope_neox(const T * x, T * dst, const int ncols, const in | |||
|     dst[i + ncols/2] = x0*sin_theta + x1*cos_theta; | ||||
| } | ||||
| 
 | ||||
| static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale, | ||||
|                                     const int p_delta_rows, const float theta_scale, const int n_ctx) { | ||||
| static __global__ void rope_glm_f32( | ||||
|     const float * x, float * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base, | ||||
|     int n_ctx | ||||
| ) { | ||||
|     const int col = blockDim.x*blockIdx.x + threadIdx.x; | ||||
|     const int half_n_dims = ncols/4; | ||||
| 
 | ||||
|  | @ -4560,7 +4597,7 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol | |||
|     const int i = row*ncols + col; | ||||
|     const int i2 = row/p_delta_rows; | ||||
| 
 | ||||
|     const float col_theta_scale = powf(theta_scale, col); | ||||
|     const float col_theta_scale = powf(freq_base, -2.0f*col/ncols); | ||||
|      // FIXME: this is likely wrong | ||||
|     const int p = pos != nullptr ? pos[i2] : 0; | ||||
| 
 | ||||
|  | @ -5584,40 +5621,54 @@ static void clamp_f32_cuda(const float * x, float * dst, const float min, const | |||
| } | ||||
| 
 | ||||
| template<typename T> | ||||
| static void rope_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale, | ||||
|                           const int p_delta_rows, const float theta_scale, cudaStream_t stream) { | ||||
| static void rope_cuda( | ||||
|     const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, | ||||
|     float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream | ||||
| ) { | ||||
|     GGML_ASSERT(ncols % 2 == 0); | ||||
|     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); | ||||
|     const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); | ||||
|     const dim3 block_nums(nrows, num_blocks_x, 1); | ||||
|     if (pos == nullptr) { | ||||
|         rope<T, false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale); | ||||
|         rope<T, false><<<block_nums, block_dims, 0, stream>>>( | ||||
|             x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims | ||||
|         ); | ||||
|     } else { | ||||
|         rope<T, true><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale); | ||||
|         rope<T, true><<<block_nums, block_dims, 0, stream>>>( | ||||
|             x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims | ||||
|         ); | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| template<typename T> | ||||
| static void rope_neox_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale, | ||||
|                           const int p_delta_rows, const float theta_scale, cudaStream_t stream) { | ||||
| static void rope_neox_cuda( | ||||
|     const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, | ||||
|     float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream | ||||
| ) { | ||||
|     GGML_ASSERT(ncols % 2 == 0); | ||||
|     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); | ||||
|     const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); | ||||
|     const dim3 block_nums(nrows, num_blocks_x, 1); | ||||
|     if (pos == nullptr) { | ||||
|         rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale); | ||||
|         rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>( | ||||
|             x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims | ||||
|         ); | ||||
|     } else { | ||||
|         rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale); | ||||
|         rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>( | ||||
|             x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims | ||||
|         ); | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale, | ||||
|                               const int p_delta_rows, const float theta_scale, const int n_ctx, cudaStream_t stream) { | ||||
| static void rope_glm_f32_cuda( | ||||
|     const float * x, float * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, | ||||
|     float freq_base, int n_ctx, cudaStream_t stream | ||||
| ) { | ||||
|     GGML_ASSERT(ncols % 4 == 0); | ||||
|     const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE/4, 1, 1); | ||||
|     const int num_blocks_x = (ncols + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE; | ||||
|     const dim3 block_nums(num_blocks_x, nrows, 1); | ||||
|     rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale, n_ctx); | ||||
|     rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, n_ctx); | ||||
| } | ||||
| 
 | ||||
| static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, | ||||
|  | @ -6481,13 +6532,16 @@ inline void ggml_cuda_op_rope( | |||
|     const int n_dims      = ((int32_t *) dst->op_params)[1]; | ||||
|     const int mode        = ((int32_t *) dst->op_params)[2]; | ||||
|     const int n_ctx       = ((int32_t *) dst->op_params)[3]; | ||||
|     const int n_orig_ctx  = ((int32_t *) dst->op_params)[4]; | ||||
| 
 | ||||
|     // RoPE alteration for extended context | ||||
| 
 | ||||
|     float freq_base, freq_scale; | ||||
|     memcpy(&freq_base,  (int32_t *) dst->op_params + 4, sizeof(float)); | ||||
|     memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); | ||||
| 
 | ||||
|     const float theta_scale = powf(freq_base, -2.0f/n_dims); | ||||
|     float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; | ||||
|     memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float)); | ||||
|     memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float)); | ||||
|     memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float)); | ||||
|     memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float)); | ||||
|     memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float)); | ||||
|     memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float)); | ||||
| 
 | ||||
|     const int32_t * pos = nullptr; | ||||
|     if ((mode & 1) == 0) { | ||||
|  | @ -6499,24 +6553,39 @@ inline void ggml_cuda_op_rope( | |||
|     const bool is_neox = mode & 2; | ||||
|     const bool is_glm  = mode & 4; | ||||
| 
 | ||||
|     rope_corr_dims corr_dims; | ||||
|     ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v); | ||||
| 
 | ||||
|     // compute | ||||
|     if (is_glm) { | ||||
|         GGML_ASSERT(false); | ||||
|         rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, n_ctx, main_stream); | ||||
|         rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, main_stream); | ||||
|     } else if (is_neox) { | ||||
|         GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet"); | ||||
|         if (src0->type == GGML_TYPE_F32) { | ||||
|             rope_neox_cuda((const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream); | ||||
|             rope_neox_cuda( | ||||
|                 (const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor, | ||||
|                 attn_factor, corr_dims, main_stream | ||||
|             ); | ||||
|         } else if (src0->type == GGML_TYPE_F16) { | ||||
|             rope_neox_cuda((const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream); | ||||
|             rope_neox_cuda( | ||||
|                 (const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor, | ||||
|                 attn_factor, corr_dims, main_stream | ||||
|             ); | ||||
|         } else { | ||||
|             GGML_ASSERT(false); | ||||
|         } | ||||
|     } else { | ||||
|         if (src0->type == GGML_TYPE_F32) { | ||||
|             rope_cuda((const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream); | ||||
|             rope_cuda( | ||||
|                 (const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor, | ||||
|                 attn_factor, corr_dims, main_stream | ||||
|             ); | ||||
|         } else if (src0->type == GGML_TYPE_F16) { | ||||
|             rope_cuda((const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream); | ||||
|             rope_cuda( | ||||
|                 (const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor, | ||||
|                 attn_factor, corr_dims, main_stream | ||||
|             ); | ||||
|         } else { | ||||
|             GGML_ASSERT(false); | ||||
|         } | ||||
|  |  | |||
							
								
								
									
										16
									
								
								ggml-metal.m
									
										
									
									
									
								
							
							
						
						
									
										16
									
								
								ggml-metal.m
									
										
									
									
									
								
							|  | @ -1403,11 +1403,15 @@ void ggml_metal_graph_compute( | |||
|                             const int n_past     = ((int32_t *) dst->op_params)[0]; | ||||
|                             const int n_dims     = ((int32_t *) dst->op_params)[1]; | ||||
|                             const int mode       = ((int32_t *) dst->op_params)[2]; | ||||
|                             const int n_orig_ctx = ((int32_t *) dst->op_params)[3]; | ||||
| 
 | ||||
|                             float freq_base; | ||||
|                             float freq_scale; | ||||
|                             memcpy(&freq_base,  (int32_t *) dst->op_params + 4, sizeof(float)); | ||||
|                             memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); | ||||
|                             float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; | ||||
|                             memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float)); | ||||
|                             memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float)); | ||||
|                             memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float)); | ||||
|                             memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float)); | ||||
|                             memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float)); | ||||
|                             memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float)); | ||||
| 
 | ||||
|                             switch (src0->type) { | ||||
|                                 case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_rope_f32]; break; | ||||
|  | @ -1439,6 +1443,10 @@ void ggml_metal_graph_compute( | |||
|                             [encoder setBytes:&mode    length:sizeof(     int) atIndex:21]; | ||||
|                             [encoder setBytes:&freq_base  length:sizeof(float) atIndex:22]; | ||||
|                             [encoder setBytes:&freq_scale length:sizeof(float) atIndex:23]; | ||||
|                             [encoder setBytes:&ext_factor  length:sizeof(float) atIndex:24]; | ||||
|                             [encoder setBytes:&attn_factor length:sizeof(float) atIndex:25]; | ||||
|                             [encoder setBytes:&beta_fast   length:sizeof(float) atIndex:26]; | ||||
|                             [encoder setBytes:&beta_slow   length:sizeof(float) atIndex:27]; | ||||
| 
 | ||||
|                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; | ||||
|                         } break; | ||||
|  |  | |||
|  | @ -1061,6 +1061,45 @@ kernel void kernel_alibi_f32( | |||
|     } | ||||
| } | ||||
| 
 | ||||
| static float rope_yarn_ramp(const float low, const float high, const int i0) { | ||||
|     const float y = (i0 / 2 - low) / max(0.001f, high - low); | ||||
|     return 1.0f - min(1.0f, max(0.0f, y)); | ||||
| } | ||||
| 
 | ||||
| // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
 | ||||
| // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
 | ||||
| static void rope_yarn( | ||||
|     float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale, | ||||
|     float * cos_theta, float * sin_theta | ||||
| ) { | ||||
|     // Get n-d rotational scaling corrected for extrapolation
 | ||||
|     float theta_interp = freq_scale * theta_extrap; | ||||
|     float theta = theta_interp; | ||||
|     if (ext_factor != 0.0f) { | ||||
|         ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; | ||||
|         theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; | ||||
| 
 | ||||
|         // Get n-d magnitude scaling corrected for interpolation
 | ||||
|         mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale); | ||||
|     } | ||||
|     *cos_theta = cosf(theta) * mscale; | ||||
|     *sin_theta = sinf(theta) * mscale; | ||||
| } | ||||
| 
 | ||||
| // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
 | ||||
| // `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
 | ||||
| static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) { | ||||
|     return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base)); | ||||
| } | ||||
| 
 | ||||
| static void rope_yarn_corr_dims( | ||||
|     int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2] | ||||
| ) { | ||||
|     // start and end correction dims
 | ||||
|     dims[0] = max(0.0f,         floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base))); | ||||
|     dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base))); | ||||
| } | ||||
| 
 | ||||
| typedef void (rope_t)( | ||||
|         device const    void * src0, | ||||
|         device const int32_t * src1, | ||||
|  | @ -1116,6 +1155,10 @@ kernel void kernel_rope( | |||
|         constant         int & mode, | ||||
|         constant       float & freq_base, | ||||
|         constant       float & freq_scale, | ||||
|         constant       float & ext_factor, | ||||
|         constant       float & attn_factor, | ||||
|         constant       float & beta_fast, | ||||
|         constant       float & beta_slow, | ||||
|         uint  tiitg[[thread_index_in_threadgroup]], | ||||
|         uint3 tptg[[threads_per_threadgroup]], | ||||
|         uint3 tgpig[[threadgroup_position_in_grid]]) { | ||||
|  | @ -1125,19 +1168,22 @@ kernel void kernel_rope( | |||
| 
 | ||||
|     const bool is_neox = mode & 2; | ||||
| 
 | ||||
|     float corr_dims[2]; | ||||
|     rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); | ||||
| 
 | ||||
|     device const int32_t * pos = src1; | ||||
| 
 | ||||
|     const int64_t p = pos[i2]; | ||||
| 
 | ||||
|     const float theta_0 = freq_scale * (float)p; | ||||
|     const float theta_0 = (float)p; | ||||
|     const float inv_ndims = -1.f/n_dims; | ||||
| 
 | ||||
|     if (!is_neox) { | ||||
|         for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { | ||||
| 
 | ||||
|             const float theta = theta_0 * pow(freq_base, inv_ndims*i0); | ||||
|             const float cos_theta = cos(theta); | ||||
|             const float sin_theta = sin(theta); | ||||
|             float cos_theta, sin_theta; | ||||
|             rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); | ||||
| 
 | ||||
|             device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); | ||||
|             device       T * dst_data  = (device T *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0); | ||||
|  | @ -1152,9 +1198,12 @@ kernel void kernel_rope( | |||
|         for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { | ||||
|             for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) { | ||||
| 
 | ||||
|                 const float theta = theta_0 * pow(freq_base, inv_ndims*ic - ib); | ||||
|                 const float cos_theta = cos(theta); | ||||
|                 const float sin_theta = sin(theta); | ||||
|                 // simplified from `(ib * n_dims + ic) * inv_ndims`
 | ||||
|                 const float cur_rot = inv_ndims*ic - ib; | ||||
| 
 | ||||
|                 const float theta = theta_0 * pow(freq_base, cur_rot); | ||||
|                 float cos_theta, sin_theta; | ||||
|                 rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta); | ||||
| 
 | ||||
|                 const int64_t i0 = ib*n_dims + ic/2; | ||||
| 
 | ||||
|  |  | |||
							
								
								
									
										225
									
								
								ggml.c
									
										
									
									
									
								
							
							
						
						
									
										225
									
								
								ggml.c
									
										
									
									
									
								
							|  | @ -1,4 +1,5 @@ | |||
| #define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows
 | ||||
| #define _USE_MATH_DEFINES // For M_PI on MSVC
 | ||||
| 
 | ||||
| #include "ggml-impl.h" | ||||
| #include "ggml-quants.h" | ||||
|  | @ -4845,8 +4846,13 @@ static struct ggml_tensor * ggml_rope_impl( | |||
|         int                   n_dims, | ||||
|         int                   mode, | ||||
|         int                   n_ctx, | ||||
|         int                   n_orig_ctx, | ||||
|         float                 freq_base, | ||||
|         float                 freq_scale, | ||||
|         float                 ext_factor, | ||||
|         float                 attn_factor, | ||||
|         float                 beta_fast, | ||||
|         float                 beta_slow, | ||||
|         float                 xpos_base, | ||||
|         bool                  xpos_down, | ||||
|         bool                  inplace) { | ||||
|  | @ -4862,11 +4868,15 @@ static struct ggml_tensor * ggml_rope_impl( | |||
| 
 | ||||
|     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); | ||||
| 
 | ||||
|     int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx }; | ||||
|     memcpy(params + 4, &freq_base,  sizeof(float)); | ||||
|     memcpy(params + 5, &freq_scale, sizeof(float)); | ||||
|     memcpy(params + 6, &xpos_base,  sizeof(float)); | ||||
|     memcpy(params + 7, &xpos_down,  sizeof(bool)); | ||||
|     int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx }; | ||||
|     memcpy(params +  5, &freq_base,    sizeof(float)); | ||||
|     memcpy(params +  6, &freq_scale,   sizeof(float)); | ||||
|     memcpy(params +  7, &ext_factor,   sizeof(float)); | ||||
|     memcpy(params +  8, &attn_factor,  sizeof(float)); | ||||
|     memcpy(params +  9, &beta_fast,    sizeof(float)); | ||||
|     memcpy(params + 10, &beta_slow,    sizeof(float)); | ||||
|     memcpy(params + 11, &xpos_base,    sizeof(float)); | ||||
|     memcpy(params + 12, &xpos_down,    sizeof(bool)); | ||||
|     ggml_set_op_params(result, params, sizeof(params)); | ||||
| 
 | ||||
|     result->op   = GGML_OP_ROPE; | ||||
|  | @ -4884,7 +4894,9 @@ struct ggml_tensor * ggml_rope( | |||
|         int                   n_dims, | ||||
|         int                   mode, | ||||
|         int                   n_ctx) { | ||||
|     return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, false); | ||||
|     return ggml_rope_impl( | ||||
|         ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false | ||||
|     ); | ||||
| } | ||||
| 
 | ||||
| struct ggml_tensor * ggml_rope_inplace( | ||||
|  | @ -4894,7 +4906,9 @@ struct ggml_tensor * ggml_rope_inplace( | |||
|         int                   n_dims, | ||||
|         int                   mode, | ||||
|         int                   n_ctx) { | ||||
|     return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, true); | ||||
|     return ggml_rope_impl( | ||||
|         ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true | ||||
|     ); | ||||
| } | ||||
| 
 | ||||
| struct ggml_tensor * ggml_rope_custom( | ||||
|  | @ -4904,9 +4918,17 @@ struct ggml_tensor * ggml_rope_custom( | |||
|         int                   n_dims, | ||||
|         int                   mode, | ||||
|         int                   n_ctx, | ||||
|         int                   n_orig_ctx, | ||||
|         float                 freq_base, | ||||
|         float                 freq_scale) { | ||||
|     return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, false); | ||||
|         float                 freq_scale, | ||||
|         float                 ext_factor, | ||||
|         float                 attn_factor, | ||||
|         float                 beta_fast, | ||||
|         float                 beta_slow) { | ||||
|     return ggml_rope_impl( | ||||
|         ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale, | ||||
|         ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false | ||||
|     ); | ||||
| } | ||||
| 
 | ||||
| struct ggml_tensor * ggml_rope_custom_inplace( | ||||
|  | @ -4916,9 +4938,17 @@ struct ggml_tensor * ggml_rope_custom_inplace( | |||
|         int                   n_dims, | ||||
|         int                   mode, | ||||
|         int                   n_ctx, | ||||
|         int                   n_orig_ctx, | ||||
|         float                 freq_base, | ||||
|         float                 freq_scale) { | ||||
|     return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, true); | ||||
|         float                 freq_scale, | ||||
|         float                 ext_factor, | ||||
|         float                 attn_factor, | ||||
|         float                 beta_fast, | ||||
|         float                 beta_slow) { | ||||
|     return ggml_rope_impl( | ||||
|         ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale, | ||||
|         ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true | ||||
|     ); | ||||
| } | ||||
| 
 | ||||
| struct ggml_tensor * ggml_rope_xpos_inplace( | ||||
|  | @ -4928,7 +4958,7 @@ struct ggml_tensor * ggml_rope_xpos_inplace( | |||
|         int                   n_dims, | ||||
|         float                 base, | ||||
|         bool                  down) { | ||||
|     return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 10000.0f, 1.0f, base, down, true); | ||||
|     return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true); | ||||
| } | ||||
| 
 | ||||
| // ggml_rope_back
 | ||||
|  | @ -10901,6 +10931,45 @@ static void ggml_compute_forward_clamp( | |||
| 
 | ||||
| // ggml_compute_forward_rope
 | ||||
| 
 | ||||
| static float rope_yarn_ramp(const float low, const float high, const int i0) { | ||||
|     const float y = (i0 / 2 - low) / MAX(0.001f, high - low); | ||||
|     return 1 - MIN(1, MAX(0, y)); | ||||
| } | ||||
| 
 | ||||
| // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
 | ||||
| // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
 | ||||
| static void rope_yarn( | ||||
|     float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale, | ||||
|     float * cos_theta, float * sin_theta | ||||
| ) { | ||||
|     // Get n-d rotational scaling corrected for extrapolation
 | ||||
|     float theta_interp = freq_scale * theta_extrap; | ||||
|     float theta = theta_interp; | ||||
|     if (ext_factor != 0.0f) { | ||||
|         float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; | ||||
|         theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; | ||||
| 
 | ||||
|         // Get n-d magnitude scaling corrected for interpolation
 | ||||
|         mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale); | ||||
|     } | ||||
|     *cos_theta = cosf(theta) * mscale; | ||||
|     *sin_theta = sinf(theta) * mscale; | ||||
| } | ||||
| 
 | ||||
| // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
 | ||||
| // `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
 | ||||
| static float ggml_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot, float base) { | ||||
|     return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base)); | ||||
| } | ||||
| 
 | ||||
| void ggml_rope_yarn_corr_dims( | ||||
|     int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2] | ||||
| ) { | ||||
|     // start and end correction dims
 | ||||
|     dims[0] = MAX(0,         floorf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_fast, freq_base))); | ||||
|     dims[1] = MIN(n_dims - 1, ceilf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_slow, freq_base))); | ||||
| } | ||||
| 
 | ||||
| static void ggml_compute_forward_rope_f32( | ||||
|         const struct ggml_compute_params * params, | ||||
|         const struct ggml_tensor * src0, | ||||
|  | @ -10910,8 +10979,7 @@ static void ggml_compute_forward_rope_f32( | |||
|         return; | ||||
|     } | ||||
| 
 | ||||
|     float freq_base; | ||||
|     float freq_scale; | ||||
|     float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; | ||||
| 
 | ||||
|     // these two only relevant for xPos RoPE:
 | ||||
|     float xpos_base; | ||||
|  | @ -10921,10 +10989,16 @@ static void ggml_compute_forward_rope_f32( | |||
|     const int n_dims     = ((int32_t *) dst->op_params)[1]; | ||||
|     const int mode       = ((int32_t *) dst->op_params)[2]; | ||||
|     const int n_ctx      = ((int32_t *) dst->op_params)[3]; | ||||
|     memcpy(&freq_base,  (int32_t *) dst->op_params + 4, sizeof(float)); | ||||
|     memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); | ||||
|     memcpy(&xpos_base,  (int32_t *) dst->op_params + 6, sizeof(float)); | ||||
|     memcpy(&xpos_down,  (int32_t *) dst->op_params + 7, sizeof(bool)); | ||||
|     const int n_orig_ctx = ((int32_t *) dst->op_params)[4]; | ||||
| 
 | ||||
|     memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float)); | ||||
|     memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float)); | ||||
|     memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float)); | ||||
|     memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float)); | ||||
|     memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float)); | ||||
|     memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float)); | ||||
|     memcpy(&xpos_base,   (int32_t *) dst->op_params + 11, sizeof(float)); | ||||
|     memcpy(&xpos_down,   (int32_t *) dst->op_params + 12, sizeof(bool)); | ||||
| 
 | ||||
|     GGML_TENSOR_UNARY_OP_LOCALS | ||||
| 
 | ||||
|  | @ -10952,6 +11026,9 @@ static void ggml_compute_forward_rope_f32( | |||
|     int ir = 0; | ||||
| 
 | ||||
|     const float theta_scale = powf(freq_base, -2.0f/n_dims); | ||||
|     const float inv_ndims = -1.f/n_dims; | ||||
|     float corr_dims[2]; | ||||
|     ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); | ||||
| 
 | ||||
|     const bool is_neox = mode & 2; | ||||
|     const bool is_glm  = mode & 4; | ||||
|  | @ -10965,18 +11042,18 @@ static void ggml_compute_forward_rope_f32( | |||
|                 if (ir++ < ir0) continue; | ||||
|                 if (ir   > ir1) break; | ||||
| 
 | ||||
|                 float theta = freq_scale * (float)p; | ||||
|                 float theta_base = (float)p; | ||||
| 
 | ||||
|                 if (is_glm) { | ||||
|                     theta = MIN(p, n_ctx - 2); | ||||
|                     theta_base = MIN(p, n_ctx - 2); | ||||
|                     float block_theta = MAX(p - (n_ctx - 2), 0); | ||||
|                     for (int64_t i0 = 0; i0 < ne0 / 4; i0++) { | ||||
|                         const float cos_theta = cosf(theta); | ||||
|                         const float sin_theta = sinf(theta); | ||||
|                         const float cos_theta = cosf(theta_base); | ||||
|                         const float sin_theta = sinf(theta_base); | ||||
|                         const float cos_block_theta = cosf(block_theta); | ||||
|                         const float sin_block_theta = sinf(block_theta); | ||||
| 
 | ||||
|                         theta *= theta_scale; | ||||
|                         theta_base *= theta_scale; | ||||
|                         block_theta *= theta_scale; | ||||
| 
 | ||||
|                         const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); | ||||
|  | @ -10994,13 +11071,16 @@ static void ggml_compute_forward_rope_f32( | |||
|                     } | ||||
|                 } else if (!is_neox) { | ||||
|                     for (int64_t i0 = 0; i0 < ne0; i0 += 2) { | ||||
|                         const float cos_theta = cosf(theta); | ||||
|                         const float sin_theta = sinf(theta); | ||||
|                         float cos_theta, sin_theta; | ||||
|                         rope_yarn( | ||||
|                             theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta | ||||
|                         ); | ||||
| 
 | ||||
|                         // zeta scaling for xPos only:
 | ||||
|                         float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f; | ||||
|                         if (xpos_down) zeta = 1.0f / zeta; | ||||
| 
 | ||||
|                         theta *= theta_scale; | ||||
|                         theta_base *= theta_scale; | ||||
| 
 | ||||
|                         const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); | ||||
|                               float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0); | ||||
|  | @ -11014,12 +11094,19 @@ static void ggml_compute_forward_rope_f32( | |||
|                 } else { | ||||
|                     // TODO: this might be wrong for ne0 != n_dims - need double check
 | ||||
|                     // ref:  https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
 | ||||
|                     theta_base *= freq_scale; | ||||
|                     for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { | ||||
|                         for (int64_t ic = 0; ic < n_dims; ic += 2) { | ||||
|                             const float cos_theta = cosf(theta); | ||||
|                             const float sin_theta = sinf(theta); | ||||
|                             // simplified from `(ib * n_dims + ic) * inv_ndims`
 | ||||
|                             float cur_rot = inv_ndims * ic - ib; | ||||
| 
 | ||||
|                             theta *= theta_scale; | ||||
|                             float cos_theta, sin_theta; | ||||
|                             rope_yarn( | ||||
|                                 theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, | ||||
|                                 &cos_theta, &sin_theta | ||||
|                             ); | ||||
| 
 | ||||
|                             theta_base *= theta_scale; | ||||
| 
 | ||||
|                             const int64_t i0 = ib*n_dims + ic/2; | ||||
| 
 | ||||
|  | @ -11048,15 +11135,19 @@ static void ggml_compute_forward_rope_f16( | |||
|         return; | ||||
|     } | ||||
| 
 | ||||
|     float freq_base; | ||||
|     float freq_scale; | ||||
|     float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; | ||||
| 
 | ||||
|     //const int n_past     = ((int32_t *) dst->op_params)[0];
 | ||||
|     const int n_dims     = ((int32_t *) dst->op_params)[1]; | ||||
|     const int mode       = ((int32_t *) dst->op_params)[2]; | ||||
|     const int n_ctx      = ((int32_t *) dst->op_params)[3]; | ||||
|     memcpy(&freq_base,  (int32_t *) dst->op_params + 4, sizeof(float)); | ||||
|     memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); | ||||
|     const int n_orig_ctx = ((int32_t *) dst->op_params)[4]; | ||||
|     memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float)); | ||||
|     memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float)); | ||||
|     memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float)); | ||||
|     memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float)); | ||||
|     memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float)); | ||||
|     memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float)); | ||||
| 
 | ||||
|     GGML_TENSOR_UNARY_OP_LOCALS | ||||
| 
 | ||||
|  | @ -11084,6 +11175,9 @@ static void ggml_compute_forward_rope_f16( | |||
|     int ir = 0; | ||||
| 
 | ||||
|     const float theta_scale = powf(freq_base, -2.0f/n_dims); | ||||
|     const float inv_ndims = -1.f/n_dims; | ||||
|     float corr_dims[2]; | ||||
|     ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); | ||||
| 
 | ||||
|     const bool is_neox = mode & 2; | ||||
|     const bool is_glm  = mode & 4; | ||||
|  | @ -11097,18 +11191,18 @@ static void ggml_compute_forward_rope_f16( | |||
|                 if (ir++ < ir0) continue; | ||||
|                 if (ir   > ir1) break; | ||||
| 
 | ||||
|                 float theta = freq_scale * (float)p; | ||||
|                 float theta_base = (float)p; | ||||
| 
 | ||||
|                 if (is_glm) { | ||||
|                     theta = MIN(p, n_ctx - 2); | ||||
|                     theta_base = MIN(p, n_ctx - 2); | ||||
|                     float block_theta = MAX(p - (n_ctx - 2), 0); | ||||
|                     for (int64_t i0 = 0; i0 < ne0 / 4; i0++) { | ||||
|                         const float cos_theta = cosf(theta); | ||||
|                         const float sin_theta = sinf(theta); | ||||
|                         const float cos_theta = cosf(theta_base); | ||||
|                         const float sin_theta = sinf(theta_base); | ||||
|                         const float cos_block_theta = cosf(block_theta); | ||||
|                         const float sin_block_theta = sinf(block_theta); | ||||
| 
 | ||||
|                         theta *= theta_scale; | ||||
|                         theta_base *= theta_scale; | ||||
|                         block_theta *= theta_scale; | ||||
| 
 | ||||
|                         const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); | ||||
|  | @ -11126,10 +11220,12 @@ static void ggml_compute_forward_rope_f16( | |||
|                     } | ||||
|                 } else if (!is_neox) { | ||||
|                     for (int64_t i0 = 0; i0 < ne0; i0 += 2) { | ||||
|                         const float cos_theta = cosf(theta); | ||||
|                         const float sin_theta = sinf(theta); | ||||
|                         float cos_theta, sin_theta; | ||||
|                         rope_yarn( | ||||
|                             theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta | ||||
|                         ); | ||||
| 
 | ||||
|                         theta *= theta_scale; | ||||
|                         theta_base *= theta_scale; | ||||
| 
 | ||||
|                         const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); | ||||
|                               ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0); | ||||
|  | @ -11143,12 +11239,19 @@ static void ggml_compute_forward_rope_f16( | |||
|                 } else { | ||||
|                     // TODO: this might be wrong for ne0 != n_dims - need double check
 | ||||
|                     // ref:  https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
 | ||||
|                     theta_base *= freq_scale; | ||||
|                     for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { | ||||
|                         for (int64_t ic = 0; ic < n_dims; ic += 2) { | ||||
|                             const float cos_theta = cosf(theta); | ||||
|                             const float sin_theta = sinf(theta); | ||||
|                             // simplified from `(ib * n_dims + ic) * inv_ndims`
 | ||||
|                             float cur_rot = inv_ndims * ic - ib; | ||||
| 
 | ||||
|                             theta *= theta_scale; | ||||
|                             float cos_theta, sin_theta; | ||||
|                             rope_yarn( | ||||
|                                 theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, | ||||
|                                 &cos_theta, &sin_theta | ||||
|                             ); | ||||
| 
 | ||||
|                             theta_base *= theta_scale; | ||||
| 
 | ||||
|                             const int64_t i0 = ib*n_dims + ic/2; | ||||
| 
 | ||||
|  | @ -11256,17 +11359,18 @@ static void ggml_compute_forward_rope_back_f32( | |||
|                 if (ir++ < ir0) continue; | ||||
|                 if (ir   > ir1) break; | ||||
| 
 | ||||
|                 float theta = freq_scale * (float)p; | ||||
|                 float theta_base = freq_scale * (float)p; | ||||
| 
 | ||||
|                 if (!is_neox) { | ||||
|                     for (int64_t i0 = 0; i0 < ne0; i0 += 2) { | ||||
|                         const float cos_theta = cosf(theta); | ||||
|                         const float sin_theta = sinf(theta); | ||||
|                         const float cos_theta = cosf(theta_base); | ||||
|                         const float sin_theta = sinf(theta_base); | ||||
| 
 | ||||
|                         // zeta scaling for xPos only:
 | ||||
|                         float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f; | ||||
|                         if (xpos_down) zeta = 1.0f / zeta; | ||||
| 
 | ||||
|                         theta *= theta_scale; | ||||
|                         theta_base *= theta_scale; | ||||
| 
 | ||||
|                         const float * const dy  = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); | ||||
|                               float *       dx  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0); | ||||
|  | @ -11280,10 +11384,10 @@ static void ggml_compute_forward_rope_back_f32( | |||
|                 } else { | ||||
|                     for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { | ||||
|                         for (int64_t ic = 0; ic < n_dims; ic += 2) { | ||||
|                             const float cos_theta = cosf(theta); | ||||
|                             const float sin_theta = sinf(theta); | ||||
|                             const float cos_theta = cosf(theta_base); | ||||
|                             const float sin_theta = sinf(theta_base); | ||||
| 
 | ||||
|                             theta *= theta_scale; | ||||
|                             theta_base *= theta_scale; | ||||
| 
 | ||||
|                             const int64_t i0 = ib*n_dims + ic/2; | ||||
| 
 | ||||
|  | @ -11356,14 +11460,14 @@ static void ggml_compute_forward_rope_back_f16( | |||
|                 if (ir++ < ir0) continue; | ||||
|                 if (ir   > ir1) break; | ||||
| 
 | ||||
|                 float theta = (float)p; | ||||
|                 float theta_base = (float)p; | ||||
| 
 | ||||
|                 if (!is_neox) { | ||||
|                     for (int64_t i0 = 0; i0 < ne0; i0 += 2) { | ||||
|                         const float cos_theta = cosf(theta); | ||||
|                         const float sin_theta = sinf(theta); | ||||
|                         const float cos_theta = cosf(theta_base); | ||||
|                         const float sin_theta = sinf(theta_base); | ||||
| 
 | ||||
|                         theta *= theta_scale; | ||||
|                         theta_base *= theta_scale; | ||||
| 
 | ||||
|                         const ggml_fp16_t * const dy  = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); | ||||
|                               ggml_fp16_t *       dx  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0); | ||||
|  | @ -11377,10 +11481,10 @@ static void ggml_compute_forward_rope_back_f16( | |||
|                 } else { | ||||
|                     for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { | ||||
|                         for (int64_t ic = 0; ic < n_dims; ic += 2) { | ||||
|                             const float cos_theta = cosf(theta); | ||||
|                             const float sin_theta = sinf(theta); | ||||
|                             const float cos_theta = cosf(theta_base); | ||||
|                             const float sin_theta = sinf(theta_base); | ||||
| 
 | ||||
|                             theta *= theta_scale; | ||||
|                             theta_base *= theta_scale; | ||||
| 
 | ||||
|                             const int64_t i0 = ib*n_dims + ic/2; | ||||
| 
 | ||||
|  | @ -15505,9 +15609,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor | |||
|                                 src1, | ||||
|                                 n_dims, | ||||
|                                 mode, | ||||
|                                 0, | ||||
|                                 n_ctx, | ||||
|                                 freq_base, | ||||
|                                 freq_scale, | ||||
|                                 0.0f, | ||||
|                                 1.0f, | ||||
|                                 0.0f, | ||||
|                                 0.0f, | ||||
|                                 xpos_base, | ||||
|                                 xpos_down, | ||||
|                                 false), | ||||
|  |  | |||
							
								
								
									
										20
									
								
								ggml.h
									
										
									
									
									
								
							
							
						
						
									
										20
									
								
								ggml.h
									
										
									
									
									
								
							|  | @ -219,7 +219,7 @@ | |||
| #define GGML_MAX_CONTEXTS      64 | ||||
| #define GGML_MAX_SRC           6 | ||||
| #define GGML_MAX_NAME          64 | ||||
| #define GGML_MAX_OP_PARAMS     32 | ||||
| #define GGML_MAX_OP_PARAMS     64 | ||||
| #define GGML_DEFAULT_N_THREADS 4 | ||||
| 
 | ||||
| #if UINTPTR_MAX == 0xFFFFFFFF | ||||
|  | @ -1326,8 +1326,13 @@ extern "C" { | |||
|             int                   n_dims, | ||||
|             int                   mode, | ||||
|             int                   n_ctx, | ||||
|             int                   n_orig_ctx, | ||||
|             float                 freq_base, | ||||
|             float                 freq_scale); | ||||
|             float                 freq_scale, | ||||
|             float                 ext_factor, | ||||
|             float                 attn_factor, | ||||
|             float                 beta_fast, | ||||
|             float                 beta_slow); | ||||
| 
 | ||||
|     // in-place, returns view(a)
 | ||||
|     GGML_API struct ggml_tensor * ggml_rope_custom_inplace( | ||||
|  | @ -1337,8 +1342,17 @@ extern "C" { | |||
|             int                   n_dims, | ||||
|             int                   mode, | ||||
|             int                   n_ctx, | ||||
|             int                   n_orig_ctx, | ||||
|             float                 freq_base, | ||||
|             float                 freq_scale); | ||||
|             float                 freq_scale, | ||||
|             float                 ext_factor, | ||||
|             float                 attn_factor, | ||||
|             float                 beta_fast, | ||||
|             float                 beta_slow); | ||||
| 
 | ||||
|     // compute correction dims for YaRN RoPE scaling
 | ||||
|     void ggml_rope_yarn_corr_dims( | ||||
|         int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]); | ||||
| 
 | ||||
|     // xPos RoPE, in-place, returns view(a)
 | ||||
|     GGML_API struct ggml_tensor * ggml_rope_xpos_inplace( | ||||
|  |  | |||
|  | @ -7,7 +7,7 @@ import shutil | |||
| import struct | ||||
| import sys | ||||
| import tempfile | ||||
| from enum import IntEnum, auto | ||||
| from enum import Enum, IntEnum, auto | ||||
| from io import BufferedWriter | ||||
| from pathlib import Path | ||||
| from typing import IO, Any, BinaryIO, Callable, Sequence | ||||
|  | @ -55,7 +55,10 @@ KEY_ATTENTION_LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon" | |||
| # RoPE | ||||
| KEY_ROPE_DIMENSION_COUNT         = "{arch}.rope.dimension_count" | ||||
| KEY_ROPE_FREQ_BASE               = "{arch}.rope.freq_base" | ||||
| KEY_ROPE_SCALE_LINEAR    = "{arch}.rope.scale_linear" | ||||
| KEY_ROPE_SCALING_TYPE            = "{arch}.rope.scaling.type" | ||||
| KEY_ROPE_SCALING_FACTOR          = "{arch}.rope.scaling.factor" | ||||
| KEY_ROPE_SCALING_ORIG_CTX_LEN    = "{arch}.rope.scaling.original_context_length" | ||||
| KEY_ROPE_SCALING_FINETUNED       = "{arch}.rope.scaling.finetuned" | ||||
| 
 | ||||
| # tokenization | ||||
| KEY_TOKENIZER_MODEL      = "tokenizer.ggml.model" | ||||
|  | @ -577,6 +580,11 @@ class TokenType(IntEnum): | |||
|     UNUSED       = 5 | ||||
|     BYTE         = 6 | ||||
| 
 | ||||
| class RopeScalingType(Enum): | ||||
|     NONE   = 'none' | ||||
|     LINEAR = 'linear' | ||||
|     YARN   = 'yarn' | ||||
| 
 | ||||
| # | ||||
| # implementation | ||||
| # | ||||
|  | @ -948,8 +956,17 @@ class GGUFWriter: | |||
|     def add_rope_freq_base(self, value: float): | ||||
|         self.add_float32(KEY_ROPE_FREQ_BASE.format(arch=self.arch), value) | ||||
| 
 | ||||
|     def add_rope_scale_linear(self, value: float): | ||||
|         self.add_float32(KEY_ROPE_SCALE_LINEAR.format(arch=self.arch), value) | ||||
|     def add_rope_scaling_type(self, value: RopeScalingType): | ||||
|         self.add_string(KEY_ROPE_SCALING_TYPE.format(arch=self.arch), value.value) | ||||
| 
 | ||||
|     def add_rope_scaling_factor(self, value: float): | ||||
|         self.add_float32(KEY_ROPE_SCALING_FACTOR.format(arch=self.arch), value) | ||||
| 
 | ||||
|     def add_rope_scaling_orig_ctx_len(self, value: int): | ||||
|         self.add_uint32(KEY_ROPE_SCALING_ORIG_CTX_LEN.format(arch=self.arch), value) | ||||
| 
 | ||||
|     def add_rope_scaling_finetuned(self, value: bool): | ||||
|         self.add_bool(KEY_ROPE_SCALING_FINETUNED.format(arch=self.arch), value) | ||||
| 
 | ||||
|     def add_tokenizer_model(self, model: str): | ||||
|         self.add_string(KEY_TOKENIZER_MODEL, model) | ||||
|  |  | |||
							
								
								
									
										192
									
								
								llama.cpp
									
										
									
									
									
								
							
							
						
						
									
										192
									
								
								llama.cpp
									
										
									
									
									
								
							|  | @ -54,6 +54,7 @@ | |||
| #include <cassert> | ||||
| #include <cinttypes> | ||||
| #include <climits> | ||||
| #include <cmath> | ||||
| #include <cstdarg> | ||||
| #include <cstddef> | ||||
| #include <cstdint> | ||||
|  | @ -235,6 +236,10 @@ enum llm_kv { | |||
|     LLM_KV_ROPE_DIMENSION_COUNT, | ||||
|     LLM_KV_ROPE_FREQ_BASE, | ||||
|     LLM_KV_ROPE_SCALE_LINEAR, | ||||
|     LLM_KV_ROPE_SCALING_TYPE, | ||||
|     LLM_KV_ROPE_SCALING_FACTOR, | ||||
|     LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, | ||||
|     LLM_KV_ROPE_SCALING_FINETUNED, | ||||
| 
 | ||||
|     LLM_KV_TOKENIZER_MODEL, | ||||
|     LLM_KV_TOKENIZER_LIST, | ||||
|  | @ -279,6 +284,10 @@ static std::map<llm_kv, std::string> LLM_KV_NAMES = { | |||
|     { LLM_KV_ROPE_DIMENSION_COUNT,          "%s.rope.dimension_count"                 }, | ||||
|     { LLM_KV_ROPE_FREQ_BASE,                "%s.rope.freq_base"                       }, | ||||
|     { LLM_KV_ROPE_SCALE_LINEAR,             "%s.rope.scale_linear"                    }, | ||||
|     { LLM_KV_ROPE_SCALING_TYPE,             "%s.rope.scaling.type"                    }, | ||||
|     { LLM_KV_ROPE_SCALING_FACTOR,           "%s.rope.scaling.factor"                  }, | ||||
|     { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,     "%s.rope.scaling.original_context_length" }, | ||||
|     { LLM_KV_ROPE_SCALING_FINETUNED,        "%s.rope.scaling.finetuned"               }, | ||||
| 
 | ||||
|     { LLM_KV_TOKENIZER_MODEL,               "tokenizer.ggml.model"              }, | ||||
|     { LLM_KV_TOKENIZER_LIST,                "tokenizer.ggml.tokens"             }, | ||||
|  | @ -552,6 +561,22 @@ do { \ | |||
|     } \ | ||||
| } while (0) | ||||
| 
 | ||||
| static std::map<int8_t, std::string> LLAMA_ROPE_SCALING_TYPES = { | ||||
|     { LLAMA_ROPE_SCALING_NONE,   "none"   }, | ||||
|     { LLAMA_ROPE_SCALING_LINEAR, "linear" }, | ||||
|     { LLAMA_ROPE_SCALING_YARN,   "yarn"   }, | ||||
| }; | ||||
| 
 | ||||
| static int8_t llama_rope_scaling_type_from_string(const std::string & name) { | ||||
|     for (const auto & kv : LLAMA_ROPE_SCALING_TYPES) { | ||||
|         if (kv.second == name) { | ||||
|             return kv.first; | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     return LLAMA_ROPE_SCALING_UNSPECIFIED; | ||||
| } | ||||
| 
 | ||||
| //
 | ||||
| // ggml helpers
 | ||||
| //
 | ||||
|  | @ -1037,6 +1062,9 @@ struct llama_hparams { | |||
| 
 | ||||
|     float    rope_freq_base_train; | ||||
|     float    rope_freq_scale_train; | ||||
|     uint32_t n_yarn_orig_ctx; | ||||
|     int8_t   rope_scaling_type_train : 3; | ||||
|     bool     rope_finetuned : 1; | ||||
| 
 | ||||
|     float f_clamp_kqv; | ||||
|     float f_max_alibi_bias; | ||||
|  | @ -1051,6 +1079,8 @@ struct llama_hparams { | |||
|         if (this->n_layer     != other.n_layer)     return true; | ||||
|         if (this->n_rot       != other.n_rot)       return true; | ||||
|         if (this->n_ff        != other.n_ff)        return true; | ||||
|         if (this->rope_finetuned  != other.rope_finetuned)  return true; | ||||
|         if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true; | ||||
| 
 | ||||
|         const float EPSILON = 1e-9; | ||||
| 
 | ||||
|  | @ -1084,6 +1114,14 @@ struct llama_cparams { | |||
|     float    rope_freq_base; | ||||
|     float    rope_freq_scale; | ||||
| 
 | ||||
|     uint32_t n_yarn_orig_ctx; | ||||
|     // These hyperparameters are not exposed in GGUF, because all
 | ||||
|     // existing YaRN models use the same values for them.
 | ||||
|     float yarn_ext_factor; | ||||
|     float yarn_attn_factor; | ||||
|     float yarn_beta_fast; | ||||
|     float yarn_beta_slow; | ||||
| 
 | ||||
|     bool mul_mat_q; | ||||
| }; | ||||
| 
 | ||||
|  | @ -2014,14 +2052,30 @@ static void llm_load_hparams( | |||
|     hparams.n_head_kv = hparams.n_head; | ||||
|     GGUF_GET_KEY(ctx, hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ATTENTION_HEAD_COUNT_KV)); | ||||
| 
 | ||||
|     hparams.rope_finetuned = false; | ||||
|     GGUF_GET_KEY(ctx, hparams.rope_finetuned, gguf_get_val_bool, GGUF_TYPE_BOOL, false, | ||||
|                  kv(LLM_KV_ROPE_SCALING_FINETUNED)); | ||||
| 
 | ||||
|     hparams.n_yarn_orig_ctx = hparams.n_ctx_train; | ||||
|     GGUF_GET_KEY(ctx, hparams.n_yarn_orig_ctx, gguf_get_val_u32, GGUF_TYPE_UINT32, false, | ||||
|                  kv(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN)); | ||||
| 
 | ||||
|     // rope_freq_base (optional)
 | ||||
|     hparams.rope_freq_base_train = 10000.0f; | ||||
|     GGUF_GET_KEY(ctx, hparams.rope_freq_base_train, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE)); | ||||
| 
 | ||||
|     std::string rope_scaling("linear"); | ||||
|     GGUF_GET_KEY(ctx, rope_scaling, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_ROPE_SCALING_TYPE)); | ||||
|     hparams.rope_scaling_type_train = llama_rope_scaling_type_from_string(rope_scaling); | ||||
|     GGML_ASSERT(hparams.rope_scaling_type_train != LLAMA_ROPE_SCALING_UNSPECIFIED); | ||||
| 
 | ||||
|     // rope_freq_scale (inverse of the kv) is optional
 | ||||
|     float ropescale = 1.0f; | ||||
|     float ropescale = 0.0f; | ||||
|     GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALING_FACTOR)); | ||||
|     if (ropescale == 0.0f) { // try the old key name
 | ||||
|         GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR)); | ||||
|     hparams.rope_freq_scale_train = 1.0f/ropescale; | ||||
|     } | ||||
|     hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale; | ||||
| 
 | ||||
|     // sanity check for n_rot (optional)
 | ||||
|     { | ||||
|  | @ -2371,6 +2425,8 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { | |||
|     const auto & hparams = model.hparams; | ||||
|     const auto & vocab   = model.vocab; | ||||
| 
 | ||||
|     const auto rope_scaling_type = LLAMA_ROPE_SCALING_TYPES.at(hparams.rope_scaling_type_train); | ||||
| 
 | ||||
|     // hparams
 | ||||
|     LLAMA_LOG_INFO("%s: format           = %s\n",     __func__, llama_file_version_name(ml.fver)); | ||||
|     LLAMA_LOG_INFO("%s: arch             = %s\n",     __func__, LLM_ARCH_NAMES.at(model.arch).c_str()); | ||||
|  | @ -2389,8 +2445,11 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { | |||
|     LLAMA_LOG_INFO("%s: f_clamp_kqv      = %.1e\n",   __func__, hparams.f_clamp_kqv); | ||||
|     LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n",   __func__, hparams.f_max_alibi_bias); | ||||
|     LLAMA_LOG_INFO("%s: n_ff             = %u\n",     __func__, hparams.n_ff); | ||||
|     LLAMA_LOG_INFO("%s: rope scaling     = %s\n",     __func__, rope_scaling_type.c_str()); | ||||
|     LLAMA_LOG_INFO("%s: freq_base_train  = %.1f\n",   __func__, hparams.rope_freq_base_train); | ||||
|     LLAMA_LOG_INFO("%s: freq_scale_train = %g\n",     __func__, hparams.rope_freq_scale_train); | ||||
|     LLAMA_LOG_INFO("%s: n_yarn_orig_ctx  = %u\n",     __func__, hparams.n_yarn_orig_ctx); | ||||
|     LLAMA_LOG_INFO("%s: rope_finetuned   = %s\n",     __func__, hparams.rope_finetuned ? "yes" : "unknown"); | ||||
|     LLAMA_LOG_INFO("%s: model type       = %s\n",     __func__, llama_model_type_name(model.type)); | ||||
|     LLAMA_LOG_INFO("%s: model ftype      = %s\n",     __func__, llama_model_ftype_name(model.ftype).c_str()); | ||||
|     LLAMA_LOG_INFO("%s: model params     = %.2f B\n", __func__, ml.n_elements*1e-9); | ||||
|  | @ -3047,21 +3106,11 @@ static void llm_load_tensors( | |||
|     model.t_load_us = ggml_time_us() - model.t_start_us; | ||||
| } | ||||
| 
 | ||||
| static bool llama_model_load( | ||||
|         const std::string & fname, | ||||
|         llama_model & model, | ||||
|         int n_gpu_layers, | ||||
|         int main_gpu, | ||||
|         const float * tensor_split, | ||||
|         bool use_mmap, | ||||
|         bool use_mlock, | ||||
|         bool vocab_only, | ||||
|         llama_progress_callback progress_callback, | ||||
|         void *progress_callback_user_data) { | ||||
| static bool llama_model_load(const std::string & fname, llama_model & model, const llama_model_params & params) { | ||||
|     try { | ||||
|         llama_model_loader ml(fname, use_mmap); | ||||
|         llama_model_loader ml(fname, params.use_mmap); | ||||
| 
 | ||||
|         model.hparams.vocab_only = vocab_only; | ||||
|         model.hparams.vocab_only = params.vocab_only; | ||||
| 
 | ||||
|         llm_load_arch   (ml, model); | ||||
|         llm_load_hparams(ml, model); | ||||
|  | @ -3073,15 +3122,15 @@ static bool llama_model_load( | |||
|             throw std::runtime_error("vocab size mismatch"); | ||||
|         } | ||||
| 
 | ||||
|         if (vocab_only) { | ||||
|         if (params.vocab_only) { | ||||
|             LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__); | ||||
|             return true; | ||||
|         } | ||||
| 
 | ||||
|         llm_load_tensors( | ||||
|                 ml, model, n_gpu_layers, | ||||
|                 main_gpu, tensor_split, | ||||
|                 use_mlock, progress_callback, progress_callback_user_data); | ||||
|             ml, model, params.n_gpu_layers, params.main_gpu, params.tensor_split, params.use_mlock, | ||||
|             params.progress_callback, params.progress_callback_user_data | ||||
|         ); | ||||
|     } catch (const std::exception & err) { | ||||
|         LLAMA_LOG_ERROR("error loading model: %s\n", err.what()); | ||||
|         return false; | ||||
|  | @ -3150,6 +3199,7 @@ static struct ggml_tensor * llm_build_inp_embd( | |||
| static void llm_build_k_shift( | ||||
|       struct ggml_context * ctx, | ||||
|       const llama_hparams & hparams, | ||||
|       const llama_cparams & cparams, | ||||
|      const llama_kv_cache & kv, | ||||
|        struct ggml_cgraph * graph, | ||||
|             llm_rope_type   type, | ||||
|  | @ -3162,6 +3212,11 @@ static void llm_build_k_shift( | |||
|     const int64_t n_head_kv   = hparams.n_head_kv; | ||||
|     const int64_t n_embd_gqa  = hparams.n_embd_gqa(); | ||||
|     const int64_t n_embd_head = hparams.n_embd_head(); | ||||
|     const int32_t n_orig_ctx  = cparams.n_yarn_orig_ctx; | ||||
|     const float   ext_factor  = cparams.yarn_ext_factor; | ||||
|     const float   attn_factor = cparams.yarn_attn_factor; | ||||
|     const float   beta_fast   = cparams.yarn_beta_fast; | ||||
|     const float   beta_slow   = cparams.yarn_beta_slow; | ||||
| 
 | ||||
|     GGML_ASSERT(n_embd_head % n_rot == 0); | ||||
| 
 | ||||
|  | @ -3185,7 +3240,8 @@ static void llm_build_k_shift( | |||
|                         ggml_element_size(kv.k)*n_embd_head, | ||||
|                         ggml_element_size(kv.k)*n_embd_gqa, | ||||
|                         ggml_element_size(kv.k)*n_embd_gqa*n_ctx*il), | ||||
|                     K_shift, n_rot, rope_type, 0, freq_base, freq_scale); | ||||
|                     K_shift, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, | ||||
|                     ext_factor, attn_factor, beta_fast, beta_slow); | ||||
|         cb(tmp, "K_shifted", il); | ||||
|         ggml_build_forward_expand(graph, tmp); | ||||
|     } | ||||
|  | @ -3442,12 +3498,17 @@ struct llm_build_context { | |||
| 
 | ||||
|     const float freq_base; | ||||
|     const float freq_scale; | ||||
|     const float ext_factor; | ||||
|     const float attn_factor; | ||||
|     const float beta_fast; | ||||
|     const float beta_slow; | ||||
|     const float norm_eps; | ||||
|     const float norm_rms_eps; | ||||
| 
 | ||||
|     const int32_t n_tokens; | ||||
|     const int32_t n_kv;     // size of KV cache to consider (n_kv <= n_ctx)
 | ||||
|     const int32_t kv_head;  // index of where we store new KV data in the cache
 | ||||
|     const int32_t n_orig_ctx; | ||||
| 
 | ||||
|     const bool do_rope_shift; | ||||
| 
 | ||||
|  | @ -3477,11 +3538,16 @@ struct llm_build_context { | |||
|         n_embd_gqa    (hparams.n_embd_gqa()), | ||||
|         freq_base     (cparams.rope_freq_base), | ||||
|         freq_scale    (cparams.rope_freq_scale), | ||||
|         ext_factor    (cparams.yarn_ext_factor), | ||||
|         attn_factor   (cparams.yarn_attn_factor), | ||||
|         beta_fast     (cparams.yarn_beta_fast), | ||||
|         beta_slow     (cparams.yarn_beta_slow), | ||||
|         norm_eps      (hparams.f_norm_eps), | ||||
|         norm_rms_eps  (hparams.f_norm_rms_eps), | ||||
|         n_tokens      (batch.n_tokens), | ||||
|         n_kv          (worst_case ? n_ctx            : kv_self.n), | ||||
|         kv_head       (worst_case ? n_ctx - n_tokens : kv_self.head), | ||||
|         n_orig_ctx    (cparams.n_yarn_orig_ctx), | ||||
|         do_rope_shift (worst_case || kv_self.has_shift), | ||||
|         cb            (cb), | ||||
|         buf_compute   (lctx.buf_compute) { | ||||
|  | @ -3532,7 +3598,7 @@ struct llm_build_context { | |||
| 
 | ||||
|         // shift the entire K-cache if needed
 | ||||
|         if (do_rope_shift) { | ||||
|             llm_build_k_shift(ctx0, hparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb); | ||||
|             llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb); | ||||
|         } | ||||
| 
 | ||||
|         for (int il = 0; il < n_layer; ++il) { | ||||
|  | @ -3556,10 +3622,18 @@ struct llm_build_context { | |||
|                 struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); | ||||
|                 cb(Vcur, "Vcur", il); | ||||
| 
 | ||||
|                 Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale); | ||||
|                 Qcur = ggml_rope_custom( | ||||
|                     ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos, | ||||
|                     n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale, | ||||
|                     ext_factor, attn_factor, beta_fast, beta_slow | ||||
|                 ); | ||||
|                 cb(Qcur, "Qcur", il); | ||||
| 
 | ||||
|                 Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale); | ||||
|                 Kcur = ggml_rope_custom( | ||||
|                     ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, | ||||
|                     n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale, | ||||
|                     ext_factor, attn_factor, beta_fast, beta_slow | ||||
|                 ); | ||||
|                 cb(Kcur, "Kcur", il); | ||||
| 
 | ||||
|                 llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); | ||||
|  | @ -3634,7 +3708,7 @@ struct llm_build_context { | |||
| 
 | ||||
|         // shift the entire K-cache if needed
 | ||||
|         if (do_rope_shift) { | ||||
|             llm_build_k_shift(ctx0, hparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb); | ||||
|             llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb); | ||||
|         } | ||||
| 
 | ||||
|         for (int il = 0; il < n_layer; ++il) { | ||||
|  | @ -3658,8 +3732,16 @@ struct llm_build_context { | |||
| 
 | ||||
|                 switch (model.type) { | ||||
|                     case MODEL_7B: | ||||
|                         Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens),    inp_pos, n_embd_head, 0, 0, freq_base, freq_scale); | ||||
|                         Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale); | ||||
|                         Qcur = ggml_rope_custom( | ||||
|                             ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, | ||||
|                             n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale, | ||||
|                             ext_factor, attn_factor, beta_fast, beta_slow | ||||
|                         ); | ||||
|                         Kcur = ggml_rope_custom( | ||||
|                             ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, | ||||
|                             n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale, | ||||
|                             ext_factor, attn_factor, beta_fast, beta_slow | ||||
|                         ); | ||||
|                         break; | ||||
|                     case MODEL_13B: | ||||
|                         Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd/n_head, n_head, n_tokens); | ||||
|  | @ -3746,7 +3828,7 @@ struct llm_build_context { | |||
| 
 | ||||
|         // shift the entire K-cache if needed
 | ||||
|         if (do_rope_shift) { | ||||
|             llm_build_k_shift(ctx0, hparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb); | ||||
|             llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb); | ||||
|         } | ||||
| 
 | ||||
|         for (int il = 0; il < n_layer; ++il) { | ||||
|  | @ -3786,10 +3868,16 @@ struct llm_build_context { | |||
|                 Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); | ||||
| 
 | ||||
|                 // using mode = 2 for neox mode
 | ||||
|                 Qcur = ggml_rope_custom(ctx0, Qcur, inp_pos, n_embd_head, 2, 0, freq_base, freq_scale); | ||||
|                 Qcur = ggml_rope_custom( | ||||
|                     ctx0, Qcur, inp_pos, n_embd_head, 2, 0, n_orig_ctx, | ||||
|                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow | ||||
|                 ); | ||||
|                 cb(Qcur, "Qcur", il); | ||||
| 
 | ||||
|                 Kcur = ggml_rope_custom(ctx0, Kcur, inp_pos, n_embd_head, 2, 0, freq_base, freq_scale); | ||||
|                 Kcur = ggml_rope_custom( | ||||
|                     ctx0, Kcur, inp_pos, n_embd_head, 2, 0, n_orig_ctx, | ||||
|                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow | ||||
|                 ); | ||||
|                 cb(Kcur, "Kcur", il); | ||||
| 
 | ||||
|                 llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); | ||||
|  | @ -3960,7 +4048,7 @@ struct llm_build_context { | |||
|         cb(KQ_mask, "KQ_mask", -1); | ||||
| 
 | ||||
|         if (do_rope_shift) { | ||||
|             llm_build_k_shift(ctx0, hparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb); | ||||
|             llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb); | ||||
|         } | ||||
| 
 | ||||
|         for (int il = 0; il < n_layer; ++il) { | ||||
|  | @ -4053,12 +4141,14 @@ struct llm_build_context { | |||
|                 cb(kpass, "kpass", il); | ||||
| 
 | ||||
|                 struct ggml_tensor * qrotated = ggml_rope_custom( | ||||
|                         ctx0, qrot, inp_pos, n_rot, 2, 0, freq_base, freq_scale | ||||
|                     ctx0, qrot, inp_pos, n_rot, 2, 0, n_orig_ctx, | ||||
|                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow | ||||
|                 ); | ||||
|                 cb(qrotated, "qrotated", il); | ||||
| 
 | ||||
|                 struct ggml_tensor * krotated = ggml_rope_custom( | ||||
|                         ctx0, krot, inp_pos, n_rot, 2, 0, freq_base, freq_scale | ||||
|                     ctx0, krot, inp_pos, n_rot, 2, 0, n_orig_ctx, | ||||
|                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow | ||||
|                 ); | ||||
|                 cb(krotated, "krotated", il); | ||||
| 
 | ||||
|  | @ -7883,8 +7973,13 @@ struct llama_context_params llama_context_default_params() { | |||
|         /*.n_batch                     =*/ 512, | ||||
|         /*.n_threads                   =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
 | ||||
|         /*.n_threads_batch             =*/ GGML_DEFAULT_N_THREADS, | ||||
|         /*.rope_scaling_type           =*/ LLAMA_ROPE_SCALING_UNSPECIFIED, | ||||
|         /*.rope_freq_base              =*/ 0.0f, | ||||
|         /*.rope_freq_scale             =*/ 0.0f, | ||||
|         /*.yarn_ext_factor             =*/ NAN, | ||||
|         /*.yarn_attn_factor            =*/ 1.0f, | ||||
|         /*.yarn_beta_fast              =*/ 32.0f, | ||||
|         /*.yarn_beta_slow              =*/ 1.0f, | ||||
|         /*.mul_mat_q                   =*/ true, | ||||
|         /*.f16_kv                      =*/ true, | ||||
|         /*.logits_all                  =*/ false, | ||||
|  | @ -7971,10 +8066,7 @@ struct llama_model * llama_load_model_from_file( | |||
|         }; | ||||
|     } | ||||
| 
 | ||||
|     if (!llama_model_load(path_model, *model, params.n_gpu_layers, | ||||
|                 params.main_gpu, params.tensor_split, | ||||
|                 params.use_mmap, params.use_mlock, params.vocab_only, | ||||
|                 params.progress_callback, params.progress_callback_user_data)) { | ||||
|     if (!llama_model_load(path_model, *model, params)) { | ||||
|         LLAMA_LOG_ERROR("%s: failed to load model\n", __func__); | ||||
|         delete model; | ||||
|         return nullptr; | ||||
|  | @ -8001,13 +8093,35 @@ struct llama_context * llama_new_context_with_model( | |||
|     auto       & cparams = ctx->cparams; | ||||
| 
 | ||||
|     cparams.n_batch          = params.n_batch; | ||||
|     cparams.n_ctx           = params.n_ctx == 0           ? hparams.n_ctx_train           : params.n_ctx; | ||||
|     cparams.rope_freq_base  = params.rope_freq_base == 0  ? hparams.rope_freq_base_train  : params.rope_freq_base; | ||||
|     cparams.rope_freq_scale = params.rope_freq_scale == 0 ? hparams.rope_freq_scale_train : params.rope_freq_scale; | ||||
|     cparams.n_threads        = params.n_threads; | ||||
|     cparams.n_threads_batch  = params.n_threads_batch; | ||||
|     cparams.yarn_ext_factor  = params.yarn_ext_factor; | ||||
|     cparams.yarn_attn_factor = params.yarn_attn_factor; | ||||
|     cparams.yarn_beta_fast   = params.yarn_beta_fast; | ||||
|     cparams.yarn_beta_slow   = params.yarn_beta_slow; | ||||
|     cparams.mul_mat_q        = params.mul_mat_q; | ||||
| 
 | ||||
|     cparams.n_ctx            = params.n_ctx           == 0    ? hparams.n_ctx_train           : params.n_ctx; | ||||
|     cparams.rope_freq_base   = params.rope_freq_base  == 0.0f ? hparams.rope_freq_base_train  : params.rope_freq_base; | ||||
|     cparams.rope_freq_scale  = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale; | ||||
| 
 | ||||
|     cparams.n_yarn_orig_ctx  = params.yarn_orig_ctx    != 0 ? params.yarn_orig_ctx    : | ||||
|                                hparams.n_yarn_orig_ctx != 0 ? hparams.n_yarn_orig_ctx : | ||||
|                                                               hparams.n_ctx_train; | ||||
| 
 | ||||
|     auto rope_scaling_type = params.rope_scaling_type; | ||||
|     if (rope_scaling_type == LLAMA_ROPE_SCALING_UNSPECIFIED) { | ||||
|         rope_scaling_type = hparams.rope_scaling_type_train; | ||||
|     } | ||||
| 
 | ||||
|     if (rope_scaling_type == LLAMA_ROPE_SCALING_NONE) { | ||||
|         cparams.rope_freq_scale = 1.0f; // never scale if scaling type is none
 | ||||
|     } | ||||
| 
 | ||||
|     if (std::isnan(cparams.yarn_ext_factor)) { // NaN indicates 'not set'
 | ||||
|         cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_YARN ? 1.0f : 0.0f; | ||||
|     } | ||||
| 
 | ||||
|     if (params.seed == LLAMA_DEFAULT_SEED) { | ||||
|         params.seed = time(NULL); | ||||
|     } | ||||
|  |  | |||
							
								
								
									
										14
									
								
								llama.h
									
										
									
									
									
								
							
							
						
						
									
										14
									
								
								llama.h
									
										
									
									
									
								
							|  | @ -106,6 +106,14 @@ extern "C" { | |||
|         LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
 | ||||
|     }; | ||||
| 
 | ||||
|     enum llama_rope_scaling_type { | ||||
|         LLAMA_ROPE_SCALING_UNSPECIFIED = -1, | ||||
|         LLAMA_ROPE_SCALING_NONE        = 0, | ||||
|         LLAMA_ROPE_SCALING_LINEAR      = 1, | ||||
|         LLAMA_ROPE_SCALING_YARN        = 2, | ||||
|         LLAMA_ROPE_SCALING_MAX_VALUE   = LLAMA_ROPE_SCALING_YARN, | ||||
|     }; | ||||
| 
 | ||||
|     typedef struct llama_token_data { | ||||
|         llama_token id; // token id
 | ||||
|         float logit;    // log-odds of the token
 | ||||
|  | @ -172,10 +180,16 @@ extern "C" { | |||
|         uint32_t n_batch;         // prompt processing maximum batch size
 | ||||
|         uint32_t n_threads;       // number of threads to use for generation
 | ||||
|         uint32_t n_threads_batch; // number of threads to use for batch processing
 | ||||
|         int8_t   rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
 | ||||
| 
 | ||||
|         // ref: https://github.com/ggerganov/llama.cpp/pull/2054
 | ||||
|         float    rope_freq_base;   // RoPE base frequency, 0 = from model
 | ||||
|         float    rope_freq_scale;  // RoPE frequency scaling factor, 0 = from model
 | ||||
|         float    yarn_ext_factor;  // YaRN extrapolation mix factor, NaN = from model
 | ||||
|         float    yarn_attn_factor; // YaRN magnitude scaling factor
 | ||||
|         float    yarn_beta_fast;   // YaRN low correction dim
 | ||||
|         float    yarn_beta_slow;   // YaRN high correction dim
 | ||||
|         uint32_t yarn_orig_ctx;    // YaRN original context size
 | ||||
| 
 | ||||
|         // Keep the booleans together to avoid misalignment during copy-by-value.
 | ||||
|         bool mul_mat_q;  // if true, use experimental mul_mat_q kernels (DEPRECATED - always true)
 | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue