From 1c58096f6f869e5c5a242de1e17b4bfb13ded8b2 Mon Sep 17 00:00:00 2001 From: Zhiyuan Li Date: Sun, 3 Nov 2024 16:36:42 +1100 Subject: [PATCH] sycl: Enhance OP support judgment --- ggml/src/ggml-sycl.cpp | 19 ++++++++++++++++--- ggml/src/ggml-sycl/concat.cpp | 2 +- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-sycl.cpp b/ggml/src/ggml-sycl.cpp index 82eaec271..fff9b76f5 100644 --- a/ggml/src/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl.cpp @@ -5733,8 +5733,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_CONCAT: { ggml_type src0_type = op->src[0]->type; - int dim = op->op_params[0]; - return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16 && dim == 2; + return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16; } break; case GGML_OP_DUP: case GGML_OP_ARGMAX: @@ -5797,9 +5796,23 @@ static bool ggml_backend_sycl_device_supports_buft(ggml_backend_dev_t dev, ggml_ return buft_ctx->device == sycl_ctx->device; } +static int64_t get_op_batch_size(const ggml_tensor * op) { + switch (op->op) { + case GGML_OP_GET_ROWS: + return op->ne[1]; // this will increse the speed of prefill in test + case GGML_OP_MUL_MAT: + return op->ne[1]; + case GGML_OP_MUL_MAT_ID: + case GGML_OP_ROPE: + return op->ne[2]; + default: + return ggml_nrows(op); + } +} + static bool ggml_backend_sycl_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) { const int min_batch_size = 32; - return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS && op->op != GGML_OP_MUL_MAT_ID; + return get_op_batch_size(op) >= min_batch_size; GGML_UNUSED(dev); } diff --git a/ggml/src/ggml-sycl/concat.cpp b/ggml/src/ggml-sycl/concat.cpp index 632eedb9d..cec313781 100644 --- a/ggml/src/ggml-sycl/concat.cpp +++ b/ggml/src/ggml-sycl/concat.cpp @@ -106,7 +106,7 @@ static void concat_f32_sycl(const float *x, const float *y, float *dst, concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1); }); break; - default: + case 2: stream->parallel_for( sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),