don't hardcode max_pos_emb
This commit is contained in:
parent
43eaf06a2f
commit
fe788c45c8
7 changed files with 131 additions and 108 deletions
|
@ -687,7 +687,7 @@ struct ggml_tensor * llama_build_train_graphs(
|
||||||
const int rope_mode = 0;
|
const int rope_mode = 0;
|
||||||
|
|
||||||
return ggml_rope_custom(
|
return ggml_rope_custom(
|
||||||
ctx, t, n_past, n_rot, rope_mode, n_ctx, rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f
|
ctx, t, n_past, n_rot, rope_mode, n_ctx, 0, rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
27
ggml-cuda.cu
27
ggml-cuda.cu
|
@ -4386,7 +4386,7 @@ static __device__ void rope_yarn(
|
||||||
|
|
||||||
// rope == RoPE == rotary positional embedding
|
// rope == RoPE == rotary positional embedding
|
||||||
static __global__ void rope_f32(
|
static __global__ void rope_f32(
|
||||||
float * x, float * dst, int ncols, float freq_scale, float ext_factor, float attn_factor, float theta_scale,
|
const float * x, float * dst, int ncols, float freq_scale, float ext_factor, float attn_factor, float theta_scale,
|
||||||
float p0, int p_delta_rows, rope_corr_dims corr_dims
|
float p0, int p_delta_rows, rope_corr_dims corr_dims
|
||||||
) {
|
) {
|
||||||
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||||
|
@ -5396,7 +5396,7 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
|
||||||
}
|
}
|
||||||
|
|
||||||
static void rope_f32_cuda(
|
static void rope_f32_cuda(
|
||||||
float * x, float * dst, int ncols, int nrows, float freq_scale, float ext_factor, float attn_factor,
|
const float * x, float * dst, int ncols, int nrows, float freq_scale, float ext_factor, float attn_factor,
|
||||||
float theta_scale, float p0, int p_delta_rows, rope_corr_dims corr_dims, cudaStream_t stream
|
float theta_scale, float p0, int p_delta_rows, rope_corr_dims corr_dims, cudaStream_t stream
|
||||||
) {
|
) {
|
||||||
GGML_ASSERT(ncols % 2 == 0);
|
GGML_ASSERT(ncols % 2 == 0);
|
||||||
|
@ -6109,19 +6109,20 @@ inline void ggml_cuda_op_rope(
|
||||||
const int64_t ne01 = src0->ne[1];
|
const int64_t ne01 = src0->ne[1];
|
||||||
const int64_t nrows = ggml_nrows(src0);
|
const int64_t nrows = ggml_nrows(src0);
|
||||||
|
|
||||||
const int n_past = ((int32_t *) dst->op_params)[0];
|
const int n_past = ((int32_t *) dst->op_params)[0];
|
||||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||||
const int mode = ((int32_t *) dst->op_params)[2];
|
const int mode = ((int32_t *) dst->op_params)[2];
|
||||||
const int n_ctx = ((int32_t *) dst->op_params)[3];
|
const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||||
|
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
||||||
|
|
||||||
// RoPE alteration for extended context
|
// RoPE alteration for extended context
|
||||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||||
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
|
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
|
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||||
memcpy(&ext_factor, (int32_t *) dst->op_params + 6, sizeof(float));
|
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
||||||
memcpy(&attn_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
||||||
memcpy(&beta_fast, (int32_t *) dst->op_params + 8, sizeof(float));
|
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 9, sizeof(float));
|
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||||
|
|
||||||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
||||||
const float p0 = (mode & 1) == 0 ? n_past : 0;
|
const float p0 = (mode & 1) == 0 ? n_past : 0;
|
||||||
|
@ -6137,7 +6138,7 @@ inline void ggml_cuda_op_rope(
|
||||||
rope_neox_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0, freq_scale, ne01, theta_scale, main_stream);
|
rope_neox_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0, freq_scale, ne01, theta_scale, main_stream);
|
||||||
} else {
|
} else {
|
||||||
rope_corr_dims corr_dims;
|
rope_corr_dims corr_dims;
|
||||||
ggml_rope_yarn_corr_dims(n_dims, freq_base, beta_fast, beta_slow, corr_dims.v);
|
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
|
||||||
|
|
||||||
rope_f32_cuda(
|
rope_f32_cuda(
|
||||||
src0_dd, dst_dd, ne00, nrows, freq_scale, ext_factor, attn_factor, theta_scale, p0, ne01, corr_dims,
|
src0_dd, dst_dd, ne00, nrows, freq_scale, ext_factor, attn_factor, theta_scale, p0, ne01, corr_dims,
|
||||||
|
|
19
ggml-metal.m
19
ggml-metal.m
|
@ -1176,17 +1176,18 @@ void ggml_metal_graph_compute(
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_ROPE:
|
case GGML_OP_ROPE:
|
||||||
{
|
{
|
||||||
const int n_past = ((int32_t *) dst->op_params)[0];
|
const int n_past = ((int32_t *) dst->op_params)[0];
|
||||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||||
const int mode = ((int32_t *) dst->op_params)[2];
|
const int mode = ((int32_t *) dst->op_params)[2];
|
||||||
|
const int n_orig_ctx = ((int32_t *) dst->op_params)[3];
|
||||||
|
|
||||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||||
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
|
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
|
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||||
memcpy(&ext_factor, (int32_t *) dst->op_params + 6, sizeof(float));
|
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
||||||
memcpy(&attn_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
||||||
memcpy(&beta_fast, (int32_t *) dst->op_params + 8, sizeof(float));
|
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 9, sizeof(float));
|
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||||
|
|
||||||
[encoder setComputePipelineState:ctx->pipeline_rope];
|
[encoder setComputePipelineState:ctx->pipeline_rope];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
|
|
|
@ -832,18 +832,18 @@ static void rope_yarn(
|
||||||
*sin_theta = sinf(theta) * mscale;
|
*sin_theta = sinf(theta) * mscale;
|
||||||
}
|
}
|
||||||
|
|
||||||
constant float max_pos_emb = 2048;
|
|
||||||
|
|
||||||
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
|
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
|
||||||
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
|
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
|
||||||
static float rope_yarn_corr_factor(const int n_dims, const float n_rot, const float base) {
|
static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) {
|
||||||
return n_dims * log(max_pos_emb / (n_rot * 2 * M_PI_F)) / (2 * log(base));
|
return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base));
|
||||||
}
|
}
|
||||||
|
|
||||||
static void rope_yarn_corr_dims(int n_dims, const float freq_base, float beta_fast, float beta_slow, float dims[2]) {
|
static void rope_yarn_corr_dims(
|
||||||
|
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
|
||||||
|
) {
|
||||||
// start and end correction dims
|
// start and end correction dims
|
||||||
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, beta_fast, freq_base)));
|
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base)));
|
||||||
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, beta_slow, freq_base)));
|
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_rope(
|
kernel void kernel_rope(
|
||||||
|
@ -868,6 +868,7 @@ kernel void kernel_rope(
|
||||||
constant int & n_past,
|
constant int & n_past,
|
||||||
constant int & n_dims,
|
constant int & n_dims,
|
||||||
constant int & mode,
|
constant int & mode,
|
||||||
|
constant int & n_orig_ctx,
|
||||||
constant float & freq_base,
|
constant float & freq_base,
|
||||||
constant float & freq_scale,
|
constant float & freq_scale,
|
||||||
constant float & ext_factor,
|
constant float & ext_factor,
|
||||||
|
@ -884,7 +885,7 @@ kernel void kernel_rope(
|
||||||
const bool is_neox = mode & 2;
|
const bool is_neox = mode & 2;
|
||||||
|
|
||||||
float corr_dims[2];
|
float corr_dims[2];
|
||||||
rope_yarn_corr_dims(n_dims, freq_base, beta_fast, beta_slow, corr_dims);
|
rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
|
||||||
|
|
||||||
const int64_t p = (mode & 1) == 0 ? n_past + i2 : i2;
|
const int64_t p = (mode & 1) == 0 ? n_past + i2 : i2;
|
||||||
|
|
||||||
|
|
100
ggml.c
100
ggml.c
|
@ -6973,6 +6973,7 @@ static struct ggml_tensor * ggml_rope_impl(
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode,
|
int mode,
|
||||||
int n_ctx,
|
int n_ctx,
|
||||||
|
int n_orig_ctx,
|
||||||
float freq_base,
|
float freq_base,
|
||||||
float freq_scale,
|
float freq_scale,
|
||||||
float ext_factor,
|
float ext_factor,
|
||||||
|
@ -6991,15 +6992,15 @@ static struct ggml_tensor * ggml_rope_impl(
|
||||||
|
|
||||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||||
|
|
||||||
int32_t params[12] = { n_past, n_dims, mode, n_ctx };
|
int32_t params[13] = { n_past, n_dims, mode, n_ctx, n_orig_ctx };
|
||||||
memcpy(params + 4, &freq_base, sizeof(float));
|
memcpy(params + 5, &freq_base, sizeof(float));
|
||||||
memcpy(params + 5, &freq_scale, sizeof(float));
|
memcpy(params + 6, &freq_scale, sizeof(float));
|
||||||
memcpy(params + 6, &ext_factor, sizeof(float));
|
memcpy(params + 7, &ext_factor, sizeof(float));
|
||||||
memcpy(params + 7, &attn_factor, sizeof(float));
|
memcpy(params + 8, &attn_factor, sizeof(float));
|
||||||
memcpy(params + 8, &beta_fast, sizeof(float));
|
memcpy(params + 9, &beta_fast, sizeof(float));
|
||||||
memcpy(params + 9, &beta_slow, sizeof(float));
|
memcpy(params + 10, &beta_slow, sizeof(float));
|
||||||
memcpy(params + 10, &xpos_base, sizeof(float));
|
memcpy(params + 11, &xpos_base, sizeof(float));
|
||||||
memcpy(params + 11, &xpos_down, sizeof(bool));
|
memcpy(params + 12, &xpos_down, sizeof(bool));
|
||||||
ggml_set_op_params(result, params, sizeof(params));
|
ggml_set_op_params(result, params, sizeof(params));
|
||||||
|
|
||||||
result->op = GGML_OP_ROPE;
|
result->op = GGML_OP_ROPE;
|
||||||
|
@ -7017,7 +7018,7 @@ struct ggml_tensor * ggml_rope(
|
||||||
int mode,
|
int mode,
|
||||||
int n_ctx) {
|
int n_ctx) {
|
||||||
return ggml_rope_impl(
|
return ggml_rope_impl(
|
||||||
ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
|
ctx, a, n_past, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7029,7 +7030,7 @@ struct ggml_tensor * ggml_rope_inplace(
|
||||||
int mode,
|
int mode,
|
||||||
int n_ctx) {
|
int n_ctx) {
|
||||||
return ggml_rope_impl(
|
return ggml_rope_impl(
|
||||||
ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
|
ctx, a, n_past, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7040,6 +7041,7 @@ struct ggml_tensor * ggml_rope_custom(
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode,
|
int mode,
|
||||||
int n_ctx,
|
int n_ctx,
|
||||||
|
int n_orig_ctx,
|
||||||
float freq_base,
|
float freq_base,
|
||||||
float freq_scale,
|
float freq_scale,
|
||||||
float ext_factor,
|
float ext_factor,
|
||||||
|
@ -7047,8 +7049,8 @@ struct ggml_tensor * ggml_rope_custom(
|
||||||
float beta_fast,
|
float beta_fast,
|
||||||
float beta_slow) {
|
float beta_slow) {
|
||||||
return ggml_rope_impl(
|
return ggml_rope_impl(
|
||||||
ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, 0.0f,
|
ctx, a, n_past, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
||||||
false, false
|
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7059,6 +7061,7 @@ struct ggml_tensor * ggml_rope_custom_inplace(
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode,
|
int mode,
|
||||||
int n_ctx,
|
int n_ctx,
|
||||||
|
int n_orig_ctx,
|
||||||
float freq_base,
|
float freq_base,
|
||||||
float freq_scale,
|
float freq_scale,
|
||||||
float ext_factor,
|
float ext_factor,
|
||||||
|
@ -7066,8 +7069,8 @@ struct ggml_tensor * ggml_rope_custom_inplace(
|
||||||
float beta_fast,
|
float beta_fast,
|
||||||
float beta_slow) {
|
float beta_slow) {
|
||||||
return ggml_rope_impl(
|
return ggml_rope_impl(
|
||||||
ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, 0.0f,
|
ctx, a, n_past, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
||||||
false, true
|
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7078,7 +7081,7 @@ struct ggml_tensor * ggml_rope_xpos_inplace(
|
||||||
int n_dims,
|
int n_dims,
|
||||||
float base,
|
float base,
|
||||||
bool down) {
|
bool down) {
|
||||||
return ggml_rope_impl(ctx, a, n_past, n_dims, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true);
|
return ggml_rope_impl(ctx, a, n_past, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ggml_rope_back
|
// ggml_rope_back
|
||||||
|
@ -12675,15 +12678,16 @@ static void rope_yarn(
|
||||||
|
|
||||||
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
|
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
|
||||||
// `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
|
// `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
|
||||||
static float ggml_rope_yarn_corr_dim(const int n_dims, const float n_rot, const float base) {
|
static float ggml_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot, float base) {
|
||||||
static const float max_pos_emb = 2048;
|
return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
|
||||||
return n_dims * logf(max_pos_emb / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_rope_yarn_corr_dims(int n_dims, const float freq_base, float beta_fast, float beta_slow, float dims[2]) {
|
void ggml_rope_yarn_corr_dims(
|
||||||
|
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
|
||||||
|
) {
|
||||||
// start and end correction dims
|
// start and end correction dims
|
||||||
dims[0] = MAX(0, floorf(ggml_rope_yarn_corr_dim(n_dims, beta_fast, freq_base)));
|
dims[0] = MAX(0, floorf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_fast, freq_base)));
|
||||||
dims[1] = MIN(n_dims - 1, ceilf(ggml_rope_yarn_corr_dim(n_dims, beta_slow, freq_base)));
|
dims[1] = MIN(n_dims - 1, ceilf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_slow, freq_base)));
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_compute_forward_rope_f32(
|
static void ggml_compute_forward_rope_f32(
|
||||||
|
@ -12701,18 +12705,20 @@ static void ggml_compute_forward_rope_f32(
|
||||||
float xpos_base;
|
float xpos_base;
|
||||||
bool xpos_down;
|
bool xpos_down;
|
||||||
|
|
||||||
const int n_past = ((int32_t *) dst->op_params)[0];
|
const int n_past = ((int32_t *) dst->op_params)[0];
|
||||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||||
const int mode = ((int32_t *) dst->op_params)[2];
|
const int mode = ((int32_t *) dst->op_params)[2];
|
||||||
const int n_ctx = ((int32_t *) dst->op_params)[3];
|
const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||||
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
|
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
||||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
|
|
||||||
memcpy(&ext_factor, (int32_t *) dst->op_params + 6, sizeof(float));
|
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||||
memcpy(&attn_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||||
memcpy(&beta_fast, (int32_t *) dst->op_params + 8, sizeof(float));
|
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
||||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 9, sizeof(float));
|
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
||||||
memcpy(&xpos_base, (int32_t *) dst->op_params + 10, sizeof(float));
|
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||||
memcpy(&xpos_down, (int32_t *) dst->op_params + 11, sizeof(bool));
|
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||||
|
memcpy(&xpos_base, (int32_t *) dst->op_params + 11, sizeof(float));
|
||||||
|
memcpy(&xpos_down, (int32_t *) dst->op_params + 12, sizeof(bool));
|
||||||
|
|
||||||
assert(n_past >= 0);
|
assert(n_past >= 0);
|
||||||
|
|
||||||
|
@ -12743,7 +12749,7 @@ static void ggml_compute_forward_rope_f32(
|
||||||
|
|
||||||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
||||||
float corr_dims[2];
|
float corr_dims[2];
|
||||||
ggml_rope_yarn_corr_dims(n_dims, freq_base, beta_fast, beta_slow, corr_dims);
|
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
|
||||||
|
|
||||||
const bool is_neox = mode & 2;
|
const bool is_neox = mode & 2;
|
||||||
const bool is_glm = mode & 4;
|
const bool is_glm = mode & 4;
|
||||||
|
@ -12844,16 +12850,17 @@ static void ggml_compute_forward_rope_f16(
|
||||||
|
|
||||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||||
|
|
||||||
const int n_past = ((int32_t *) dst->op_params)[0];
|
const int n_past = ((int32_t *) dst->op_params)[0];
|
||||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||||
const int mode = ((int32_t *) dst->op_params)[2];
|
const int mode = ((int32_t *) dst->op_params)[2];
|
||||||
const int n_ctx = ((int32_t *) dst->op_params)[3];
|
const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||||
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
|
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
||||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
|
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||||
memcpy(&ext_factor, (int32_t *) dst->op_params + 6, sizeof(float));
|
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||||
memcpy(&attn_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
||||||
memcpy(&beta_fast, (int32_t *) dst->op_params + 8, sizeof(float));
|
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
||||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 9, sizeof(float));
|
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||||
|
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||||
|
|
||||||
assert(n_past >= 0);
|
assert(n_past >= 0);
|
||||||
|
|
||||||
|
@ -12884,7 +12891,7 @@ static void ggml_compute_forward_rope_f16(
|
||||||
|
|
||||||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
||||||
float corr_dims[2];
|
float corr_dims[2];
|
||||||
ggml_rope_yarn_corr_dims(n_dims, freq_base, beta_fast, beta_slow, corr_dims);
|
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
|
||||||
|
|
||||||
const bool is_neox = mode & 2;
|
const bool is_neox = mode & 2;
|
||||||
const bool is_glm = mode & 4;
|
const bool is_glm = mode & 4;
|
||||||
|
@ -16641,6 +16648,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||||
n_past,
|
n_past,
|
||||||
n_dims,
|
n_dims,
|
||||||
mode,
|
mode,
|
||||||
|
0,
|
||||||
n_ctx,
|
n_ctx,
|
||||||
freq_base,
|
freq_base,
|
||||||
freq_scale,
|
freq_scale,
|
||||||
|
|
7
ggml.h
7
ggml.h
|
@ -219,7 +219,7 @@
|
||||||
#define GGML_MAX_CONTEXTS 64
|
#define GGML_MAX_CONTEXTS 64
|
||||||
#define GGML_MAX_SRC 6
|
#define GGML_MAX_SRC 6
|
||||||
#define GGML_MAX_NAME 64
|
#define GGML_MAX_NAME 64
|
||||||
#define GGML_MAX_OP_PARAMS 48
|
#define GGML_MAX_OP_PARAMS 64
|
||||||
#define GGML_DEFAULT_N_THREADS 4
|
#define GGML_DEFAULT_N_THREADS 4
|
||||||
|
|
||||||
#if UINTPTR_MAX == 0xFFFFFFFF
|
#if UINTPTR_MAX == 0xFFFFFFFF
|
||||||
|
@ -1248,6 +1248,7 @@ extern "C" {
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode,
|
int mode,
|
||||||
int n_ctx,
|
int n_ctx,
|
||||||
|
int n_orig_ctx,
|
||||||
float freq_base,
|
float freq_base,
|
||||||
float freq_scale,
|
float freq_scale,
|
||||||
float ext_factor,
|
float ext_factor,
|
||||||
|
@ -1263,6 +1264,7 @@ extern "C" {
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode,
|
int mode,
|
||||||
int n_ctx,
|
int n_ctx,
|
||||||
|
int n_orig_ctx,
|
||||||
float freq_base,
|
float freq_base,
|
||||||
float freq_scale,
|
float freq_scale,
|
||||||
float ext_factor,
|
float ext_factor,
|
||||||
|
@ -1271,7 +1273,8 @@ extern "C" {
|
||||||
float beta_slow);
|
float beta_slow);
|
||||||
|
|
||||||
// compute correction dims for YaRN RoPE scaling
|
// compute correction dims for YaRN RoPE scaling
|
||||||
void ggml_rope_yarn_corr_dims(int n_dims, const float freq_base, float beta_fast, float beta_slow, float dims[2]);
|
void ggml_rope_yarn_corr_dims(
|
||||||
|
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]);
|
||||||
|
|
||||||
// xPos RoPE, in-place, returns view(a)
|
// xPos RoPE, in-place, returns view(a)
|
||||||
GGML_API struct ggml_tensor * ggml_rope_xpos_inplace(
|
GGML_API struct ggml_tensor * ggml_rope_xpos_inplace(
|
||||||
|
|
67
llama.cpp
67
llama.cpp
|
@ -2471,13 +2471,14 @@ static struct ggml_cgraph * llm_build_llama(
|
||||||
|
|
||||||
GGML_ASSERT(!!kv_self.ctx);
|
GGML_ASSERT(!!kv_self.ctx);
|
||||||
|
|
||||||
const int64_t n_embd = hparams.n_embd;
|
const int32_t n_embd = hparams.n_embd;
|
||||||
const int64_t n_layer = hparams.n_layer;
|
const int32_t n_layer = hparams.n_layer;
|
||||||
const int64_t n_ctx = hparams.n_ctx;
|
const int32_t n_ctx = hparams.n_ctx;
|
||||||
const int64_t n_head = hparams.n_head;
|
const int32_t n_orig_ctx = hparams.n_yarn_orig_ctx;
|
||||||
const int64_t n_head_kv = hparams.n_head_kv;
|
const int32_t n_head = hparams.n_head;
|
||||||
const int64_t n_embd_head = hparams.n_embd_head();
|
const int32_t n_head_kv = hparams.n_head_kv;
|
||||||
const int64_t n_embd_gqa = hparams.n_embd_gqa();
|
const int32_t n_embd_head = hparams.n_embd_head();
|
||||||
|
const int32_t n_embd_gqa = hparams.n_embd_gqa();
|
||||||
|
|
||||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||||
|
|
||||||
|
@ -2599,14 +2600,18 @@ static struct ggml_cgraph * llm_build_llama(
|
||||||
ggml_set_name(tmpq, "tmpq");
|
ggml_set_name(tmpq, "tmpq");
|
||||||
|
|
||||||
struct ggml_tensor * Kcur = ggml_rope_custom_inplace(
|
struct ggml_tensor * Kcur = ggml_rope_custom_inplace(
|
||||||
ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base,
|
ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N),
|
||||||
freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
|
n_past, n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale,
|
||||||
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
|
);
|
||||||
offload_func_kq(Kcur);
|
offload_func_kq(Kcur);
|
||||||
ggml_set_name(Kcur, "Kcur");
|
ggml_set_name(Kcur, "Kcur");
|
||||||
|
|
||||||
struct ggml_tensor * Qcur = ggml_rope_custom_inplace(
|
struct ggml_tensor * Qcur = ggml_rope_custom_inplace(
|
||||||
ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base,
|
ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N),
|
||||||
freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
|
n_past, n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale,
|
||||||
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
|
);
|
||||||
offload_func_kq(Qcur);
|
offload_func_kq(Qcur);
|
||||||
ggml_set_name(Qcur, "Qcur");
|
ggml_set_name(Qcur, "Qcur");
|
||||||
|
|
||||||
|
@ -2811,13 +2816,14 @@ static struct ggml_cgraph * llm_build_baichaun(
|
||||||
|
|
||||||
GGML_ASSERT(!!kv_self.ctx);
|
GGML_ASSERT(!!kv_self.ctx);
|
||||||
|
|
||||||
const int64_t n_embd = hparams.n_embd;
|
const int32_t n_embd = hparams.n_embd;
|
||||||
const int64_t n_layer = hparams.n_layer;
|
const int32_t n_layer = hparams.n_layer;
|
||||||
const int64_t n_ctx = hparams.n_ctx;
|
const int32_t n_ctx = hparams.n_ctx;
|
||||||
const int64_t n_head = hparams.n_head;
|
const int32_t n_orig_ctx = hparams.n_yarn_orig_ctx;
|
||||||
const int64_t n_head_kv = hparams.n_head_kv;
|
const int32_t n_head = hparams.n_head;
|
||||||
const int64_t n_embd_head = hparams.n_embd_head();
|
const int32_t n_head_kv = hparams.n_head_kv;
|
||||||
const int64_t n_embd_gqa = hparams.n_embd_gqa();
|
const int32_t n_embd_head = hparams.n_embd_head();
|
||||||
|
const int32_t n_embd_gqa = hparams.n_embd_gqa();
|
||||||
|
|
||||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||||
|
|
||||||
|
@ -2945,12 +2951,14 @@ static struct ggml_cgraph * llm_build_baichaun(
|
||||||
Kcur = ggml_rope_custom_inplace(
|
Kcur = ggml_rope_custom_inplace(
|
||||||
ctx0,
|
ctx0,
|
||||||
ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N),
|
ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N),
|
||||||
n_past, n_embd_head, 0, 0, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
n_past, n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale,
|
||||||
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
Qcur = ggml_rope_custom_inplace(
|
Qcur = ggml_rope_custom_inplace(
|
||||||
ctx0,
|
ctx0,
|
||||||
ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N),
|
ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N),
|
||||||
n_past, n_embd_head, 0, 0, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
n_past, n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale,
|
||||||
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
break;
|
break;
|
||||||
case MODEL_13B:
|
case MODEL_13B:
|
||||||
|
@ -3183,13 +3191,14 @@ static struct ggml_cgraph * llm_build_falcon(
|
||||||
|
|
||||||
GGML_ASSERT(!!kv_self.ctx);
|
GGML_ASSERT(!!kv_self.ctx);
|
||||||
|
|
||||||
const int64_t n_embd = hparams.n_embd;
|
const int32_t n_embd = hparams.n_embd;
|
||||||
const int64_t n_layer = hparams.n_layer;
|
const int32_t n_layer = hparams.n_layer;
|
||||||
const int64_t n_ctx = hparams.n_ctx;
|
const int32_t n_ctx = hparams.n_ctx;
|
||||||
const int64_t n_head = hparams.n_head;
|
const int32_t n_orig_ctx = hparams.n_yarn_orig_ctx;
|
||||||
const int64_t n_head_kv = hparams.n_head_kv;
|
const int32_t n_head = hparams.n_head;
|
||||||
const int64_t n_embd_head = hparams.n_embd_head();
|
const int32_t n_head_kv = hparams.n_head_kv;
|
||||||
const int64_t n_embd_gqa = hparams.n_embd_gqa();
|
const int32_t n_embd_head = hparams.n_embd_head();
|
||||||
|
const int32_t n_embd_gqa = hparams.n_embd_gqa();
|
||||||
|
|
||||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||||
|
|
||||||
|
@ -3351,12 +3360,12 @@ static struct ggml_cgraph * llm_build_falcon(
|
||||||
|
|
||||||
// using mode = 2 for neox mode
|
// using mode = 2 for neox mode
|
||||||
struct ggml_tensor * Qcur = ggml_rope_custom_inplace(
|
struct ggml_tensor * Qcur = ggml_rope_custom_inplace(
|
||||||
ctx0, tmpq, n_past, n_embd_head, 2, 0,
|
ctx0, tmpq, n_past, n_embd_head, 2, 0, n_orig_ctx,
|
||||||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
offload_func_kq(Qcur);
|
offload_func_kq(Qcur);
|
||||||
struct ggml_tensor * Kcur = ggml_rope_custom_inplace(
|
struct ggml_tensor * Kcur = ggml_rope_custom_inplace(
|
||||||
ctx0, tmpk, n_past, n_embd_head, 2, 0,
|
ctx0, tmpk, n_past, n_embd_head, 2, 0, n_orig_ctx,
|
||||||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
offload_func_kq(Kcur);
|
offload_func_kq(Kcur);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue