From 9129362c6f3cf5b08c3e591daacf2b490af69e66 Mon Sep 17 00:00:00 2001 From: Akarshan Biswas Date: Tue, 10 Dec 2024 20:08:19 +0530 Subject: [PATCH] SYCL softmax: Initialize nreduce as size_t --- ggml/src/ggml-sycl/softmax.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-sycl/softmax.cpp b/ggml/src/ggml-sycl/softmax.cpp index e73ad7c2e..a9b3fce0d 100644 --- a/ggml/src/ggml-sycl/softmax.cpp +++ b/ggml/src/ggml-sycl/softmax.cpp @@ -16,7 +16,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE; const int nthreads = block_size; const int nwarps = nthreads / WARP_SIZE; - int nreduce = nwarps / WARP_SIZE; + size_t nreduce = nwarps / WARP_SIZE; float slope = 1.0f; // ALiBi @@ -53,7 +53,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const if (block_size > WARP_SIZE) { if (warp_id == 0) { buf[lane_id] = -INFINITY; - for (size_t i = 1; i < (size_t) nreduce; i += 1) { + for (size_t i = 1; i < nreduce; i += 1) { buf[lane_id + i * WARP_SIZE] = -INFINITY; } } @@ -64,7 +64,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const } item_ct1.barrier(sycl::access::fence_space::local_space); max_val = buf[lane_id]; - for (size_t i = 1; i < (size_t) nreduce; i += 1) { + for (size_t i = 1; i < nreduce; i += 1) { max_val = std::max(max_val, buf[lane_id + i * WARP_SIZE]); } max_val = warp_reduce_max(max_val, item_ct1); @@ -89,7 +89,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const item_ct1.barrier(sycl::access::fence_space::local_space); if (warp_id == 0) { buf[lane_id] = 0.f; - for (size_t i = 1; i < (size_t) nreduce; i += 1) { + for (size_t i = 1; i < nreduce; i += 1) { buf[lane_id + i * WARP_SIZE] = 0.f; } } @@ -101,7 +101,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const item_ct1.barrier(sycl::access::fence_space::local_space); tmp = buf[lane_id]; - for (size_t i = 1; i < (size_t) nreduce; i += 1) { + for (size_t i = 1; i < nreduce; i += 1) { tmp += buf[lane_id + i * WARP_SIZE]; } tmp = warp_reduce_sum(tmp, item_ct1);