cuda : fix warnings and formatting

This commit is contained in:
slaren 2024-01-30 14:05:35 +01:00
parent 04f10a2287
commit caf2fc8294

View file

@ -6041,8 +6041,10 @@ static __global__ void pool2d_nchw_kernel(
const int ph, const int pw, const int parallel_elements,
const Ti* src, To* dst, const enum ggml_op_pool op) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if(idx >= parallel_elements)
if (idx >= parallel_elements) {
return;
}
const int I_HW = ih * iw;
const int O_HW = oh * ow;
const int nc = idx / O_HW;
@ -6058,10 +6060,12 @@ static __global__ void pool2d_nchw_kernel(
const int ew = min(iw, start_w + kw);
const To scale = 1. / (kh * kw);
To res = 0;
switch (op) {
case GGML_OP_POOL_AVG: res = 0; break;
case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
}
for(int i = bh; i < eh; i += 1) {
for(int j = bw; j < ew; j += 1) {
#if __CUDA_ARCH__ >= 350
@ -8741,11 +8745,10 @@ static void ggml_cuda_op_pool2d(
dim3 block_nums(num_blocks);
pool2d_nchw_kernel<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, main_stream>>>(IH, IW, OH, OW, k1, k0, s1, s0, p1, p0, parallel_elements, src0_dd, dst_dd, op);
(void) src0;
(void) src0_dd;
(void) src1;
(void) src1_dd;
}
static void ggml_cuda_op_im2col(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {