don't hardcode max_pos_emb
This commit is contained in:
parent
43eaf06a2f
commit
fe788c45c8
7 changed files with 131 additions and 108 deletions
|
@ -687,7 +687,7 @@ struct ggml_tensor * llama_build_train_graphs(
|
|||
const int rope_mode = 0;
|
||||
|
||||
return ggml_rope_custom(
|
||||
ctx, t, n_past, n_rot, rope_mode, n_ctx, rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f
|
||||
ctx, t, n_past, n_rot, rope_mode, n_ctx, 0, rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f
|
||||
);
|
||||
};
|
||||
|
||||
|
|
27
ggml-cuda.cu
27
ggml-cuda.cu
|
@ -4386,7 +4386,7 @@ static __device__ void rope_yarn(
|
|||
|
||||
// rope == RoPE == rotary positional embedding
|
||||
static __global__ void rope_f32(
|
||||
float * x, float * dst, int ncols, float freq_scale, float ext_factor, float attn_factor, float theta_scale,
|
||||
const float * x, float * dst, int ncols, float freq_scale, float ext_factor, float attn_factor, float theta_scale,
|
||||
float p0, int p_delta_rows, rope_corr_dims corr_dims
|
||||
) {
|
||||
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||
|
@ -5396,7 +5396,7 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
|
|||
}
|
||||
|
||||
static void rope_f32_cuda(
|
||||
float * x, float * dst, int ncols, int nrows, float freq_scale, float ext_factor, float attn_factor,
|
||||
const float * x, float * dst, int ncols, int nrows, float freq_scale, float ext_factor, float attn_factor,
|
||||
float theta_scale, float p0, int p_delta_rows, rope_corr_dims corr_dims, cudaStream_t stream
|
||||
) {
|
||||
GGML_ASSERT(ncols % 2 == 0);
|
||||
|
@ -6109,19 +6109,20 @@ inline void ggml_cuda_op_rope(
|
|||
const int64_t ne01 = src0->ne[1];
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
|
||||
const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||
const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
||||
|
||||
// RoPE alteration for extended context
|
||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||
memcpy(&ext_factor, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||
memcpy(&attn_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
||||
memcpy(&beta_fast, (int32_t *) dst->op_params + 8, sizeof(float));
|
||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
||||
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
||||
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||
|
||||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
||||
const float p0 = (mode & 1) == 0 ? n_past : 0;
|
||||
|
@ -6137,7 +6138,7 @@ inline void ggml_cuda_op_rope(
|
|||
rope_neox_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0, freq_scale, ne01, theta_scale, main_stream);
|
||||
} else {
|
||||
rope_corr_dims corr_dims;
|
||||
ggml_rope_yarn_corr_dims(n_dims, freq_base, beta_fast, beta_slow, corr_dims.v);
|
||||
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
|
||||
|
||||
rope_f32_cuda(
|
||||
src0_dd, dst_dd, ne00, nrows, freq_scale, ext_factor, attn_factor, theta_scale, p0, ne01, corr_dims,
|
||||
|
|
19
ggml-metal.m
19
ggml-metal.m
|
@ -1176,17 +1176,18 @@ void ggml_metal_graph_compute(
|
|||
} break;
|
||||
case GGML_OP_ROPE:
|
||||
{
|
||||
const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
const int n_orig_ctx = ((int32_t *) dst->op_params)[3];
|
||||
|
||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||
memcpy(&ext_factor, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||
memcpy(&attn_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
||||
memcpy(&beta_fast, (int32_t *) dst->op_params + 8, sizeof(float));
|
||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
||||
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
||||
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||
|
||||
[encoder setComputePipelineState:ctx->pipeline_rope];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
|
|
|
@ -832,18 +832,18 @@ static void rope_yarn(
|
|||
*sin_theta = sinf(theta) * mscale;
|
||||
}
|
||||
|
||||
constant float max_pos_emb = 2048;
|
||||
|
||||
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
|
||||
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
|
||||
static float rope_yarn_corr_factor(const int n_dims, const float n_rot, const float base) {
|
||||
return n_dims * log(max_pos_emb / (n_rot * 2 * M_PI_F)) / (2 * log(base));
|
||||
static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) {
|
||||
return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base));
|
||||
}
|
||||
|
||||
static void rope_yarn_corr_dims(int n_dims, const float freq_base, float beta_fast, float beta_slow, float dims[2]) {
|
||||
static void rope_yarn_corr_dims(
|
||||
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
|
||||
) {
|
||||
// start and end correction dims
|
||||
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, beta_fast, freq_base)));
|
||||
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, beta_slow, freq_base)));
|
||||
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base)));
|
||||
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
|
||||
}
|
||||
|
||||
kernel void kernel_rope(
|
||||
|
@ -868,6 +868,7 @@ kernel void kernel_rope(
|
|||
constant int & n_past,
|
||||
constant int & n_dims,
|
||||
constant int & mode,
|
||||
constant int & n_orig_ctx,
|
||||
constant float & freq_base,
|
||||
constant float & freq_scale,
|
||||
constant float & ext_factor,
|
||||
|
@ -884,7 +885,7 @@ kernel void kernel_rope(
|
|||
const bool is_neox = mode & 2;
|
||||
|
||||
float corr_dims[2];
|
||||
rope_yarn_corr_dims(n_dims, freq_base, beta_fast, beta_slow, corr_dims);
|
||||
rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
|
||||
|
||||
const int64_t p = (mode & 1) == 0 ? n_past + i2 : i2;
|
||||
|
||||
|
|
100
ggml.c
100
ggml.c
|
@ -6973,6 +6973,7 @@ static struct ggml_tensor * ggml_rope_impl(
|
|||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
int n_orig_ctx,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
|
@ -6991,15 +6992,15 @@ static struct ggml_tensor * ggml_rope_impl(
|
|||
|
||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
|
||||
int32_t params[12] = { n_past, n_dims, mode, n_ctx };
|
||||
memcpy(params + 4, &freq_base, sizeof(float));
|
||||
memcpy(params + 5, &freq_scale, sizeof(float));
|
||||
memcpy(params + 6, &ext_factor, sizeof(float));
|
||||
memcpy(params + 7, &attn_factor, sizeof(float));
|
||||
memcpy(params + 8, &beta_fast, sizeof(float));
|
||||
memcpy(params + 9, &beta_slow, sizeof(float));
|
||||
memcpy(params + 10, &xpos_base, sizeof(float));
|
||||
memcpy(params + 11, &xpos_down, sizeof(bool));
|
||||
int32_t params[13] = { n_past, n_dims, mode, n_ctx, n_orig_ctx };
|
||||
memcpy(params + 5, &freq_base, sizeof(float));
|
||||
memcpy(params + 6, &freq_scale, sizeof(float));
|
||||
memcpy(params + 7, &ext_factor, sizeof(float));
|
||||
memcpy(params + 8, &attn_factor, sizeof(float));
|
||||
memcpy(params + 9, &beta_fast, sizeof(float));
|
||||
memcpy(params + 10, &beta_slow, sizeof(float));
|
||||
memcpy(params + 11, &xpos_base, sizeof(float));
|
||||
memcpy(params + 12, &xpos_down, sizeof(bool));
|
||||
ggml_set_op_params(result, params, sizeof(params));
|
||||
|
||||
result->op = GGML_OP_ROPE;
|
||||
|
@ -7017,7 +7018,7 @@ struct ggml_tensor * ggml_rope(
|
|||
int mode,
|
||||
int n_ctx) {
|
||||
return ggml_rope_impl(
|
||||
ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
|
||||
ctx, a, n_past, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -7029,7 +7030,7 @@ struct ggml_tensor * ggml_rope_inplace(
|
|||
int mode,
|
||||
int n_ctx) {
|
||||
return ggml_rope_impl(
|
||||
ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
|
||||
ctx, a, n_past, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -7040,6 +7041,7 @@ struct ggml_tensor * ggml_rope_custom(
|
|||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
int n_orig_ctx,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
|
@ -7047,8 +7049,8 @@ struct ggml_tensor * ggml_rope_custom(
|
|||
float beta_fast,
|
||||
float beta_slow) {
|
||||
return ggml_rope_impl(
|
||||
ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, 0.0f,
|
||||
false, false
|
||||
ctx, a, n_past, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -7059,6 +7061,7 @@ struct ggml_tensor * ggml_rope_custom_inplace(
|
|||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
int n_orig_ctx,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
|
@ -7066,8 +7069,8 @@ struct ggml_tensor * ggml_rope_custom_inplace(
|
|||
float beta_fast,
|
||||
float beta_slow) {
|
||||
return ggml_rope_impl(
|
||||
ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, 0.0f,
|
||||
false, true
|
||||
ctx, a, n_past, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -7078,7 +7081,7 @@ struct ggml_tensor * ggml_rope_xpos_inplace(
|
|||
int n_dims,
|
||||
float base,
|
||||
bool down) {
|
||||
return ggml_rope_impl(ctx, a, n_past, n_dims, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true);
|
||||
return ggml_rope_impl(ctx, a, n_past, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true);
|
||||
}
|
||||
|
||||
// ggml_rope_back
|
||||
|
@ -12675,15 +12678,16 @@ static void rope_yarn(
|
|||
|
||||
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
|
||||
// `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
|
||||
static float ggml_rope_yarn_corr_dim(const int n_dims, const float n_rot, const float base) {
|
||||
static const float max_pos_emb = 2048;
|
||||
return n_dims * logf(max_pos_emb / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
|
||||
static float ggml_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot, float base) {
|
||||
return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
|
||||
}
|
||||
|
||||
void ggml_rope_yarn_corr_dims(int n_dims, const float freq_base, float beta_fast, float beta_slow, float dims[2]) {
|
||||
void ggml_rope_yarn_corr_dims(
|
||||
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
|
||||
) {
|
||||
// start and end correction dims
|
||||
dims[0] = MAX(0, floorf(ggml_rope_yarn_corr_dim(n_dims, beta_fast, freq_base)));
|
||||
dims[1] = MIN(n_dims - 1, ceilf(ggml_rope_yarn_corr_dim(n_dims, beta_slow, freq_base)));
|
||||
dims[0] = MAX(0, floorf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_fast, freq_base)));
|
||||
dims[1] = MIN(n_dims - 1, ceilf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_slow, freq_base)));
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_rope_f32(
|
||||
|
@ -12701,18 +12705,20 @@ static void ggml_compute_forward_rope_f32(
|
|||
float xpos_base;
|
||||
bool xpos_down;
|
||||
|
||||
const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||
memcpy(&ext_factor, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||
memcpy(&attn_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
||||
memcpy(&beta_fast, (int32_t *) dst->op_params + 8, sizeof(float));
|
||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||
memcpy(&xpos_base, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||
memcpy(&xpos_down, (int32_t *) dst->op_params + 11, sizeof(bool));
|
||||
const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
||||
|
||||
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
||||
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
||||
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||
memcpy(&xpos_base, (int32_t *) dst->op_params + 11, sizeof(float));
|
||||
memcpy(&xpos_down, (int32_t *) dst->op_params + 12, sizeof(bool));
|
||||
|
||||
assert(n_past >= 0);
|
||||
|
||||
|
@ -12743,7 +12749,7 @@ static void ggml_compute_forward_rope_f32(
|
|||
|
||||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
||||
float corr_dims[2];
|
||||
ggml_rope_yarn_corr_dims(n_dims, freq_base, beta_fast, beta_slow, corr_dims);
|
||||
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
|
||||
|
||||
const bool is_neox = mode & 2;
|
||||
const bool is_glm = mode & 4;
|
||||
|
@ -12844,16 +12850,17 @@ static void ggml_compute_forward_rope_f16(
|
|||
|
||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||
|
||||
const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||
memcpy(&ext_factor, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||
memcpy(&attn_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
||||
memcpy(&beta_fast, (int32_t *) dst->op_params + 8, sizeof(float));
|
||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||
const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
||||
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
||||
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
||||
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||
|
||||
assert(n_past >= 0);
|
||||
|
||||
|
@ -12884,7 +12891,7 @@ static void ggml_compute_forward_rope_f16(
|
|||
|
||||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
||||
float corr_dims[2];
|
||||
ggml_rope_yarn_corr_dims(n_dims, freq_base, beta_fast, beta_slow, corr_dims);
|
||||
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
|
||||
|
||||
const bool is_neox = mode & 2;
|
||||
const bool is_glm = mode & 4;
|
||||
|
@ -16641,6 +16648,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
n_past,
|
||||
n_dims,
|
||||
mode,
|
||||
0,
|
||||
n_ctx,
|
||||
freq_base,
|
||||
freq_scale,
|
||||
|
|
7
ggml.h
7
ggml.h
|
@ -219,7 +219,7 @@
|
|||
#define GGML_MAX_CONTEXTS 64
|
||||
#define GGML_MAX_SRC 6
|
||||
#define GGML_MAX_NAME 64
|
||||
#define GGML_MAX_OP_PARAMS 48
|
||||
#define GGML_MAX_OP_PARAMS 64
|
||||
#define GGML_DEFAULT_N_THREADS 4
|
||||
|
||||
#if UINTPTR_MAX == 0xFFFFFFFF
|
||||
|
@ -1248,6 +1248,7 @@ extern "C" {
|
|||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
int n_orig_ctx,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
|
@ -1263,6 +1264,7 @@ extern "C" {
|
|||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
int n_orig_ctx,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
|
@ -1271,7 +1273,8 @@ extern "C" {
|
|||
float beta_slow);
|
||||
|
||||
// compute correction dims for YaRN RoPE scaling
|
||||
void ggml_rope_yarn_corr_dims(int n_dims, const float freq_base, float beta_fast, float beta_slow, float dims[2]);
|
||||
void ggml_rope_yarn_corr_dims(
|
||||
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]);
|
||||
|
||||
// xPos RoPE, in-place, returns view(a)
|
||||
GGML_API struct ggml_tensor * ggml_rope_xpos_inplace(
|
||||
|
|
67
llama.cpp
67
llama.cpp
|
@ -2471,13 +2471,14 @@ static struct ggml_cgraph * llm_build_llama(
|
|||
|
||||
GGML_ASSERT(!!kv_self.ctx);
|
||||
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
const int64_t n_layer = hparams.n_layer;
|
||||
const int64_t n_ctx = hparams.n_ctx;
|
||||
const int64_t n_head = hparams.n_head;
|
||||
const int64_t n_head_kv = hparams.n_head_kv;
|
||||
const int64_t n_embd_head = hparams.n_embd_head();
|
||||
const int64_t n_embd_gqa = hparams.n_embd_gqa();
|
||||
const int32_t n_embd = hparams.n_embd;
|
||||
const int32_t n_layer = hparams.n_layer;
|
||||
const int32_t n_ctx = hparams.n_ctx;
|
||||
const int32_t n_orig_ctx = hparams.n_yarn_orig_ctx;
|
||||
const int32_t n_head = hparams.n_head;
|
||||
const int32_t n_head_kv = hparams.n_head_kv;
|
||||
const int32_t n_embd_head = hparams.n_embd_head();
|
||||
const int32_t n_embd_gqa = hparams.n_embd_gqa();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
|
||||
|
@ -2599,14 +2600,18 @@ static struct ggml_cgraph * llm_build_llama(
|
|||
ggml_set_name(tmpq, "tmpq");
|
||||
|
||||
struct ggml_tensor * Kcur = ggml_rope_custom_inplace(
|
||||
ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base,
|
||||
freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N),
|
||||
n_past, n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
offload_func_kq(Kcur);
|
||||
ggml_set_name(Kcur, "Kcur");
|
||||
|
||||
struct ggml_tensor * Qcur = ggml_rope_custom_inplace(
|
||||
ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base,
|
||||
freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N),
|
||||
n_past, n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
offload_func_kq(Qcur);
|
||||
ggml_set_name(Qcur, "Qcur");
|
||||
|
||||
|
@ -2811,13 +2816,14 @@ static struct ggml_cgraph * llm_build_baichaun(
|
|||
|
||||
GGML_ASSERT(!!kv_self.ctx);
|
||||
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
const int64_t n_layer = hparams.n_layer;
|
||||
const int64_t n_ctx = hparams.n_ctx;
|
||||
const int64_t n_head = hparams.n_head;
|
||||
const int64_t n_head_kv = hparams.n_head_kv;
|
||||
const int64_t n_embd_head = hparams.n_embd_head();
|
||||
const int64_t n_embd_gqa = hparams.n_embd_gqa();
|
||||
const int32_t n_embd = hparams.n_embd;
|
||||
const int32_t n_layer = hparams.n_layer;
|
||||
const int32_t n_ctx = hparams.n_ctx;
|
||||
const int32_t n_orig_ctx = hparams.n_yarn_orig_ctx;
|
||||
const int32_t n_head = hparams.n_head;
|
||||
const int32_t n_head_kv = hparams.n_head_kv;
|
||||
const int32_t n_embd_head = hparams.n_embd_head();
|
||||
const int32_t n_embd_gqa = hparams.n_embd_gqa();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
|
||||
|
@ -2945,12 +2951,14 @@ static struct ggml_cgraph * llm_build_baichaun(
|
|||
Kcur = ggml_rope_custom_inplace(
|
||||
ctx0,
|
||||
ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N),
|
||||
n_past, n_embd_head, 0, 0, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||
n_past, n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
Qcur = ggml_rope_custom_inplace(
|
||||
ctx0,
|
||||
ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N),
|
||||
n_past, n_embd_head, 0, 0, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||
n_past, n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
break;
|
||||
case MODEL_13B:
|
||||
|
@ -3183,13 +3191,14 @@ static struct ggml_cgraph * llm_build_falcon(
|
|||
|
||||
GGML_ASSERT(!!kv_self.ctx);
|
||||
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
const int64_t n_layer = hparams.n_layer;
|
||||
const int64_t n_ctx = hparams.n_ctx;
|
||||
const int64_t n_head = hparams.n_head;
|
||||
const int64_t n_head_kv = hparams.n_head_kv;
|
||||
const int64_t n_embd_head = hparams.n_embd_head();
|
||||
const int64_t n_embd_gqa = hparams.n_embd_gqa();
|
||||
const int32_t n_embd = hparams.n_embd;
|
||||
const int32_t n_layer = hparams.n_layer;
|
||||
const int32_t n_ctx = hparams.n_ctx;
|
||||
const int32_t n_orig_ctx = hparams.n_yarn_orig_ctx;
|
||||
const int32_t n_head = hparams.n_head;
|
||||
const int32_t n_head_kv = hparams.n_head_kv;
|
||||
const int32_t n_embd_head = hparams.n_embd_head();
|
||||
const int32_t n_embd_gqa = hparams.n_embd_gqa();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
|
||||
|
@ -3351,12 +3360,12 @@ static struct ggml_cgraph * llm_build_falcon(
|
|||
|
||||
// using mode = 2 for neox mode
|
||||
struct ggml_tensor * Qcur = ggml_rope_custom_inplace(
|
||||
ctx0, tmpq, n_past, n_embd_head, 2, 0,
|
||||
ctx0, tmpq, n_past, n_embd_head, 2, 0, n_orig_ctx,
|
||||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
offload_func_kq(Qcur);
|
||||
struct ggml_tensor * Kcur = ggml_rope_custom_inplace(
|
||||
ctx0, tmpk, n_past, n_embd_head, 2, 0,
|
||||
ctx0, tmpk, n_past, n_embd_head, 2, 0, n_orig_ctx,
|
||||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
offload_func_kq(Kcur);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue