rebase work_space api
This commit is contained in:
parent
ac8a4bd9d5
commit
87098db626
1 changed files with 3 additions and 3 deletions
|
@ -144,9 +144,9 @@ static void soft_max_f32_submitter(const float * x, const float * mask, float *
|
|||
static void soft_max_f32_sycl(const float * x, const float * mask,
|
||||
float * dst, const int ncols_x, const int nrows_x,
|
||||
const int nrows_y, const float scale, const float max_bias,
|
||||
queue_ptr stream) {
|
||||
queue_ptr stream, int device) {
|
||||
int nth = WARP_SIZE;
|
||||
int max_block_size = get_work_group_size(stream->get_device());
|
||||
int max_block_size = ggml_sycl_info().max_work_group_sizes[device];
|
||||
while (nth < ncols_x && nth < max_block_size) nth *= 2;
|
||||
if (nth>max_block_size) nth = max_block_size;
|
||||
|
||||
|
@ -246,5 +246,5 @@ void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor *s
|
|||
memcpy(&max_bias, dst->op_params + 1, sizeof(float));
|
||||
|
||||
soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00,
|
||||
nrows_x, nrows_y, scale, max_bias, main_stream);
|
||||
nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue