Helps to pass args in the correct order

This commit is contained in:
KerfuffleV2 2023-06-22 06:37:21 -06:00
parent bc17e11590
commit 4bf45a7dbe

40
ggml.c
View file

@ -6640,7 +6640,7 @@ struct ggml_tensor * ggml_rope(
int n_past,
int n_dims,
int mode) {
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, false, 1.0);
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 1.0, false);
}
struct ggml_tensor * ggml_rope_inplace(
@ -6649,7 +6649,7 @@ struct ggml_tensor * ggml_rope_inplace(
int n_past,
int n_dims,
int mode) {
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, true, 1.0);
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 1.0, true);
}
struct ggml_tensor * ggml_rope_scaled(
@ -6659,7 +6659,7 @@ struct ggml_tensor * ggml_rope_scaled(
int n_dims,
int mode,
float p_scale) {
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, false, p_scale);
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, p_scale, false);
}
struct ggml_tensor * ggml_rope_scaled_inplace(
@ -6669,7 +6669,7 @@ struct ggml_tensor * ggml_rope_scaled_inplace(
int n_dims,
int mode,
float p_scale) {
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, true, p_scale);
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, p_scale, true);
}
@ -15763,18 +15763,20 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
// necessary for llama
if (src0->grad) {
assert(src1->type == GGML_TYPE_I32);
assert(ggml_nelements(src1) == 3);
const int n_past = ((int32_t *) src1->data)[0];
const int n_dims = ((int32_t *) src1->data)[1];
const int mode = ((int32_t *) src1->data)[2];
assert(src1->type == GGML_TYPE_F32);
assert(ggml_nelements(src1) == 4);
const int n_past = (int)((float *) src1->data)[0];
const int n_dims = (int)((float *) src1->data)[1];
const int mode = (int)((float *) src1->data)[2];
const float p_scale = ((float *) src1->data)[3];
src0->grad = ggml_add_impl(ctx,
src0->grad,
ggml_rope_back(ctx,
ggml_rope_back_scaled(ctx,
tensor->grad,
n_past,
n_dims,
mode),
mode,
p_scale),
inplace);
}
if (src1->grad) {
@ -15784,18 +15786,20 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
case GGML_OP_ROPE_BACK:
{
if (src0->grad) {
assert(src1->type == GGML_TYPE_I32);
assert(ggml_nelements(src1) == 3);
const int n_past = ((int32_t *) src1->data)[0];
const int n_dims = ((int32_t *) src1->data)[1];
const int mode = ((int32_t *) src1->data)[2];
assert(src1->type == GGML_TYPE_F32);
assert(ggml_nelements(src1) == 4);
const int n_past = (int)((float *) src1->data)[0];
const int n_dims = (int)((float *) src1->data)[1];
const int mode = (int)((float *) src1->data)[2];
const float p_scale = ((float *) src1->data)[3];
src0->grad = ggml_add_impl(ctx,
src0->grad,
ggml_rope(ctx,
ggml_rope_scaled(ctx,
tensor->grad,
n_past,
n_dims,
mode),
mode,
p_scale),
inplace);
}
if (src1->grad) {