add uma support
This commit is contained in:
parent
6ea605ddfc
commit
353c5f8c7b
1 changed files with 52 additions and 15 deletions
|
@ -5471,21 +5471,58 @@ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subc
|
|||
ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer->context;
|
||||
ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer->context;
|
||||
|
||||
vk_buffer d_D = dst_buf_ctx->dev_buffer;
|
||||
vk_buffer d_K = k_buf_ctx->dev_buffer;
|
||||
vk_buffer d_V = v_buf_ctx->dev_buffer;
|
||||
vk_buffer d_R = r_buf_ctx->dev_buffer;
|
||||
vk_buffer d_TF = tf_buf_ctx->dev_buffer;
|
||||
vk_buffer d_TD = td_buf_ctx->dev_buffer;
|
||||
vk_buffer d_State = state_buf_ctx->dev_buffer;
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
|
||||
vk_buffer d_D, d_K, d_V, d_R, d_TF, d_TD, d_State;
|
||||
uint64_t k_offset, v_offset, r_offset, tf_offset, td_offset, state_offset, dst_offset;
|
||||
bool K_uma = false, V_uma = false, R_uma = false, TF_uma = false, TD_uma = false, STATE_uma = false, DST_uma = false;
|
||||
|
||||
const uint64_t k_offset = vk_tensor_offset(k);
|
||||
const uint64_t v_offset = vk_tensor_offset(v);
|
||||
const uint64_t r_offset = vk_tensor_offset(r);
|
||||
const uint64_t tf_offset = vk_tensor_offset(tf);
|
||||
const uint64_t td_offset = vk_tensor_offset(td);
|
||||
const uint64_t state_offset = vk_tensor_offset(state);
|
||||
const uint64_t dst_offset = vk_tensor_offset(dst);
|
||||
if (ctx->device->uma) {
|
||||
ggml_vk_host_get(ctx->device, k->data, d_K, k_offset);
|
||||
ggml_vk_host_get(ctx->device, v->data, d_V, v_offset);
|
||||
ggml_vk_host_get(ctx->device, r->data, d_R, r_offset);
|
||||
ggml_vk_host_get(ctx->device, tf->data, d_TF, tf_offset);
|
||||
ggml_vk_host_get(ctx->device, td->data, d_TD, td_offset);
|
||||
ggml_vk_host_get(ctx->device, state->data, d_State, state_offset);
|
||||
ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
|
||||
|
||||
K_uma = d_K != nullptr;
|
||||
V_uma = d_V != nullptr;
|
||||
R_uma = d_R != nullptr;
|
||||
TF_uma = d_TF != nullptr;
|
||||
TD_uma = d_TD != nullptr;
|
||||
STATE_uma = d_State != nullptr;
|
||||
DST_uma = d_D != nullptr;
|
||||
}
|
||||
|
||||
if (!K_uma) {
|
||||
d_K = k_buf_ctx->dev_buffer;
|
||||
k_offset = vk_tensor_offset(k) + k->view_offs;
|
||||
}
|
||||
if (!V_uma) {
|
||||
d_V = v_buf_ctx->dev_buffer;
|
||||
v_offset = vk_tensor_offset(v) + v->view_offs;
|
||||
}
|
||||
if (!R_uma) {
|
||||
d_R = r_buf_ctx->dev_buffer;
|
||||
r_offset = vk_tensor_offset(r) + r->view_offs;
|
||||
}
|
||||
if (!TF_uma) {
|
||||
d_TF = tf_buf_ctx->dev_buffer;
|
||||
tf_offset = vk_tensor_offset(tf) + tf->view_offs;
|
||||
}
|
||||
if (!TD_uma) {
|
||||
d_TD = td_buf_ctx->dev_buffer;
|
||||
td_offset = vk_tensor_offset(td) + td->view_offs;
|
||||
}
|
||||
if (!STATE_uma) {
|
||||
d_State = state_buf_ctx->dev_buffer;
|
||||
state_offset = vk_tensor_offset(state) + state->view_offs;
|
||||
}
|
||||
if (!DST_uma) {
|
||||
d_D = dst_buf_ctx->dev_buffer;
|
||||
dst_offset = vk_tensor_offset(dst) + dst->view_offs;
|
||||
}
|
||||
|
||||
const uint64_t k_size = ggml_nbytes(k);
|
||||
const uint64_t v_size = ggml_nbytes(v);
|
||||
|
@ -5501,7 +5538,7 @@ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subc
|
|||
1
|
||||
};
|
||||
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
|
||||
vk_subbuffer{ d_K, k_offset, k_size },
|
||||
vk_subbuffer{ d_V, v_offset, v_size },
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue