diff --git a/ggml/src/ggml-kompute.cpp b/ggml/src/ggml-kompute.cpp index 911827dbd..602d9e97e 100644 --- a/ggml/src/ggml-kompute.cpp +++ b/ggml/src/ggml-kompute.cpp @@ -2002,6 +2002,14 @@ static bool ggml_backend_kompute_supports_buft(ggml_backend_t backend, ggml_back return &ctx->buft == buft; } +static bool ggml_backend_kompute_offload_op(ggml_backend_t backend, const ggml_tensor * op) { + GGML_UNUSED(backend); + const int min_batch_size = 32; + + return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) || + (op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID); +} + static struct ggml_backend_i kompute_backend_i = { /* .get_name = */ ggml_backend_kompute_name, /* .free = */ ggml_backend_kompute_free, @@ -2017,7 +2025,7 @@ static struct ggml_backend_i kompute_backend_i = { /* .graph_compute = */ ggml_backend_kompute_graph_compute, /* .supports_op = */ ggml_backend_kompute_supports_op, /* .supports_buft = */ ggml_backend_kompute_supports_buft, - /* .offload_op = */ NULL, + /* .offload_op = */ ggml_backend_kompute_offload_op, /* .event_new = */ NULL, /* .event_free = */ NULL, /* .event_record = */ NULL,