From 353c5f8c7b6a3583b3632d09a1e8893c5f3d2954 Mon Sep 17 00:00:00 2001 From: Zhiyuan Li Date: Mon, 16 Dec 2024 18:43:22 +0800 Subject: [PATCH] add uma support --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 67 +++++++++++++++++++++------- 1 file changed, 52 insertions(+), 15 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 79a4c6d9a..3aa1fc07b 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -5471,21 +5471,58 @@ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subc ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer->context; ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer->context; - vk_buffer d_D = dst_buf_ctx->dev_buffer; - vk_buffer d_K = k_buf_ctx->dev_buffer; - vk_buffer d_V = v_buf_ctx->dev_buffer; - vk_buffer d_R = r_buf_ctx->dev_buffer; - vk_buffer d_TF = tf_buf_ctx->dev_buffer; - vk_buffer d_TD = td_buf_ctx->dev_buffer; - vk_buffer d_State = state_buf_ctx->dev_buffer; + ggml_vk_sync_buffers(subctx); + + vk_buffer d_D, d_K, d_V, d_R, d_TF, d_TD, d_State; + uint64_t k_offset, v_offset, r_offset, tf_offset, td_offset, state_offset, dst_offset; + bool K_uma = false, V_uma = false, R_uma = false, TF_uma = false, TD_uma = false, STATE_uma = false, DST_uma = false; - const uint64_t k_offset = vk_tensor_offset(k); - const uint64_t v_offset = vk_tensor_offset(v); - const uint64_t r_offset = vk_tensor_offset(r); - const uint64_t tf_offset = vk_tensor_offset(tf); - const uint64_t td_offset = vk_tensor_offset(td); - const uint64_t state_offset = vk_tensor_offset(state); - const uint64_t dst_offset = vk_tensor_offset(dst); + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, k->data, d_K, k_offset); + ggml_vk_host_get(ctx->device, v->data, d_V, v_offset); + ggml_vk_host_get(ctx->device, r->data, d_R, r_offset); + ggml_vk_host_get(ctx->device, tf->data, d_TF, tf_offset); + ggml_vk_host_get(ctx->device, td->data, d_TD, td_offset); + ggml_vk_host_get(ctx->device, state->data, d_State, state_offset); + ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset); + + K_uma = d_K != nullptr; + V_uma = d_V != nullptr; + R_uma = d_R != nullptr; + TF_uma = d_TF != nullptr; + TD_uma = d_TD != nullptr; + STATE_uma = d_State != nullptr; + DST_uma = d_D != nullptr; + } + + if (!K_uma) { + d_K = k_buf_ctx->dev_buffer; + k_offset = vk_tensor_offset(k) + k->view_offs; + } + if (!V_uma) { + d_V = v_buf_ctx->dev_buffer; + v_offset = vk_tensor_offset(v) + v->view_offs; + } + if (!R_uma) { + d_R = r_buf_ctx->dev_buffer; + r_offset = vk_tensor_offset(r) + r->view_offs; + } + if (!TF_uma) { + d_TF = tf_buf_ctx->dev_buffer; + tf_offset = vk_tensor_offset(tf) + tf->view_offs; + } + if (!TD_uma) { + d_TD = td_buf_ctx->dev_buffer; + td_offset = vk_tensor_offset(td) + td->view_offs; + } + if (!STATE_uma) { + d_State = state_buf_ctx->dev_buffer; + state_offset = vk_tensor_offset(state) + state->view_offs; + } + if (!DST_uma) { + d_D = dst_buf_ctx->dev_buffer; + dst_offset = vk_tensor_offset(dst) + dst->view_offs; + } const uint64_t k_size = ggml_nbytes(k); const uint64_t v_size = ggml_nbytes(v); @@ -5501,7 +5538,7 @@ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subc 1 }; - ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_K, k_offset, k_size }, vk_subbuffer{ d_V, v_offset, v_size },