remove restrict from pointers
This commit is contained in:
parent
deab32760a
commit
1120d94b60
5 changed files with 52 additions and 47 deletions
|
@ -37,6 +37,7 @@ static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml
|
|||
return true;
|
||||
}
|
||||
|
||||
// ops that return true for this function must not use restrict pointers for their backend implementations
|
||||
static bool ggml_op_can_inplace(enum ggml_op op) {
|
||||
switch (op) {
|
||||
case GGML_OP_SCALE:
|
||||
|
@ -52,8 +53,12 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
|
|||
case GGML_OP_LOG:
|
||||
case GGML_OP_UNARY:
|
||||
case GGML_OP_ROPE:
|
||||
case GGML_OP_ROPE_BACK:
|
||||
case GGML_OP_SILU_BACK:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_SOFT_MAX_BACK:
|
||||
return true;
|
||||
|
||||
default:
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#include "norm.cuh"
|
||||
|
||||
template <int block_size>
|
||||
static __global__ void norm_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols, const float eps) {
|
||||
static __global__ void norm_f32(const float * x, float * dst, const int ncols, const float eps) {
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
|
@ -41,7 +41,7 @@ static __global__ void norm_f32(const float * __restrict__ x, float * __restrict
|
|||
}
|
||||
|
||||
template <int block_size>
|
||||
static __global__ void group_norm_f32(const float * __restrict__ x, float * __restrict__ dst, const int group_size, const int ne_elements, const float eps) {
|
||||
static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) {
|
||||
// blockIdx.x: num_groups idx
|
||||
// threadIdx.x: block_size idx
|
||||
const int start = blockIdx.x*group_size + threadIdx.x;
|
||||
|
@ -97,7 +97,7 @@ static __global__ void group_norm_f32(const float * __restrict__ x, float * __re
|
|||
}
|
||||
|
||||
template <int block_size>
|
||||
static __global__ void rms_norm_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols, const float eps) {
|
||||
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
|
@ -136,7 +136,7 @@ static __global__ void rms_norm_f32(const float * __restrict__ x, float * __rest
|
|||
|
||||
template <int block_size>
|
||||
static __global__ void rms_norm_back_f32(
|
||||
const float * __restrict__ grad, const float * __restrict__ xf, float * __restrict__ dst, const int ncols, const float eps) {
|
||||
const float * grad, const float * xf, float * dst, const int ncols, const float eps) {
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
|
|
|
@ -39,9 +39,9 @@ static __device__ void rope_yarn(
|
|||
|
||||
template<bool forward, bool has_ff, typename T>
|
||||
static __global__ void rope_norm(
|
||||
const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
|
||||
const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float theta_scale, const float * __restrict__ freq_factors) {
|
||||
const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
|
||||
const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) {
|
||||
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||
|
||||
if (i0 >= ne0) {
|
||||
|
@ -83,9 +83,9 @@ static __global__ void rope_norm(
|
|||
|
||||
template<bool forward, bool has_ff, typename T>
|
||||
static __global__ void rope_neox(
|
||||
const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
|
||||
const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float theta_scale, const float * __restrict__ freq_factors) {
|
||||
const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
|
||||
const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) {
|
||||
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||
|
||||
if (i0 >= ne0) {
|
||||
|
@ -127,9 +127,9 @@ static __global__ void rope_neox(
|
|||
|
||||
template<bool forward, bool has_ff, typename T>
|
||||
static __global__ void rope_multi(
|
||||
const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2,
|
||||
const int n_dims, const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float theta_scale, const float * __restrict__ freq_factors, const mrope_sections sections) {
|
||||
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2,
|
||||
const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections) {
|
||||
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||
|
||||
if (i0 >= ne0) {
|
||||
|
@ -187,9 +187,9 @@ static __global__ void rope_multi(
|
|||
|
||||
template<bool forward, bool has_ff, typename T>
|
||||
static __global__ void rope_vision(
|
||||
const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims,
|
||||
const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
|
||||
const float theta_scale, const float * __restrict__ freq_factors, const mrope_sections sections) {
|
||||
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims,
|
||||
const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
|
||||
const float theta_scale, const float * freq_factors, const mrope_sections sections) {
|
||||
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||
|
||||
if (i0 >= ne0) {
|
||||
|
@ -234,9 +234,9 @@ static __global__ void rope_vision(
|
|||
|
||||
template<bool forward, typename T>
|
||||
static void rope_norm_cuda(
|
||||
const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
|
||||
const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, cudaStream_t stream) {
|
||||
const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
|
||||
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
|
||||
GGML_ASSERT(ne0 % 2 == 0);
|
||||
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
||||
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
||||
|
@ -257,9 +257,9 @@ static void rope_norm_cuda(
|
|||
|
||||
template<bool forward, typename T>
|
||||
static void rope_neox_cuda(
|
||||
const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
|
||||
const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, cudaStream_t stream) {
|
||||
const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
|
||||
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
|
||||
GGML_ASSERT(ne0 % 2 == 0);
|
||||
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
||||
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
||||
|
@ -280,9 +280,9 @@ static void rope_neox_cuda(
|
|||
|
||||
template<bool forward, typename T>
|
||||
static void rope_multi_cuda(
|
||||
const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
|
||||
const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, const mrope_sections sections, cudaStream_t stream) {
|
||||
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
|
||||
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) {
|
||||
GGML_ASSERT(ne0 % 2 == 0);
|
||||
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
||||
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
||||
|
@ -303,9 +303,9 @@ static void rope_multi_cuda(
|
|||
|
||||
template<bool forward, typename T>
|
||||
static void rope_vision_cuda(
|
||||
const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
|
||||
const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, const mrope_sections sections, cudaStream_t stream) {
|
||||
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
|
||||
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) {
|
||||
GGML_ASSERT(ne0 % 2 == 0);
|
||||
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
||||
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
||||
|
|
|
@ -15,7 +15,7 @@ __device__ float __forceinline__ t2f32<half>(half val) {
|
|||
|
||||
template <bool use_shared, int ncols_template, int block_size_template, typename T>
|
||||
static __global__ void soft_max_f32(
|
||||
const float * __restrict__ x, const T * __restrict__ mask, float * __restrict__ dst, const int ncols_par, const int nrows_y,
|
||||
const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y,
|
||||
const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
|
||||
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
||||
|
||||
|
@ -120,7 +120,7 @@ static __global__ void soft_max_f32(
|
|||
}
|
||||
|
||||
static __global__ void soft_max_back_f32(
|
||||
const float * __restrict__ grad, const float * __restrict__ dstf, float * __restrict__ dst, const int ncols, const float scale) {
|
||||
const float * grad, const float * dstf, float * dst, const int ncols, const float scale) {
|
||||
const int tid = threadIdx.x;
|
||||
const int rowx = blockIdx.x;
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#include "unary.cuh"
|
||||
|
||||
static __global__ void neg_f32(const float * __restrict__ x, float * __restrict__ dst, const int k) {
|
||||
static __global__ void neg_f32(const float * x, float * dst, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
|
@ -10,7 +10,7 @@ static __global__ void neg_f32(const float * __restrict__ x, float * __restrict_
|
|||
dst[i] = -x[i];
|
||||
}
|
||||
|
||||
static __global__ void step_f32(const float * __restrict__ x, float * __restrict__ dst, const int k) {
|
||||
static __global__ void step_f32(const float * x, float * dst, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
|
@ -20,7 +20,7 @@ static __global__ void step_f32(const float * __restrict__ x, float * __restrict
|
|||
dst[i] = x[i] > 0.0f;
|
||||
}
|
||||
|
||||
static __global__ void gelu_f32(const float * __restrict__ x, float * __restrict__ dst, const int k) {
|
||||
static __global__ void gelu_f32(const float * x, float * dst, const int k) {
|
||||
const float GELU_COEF_A = 0.044715f;
|
||||
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
@ -33,7 +33,7 @@ static __global__ void gelu_f32(const float * __restrict__ x, float * __restrict
|
|||
dst[i] = 0.5f*xi*(1.0f + tanhf(SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi)));
|
||||
}
|
||||
|
||||
static __global__ void gelu_quick_f32(const float * __restrict__ x, float * __restrict__ dst, int k) {
|
||||
static __global__ void gelu_quick_f32(const float * x, float * dst, int k) {
|
||||
const float GELU_QUICK_COEF = -1.702f;
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
if (i >= k) {
|
||||
|
@ -42,7 +42,7 @@ static __global__ void gelu_quick_f32(const float * __restrict__ x, float * __re
|
|||
dst[i] = x[i] * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x[i])));
|
||||
}
|
||||
|
||||
static __global__ void silu_f32(const float * __restrict__ x, float * __restrict__ dst, const int k) {
|
||||
static __global__ void silu_f32(const float * x, float * dst, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
|
@ -52,7 +52,7 @@ static __global__ void silu_f32(const float * __restrict__ x, float * __restrict
|
|||
}
|
||||
|
||||
static __global__ void silu_back_f32(
|
||||
const float * __restrict__ grad, const float * __restrict__ xf, float * __restrict__ dst, const int k) {
|
||||
const float * grad, const float * xf, float * dst, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
|
@ -64,7 +64,7 @@ static __global__ void silu_back_f32(
|
|||
dst[i] = grad[i] * s * (1.0f + xfi * (1.0f - s));
|
||||
}
|
||||
|
||||
static __global__ void tanh_f32(const float * __restrict__ x, float * __restrict__ dst, int k) {
|
||||
static __global__ void tanh_f32(const float * x, float * dst, int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
if (i >= k) {
|
||||
return;
|
||||
|
@ -72,7 +72,7 @@ static __global__ void tanh_f32(const float * __restrict__ x, float * __restrict
|
|||
dst[i] = tanhf(x[i]);
|
||||
}
|
||||
|
||||
static __global__ void relu_f32(const float * __restrict__ x, float * __restrict__ dst, const int k) {
|
||||
static __global__ void relu_f32(const float * x, float * dst, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
|
@ -81,7 +81,7 @@ static __global__ void relu_f32(const float * __restrict__ x, float * __restrict
|
|||
dst[i] = fmaxf(x[i], 0);
|
||||
}
|
||||
|
||||
static __global__ void sigmoid_f32(const float * __restrict__ x, float * __restrict__ dst, const int k) {
|
||||
static __global__ void sigmoid_f32(const float * x, float * dst, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
|
@ -90,7 +90,7 @@ static __global__ void sigmoid_f32(const float * __restrict__ x, float * __restr
|
|||
dst[i] = 1.0f / (1.0f + expf(-x[i]));
|
||||
}
|
||||
|
||||
static __global__ void hardsigmoid_f32(const float * __restrict__ x, float * __restrict__ dst, const int k) {
|
||||
static __global__ void hardsigmoid_f32(const float * x, float * dst, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
|
@ -99,7 +99,7 @@ static __global__ void hardsigmoid_f32(const float * __restrict__ x, float * __r
|
|||
dst[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
|
||||
}
|
||||
|
||||
static __global__ void hardswish_f32(const float * __restrict__ x, float * __restrict__ dst, const int k) {
|
||||
static __global__ void hardswish_f32(const float * x, float * dst, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
|
@ -108,7 +108,7 @@ static __global__ void hardswish_f32(const float * __restrict__ x, float * __res
|
|||
dst[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
|
||||
}
|
||||
|
||||
static __global__ void exp_f32(const float * __restrict__ x, float * __restrict__ dst, const int k) {
|
||||
static __global__ void exp_f32(const float * x, float * dst, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
|
@ -117,7 +117,7 @@ static __global__ void exp_f32(const float * __restrict__ x, float * __restrict_
|
|||
dst[i] = expf(x[i]);
|
||||
}
|
||||
|
||||
static __global__ void leaky_relu_f32(const float * __restrict__ x, float * __restrict__ dst, const int k, const float negative_slope) {
|
||||
static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
if (i >= k) {
|
||||
return;
|
||||
|
@ -125,7 +125,7 @@ static __global__ void leaky_relu_f32(const float * __restrict__ x, float * __re
|
|||
dst[i] = fmaxf(x[i], 0) + fminf(x[i], 0.0f) * negative_slope;
|
||||
}
|
||||
|
||||
static __global__ void sqr_f32(const float * __restrict__ x, float * __restrict__ dst, const int k) {
|
||||
static __global__ void sqr_f32(const float * x, float * dst, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
|
@ -134,7 +134,7 @@ static __global__ void sqr_f32(const float * __restrict__ x, float * __restrict_
|
|||
dst[i] = x[i] * x[i];
|
||||
}
|
||||
|
||||
static __global__ void sqrt_f32(const float * __restrict__ x, float * __restrict__ dst, const int k) {
|
||||
static __global__ void sqrt_f32(const float * x, float * dst, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
|
@ -143,7 +143,7 @@ static __global__ void sqrt_f32(const float * __restrict__ x, float * __restrict
|
|||
dst[i] = sqrtf(x[i]);
|
||||
}
|
||||
|
||||
static __global__ void sin_f32(const float * __restrict__ x, float * __restrict__ dst, const int k) {
|
||||
static __global__ void sin_f32(const float * x, float * dst, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
|
@ -152,7 +152,7 @@ static __global__ void sin_f32(const float * __restrict__ x, float * __restrict_
|
|||
dst[i] = sinf(x[i]);
|
||||
}
|
||||
|
||||
static __global__ void cos_f32(const float * __restrict__ x, float * __restrict__ dst, const int k) {
|
||||
static __global__ void cos_f32(const float * x, float * dst, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue