CUDA: non-contiguous (RMS) norm support (#11659)

* CUDA: non-contiguous (RMS) norm support

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Johannes Gäßler 2025-02-04 22:21:42 +01:00 committed by GitHub
parent 3ec9fd4b77
commit fd08255d0d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 97 additions and 47 deletions

View file

@ -38,6 +38,7 @@
#include "ggml-cuda/upscale.cuh"
#include "ggml-cuda/wkv6.cuh"
#include "ggml-cuda/gla.cuh"
#include "ggml.h"
#include <algorithm>
#include <array>
@ -3139,6 +3140,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
break;
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
return true;
case GGML_OP_RMS_NORM_BACK:
return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
break;
@ -3181,7 +3183,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_SUM_ROWS:
case GGML_OP_ARGSORT:
case GGML_OP_ACC:
return true;
case GGML_OP_GROUP_NORM:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_UPSCALE:
case GGML_OP_PAD:
case GGML_OP_ARANGE: