From 1120d94b603604c7cdd03191ffdba9a849fe1c73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Thu, 16 Jan 2025 13:56:39 +0100 Subject: [PATCH] remove restrict from pointers --- ggml/src/ggml-alloc.c | 5 ++++ ggml/src/ggml-cuda/norm.cu | 8 +++--- ggml/src/ggml-cuda/rope.cu | 48 +++++++++++++++++------------------ ggml/src/ggml-cuda/softmax.cu | 4 +-- ggml/src/ggml-cuda/unary.cu | 34 ++++++++++++------------- 5 files changed, 52 insertions(+), 47 deletions(-) diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c index 8dc8226ac..9a3bf9f29 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c @@ -37,6 +37,7 @@ static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml return true; } +// ops that return true for this function must not use restrict pointers for their backend implementations static bool ggml_op_can_inplace(enum ggml_op op) { switch (op) { case GGML_OP_SCALE: @@ -52,8 +53,12 @@ static bool ggml_op_can_inplace(enum ggml_op op) { case GGML_OP_LOG: case GGML_OP_UNARY: case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: + case GGML_OP_SILU_BACK: case GGML_OP_RMS_NORM: + case GGML_OP_RMS_NORM_BACK: case GGML_OP_SOFT_MAX: + case GGML_OP_SOFT_MAX_BACK: return true; default: diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index 04d40d3a5..aad63a1a0 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -1,7 +1,7 @@ #include "norm.cuh" template -static __global__ void norm_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols, const float eps) { +static __global__ void norm_f32(const float * x, float * dst, const int ncols, const float eps) { const int row = blockIdx.x*blockDim.y + threadIdx.y; const int tid = threadIdx.x; @@ -41,7 +41,7 @@ static __global__ void norm_f32(const float * __restrict__ x, float * __restrict } template -static __global__ void group_norm_f32(const float * __restrict__ x, float * __restrict__ dst, const int group_size, const int ne_elements, const float eps) { +static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) { // blockIdx.x: num_groups idx // threadIdx.x: block_size idx const int start = blockIdx.x*group_size + threadIdx.x; @@ -97,7 +97,7 @@ static __global__ void group_norm_f32(const float * __restrict__ x, float * __re } template -static __global__ void rms_norm_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols, const float eps) { +static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) { const int row = blockIdx.x*blockDim.y + threadIdx.y; const int tid = threadIdx.x; @@ -136,7 +136,7 @@ static __global__ void rms_norm_f32(const float * __restrict__ x, float * __rest template static __global__ void rms_norm_back_f32( - const float * __restrict__ grad, const float * __restrict__ xf, float * __restrict__ dst, const int ncols, const float eps) { + const float * grad, const float * xf, float * dst, const int ncols, const float eps) { const int row = blockIdx.x*blockDim.y + threadIdx.y; const int tid = threadIdx.x; diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu index e1912fee1..18f691b2d 100644 --- a/ggml/src/ggml-cuda/rope.cu +++ b/ggml/src/ggml-cuda/rope.cu @@ -39,9 +39,9 @@ static __device__ void rope_yarn( template static __global__ void rope_norm( - const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, - const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float theta_scale, const float * __restrict__ freq_factors) { + const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, + const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) { const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); if (i0 >= ne0) { @@ -83,9 +83,9 @@ static __global__ void rope_norm( template static __global__ void rope_neox( - const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, - const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float theta_scale, const float * __restrict__ freq_factors) { + const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, + const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) { const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); if (i0 >= ne0) { @@ -127,9 +127,9 @@ static __global__ void rope_neox( template static __global__ void rope_multi( - const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, - const int n_dims, const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float theta_scale, const float * __restrict__ freq_factors, const mrope_sections sections) { + const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, + const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections) { const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); if (i0 >= ne0) { @@ -187,9 +187,9 @@ static __global__ void rope_multi( template static __global__ void rope_vision( - const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, - const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims, - const float theta_scale, const float * __restrict__ freq_factors, const mrope_sections sections) { + const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, + const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims, + const float theta_scale, const float * freq_factors, const mrope_sections sections) { const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); if (i0 >= ne0) { @@ -234,9 +234,9 @@ static __global__ void rope_vision( template static void rope_norm_cuda( - const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr, - const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, cudaStream_t stream) { + const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr, + const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { GGML_ASSERT(ne0 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); @@ -257,9 +257,9 @@ static void rope_norm_cuda( template static void rope_neox_cuda( - const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr, - const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, cudaStream_t stream) { + const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr, + const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { GGML_ASSERT(ne0 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); @@ -280,9 +280,9 @@ static void rope_neox_cuda( template static void rope_multi_cuda( - const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr, - const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, const mrope_sections sections, cudaStream_t stream) { + const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr, + const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) { GGML_ASSERT(ne0 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); @@ -303,9 +303,9 @@ static void rope_multi_cuda( template static void rope_vision_cuda( - const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr, - const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, const mrope_sections sections, cudaStream_t stream) { + const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr, + const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) { GGML_ASSERT(ne0 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); diff --git a/ggml/src/ggml-cuda/softmax.cu b/ggml/src/ggml-cuda/softmax.cu index c8a854d60..9e6cb2637 100644 --- a/ggml/src/ggml-cuda/softmax.cu +++ b/ggml/src/ggml-cuda/softmax.cu @@ -15,7 +15,7 @@ __device__ float __forceinline__ t2f32(half val) { template static __global__ void soft_max_f32( - const float * __restrict__ x, const T * __restrict__ mask, float * __restrict__ dst, const int ncols_par, const int nrows_y, + const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) { const int ncols = ncols_template == 0 ? ncols_par : ncols_template; @@ -120,7 +120,7 @@ static __global__ void soft_max_f32( } static __global__ void soft_max_back_f32( - const float * __restrict__ grad, const float * __restrict__ dstf, float * __restrict__ dst, const int ncols, const float scale) { + const float * grad, const float * dstf, float * dst, const int ncols, const float scale) { const int tid = threadIdx.x; const int rowx = blockIdx.x; diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index ce2029f56..6b21f407d 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -1,6 +1,6 @@ #include "unary.cuh" -static __global__ void neg_f32(const float * __restrict__ x, float * __restrict__ dst, const int k) { +static __global__ void neg_f32(const float * x, float * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { @@ -10,7 +10,7 @@ static __global__ void neg_f32(const float * __restrict__ x, float * __restrict_ dst[i] = -x[i]; } -static __global__ void step_f32(const float * __restrict__ x, float * __restrict__ dst, const int k) { +static __global__ void step_f32(const float * x, float * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { @@ -20,7 +20,7 @@ static __global__ void step_f32(const float * __restrict__ x, float * __restrict dst[i] = x[i] > 0.0f; } -static __global__ void gelu_f32(const float * __restrict__ x, float * __restrict__ dst, const int k) { +static __global__ void gelu_f32(const float * x, float * dst, const int k) { const float GELU_COEF_A = 0.044715f; const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -33,7 +33,7 @@ static __global__ void gelu_f32(const float * __restrict__ x, float * __restrict dst[i] = 0.5f*xi*(1.0f + tanhf(SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi))); } -static __global__ void gelu_quick_f32(const float * __restrict__ x, float * __restrict__ dst, int k) { +static __global__ void gelu_quick_f32(const float * x, float * dst, int k) { const float GELU_QUICK_COEF = -1.702f; const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { @@ -42,7 +42,7 @@ static __global__ void gelu_quick_f32(const float * __restrict__ x, float * __re dst[i] = x[i] * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x[i]))); } -static __global__ void silu_f32(const float * __restrict__ x, float * __restrict__ dst, const int k) { +static __global__ void silu_f32(const float * x, float * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { @@ -52,7 +52,7 @@ static __global__ void silu_f32(const float * __restrict__ x, float * __restrict } static __global__ void silu_back_f32( - const float * __restrict__ grad, const float * __restrict__ xf, float * __restrict__ dst, const int k) { + const float * grad, const float * xf, float * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { @@ -64,7 +64,7 @@ static __global__ void silu_back_f32( dst[i] = grad[i] * s * (1.0f + xfi * (1.0f - s)); } -static __global__ void tanh_f32(const float * __restrict__ x, float * __restrict__ dst, int k) { +static __global__ void tanh_f32(const float * x, float * dst, int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { return; @@ -72,7 +72,7 @@ static __global__ void tanh_f32(const float * __restrict__ x, float * __restrict dst[i] = tanhf(x[i]); } -static __global__ void relu_f32(const float * __restrict__ x, float * __restrict__ dst, const int k) { +static __global__ void relu_f32(const float * x, float * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { @@ -81,7 +81,7 @@ static __global__ void relu_f32(const float * __restrict__ x, float * __restrict dst[i] = fmaxf(x[i], 0); } -static __global__ void sigmoid_f32(const float * __restrict__ x, float * __restrict__ dst, const int k) { +static __global__ void sigmoid_f32(const float * x, float * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { @@ -90,7 +90,7 @@ static __global__ void sigmoid_f32(const float * __restrict__ x, float * __restr dst[i] = 1.0f / (1.0f + expf(-x[i])); } -static __global__ void hardsigmoid_f32(const float * __restrict__ x, float * __restrict__ dst, const int k) { +static __global__ void hardsigmoid_f32(const float * x, float * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { @@ -99,7 +99,7 @@ static __global__ void hardsigmoid_f32(const float * __restrict__ x, float * __r dst[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); } -static __global__ void hardswish_f32(const float * __restrict__ x, float * __restrict__ dst, const int k) { +static __global__ void hardswish_f32(const float * x, float * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { @@ -108,7 +108,7 @@ static __global__ void hardswish_f32(const float * __restrict__ x, float * __res dst[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); } -static __global__ void exp_f32(const float * __restrict__ x, float * __restrict__ dst, const int k) { +static __global__ void exp_f32(const float * x, float * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { @@ -117,7 +117,7 @@ static __global__ void exp_f32(const float * __restrict__ x, float * __restrict_ dst[i] = expf(x[i]); } -static __global__ void leaky_relu_f32(const float * __restrict__ x, float * __restrict__ dst, const int k, const float negative_slope) { +static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { return; @@ -125,7 +125,7 @@ static __global__ void leaky_relu_f32(const float * __restrict__ x, float * __re dst[i] = fmaxf(x[i], 0) + fminf(x[i], 0.0f) * negative_slope; } -static __global__ void sqr_f32(const float * __restrict__ x, float * __restrict__ dst, const int k) { +static __global__ void sqr_f32(const float * x, float * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { @@ -134,7 +134,7 @@ static __global__ void sqr_f32(const float * __restrict__ x, float * __restrict_ dst[i] = x[i] * x[i]; } -static __global__ void sqrt_f32(const float * __restrict__ x, float * __restrict__ dst, const int k) { +static __global__ void sqrt_f32(const float * x, float * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { @@ -143,7 +143,7 @@ static __global__ void sqrt_f32(const float * __restrict__ x, float * __restrict dst[i] = sqrtf(x[i]); } -static __global__ void sin_f32(const float * __restrict__ x, float * __restrict__ dst, const int k) { +static __global__ void sin_f32(const float * x, float * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { @@ -152,7 +152,7 @@ static __global__ void sin_f32(const float * __restrict__ x, float * __restrict_ dst[i] = sinf(x[i]); } -static __global__ void cos_f32(const float * __restrict__ x, float * __restrict__ dst, const int k) { +static __global__ void cos_f32(const float * x, float * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) {