llama : implement YaRN RoPE scaling (#2268)

Co-authored-by: cebtenzzre <cebtenzzre@gmail.com>
Co-authored-by: Jeffrey Quesnelle <jquesnelle@gmail.com>
This commit is contained in:
cebtenzzre 2023-11-01 18:04:33 -04:00 committed by GitHub
parent c43c2da8af
commit 898aeca90a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 763 additions and 257 deletions

20
ggml.h
View file

@ -219,7 +219,7 @@
#define GGML_MAX_CONTEXTS 64
#define GGML_MAX_SRC 6
#define GGML_MAX_NAME 64
#define GGML_MAX_OP_PARAMS 32
#define GGML_MAX_OP_PARAMS 64
#define GGML_DEFAULT_N_THREADS 4
#if UINTPTR_MAX == 0xFFFFFFFF
@ -1326,8 +1326,13 @@ extern "C" {
int n_dims,
int mode,
int n_ctx,
int n_orig_ctx,
float freq_base,
float freq_scale);
float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow);
// in-place, returns view(a)
GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
@ -1337,8 +1342,17 @@ extern "C" {
int n_dims,
int mode,
int n_ctx,
int n_orig_ctx,
float freq_base,
float freq_scale);
float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow);
// compute correction dims for YaRN RoPE scaling
void ggml_rope_yarn_corr_dims(
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]);
// xPos RoPE, in-place, returns view(a)
GGML_API struct ggml_tensor * ggml_rope_xpos_inplace(