rwkv6: rename to wkv6
This commit is contained in:
parent
42cadc74bd
commit
f66c75a495
7 changed files with 38 additions and 36 deletions
|
@ -509,7 +509,7 @@ extern "C" {
|
|||
GGML_OP_WIN_UNPART,
|
||||
GGML_OP_GET_REL_POS,
|
||||
GGML_OP_ADD_REL_POS,
|
||||
GGML_OP_RWKV_WKV,
|
||||
GGML_OP_RWKV_WKV6,
|
||||
|
||||
GGML_OP_UNARY,
|
||||
|
||||
|
@ -1879,7 +1879,7 @@ extern "C" {
|
|||
struct ggml_tensor * pw,
|
||||
struct ggml_tensor * ph);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_rwkv_wkv(
|
||||
GGML_API struct ggml_tensor * ggml_rwkv_wkv6(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * k,
|
||||
struct ggml_tensor * v,
|
||||
|
|
|
@ -2313,8 +2313,8 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||
ggml_cuda_cross_entropy_loss(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_RWKV_WKV:
|
||||
ggml_cuda_op_rwkv_wkv(ctx, dst);
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
ggml_cuda_op_rwkv_wkv6(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
||||
ggml_cuda_cross_entropy_loss_back(ctx, dst);
|
||||
|
@ -3147,7 +3147,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
case GGML_OP_ARANGE:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_RWKV_WKV:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
return true;
|
||||
case GGML_OP_FLASH_ATTN_EXT: {
|
||||
#ifndef FLASH_ATTN_AVAILABLE
|
||||
|
|
|
@ -64,7 +64,7 @@ static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const
|
|||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const float * k_d = (const float *)dst->src[0]->data;
|
||||
const float * v_d = (const float *)dst->src[1]->data;
|
||||
const float * r_d = (const float *)dst->src[2]->data;
|
||||
|
|
|
@ -2,4 +2,4 @@
|
|||
|
||||
#define CUDA_WKV_BLOCK_SIZE 64
|
||||
|
||||
void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
|
|
@ -3169,7 +3169,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|||
"win_unpart(x)",
|
||||
"get_rel_pos(x)",
|
||||
"add_rel_pos(x)",
|
||||
"rwkv_wkv(k, v, r, tf, td, s)",
|
||||
"rwkv_wkv6(k, v, r, tf, td, s)",
|
||||
|
||||
"unary(x)",
|
||||
|
||||
|
@ -7361,9 +7361,9 @@ struct ggml_tensor * ggml_add_rel_pos_inplace(
|
|||
return ggml_add_rel_pos_impl(ctx, a, pw, ph, true);
|
||||
}
|
||||
|
||||
// ggml_rwkv_wkv
|
||||
// ggml_rwkv_wkv6
|
||||
|
||||
struct ggml_tensor * ggml_rwkv_wkv(
|
||||
struct ggml_tensor * ggml_rwkv_wkv6(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * k,
|
||||
struct ggml_tensor * v,
|
||||
|
@ -7395,7 +7395,7 @@ struct ggml_tensor * ggml_rwkv_wkv(
|
|||
const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
|
||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
||||
|
||||
result->op = GGML_OP_RWKV_WKV;
|
||||
result->op = GGML_OP_RWKV_WKV6;
|
||||
result->src[0] = k;
|
||||
result->src[1] = v;
|
||||
result->src[2] = r;
|
||||
|
@ -16604,15 +16604,16 @@ static void ggml_compute_forward_add_rel_pos(
|
|||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_rwkv_wkv
|
||||
// ggml_compute_forward_rwkv_wkv6
|
||||
|
||||
static void ggml_compute_forward_rwkv_wkv_f32(
|
||||
static void ggml_compute_forward_rwkv_wkv6_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
const size_t T = dst->src[1]->ne[3];
|
||||
const size_t C = dst->ne[0];
|
||||
const size_t H = dst->src[1]->ne[2];
|
||||
const size_t n_seqs = dst->src[5]->ne[1];
|
||||
const size_t head_size = C / H;
|
||||
|
||||
float * dst_data = (float *) dst->data;
|
||||
float * state = ((float *) dst->data) + C * T;
|
||||
|
@ -16629,10 +16630,10 @@ static void ggml_compute_forward_rwkv_wkv_f32(
|
|||
float * time_faaaa = (float *) dst->src[3]->data;
|
||||
float * time_decay = (float *) dst->src[4]->data;
|
||||
|
||||
size_t t_stride = H * (C / H);
|
||||
size_t t_stride = H * head_size;
|
||||
|
||||
size_t h_stride = C / H;
|
||||
size_t h_stride_2d = (C / H) * (C / H);
|
||||
size_t h_stride_2d = head_size * head_size;
|
||||
|
||||
// basically fused operations:
|
||||
// dst = r @ (time_faaaa * (k @ v) + state),
|
||||
|
@ -16640,7 +16641,7 @@ static void ggml_compute_forward_rwkv_wkv_f32(
|
|||
// recursive through each token
|
||||
for (size_t t = 0; t < T; t++) {
|
||||
size_t t_offset = t * t_stride;
|
||||
size_t state_offset = (C / H) * C * (t / (T / n_seqs));
|
||||
size_t state_offset = head_size * C * (t / (T / n_seqs));
|
||||
float * state_cur = state + state_offset;
|
||||
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
|
||||
|
||||
|
@ -16649,7 +16650,7 @@ static void ggml_compute_forward_rwkv_wkv_f32(
|
|||
size_t t_h_offset = t_offset + h_offset;
|
||||
size_t h_2d_offset = h * h_stride_2d;
|
||||
|
||||
for (size_t i = 0; i < C / H; i++) {
|
||||
for (size_t i = 0; i < head_size; i++) {
|
||||
size_t t_h_i_offset = t_h_offset + i;
|
||||
size_t h_i_offset = h_offset + i;
|
||||
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
|
||||
|
@ -16660,7 +16661,7 @@ static void ggml_compute_forward_rwkv_wkv_f32(
|
|||
// RWKV v6: different time_decay for each token.
|
||||
float time_decay_val = time_decay[t_h_i_offset];
|
||||
|
||||
for (size_t j = 0; j < C / H; j ++) {
|
||||
for (size_t j = 0; j < head_size; j ++) {
|
||||
size_t t_h_j_offset = t_h_offset + j;
|
||||
size_t h_2d_i_j_offset = h_2d_i_offset + j;
|
||||
|
||||
|
@ -16676,7 +16677,8 @@ static void ggml_compute_forward_rwkv_wkv_f32(
|
|||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_rwkv_wkv(
|
||||
|
||||
static void ggml_compute_forward_rwkv_wkv6(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
|
@ -16685,7 +16687,7 @@ static void ggml_compute_forward_rwkv_wkv(
|
|||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_rwkv_wkv_f32(params, dst);
|
||||
ggml_compute_forward_rwkv_wkv6_f32(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
|
@ -17437,9 +17439,9 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|||
{
|
||||
ggml_compute_forward_add_rel_pos(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_RWKV_WKV:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
{
|
||||
ggml_compute_forward_rwkv_wkv(params, tensor);
|
||||
ggml_compute_forward_rwkv_wkv6(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_MAP_UNARY:
|
||||
{
|
||||
|
@ -18628,7 +18630,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
} break;
|
||||
case GGML_OP_GET_REL_POS:
|
||||
case GGML_OP_ADD_REL_POS:
|
||||
case GGML_OP_RWKV_WKV:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_MAP_UNARY:
|
||||
case GGML_OP_MAP_BINARY:
|
||||
case GGML_OP_MAP_CUSTOM1_F32:
|
||||
|
@ -19278,7 +19280,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|||
case GGML_OP_WIN_PART:
|
||||
case GGML_OP_WIN_UNPART:
|
||||
case GGML_OP_GET_REL_POS:
|
||||
case GGML_OP_RWKV_WKV:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_MAP_UNARY:
|
||||
case GGML_OP_MAP_BINARY:
|
||||
case GGML_OP_MAP_CUSTOM1_F32:
|
||||
|
|
|
@ -7011,7 +7011,7 @@ static const std::map<llm_tensor, llm_tensor_info> llm_tensor_info_mapping = {
|
|||
{LLM_TENSOR_TIME_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_TIME_MIX_LERP_G, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_TIME_MIX_DECAY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV}},
|
||||
{LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV6}},
|
||||
{LLM_TENSOR_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_ATTN_NORM_2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_ATTN_OUT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
|
@ -7127,7 +7127,7 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w
|
|||
ggml_tensor * C = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs);
|
||||
op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C);
|
||||
} break;
|
||||
case GGML_OP_RWKV_WKV:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
{
|
||||
// FIXME
|
||||
const int64_t S = 123;
|
||||
|
@ -7140,7 +7140,7 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w
|
|||
ggml_tensor * tf = w;
|
||||
ggml_tensor * td = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens);
|
||||
ggml_tensor * state = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, n_seqs, S, H);
|
||||
op_tensor = ggml_rwkv_wkv(ctx, k, v, r, tf, td, state);
|
||||
op_tensor = ggml_rwkv_wkv6(ctx, k, v, r, tf, td, state);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("%s: missing test for op %s for tensor %s", __func__, ggml_op_name(op), w->name);
|
||||
|
@ -10083,7 +10083,7 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix(
|
|||
v = ggml_transpose(ctx, v);
|
||||
r = ggml_transpose(ctx, r);
|
||||
|
||||
struct ggml_tensor * wkv_output = ggml_rwkv_wkv(ctx, k, v, r, layer->time_mix_first, w, *wkv_state);
|
||||
struct ggml_tensor * wkv_output = ggml_rwkv_wkv6(ctx, k, v, r, layer->time_mix_first, w, *wkv_state);
|
||||
cur = ggml_view_1d(ctx, wkv_output, n_embd * n_tokens, 0);
|
||||
*wkv_state = ggml_view_1d(ctx, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float));
|
||||
|
||||
|
|
|
@ -1613,8 +1613,8 @@ struct test_ssm_scan : public test_case {
|
|||
}
|
||||
};
|
||||
|
||||
// GGML_OP_RWKV_WKV
|
||||
struct test_rwkv_wkv : public test_case {
|
||||
// GGML_OP_RWKV_WKV6
|
||||
struct test_rwkv_wkv6 : public test_case {
|
||||
const ggml_type type;
|
||||
|
||||
const int64_t head_count;
|
||||
|
@ -1626,7 +1626,7 @@ struct test_rwkv_wkv : public test_case {
|
|||
return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs);
|
||||
}
|
||||
|
||||
test_rwkv_wkv(ggml_type type = GGML_TYPE_F32,
|
||||
test_rwkv_wkv6(ggml_type type = GGML_TYPE_F32,
|
||||
int64_t head_count = 32, int64_t head_size = 64, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)
|
||||
: type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
|
||||
|
||||
|
@ -1638,7 +1638,7 @@ struct test_rwkv_wkv : public test_case {
|
|||
ggml_tensor * tf = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size, head_count }.data());
|
||||
ggml_tensor * td = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
|
||||
ggml_tensor * s = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size * head_size * head_count, n_seqs }.data());
|
||||
ggml_tensor * out = ggml_rwkv_wkv(ctx, k, v, r, tf, td, s);
|
||||
ggml_tensor * out = ggml_rwkv_wkv6(ctx, k, v, r, tf, td, s);
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
@ -3498,10 +3498,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
|
||||
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1024, 32, 4));
|
||||
|
||||
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 1, 1));
|
||||
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 32, 1));
|
||||
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 32, 4));
|
||||
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 128, 4));
|
||||
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 1, 1));
|
||||
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 1));
|
||||
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 4));
|
||||
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 128, 4));
|
||||
|
||||
#if 1
|
||||
for (ggml_type type_a : base_types) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue