Helps to pass args in the correct order
This commit is contained in:
parent
bc17e11590
commit
4bf45a7dbe
1 changed files with 22 additions and 18 deletions
40
ggml.c
40
ggml.c
|
@ -6640,7 +6640,7 @@ struct ggml_tensor * ggml_rope(
|
||||||
int n_past,
|
int n_past,
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode) {
|
int mode) {
|
||||||
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, false, 1.0);
|
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 1.0, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * ggml_rope_inplace(
|
struct ggml_tensor * ggml_rope_inplace(
|
||||||
|
@ -6649,7 +6649,7 @@ struct ggml_tensor * ggml_rope_inplace(
|
||||||
int n_past,
|
int n_past,
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode) {
|
int mode) {
|
||||||
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, true, 1.0);
|
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 1.0, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * ggml_rope_scaled(
|
struct ggml_tensor * ggml_rope_scaled(
|
||||||
|
@ -6659,7 +6659,7 @@ struct ggml_tensor * ggml_rope_scaled(
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode,
|
int mode,
|
||||||
float p_scale) {
|
float p_scale) {
|
||||||
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, false, p_scale);
|
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, p_scale, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * ggml_rope_scaled_inplace(
|
struct ggml_tensor * ggml_rope_scaled_inplace(
|
||||||
|
@ -6669,7 +6669,7 @@ struct ggml_tensor * ggml_rope_scaled_inplace(
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode,
|
int mode,
|
||||||
float p_scale) {
|
float p_scale) {
|
||||||
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, true, p_scale);
|
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, p_scale, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -15763,18 +15763,20 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||||
{
|
{
|
||||||
// necessary for llama
|
// necessary for llama
|
||||||
if (src0->grad) {
|
if (src0->grad) {
|
||||||
assert(src1->type == GGML_TYPE_I32);
|
assert(src1->type == GGML_TYPE_F32);
|
||||||
assert(ggml_nelements(src1) == 3);
|
assert(ggml_nelements(src1) == 4);
|
||||||
const int n_past = ((int32_t *) src1->data)[0];
|
const int n_past = (int)((float *) src1->data)[0];
|
||||||
const int n_dims = ((int32_t *) src1->data)[1];
|
const int n_dims = (int)((float *) src1->data)[1];
|
||||||
const int mode = ((int32_t *) src1->data)[2];
|
const int mode = (int)((float *) src1->data)[2];
|
||||||
|
const float p_scale = ((float *) src1->data)[3];
|
||||||
src0->grad = ggml_add_impl(ctx,
|
src0->grad = ggml_add_impl(ctx,
|
||||||
src0->grad,
|
src0->grad,
|
||||||
ggml_rope_back(ctx,
|
ggml_rope_back_scaled(ctx,
|
||||||
tensor->grad,
|
tensor->grad,
|
||||||
n_past,
|
n_past,
|
||||||
n_dims,
|
n_dims,
|
||||||
mode),
|
mode,
|
||||||
|
p_scale),
|
||||||
inplace);
|
inplace);
|
||||||
}
|
}
|
||||||
if (src1->grad) {
|
if (src1->grad) {
|
||||||
|
@ -15784,18 +15786,20 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||||
case GGML_OP_ROPE_BACK:
|
case GGML_OP_ROPE_BACK:
|
||||||
{
|
{
|
||||||
if (src0->grad) {
|
if (src0->grad) {
|
||||||
assert(src1->type == GGML_TYPE_I32);
|
assert(src1->type == GGML_TYPE_F32);
|
||||||
assert(ggml_nelements(src1) == 3);
|
assert(ggml_nelements(src1) == 4);
|
||||||
const int n_past = ((int32_t *) src1->data)[0];
|
const int n_past = (int)((float *) src1->data)[0];
|
||||||
const int n_dims = ((int32_t *) src1->data)[1];
|
const int n_dims = (int)((float *) src1->data)[1];
|
||||||
const int mode = ((int32_t *) src1->data)[2];
|
const int mode = (int)((float *) src1->data)[2];
|
||||||
|
const float p_scale = ((float *) src1->data)[3];
|
||||||
src0->grad = ggml_add_impl(ctx,
|
src0->grad = ggml_add_impl(ctx,
|
||||||
src0->grad,
|
src0->grad,
|
||||||
ggml_rope(ctx,
|
ggml_rope_scaled(ctx,
|
||||||
tensor->grad,
|
tensor->grad,
|
||||||
n_past,
|
n_past,
|
||||||
n_dims,
|
n_dims,
|
||||||
mode),
|
mode,
|
||||||
|
p_scale),
|
||||||
inplace);
|
inplace);
|
||||||
}
|
}
|
||||||
if (src1->grad) {
|
if (src1->grad) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue