llama : custom attention mask + parallel decoding + no context swaps (#3228)
* tests : verify that RoPE is "additive" * llama : replace ggml_diag_mask_inf with ggml_add (custom -inf mask) * ggml : ggml_rope now takes a vector with positions instead of n_past * metal : add rope_f16 kernel + optimize cpy kernels * llama : unified KV cache + batch inference API * llama : add new llama_decode() API that works with llama_batch * llama : add cell_max heuristic for more efficient kv_cache * llama : extend llama_kv_cache API * llama : more robust cell_max heuristic + wip shift * metal : disable concurrency optimization * llama : add llama_kv_cache_shift_seq + no more context swaps * llama : apply K-cache roping for Falcon and Baichuan * speculative : fix KV cache management * parallel : example for serving multiple users in parallel * parallel : disable hot-plug to avoid cache fragmentation * fixes : speculative KV cache + llama worst-case graph * llama : extend batch API to select which logits to output * llama : fix worst case graph build * ggml-cuda : update rope implementation for parallel decoding (#3254) * ggml-cuda : update rope implementation for parallel decoding * better solution for p0 computation * fix rope * simpler rope implementation --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * make : add parallel to build + fix static functions in llama.cpp * simple : fix token counting * parallel : various improvements * llama : fix cell_max logic + rename functions * parallel : try smaller batches when the KV cache is fragmented * parallel : fix sequence termination criteria * llama : silence errors KV cache errors * parallel : remove new line from prompt * parallel : process system prompt once + configurable paramters + llama API * parallel : remove question with short answers * parallel : count cache misses * parallel : print misses on each request * parallel : minor * llama : fix n_kv to never become 0 * parallel : rename hot-plug to continuous-batching * llama : improve llama_batch API + simplify parallel example * simple : add parallel decoding support * simple : improve comments + free batch * ggml-cuda : add rope f16, restore performance with parallel decoding (#3272) * ggml-cuda : add rope f16, restore performance * offload KQ_mask with all models * fix rope shift --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * llama : disable MPI for now ggml-ci * train : make KQ_pos memory buffer permanent via dummy scale op * ggml : revert change to ggml_cpy, add ggml_cont_Nd instead (#3275) ggml-ci * parallel : fix bug (extra BOS) + smaller token_prev array * parallel : fix cases where the input prompts can overflow the batch * parallel : add disabled experimental batch chunking in powers of two * llama : llama.h formatting + comments * simple : add README.md * llama : fix kv cache heuristic when context is less than 32 * parallel : fix crash when `-n -1` * llama : simplify returns if/else branches * metal : use mm kernels for batch size > 2 * examples : utilize new llama_get_logits_ith() * examples : add example for batched decoding * examples : do not eval prompt 2 times (close #3348) * server : clear the KV cache beyond n_past before llama_decode * server : avoid context swaps by shifting the KV cache --------- Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
parent
45855b3f1c
commit
ec893798b7
35 changed files with 2700 additions and 673 deletions
170
ggml.c
170
ggml.c
|
@ -6406,6 +6406,54 @@ struct ggml_tensor * ggml_cont_inplace(
|
|||
return ggml_cont_impl(ctx, a, true);
|
||||
}
|
||||
|
||||
|
||||
// make contiguous, with new shape
|
||||
GGML_API struct ggml_tensor * ggml_cont_1d(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int64_t ne0) {
|
||||
return ggml_cont_4d(ctx, a, ne0, 1, 1, 1);
|
||||
}
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_cont_2d(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int64_t ne0,
|
||||
int64_t ne1) {
|
||||
return ggml_cont_4d(ctx, a, ne0, ne1, 1, 1);
|
||||
}
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_cont_3d(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int64_t ne0,
|
||||
int64_t ne1,
|
||||
int64_t ne2) {
|
||||
return ggml_cont_4d(ctx, a, ne0, ne1, ne2, 1);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_cont_4d(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int64_t ne0,
|
||||
int64_t ne1,
|
||||
int64_t ne2,
|
||||
int64_t ne3) {
|
||||
GGML_ASSERT(ggml_nelements(a) == (ne0*ne1*ne2*ne3));
|
||||
|
||||
bool is_node = false;
|
||||
|
||||
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
|
||||
ggml_format_name(result, "%s (cont)", a->name);
|
||||
|
||||
result->op = GGML_OP_CONT;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src[0] = a;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
// ggml_reshape
|
||||
|
||||
struct ggml_tensor * ggml_reshape(
|
||||
|
@ -6968,7 +7016,7 @@ struct ggml_tensor * ggml_soft_max_back_inplace(
|
|||
static struct ggml_tensor * ggml_rope_impl(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int n_past,
|
||||
struct ggml_tensor * b,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
|
@ -6977,7 +7025,10 @@ static struct ggml_tensor * ggml_rope_impl(
|
|||
float xpos_base,
|
||||
bool xpos_down,
|
||||
bool inplace) {
|
||||
GGML_ASSERT(n_past >= 0);
|
||||
GGML_ASSERT(ggml_is_vector(b));
|
||||
GGML_ASSERT(b->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(a->ne[2] == b->ne[0]);
|
||||
|
||||
bool is_node = false;
|
||||
|
||||
if (a->grad) {
|
||||
|
@ -6986,7 +7037,7 @@ static struct ggml_tensor * ggml_rope_impl(
|
|||
|
||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
|
||||
int32_t params[8] = { n_past, n_dims, mode, n_ctx };
|
||||
int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx };
|
||||
memcpy(params + 4, &freq_base, sizeof(float));
|
||||
memcpy(params + 5, &freq_scale, sizeof(float));
|
||||
memcpy(params + 6, &xpos_base, sizeof(float));
|
||||
|
@ -6996,6 +7047,7 @@ static struct ggml_tensor * ggml_rope_impl(
|
|||
result->op = GGML_OP_ROPE;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src[0] = a;
|
||||
result->src[1] = b;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
@ -7003,55 +7055,55 @@ static struct ggml_tensor * ggml_rope_impl(
|
|||
struct ggml_tensor * ggml_rope(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int n_past,
|
||||
struct ggml_tensor * b,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx) {
|
||||
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, false);
|
||||
return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, false);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_rope_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int n_past,
|
||||
struct ggml_tensor * b,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx) {
|
||||
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, true);
|
||||
return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, true);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_rope_custom(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int n_past,
|
||||
struct ggml_tensor * b,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
float freq_base,
|
||||
float freq_scale) {
|
||||
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, false);
|
||||
return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, false);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_rope_custom_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int n_past,
|
||||
struct ggml_tensor * b,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
float freq_base,
|
||||
float freq_scale) {
|
||||
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, true);
|
||||
return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, true);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_rope_xpos_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int n_past,
|
||||
struct ggml_tensor * b,
|
||||
int n_dims,
|
||||
float base,
|
||||
bool down) {
|
||||
return ggml_rope_impl(ctx, a, n_past, n_dims, 0, 0, 10000.0f, 1.0f, base, down, true);
|
||||
return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 10000.0f, 1.0f, base, down, true);
|
||||
}
|
||||
|
||||
// ggml_rope_back
|
||||
|
@ -7059,7 +7111,7 @@ struct ggml_tensor * ggml_rope_xpos_inplace(
|
|||
struct ggml_tensor * ggml_rope_back(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int n_past,
|
||||
struct ggml_tensor * b,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
|
@ -7067,7 +7119,10 @@ struct ggml_tensor * ggml_rope_back(
|
|||
float freq_scale,
|
||||
float xpos_base,
|
||||
bool xpos_down) {
|
||||
GGML_ASSERT(n_past >= 0);
|
||||
GGML_ASSERT(ggml_is_vector(b));
|
||||
GGML_ASSERT(b->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(a->ne[2] == b->ne[0]);
|
||||
|
||||
GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
|
||||
|
||||
bool is_node = false;
|
||||
|
@ -7078,7 +7133,7 @@ struct ggml_tensor * ggml_rope_back(
|
|||
|
||||
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
||||
|
||||
int32_t params[8] = { n_past, n_dims, mode, n_ctx };
|
||||
int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx };
|
||||
memcpy(params + 4, &freq_base, sizeof(float));
|
||||
memcpy(params + 5, &freq_scale, sizeof(float));
|
||||
memcpy(params + 6, &xpos_base, sizeof(float));
|
||||
|
@ -7088,6 +7143,7 @@ struct ggml_tensor * ggml_rope_back(
|
|||
result->op = GGML_OP_ROPE_BACK;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src[0] = a;
|
||||
result->src[1] = b;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
@ -8798,8 +8854,6 @@ static void ggml_compute_forward_add_f32(
|
|||
#else
|
||||
ggml_vec_add_f32(ne00, dst_ptr, src0_ptr, src1_ptr);
|
||||
#endif
|
||||
// }
|
||||
// }
|
||||
}
|
||||
} else {
|
||||
// src1 is not contiguous
|
||||
|
@ -12456,13 +12510,11 @@ static void ggml_compute_forward_alibi_f16(
|
|||
return;
|
||||
}
|
||||
|
||||
const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_head = ((int32_t *) dst->op_params)[1];
|
||||
float max_bias;
|
||||
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
||||
|
||||
assert(n_past >= 0);
|
||||
|
||||
const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
|
||||
const int ne1 = src0->ne[1]; // seq_len_without_past
|
||||
const int ne2 = src0->ne[2]; // n_head -> this is k
|
||||
|
@ -12477,7 +12529,7 @@ static void ggml_compute_forward_alibi_f16(
|
|||
//const int nb3 = src0->nb[3];
|
||||
|
||||
GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
|
||||
GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
|
||||
//GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
|
||||
GGML_ASSERT(n_head == ne2);
|
||||
|
||||
// add alibi to src0 (KQ_scaled)
|
||||
|
@ -12623,8 +12675,8 @@ static void ggml_compute_forward_clamp(
|
|||
static void ggml_compute_forward_rope_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
const struct ggml_tensor * src1,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||
return;
|
||||
}
|
||||
|
@ -12634,9 +12686,9 @@ static void ggml_compute_forward_rope_f32(
|
|||
|
||||
// these two only relevant for xPos RoPE:
|
||||
float xpos_base;
|
||||
bool xpos_down;
|
||||
bool xpos_down;
|
||||
|
||||
const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||
|
@ -12645,8 +12697,6 @@ static void ggml_compute_forward_rope_f32(
|
|||
memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||
memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool));
|
||||
|
||||
assert(n_past >= 0);
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS;
|
||||
|
||||
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
|
||||
|
@ -12677,9 +12727,11 @@ static void ggml_compute_forward_rope_f32(
|
|||
const bool is_neox = mode & 2;
|
||||
const bool is_glm = mode & 4;
|
||||
|
||||
const int32_t * pos = (const int32_t *) src1->data;
|
||||
|
||||
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
||||
for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
|
||||
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
|
||||
for (int64_t i2 = 0; i2 < ne2; i2++) {
|
||||
const int64_t p = pos[i2];
|
||||
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
||||
if (ir++ < ir0) continue;
|
||||
if (ir > ir1) break;
|
||||
|
@ -12716,7 +12768,7 @@ static void ggml_compute_forward_rope_f32(
|
|||
const float cos_theta = cosf(theta);
|
||||
const float sin_theta = sinf(theta);
|
||||
// zeta scaling for xPos only:
|
||||
float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), (n_past + i2) / xpos_base) : 1.0f;
|
||||
float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
|
||||
if (xpos_down) zeta = 1.0f / zeta;
|
||||
|
||||
theta *= theta_scale;
|
||||
|
@ -12761,8 +12813,8 @@ static void ggml_compute_forward_rope_f32(
|
|||
static void ggml_compute_forward_rope_f16(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
const struct ggml_tensor * src1,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||
return;
|
||||
}
|
||||
|
@ -12770,15 +12822,13 @@ static void ggml_compute_forward_rope_f16(
|
|||
float freq_base;
|
||||
float freq_scale;
|
||||
|
||||
const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||
|
||||
assert(n_past >= 0);
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS;
|
||||
|
||||
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
|
||||
|
@ -12809,9 +12859,11 @@ static void ggml_compute_forward_rope_f16(
|
|||
const bool is_neox = mode & 2;
|
||||
const bool is_glm = mode & 4;
|
||||
|
||||
const int32_t * pos = (const int32_t *) src1->data;
|
||||
|
||||
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
||||
for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
|
||||
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
|
||||
for (int64_t i2 = 0; i2 < ne2; i2++) {
|
||||
const int64_t p = pos[i2];
|
||||
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
||||
if (ir++ < ir0) continue;
|
||||
if (ir > ir1) break;
|
||||
|
@ -12890,15 +12942,16 @@ static void ggml_compute_forward_rope_f16(
|
|||
static void ggml_compute_forward_rope(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
const struct ggml_tensor * src1,
|
||||
struct ggml_tensor * dst) {
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
ggml_compute_forward_rope_f16(params, src0, dst);
|
||||
ggml_compute_forward_rope_f16(params, src0, src1, dst);
|
||||
} break;
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_rope_f32(params, src0, dst);
|
||||
ggml_compute_forward_rope_f32(params, src0, src1, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
|
@ -12912,6 +12965,7 @@ static void ggml_compute_forward_rope(
|
|||
static void ggml_compute_forward_rope_back_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
const struct ggml_tensor * src1,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||
|
@ -12929,7 +12983,7 @@ static void ggml_compute_forward_rope_back_f32(
|
|||
float xpos_base;
|
||||
bool xpos_down;
|
||||
|
||||
const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
const int n_ctx = ((int32_t *) dst->op_params)[3]; UNUSED(n_ctx);
|
||||
|
@ -12938,8 +12992,6 @@ static void ggml_compute_forward_rope_back_f32(
|
|||
memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||
memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool));
|
||||
|
||||
assert(n_past >= 0);
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS;
|
||||
|
||||
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
|
||||
|
@ -12966,9 +13018,11 @@ static void ggml_compute_forward_rope_back_f32(
|
|||
|
||||
const bool is_neox = mode & 2;
|
||||
|
||||
const int32_t * pos = (const int32_t *) src1->data;
|
||||
|
||||
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
||||
for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
|
||||
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
|
||||
for (int64_t i2 = 0; i2 < ne2; i2++) {
|
||||
const int64_t p = pos[i2];
|
||||
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
||||
if (ir++ < ir0) continue;
|
||||
if (ir > ir1) break;
|
||||
|
@ -12980,7 +13034,7 @@ static void ggml_compute_forward_rope_back_f32(
|
|||
const float cos_theta = cosf(theta);
|
||||
const float sin_theta = sinf(theta);
|
||||
// zeta scaling for xPos only:
|
||||
float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), (n_past + i2) / xpos_base) : 1.0f;
|
||||
float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
|
||||
if (xpos_down) zeta = 1.0f / zeta;
|
||||
|
||||
theta *= theta_scale;
|
||||
|
@ -13023,6 +13077,7 @@ static void ggml_compute_forward_rope_back_f32(
|
|||
static void ggml_compute_forward_rope_back_f16(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
const struct ggml_tensor * src1,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||
|
@ -13033,12 +13088,10 @@ static void ggml_compute_forward_rope_back_f16(
|
|||
// dx = rope_back(dy, src1)
|
||||
// src0 is dy, src1 contains options
|
||||
|
||||
const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
|
||||
assert(n_past >= 0);
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS;
|
||||
|
||||
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
|
||||
|
@ -13065,9 +13118,11 @@ static void ggml_compute_forward_rope_back_f16(
|
|||
|
||||
const bool is_neox = mode & 2;
|
||||
|
||||
const int32_t * pos = (const int32_t *) src1->data;
|
||||
|
||||
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
||||
for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
|
||||
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
|
||||
for (int64_t i2 = 0; i2 < ne2; i2++) {
|
||||
const int64_t p = pos[i2];
|
||||
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
||||
if (ir++ < ir0) continue;
|
||||
if (ir > ir1) break;
|
||||
|
@ -13119,15 +13174,16 @@ static void ggml_compute_forward_rope_back_f16(
|
|||
static void ggml_compute_forward_rope_back(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
const struct ggml_tensor * src1,
|
||||
struct ggml_tensor * dst) {
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
ggml_compute_forward_rope_back_f16(params, src0, dst);
|
||||
ggml_compute_forward_rope_back_f16(params, src0, src1, dst);
|
||||
} break;
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_rope_back_f32(params, src0, dst);
|
||||
ggml_compute_forward_rope_back_f32(params, src0, src1, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
|
@ -15864,11 +15920,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|||
} break;
|
||||
case GGML_OP_ROPE:
|
||||
{
|
||||
ggml_compute_forward_rope(params, tensor->src[0], tensor);
|
||||
ggml_compute_forward_rope(params, tensor->src[0], tensor->src[1], tensor);
|
||||
} break;
|
||||
case GGML_OP_ROPE_BACK:
|
||||
{
|
||||
ggml_compute_forward_rope_back(params, tensor->src[0], tensor);
|
||||
ggml_compute_forward_rope_back(params, tensor->src[0], tensor->src[1], tensor);
|
||||
} break;
|
||||
case GGML_OP_ALIBI:
|
||||
{
|
||||
|
@ -16506,7 +16562,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
{
|
||||
// necessary for llama
|
||||
if (src0->grad) {
|
||||
const int n_past = ((int32_t *) tensor->op_params)[0];
|
||||
//const int n_past = ((int32_t *) tensor->op_params)[0];
|
||||
const int n_dims = ((int32_t *) tensor->op_params)[1];
|
||||
const int mode = ((int32_t *) tensor->op_params)[2];
|
||||
const int n_ctx = ((int32_t *) tensor->op_params)[3];
|
||||
|
@ -16523,7 +16579,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
src0->grad,
|
||||
ggml_rope_back(ctx,
|
||||
tensor->grad,
|
||||
n_past,
|
||||
src1,
|
||||
n_dims,
|
||||
mode,
|
||||
n_ctx,
|
||||
|
@ -16537,7 +16593,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
case GGML_OP_ROPE_BACK:
|
||||
{
|
||||
if (src0->grad) {
|
||||
const int n_past = ((int32_t *) tensor->op_params)[0];
|
||||
//const int n_past = ((int32_t *) tensor->op_params)[0];
|
||||
const int n_dims = ((int32_t *) tensor->op_params)[1];
|
||||
const int mode = ((int32_t *) tensor->op_params)[2];
|
||||
const int n_ctx = ((int32_t *) tensor->op_params)[3];
|
||||
|
@ -16554,7 +16610,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
src0->grad,
|
||||
ggml_rope_impl(ctx,
|
||||
tensor->grad,
|
||||
n_past,
|
||||
src1,
|
||||
n_dims,
|
||||
mode,
|
||||
n_ctx,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue