sycl: Enhance OP support judgment
This commit is contained in:
parent
bee1cec7d2
commit
1c58096f6f
2 changed files with 17 additions and 4 deletions
|
@ -5733,8 +5733,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
case GGML_OP_CONCAT:
|
case GGML_OP_CONCAT:
|
||||||
{
|
{
|
||||||
ggml_type src0_type = op->src[0]->type;
|
ggml_type src0_type = op->src[0]->type;
|
||||||
int dim = op->op_params[0];
|
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
|
||||||
return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16 && dim == 2;
|
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_DUP:
|
case GGML_OP_DUP:
|
||||||
case GGML_OP_ARGMAX:
|
case GGML_OP_ARGMAX:
|
||||||
|
@ -5797,9 +5796,23 @@ static bool ggml_backend_sycl_device_supports_buft(ggml_backend_dev_t dev, ggml_
|
||||||
return buft_ctx->device == sycl_ctx->device;
|
return buft_ctx->device == sycl_ctx->device;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static int64_t get_op_batch_size(const ggml_tensor * op) {
|
||||||
|
switch (op->op) {
|
||||||
|
case GGML_OP_GET_ROWS:
|
||||||
|
return op->ne[1]; // this will increse the speed of prefill in test
|
||||||
|
case GGML_OP_MUL_MAT:
|
||||||
|
return op->ne[1];
|
||||||
|
case GGML_OP_MUL_MAT_ID:
|
||||||
|
case GGML_OP_ROPE:
|
||||||
|
return op->ne[2];
|
||||||
|
default:
|
||||||
|
return ggml_nrows(op);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static bool ggml_backend_sycl_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
|
static bool ggml_backend_sycl_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
|
||||||
const int min_batch_size = 32;
|
const int min_batch_size = 32;
|
||||||
return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS && op->op != GGML_OP_MUL_MAT_ID;
|
return get_op_batch_size(op) >= min_batch_size;
|
||||||
GGML_UNUSED(dev);
|
GGML_UNUSED(dev);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -106,7 +106,7 @@ static void concat_f32_sycl(const float *x, const float *y, float *dst,
|
||||||
concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1);
|
concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1);
|
||||||
});
|
});
|
||||||
break;
|
break;
|
||||||
default:
|
case 2:
|
||||||
stream->parallel_for(
|
stream->parallel_for(
|
||||||
sycl::nd_range<3>(gridDim *
|
sycl::nd_range<3>(gridDim *
|
||||||
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
|
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue