Fix code style
This commit is contained in:
parent
54ddacaa8b
commit
9dc817e57f
1 changed files with 8 additions and 8 deletions
16
ggml-cuda.cu
16
ggml-cuda.cu
|
@ -473,7 +473,7 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
|
|||
return a;
|
||||
}
|
||||
|
||||
template <int BLOCK_SIZE>
|
||||
template <int block_size>
|
||||
static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
const int tid = threadIdx.x;
|
||||
|
@ -482,7 +482,7 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
|
|||
|
||||
float2 mean_var = make_float2(0.f, 0.f);
|
||||
|
||||
for (int col = tid; col < ncols; col += BLOCK_SIZE) {
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const float xi = x[row*ncols + col];
|
||||
mean_var.x += xi;
|
||||
mean_var.y += xi * xi;
|
||||
|
@ -490,7 +490,7 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
|
|||
|
||||
// sum up partial sums
|
||||
mean_var = warp_reduce_sum(mean_var);
|
||||
if (BLOCK_SIZE > WARP_SIZE) {
|
||||
if (block_size > WARP_SIZE) {
|
||||
__shared__ float2 s_sum[32];
|
||||
int warp_id = threadIdx.x / WARP_SIZE;
|
||||
int lane_id = threadIdx.x % WARP_SIZE;
|
||||
|
@ -506,7 +506,7 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
|
|||
const float var = mean_var.y / ncols - mean * mean;
|
||||
const float inv_std = rsqrtf(var + eps);
|
||||
|
||||
for (int col = tid; col < ncols; col += BLOCK_SIZE) {
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_std;
|
||||
}
|
||||
}
|
||||
|
@ -519,21 +519,21 @@ static __device__ __forceinline__ float warp_reduce_sum(float x) {
|
|||
return x;
|
||||
}
|
||||
|
||||
template <int BLOCK_SIZE>
|
||||
template <int block_size>
|
||||
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
float tmp = 0.0f; // partial sum for thread in warp
|
||||
|
||||
for (int col = tid; col < ncols; col += BLOCK_SIZE) {
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const float xi = x[row*ncols + col];
|
||||
tmp += xi * xi;
|
||||
}
|
||||
|
||||
// sum up partial sums
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
if (BLOCK_SIZE > WARP_SIZE) {
|
||||
if (block_size > WARP_SIZE) {
|
||||
__shared__ float s_sum[32];
|
||||
int warp_id = threadIdx.x / WARP_SIZE;
|
||||
int lane_id = threadIdx.x % WARP_SIZE;
|
||||
|
@ -548,7 +548,7 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
|
|||
const float mean = tmp / ncols;
|
||||
const float scale = rsqrtf(mean + eps);
|
||||
|
||||
for (int col = tid; col < ncols; col += BLOCK_SIZE) {
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
dst[row*ncols + col] = scale * x[row*ncols + col];
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue