RoPE: fix back, CUDA support for back + noncont. (#11240)
* RoPE: fix back, CUDA support for back + noncont. * fix comments reg. non-cont. RoPE support [no-ci]
This commit is contained in:
parent
0ccd7f3eb2
commit
432df2d5f9
9 changed files with 269 additions and 258 deletions
|
@ -2141,6 +2141,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||
case GGML_OP_ROPE:
|
||||
ggml_cuda_op_rope(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_ROPE_BACK:
|
||||
ggml_cuda_op_rope_back(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_IM2COL:
|
||||
ggml_cuda_op_im2col(ctx, dst);
|
||||
break;
|
||||
|
@ -3025,7 +3028,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
case GGML_OP_SOFT_MAX:
|
||||
return true;
|
||||
case GGML_OP_ROPE:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_ROPE_BACK: {
|
||||
const size_t ts = ggml_type_size(op->src[0]->type);
|
||||
const int64_t ne0_012 = op->src[0]->ne[0] * op->src[0]->ne[1] * op->src[0]->ne[2];
|
||||
return op->src[0]->nb[0] == ts && op->src[0]->nb[3] == ne0_012*ts;
|
||||
}
|
||||
case GGML_OP_IM2COL:
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_SUM:
|
||||
|
@ -3081,6 +3088,7 @@ static int64_t get_op_batch_size(const ggml_tensor * op) {
|
|||
return op->ne[1];
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
case GGML_OP_ROPE:
|
||||
case GGML_OP_ROPE_BACK:
|
||||
return op->ne[2];
|
||||
default:
|
||||
return ggml_nrows(op);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue