cuda : fix warnings and formatting
This commit is contained in:
parent
04f10a2287
commit
caf2fc8294
1 changed files with 10 additions and 7 deletions
11
ggml-cuda.cu
11
ggml-cuda.cu
|
@ -6041,8 +6041,10 @@ static __global__ void pool2d_nchw_kernel(
|
|||
const int ph, const int pw, const int parallel_elements,
|
||||
const Ti* src, To* dst, const enum ggml_op_pool op) {
|
||||
int idx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if(idx >= parallel_elements)
|
||||
if (idx >= parallel_elements) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int I_HW = ih * iw;
|
||||
const int O_HW = oh * ow;
|
||||
const int nc = idx / O_HW;
|
||||
|
@ -6058,10 +6060,12 @@ static __global__ void pool2d_nchw_kernel(
|
|||
const int ew = min(iw, start_w + kw);
|
||||
const To scale = 1. / (kh * kw);
|
||||
To res = 0;
|
||||
|
||||
switch (op) {
|
||||
case GGML_OP_POOL_AVG: res = 0; break;
|
||||
case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
|
||||
}
|
||||
|
||||
for(int i = bh; i < eh; i += 1) {
|
||||
for(int j = bw; j < ew; j += 1) {
|
||||
#if __CUDA_ARCH__ >= 350
|
||||
|
@ -8741,11 +8745,10 @@ static void ggml_cuda_op_pool2d(
|
|||
dim3 block_nums(num_blocks);
|
||||
pool2d_nchw_kernel<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, main_stream>>>(IH, IW, OH, OW, k1, k0, s1, s0, p1, p0, parallel_elements, src0_dd, dst_dd, op);
|
||||
|
||||
(void) src0;
|
||||
(void) src0_dd;
|
||||
(void) src1;
|
||||
(void) src1_dd;
|
||||
}
|
||||
|
||||
|
||||
static void ggml_cuda_op_im2col(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue