diff --git a/ggml/src/ggml-sycl/softmax.cpp b/ggml/src/ggml-sycl/softmax.cpp index e269d7bea..f9fe90add 100644 --- a/ggml/src/ggml-sycl/softmax.cpp +++ b/ggml/src/ggml-sycl/softmax.cpp @@ -14,7 +14,9 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE; const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE; - + const int nthreads = block_size; + const int nwarps = nthreads / WARP_SIZE; + int nreduce = nwarps / WARP_SIZE; float slope = 1.0f; // ALiBi @@ -27,7 +29,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const slope = sycl::pow(base, float(exp)); } - float * vals = vals_smem ? buf + WARP_SIZE : dst + rowx*ncols; + float *vals = vals_smem ? buf + std::max(nwarps, WARP_SIZE) : dst + rowx * ncols; float max_val = -INFINITY; for (int col0 = 0; col0 < ncols; col0 += block_size) { @@ -51,6 +53,8 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const if (block_size > WARP_SIZE) { if (warp_id == 0) { buf[lane_id] = -INFINITY; + for (size_t i = 1; i < nreduce; i += 1) + buf[lane_id + i * WARP_SIZE] = -INFINITY; } item_ct1.barrier(sycl::access::fence_space::local_space); @@ -58,13 +62,15 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const buf[warp_id] = max_val; } item_ct1.barrier(sycl::access::fence_space::local_space); - max_val = buf[lane_id]; + for (size_t i = 1; i < nreduce; i += 1) + { + max_val = std::max(max_val, buf[lane_id + i * WARP_SIZE]); + } max_val = warp_reduce_max(max_val, item_ct1); } float tmp = 0.f; - #pragma unroll for (int col0 = 0; col0 < ncols; col0 += block_size) { const int col = col0 + tid; @@ -83,6 +89,8 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const item_ct1.barrier(sycl::access::fence_space::local_space); if (warp_id == 0) { buf[lane_id] = 0.f; + for (size_t i = 1; i < nreduce; i += 1) + buf[lane_id + i * WARP_SIZE] = 0.f; } item_ct1.barrier(sycl::access::fence_space::local_space); @@ -92,6 +100,10 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const item_ct1.barrier(sycl::access::fence_space::local_space); tmp = buf[lane_id]; + for (size_t i = 1; i < nreduce; i += 1) + { + tmp += buf[lane_id + i * WARP_SIZE]; + } tmp = warp_reduce_sum(tmp, item_ct1); }