SYCL poo2d kernel: set NAN for invalid pooling op
This commit is contained in:
parent
7dda9aad23
commit
cc7cd62ee7
1 changed files with 11 additions and 3 deletions
|
@ -1792,12 +1792,15 @@ static void pool2d_nchw_kernel(
|
||||||
const int ew = sycl::min(iw, start_w + kw);
|
const int ew = sycl::min(iw, start_w + kw);
|
||||||
|
|
||||||
To res = 0;
|
To res = 0;
|
||||||
|
bool op_valid = true;
|
||||||
|
|
||||||
switch (op) {
|
switch (op) {
|
||||||
case GGML_OP_POOL_AVG: res = 0; break;
|
case GGML_OP_POOL_AVG: res = 0; break;
|
||||||
case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
|
case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
|
||||||
default:
|
default:
|
||||||
break; // TODO: handle this properly
|
res = NAN;
|
||||||
|
op_valid = false;
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = bh; i < eh; i += 1) {
|
for (int i = bh; i < eh; i += 1) {
|
||||||
|
@ -1817,11 +1820,16 @@ static void pool2d_nchw_kernel(
|
||||||
case GGML_OP_POOL_AVG: res += (cur / (kh * kw)); break;
|
case GGML_OP_POOL_AVG: res += (cur / (kh * kw)); break;
|
||||||
case GGML_OP_POOL_MAX: res = sycl::max(res, (To)cur); break;
|
case GGML_OP_POOL_MAX: res = sycl::max(res, (To)cur); break;
|
||||||
default:
|
default:
|
||||||
break; // TODO: handle this properly
|
op_valid = false;
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
o_ptr[cur_oh * ow + cur_ow] = res;
|
if (op_valid) {
|
||||||
|
o_ptr[cur_oh * ow + cur_ow] = res;
|
||||||
|
} else {
|
||||||
|
o_ptr[cur_oh * ow + cur_ow] = NAN;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int qk, int qr, dequantize_kernel_t dq>
|
template <int qk, int qr, dequantize_kernel_t dq>
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue