diff --git a/ggml/src/ggml-sycl/softmax.cpp b/ggml/src/ggml-sycl/softmax.cpp index f9fe90add..e624b6ba3 100644 --- a/ggml/src/ggml-sycl/softmax.cpp +++ b/ggml/src/ggml-sycl/softmax.cpp @@ -144,9 +144,9 @@ static void soft_max_f32_submitter(const float * x, const float * mask, float * static void soft_max_f32_sycl(const float * x, const float * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, - queue_ptr stream) { + queue_ptr stream, int device) { int nth = WARP_SIZE; - int max_block_size = get_work_group_size(stream->get_device()); + int max_block_size = ggml_sycl_info().max_work_group_sizes[device]; while (nth < ncols_x && nth < max_block_size) nth *= 2; if (nth>max_block_size) nth = max_block_size; @@ -246,5 +246,5 @@ void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor *s memcpy(&max_bias, dst->op_params + 1, sizeof(float)); soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, - nrows_x, nrows_y, scale, max_bias, main_stream); + nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device); }