Helps to pass args in the correct order

This commit is contained in:
KerfuffleV2 2023-06-22 06:37:21 -06:00
parent bc17e11590
commit 4bf45a7dbe

40
ggml.c
View file

@ -6640,7 +6640,7 @@ struct ggml_tensor * ggml_rope(
int n_past, int n_past,
int n_dims, int n_dims,
int mode) { int mode) {
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, false, 1.0); return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 1.0, false);
} }
struct ggml_tensor * ggml_rope_inplace( struct ggml_tensor * ggml_rope_inplace(
@ -6649,7 +6649,7 @@ struct ggml_tensor * ggml_rope_inplace(
int n_past, int n_past,
int n_dims, int n_dims,
int mode) { int mode) {
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, true, 1.0); return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 1.0, true);
} }
struct ggml_tensor * ggml_rope_scaled( struct ggml_tensor * ggml_rope_scaled(
@ -6659,7 +6659,7 @@ struct ggml_tensor * ggml_rope_scaled(
int n_dims, int n_dims,
int mode, int mode,
float p_scale) { float p_scale) {
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, false, p_scale); return ggml_rope_impl(ctx, a, n_past, n_dims, mode, p_scale, false);
} }
struct ggml_tensor * ggml_rope_scaled_inplace( struct ggml_tensor * ggml_rope_scaled_inplace(
@ -6669,7 +6669,7 @@ struct ggml_tensor * ggml_rope_scaled_inplace(
int n_dims, int n_dims,
int mode, int mode,
float p_scale) { float p_scale) {
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, true, p_scale); return ggml_rope_impl(ctx, a, n_past, n_dims, mode, p_scale, true);
} }
@ -15763,18 +15763,20 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{ {
// necessary for llama // necessary for llama
if (src0->grad) { if (src0->grad) {
assert(src1->type == GGML_TYPE_I32); assert(src1->type == GGML_TYPE_F32);
assert(ggml_nelements(src1) == 3); assert(ggml_nelements(src1) == 4);
const int n_past = ((int32_t *) src1->data)[0]; const int n_past = (int)((float *) src1->data)[0];
const int n_dims = ((int32_t *) src1->data)[1]; const int n_dims = (int)((float *) src1->data)[1];
const int mode = ((int32_t *) src1->data)[2]; const int mode = (int)((float *) src1->data)[2];
const float p_scale = ((float *) src1->data)[3];
src0->grad = ggml_add_impl(ctx, src0->grad = ggml_add_impl(ctx,
src0->grad, src0->grad,
ggml_rope_back(ctx, ggml_rope_back_scaled(ctx,
tensor->grad, tensor->grad,
n_past, n_past,
n_dims, n_dims,
mode), mode,
p_scale),
inplace); inplace);
} }
if (src1->grad) { if (src1->grad) {
@ -15784,18 +15786,20 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
case GGML_OP_ROPE_BACK: case GGML_OP_ROPE_BACK:
{ {
if (src0->grad) { if (src0->grad) {
assert(src1->type == GGML_TYPE_I32); assert(src1->type == GGML_TYPE_F32);
assert(ggml_nelements(src1) == 3); assert(ggml_nelements(src1) == 4);
const int n_past = ((int32_t *) src1->data)[0]; const int n_past = (int)((float *) src1->data)[0];
const int n_dims = ((int32_t *) src1->data)[1]; const int n_dims = (int)((float *) src1->data)[1];
const int mode = ((int32_t *) src1->data)[2]; const int mode = (int)((float *) src1->data)[2];
const float p_scale = ((float *) src1->data)[3];
src0->grad = ggml_add_impl(ctx, src0->grad = ggml_add_impl(ctx,
src0->grad, src0->grad,
ggml_rope(ctx, ggml_rope_scaled(ctx,
tensor->grad, tensor->grad,
n_past, n_past,
n_dims, n_dims,
mode), mode,
p_scale),
inplace); inplace);
} }
if (src1->grad) { if (src1->grad) {