SYCL poo2d kernel: set NAN for invalid pooling op

This commit is contained in:
Akarshan Biswas 2024-12-11 11:07:32 +05:30
parent 7dda9aad23
commit cc7cd62ee7
No known key found for this signature in database
GPG key ID: 52A578A14B32134D

View file

@ -1792,12 +1792,15 @@ static void pool2d_nchw_kernel(
const int ew = sycl::min(iw, start_w + kw);
To res = 0;
bool op_valid = true;
switch (op) {
case GGML_OP_POOL_AVG: res = 0; break;
case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
default:
break; // TODO: handle this properly
res = NAN;
op_valid = false;
break;
}
for (int i = bh; i < eh; i += 1) {
@ -1817,11 +1820,16 @@ static void pool2d_nchw_kernel(
case GGML_OP_POOL_AVG: res += (cur / (kh * kw)); break;
case GGML_OP_POOL_MAX: res = sycl::max(res, (To)cur); break;
default:
break; // TODO: handle this properly
op_valid = false;
break;
}
}
}
o_ptr[cur_oh * ow + cur_ow] = res;
if (op_valid) {
o_ptr[cur_oh * ow + cur_ow] = res;
} else {
o_ptr[cur_oh * ow + cur_ow] = NAN;
}
}
template <int qk, int qr, dequantize_kernel_t dq>