Fix code style

This commit is contained in:
lijiahao 2023-09-03 17:54:22 +08:00
parent 54ddacaa8b
commit 9dc817e57f

View file

@ -473,7 +473,7 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
return a; return a;
} }
template <int BLOCK_SIZE> template <int block_size>
static __global__ void norm_f32(const float * x, float * dst, const int ncols) { static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
const int row = blockIdx.x*blockDim.y + threadIdx.y; const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x; const int tid = threadIdx.x;
@ -482,7 +482,7 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
float2 mean_var = make_float2(0.f, 0.f); float2 mean_var = make_float2(0.f, 0.f);
for (int col = tid; col < ncols; col += BLOCK_SIZE) { for (int col = tid; col < ncols; col += block_size) {
const float xi = x[row*ncols + col]; const float xi = x[row*ncols + col];
mean_var.x += xi; mean_var.x += xi;
mean_var.y += xi * xi; mean_var.y += xi * xi;
@ -490,7 +490,7 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
// sum up partial sums // sum up partial sums
mean_var = warp_reduce_sum(mean_var); mean_var = warp_reduce_sum(mean_var);
if (BLOCK_SIZE > WARP_SIZE) { if (block_size > WARP_SIZE) {
__shared__ float2 s_sum[32]; __shared__ float2 s_sum[32];
int warp_id = threadIdx.x / WARP_SIZE; int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE; int lane_id = threadIdx.x % WARP_SIZE;
@ -506,7 +506,7 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
const float var = mean_var.y / ncols - mean * mean; const float var = mean_var.y / ncols - mean * mean;
const float inv_std = rsqrtf(var + eps); const float inv_std = rsqrtf(var + eps);
for (int col = tid; col < ncols; col += BLOCK_SIZE) { for (int col = tid; col < ncols; col += block_size) {
dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_std; dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_std;
} }
} }
@ -519,21 +519,21 @@ static __device__ __forceinline__ float warp_reduce_sum(float x) {
return x; return x;
} }
template <int BLOCK_SIZE> template <int block_size>
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) { static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y; const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x; const int tid = threadIdx.x;
float tmp = 0.0f; // partial sum for thread in warp float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += BLOCK_SIZE) { for (int col = tid; col < ncols; col += block_size) {
const float xi = x[row*ncols + col]; const float xi = x[row*ncols + col];
tmp += xi * xi; tmp += xi * xi;
} }
// sum up partial sums // sum up partial sums
tmp = warp_reduce_sum(tmp); tmp = warp_reduce_sum(tmp);
if (BLOCK_SIZE > WARP_SIZE) { if (block_size > WARP_SIZE) {
__shared__ float s_sum[32]; __shared__ float s_sum[32];
int warp_id = threadIdx.x / WARP_SIZE; int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE; int lane_id = threadIdx.x % WARP_SIZE;
@ -548,7 +548,7 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
const float mean = tmp / ncols; const float mean = tmp / ncols;
const float scale = rsqrtf(mean + eps); const float scale = rsqrtf(mean + eps);
for (int col = tid; col < ncols; col += BLOCK_SIZE) { for (int col = tid; col < ncols; col += block_size) {
dst[row*ncols + col] = scale * x[row*ncols + col]; dst[row*ncols + col] = scale * x[row*ncols + col];
} }
} }