rebase work_space api

This commit is contained in:
luoyu-intel 2024-07-05 10:57:05 +08:00
parent ac8a4bd9d5
commit 87098db626

View file

@ -144,9 +144,9 @@ static void soft_max_f32_submitter(const float * x, const float * mask, float *
static void soft_max_f32_sycl(const float * x, const float * mask,
float * dst, const int ncols_x, const int nrows_x,
const int nrows_y, const float scale, const float max_bias,
queue_ptr stream) {
queue_ptr stream, int device) {
int nth = WARP_SIZE;
int max_block_size = get_work_group_size(stream->get_device());
int max_block_size = ggml_sycl_info().max_work_group_sizes[device];
while (nth < ncols_x && nth < max_block_size) nth *= 2;
if (nth>max_block_size) nth = max_block_size;
@ -246,5 +246,5 @@ void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor *s
memcpy(&max_bias, dst->op_params + 1, sizeof(float));
soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00,
nrows_x, nrows_y, scale, max_bias, main_stream);
nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
}