add uma support

This commit is contained in:
Zhiyuan Li 2024-12-16 18:43:22 +08:00
parent 6ea605ddfc
commit 353c5f8c7b

View file

@ -5471,21 +5471,58 @@ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subc
ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer->context;
ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer->context;
vk_buffer d_D = dst_buf_ctx->dev_buffer;
vk_buffer d_K = k_buf_ctx->dev_buffer;
vk_buffer d_V = v_buf_ctx->dev_buffer;
vk_buffer d_R = r_buf_ctx->dev_buffer;
vk_buffer d_TF = tf_buf_ctx->dev_buffer;
vk_buffer d_TD = td_buf_ctx->dev_buffer;
vk_buffer d_State = state_buf_ctx->dev_buffer;
ggml_vk_sync_buffers(subctx);
vk_buffer d_D, d_K, d_V, d_R, d_TF, d_TD, d_State;
uint64_t k_offset, v_offset, r_offset, tf_offset, td_offset, state_offset, dst_offset;
bool K_uma = false, V_uma = false, R_uma = false, TF_uma = false, TD_uma = false, STATE_uma = false, DST_uma = false;
const uint64_t k_offset = vk_tensor_offset(k);
const uint64_t v_offset = vk_tensor_offset(v);
const uint64_t r_offset = vk_tensor_offset(r);
const uint64_t tf_offset = vk_tensor_offset(tf);
const uint64_t td_offset = vk_tensor_offset(td);
const uint64_t state_offset = vk_tensor_offset(state);
const uint64_t dst_offset = vk_tensor_offset(dst);
if (ctx->device->uma) {
ggml_vk_host_get(ctx->device, k->data, d_K, k_offset);
ggml_vk_host_get(ctx->device, v->data, d_V, v_offset);
ggml_vk_host_get(ctx->device, r->data, d_R, r_offset);
ggml_vk_host_get(ctx->device, tf->data, d_TF, tf_offset);
ggml_vk_host_get(ctx->device, td->data, d_TD, td_offset);
ggml_vk_host_get(ctx->device, state->data, d_State, state_offset);
ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
K_uma = d_K != nullptr;
V_uma = d_V != nullptr;
R_uma = d_R != nullptr;
TF_uma = d_TF != nullptr;
TD_uma = d_TD != nullptr;
STATE_uma = d_State != nullptr;
DST_uma = d_D != nullptr;
}
if (!K_uma) {
d_K = k_buf_ctx->dev_buffer;
k_offset = vk_tensor_offset(k) + k->view_offs;
}
if (!V_uma) {
d_V = v_buf_ctx->dev_buffer;
v_offset = vk_tensor_offset(v) + v->view_offs;
}
if (!R_uma) {
d_R = r_buf_ctx->dev_buffer;
r_offset = vk_tensor_offset(r) + r->view_offs;
}
if (!TF_uma) {
d_TF = tf_buf_ctx->dev_buffer;
tf_offset = vk_tensor_offset(tf) + tf->view_offs;
}
if (!TD_uma) {
d_TD = td_buf_ctx->dev_buffer;
td_offset = vk_tensor_offset(td) + td->view_offs;
}
if (!STATE_uma) {
d_State = state_buf_ctx->dev_buffer;
state_offset = vk_tensor_offset(state) + state->view_offs;
}
if (!DST_uma) {
d_D = dst_buf_ctx->dev_buffer;
dst_offset = vk_tensor_offset(dst) + dst->view_offs;
}
const uint64_t k_size = ggml_nbytes(k);
const uint64_t v_size = ggml_nbytes(v);
@ -5501,7 +5538,7 @@ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subc
1
};
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
vk_subbuffer{ d_K, k_offset, k_size },
vk_subbuffer{ d_V, v_offset, v_size },