llama : add phi3 128K model support (#7225)
* add phi3 128k support in convert-hf-to-gguf * add phi3 128k support in cuda * address build warnings on llama.cpp * adjust index value in cuda long rope freq factors * add long rope support in ggml cpu backend * make freq factors only depend on ctx size * remove unused rope scaling type 'su' frin gguf converter * fix flint warnings on convert-hf-to-gguf.py * set to the short freq factor when context size is small than trained context size * add one line of comments * metal : support rope freq_factors * ggml : update ggml_rope_ext API to support freq. factors * backends : add dev messages to support rope freq. factors * minor : style * tests : update to use new rope API * backends : fix pragma semicolons * minor : cleanup * llama : move rope factors from KV header to tensors * llama : remove tmp assert * cuda : fix compile warning * convert : read/write n_head_kv * llama : fix uninitialized tensors --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
6369bf0433
commit
201cc11afa
15 changed files with 484 additions and 233 deletions
88
ggml.c
88
ggml.c
|
@ -6231,6 +6231,7 @@ static struct ggml_tensor * ggml_rope_impl(
|
|||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * c,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
|
@ -6248,6 +6249,11 @@ static struct ggml_tensor * ggml_rope_impl(
|
|||
GGML_ASSERT(b->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(a->ne[2] == b->ne[0]);
|
||||
|
||||
if (c) {
|
||||
GGML_ASSERT(c->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(c->ne[0] >= n_dims / 2);
|
||||
}
|
||||
|
||||
bool is_node = false;
|
||||
|
||||
if (a->grad) {
|
||||
|
@ -6271,6 +6277,7 @@ static struct ggml_tensor * ggml_rope_impl(
|
|||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src[0] = a;
|
||||
result->src[1] = b;
|
||||
result->src[2] = c;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
@ -6283,7 +6290,7 @@ struct ggml_tensor * ggml_rope(
|
|||
int mode,
|
||||
int n_ctx) {
|
||||
return ggml_rope_impl(
|
||||
ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
|
||||
ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -6295,7 +6302,49 @@ struct ggml_tensor * ggml_rope_inplace(
|
|||
int mode,
|
||||
int n_ctx) {
|
||||
return ggml_rope_impl(
|
||||
ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
|
||||
ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
|
||||
);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_rope_ext(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * c,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
int n_orig_ctx,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
float attn_factor,
|
||||
float beta_fast,
|
||||
float beta_slow) {
|
||||
return ggml_rope_impl(
|
||||
ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
|
||||
);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_rope_ext_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * c,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
int n_orig_ctx,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
float attn_factor,
|
||||
float beta_fast,
|
||||
float beta_slow) {
|
||||
return ggml_rope_impl(
|
||||
ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -6314,7 +6363,7 @@ struct ggml_tensor * ggml_rope_custom(
|
|||
float beta_fast,
|
||||
float beta_slow) {
|
||||
return ggml_rope_impl(
|
||||
ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
||||
ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
|
||||
);
|
||||
}
|
||||
|
@ -6334,27 +6383,18 @@ struct ggml_tensor * ggml_rope_custom_inplace(
|
|||
float beta_fast,
|
||||
float beta_slow) {
|
||||
return ggml_rope_impl(
|
||||
ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
||||
ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
|
||||
);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_rope_xpos_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
int n_dims,
|
||||
float base,
|
||||
bool down) {
|
||||
return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true);
|
||||
}
|
||||
|
||||
// ggml_rope_back
|
||||
|
||||
struct ggml_tensor * ggml_rope_back(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * c,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
|
@ -6370,6 +6410,7 @@ struct ggml_tensor * ggml_rope_back(
|
|||
GGML_ASSERT(ggml_is_vector(b));
|
||||
GGML_ASSERT(b->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(a->ne[2] == b->ne[0]);
|
||||
GGML_ASSERT(c == NULL && "freq factors not implemented yet");
|
||||
|
||||
GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
|
||||
|
||||
|
@ -14304,6 +14345,7 @@ static void ggml_compute_forward_rope_f32(
|
|||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
const struct ggml_tensor * src2 = dst->src[2];
|
||||
|
||||
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
||||
return;
|
||||
|
@ -14363,6 +14405,17 @@ static void ggml_compute_forward_rope_f32(
|
|||
const bool is_neox = mode & 2;
|
||||
const bool is_glm = mode & 4;
|
||||
|
||||
const float * freq_factors = NULL;
|
||||
if (is_neox) {
|
||||
if (src2 != NULL) {
|
||||
GGML_ASSERT(src2->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src2->ne[0] >= n_dims / 2);
|
||||
freq_factors = (const float *) src2->data;
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for mode 1");
|
||||
}
|
||||
|
||||
// backward process uses inverse rotation by cos and sin.
|
||||
// cos and sin build a rotation matrix, where the inverse is the transpose.
|
||||
// this essentially just switches the sign of sin.
|
||||
|
@ -14439,10 +14492,11 @@ static void ggml_compute_forward_rope_f32(
|
|||
|
||||
// simplified from `(ib * n_dims + ic) * inv_ndims`
|
||||
float cur_rot = inv_ndims * ic - ib;
|
||||
float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
|
||||
|
||||
float cos_theta, sin_theta;
|
||||
rope_yarn(
|
||||
theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
|
||||
theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
|
||||
&cos_theta, &sin_theta
|
||||
);
|
||||
sin_theta *= sin_sign;
|
||||
|
@ -18387,6 +18441,7 @@ static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct gg
|
|||
static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set zero_table) {
|
||||
struct ggml_tensor * src0 = tensor->src[0];
|
||||
struct ggml_tensor * src1 = tensor->src[1];
|
||||
struct ggml_tensor * src2 = tensor->src[2];
|
||||
|
||||
switch (tensor->op) {
|
||||
case GGML_OP_DUP:
|
||||
|
@ -18918,6 +18973,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
ggml_rope_back(ctx,
|
||||
tensor->grad,
|
||||
src1,
|
||||
src2,
|
||||
n_dims,
|
||||
mode,
|
||||
n_ctx,
|
||||
|
@ -18957,6 +19013,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
ggml_rope_impl(ctx,
|
||||
tensor->grad,
|
||||
src1,
|
||||
src2,
|
||||
n_dims,
|
||||
mode,
|
||||
n_ctx,
|
||||
|
@ -19038,7 +19095,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
masked);
|
||||
}
|
||||
|
||||
struct ggml_tensor * src2 = tensor->src[2];
|
||||
const int64_t elem_q = ggml_nelements(src0);
|
||||
const int64_t elem_k = ggml_nelements(src1);
|
||||
const int64_t elem_v = ggml_nelements(src2);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue