implement backward pass of ggml_rope and ggml_rope_back
This commit is contained in:
parent
36d8a051d4
commit
488decfdc5
2 changed files with 73 additions and 66 deletions
135
ggml.c
135
ggml.c
|
@ -6271,8 +6271,7 @@ struct ggml_tensor * ggml_rope_impl(
|
|||
GGML_ASSERT(n_past >= 0);
|
||||
bool is_node = false;
|
||||
|
||||
if (a->grad) {
|
||||
GGML_ASSERT(false); // TODO: implement backward
|
||||
if (!inplace && a->grad) {
|
||||
is_node = true;
|
||||
}
|
||||
|
||||
|
@ -6314,7 +6313,6 @@ struct ggml_tensor * ggml_rope_inplace(
|
|||
struct ggml_tensor * ggml_rope_back(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
int n_past,
|
||||
int n_dims,
|
||||
int mode) {
|
||||
|
@ -6328,16 +6326,15 @@ struct ggml_tensor * ggml_rope_back(
|
|||
|
||||
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
||||
|
||||
struct ggml_tensor * c = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
|
||||
((int32_t *) c->data)[0] = n_past;
|
||||
((int32_t *) c->data)[1] = n_dims;
|
||||
((int32_t *) c->data)[2] = mode;
|
||||
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
|
||||
((int32_t *) b->data)[0] = n_past;
|
||||
((int32_t *) b->data)[1] = n_dims;
|
||||
((int32_t *) b->data)[2] = mode;
|
||||
|
||||
result->op = GGML_OP_ROPE_BACK;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src0 = a;
|
||||
result->src1 = b;
|
||||
result->opt[0] = c;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
@ -10710,22 +10707,21 @@ static void ggml_compute_forward_rope_back_f32(
|
|||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
const struct ggml_tensor * src1,
|
||||
const struct ggml_tensor * opt,
|
||||
struct ggml_tensor * dst) {
|
||||
assert(opt->type == GGML_TYPE_I32);
|
||||
assert(ggml_nelements(opt) == 3);
|
||||
assert(src1->type == GGML_TYPE_I32);
|
||||
assert(ggml_nelements(src1) == 3);
|
||||
|
||||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||
return;
|
||||
}
|
||||
|
||||
// y = rope(x, opt)
|
||||
// dx = rope_back(x, dy)
|
||||
// src0 is x, src1 is dy
|
||||
// y = rope(x, src1)
|
||||
// dx = rope_back(dy, src1)
|
||||
// src0 is dy, src1 contains options
|
||||
|
||||
const int n_past = ((int32_t *) opt->data)[0];
|
||||
const int n_dims = ((int32_t *) opt->data)[1];
|
||||
const int mode = ((int32_t *) opt->data)[2];
|
||||
const int n_past = ((int32_t *) src1->data)[0];
|
||||
const int n_dims = ((int32_t *) src1->data)[1];
|
||||
const int mode = ((int32_t *) src1->data)[2];
|
||||
|
||||
//const int64_t ne0 = src0->ne[0];
|
||||
const int64_t ne1 = src0->ne[1];
|
||||
|
@ -10761,9 +10757,6 @@ static void ggml_compute_forward_rope_back_f32(
|
|||
|
||||
const bool is_neox = mode & 2;
|
||||
|
||||
// TODO
|
||||
GGML_ASSERT(false);
|
||||
//*
|
||||
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
||||
for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
|
||||
const int p = ((mode & 1) == 0 ? n_past + i2 : i2);
|
||||
|
@ -10780,47 +10773,49 @@ static void ggml_compute_forward_rope_back_f32(
|
|||
theta *= theta_scale;
|
||||
|
||||
if (!is_neox) {
|
||||
const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
const float * const dy = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
float * dx = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[1];
|
||||
const float dy0 = dy[0];
|
||||
const float dy1 = dy[1];
|
||||
|
||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
||||
dx[0] = dy0*cos_theta + dy1*sin_theta;
|
||||
dx[1] = - dy0*sin_theta + dy1*cos_theta;
|
||||
} else {
|
||||
const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
|
||||
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
|
||||
const float * const dy = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
|
||||
float * dx = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[n_dims/2];
|
||||
const float dy0 = dy[0];
|
||||
const float dy1 = dy[n_dims/2];
|
||||
|
||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
||||
dx[0] = dy0*cos_theta + dy1*sin_theta;
|
||||
dx[n_dims/2] = - dy0*sin_theta + dy1*cos_theta;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
//*/
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_rope_back_f16(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
const struct ggml_tensor * src1,
|
||||
const struct ggml_tensor * opt,
|
||||
struct ggml_tensor * dst) {
|
||||
assert(opt->type == GGML_TYPE_I32);
|
||||
assert(ggml_nelements(opt) == 3);
|
||||
assert(src1->type == GGML_TYPE_I32);
|
||||
assert(ggml_nelements(src1) == 3);
|
||||
|
||||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int n_past = ((int32_t *) opt->data)[0];
|
||||
const int n_dims = ((int32_t *) opt->data)[1];
|
||||
const int mode = ((int32_t *) opt->data)[2];
|
||||
// y = rope(x, src1)
|
||||
// dx = rope_back(dy, src1)
|
||||
// src0 is dy, src1 contains options
|
||||
|
||||
const int n_past = ((int32_t *) src1->data)[0];
|
||||
const int n_dims = ((int32_t *) src1->data)[1];
|
||||
const int mode = ((int32_t *) src1->data)[2];
|
||||
|
||||
//const int64_t ne0 = src0->ne[0];
|
||||
const int64_t ne1 = src0->ne[1];
|
||||
|
@ -10856,9 +10851,6 @@ static void ggml_compute_forward_rope_back_f16(
|
|||
|
||||
const bool is_neox = mode & 2;
|
||||
|
||||
// TODO
|
||||
GGML_ASSERT(false);
|
||||
/*
|
||||
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
||||
for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
|
||||
const int p = ((mode & 1) == 0 ? n_past + i2 : i2);
|
||||
|
@ -10875,45 +10867,43 @@ static void ggml_compute_forward_rope_back_f16(
|
|||
theta *= theta_scale;
|
||||
|
||||
if (!is_neox) {
|
||||
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
const ggml_fp16_t * const dy = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
ggml_fp16_t * dx = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
const float x0 = GGML_FP16_TO_FP32(src[0]);
|
||||
const float x1 = GGML_FP16_TO_FP32(src[1]);
|
||||
const float dy0 = GGML_FP16_TO_FP32(dy[0]);
|
||||
const float dy1 = GGML_FP16_TO_FP32(dy[1]);
|
||||
|
||||
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
||||
dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
||||
dx[0] = GGML_FP32_TO_FP16( dy0*cos_theta + dy1*sin_theta);
|
||||
dx[1] = GGML_FP32_TO_FP16(-dy0*sin_theta + dy1*cos_theta);
|
||||
} else {
|
||||
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
|
||||
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
|
||||
const ggml_fp16_t * const dy = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
|
||||
ggml_fp16_t * dx = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
|
||||
|
||||
const float x0 = GGML_FP16_TO_FP32(src[0]);
|
||||
const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
|
||||
const float dy0 = GGML_FP16_TO_FP32(dy[0]);
|
||||
const float dy1 = GGML_FP16_TO_FP32(dy[n_dims/2]);
|
||||
|
||||
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
||||
dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
||||
dx[0] = GGML_FP32_TO_FP16( dy0*cos_theta + dy1*sin_theta);
|
||||
dx[n_dims/2] = GGML_FP32_TO_FP16(-dy0*sin_theta + dy1*cos_theta);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_rope_back(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
const struct ggml_tensor * src1,
|
||||
const struct ggml_tensor * opt,
|
||||
struct ggml_tensor * dst) {
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
ggml_compute_forward_rope_back_f16(params, src0, src1, opt, dst);
|
||||
ggml_compute_forward_rope_back_f16(params, src0, src1, dst);
|
||||
} break;
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_rope_back_f32(params, src0, src1, opt, dst);
|
||||
ggml_compute_forward_rope_back_f32(params, src0, src1, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
|
@ -12373,14 +12363,14 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|||
{
|
||||
ggml_compute_forward_soft_max(params, tensor->src0, tensor);
|
||||
} break;
|
||||
case GGML_OP_ROPE_BACK:
|
||||
{
|
||||
ggml_compute_forward_rope_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
|
||||
} break;
|
||||
case GGML_OP_ROPE:
|
||||
{
|
||||
ggml_compute_forward_rope(params, tensor->src0, tensor->src1, tensor);
|
||||
} break;
|
||||
case GGML_OP_ROPE_BACK:
|
||||
{
|
||||
ggml_compute_forward_rope_back(params, tensor->src0, tensor->src1, tensor);
|
||||
} break;
|
||||
case GGML_OP_ALIBI:
|
||||
{
|
||||
ggml_compute_forward_alibi(params, tensor->src0, tensor->src1, tensor);
|
||||
|
@ -12865,7 +12855,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
src0->grad = ggml_sub_impl(ctx,
|
||||
src0->grad,
|
||||
ggml_rope_back(ctx,
|
||||
src0,
|
||||
tensor->grad,
|
||||
n_past,
|
||||
n_dims,
|
||||
|
@ -12878,7 +12867,24 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
} break;
|
||||
case GGML_OP_ROPE_BACK:
|
||||
{
|
||||
GGML_ASSERT(false); // TODO: not implemented
|
||||
if (src0->grad) {
|
||||
assert(src1->type == GGML_TYPE_I32);
|
||||
assert(ggml_nelements(src1) == 3);
|
||||
const int n_past = ((int32_t *) src1->data)[0];
|
||||
const int n_dims = ((int32_t *) src1->data)[1];
|
||||
const int mode = ((int32_t *) src1->data)[2];
|
||||
src0->grad = ggml_sub_impl(ctx,
|
||||
src0->grad,
|
||||
ggml_rope(ctx,
|
||||
tensor->grad,
|
||||
n_past,
|
||||
n_dims,
|
||||
mode),
|
||||
inplace);
|
||||
}
|
||||
if (src1->grad) {
|
||||
// noop
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_CONV_1D_1S:
|
||||
{
|
||||
|
@ -13369,6 +13375,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|||
node->n_tasks = n_threads;
|
||||
} break;
|
||||
case GGML_OP_ROPE:
|
||||
case GGML_OP_ROPE_BACK:
|
||||
{
|
||||
node->n_tasks = n_threads;
|
||||
} break;
|
||||
|
|
4
ggml.h
4
ggml.h
|
@ -252,7 +252,7 @@ extern "C" {
|
|||
|
||||
GGML_OP_DUP,
|
||||
GGML_OP_ADD,
|
||||
GGML_OP_ADD1
|
||||
GGML_OP_ADD1,
|
||||
GGML_OP_ADD_AT,
|
||||
GGML_OP_SUB,
|
||||
GGML_OP_MUL,
|
||||
|
@ -746,7 +746,7 @@ extern "C" {
|
|||
int n_dims,
|
||||
int mode);
|
||||
|
||||
// rotary position embedding backward, i.e compute dx
|
||||
// rotary position embedding backward, i.e compute dx from dy
|
||||
GGML_API struct ggml_tensor * ggml_rope_back(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * x,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue