rwkv6: update cuda file name

This commit is contained in:
Zhiyuan Li 2024-11-01 20:58:17 +11:00
parent b4254c5550
commit e198f7b9df
4 changed files with 19 additions and 10 deletions

View file

@ -36,7 +36,7 @@
#include "ggml-cuda/tsembd.cuh"
#include "ggml-cuda/unary.cuh"
#include "ggml-cuda/upscale.cuh"
#include "ggml-cuda/rwkv-wkv.cuh"
#include "ggml-cuda/wkv6.cuh"
#include <algorithm>
#include <array>

View file

@ -1,5 +1,5 @@
#include "common.cuh"
#include "rwkv-wkv.cuh"
#include "wkv6.cuh"
static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
const int tid = threadIdx.x;

View file

@ -3074,7 +3074,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"WIN_UNPART",
"GET_REL_POS",
"ADD_REL_POS",
"RWKV_WKV",
"RWKV_WKV6",
"UNARY",
@ -16618,11 +16618,13 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
float * dst_data = (float *) dst->data;
float * state = ((float *) dst->data) + C * T;
if (params->ith != 0) {
if ((size_t)params->ith >= H) {
return;
}
memset(dst_data, 0, T * C * sizeof(float));
size_t h_start = (H * params->ith) / params->nth;
size_t h_end = ((H * (size_t)(params->ith + 1)) / (size_t)params->nth < H) ?
(H * (size_t)(params->ith + 1)) / (size_t)params->nth : H;
float * k = (float *) dst->src[0]->data;
float * v = (float *) dst->src[1]->data;
@ -16635,6 +16637,13 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
size_t h_stride = C / H;
size_t h_stride_2d = head_size * head_size;
if (params->ith == 0) {
memset(dst_data, 0, T * C * sizeof(float));
}
ggml_barrier(params->threadpool);
#ifdef __AVX2__
// AVX2 uses 256-bit vectors = 8 float32
const int vec_size = 8;
@ -16646,7 +16655,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
float * state_cur = state + state_offset;
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
for (size_t h = 0; h < H; h++) {
for (size_t h = h_start; h < h_end; h++) {
size_t h_offset = h * h_stride;
size_t t_h_offset = t_offset + h_offset;
size_t h_2d_offset = h * h_stride_2d;
@ -16724,7 +16733,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
float * state_cur = state + state_offset;
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
for (size_t h = 0; h < H; h++) {
for (size_t h = h_start; h < h_end; h++) {
size_t h_offset = h * h_stride;
size_t t_h_offset = t_offset + h_offset;
size_t h_2d_offset = h * h_stride_2d;
@ -16806,7 +16815,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
float * state_cur = state + state_offset;
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
for (size_t h = 0; h < H; h++) {
for (size_t h = h_start; h < h_end; h++) {
size_t h_offset = h * h_stride;
size_t t_h_offset = t_offset + h_offset;
size_t h_2d_offset = h * h_stride_2d;
@ -16867,7 +16876,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
float * state_cur = state + state_offset;
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
for (size_t h = 0; h < H; h++) {
for (size_t h = h_start; h < h_end; h++) {
size_t h_offset = h * h_stride;
size_t t_h_offset = t_offset + h_offset;
size_t h_2d_offset = h * h_stride_2d;
@ -16959,7 +16968,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
float * state_cur = state + state_offset;
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
for (size_t h = 0; h < H; h++) {
for (size_t h = h_start; h < h_end; h++) {
size_t h_offset = h * h_stride;
size_t t_h_offset = t_offset + h_offset;
size_t h_2d_offset = h * h_stride_2d;