don't hardcode max_pos_emb

This commit is contained in:
Cebtenzzre 2023-09-21 00:01:48 -04:00
parent 43eaf06a2f
commit fe788c45c8
7 changed files with 131 additions and 108 deletions

View file

@ -687,7 +687,7 @@ struct ggml_tensor * llama_build_train_graphs(
const int rope_mode = 0;
return ggml_rope_custom(
ctx, t, n_past, n_rot, rope_mode, n_ctx, rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f
ctx, t, n_past, n_rot, rope_mode, n_ctx, 0, rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f
);
};

View file

@ -4386,7 +4386,7 @@ static __device__ void rope_yarn(
// rope == RoPE == rotary positional embedding
static __global__ void rope_f32(
float * x, float * dst, int ncols, float freq_scale, float ext_factor, float attn_factor, float theta_scale,
const float * x, float * dst, int ncols, float freq_scale, float ext_factor, float attn_factor, float theta_scale,
float p0, int p_delta_rows, rope_corr_dims corr_dims
) {
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
@ -5396,7 +5396,7 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
}
static void rope_f32_cuda(
float * x, float * dst, int ncols, int nrows, float freq_scale, float ext_factor, float attn_factor,
const float * x, float * dst, int ncols, int nrows, float freq_scale, float ext_factor, float attn_factor,
float theta_scale, float p0, int p_delta_rows, rope_corr_dims corr_dims, cudaStream_t stream
) {
GGML_ASSERT(ncols % 2 == 0);
@ -6109,19 +6109,20 @@ inline void ggml_cuda_op_rope(
const int64_t ne01 = src0->ne[1];
const int64_t nrows = ggml_nrows(src0);
const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
const int n_ctx = ((int32_t *) dst->op_params)[3];
const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
const int n_ctx = ((int32_t *) dst->op_params)[3];
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
// RoPE alteration for extended context
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&ext_factor, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&attn_factor, (int32_t *) dst->op_params + 7, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
const float theta_scale = powf(freq_base, -2.0f/n_dims);
const float p0 = (mode & 1) == 0 ? n_past : 0;
@ -6137,7 +6138,7 @@ inline void ggml_cuda_op_rope(
rope_neox_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0, freq_scale, ne01, theta_scale, main_stream);
} else {
rope_corr_dims corr_dims;
ggml_rope_yarn_corr_dims(n_dims, freq_base, beta_fast, beta_slow, corr_dims.v);
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
rope_f32_cuda(
src0_dd, dst_dd, ne00, nrows, freq_scale, ext_factor, attn_factor, theta_scale, p0, ne01, corr_dims,

View file

@ -1176,17 +1176,18 @@ void ggml_metal_graph_compute(
} break;
case GGML_OP_ROPE:
{
const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
const int n_orig_ctx = ((int32_t *) dst->op_params)[3];
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&ext_factor, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&attn_factor, (int32_t *) dst->op_params + 7, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
[encoder setComputePipelineState:ctx->pipeline_rope];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];

View file

@ -832,18 +832,18 @@ static void rope_yarn(
*sin_theta = sinf(theta) * mscale;
}
constant float max_pos_emb = 2048;
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
static float rope_yarn_corr_factor(const int n_dims, const float n_rot, const float base) {
return n_dims * log(max_pos_emb / (n_rot * 2 * M_PI_F)) / (2 * log(base));
static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) {
return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base));
}
static void rope_yarn_corr_dims(int n_dims, const float freq_base, float beta_fast, float beta_slow, float dims[2]) {
static void rope_yarn_corr_dims(
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
) {
// start and end correction dims
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, beta_fast, freq_base)));
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, beta_slow, freq_base)));
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base)));
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
}
kernel void kernel_rope(
@ -868,6 +868,7 @@ kernel void kernel_rope(
constant int & n_past,
constant int & n_dims,
constant int & mode,
constant int & n_orig_ctx,
constant float & freq_base,
constant float & freq_scale,
constant float & ext_factor,
@ -884,7 +885,7 @@ kernel void kernel_rope(
const bool is_neox = mode & 2;
float corr_dims[2];
rope_yarn_corr_dims(n_dims, freq_base, beta_fast, beta_slow, corr_dims);
rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
const int64_t p = (mode & 1) == 0 ? n_past + i2 : i2;

100
ggml.c
View file

@ -6973,6 +6973,7 @@ static struct ggml_tensor * ggml_rope_impl(
int n_dims,
int mode,
int n_ctx,
int n_orig_ctx,
float freq_base,
float freq_scale,
float ext_factor,
@ -6991,15 +6992,15 @@ static struct ggml_tensor * ggml_rope_impl(
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
int32_t params[12] = { n_past, n_dims, mode, n_ctx };
memcpy(params + 4, &freq_base, sizeof(float));
memcpy(params + 5, &freq_scale, sizeof(float));
memcpy(params + 6, &ext_factor, sizeof(float));
memcpy(params + 7, &attn_factor, sizeof(float));
memcpy(params + 8, &beta_fast, sizeof(float));
memcpy(params + 9, &beta_slow, sizeof(float));
memcpy(params + 10, &xpos_base, sizeof(float));
memcpy(params + 11, &xpos_down, sizeof(bool));
int32_t params[13] = { n_past, n_dims, mode, n_ctx, n_orig_ctx };
memcpy(params + 5, &freq_base, sizeof(float));
memcpy(params + 6, &freq_scale, sizeof(float));
memcpy(params + 7, &ext_factor, sizeof(float));
memcpy(params + 8, &attn_factor, sizeof(float));
memcpy(params + 9, &beta_fast, sizeof(float));
memcpy(params + 10, &beta_slow, sizeof(float));
memcpy(params + 11, &xpos_base, sizeof(float));
memcpy(params + 12, &xpos_down, sizeof(bool));
ggml_set_op_params(result, params, sizeof(params));
result->op = GGML_OP_ROPE;
@ -7017,7 +7018,7 @@ struct ggml_tensor * ggml_rope(
int mode,
int n_ctx) {
return ggml_rope_impl(
ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
ctx, a, n_past, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
);
}
@ -7029,7 +7030,7 @@ struct ggml_tensor * ggml_rope_inplace(
int mode,
int n_ctx) {
return ggml_rope_impl(
ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
ctx, a, n_past, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
);
}
@ -7040,6 +7041,7 @@ struct ggml_tensor * ggml_rope_custom(
int n_dims,
int mode,
int n_ctx,
int n_orig_ctx,
float freq_base,
float freq_scale,
float ext_factor,
@ -7047,8 +7049,8 @@ struct ggml_tensor * ggml_rope_custom(
float beta_fast,
float beta_slow) {
return ggml_rope_impl(
ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, 0.0f,
false, false
ctx, a, n_past, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
);
}
@ -7059,6 +7061,7 @@ struct ggml_tensor * ggml_rope_custom_inplace(
int n_dims,
int mode,
int n_ctx,
int n_orig_ctx,
float freq_base,
float freq_scale,
float ext_factor,
@ -7066,8 +7069,8 @@ struct ggml_tensor * ggml_rope_custom_inplace(
float beta_fast,
float beta_slow) {
return ggml_rope_impl(
ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, 0.0f,
false, true
ctx, a, n_past, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
);
}
@ -7078,7 +7081,7 @@ struct ggml_tensor * ggml_rope_xpos_inplace(
int n_dims,
float base,
bool down) {
return ggml_rope_impl(ctx, a, n_past, n_dims, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true);
return ggml_rope_impl(ctx, a, n_past, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true);
}
// ggml_rope_back
@ -12675,15 +12678,16 @@ static void rope_yarn(
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
// `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
static float ggml_rope_yarn_corr_dim(const int n_dims, const float n_rot, const float base) {
static const float max_pos_emb = 2048;
return n_dims * logf(max_pos_emb / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
static float ggml_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot, float base) {
return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
}
void ggml_rope_yarn_corr_dims(int n_dims, const float freq_base, float beta_fast, float beta_slow, float dims[2]) {
void ggml_rope_yarn_corr_dims(
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
) {
// start and end correction dims
dims[0] = MAX(0, floorf(ggml_rope_yarn_corr_dim(n_dims, beta_fast, freq_base)));
dims[1] = MIN(n_dims - 1, ceilf(ggml_rope_yarn_corr_dim(n_dims, beta_slow, freq_base)));
dims[0] = MAX(0, floorf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_fast, freq_base)));
dims[1] = MIN(n_dims - 1, ceilf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_slow, freq_base)));
}
static void ggml_compute_forward_rope_f32(
@ -12701,18 +12705,20 @@ static void ggml_compute_forward_rope_f32(
float xpos_base;
bool xpos_down;
const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
const int n_ctx = ((int32_t *) dst->op_params)[3];
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&ext_factor, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&attn_factor, (int32_t *) dst->op_params + 7, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&xpos_base, (int32_t *) dst->op_params + 10, sizeof(float));
memcpy(&xpos_down, (int32_t *) dst->op_params + 11, sizeof(bool));
const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
const int n_ctx = ((int32_t *) dst->op_params)[3];
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
memcpy(&xpos_base, (int32_t *) dst->op_params + 11, sizeof(float));
memcpy(&xpos_down, (int32_t *) dst->op_params + 12, sizeof(bool));
assert(n_past >= 0);
@ -12743,7 +12749,7 @@ static void ggml_compute_forward_rope_f32(
const float theta_scale = powf(freq_base, -2.0f/n_dims);
float corr_dims[2];
ggml_rope_yarn_corr_dims(n_dims, freq_base, beta_fast, beta_slow, corr_dims);
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
const bool is_neox = mode & 2;
const bool is_glm = mode & 4;
@ -12844,16 +12850,17 @@ static void ggml_compute_forward_rope_f16(
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
const int n_ctx = ((int32_t *) dst->op_params)[3];
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&ext_factor, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&attn_factor, (int32_t *) dst->op_params + 7, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 9, sizeof(float));
const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
const int n_ctx = ((int32_t *) dst->op_params)[3];
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
assert(n_past >= 0);
@ -12884,7 +12891,7 @@ static void ggml_compute_forward_rope_f16(
const float theta_scale = powf(freq_base, -2.0f/n_dims);
float corr_dims[2];
ggml_rope_yarn_corr_dims(n_dims, freq_base, beta_fast, beta_slow, corr_dims);
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
const bool is_neox = mode & 2;
const bool is_glm = mode & 4;
@ -16641,6 +16648,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
n_past,
n_dims,
mode,
0,
n_ctx,
freq_base,
freq_scale,

7
ggml.h
View file

@ -219,7 +219,7 @@
#define GGML_MAX_CONTEXTS 64
#define GGML_MAX_SRC 6
#define GGML_MAX_NAME 64
#define GGML_MAX_OP_PARAMS 48
#define GGML_MAX_OP_PARAMS 64
#define GGML_DEFAULT_N_THREADS 4
#if UINTPTR_MAX == 0xFFFFFFFF
@ -1248,6 +1248,7 @@ extern "C" {
int n_dims,
int mode,
int n_ctx,
int n_orig_ctx,
float freq_base,
float freq_scale,
float ext_factor,
@ -1263,6 +1264,7 @@ extern "C" {
int n_dims,
int mode,
int n_ctx,
int n_orig_ctx,
float freq_base,
float freq_scale,
float ext_factor,
@ -1271,7 +1273,8 @@ extern "C" {
float beta_slow);
// compute correction dims for YaRN RoPE scaling
void ggml_rope_yarn_corr_dims(int n_dims, const float freq_base, float beta_fast, float beta_slow, float dims[2]);
void ggml_rope_yarn_corr_dims(
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]);
// xPos RoPE, in-place, returns view(a)
GGML_API struct ggml_tensor * ggml_rope_xpos_inplace(

View file

@ -2471,13 +2471,14 @@ static struct ggml_cgraph * llm_build_llama(
GGML_ASSERT(!!kv_self.ctx);
const int64_t n_embd = hparams.n_embd;
const int64_t n_layer = hparams.n_layer;
const int64_t n_ctx = hparams.n_ctx;
const int64_t n_head = hparams.n_head;
const int64_t n_head_kv = hparams.n_head_kv;
const int64_t n_embd_head = hparams.n_embd_head();
const int64_t n_embd_gqa = hparams.n_embd_gqa();
const int32_t n_embd = hparams.n_embd;
const int32_t n_layer = hparams.n_layer;
const int32_t n_ctx = hparams.n_ctx;
const int32_t n_orig_ctx = hparams.n_yarn_orig_ctx;
const int32_t n_head = hparams.n_head;
const int32_t n_head_kv = hparams.n_head_kv;
const int32_t n_embd_head = hparams.n_embd_head();
const int32_t n_embd_gqa = hparams.n_embd_gqa();
GGML_ASSERT(n_embd_head == hparams.n_rot);
@ -2599,14 +2600,18 @@ static struct ggml_cgraph * llm_build_llama(
ggml_set_name(tmpq, "tmpq");
struct ggml_tensor * Kcur = ggml_rope_custom_inplace(
ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base,
freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N),
n_past, n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
offload_func_kq(Kcur);
ggml_set_name(Kcur, "Kcur");
struct ggml_tensor * Qcur = ggml_rope_custom_inplace(
ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base,
freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N),
n_past, n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
offload_func_kq(Qcur);
ggml_set_name(Qcur, "Qcur");
@ -2811,13 +2816,14 @@ static struct ggml_cgraph * llm_build_baichaun(
GGML_ASSERT(!!kv_self.ctx);
const int64_t n_embd = hparams.n_embd;
const int64_t n_layer = hparams.n_layer;
const int64_t n_ctx = hparams.n_ctx;
const int64_t n_head = hparams.n_head;
const int64_t n_head_kv = hparams.n_head_kv;
const int64_t n_embd_head = hparams.n_embd_head();
const int64_t n_embd_gqa = hparams.n_embd_gqa();
const int32_t n_embd = hparams.n_embd;
const int32_t n_layer = hparams.n_layer;
const int32_t n_ctx = hparams.n_ctx;
const int32_t n_orig_ctx = hparams.n_yarn_orig_ctx;
const int32_t n_head = hparams.n_head;
const int32_t n_head_kv = hparams.n_head_kv;
const int32_t n_embd_head = hparams.n_embd_head();
const int32_t n_embd_gqa = hparams.n_embd_gqa();
GGML_ASSERT(n_embd_head == hparams.n_rot);
@ -2945,12 +2951,14 @@ static struct ggml_cgraph * llm_build_baichaun(
Kcur = ggml_rope_custom_inplace(
ctx0,
ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N),
n_past, n_embd_head, 0, 0, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
n_past, n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
Qcur = ggml_rope_custom_inplace(
ctx0,
ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N),
n_past, n_embd_head, 0, 0, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
n_past, n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
break;
case MODEL_13B:
@ -3183,13 +3191,14 @@ static struct ggml_cgraph * llm_build_falcon(
GGML_ASSERT(!!kv_self.ctx);
const int64_t n_embd = hparams.n_embd;
const int64_t n_layer = hparams.n_layer;
const int64_t n_ctx = hparams.n_ctx;
const int64_t n_head = hparams.n_head;
const int64_t n_head_kv = hparams.n_head_kv;
const int64_t n_embd_head = hparams.n_embd_head();
const int64_t n_embd_gqa = hparams.n_embd_gqa();
const int32_t n_embd = hparams.n_embd;
const int32_t n_layer = hparams.n_layer;
const int32_t n_ctx = hparams.n_ctx;
const int32_t n_orig_ctx = hparams.n_yarn_orig_ctx;
const int32_t n_head = hparams.n_head;
const int32_t n_head_kv = hparams.n_head_kv;
const int32_t n_embd_head = hparams.n_embd_head();
const int32_t n_embd_gqa = hparams.n_embd_gqa();
GGML_ASSERT(n_embd_head == hparams.n_rot);
@ -3351,12 +3360,12 @@ static struct ggml_cgraph * llm_build_falcon(
// using mode = 2 for neox mode
struct ggml_tensor * Qcur = ggml_rope_custom_inplace(
ctx0, tmpq, n_past, n_embd_head, 2, 0,
ctx0, tmpq, n_past, n_embd_head, 2, 0, n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
offload_func_kq(Qcur);
struct ggml_tensor * Kcur = ggml_rope_custom_inplace(
ctx0, tmpk, n_past, n_embd_head, 2, 0,
ctx0, tmpk, n_past, n_embd_head, 2, 0, n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
offload_func_kq(Kcur);