SYCL ggml-sycl: pool2D use sycl::nan and remove if-else block
This commit is contained in:
parent
b828f4aa5f
commit
6b0848ceaf
1 changed files with 3 additions and 9 deletions
|
@ -1789,14 +1789,12 @@ static void pool2d_nchw_kernel(
|
|||
const int ew = sycl::min(iw, start_w + kw);
|
||||
|
||||
To res = 0;
|
||||
bool op_valid = true;
|
||||
|
||||
switch (op) {
|
||||
case GGML_OP_POOL_AVG: res = 0; break;
|
||||
case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
|
||||
default:
|
||||
res = NAN;
|
||||
op_valid = false;
|
||||
res = (To) sycl::nan(uint32_t(0));
|
||||
break;
|
||||
}
|
||||
|
||||
|
@ -1817,16 +1815,12 @@ static void pool2d_nchw_kernel(
|
|||
case GGML_OP_POOL_AVG: res += (cur / (kh * kw)); break;
|
||||
case GGML_OP_POOL_MAX: res = sycl::max(res, (To)cur); break;
|
||||
default:
|
||||
op_valid = false;
|
||||
res = (To) sycl::nan(uint32_t(0));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (op_valid) {
|
||||
o_ptr[cur_oh * ow + cur_ow] = res;
|
||||
} else {
|
||||
o_ptr[cur_oh * ow + cur_ow] = NAN;
|
||||
}
|
||||
o_ptr[cur_oh * ow + cur_ow] = res;
|
||||
}
|
||||
|
||||
template <int qk, int qr, dequantize_kernel_t dq>
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue