naming : n_orig_ctx -> n_ctx_orig
ggml-ci
This commit is contained in:
parent
c989fd06b1
commit
572446bfe1
14 changed files with 137 additions and 137 deletions
|
@ -176,7 +176,7 @@ class Params:
|
||||||
rope_scaling_type: gguf.RopeScalingType | None = None
|
rope_scaling_type: gguf.RopeScalingType | None = None
|
||||||
f_rope_freq_base: float | None = None
|
f_rope_freq_base: float | None = None
|
||||||
f_rope_scale: float | None = None
|
f_rope_scale: float | None = None
|
||||||
n_orig_ctx: int | None = None
|
n_ctx_orig: int | None = None
|
||||||
rope_finetuned: bool | None = None
|
rope_finetuned: bool | None = None
|
||||||
|
|
||||||
ftype: GGMLFileType | None = None
|
ftype: GGMLFileType | None = None
|
||||||
|
@ -226,7 +226,7 @@ class Params:
|
||||||
with open(config_path) as f:
|
with open(config_path) as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
|
|
||||||
rope_scaling_type = f_rope_scale = n_orig_ctx = rope_finetuned = None
|
rope_scaling_type = f_rope_scale = n_ctx_orig = rope_finetuned = None
|
||||||
rope_scaling = config.get("rope_scaling")
|
rope_scaling = config.get("rope_scaling")
|
||||||
|
|
||||||
if rope_scaling is not None and (typ := rope_scaling.get("type")):
|
if rope_scaling is not None and (typ := rope_scaling.get("type")):
|
||||||
|
@ -236,7 +236,7 @@ class Params:
|
||||||
rope_scaling_type = gguf.RopeScalingType.LINEAR
|
rope_scaling_type = gguf.RopeScalingType.LINEAR
|
||||||
elif typ == "yarn":
|
elif typ == "yarn":
|
||||||
rope_scaling_type = gguf.RopeScalingType.YARN
|
rope_scaling_type = gguf.RopeScalingType.YARN
|
||||||
n_orig_ctx = rope_scaling['original_max_position_embeddings']
|
n_ctx_orig = rope_scaling['original_max_position_embeddings']
|
||||||
rope_finetuned = rope_scaling['finetuned']
|
rope_finetuned = rope_scaling['finetuned']
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f'Unknown rope scaling type: {typ}')
|
raise NotImplementedError(f'Unknown rope scaling type: {typ}')
|
||||||
|
@ -272,7 +272,7 @@ class Params:
|
||||||
f_rope_freq_base = config.get("rope_theta"),
|
f_rope_freq_base = config.get("rope_theta"),
|
||||||
rope_scaling_type = rope_scaling_type,
|
rope_scaling_type = rope_scaling_type,
|
||||||
f_rope_scale = f_rope_scale,
|
f_rope_scale = f_rope_scale,
|
||||||
n_orig_ctx = n_orig_ctx,
|
n_ctx_orig = n_ctx_orig,
|
||||||
rope_finetuned = rope_finetuned,
|
rope_finetuned = rope_finetuned,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -864,8 +864,8 @@ class OutputFile:
|
||||||
self.gguf.add_rope_scaling_type(params.rope_scaling_type)
|
self.gguf.add_rope_scaling_type(params.rope_scaling_type)
|
||||||
self.gguf.add_rope_scaling_factor(params.f_rope_scale)
|
self.gguf.add_rope_scaling_factor(params.f_rope_scale)
|
||||||
|
|
||||||
if params.n_orig_ctx is not None:
|
if params.n_ctx_orig is not None:
|
||||||
self.gguf.add_rope_scaling_orig_ctx_len(params.n_orig_ctx)
|
self.gguf.add_rope_scaling_orig_ctx_len(params.n_ctx_orig)
|
||||||
|
|
||||||
if params.rope_finetuned is not None:
|
if params.rope_finetuned is not None:
|
||||||
self.gguf.add_rope_scaling_finetuned(params.rope_finetuned)
|
self.gguf.add_rope_scaling_finetuned(params.rope_finetuned)
|
||||||
|
|
|
@ -209,7 +209,7 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||||
const int mode = ((int32_t *) dst->op_params)[2];
|
const int mode = ((int32_t *) dst->op_params)[2];
|
||||||
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||||
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
||||||
|
|
||||||
// RoPE alteration for extended context
|
// RoPE alteration for extended context
|
||||||
float freq_base;
|
float freq_base;
|
||||||
|
@ -236,7 +236,7 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
}
|
}
|
||||||
|
|
||||||
rope_corr_dims corr_dims;
|
rope_corr_dims corr_dims;
|
||||||
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
|
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
|
||||||
|
|
||||||
// compute
|
// compute
|
||||||
if (is_neox) {
|
if (is_neox) {
|
||||||
|
|
|
@ -1183,7 +1183,7 @@ static void ggml_vk_rope(
|
||||||
const std::shared_ptr<kp::Tensor>& inB,
|
const std::shared_ptr<kp::Tensor>& inB,
|
||||||
const std::shared_ptr<kp::Tensor>& out,
|
const std::shared_ptr<kp::Tensor>& out,
|
||||||
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
|
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
|
||||||
ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_orig_ctx,
|
ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_ctx_orig,
|
||||||
float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, float beta_slow,
|
float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, float beta_slow,
|
||||||
int32_t ne01, int32_t ne02, int32_t ne03,
|
int32_t ne01, int32_t ne02, int32_t ne03,
|
||||||
uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
|
uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
|
||||||
|
@ -1212,14 +1212,14 @@ static void ggml_vk_rope(
|
||||||
|
|
||||||
struct PushConstants {
|
struct PushConstants {
|
||||||
uint32_t inAOff, inBOff, outOff;
|
uint32_t inAOff, inBOff, outOff;
|
||||||
int32_t n_dims, mode, n_orig_ctx;
|
int32_t n_dims, mode, n_ctx_orig;
|
||||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||||
uint32_t nb00, nb01, nb02, nb03;
|
uint32_t nb00, nb01, nb02, nb03;
|
||||||
int32_t ne0;
|
int32_t ne0;
|
||||||
uint32_t nb0, nb1, nb2, nb3;
|
uint32_t nb0, nb1, nb2, nb3;
|
||||||
} pushConsts {
|
} pushConsts {
|
||||||
safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(outOff, type_size),
|
safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(outOff, type_size),
|
||||||
n_dims, mode, n_orig_ctx,
|
n_dims, mode, n_ctx_orig,
|
||||||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
|
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
|
||||||
nb00, nb01, nb02, nb03,
|
nb00, nb01, nb02, nb03,
|
||||||
ne0,
|
ne0,
|
||||||
|
@ -1686,7 +1686,7 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
||||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||||
const int mode = ((int32_t *) dst->op_params)[2];
|
const int mode = ((int32_t *) dst->op_params)[2];
|
||||||
// skip 3, n_ctx used in GLM RoPE, unimplemented in Vulkan
|
// skip 3, n_ctx used in GLM RoPE, unimplemented in Vulkan
|
||||||
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
||||||
|
|
||||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||||
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||||
|
@ -1696,7 +1696,7 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
||||||
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||||
ggml_vk_rope(
|
ggml_vk_rope(
|
||||||
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, n_orig_ctx,
|
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, n_ctx_orig,
|
||||||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
|
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
|
||||||
ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3
|
ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3
|
||||||
);
|
);
|
||||||
|
|
|
@ -2289,7 +2289,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||||
const int mode = ((int32_t *) dst->op_params)[2];
|
const int mode = ((int32_t *) dst->op_params)[2];
|
||||||
// skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
|
// skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
|
||||||
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
||||||
|
|
||||||
float freq_base;
|
float freq_base;
|
||||||
float freq_scale;
|
float freq_scale;
|
||||||
|
@ -2350,7 +2350,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
|
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
|
||||||
[encoder setBytes:&n_past length:sizeof( int) atIndex:20];
|
[encoder setBytes:&n_past length:sizeof( int) atIndex:20];
|
||||||
[encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
|
[encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
|
||||||
[encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22];
|
[encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22];
|
||||||
[encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
|
[encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
|
||||||
[encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
|
[encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
|
||||||
[encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
|
[encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
|
||||||
|
|
|
@ -1671,16 +1671,16 @@ static void rope_yarn(
|
||||||
|
|
||||||
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
|
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
|
||||||
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
|
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
|
||||||
static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) {
|
static float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
|
||||||
return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base));
|
return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base));
|
||||||
}
|
}
|
||||||
|
|
||||||
static void rope_yarn_corr_dims(
|
static void rope_yarn_corr_dims(
|
||||||
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
|
int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
|
||||||
) {
|
) {
|
||||||
// start and end correction dims
|
// start and end correction dims
|
||||||
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base)));
|
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));
|
||||||
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
|
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
|
@ -1707,7 +1707,7 @@ kernel void kernel_rope_norm(
|
||||||
constant uint64_t & nb3,
|
constant uint64_t & nb3,
|
||||||
constant int & n_past,
|
constant int & n_past,
|
||||||
constant int & n_dims,
|
constant int & n_dims,
|
||||||
constant int & n_orig_ctx,
|
constant int & n_ctx_orig,
|
||||||
constant float & freq_base,
|
constant float & freq_base,
|
||||||
constant float & freq_scale,
|
constant float & freq_scale,
|
||||||
constant float & ext_factor,
|
constant float & ext_factor,
|
||||||
|
@ -1722,7 +1722,7 @@ kernel void kernel_rope_norm(
|
||||||
const int64_t i1 = tgpig[0];
|
const int64_t i1 = tgpig[0];
|
||||||
|
|
||||||
float corr_dims[2];
|
float corr_dims[2];
|
||||||
rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
|
rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
||||||
|
|
||||||
device const int32_t * pos = src1;
|
device const int32_t * pos = src1;
|
||||||
|
|
||||||
|
@ -1784,7 +1784,7 @@ kernel void kernel_rope_neox(
|
||||||
constant uint64_t & nb3,
|
constant uint64_t & nb3,
|
||||||
constant int & n_past,
|
constant int & n_past,
|
||||||
constant int & n_dims,
|
constant int & n_dims,
|
||||||
constant int & n_orig_ctx,
|
constant int & n_ctx_orig,
|
||||||
constant float & freq_base,
|
constant float & freq_base,
|
||||||
constant float & freq_scale,
|
constant float & freq_scale,
|
||||||
constant float & ext_factor,
|
constant float & ext_factor,
|
||||||
|
@ -1799,7 +1799,7 @@ kernel void kernel_rope_neox(
|
||||||
const int64_t i1 = tgpig[0];
|
const int64_t i1 = tgpig[0];
|
||||||
|
|
||||||
float corr_dims[2];
|
float corr_dims[2];
|
||||||
rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
|
rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
||||||
|
|
||||||
device const int32_t * pos = src1;
|
device const int32_t * pos = src1;
|
||||||
|
|
||||||
|
|
|
@ -14008,7 +14008,7 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
|
||||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||||
const int mode = ((int32_t *) dst->op_params)[2];
|
const int mode = ((int32_t *) dst->op_params)[2];
|
||||||
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||||
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
||||||
|
|
||||||
// RoPE alteration for extended context
|
// RoPE alteration for extended context
|
||||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||||
|
@ -14040,7 +14040,7 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
|
||||||
}
|
}
|
||||||
|
|
||||||
rope_corr_dims corr_dims;
|
rope_corr_dims corr_dims;
|
||||||
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
|
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
|
||||||
|
|
||||||
// compute
|
// compute
|
||||||
if (is_neox) {
|
if (is_neox) {
|
||||||
|
|
|
@ -4293,7 +4293,7 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, con
|
||||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||||
const int mode = ((int32_t *) dst->op_params)[2];
|
const int mode = ((int32_t *) dst->op_params)[2];
|
||||||
// const int n_ctx = ((int32_t *) dst->op_params)[3];
|
// const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||||
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
||||||
const float freq_base = ((float *) dst->op_params)[5];
|
const float freq_base = ((float *) dst->op_params)[5];
|
||||||
const float freq_scale = ((float *) dst->op_params)[6];
|
const float freq_scale = ((float *) dst->op_params)[6];
|
||||||
const float ext_factor = ((float *) dst->op_params)[7];
|
const float ext_factor = ((float *) dst->op_params)[7];
|
||||||
|
@ -4304,7 +4304,7 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, con
|
||||||
const bool is_neox = mode & 2;
|
const bool is_neox = mode & 2;
|
||||||
|
|
||||||
float corr_dims[2];
|
float corr_dims[2];
|
||||||
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
|
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
||||||
|
|
||||||
if (is_neox) {
|
if (is_neox) {
|
||||||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
||||||
|
@ -6862,15 +6862,15 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
|
||||||
} else if (tensor->op == GGML_OP_ROPE) {
|
} else if (tensor->op == GGML_OP_ROPE) {
|
||||||
const int n_dims = ((int32_t *) tensor->op_params)[1];
|
const int n_dims = ((int32_t *) tensor->op_params)[1];
|
||||||
const int mode = ((int32_t *) tensor->op_params)[2];
|
const int mode = ((int32_t *) tensor->op_params)[2];
|
||||||
const int n_ggml_ctx = ((int32_t *) tensor->op_params)[3];
|
const int n_ctx_ggml = ((int32_t *) tensor->op_params)[3];
|
||||||
const int n_orig_ggml_ctx = ((int32_t *) tensor->op_params)[4];
|
const int n_ctx_orig_ggml = ((int32_t *) tensor->op_params)[4];
|
||||||
float freq_base = ((float *) tensor->op_params)[5];
|
float freq_base = ((float *) tensor->op_params)[5];
|
||||||
float freq_scale = ((float *) tensor->op_params)[6];
|
float freq_scale = ((float *) tensor->op_params)[6];
|
||||||
float ext_factor = ((float *) tensor->op_params)[7];
|
float ext_factor = ((float *) tensor->op_params)[7];
|
||||||
float attn_factor = ((float *) tensor->op_params)[8];
|
float attn_factor = ((float *) tensor->op_params)[8];
|
||||||
float beta_fast = ((float *) tensor->op_params)[9];
|
float beta_fast = ((float *) tensor->op_params)[9];
|
||||||
float beta_slow = ((float *) tensor->op_params)[10];
|
float beta_slow = ((float *) tensor->op_params)[10];
|
||||||
tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ggml_ctx, n_orig_ggml_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
|
tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_ggml, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
|
||||||
} else if (tensor->op == GGML_OP_UNARY) {
|
} else if (tensor->op == GGML_OP_UNARY) {
|
||||||
switch (ggml_get_unary_op(tensor)) {
|
switch (ggml_get_unary_op(tensor)) {
|
||||||
case GGML_UNARY_OP_SILU:
|
case GGML_UNARY_OP_SILU:
|
||||||
|
|
50
ggml.c
50
ggml.c
|
@ -6249,7 +6249,7 @@ static struct ggml_tensor * ggml_rope_impl(
|
||||||
struct ggml_tensor * c,
|
struct ggml_tensor * c,
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode,
|
int mode,
|
||||||
int n_orig_ctx,
|
int n_ctx_orig,
|
||||||
float freq_base,
|
float freq_base,
|
||||||
float freq_scale,
|
float freq_scale,
|
||||||
float ext_factor,
|
float ext_factor,
|
||||||
|
@ -6276,7 +6276,7 @@ static struct ggml_tensor * ggml_rope_impl(
|
||||||
|
|
||||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||||
|
|
||||||
int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_orig_ctx };
|
int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
|
||||||
memcpy(params + 5, &freq_base, sizeof(float));
|
memcpy(params + 5, &freq_base, sizeof(float));
|
||||||
memcpy(params + 6, &freq_scale, sizeof(float));
|
memcpy(params + 6, &freq_scale, sizeof(float));
|
||||||
memcpy(params + 7, &ext_factor, sizeof(float));
|
memcpy(params + 7, &ext_factor, sizeof(float));
|
||||||
|
@ -6323,7 +6323,7 @@ struct ggml_tensor * ggml_rope_ext(
|
||||||
struct ggml_tensor * c,
|
struct ggml_tensor * c,
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode,
|
int mode,
|
||||||
int n_orig_ctx,
|
int n_ctx_orig,
|
||||||
float freq_base,
|
float freq_base,
|
||||||
float freq_scale,
|
float freq_scale,
|
||||||
float ext_factor,
|
float ext_factor,
|
||||||
|
@ -6331,7 +6331,7 @@ struct ggml_tensor * ggml_rope_ext(
|
||||||
float beta_fast,
|
float beta_fast,
|
||||||
float beta_slow) {
|
float beta_slow) {
|
||||||
return ggml_rope_impl(
|
return ggml_rope_impl(
|
||||||
ctx, a, b, c, n_dims, mode, n_orig_ctx, freq_base, freq_scale,
|
ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow, false
|
ext_factor, attn_factor, beta_fast, beta_slow, false
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -6343,7 +6343,7 @@ struct ggml_tensor * ggml_rope_ext_inplace(
|
||||||
struct ggml_tensor * c,
|
struct ggml_tensor * c,
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode,
|
int mode,
|
||||||
int n_orig_ctx,
|
int n_ctx_orig,
|
||||||
float freq_base,
|
float freq_base,
|
||||||
float freq_scale,
|
float freq_scale,
|
||||||
float ext_factor,
|
float ext_factor,
|
||||||
|
@ -6351,7 +6351,7 @@ struct ggml_tensor * ggml_rope_ext_inplace(
|
||||||
float beta_fast,
|
float beta_fast,
|
||||||
float beta_slow) {
|
float beta_slow) {
|
||||||
return ggml_rope_impl(
|
return ggml_rope_impl(
|
||||||
ctx, a, b, c, n_dims, mode, n_orig_ctx, freq_base, freq_scale,
|
ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow, true
|
ext_factor, attn_factor, beta_fast, beta_slow, true
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -6362,7 +6362,7 @@ struct ggml_tensor * ggml_rope_custom(
|
||||||
struct ggml_tensor * b,
|
struct ggml_tensor * b,
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode,
|
int mode,
|
||||||
int n_orig_ctx,
|
int n_ctx_orig,
|
||||||
float freq_base,
|
float freq_base,
|
||||||
float freq_scale,
|
float freq_scale,
|
||||||
float ext_factor,
|
float ext_factor,
|
||||||
|
@ -6370,7 +6370,7 @@ struct ggml_tensor * ggml_rope_custom(
|
||||||
float beta_fast,
|
float beta_fast,
|
||||||
float beta_slow) {
|
float beta_slow) {
|
||||||
return ggml_rope_impl(
|
return ggml_rope_impl(
|
||||||
ctx, a, b, NULL, n_dims, mode, n_orig_ctx, freq_base, freq_scale,
|
ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow, false
|
ext_factor, attn_factor, beta_fast, beta_slow, false
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -6381,7 +6381,7 @@ struct ggml_tensor * ggml_rope_custom_inplace(
|
||||||
struct ggml_tensor * b,
|
struct ggml_tensor * b,
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode,
|
int mode,
|
||||||
int n_orig_ctx,
|
int n_ctx_orig,
|
||||||
float freq_base,
|
float freq_base,
|
||||||
float freq_scale,
|
float freq_scale,
|
||||||
float ext_factor,
|
float ext_factor,
|
||||||
|
@ -6389,7 +6389,7 @@ struct ggml_tensor * ggml_rope_custom_inplace(
|
||||||
float beta_fast,
|
float beta_fast,
|
||||||
float beta_slow) {
|
float beta_slow) {
|
||||||
return ggml_rope_impl(
|
return ggml_rope_impl(
|
||||||
ctx, a, b, NULL, n_dims, mode, n_orig_ctx, freq_base, freq_scale,
|
ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow, true
|
ext_factor, attn_factor, beta_fast, beta_slow, true
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -6403,7 +6403,7 @@ struct ggml_tensor * ggml_rope_back(
|
||||||
struct ggml_tensor * c,
|
struct ggml_tensor * c,
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode,
|
int mode,
|
||||||
int n_orig_ctx,
|
int n_ctx_orig,
|
||||||
float freq_base,
|
float freq_base,
|
||||||
float freq_scale,
|
float freq_scale,
|
||||||
float ext_factor,
|
float ext_factor,
|
||||||
|
@ -6425,7 +6425,7 @@ struct ggml_tensor * ggml_rope_back(
|
||||||
|
|
||||||
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
||||||
|
|
||||||
int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_orig_ctx };
|
int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
|
||||||
memcpy(params + 5, &freq_base, sizeof(float));
|
memcpy(params + 5, &freq_base, sizeof(float));
|
||||||
memcpy(params + 6, &freq_scale, sizeof(float));
|
memcpy(params + 6, &freq_scale, sizeof(float));
|
||||||
memcpy(params + 7, &ext_factor, sizeof(float));
|
memcpy(params + 7, &ext_factor, sizeof(float));
|
||||||
|
@ -14250,8 +14250,8 @@ static void rope_yarn(
|
||||||
|
|
||||||
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
|
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
|
||||||
// `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
|
// `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
|
||||||
static float ggml_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot, float base) {
|
static float ggml_rope_yarn_corr_dim(int n_dims, int n_ctx_orig, float n_rot, float base) {
|
||||||
return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
|
return n_dims * logf(n_ctx_orig / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_rope_cache_init(
|
static void ggml_rope_cache_init(
|
||||||
|
@ -14271,11 +14271,11 @@ static void ggml_rope_cache_init(
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_CALL void ggml_rope_yarn_corr_dims(
|
GGML_CALL void ggml_rope_yarn_corr_dims(
|
||||||
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
|
int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
|
||||||
) {
|
) {
|
||||||
// start and end correction dims
|
// start and end correction dims
|
||||||
float start = floorf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_fast, freq_base));
|
float start = floorf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_fast, freq_base));
|
||||||
float end = ceilf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_slow, freq_base));
|
float end = ceilf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_slow, freq_base));
|
||||||
dims[0] = MAX(0, start);
|
dims[0] = MAX(0, start);
|
||||||
dims[1] = MIN(n_dims - 1, end);
|
dims[1] = MIN(n_dims - 1, end);
|
||||||
}
|
}
|
||||||
|
@ -14299,7 +14299,7 @@ static void ggml_compute_forward_rope_f32(
|
||||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||||
const int mode = ((int32_t *) dst->op_params)[2];
|
const int mode = ((int32_t *) dst->op_params)[2];
|
||||||
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||||
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
||||||
|
|
||||||
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||||
|
@ -14336,7 +14336,7 @@ static void ggml_compute_forward_rope_f32(
|
||||||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
||||||
|
|
||||||
float corr_dims[2];
|
float corr_dims[2];
|
||||||
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
|
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
||||||
|
|
||||||
const bool is_neox = mode & 2;
|
const bool is_neox = mode & 2;
|
||||||
|
|
||||||
|
@ -14429,7 +14429,7 @@ static void ggml_compute_forward_rope_f16(
|
||||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||||
const int mode = ((int32_t *) dst->op_params)[2];
|
const int mode = ((int32_t *) dst->op_params)[2];
|
||||||
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||||
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
||||||
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||||
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
||||||
|
@ -14465,7 +14465,7 @@ static void ggml_compute_forward_rope_f16(
|
||||||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
||||||
|
|
||||||
float corr_dims[2];
|
float corr_dims[2];
|
||||||
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
|
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
||||||
|
|
||||||
const bool is_neox = mode & 2;
|
const bool is_neox = mode & 2;
|
||||||
|
|
||||||
|
@ -18234,7 +18234,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||||
const int n_dims = ((int32_t *) tensor->op_params)[1];
|
const int n_dims = ((int32_t *) tensor->op_params)[1];
|
||||||
const int mode = ((int32_t *) tensor->op_params)[2];
|
const int mode = ((int32_t *) tensor->op_params)[2];
|
||||||
//const int n_ctx = ((int32_t *) tensor->op_params)[3];
|
//const int n_ctx = ((int32_t *) tensor->op_params)[3];
|
||||||
const int n_orig_ctx = ((int32_t *) tensor->op_params)[4];
|
const int n_ctx_orig = ((int32_t *) tensor->op_params)[4];
|
||||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||||
|
|
||||||
memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float));
|
memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float));
|
||||||
|
@ -18252,7 +18252,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||||
src2,
|
src2,
|
||||||
n_dims,
|
n_dims,
|
||||||
mode,
|
mode,
|
||||||
n_orig_ctx,
|
n_ctx_orig,
|
||||||
freq_base,
|
freq_base,
|
||||||
freq_scale,
|
freq_scale,
|
||||||
ext_factor,
|
ext_factor,
|
||||||
|
@ -18269,7 +18269,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||||
const int n_dims = ((int32_t *) tensor->op_params)[1];
|
const int n_dims = ((int32_t *) tensor->op_params)[1];
|
||||||
const int mode = ((int32_t *) tensor->op_params)[2];
|
const int mode = ((int32_t *) tensor->op_params)[2];
|
||||||
//const int n_ctx = ((int32_t *) tensor->op_params)[3];
|
//const int n_ctx = ((int32_t *) tensor->op_params)[3];
|
||||||
const int n_orig_ctx = ((int32_t *) tensor->op_params)[4];
|
const int n_ctx_orig = ((int32_t *) tensor->op_params)[4];
|
||||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||||
|
|
||||||
memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float));
|
memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float));
|
||||||
|
@ -18287,7 +18287,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||||
src2,
|
src2,
|
||||||
n_dims,
|
n_dims,
|
||||||
mode,
|
mode,
|
||||||
n_orig_ctx,
|
n_ctx_orig,
|
||||||
freq_base,
|
freq_base,
|
||||||
freq_scale,
|
freq_scale,
|
||||||
ext_factor,
|
ext_factor,
|
||||||
|
|
12
ggml.h
12
ggml.h
|
@ -1491,7 +1491,7 @@ extern "C" {
|
||||||
struct ggml_tensor * c,
|
struct ggml_tensor * c,
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode,
|
int mode,
|
||||||
int n_orig_ctx,
|
int n_ctx_orig,
|
||||||
float freq_base,
|
float freq_base,
|
||||||
float freq_scale,
|
float freq_scale,
|
||||||
float ext_factor,
|
float ext_factor,
|
||||||
|
@ -1507,7 +1507,7 @@ extern "C" {
|
||||||
struct ggml_tensor * c,
|
struct ggml_tensor * c,
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode,
|
int mode,
|
||||||
int n_orig_ctx,
|
int n_ctx_orig,
|
||||||
float freq_base,
|
float freq_base,
|
||||||
float freq_scale,
|
float freq_scale,
|
||||||
float ext_factor,
|
float ext_factor,
|
||||||
|
@ -1521,7 +1521,7 @@ extern "C" {
|
||||||
struct ggml_tensor * b,
|
struct ggml_tensor * b,
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode,
|
int mode,
|
||||||
int n_orig_ctx,
|
int n_ctx_orig,
|
||||||
float freq_base,
|
float freq_base,
|
||||||
float freq_scale,
|
float freq_scale,
|
||||||
float ext_factor,
|
float ext_factor,
|
||||||
|
@ -1536,7 +1536,7 @@ extern "C" {
|
||||||
struct ggml_tensor * b,
|
struct ggml_tensor * b,
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode,
|
int mode,
|
||||||
int n_orig_ctx,
|
int n_ctx_orig,
|
||||||
float freq_base,
|
float freq_base,
|
||||||
float freq_scale,
|
float freq_scale,
|
||||||
float ext_factor,
|
float ext_factor,
|
||||||
|
@ -1547,7 +1547,7 @@ extern "C" {
|
||||||
|
|
||||||
// compute correction dims for YaRN RoPE scaling
|
// compute correction dims for YaRN RoPE scaling
|
||||||
GGML_CALL void ggml_rope_yarn_corr_dims(
|
GGML_CALL void ggml_rope_yarn_corr_dims(
|
||||||
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]);
|
int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]);
|
||||||
|
|
||||||
// rotary position embedding backward, i.e compute dx from dy
|
// rotary position embedding backward, i.e compute dx from dy
|
||||||
// a - dy
|
// a - dy
|
||||||
|
@ -1558,7 +1558,7 @@ extern "C" {
|
||||||
struct ggml_tensor * c,
|
struct ggml_tensor * c,
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode,
|
int mode,
|
||||||
int n_orig_ctx,
|
int n_ctx_orig,
|
||||||
float freq_base,
|
float freq_base,
|
||||||
float freq_scale,
|
float freq_scale,
|
||||||
float ext_factor,
|
float ext_factor,
|
||||||
|
|
|
@ -14,7 +14,7 @@ void main() {
|
||||||
const bool is_neox = (pcs.mode & 2) != 0;
|
const bool is_neox = (pcs.mode & 2) != 0;
|
||||||
|
|
||||||
float corr_dims[2];
|
float corr_dims[2];
|
||||||
rope_yarn_corr_dims(pcs.n_dims, pcs.n_orig_ctx, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
|
rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
|
||||||
|
|
||||||
const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
|
const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,7 @@ void main() {
|
||||||
const bool is_neox = (pcs.mode & 2) != 0;
|
const bool is_neox = (pcs.mode & 2) != 0;
|
||||||
|
|
||||||
float corr_dims[2];
|
float corr_dims[2];
|
||||||
rope_yarn_corr_dims(pcs.n_dims, pcs.n_orig_ctx, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
|
rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
|
||||||
|
|
||||||
const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
|
const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
|
||||||
|
|
||||||
|
|
|
@ -9,7 +9,7 @@ layout (push_constant) uniform parameter {
|
||||||
uint outOff;
|
uint outOff;
|
||||||
int n_dims;
|
int n_dims;
|
||||||
int mode;
|
int mode;
|
||||||
int n_orig_ctx;
|
int n_ctx_orig;
|
||||||
float freq_base;
|
float freq_base;
|
||||||
float freq_scale;
|
float freq_scale;
|
||||||
float ext_factor;
|
float ext_factor;
|
||||||
|
@ -54,14 +54,14 @@ void rope_yarn(
|
||||||
|
|
||||||
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
|
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
|
||||||
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
|
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
|
||||||
float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) {
|
float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
|
||||||
return n_dims * log(n_orig_ctx / (n_rot * TWOPI_F)) / (2 * log(base));
|
return n_dims * log(n_ctx_orig / (n_rot * TWOPI_F)) / (2 * log(base));
|
||||||
}
|
}
|
||||||
|
|
||||||
void rope_yarn_corr_dims(
|
void rope_yarn_corr_dims(
|
||||||
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, out float dims[2]
|
int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, out float dims[2]
|
||||||
) {
|
) {
|
||||||
// start and end correction dims
|
// start and end correction dims
|
||||||
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base)));
|
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));
|
||||||
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
|
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));
|
||||||
}
|
}
|
||||||
|
|
124
llama.cpp
124
llama.cpp
|
@ -1850,7 +1850,7 @@ struct llama_hparams {
|
||||||
float rope_attn_factor = 1.0f;
|
float rope_attn_factor = 1.0f;
|
||||||
float rope_freq_base_train;
|
float rope_freq_base_train;
|
||||||
float rope_freq_scale_train;
|
float rope_freq_scale_train;
|
||||||
uint32_t n_yarn_orig_ctx;
|
uint32_t n_ctx_orig_yarn;
|
||||||
float rope_yarn_log_mul;
|
float rope_yarn_log_mul;
|
||||||
|
|
||||||
// for State Space Models
|
// for State Space Models
|
||||||
|
@ -1892,7 +1892,7 @@ struct llama_hparams {
|
||||||
if (this->n_expert_shared != other.n_expert_shared) return true;
|
if (this->n_expert_shared != other.n_expert_shared) return true;
|
||||||
|
|
||||||
if (this->rope_finetuned != other.rope_finetuned) return true;
|
if (this->rope_finetuned != other.rope_finetuned) return true;
|
||||||
if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true;
|
if (this->n_ctx_orig_yarn != other.n_ctx_orig_yarn) return true;
|
||||||
|
|
||||||
if (this->ssm_d_conv != other.ssm_d_conv) return true;
|
if (this->ssm_d_conv != other.ssm_d_conv) return true;
|
||||||
if (this->ssm_d_inner != other.ssm_d_inner) return true;
|
if (this->ssm_d_inner != other.ssm_d_inner) return true;
|
||||||
|
@ -1951,7 +1951,7 @@ struct llama_cparams {
|
||||||
float rope_freq_base;
|
float rope_freq_base;
|
||||||
float rope_freq_scale;
|
float rope_freq_scale;
|
||||||
|
|
||||||
uint32_t n_yarn_orig_ctx;
|
uint32_t n_ctx_orig_yarn;
|
||||||
// These hyperparameters are not exposed in GGUF, because all
|
// These hyperparameters are not exposed in GGUF, because all
|
||||||
// existing YaRN models use the same values for them.
|
// existing YaRN models use the same values for them.
|
||||||
float yarn_ext_factor;
|
float yarn_ext_factor;
|
||||||
|
@ -4003,8 +4003,8 @@ static void llm_load_hparams(
|
||||||
ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false);
|
ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false);
|
||||||
hparams.rope_finetuned = rope_finetuned;
|
hparams.rope_finetuned = rope_finetuned;
|
||||||
|
|
||||||
hparams.n_yarn_orig_ctx = hparams.n_ctx_train;
|
hparams.n_ctx_orig_yarn = hparams.n_ctx_train;
|
||||||
ml.get_key(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, hparams.n_yarn_orig_ctx, false);
|
ml.get_key(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, hparams.n_ctx_orig_yarn, false);
|
||||||
|
|
||||||
// rope_freq_base (optional)
|
// rope_freq_base (optional)
|
||||||
hparams.rope_freq_base_train = 10000.0f;
|
hparams.rope_freq_base_train = 10000.0f;
|
||||||
|
@ -4904,7 +4904,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
|
||||||
LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type);
|
LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type);
|
||||||
LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train);
|
LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train);
|
||||||
LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
|
LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
|
||||||
LLAMA_LOG_INFO("%s: n_yarn_orig_ctx = %u\n", __func__, hparams.n_yarn_orig_ctx);
|
LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn);
|
||||||
LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown");
|
LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown");
|
||||||
LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv);
|
LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv);
|
||||||
LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner);
|
LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner);
|
||||||
|
@ -7072,7 +7072,7 @@ struct llm_build_context {
|
||||||
const int32_t n_kv; // size of KV cache to consider (n_kv <= kv_self.size)
|
const int32_t n_kv; // size of KV cache to consider (n_kv <= kv_self.size)
|
||||||
const int32_t n_outputs;
|
const int32_t n_outputs;
|
||||||
const int32_t kv_head; // index of where we store new KV data in the cache
|
const int32_t kv_head; // index of where we store new KV data in the cache
|
||||||
const int32_t n_orig_ctx;
|
const int32_t n_ctx_orig;
|
||||||
|
|
||||||
const bool flash_attn;
|
const bool flash_attn;
|
||||||
|
|
||||||
|
@ -7121,7 +7121,7 @@ struct llm_build_context {
|
||||||
n_kv (worst_case ? kv_self.size : kv_self.n),
|
n_kv (worst_case ? kv_self.size : kv_self.n),
|
||||||
n_outputs (worst_case ? n_tokens : lctx.n_outputs),
|
n_outputs (worst_case ? n_tokens : lctx.n_outputs),
|
||||||
kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head),
|
kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head),
|
||||||
n_orig_ctx (cparams.n_yarn_orig_ctx),
|
n_ctx_orig (cparams.n_ctx_orig_yarn),
|
||||||
flash_attn (cparams.flash_attn),
|
flash_attn (cparams.flash_attn),
|
||||||
pooling_type (cparams.pooling_type),
|
pooling_type (cparams.pooling_type),
|
||||||
rope_type (hparams.rope_type),
|
rope_type (hparams.rope_type),
|
||||||
|
@ -7179,7 +7179,7 @@ struct llm_build_context {
|
||||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
|
ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
|
||||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
|
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
|
||||||
0),
|
0),
|
||||||
lctx.inp_K_shift, rope_factors, n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
lctx.inp_K_shift, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||||
|
|
||||||
cb(tmp, "K_shifted", il);
|
cb(tmp, "K_shifted", il);
|
||||||
|
@ -7288,7 +7288,7 @@ struct llm_build_context {
|
||||||
// choose long/short freq factors based on the context size
|
// choose long/short freq factors based on the context size
|
||||||
const auto n_ctx_pre_seq = cparams.n_ctx / cparams.n_seq_max;
|
const auto n_ctx_pre_seq = cparams.n_ctx / cparams.n_seq_max;
|
||||||
|
|
||||||
if (n_ctx_pre_seq > hparams.n_yarn_orig_ctx) {
|
if (n_ctx_pre_seq > hparams.n_ctx_orig_yarn) {
|
||||||
return model.layers[il].rope_long;
|
return model.layers[il].rope_long;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7404,14 +7404,14 @@ struct llm_build_context {
|
||||||
|
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
@ -7535,12 +7535,12 @@ struct llm_build_context {
|
||||||
case MODEL_7B:
|
case MODEL_7B:
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
break;
|
break;
|
||||||
|
@ -7647,14 +7647,14 @@ struct llm_build_context {
|
||||||
|
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
@ -7767,13 +7767,13 @@ struct llm_build_context {
|
||||||
|
|
||||||
// using mode = 2 for neox mode
|
// using mode = 2 for neox mode
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_orig_ctx,
|
ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
|
||||||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_orig_ctx,
|
ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
|
||||||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
@ -7891,14 +7891,14 @@ struct llm_build_context {
|
||||||
|
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
@ -8044,14 +8044,14 @@ struct llm_build_context {
|
||||||
|
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
@ -8398,14 +8398,14 @@ struct llm_build_context {
|
||||||
|
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
@ -8838,14 +8838,14 @@ struct llm_build_context {
|
||||||
|
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, Qcur, inp_pos, nullptr,
|
ctx0, Qcur, inp_pos, nullptr,
|
||||||
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, Kcur, inp_pos, nullptr,
|
ctx0, Kcur, inp_pos, nullptr,
|
||||||
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
@ -8957,13 +8957,13 @@ struct llm_build_context {
|
||||||
|
|
||||||
// using mode = 2 for neox mode
|
// using mode = 2 for neox mode
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_orig_ctx,
|
ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
|
||||||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_orig_ctx,
|
ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
|
||||||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
@ -9069,14 +9069,14 @@ struct llm_build_context {
|
||||||
|
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
@ -9183,14 +9183,14 @@ struct llm_build_context {
|
||||||
|
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
@ -9335,7 +9335,7 @@ struct llm_build_context {
|
||||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||||
|
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_orig_ctx,
|
ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
|
||||||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
@ -9346,7 +9346,7 @@ struct llm_build_context {
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_orig_ctx,
|
ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
|
||||||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
@ -9457,7 +9457,7 @@ struct llm_build_context {
|
||||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||||
|
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_orig_ctx,
|
ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig,
|
||||||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
@ -9466,7 +9466,7 @@ struct llm_build_context {
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, n_orig_ctx,
|
ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig,
|
||||||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
@ -9574,13 +9574,13 @@ struct llm_build_context {
|
||||||
|
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens), inp_pos, nullptr,
|
||||||
n_embd_head, rope_type, n_orig_ctx, freq_base, freq_scale,
|
n_embd_head, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||||
n_embd_head, rope_type, n_orig_ctx, freq_base, freq_scale,
|
n_embd_head, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
|
@ -9782,14 +9782,14 @@ struct llm_build_context {
|
||||||
|
|
||||||
struct ggml_tensor * Qcur = ggml_rope_ext(
|
struct ggml_tensor * Qcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
struct ggml_tensor * Kcur = ggml_rope_ext(
|
struct ggml_tensor * Kcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
@ -9898,14 +9898,14 @@ struct llm_build_context {
|
||||||
|
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
@ -10015,14 +10015,14 @@ struct llm_build_context {
|
||||||
|
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
@ -10145,14 +10145,14 @@ struct llm_build_context {
|
||||||
|
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
@ -10265,7 +10265,7 @@ struct llm_build_context {
|
||||||
|
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr,
|
||||||
n_embd_head_k, rope_type, n_orig_ctx, freq_base, freq_scale,
|
n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
|
@ -10274,7 +10274,7 @@ struct llm_build_context {
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||||
n_embd_head_k, rope_type, n_orig_ctx, freq_base, freq_scale,
|
n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
|
@ -10385,14 +10385,14 @@ struct llm_build_context {
|
||||||
|
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
@ -10675,14 +10675,14 @@ struct llm_build_context {
|
||||||
|
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
@ -10806,14 +10806,14 @@ struct llm_build_context {
|
||||||
|
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
@ -10920,14 +10920,14 @@ struct llm_build_context {
|
||||||
|
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
@ -11055,14 +11055,14 @@ struct llm_build_context {
|
||||||
|
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||||
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
@ -11272,7 +11272,7 @@ struct llm_build_context {
|
||||||
q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
|
q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
|
||||||
q_pe = ggml_rope_ext(
|
q_pe = ggml_rope_ext(
|
||||||
ctx0, q_pe, inp_pos, nullptr,
|
ctx0, q_pe, inp_pos, nullptr,
|
||||||
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor_scaled, beta_fast, beta_slow
|
ext_factor, attn_factor_scaled, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(q_pe, "q_pe", il);
|
cb(q_pe, "q_pe", il);
|
||||||
|
@ -11281,7 +11281,7 @@ struct llm_build_context {
|
||||||
k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
|
k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
|
||||||
k_pe = ggml_rope_ext(
|
k_pe = ggml_rope_ext(
|
||||||
ctx0, k_pe, inp_pos, nullptr,
|
ctx0, k_pe, inp_pos, nullptr,
|
||||||
n_rot, rope_type, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor_scaled, beta_fast, beta_slow
|
ext_factor, attn_factor_scaled, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
cb(k_pe, "k_pe", il);
|
cb(k_pe, "k_pe", il);
|
||||||
|
@ -16264,8 +16264,8 @@ struct llama_context * llama_new_context_with_model(
|
||||||
|
|
||||||
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
|
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
|
||||||
|
|
||||||
cparams.n_yarn_orig_ctx = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx :
|
cparams.n_ctx_orig_yarn = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx :
|
||||||
hparams.n_yarn_orig_ctx != 0 ? hparams.n_yarn_orig_ctx :
|
hparams.n_ctx_orig_yarn != 0 ? hparams.n_ctx_orig_yarn :
|
||||||
hparams.n_ctx_train;
|
hparams.n_ctx_train;
|
||||||
|
|
||||||
cparams.cb_eval = params.cb_eval;
|
cparams.cb_eval = params.cb_eval;
|
||||||
|
|
|
@ -1615,7 +1615,7 @@ struct llama_hparams {
|
||||||
|
|
||||||
// cparams
|
// cparams
|
||||||
static constexpr uint32_t n_ctx = 512; // user-specified context size
|
static constexpr uint32_t n_ctx = 512; // user-specified context size
|
||||||
static constexpr uint32_t n_orig_ctx = n_ctx;
|
static constexpr uint32_t n_ctx_orig = n_ctx;
|
||||||
|
|
||||||
// batch
|
// batch
|
||||||
int32_t n_tokens;
|
int32_t n_tokens;
|
||||||
|
@ -1806,13 +1806,13 @@ struct test_llama : public test_llm {
|
||||||
|
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx, ggml_reshape_3d(ctx, Qcur, hp.n_embd_head, hp.n_head, hp.n_tokens), inp_pos, nullptr,
|
ctx, ggml_reshape_3d(ctx, Qcur, hp.n_embd_head, hp.n_head, hp.n_tokens), inp_pos, nullptr,
|
||||||
hp.n_rot, 0, hp.n_orig_ctx, freq_base, freq_scale,
|
hp.n_rot, 0, hp.n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx, ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens), inp_pos, nullptr,
|
ctx, ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens), inp_pos, nullptr,
|
||||||
hp.n_rot, 0, hp.n_orig_ctx, freq_base, freq_scale,
|
hp.n_rot, 0, hp.n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -1931,12 +1931,12 @@ struct test_falcon : public test_llm {
|
||||||
|
|
||||||
// using mode = 2 for neox mode
|
// using mode = 2 for neox mode
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx, Qcur, inp_pos, nullptr, hp.n_rot, 2, hp.n_orig_ctx,
|
ctx, Qcur, inp_pos, nullptr, hp.n_rot, 2, hp.n_ctx_orig,
|
||||||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx, Kcur, inp_pos, nullptr, hp.n_rot, 2, hp.n_orig_ctx,
|
ctx, Kcur, inp_pos, nullptr, hp.n_rot, 2, hp.n_ctx_orig,
|
||||||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue