rebase work_space api
This commit is contained in:
parent
ac8a4bd9d5
commit
87098db626
1 changed files with 3 additions and 3 deletions
|
@ -144,9 +144,9 @@ static void soft_max_f32_submitter(const float * x, const float * mask, float *
|
||||||
static void soft_max_f32_sycl(const float * x, const float * mask,
|
static void soft_max_f32_sycl(const float * x, const float * mask,
|
||||||
float * dst, const int ncols_x, const int nrows_x,
|
float * dst, const int ncols_x, const int nrows_x,
|
||||||
const int nrows_y, const float scale, const float max_bias,
|
const int nrows_y, const float scale, const float max_bias,
|
||||||
queue_ptr stream) {
|
queue_ptr stream, int device) {
|
||||||
int nth = WARP_SIZE;
|
int nth = WARP_SIZE;
|
||||||
int max_block_size = get_work_group_size(stream->get_device());
|
int max_block_size = ggml_sycl_info().max_work_group_sizes[device];
|
||||||
while (nth < ncols_x && nth < max_block_size) nth *= 2;
|
while (nth < ncols_x && nth < max_block_size) nth *= 2;
|
||||||
if (nth>max_block_size) nth = max_block_size;
|
if (nth>max_block_size) nth = max_block_size;
|
||||||
|
|
||||||
|
@ -246,5 +246,5 @@ void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor *s
|
||||||
memcpy(&max_bias, dst->op_params + 1, sizeof(float));
|
memcpy(&max_bias, dst->op_params + 1, sizeof(float));
|
||||||
|
|
||||||
soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00,
|
soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00,
|
||||||
nrows_x, nrows_y, scale, max_bias, main_stream);
|
nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue