WKV7 Vulkan & sycl

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
Molly Sophia 2025-01-16 22:49:12 +08:00
parent 9cd24dd3eb
commit e7794cb274
7 changed files with 334 additions and 87 deletions

View file

@ -26,7 +26,7 @@
#include "softmax.hpp"
#include "tsembd.hpp"
#include "im2col.hpp"
#include "wkv6.hpp"
#include "wkv.hpp"
#include "outprod.hpp"
#include "element_wise.hpp"
#include "gla.hpp"

View file

@ -4123,6 +4123,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
case GGML_OP_GATED_LINEAR_ATTN:
ggml_sycl_op_gated_linear_attn(ctx, dst);
break;
case GGML_OP_RWKV_WKV7:
ggml_sycl_op_rwkv_wkv7(ctx, dst);
break;
default:
return false;
}
@ -4594,6 +4597,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_RWKV_WKV6:
case GGML_OP_GATED_LINEAR_ATTN:
case GGML_OP_RWKV_WKV7:
return true;
default:
return false;

View file

@ -1,10 +1,10 @@
#include <sycl/sycl.hpp>
#include "wkv6.hpp"
#include "wkv.hpp"
constexpr int WKV_BLOCK_SIZE = 64; // Matching CUDA_WKV_BLOCK_SIZE
// Helper function for the main kernel
static void rwkv_wkv_f32_kernel(
static void rwkv_wkv6_f32_kernel(
const int B, const int T, const int C, const int H,
const float* k, const float* v, const float* r,
const float* tf, const float* td, const float* s,
@ -95,6 +95,88 @@ static void rwkv_wkv_f32_kernel(
}
}
static void rwkv_wkv7_f32_kernel(
const int B, const int T, const int C, const int H,
const float* r, const float* w, const float* k, const float* v,
const float* a, const float* b, const float* s,
float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
const int tid = item_ct1.get_local_id(2);
const int bid = item_ct1.get_group(2);
const int head_size = WKV_BLOCK_SIZE;
const int batch_i = bid / H;
const int head_i = bid % H;
const int state_size = C * head_size;
const int n_seq_tokens = T / B;
float* _r = shared_mem;
float* _w = _r + head_size;
float* _k = _w + head_size;
float* _a = _k + head_size;
float* _b = _a + head_size;
float state[WKV_BLOCK_SIZE];
#pragma unroll
for (int i = 0; i < head_size; i++) {
state[i] = s[batch_i * state_size + head_i * head_size * head_size + tid * head_size + i];
}
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
t += C) {
item_ct1.barrier(sycl::access::fence_space::local_space);
_r[tid] = r[t];
_w[tid] = w[t];
_k[tid] = k[t];
_a[tid] = a[t];
_b[tid] = b[t];
item_ct1.barrier(sycl::access::fence_space::local_space);
const float _v = v[t];
float y = 0, sa = 0;
sycl::float4 a4, s4;
#pragma unroll
for (int j = 0; j < head_size; j += 4) {
a4 = sycl::float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
sa += sycl::dot(a4, s4);
}
sycl::float4 r4, w4, k4, b4;
#pragma unroll
for (int j = 0; j < head_size; j += 4) {
r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
w4 = sycl::float4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
b4 = sycl::float4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
sycl::float4 kv4 = k4 * _v;
s4 = s4 * w4 + kv4 + sa * b4;
y += sycl::dot(r4, s4);
state[j] = s4.x();
state[j+1] = s4.y();
state[j+2] = s4.z();
state[j+3] = s4.w();
}
dst[t] = y;
}
#pragma unroll
for (int i = 0; i < head_size; i++) {
dst[T * C + batch_i * state_size + head_i * head_size * head_size + tid * head_size + i] = state[i];
}
}
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
const ggml_tensor *src0 = dst->src[0];
@ -131,7 +213,7 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
cgh.parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rwkv_wkv_f32_kernel(
rwkv_wkv6_f32_kernel(
B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
);
@ -141,3 +223,51 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
GGML_UNUSED(src0);
GGML_UNUSED(src1);
}
void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
const ggml_tensor *src0 = dst->src[0];
const ggml_tensor *src1 = dst->src[1];
const float* r_d = (const float*)dst->src[0]->data;
const float* w_d = (const float*)dst->src[1]->data;
const float* k_d = (const float*)dst->src[2]->data;
const float* v_d = (const float*)dst->src[3]->data;
const float* a_d = (const float*)dst->src[4]->data;
const float* b_d = (const float*)dst->src[5]->data;
const float* s_d = (const float*)dst->src[6]->data;
float* dst_d = (float*)dst->data;
const int64_t B = dst->src[6]->ne[1];
const int64_t T = dst->src[0]->ne[2];
const int64_t C = dst->ne[0];
const int64_t H = dst->src[0]->ne[1];
GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);
GGML_ASSERT(C % H == 0);
GGML_ASSERT(C / H == WKV_BLOCK_SIZE);
dpct::queue_ptr stream = ctx.stream();
// Calculate execution configuration
const size_t shared_mem_size = WKV_BLOCK_SIZE * 5 * sizeof(float); // For r, w, k, a, b
sycl::range<3> block_dims(1, 1, C / H);
sycl::range<3> grid_dims(1, 1, B * H);
// Submit kernel
stream->submit([&](sycl::handler& cgh) {
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
cgh.parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rwkv_wkv7_f32_kernel(
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
);
});
});
GGML_UNUSED(src0);
GGML_UNUSED(src1);
}

View file

@ -1,9 +1,11 @@
#ifndef GGML_SYCL_WKV6_HPP
#define GGML_SYCL_WKV6_HPP
#ifndef GGML_SYCL_WKV_HPP
#define GGML_SYCL_WKV_HPP
#include "common.hpp"
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
#endif // GGML_SYCL_WKV6_HPP

View file

@ -257,6 +257,7 @@ struct vk_device_struct {
vk_pipeline pipeline_timestep_embedding_f32;
vk_pipeline pipeline_pool2d_f32;
vk_pipeline pipeline_rwkv_wkv6_f32;
vk_pipeline pipeline_rwkv_wkv7_f32;
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
@ -543,7 +544,7 @@ struct vk_op_pool2d_push_constants {
int32_t p0; int32_t p1;
};
struct vk_op_rwkv_wkv6_push_constants {
struct vk_op_rwkv_wkv_push_constants {
uint32_t B;
uint32_t T;
uint32_t C;
@ -2164,7 +2165,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
for (auto &c : compiles) {
c.wait();
@ -5311,6 +5314,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_rwkv_wkv6_f32;
}
return nullptr;
case GGML_OP_RWKV_WKV7:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_rwkv_wkv7_f32;
}
return nullptr;
case GGML_OP_LEAKY_RELU:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_leaky_relu_f32;
@ -5769,23 +5777,17 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const
}, dryrun);
}
static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, bool dryrun = false) {
const ggml_tensor * k = dst->src[0];
const ggml_tensor * v = dst->src[1];
const ggml_tensor * r = dst->src[2];
const ggml_tensor * tf = dst->src[3];
const ggml_tensor * td = dst->src[4];
const ggml_tensor * state = dst->src[5];
static void ggml_vk_op_f32_rwkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv_push_constants&& pc, int version, bool dryrun = false) {
GGML_ASSERT(version == 6 || version == 7);
int num_srcs = version == 6 ? 6 : 7;
for (int i = 0; i < num_srcs; i++) {
GGML_ASSERT(!ggml_is_quantized(dst->src[i]->type));
}
GGML_ASSERT(!ggml_is_quantized(k->type));
GGML_ASSERT(!ggml_is_quantized(v->type));
GGML_ASSERT(!ggml_is_quantized(r->type));
GGML_ASSERT(!ggml_is_quantized(tf->type));
GGML_ASSERT(!ggml_is_quantized(td->type));
GGML_ASSERT(!ggml_is_quantized(state->type));
GGML_ASSERT(dst->buffer != nullptr);
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, k, v, r, dst, GGML_OP_RWKV_WKV6);
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op);
GGML_ASSERT(pipeline != nullptr);
if (dryrun) {
@ -5794,89 +5796,73 @@ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subc
}
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context;
ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context;
ggml_backend_vk_buffer_context * r_buf_ctx = (ggml_backend_vk_buffer_context *)r->buffer->context;
ggml_backend_vk_buffer_context * tf_buf_ctx = (ggml_backend_vk_buffer_context *)tf->buffer->context;
ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer->context;
ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer->context;
ggml_backend_vk_buffer_context * src_buf_ctxs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr };
for (int i = 0; i < num_srcs; i++) {
src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context;
}
ggml_vk_sync_buffers(subctx);
vk_buffer d_D = nullptr, d_K = nullptr, d_V = nullptr, d_R = nullptr, d_TF = nullptr, d_TD = nullptr, d_State = nullptr;
size_t k_offset = 0, v_offset = 0, r_offset = 0, tf_offset = 0, td_offset = 0, state_offset = 0, dst_offset = 0;
bool K_uma = false, V_uma = false, R_uma = false, TF_uma = false, TD_uma = false, STATE_uma = false, DST_uma = false;
vk_buffer d_D = nullptr, d_srcs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr };
size_t dst_offset = 0, src_offsets[7] = { 0, 0, 0, 0, 0, 0, 0 };
bool dst_uma = false, srcs_uma[7] = { false, false, false, false, false, false, false };
if (ctx->device->uma) {
ggml_vk_host_get(ctx->device, k->data, d_K, k_offset);
ggml_vk_host_get(ctx->device, v->data, d_V, v_offset);
ggml_vk_host_get(ctx->device, r->data, d_R, r_offset);
ggml_vk_host_get(ctx->device, tf->data, d_TF, tf_offset);
ggml_vk_host_get(ctx->device, td->data, d_TD, td_offset);
ggml_vk_host_get(ctx->device, state->data, d_State, state_offset);
for (int i = 0; i < num_srcs; i++) {
ggml_vk_host_get(ctx->device, dst->src[i]->data, d_srcs[i], src_offsets[i]);
srcs_uma[i] = d_srcs[i] != nullptr;
}
ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
K_uma = d_K != nullptr;
V_uma = d_V != nullptr;
R_uma = d_R != nullptr;
TF_uma = d_TF != nullptr;
TD_uma = d_TD != nullptr;
STATE_uma = d_State != nullptr;
DST_uma = d_D != nullptr;
dst_uma = d_D != nullptr;
}
if (!K_uma) {
d_K = k_buf_ctx->dev_buffer;
k_offset = vk_tensor_offset(k) + k->view_offs;
uint64_t src_sizes[7] = { 0, 0, 0, 0, 0, 0, 0 };
for (int i = 0; i < num_srcs; i++) {
src_sizes[i] = ggml_nbytes(dst->src[i]);
if (!srcs_uma[i]) {
d_srcs[i] = src_buf_ctxs[i]->dev_buffer;
src_offsets[i] = vk_tensor_offset(dst->src[i]) + dst->src[i]->view_offs;
}
if (!V_uma) {
d_V = v_buf_ctx->dev_buffer;
v_offset = vk_tensor_offset(v) + v->view_offs;
}
if (!R_uma) {
d_R = r_buf_ctx->dev_buffer;
r_offset = vk_tensor_offset(r) + r->view_offs;
}
if (!TF_uma) {
d_TF = tf_buf_ctx->dev_buffer;
tf_offset = vk_tensor_offset(tf) + tf->view_offs;
}
if (!TD_uma) {
d_TD = td_buf_ctx->dev_buffer;
td_offset = vk_tensor_offset(td) + td->view_offs;
}
if (!STATE_uma) {
d_State = state_buf_ctx->dev_buffer;
state_offset = vk_tensor_offset(state) + state->view_offs;
}
if (!DST_uma) {
const uint64_t dst_size = ggml_nbytes(dst);
if (!dst_uma) {
d_D = dst_buf_ctx->dev_buffer;
dst_offset = vk_tensor_offset(dst) + dst->view_offs;
}
const uint64_t k_size = ggml_nbytes(k);
const uint64_t v_size = ggml_nbytes(v);
const uint64_t r_size = ggml_nbytes(r);
const uint64_t tf_size = ggml_nbytes(tf);
const uint64_t td_size = ggml_nbytes(td);
const uint64_t state_size = ggml_nbytes(state);
const uint64_t dst_size = ggml_nbytes(dst);
std::array<uint32_t, 3> elements = {
(uint32_t)(pc.B * pc.H),
1,
1
};
if (version == 6) {
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
vk_subbuffer{ d_K, k_offset, k_size },
vk_subbuffer{ d_V, v_offset, v_size },
vk_subbuffer{ d_R, r_offset, r_size },
vk_subbuffer{ d_TF, tf_offset, tf_size },
vk_subbuffer{ d_TD, td_offset, td_size },
vk_subbuffer{ d_State, state_offset, state_size },
vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
vk_subbuffer{ d_D, dst_offset, dst_size }
}, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements);
}, sizeof(vk_op_rwkv_wkv_push_constants), &pc, elements);
} else if (version == 7) {
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] },
vk_subbuffer{ d_D, dst_offset, dst_size }
}, sizeof(vk_op_rwkv_wkv_push_constants), &pc, elements);
} else {
// shouldn't happen
GGML_ASSERT(false);
}
}
static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
@ -5885,7 +5871,7 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx,
const size_t n_heads = dst->src[0]->ne[1];
const size_t n_seqs = dst->src[5]->ne[1];
ggml_vk_op_f32_rwkv6(
ggml_vk_op_f32_rwkv(
ctx, subctx, dst,
{
(uint32_t)n_seqs,
@ -5893,6 +5879,26 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx,
(uint32_t)n_embed,
(uint32_t)n_heads,
},
6,
dryrun
);
}
static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
const size_t seq_length = dst->src[0]->ne[2];
const size_t n_embed = dst->ne[0];
const size_t n_heads = dst->src[0]->ne[1];
const size_t n_seqs = dst->src[6]->ne[1];
ggml_vk_op_f32_rwkv(
ctx, subctx, dst,
{
(uint32_t)n_seqs,
(uint32_t)seq_length,
(uint32_t)n_embed,
(uint32_t)n_heads,
},
7,
dryrun
);
}
@ -7048,6 +7054,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_POOL_2D:
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
case GGML_OP_LEAKY_RELU:
case GGML_OP_FLASH_ATTN_EXT:
break;
@ -7258,6 +7265,12 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
ggml_vk_rwkv_wkv6(ctx, compute_ctx, node, dryrun);
break;
case GGML_OP_RWKV_WKV7:
ggml_vk_rwkv_wkv7(ctx, compute_ctx, node, dryrun);
break;
default:
return false;
}
@ -7339,6 +7352,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_POOL_2D:
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
case GGML_OP_LEAKY_RELU:
case GGML_OP_REPEAT:
buf = tensor->buffer;
@ -8265,6 +8279,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_POOL_2D:
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
case GGML_OP_LEAKY_RELU:
return true;
default:
@ -8865,6 +8880,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
} else if (tensor->op == GGML_OP_RWKV_WKV6) {
tensor_clone = ggml_rwkv_wkv6(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3],
tensor->src[4], tensor->src[5]);
} else if (tensor->op == GGML_OP_RWKV_WKV7) {
tensor_clone = ggml_rwkv_wkv7(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3],
tensor->src[4], tensor->src[5], tensor->src[6]);
}
else {
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;

View file

@ -497,6 +497,8 @@ void process_shaders() {
string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
for (auto &c : compiles) {
c.wait();
}

View file

@ -0,0 +1,91 @@
#version 450
#extension GL_EXT_control_flow_attributes : require
#define BLOCK_SIZE 64
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
layout(push_constant) uniform Parameters {
uint B;
uint T;
uint C;
uint H;
};
layout(binding = 0) readonly buffer RBuf { A_TYPE r[]; };
layout(binding = 1) readonly buffer WBuf { A_TYPE w[]; };
layout(binding = 2) readonly buffer KBuf { A_TYPE k[]; };
layout(binding = 3) readonly buffer VBuf { A_TYPE v[]; };
layout(binding = 4) readonly buffer ABuf { A_TYPE a[]; };
layout(binding = 5) readonly buffer BBuf { A_TYPE b[]; };
layout(binding = 6) readonly buffer StateBuf { A_TYPE state_in[]; };
layout(binding = 7) buffer DstBuf { A_TYPE dst[]; };
shared A_TYPE _r[BLOCK_SIZE], _w[BLOCK_SIZE], _k[BLOCK_SIZE], _a[BLOCK_SIZE], _b[BLOCK_SIZE];
void main() {
const uint head_size = BLOCK_SIZE;
const uint batch_id = gl_WorkGroupID.x / H;
const uint head_id = gl_WorkGroupID.x % H;
const uint tid = gl_LocalInvocationID.x;
const uint state_size = C * head_size;
const uint n_seq_tokens = T / B;
if (batch_id >= B || head_id >= H) {
return;
}
A_TYPE state[BLOCK_SIZE];
[[unroll]] for (uint i = 0; i < head_size; i++) {
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
+ tid * head_size + i];
}
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
for (uint t = start_t; t < end_t; t += C) {
barrier();
_r[tid] = r[t];
_w[tid] = w[t];
_k[tid] = k[t];
_a[tid] = a[t];
_b[tid] = b[t];
barrier();
A_TYPE sa = 0.0;
[[unroll]] for (uint j = 0; j < head_size; j += 4) {
vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]);
vec4 a_vec = vec4(a[j], a[j+1], a[j+2], a[j+3]);
sa += dot(s_vec, a_vec);
}
const A_TYPE v_val = v[t];
A_TYPE y = 0.0;
[[unroll]] for (uint j = 0; j < head_size; j += 4) {
vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
vec4 w_vec = vec4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
vec4 b_vec = vec4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]);
vec4 kv = k_vec * v_val;
s_vec = s_vec * w_vec + kv + sa * b_vec;
y += dot(r_vec, s_vec);
state[j] = s_vec.x;
state[j+1] = s_vec.y;
state[j+2] = s_vec.z;
state[j+3] = s_vec.w;
}
dst[t] = y;
}
[[unroll]] for (uint i = 0; i < head_size; i++) {
dst[T * C + batch_id * state_size + head_id * head_size * head_size
+ tid * head_size + i] = state[i];
}
}