diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index ff5c2b5de..5e12ea9dd 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -6516,15 +6516,10 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const } GGML_CALL static bool ggml_backend_vk_offload_op(ggml_backend_t backend, const ggml_tensor * op) { - const ggml_tensor * dst = op; - const int min_batch_size = 32; - if (dst->ne[1] > min_batch_size && dst->op != GGML_OP_GET_ROWS) { - return true; - } - - return false; + return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) || + (op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID); UNUSED(backend); }