cuda : more style fixes
This commit is contained in:
parent
8824e42786
commit
0d94da7cbb
1 changed files with 6 additions and 5 deletions
11
ggml-cuda.cu
11
ggml-cuda.cu
|
@ -6066,14 +6066,14 @@ static __global__ void pool2d_nchw_kernel(
|
||||||
case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
|
case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
|
||||||
}
|
}
|
||||||
|
|
||||||
for(int i = bh; i < eh; i += 1) {
|
for (int i = bh; i < eh; i += 1) {
|
||||||
for(int j = bw; j < ew; j += 1) {
|
for (int j = bw; j < ew; j += 1) {
|
||||||
#if __CUDA_ARCH__ >= 350
|
#if __CUDA_ARCH__ >= 350
|
||||||
Ti cur = __ldg(i_ptr + i * iw + j);
|
Ti cur = __ldg(i_ptr + i * iw + j);
|
||||||
#else
|
#else
|
||||||
Ti cur = i_ptr[i * iw + j];
|
Ti cur = i_ptr[i * iw + j];
|
||||||
#endif
|
#endif
|
||||||
switch(op){
|
switch (op) {
|
||||||
case GGML_OP_POOL_AVG: res += cur * scale; break;
|
case GGML_OP_POOL_AVG: res += cur * scale; break;
|
||||||
case GGML_OP_POOL_MAX: res = max(res, (To)cur); break;
|
case GGML_OP_POOL_MAX: res = max(res, (To)cur); break;
|
||||||
}
|
}
|
||||||
|
@ -8780,10 +8780,11 @@ static void ggml_cuda_op_im2col(
|
||||||
const int64_t batch = src1->ne[3];
|
const int64_t batch = src1->ne[3];
|
||||||
const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32
|
const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32
|
||||||
|
|
||||||
if(dst->type == GGML_TYPE_F16)
|
if(dst->type == GGML_TYPE_F16) {
|
||||||
im2col_cuda(src1_dd, (half*) dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
|
im2col_cuda(src1_dd, (half*) dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
|
||||||
else
|
} else {
|
||||||
im2col_cuda(src1_dd, (float*) dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
|
im2col_cuda(src1_dd, (float*) dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
|
||||||
|
}
|
||||||
|
|
||||||
(void) src0;
|
(void) src0;
|
||||||
(void) src0_dd;
|
(void) src0_dd;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue