From cc7cd62ee70e2705e71679beb37e71a281742900 Mon Sep 17 00:00:00 2001 From: Akarshan Biswas Date: Wed, 11 Dec 2024 11:07:32 +0530 Subject: [PATCH] SYCL poo2d kernel: set NAN for invalid pooling op --- ggml/src/ggml-sycl/ggml-sycl.cpp | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 0cea15ca4..76576c569 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -1792,12 +1792,15 @@ static void pool2d_nchw_kernel( const int ew = sycl::min(iw, start_w + kw); To res = 0; + bool op_valid = true; switch (op) { case GGML_OP_POOL_AVG: res = 0; break; case GGML_OP_POOL_MAX: res = -FLT_MAX; break; default: - break; // TODO: handle this properly + res = NAN; + op_valid = false; + break; } for (int i = bh; i < eh; i += 1) { @@ -1817,11 +1820,16 @@ static void pool2d_nchw_kernel( case GGML_OP_POOL_AVG: res += (cur / (kh * kw)); break; case GGML_OP_POOL_MAX: res = sycl::max(res, (To)cur); break; default: - break; // TODO: handle this properly + op_valid = false; + break; } } } - o_ptr[cur_oh * ow + cur_ow] = res; + if (op_valid) { + o_ptr[cur_oh * ow + cur_ow] = res; + } else { + o_ptr[cur_oh * ow + cur_ow] = NAN; + } } template