rwkv6: update cuda file name
This commit is contained in:
parent
b4254c5550
commit
e198f7b9df
4 changed files with 19 additions and 10 deletions
|
@ -36,7 +36,7 @@
|
|||
#include "ggml-cuda/tsembd.cuh"
|
||||
#include "ggml-cuda/unary.cuh"
|
||||
#include "ggml-cuda/upscale.cuh"
|
||||
#include "ggml-cuda/rwkv-wkv.cuh"
|
||||
#include "ggml-cuda/wkv6.cuh"
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
#include "common.cuh"
|
||||
#include "rwkv-wkv.cuh"
|
||||
#include "wkv6.cuh"
|
||||
|
||||
static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
|
||||
const int tid = threadIdx.x;
|
|
@ -3074,7 +3074,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|||
"WIN_UNPART",
|
||||
"GET_REL_POS",
|
||||
"ADD_REL_POS",
|
||||
"RWKV_WKV",
|
||||
"RWKV_WKV6",
|
||||
|
||||
"UNARY",
|
||||
|
||||
|
@ -16618,11 +16618,13 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
|
|||
float * dst_data = (float *) dst->data;
|
||||
float * state = ((float *) dst->data) + C * T;
|
||||
|
||||
if (params->ith != 0) {
|
||||
if ((size_t)params->ith >= H) {
|
||||
return;
|
||||
}
|
||||
|
||||
memset(dst_data, 0, T * C * sizeof(float));
|
||||
size_t h_start = (H * params->ith) / params->nth;
|
||||
size_t h_end = ((H * (size_t)(params->ith + 1)) / (size_t)params->nth < H) ?
|
||||
(H * (size_t)(params->ith + 1)) / (size_t)params->nth : H;
|
||||
|
||||
float * k = (float *) dst->src[0]->data;
|
||||
float * v = (float *) dst->src[1]->data;
|
||||
|
@ -16635,6 +16637,13 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
|
|||
size_t h_stride = C / H;
|
||||
size_t h_stride_2d = head_size * head_size;
|
||||
|
||||
if (params->ith == 0) {
|
||||
memset(dst_data, 0, T * C * sizeof(float));
|
||||
}
|
||||
ggml_barrier(params->threadpool);
|
||||
|
||||
|
||||
|
||||
#ifdef __AVX2__
|
||||
// AVX2 uses 256-bit vectors = 8 float32
|
||||
const int vec_size = 8;
|
||||
|
@ -16646,7 +16655,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
|
|||
float * state_cur = state + state_offset;
|
||||
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
|
||||
|
||||
for (size_t h = 0; h < H; h++) {
|
||||
for (size_t h = h_start; h < h_end; h++) {
|
||||
size_t h_offset = h * h_stride;
|
||||
size_t t_h_offset = t_offset + h_offset;
|
||||
size_t h_2d_offset = h * h_stride_2d;
|
||||
|
@ -16724,7 +16733,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
|
|||
float * state_cur = state + state_offset;
|
||||
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
|
||||
|
||||
for (size_t h = 0; h < H; h++) {
|
||||
for (size_t h = h_start; h < h_end; h++) {
|
||||
size_t h_offset = h * h_stride;
|
||||
size_t t_h_offset = t_offset + h_offset;
|
||||
size_t h_2d_offset = h * h_stride_2d;
|
||||
|
@ -16806,7 +16815,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
|
|||
float * state_cur = state + state_offset;
|
||||
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
|
||||
|
||||
for (size_t h = 0; h < H; h++) {
|
||||
for (size_t h = h_start; h < h_end; h++) {
|
||||
size_t h_offset = h * h_stride;
|
||||
size_t t_h_offset = t_offset + h_offset;
|
||||
size_t h_2d_offset = h * h_stride_2d;
|
||||
|
@ -16867,7 +16876,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
|
|||
float * state_cur = state + state_offset;
|
||||
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
|
||||
|
||||
for (size_t h = 0; h < H; h++) {
|
||||
for (size_t h = h_start; h < h_end; h++) {
|
||||
size_t h_offset = h * h_stride;
|
||||
size_t t_h_offset = t_offset + h_offset;
|
||||
size_t h_2d_offset = h * h_stride_2d;
|
||||
|
@ -16959,7 +16968,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
|
|||
float * state_cur = state + state_offset;
|
||||
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
|
||||
|
||||
for (size_t h = 0; h < H; h++) {
|
||||
for (size_t h = h_start; h < h_end; h++) {
|
||||
size_t h_offset = h * h_stride;
|
||||
size_t t_h_offset = t_offset + h_offset;
|
||||
size_t h_2d_offset = h * h_stride_2d;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue