From 9cd24dd3ebb0550be92b99c6873c0f317484128c Mon Sep 17 00:00:00 2001 From: Molly Sophia Date: Thu, 16 Jan 2025 15:50:56 +0800 Subject: [PATCH] wkv7 CUDA impl Signed-off-by: Molly Sophia --- ggml/src/ggml-cuda/ggml-cuda.cu | 6 +- ggml/src/ggml-cuda/wkv.cu | 187 +++++++++++++++++++++++ ggml/src/ggml-cuda/{wkv6.cuh => wkv.cuh} | 2 + ggml/src/ggml-cuda/wkv6.cu | 89 ----------- tests/test-backend-ops.cpp | 3 + 5 files changed, 197 insertions(+), 90 deletions(-) create mode 100644 ggml/src/ggml-cuda/wkv.cu rename ggml/src/ggml-cuda/{wkv6.cuh => wkv.cuh} (62%) delete mode 100644 ggml/src/ggml-cuda/wkv6.cu diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 694081a89..c6c182320 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -36,7 +36,7 @@ #include "ggml-cuda/tsembd.cuh" #include "ggml-cuda/unary.cuh" #include "ggml-cuda/upscale.cuh" -#include "ggml-cuda/wkv6.cuh" +#include "ggml-cuda/wkv.cuh" #include "ggml-cuda/gla.cuh" #include "ggml.h" @@ -2296,6 +2296,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_RWKV_WKV6: ggml_cuda_op_rwkv_wkv6(ctx, dst); break; + case GGML_OP_RWKV_WKV7: + ggml_cuda_op_rwkv_wkv7(ctx, dst); + break; case GGML_OP_GATED_LINEAR_ATTN: ggml_cuda_op_gated_linear_attn(ctx, dst); break; @@ -3191,6 +3194,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_LEAKY_RELU: case GGML_OP_RWKV_WKV6: + case GGML_OP_RWKV_WKV7: case GGML_OP_GATED_LINEAR_ATTN: return true; case GGML_OP_FLASH_ATTN_EXT: { diff --git a/ggml/src/ggml-cuda/wkv.cu b/ggml/src/ggml-cuda/wkv.cu new file mode 100644 index 000000000..96df96cfa --- /dev/null +++ b/ggml/src/ggml-cuda/wkv.cu @@ -0,0 +1,187 @@ +#include "common.cuh" +#include "wkv.cuh" + +static __global__ void rwkv_wkv6_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) { + const int tid = threadIdx.x; + const int bid = blockIdx.x; + + const int head_size = CUDA_WKV_BLOCK_SIZE; + const int batch_i = bid / H; + const int head_i = bid % H; + const int state_size = C * head_size; + const int n_seq_tokens = T / B; + + float state[head_size]; + __shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size]; + + #pragma unroll + for (int i = 0; i < head_size; i++) { + state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid]; + } + + __syncthreads(); + _tf[tid] = tf[head_i * head_size + tid]; + __syncthreads(); + + for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) { + __syncthreads(); + _k[tid] = k[t]; + _r[tid] = r[t]; + _td[tid] = td[t]; + __syncthreads(); + + const float _v = v[t]; + float y = 0; + for (int j = 0; j < head_size; j += 4) { + const float4& k = (float4&)(_k[j]); + const float4& r = (float4&)(_r[j]); + const float4& tf = (float4&)(_tf[j]); + const float4& td = (float4&)(_td[j]); + float4& s = (float4&)(state[j]); + float4 kv; + + kv.x = k.x * _v; + kv.y = k.y * _v; + kv.z = k.z * _v; + kv.w = k.w * _v; + + y += r.x * (tf.x * kv.x + s.x); + y += r.y * (tf.y * kv.y + s.y); + y += r.z * (tf.z * kv.z + s.z); + y += r.w * (tf.w * kv.w + s.w); + + s.x = s.x * td.x + kv.x; + s.y = s.y * td.y + kv.y; + s.z = s.z * td.z + kv.z; + s.w = s.w * td.w + kv.w; + } + dst[t] = y; + } + + #pragma unroll + for (int i = 0; i < head_size; i++) { + dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i]; + } +} + +static __global__ void rwkv_wkv7_f32(const int B, const int T, const int C, const int H, const float * r, const float * w, const float * k, const float * v, const float * a, const float * b, const float * s, float * dst) { + const int tid = threadIdx.x; + const int bid = blockIdx.x; + + const int head_size = CUDA_WKV_BLOCK_SIZE; + const int batch_i = bid / H; + const int head_i = bid % H; + const int state_size = C * head_size; + const int n_seq_tokens = T / B; + + float state[head_size]; + __shared__ float _r[head_size], _w[head_size], _k[head_size], _a[head_size], _b[head_size]; + + #pragma unroll + for (int i = 0; i < head_size; i++) { + state[i] = s[batch_i * state_size + head_i * head_size * head_size + tid * head_size + i]; + } + + for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) { + __syncthreads(); + _r[tid] = r[t]; + _w[tid] = w[t]; + _k[tid] = k[t]; + _a[tid] = a[t]; + _b[tid] = b[t]; + __syncthreads(); + + float sa = 0; + #pragma unroll + for (int j = 0; j < head_size; j += 4) + { + const float4& a = (float4&)(_a[j]); + const float4& s = (float4&)(state[j]); + sa += a.x * s.x; + sa += a.y * s.y; + sa += a.z * s.z; + sa += a.w * s.w; + } + + const float _v = v[t]; + float y = 0; + for (int j = 0; j < head_size; j += 4) { + const float4& r = (float4&)(_r[j]); + const float4& w = (float4&)(_w[j]); + const float4& k = (float4&)(_k[j]); + const float4& b = (float4&)(_b[j]); + float4& s = (float4&)(state[j]); + float4 kv; + + kv.x = k.x * _v; + kv.y = k.y * _v; + kv.z = k.z * _v; + kv.w = k.w * _v; + + s.x = s.x * w.x + kv.x + sa * b.x; + s.y = s.y * w.y + kv.y + sa * b.y; + s.z = s.z * w.z + kv.z + sa * b.z; + s.w = s.w * w.w + kv.w + sa * b.w; + + y += s.x * r.x; + y += s.y * r.y; + y += s.z * r.z; + y += s.w * r.w; + } + dst[t] = y; + } + + #pragma unroll + for (int i = 0; i < head_size; i++) { + dst[T * C + batch_i * state_size + head_i * head_size * head_size + tid * head_size + i] = state[i]; + } +} + +void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const float * k_d = (const float *)dst->src[0]->data; + const float * v_d = (const float *)dst->src[1]->data; + const float * r_d = (const float *)dst->src[2]->data; + const float * tf_d = (const float *)dst->src[3]->data; + const float * td_d = (const float *)dst->src[4]->data; + const float * s_d = (const float *)dst->src[5]->data; + + const int64_t B = dst->src[5]->ne[1]; + const int64_t T = dst->src[0]->ne[2]; + const int64_t C = dst->ne[0]; + const int64_t H = dst->src[0]->ne[1]; + + float * dst_d = (float *)dst->data; + + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32); + GGML_ASSERT(C % H == 0); + GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE); // The current cuda kernel is designed for RWKV6, HEAD_SIZE == 64 + + rwkv_wkv6_f32<<>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d); +} + +void ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const float * r_d = (const float *)dst->src[0]->data; + const float * w_d = (const float *)dst->src[1]->data; + const float * k_d = (const float *)dst->src[2]->data; + const float * v_d = (const float *)dst->src[3]->data; + const float * a_d = (const float *)dst->src[4]->data; + const float * b_d = (const float *)dst->src[5]->data; + const float * s_d = (const float *)dst->src[6]->data; + + const int64_t B = dst->src[6]->ne[1]; + const int64_t T = dst->src[0]->ne[2]; + const int64_t C = dst->ne[0]; + const int64_t H = dst->src[0]->ne[1]; + + float * dst_d = (float *)dst->data; + + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32); + GGML_ASSERT(C % H == 0); + GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE); + + rwkv_wkv7_f32<<>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d); +} diff --git a/ggml/src/ggml-cuda/wkv6.cuh b/ggml/src/ggml-cuda/wkv.cuh similarity index 62% rename from ggml/src/ggml-cuda/wkv6.cuh rename to ggml/src/ggml-cuda/wkv.cuh index a7124ee51..9623dd7f8 100644 --- a/ggml/src/ggml-cuda/wkv6.cuh +++ b/ggml/src/ggml-cuda/wkv.cuh @@ -3,3 +3,5 @@ #define CUDA_WKV_BLOCK_SIZE 64 void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/wkv6.cu b/ggml/src/ggml-cuda/wkv6.cu deleted file mode 100644 index bbdafbee5..000000000 --- a/ggml/src/ggml-cuda/wkv6.cu +++ /dev/null @@ -1,89 +0,0 @@ -#include "common.cuh" -#include "wkv6.cuh" - -static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) { - const int tid = threadIdx.x; - const int bid = blockIdx.x; - - const int head_size = CUDA_WKV_BLOCK_SIZE; - const int batch_i = bid / H; - const int head_i = bid % H; - const int state_size = C * head_size; - const int n_seq_tokens = T / B; - - float state[head_size]; - __shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size]; - - #pragma unroll - for (int i = 0; i < head_size; i++) { - state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid]; - } - - __syncthreads(); - _tf[tid] = tf[head_i * head_size + tid]; - __syncthreads(); - - for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) { - __syncthreads(); - _k[tid] = k[t]; - _r[tid] = r[t]; - _td[tid] = td[t]; - __syncthreads(); - - const float _v = v[t]; - float y = 0; - for (int j = 0; j < head_size; j += 4) { - const float4& k = (float4&)(_k[j]); - const float4& r = (float4&)(_r[j]); - const float4& tf = (float4&)(_tf[j]); - const float4& td = (float4&)(_td[j]); - float4& s = (float4&)(state[j]); - float4 kv; - - kv.x = k.x * _v; - kv.y = k.y * _v; - kv.z = k.z * _v; - kv.w = k.w * _v; - - y += r.x * (tf.x * kv.x + s.x); - y += r.y * (tf.y * kv.y + s.y); - y += r.z * (tf.z * kv.z + s.z); - y += r.w * (tf.w * kv.w + s.w); - - s.x = s.x * td.x + kv.x; - s.y = s.y * td.y + kv.y; - s.z = s.z * td.z + kv.z; - s.w = s.w * td.w + kv.w; - } - dst[t] = y; - } - - #pragma unroll - for (int i = 0; i < head_size; i++) { - dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i]; - } -} - -void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const float * k_d = (const float *)dst->src[0]->data; - const float * v_d = (const float *)dst->src[1]->data; - const float * r_d = (const float *)dst->src[2]->data; - const float * tf_d = (const float *)dst->src[3]->data; - const float * td_d = (const float *)dst->src[4]->data; - const float * s_d = (const float *)dst->src[5]->data; - - const int64_t B = dst->src[5]->ne[1]; - const int64_t T = dst->src[0]->ne[2]; - const int64_t C = dst->ne[0]; - const int64_t H = dst->src[0]->ne[1]; - - float * dst_d = (float *)dst->data; - - cudaStream_t stream = ctx.stream(); - - GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32); - GGML_ASSERT(C % H == 0); - GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE); // The current cuda kernel is designed for RWKV6, HEAD_SIZE == 64 - - rwkv_wkv_f32<<>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d); -} diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 5418dea48..e5b6d0d35 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1893,6 +1893,9 @@ struct test_rwkv_wkv7 : public test_case { ggml_tensor * v = ggml_new_tensor(ctx, type, 3, std::vector{ head_size, head_count, n_tokens }.data()); ggml_tensor * a = ggml_new_tensor(ctx, type, 3, std::vector{ head_size, head_count, n_tokens }.data()); ggml_tensor * b = ggml_new_tensor(ctx, type, 3, std::vector{ head_size, head_count, n_tokens }.data()); + // Outputs may become NaN with long seqlen without these normalization + a = ggml_l2_norm(ctx, a, 1e-7F); + b = ggml_l2_norm(ctx, b, 1e-7F); ggml_tensor * s = ggml_new_tensor(ctx, type, 2, std::vector{ head_size * head_size * head_count, n_seqs }.data()); ggml_tensor * out = ggml_rwkv_wkv7(ctx, r, w, k, v, a, b, s); return out;