RoPE: fix back, CUDA support for back + noncont. (#11240)

* RoPE: fix back, CUDA support for back + noncont.

* fix comments reg. non-cont. RoPE support [no-ci]
This commit is contained in:
Johannes Gäßler 2025-01-15 12:51:37 +01:00 committed by GitHub
parent 0ccd7f3eb2
commit 432df2d5f9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 269 additions and 258 deletions

View file

@ -2141,6 +2141,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_ROPE:
ggml_cuda_op_rope(ctx, dst);
break;
case GGML_OP_ROPE_BACK:
ggml_cuda_op_rope_back(ctx, dst);
break;
case GGML_OP_IM2COL:
ggml_cuda_op_im2col(ctx, dst);
break;
@ -3025,7 +3028,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_SOFT_MAX:
return true;
case GGML_OP_ROPE:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_ROPE_BACK: {
const size_t ts = ggml_type_size(op->src[0]->type);
const int64_t ne0_012 = op->src[0]->ne[0] * op->src[0]->ne[1] * op->src[0]->ne[2];
return op->src[0]->nb[0] == ts && op->src[0]->nb[3] == ne0_012*ts;
}
case GGML_OP_IM2COL:
case GGML_OP_POOL_2D:
case GGML_OP_SUM:
@ -3081,6 +3088,7 @@ static int64_t get_op_batch_size(const ggml_tensor * op) {
return op->ne[1];
case GGML_OP_MUL_MAT_ID:
case GGML_OP_ROPE:
case GGML_OP_ROPE_BACK:
return op->ne[2];
default:
return ggml_nrows(op);