From e198f7b9df27385d48e8386acfc4d96348380918 Mon Sep 17 00:00:00 2001 From: Zhiyuan Li Date: Fri, 1 Nov 2024 20:58:17 +1100 Subject: [PATCH] rwkv6: update cuda file name --- ggml/src/ggml-cuda.cu | 2 +- ggml/src/ggml-cuda/{rwkv-wkv.cu => wkv6.cu} | 2 +- ggml/src/ggml-cuda/{rwkv-wkv.cuh => wkv6.cuh} | 0 ggml/src/ggml.c | 25 +++++++++++++------ 4 files changed, 19 insertions(+), 10 deletions(-) rename ggml/src/ggml-cuda/{rwkv-wkv.cu => wkv6.cu} (99%) rename ggml/src/ggml-cuda/{rwkv-wkv.cuh => wkv6.cuh} (100%) diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 9ae59265e..88d53d7cb 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -36,7 +36,7 @@ #include "ggml-cuda/tsembd.cuh" #include "ggml-cuda/unary.cuh" #include "ggml-cuda/upscale.cuh" -#include "ggml-cuda/rwkv-wkv.cuh" +#include "ggml-cuda/wkv6.cuh" #include #include diff --git a/ggml/src/ggml-cuda/rwkv-wkv.cu b/ggml/src/ggml-cuda/wkv6.cu similarity index 99% rename from ggml/src/ggml-cuda/rwkv-wkv.cu rename to ggml/src/ggml-cuda/wkv6.cu index 761a81d75..b4d6eb9a8 100644 --- a/ggml/src/ggml-cuda/rwkv-wkv.cu +++ b/ggml/src/ggml-cuda/wkv6.cu @@ -1,5 +1,5 @@ #include "common.cuh" -#include "rwkv-wkv.cuh" +#include "wkv6.cuh" static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) { const int tid = threadIdx.x; diff --git a/ggml/src/ggml-cuda/rwkv-wkv.cuh b/ggml/src/ggml-cuda/wkv6.cuh similarity index 100% rename from ggml/src/ggml-cuda/rwkv-wkv.cuh rename to ggml/src/ggml-cuda/wkv6.cuh diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 2555e7509..16ecf8bf4 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3074,7 +3074,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "WIN_UNPART", "GET_REL_POS", "ADD_REL_POS", - "RWKV_WKV", + "RWKV_WKV6", "UNARY", @@ -16618,11 +16618,13 @@ static void ggml_compute_forward_rwkv_wkv6_f32( float * dst_data = (float *) dst->data; float * state = ((float *) dst->data) + C * T; - if (params->ith != 0) { + if ((size_t)params->ith >= H) { return; } - memset(dst_data, 0, T * C * sizeof(float)); + size_t h_start = (H * params->ith) / params->nth; + size_t h_end = ((H * (size_t)(params->ith + 1)) / (size_t)params->nth < H) ? + (H * (size_t)(params->ith + 1)) / (size_t)params->nth : H; float * k = (float *) dst->src[0]->data; float * v = (float *) dst->src[1]->data; @@ -16635,6 +16637,13 @@ static void ggml_compute_forward_rwkv_wkv6_f32( size_t h_stride = C / H; size_t h_stride_2d = head_size * head_size; + if (params->ith == 0) { + memset(dst_data, 0, T * C * sizeof(float)); + } + ggml_barrier(params->threadpool); + + + #ifdef __AVX2__ // AVX2 uses 256-bit vectors = 8 float32 const int vec_size = 8; @@ -16646,7 +16655,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32( float * state_cur = state + state_offset; float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset; - for (size_t h = 0; h < H; h++) { + for (size_t h = h_start; h < h_end; h++) { size_t h_offset = h * h_stride; size_t t_h_offset = t_offset + h_offset; size_t h_2d_offset = h * h_stride_2d; @@ -16724,7 +16733,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32( float * state_cur = state + state_offset; float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset; - for (size_t h = 0; h < H; h++) { + for (size_t h = h_start; h < h_end; h++) { size_t h_offset = h * h_stride; size_t t_h_offset = t_offset + h_offset; size_t h_2d_offset = h * h_stride_2d; @@ -16806,7 +16815,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32( float * state_cur = state + state_offset; float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset; - for (size_t h = 0; h < H; h++) { + for (size_t h = h_start; h < h_end; h++) { size_t h_offset = h * h_stride; size_t t_h_offset = t_offset + h_offset; size_t h_2d_offset = h * h_stride_2d; @@ -16867,7 +16876,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32( float * state_cur = state + state_offset; float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset; - for (size_t h = 0; h < H; h++) { + for (size_t h = h_start; h < h_end; h++) { size_t h_offset = h * h_stride; size_t t_h_offset = t_offset + h_offset; size_t h_2d_offset = h * h_stride_2d; @@ -16959,7 +16968,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32( float * state_cur = state + state_offset; float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset; - for (size_t h = 0; h < H; h++) { + for (size_t h = h_start; h < h_end; h++) { size_t h_offset = h * h_stride; size_t t_h_offset = t_offset + h_offset; size_t h_2d_offset = h * h_stride_2d;