ggml : ggml_rope now takes a vector with positions instead of n_past

This commit is contained in:
Georgi Gerganov 2023-09-17 21:12:51 +03:00
parent 3b4bab6a38
commit 1fb033fd85
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
9 changed files with 270 additions and 131 deletions

113
ggml.c
View file

@ -6968,7 +6968,7 @@ struct ggml_tensor * ggml_soft_max_back_inplace(
static struct ggml_tensor * ggml_rope_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past,
struct ggml_tensor * b,
int n_dims,
int mode,
int n_ctx,
@ -6977,6 +6977,10 @@ static struct ggml_tensor * ggml_rope_impl(
float xpos_base,
bool xpos_down,
bool inplace) {
GGML_ASSERT(ggml_is_vector(b));
GGML_ASSERT(b->type == GGML_TYPE_I32);
GGML_ASSERT(a->ne[2] == b->ne[0]);
bool is_node = false;
if (a->grad) {
@ -6985,7 +6989,7 @@ static struct ggml_tensor * ggml_rope_impl(
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
int32_t params[8] = { n_past, n_dims, mode, n_ctx };
int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx };
memcpy(params + 4, &freq_base, sizeof(float));
memcpy(params + 5, &freq_scale, sizeof(float));
memcpy(params + 6, &xpos_base, sizeof(float));
@ -6995,6 +6999,7 @@ static struct ggml_tensor * ggml_rope_impl(
result->op = GGML_OP_ROPE;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = b;
return result;
}
@ -7002,55 +7007,55 @@ static struct ggml_tensor * ggml_rope_impl(
struct ggml_tensor * ggml_rope(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past,
struct ggml_tensor * b,
int n_dims,
int mode,
int n_ctx) {
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, false);
return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, false);
}
struct ggml_tensor * ggml_rope_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past,
struct ggml_tensor * b,
int n_dims,
int mode,
int n_ctx) {
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, true);
return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, true);
}
struct ggml_tensor * ggml_rope_custom(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past,
struct ggml_tensor * b,
int n_dims,
int mode,
int n_ctx,
float freq_base,
float freq_scale) {
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, false);
return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, false);
}
struct ggml_tensor * ggml_rope_custom_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past,
struct ggml_tensor * b,
int n_dims,
int mode,
int n_ctx,
float freq_base,
float freq_scale) {
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, true);
return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, true);
}
struct ggml_tensor * ggml_rope_xpos_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past,
struct ggml_tensor * b,
int n_dims,
float base,
bool down) {
return ggml_rope_impl(ctx, a, n_past, n_dims, 0, 0, 10000.0f, 1.0f, base, down, true);
return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 10000.0f, 1.0f, base, down, true);
}
// ggml_rope_back
@ -7058,7 +7063,7 @@ struct ggml_tensor * ggml_rope_xpos_inplace(
struct ggml_tensor * ggml_rope_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past,
struct ggml_tensor * b,
int n_dims,
int mode,
int n_ctx,
@ -7066,7 +7071,10 @@ struct ggml_tensor * ggml_rope_back(
float freq_scale,
float xpos_base,
bool xpos_down) {
GGML_ASSERT(n_past >= 0);
GGML_ASSERT(ggml_is_vector(b));
GGML_ASSERT(b->type == GGML_TYPE_I32);
GGML_ASSERT(a->ne[2] == b->ne[0]);
GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
bool is_node = false;
@ -7077,7 +7085,7 @@ struct ggml_tensor * ggml_rope_back(
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
int32_t params[8] = { n_past, n_dims, mode, n_ctx };
int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx };
memcpy(params + 4, &freq_base, sizeof(float));
memcpy(params + 5, &freq_scale, sizeof(float));
memcpy(params + 6, &xpos_base, sizeof(float));
@ -7087,6 +7095,7 @@ struct ggml_tensor * ggml_rope_back(
result->op = GGML_OP_ROPE_BACK;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = b;
return result;
}
@ -12620,8 +12629,8 @@ static void ggml_compute_forward_clamp(
static void ggml_compute_forward_rope_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
}
@ -12631,9 +12640,9 @@ static void ggml_compute_forward_rope_f32(
// these two only relevant for xPos RoPE:
float xpos_base;
bool xpos_down;
bool xpos_down;
const int n_past = ((int32_t *) dst->op_params)[0];
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
const int n_ctx = ((int32_t *) dst->op_params)[3];
@ -12669,14 +12678,14 @@ static void ggml_compute_forward_rope_f32(
const float theta_scale = powf(freq_base, -2.0f/n_dims);
const bool is_skip = mode & 1;
const bool is_neox = mode & 2;
const bool is_glm = mode & 4;
const bool is_diff = mode & 8; // TODO: temporary
const int32_t * pos = (const int32_t *) src1->data;
for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = (is_skip ? n_past : 0); i2 < ne2; i2++) {
const int64_t p = is_diff ? n_past : is_skip ? i2 : n_past + i2;
for (int64_t i2 = 0; i2 < ne2; i2++) {
const int64_t p = pos[i2];
for (int64_t i1 = 0; i1 < ne1; i1++) {
if (ir++ < ir0) continue;
if (ir > ir1) break;
@ -12713,7 +12722,7 @@ static void ggml_compute_forward_rope_f32(
const float cos_theta = cosf(theta);
const float sin_theta = sinf(theta);
// zeta scaling for xPos only:
float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), (n_past + i2) / xpos_base) : 1.0f;
float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
if (xpos_down) zeta = 1.0f / zeta;
theta *= theta_scale;
@ -12758,8 +12767,8 @@ static void ggml_compute_forward_rope_f32(
static void ggml_compute_forward_rope_f16(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
}
@ -12767,15 +12776,13 @@ static void ggml_compute_forward_rope_f16(
float freq_base;
float freq_scale;
const int n_past = ((int32_t *) dst->op_params)[0];
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
const int n_ctx = ((int32_t *) dst->op_params)[3];
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
assert(n_past >= 0);
GGML_TENSOR_UNARY_OP_LOCALS;
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
@ -12806,9 +12813,11 @@ static void ggml_compute_forward_rope_f16(
const bool is_neox = mode & 2;
const bool is_glm = mode & 4;
const int32_t * pos = (const int32_t *) src1->data;
for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
for (int64_t i2 = 0; i2 < ne2; i2++) {
const int64_t p = pos[i2];
for (int64_t i1 = 0; i1 < ne1; i1++) {
if (ir++ < ir0) continue;
if (ir > ir1) break;
@ -12887,15 +12896,16 @@ static void ggml_compute_forward_rope_f16(
static void ggml_compute_forward_rope(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
switch (src0->type) {
case GGML_TYPE_F16:
{
ggml_compute_forward_rope_f16(params, src0, dst);
ggml_compute_forward_rope_f16(params, src0, src1, dst);
} break;
case GGML_TYPE_F32:
{
ggml_compute_forward_rope_f32(params, src0, dst);
ggml_compute_forward_rope_f32(params, src0, src1, dst);
} break;
default:
{
@ -12909,6 +12919,7 @@ static void ggml_compute_forward_rope(
static void ggml_compute_forward_rope_back_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
@ -12926,7 +12937,7 @@ static void ggml_compute_forward_rope_back_f32(
float xpos_base;
bool xpos_down;
const int n_past = ((int32_t *) dst->op_params)[0];
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
const int n_ctx = ((int32_t *) dst->op_params)[3]; UNUSED(n_ctx);
@ -12935,8 +12946,6 @@ static void ggml_compute_forward_rope_back_f32(
memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool));
assert(n_past >= 0);
GGML_TENSOR_UNARY_OP_LOCALS;
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
@ -12963,9 +12972,11 @@ static void ggml_compute_forward_rope_back_f32(
const bool is_neox = mode & 2;
const int32_t * pos = (const int32_t *) src1->data;
for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
for (int64_t i2 = 0; i2 < ne2; i2++) {
const int64_t p = pos[i2];
for (int64_t i1 = 0; i1 < ne1; i1++) {
if (ir++ < ir0) continue;
if (ir > ir1) break;
@ -12977,7 +12988,7 @@ static void ggml_compute_forward_rope_back_f32(
const float cos_theta = cosf(theta);
const float sin_theta = sinf(theta);
// zeta scaling for xPos only:
float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), (n_past + i2) / xpos_base) : 1.0f;
float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
if (xpos_down) zeta = 1.0f / zeta;
theta *= theta_scale;
@ -13020,6 +13031,7 @@ static void ggml_compute_forward_rope_back_f32(
static void ggml_compute_forward_rope_back_f16(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
@ -13030,12 +13042,10 @@ static void ggml_compute_forward_rope_back_f16(
// dx = rope_back(dy, src1)
// src0 is dy, src1 contains options
const int n_past = ((int32_t *) dst->op_params)[0];
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
assert(n_past >= 0);
GGML_TENSOR_UNARY_OP_LOCALS;
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
@ -13062,9 +13072,11 @@ static void ggml_compute_forward_rope_back_f16(
const bool is_neox = mode & 2;
const int32_t * pos = (const int32_t *) src1->data;
for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
for (int64_t i2 = 0; i2 < ne2; i2++) {
const int64_t p = pos[i2];
for (int64_t i1 = 0; i1 < ne1; i1++) {
if (ir++ < ir0) continue;
if (ir > ir1) break;
@ -13116,15 +13128,16 @@ static void ggml_compute_forward_rope_back_f16(
static void ggml_compute_forward_rope_back(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
switch (src0->type) {
case GGML_TYPE_F16:
{
ggml_compute_forward_rope_back_f16(params, src0, dst);
ggml_compute_forward_rope_back_f16(params, src0, src1, dst);
} break;
case GGML_TYPE_F32:
{
ggml_compute_forward_rope_back_f32(params, src0, dst);
ggml_compute_forward_rope_back_f32(params, src0, src1, dst);
} break;
default:
{
@ -15861,11 +15874,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
} break;
case GGML_OP_ROPE:
{
ggml_compute_forward_rope(params, tensor->src[0], tensor);
ggml_compute_forward_rope(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_OP_ROPE_BACK:
{
ggml_compute_forward_rope_back(params, tensor->src[0], tensor);
ggml_compute_forward_rope_back(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_OP_ALIBI:
{
@ -16503,7 +16516,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
// necessary for llama
if (src0->grad) {
const int n_past = ((int32_t *) tensor->op_params)[0];
//const int n_past = ((int32_t *) tensor->op_params)[0];
const int n_dims = ((int32_t *) tensor->op_params)[1];
const int mode = ((int32_t *) tensor->op_params)[2];
const int n_ctx = ((int32_t *) tensor->op_params)[3];
@ -16520,7 +16533,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
src0->grad,
ggml_rope_back(ctx,
tensor->grad,
n_past,
src1,
n_dims,
mode,
n_ctx,
@ -16534,7 +16547,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
case GGML_OP_ROPE_BACK:
{
if (src0->grad) {
const int n_past = ((int32_t *) tensor->op_params)[0];
//const int n_past = ((int32_t *) tensor->op_params)[0];
const int n_dims = ((int32_t *) tensor->op_params)[1];
const int mode = ((int32_t *) tensor->op_params)[2];
const int n_ctx = ((int32_t *) tensor->op_params)[3];
@ -16551,7 +16564,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
src0->grad,
ggml_rope_impl(ctx,
tensor->grad,
n_past,
src1,
n_dims,
mode,
n_ctx,