RoPE: fix back, CUDA support for back + noncont. (#11240)

* RoPE: fix back, CUDA support for back + noncont.

* fix comments reg. non-cont. RoPE support [no-ci]
This commit is contained in:
Johannes Gäßler 2025-01-15 12:51:37 +01:00 committed by GitHub
parent 0ccd7f3eb2
commit 432df2d5f9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 269 additions and 258 deletions

View file

@ -3695,7 +3695,7 @@ void ggml_rope_yarn_corr_dims(
// ggml_rope_back
struct ggml_tensor * ggml_rope_back(
struct ggml_tensor * ggml_rope_ext_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
@ -3709,29 +3709,32 @@ struct ggml_tensor * ggml_rope_back(
float attn_factor,
float beta_fast,
float beta_slow) {
GGML_ASSERT(ggml_is_vector(b));
GGML_ASSERT(b->type == GGML_TYPE_I32);
GGML_ASSERT(a->ne[2] == b->ne[0]);
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
memcpy(params + 5, &freq_base, sizeof(float));
memcpy(params + 6, &freq_scale, sizeof(float));
memcpy(params + 7, &ext_factor, sizeof(float));
memcpy(params + 8, &attn_factor, sizeof(float));
memcpy(params + 9, &beta_fast, sizeof(float));
memcpy(params + 10, &beta_slow, sizeof(float));
ggml_set_op_params(result, params, sizeof(params));
result->op = GGML_OP_ROPE_BACK;
result->src[0] = a;
result->src[1] = b;
result->src[2] = c;
struct ggml_tensor * result = ggml_rope_ext(
ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
result->op = GGML_OP_ROPE_BACK;
return result;
}
struct ggml_tensor * ggml_rope_multi_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c,
int n_dims,
int sections[4],
int mode,
int n_ctx_orig,
float freq_base,
float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow) {
struct ggml_tensor * result = ggml_rope_multi(
ctx, a, b, c, n_dims, sections, mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
result->op = GGML_OP_ROPE_BACK;
return result;
}
// ggml_clamp
struct ggml_tensor * ggml_clamp(
@ -5594,6 +5597,7 @@ static void ggml_compute_backward(
//const int n_ctx = ((int32_t *) tensor->op_params)[3];
const int n_ctx_orig = ((const int32_t *) tensor->op_params)[4];
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
int sections[4] = {0, 0, 0, 0};
memcpy(&freq_base, (const float *) tensor->op_params + 5, sizeof(float));
memcpy(&freq_scale, (const float *) tensor->op_params + 6, sizeof(float));
@ -5601,10 +5605,14 @@ static void ggml_compute_backward(
memcpy(&attn_factor, (const float *) tensor->op_params + 8, sizeof(float));
memcpy(&beta_fast, (const float *) tensor->op_params + 9, sizeof(float));
memcpy(&beta_slow, (const float *) tensor->op_params + 10, sizeof(float));
memcpy(&sections, tensor->op_params + 11, sizeof(sections));
ggml_add_or_set(ctx, cgraph, isrc0,
ggml_rope_back(ctx, grad, src1, src2, n_dims, mode, n_ctx_orig, freq_base,
freq_scale, ext_factor, attn_factor, beta_fast, beta_slow));
struct ggml_tensor * rope_back = grad->ne[2] == src1->ne[0] ?
ggml_rope_ext_back(ctx, grad, src1, src2, n_dims,
mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow) :
ggml_rope_multi_back(ctx, grad, src1, src2, n_dims, sections,
mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
ggml_add_or_set(ctx, cgraph, isrc0, rope_back);
}
GGML_ASSERT((!src2 || !src2_needs_grads) && "gradients for freq factors not implemented");
} break;