Apply code format changes
Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
parent
77fe4fd982
commit
60bbd4ebf1
3 changed files with 60 additions and 114 deletions
|
@ -524,15 +524,13 @@ struct vk_op_pool2d_push_constants {
|
|||
int32_t p0; int32_t p1;
|
||||
};
|
||||
|
||||
|
||||
struct vk_op_rwkv_wkv6_push_constants {
|
||||
uint32_t B; // Batch size (原n_seqs)
|
||||
uint32_t T; // Sequence length
|
||||
uint32_t C; // Total channels
|
||||
uint32_t H; // Number of heads (原HEADS)
|
||||
uint32_t B;
|
||||
uint32_t T;
|
||||
uint32_t C;
|
||||
uint32_t H;
|
||||
};
|
||||
|
||||
|
||||
// Allow pre-recording command buffers
|
||||
struct vk_staging_memcpy {
|
||||
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
|
||||
|
@ -1952,19 +1950,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(
|
||||
device,
|
||||
device->pipeline_rwkv_wkv6_f32,
|
||||
"rwkv_wkv6_f32",
|
||||
rwkv_wkv6_f32_len,
|
||||
rwkv_wkv6_f32_data,
|
||||
"main",
|
||||
7,
|
||||
sizeof(vk_op_rwkv_wkv6_push_constants),
|
||||
{1, 1, 1}, // work group
|
||||
{device->subgroup_size},
|
||||
1
|
||||
);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
||||
|
||||
for (auto &c : compiles) {
|
||||
c.wait();
|
||||
|
@ -5348,28 +5334,14 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
|||
}, dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, bool dryrun = false) {
|
||||
const ggml_tensor * k = dst->src[0];
|
||||
const ggml_tensor * v = dst->src[1];
|
||||
const ggml_tensor * r = dst->src[2];
|
||||
const ggml_tensor * tf = dst->src[3];
|
||||
const ggml_tensor * td = dst->src[4];
|
||||
const ggml_tensor * state = dst->src[5];
|
||||
|
||||
|
||||
template<typename PC>
|
||||
static void ggml_vk_op_f32_rwkv6(
|
||||
ggml_backend_vk_context * ctx,
|
||||
vk_context& subctx,
|
||||
ggml_tensor * dst,
|
||||
const PC&& pc,
|
||||
bool dryrun = false) {
|
||||
|
||||
// Get source tensors
|
||||
const ggml_tensor * k = dst->src[0]; // keys
|
||||
const ggml_tensor * v = dst->src[1]; // values
|
||||
const ggml_tensor * r = dst->src[2]; // reset gates
|
||||
const ggml_tensor * tf = dst->src[3]; // time first
|
||||
const ggml_tensor * td = dst->src[4]; // time decay
|
||||
const ggml_tensor * state = dst->src[5]; // states
|
||||
|
||||
VK_LOG_DEBUG("ggml_vk_op_f32_rwkv6(" << k << ", " << v << ", " << r << ", "
|
||||
<< tf << ", " << td << ", " << state << ", " << dst << ")");
|
||||
|
||||
// Verify input types
|
||||
GGML_ASSERT(!ggml_is_quantized(k->type));
|
||||
GGML_ASSERT(!ggml_is_quantized(v->type));
|
||||
GGML_ASSERT(!ggml_is_quantized(r->type));
|
||||
|
@ -5378,7 +5350,6 @@ static void ggml_vk_op_f32_rwkv6(
|
|||
GGML_ASSERT(!ggml_is_quantized(state->type));
|
||||
GGML_ASSERT(dst->buffer != nullptr);
|
||||
|
||||
// Get pipeline
|
||||
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, k, v, r, dst, GGML_OP_RWKV_WKV6);
|
||||
GGML_ASSERT(pipeline != nullptr);
|
||||
|
||||
|
@ -5387,7 +5358,6 @@ static void ggml_vk_op_f32_rwkv6(
|
|||
return;
|
||||
}
|
||||
|
||||
// Get buffer contexts
|
||||
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
||||
ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context;
|
||||
ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context;
|
||||
|
@ -5396,7 +5366,6 @@ static void ggml_vk_op_f32_rwkv6(
|
|||
ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer->context;
|
||||
ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer->context;
|
||||
|
||||
// Get device buffers
|
||||
vk_buffer d_D = dst_buf_ctx->dev_buffer;
|
||||
vk_buffer d_K = k_buf_ctx->dev_buffer;
|
||||
vk_buffer d_V = v_buf_ctx->dev_buffer;
|
||||
|
@ -5405,7 +5374,6 @@ static void ggml_vk_op_f32_rwkv6(
|
|||
vk_buffer d_TD = td_buf_ctx->dev_buffer;
|
||||
vk_buffer d_State = state_buf_ctx->dev_buffer;
|
||||
|
||||
// Calculate buffer offsets
|
||||
const uint64_t k_offset = vk_tensor_offset(k);
|
||||
const uint64_t v_offset = vk_tensor_offset(v);
|
||||
const uint64_t r_offset = vk_tensor_offset(r);
|
||||
|
@ -5414,7 +5382,6 @@ static void ggml_vk_op_f32_rwkv6(
|
|||
const uint64_t state_offset = vk_tensor_offset(state);
|
||||
const uint64_t dst_offset = vk_tensor_offset(dst);
|
||||
|
||||
// Calculate buffer sizes
|
||||
const uint64_t k_size = ggml_nbytes(k);
|
||||
const uint64_t v_size = ggml_nbytes(v);
|
||||
const uint64_t r_size = ggml_nbytes(r);
|
||||
|
@ -5423,14 +5390,12 @@ static void ggml_vk_op_f32_rwkv6(
|
|||
const uint64_t state_size = ggml_nbytes(state);
|
||||
const uint64_t dst_size = ggml_nbytes(dst);
|
||||
|
||||
// Set work elements based on tensor dimensions
|
||||
std::array<uint32_t, 3> elements = {
|
||||
(uint32_t)(pc.B*pc.H), // B * H workgroups
|
||||
1, // 每个workgroup 64个线程
|
||||
(uint32_t)(pc.B * pc.H),
|
||||
1,
|
||||
1
|
||||
};
|
||||
|
||||
// Synchronize buffers and dispatch compute pipeline
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
|
||||
vk_subbuffer{ d_K, k_offset, k_size },
|
||||
|
@ -5440,35 +5405,27 @@ static void ggml_vk_op_f32_rwkv6(
|
|||
vk_subbuffer{ d_TD, td_offset, td_size },
|
||||
vk_subbuffer{ d_State, state_offset, state_size },
|
||||
vk_subbuffer{ d_D, dst_offset, dst_size }
|
||||
}, sizeof(PC), &pc, elements);
|
||||
}, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements);
|
||||
}
|
||||
|
||||
static void ggml_vk_rwkv_wkv6(
|
||||
ggml_backend_vk_context * ctx,
|
||||
vk_context& subctx,
|
||||
ggml_tensor * dst,
|
||||
bool dryrun = false) {
|
||||
|
||||
// Extract dimensions from tensors
|
||||
const size_t T = dst->src[0]->ne[3]; // Sequence length
|
||||
const size_t C = dst->ne[0]; // Channel dimension
|
||||
const size_t HEADS = dst->src[0]->ne[2]; // Number of heads
|
||||
const size_t n_seqs = dst->src[5]->ne[1]; // Batch size
|
||||
static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
|
||||
const size_t seq_length = dst->src[0]->ne[3];
|
||||
const size_t n_embed = dst->ne[0];
|
||||
const size_t n_heads = dst->src[0]->ne[2];
|
||||
const size_t n_seqs = dst->src[5]->ne[1];
|
||||
|
||||
// Call implementation with push constants
|
||||
ggml_vk_op_f32_rwkv6<vk_op_rwkv_wkv6_push_constants>(
|
||||
ggml_vk_op_f32_rwkv6(
|
||||
ctx, subctx, dst,
|
||||
{
|
||||
(uint32_t)n_seqs, // B
|
||||
(uint32_t)T, // T
|
||||
(uint32_t)C, // C
|
||||
(uint32_t)HEADS, // H
|
||||
(uint32_t)n_seqs,
|
||||
(uint32_t)seq_length,
|
||||
(uint32_t)n_embed,
|
||||
(uint32_t)n_heads,
|
||||
},
|
||||
dryrun
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
||||
int * op_params = (int *)dst->op_params;
|
||||
|
||||
|
@ -8344,10 +8301,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|||
} else if (tensor->op == GGML_OP_LEAKY_RELU) {
|
||||
const float * op_params = (const float *)tensor->op_params;
|
||||
tensor_clone = ggml_leaky_relu(ggml_ctx, src0_clone, op_params[0], false);
|
||||
} else if (tensor->op == GGML_OP_RWKV_WKV6) {
|
||||
} else if (tensor->op == GGML_OP_RWKV_WKV6) {
|
||||
tensor_clone = ggml_rwkv_wkv6(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3],
|
||||
tensor->src[4], tensor->src[5]);
|
||||
}
|
||||
}
|
||||
else {
|
||||
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
|
||||
GGML_ABORT("fatal error");
|
||||
|
|
|
@ -479,7 +479,7 @@ void process_shaders() {
|
|||
|
||||
string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"C_TYPE", "float"}, {"D_TYPE", "float"}, {"E_TYPE", "float"}, {"F_TYPE", "float"}, {"S_TYPE", "float"}}));
|
||||
string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||
|
||||
for (auto &c : compiles) {
|
||||
c.wait();
|
||||
|
|
|
@ -1,96 +1,85 @@
|
|||
#version 450
|
||||
|
||||
|
||||
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
|
||||
#define BLOCK_SIZE 64
|
||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout(push_constant) uniform Parameters {
|
||||
uint B; // Batch size
|
||||
uint T; // Sequence length
|
||||
uint C; // Total number of channels
|
||||
uint H; // Number of heads
|
||||
uint B;
|
||||
uint T;
|
||||
uint C;
|
||||
uint H;
|
||||
};
|
||||
|
||||
layout(set = 0, binding = 0) readonly buffer KBuf { float k[]; };
|
||||
layout(set = 0, binding = 1) readonly buffer VBuf { float v[]; };
|
||||
layout(set = 0, binding = 2) readonly buffer RBuf { float r[]; };
|
||||
layout(set = 0, binding = 3) readonly buffer TimeFBuf { float tf[]; };
|
||||
layout(set = 0, binding = 4) readonly buffer TimeDBuf { float td[]; };
|
||||
layout(set = 0, binding = 5) readonly buffer StateBuf { float state_in[]; };
|
||||
layout(set = 0, binding = 6) buffer DstBuf { float dst[]; };
|
||||
layout(binding = 0) readonly buffer KBuf { A_TYPE k[]; };
|
||||
layout(binding = 1) readonly buffer VBuf { A_TYPE v[]; };
|
||||
layout(binding = 2) readonly buffer RBuf { A_TYPE r[]; };
|
||||
layout(binding = 3) readonly buffer TimeFBuf { A_TYPE tf[]; };
|
||||
layout(binding = 4) readonly buffer TimeDBuf { A_TYPE td[]; };
|
||||
layout(binding = 5) readonly buffer StateBuf { A_TYPE state_in[]; };
|
||||
layout(binding = 6) buffer DstBuf { A_TYPE dst[]; };
|
||||
|
||||
shared float _k[64], _r[64], _tf[64], _td[64];
|
||||
shared A_TYPE _k[BLOCK_SIZE], _r[BLOCK_SIZE], _tf[BLOCK_SIZE], _td[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
const uint head_size = 64;
|
||||
const uint head_size = BLOCK_SIZE;
|
||||
const uint batch_id = gl_WorkGroupID.x / H;
|
||||
const uint head_id = gl_WorkGroupID.x % H;
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
|
||||
|
||||
const uint state_size = C * head_size;
|
||||
const uint n_seq_tokens = T / B;
|
||||
|
||||
if (tid >= head_size || batch_id >= B || head_id >= H) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Load state
|
||||
float state[64]; // Use fixed size matching head_size
|
||||
|
||||
A_TYPE state[BLOCK_SIZE];
|
||||
for (uint i = 0; i < head_size; i++) {
|
||||
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
|
||||
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
|
||||
+ i * head_size + tid];
|
||||
}
|
||||
|
||||
_k[tid] = 0.0;
|
||||
_r[tid] = 0.0;
|
||||
_td[tid] = 0.0;
|
||||
|
||||
barrier();
|
||||
_tf[tid] = tf[head_id * head_size + tid];
|
||||
barrier();
|
||||
|
||||
|
||||
// Main loop
|
||||
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
|
||||
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
|
||||
|
||||
|
||||
for (uint t = start_t; t < end_t; t += C) {
|
||||
barrier();
|
||||
_k[tid] = k[t];
|
||||
_r[tid] = r[t];
|
||||
_td[tid] = td[t];
|
||||
barrier();
|
||||
|
||||
const float v_val = v[t];
|
||||
float y = 0.0;
|
||||
|
||||
|
||||
const A_TYPE v_val = v[t];
|
||||
A_TYPE y = 0.0;
|
||||
|
||||
for (uint j = 0; j < head_size; j += 4) {
|
||||
// Load values in blocks of 4
|
||||
vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
||||
vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
||||
vec4 tf_vec = vec4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
|
||||
vec4 td_vec = vec4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
|
||||
vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||
|
||||
// Compute kv products
|
||||
|
||||
vec4 kv = k_vec * v_val;
|
||||
|
||||
// Accumulate results
|
||||
|
||||
vec4 temp = tf_vec * kv + s_vec;
|
||||
y += dot(r_vec, temp);
|
||||
|
||||
// Update state
|
||||
|
||||
s_vec = s_vec * td_vec + kv;
|
||||
state[j] = s_vec.x;
|
||||
state[j+1] = s_vec.y;
|
||||
state[j+2] = s_vec.z;
|
||||
state[j+3] = s_vec.w;
|
||||
}
|
||||
|
||||
|
||||
dst[t] = y;
|
||||
}
|
||||
|
||||
// Write back state
|
||||
|
||||
for (uint i = 0; i < head_size; i++) {
|
||||
dst[T * C + batch_id * state_size + head_id * head_size * head_size
|
||||
dst[T * C + batch_id * state_size + head_id * head_size * head_size
|
||||
+ i * head_size + tid] = state[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue