RWKV v6 graph building

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
Molly Sophia 2024-08-02 13:58:34 +08:00
parent 5732de89b7
commit 0784a0cf26
3 changed files with 423 additions and 25 deletions

View file

@ -512,6 +512,7 @@ extern "C" {
GGML_OP_WIN_UNPART, GGML_OP_WIN_UNPART,
GGML_OP_GET_REL_POS, GGML_OP_GET_REL_POS,
GGML_OP_ADD_REL_POS, GGML_OP_ADD_REL_POS,
GGML_OP_RWKV_WKV,
GGML_OP_UNARY, GGML_OP_UNARY,
@ -1896,6 +1897,15 @@ extern "C" {
struct ggml_tensor * pw, struct ggml_tensor * pw,
struct ggml_tensor * ph); struct ggml_tensor * ph);
GGML_API struct ggml_tensor * ggml_rwkv_wkv(
struct ggml_context * ctx,
struct ggml_tensor * k,
struct ggml_tensor * v,
struct ggml_tensor * r,
struct ggml_tensor * tf,
struct ggml_tensor * td,
struct ggml_tensor * state);
// custom operators // custom operators
typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *); typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);

View file

@ -2835,6 +2835,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"WIN_UNPART", "WIN_UNPART",
"GET_REL_POS", "GET_REL_POS",
"ADD_REL_POS", "ADD_REL_POS",
"RWKV_WKV",
"UNARY", "UNARY",
@ -2853,7 +2854,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CROSS_ENTROPY_LOSS_BACK", "CROSS_ENTROPY_LOSS_BACK",
}; };
static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78"); static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none", "none",
@ -2927,6 +2928,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"win_unpart(x)", "win_unpart(x)",
"get_rel_pos(x)", "get_rel_pos(x)",
"add_rel_pos(x)", "add_rel_pos(x)",
"rwkv_wkv(x, k, v, r, tf, td, s)",
"unary(x)", "unary(x)",
@ -2945,7 +2947,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"cross_entropy_loss_back(x,y)", "cross_entropy_loss_back(x,y)",
}; };
static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78"); static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@ -7637,6 +7639,57 @@ struct ggml_tensor * ggml_add_rel_pos_inplace(
return ggml_add_rel_pos_impl(ctx, a, pw, ph, true); return ggml_add_rel_pos_impl(ctx, a, pw, ph, true);
} }
// ggml_rwkv_wkv
struct ggml_tensor * ggml_rwkv_wkv(
struct ggml_context * ctx,
struct ggml_tensor * k,
struct ggml_tensor * v,
struct ggml_tensor * r,
struct ggml_tensor * tf,
struct ggml_tensor * td,
struct ggml_tensor * state) {
GGML_ASSERT(ggml_is_contiguous(k));
GGML_ASSERT(ggml_is_contiguous(v));
GGML_ASSERT(ggml_is_contiguous(r));
GGML_ASSERT(ggml_is_contiguous(tf));
GGML_ASSERT(ggml_is_contiguous(td));
GGML_ASSERT(ggml_is_contiguous(state));
const int64_t S = k->ne[0];
const int64_t H = k->ne[2];
const int64_t n_tokens = k->ne[3];
{
GGML_ASSERT(k->ne[1] == 1);
GGML_ASSERT(v->ne[0] == 1 && v->ne[1] == S && v->ne[2] == H && v->ne[3] == n_tokens);
GGML_ASSERT(r->ne[0] == 1 && r->ne[1] == S && r->ne[2] == H && r->ne[3] == n_tokens);
// TODO: RWKV v4 and v5
GGML_ASSERT(td->ne[0] == 1 && td->ne[1] == S && td->ne[2] == H && td->ne[3] == n_tokens);
GGML_ASSERT(ggml_nelements(state) == S * S * H);
}
bool is_node = false;
if (k->grad || v->grad || r->grad || tf->grad || td->grad || state->grad) {
GGML_ABORT("fatal error"); // TODO: implement backward
is_node = true;
}
const int64_t ne[4] = { S * H, n_tokens, 1, 1 };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
result->op = GGML_OP_RWKV_WKV;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = k;
result->src[1] = v;
result->src[2] = r;
result->src[3] = tf;
result->src[4] = td;
result->src[5] = state;
return result;
}
// ggml_unary // ggml_unary
static struct ggml_tensor * ggml_unary_impl( static struct ggml_tensor * ggml_unary_impl(
@ -16795,6 +16848,92 @@ static void ggml_compute_forward_add_rel_pos(
} }
} }
// ggml_compute_forward_rwkv_wkv
static void ggml_compute_forward_rwkv_wkv_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const size_t T = dst->ne[1];
const size_t C = dst->ne[0];
const size_t H = dst->src[1]->ne[2];
float * dst_data = (float *) dst->data;
if (params->ith != 0) {
return;
}
memset(dst_data, 0, T * C * sizeof(float));
float * k = (float *) dst->src[0]->data;
float * v = (float *) dst->src[1]->data;
float * r = (float *) dst->src[2]->data;
float * time_faaaa = (float *) dst->src[3]->data;
float * time_decay = (float *) dst->src[4]->data;
float * state = (float *) dst->src[5]->data;
size_t t_stride = H * (C / H);
size_t h_stride = C / H;
size_t h_stride_2d = (C / H) * (C / H);
// basically fused operations:
// dst = r @ (time_faaaa * (k @ v) + state),
// state = time_decay * state + (k @ v),
// recursive through each token
for (size_t t = 0; t < T; t++) {
size_t t_offset = t * t_stride;
for (size_t h = 0; h < H; h++) {
size_t h_offset = h * h_stride;
size_t t_h_offset = t_offset + h_offset;
size_t h_2d_offset = h * h_stride_2d;
for (size_t i = 0; i < C / H; i++) {
size_t t_h_i_offset = t_h_offset + i;
size_t h_i_offset = h_offset + i;
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
float k_val = k[t_h_i_offset];
float r_val = r[t_h_i_offset];
float time_faaaa_val = time_faaaa[h_i_offset];
// RWKV v6: different time_decay for each token.
float time_decay_val = time_decay[t_h_i_offset];
for (size_t j = 0; j < C / H; j ++) {
size_t t_h_j_offset = t_h_offset + j;
size_t h_2d_i_j_offset = h_2d_i_offset + j;
float v_val = v[t_h_j_offset];
float kv_val = v_val * k_val;
float prev_state_val = state[h_2d_i_j_offset];
float temp_val = kv_val * time_faaaa_val + prev_state_val;
dst_data[t_h_j_offset] += temp_val * r_val;
state[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
}
}
}
}
}
static void ggml_compute_forward_rwkv_wkv(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_rwkv_wkv_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_map_unary // ggml_compute_forward_map_unary
static void ggml_compute_forward_map_unary_f32( static void ggml_compute_forward_map_unary_f32(
@ -17446,6 +17585,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{ {
ggml_compute_forward_add_rel_pos(params, tensor); ggml_compute_forward_add_rel_pos(params, tensor);
} break; } break;
case GGML_OP_RWKV_WKV:
{
ggml_compute_forward_rwkv_wkv(params, tensor);
} break;
case GGML_OP_MAP_UNARY: case GGML_OP_MAP_UNARY:
{ {
ggml_unary_op_f32_t fun; ggml_unary_op_f32_t fun;
@ -18569,6 +18712,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
} break; } break;
case GGML_OP_GET_REL_POS: case GGML_OP_GET_REL_POS:
case GGML_OP_ADD_REL_POS: case GGML_OP_ADD_REL_POS:
case GGML_OP_RWKV_WKV:
case GGML_OP_MAP_UNARY: case GGML_OP_MAP_UNARY:
case GGML_OP_MAP_BINARY: case GGML_OP_MAP_BINARY:
case GGML_OP_MAP_CUSTOM1_F32: case GGML_OP_MAP_CUSTOM1_F32:
@ -19143,6 +19287,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_WIN_PART: case GGML_OP_WIN_PART:
case GGML_OP_WIN_UNPART: case GGML_OP_WIN_UNPART:
case GGML_OP_GET_REL_POS: case GGML_OP_GET_REL_POS:
case GGML_OP_RWKV_WKV:
case GGML_OP_MAP_UNARY: case GGML_OP_MAP_UNARY:
case GGML_OP_MAP_BINARY: case GGML_OP_MAP_BINARY:
case GGML_OP_MAP_CUSTOM1_F32: case GGML_OP_MAP_CUSTOM1_F32:

View file

@ -3512,7 +3512,7 @@ static bool llama_kv_cache_find_slot(
const uint32_t n_seq_tokens = batch.n_seq_tokens; const uint32_t n_seq_tokens = batch.n_seq_tokens;
if (cache.recurrent) { if (cache.recurrent) {
// For recurrent state architectures (like Mamba), // For recurrent state architectures (like Mamba or RWKV),
// each cache cell can store the state for a whole sequence. // each cache cell can store the state for a whole sequence.
// A slot should be always be contiguous. // A slot should be always be contiguous.
@ -3761,7 +3761,7 @@ static bool llama_kv_cache_seq_rm(
if (p0 < 0) p0 = 0; if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max(); if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
// models like Mamba can't have a state partially erased // models like Mamba or RWKV can't have a state partially erased
if (cache.recurrent) { if (cache.recurrent) {
if (seq_id >= (int64_t) cache.size) { if (seq_id >= (int64_t) cache.size) {
// could be fatal // could be fatal
@ -3897,7 +3897,7 @@ static void llama_kv_cache_seq_add(
if (p0 == p1) return; if (p0 == p1) return;
if (cache.recurrent) { if (cache.recurrent) {
// for Mamba-like models, only the pos needs to be shifted // for Mamba-like or RWKV models, only the pos needs to be shifted
if (0 <= seq_id && seq_id < (int64_t) cache.size) { if (0 <= seq_id && seq_id < (int64_t) cache.size) {
const int32_t tail_id = cache.cells[seq_id].tail; const int32_t tail_id = cache.cells[seq_id].tail;
if (tail_id >= 0) { if (tail_id >= 0) {
@ -3946,7 +3946,7 @@ static void llama_kv_cache_seq_div(
if (p0 == p1) return; if (p0 == p1) return;
if (cache.recurrent) { if (cache.recurrent) {
// for Mamba-like models, only the pos needs to be changed // for Mamba-like or RWKV models, only the pos needs to be changed
if (0 <= seq_id && seq_id < (int64_t) cache.size) { if (0 <= seq_id && seq_id < (int64_t) cache.size) {
const int32_t tail_id = cache.cells[seq_id].tail; const int32_t tail_id = cache.cells[seq_id].tail;
if (tail_id >= 0) { if (tail_id >= 0) {
@ -5885,8 +5885,9 @@ static void llm_load_hparams(
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
// TODO: Re-using mamba keys right now, but RWKV isn't state-space // TODO: Re-using mamba keys right now, but RWKV isn't state-space
ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
} break; } break;
default: (void)0; default: (void)0;
} }
@ -8323,7 +8324,7 @@ static bool llm_load_tensors(
// Block 0, LN0 // Block 0, LN0
model.tok_norm = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); model.tok_norm = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd});
model.tok_norm_b = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); model.tok_norm_b = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd});
// output // output
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
@ -8348,8 +8349,8 @@ static bool llm_load_tensors(
layer.attn_norm_2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}); layer.attn_norm_2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd});
layer.attn_norm_2_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}); layer.attn_norm_2_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd});
layer.time_mix_w1 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {time_mix_extra_dim * 5, n_embd}); layer.time_mix_w1 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, time_mix_extra_dim * 5});
layer.time_mix_w2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {n_embd, time_mix_extra_dim, 5}); layer.time_mix_w2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5});
layer.time_mix_lerp_x = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}); layer.time_mix_lerp_x = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1});
layer.time_mix_lerp_w = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1}); layer.time_mix_lerp_w = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1});
@ -8361,8 +8362,8 @@ static bool llm_load_tensors(
// TODO: Parametrize hardcoded dimensions for first & decay // TODO: Parametrize hardcoded dimensions for first & decay
layer.time_mix_first = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}); layer.time_mix_first = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size});
layer.time_mix_decay = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}); layer.time_mix_decay = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd});
layer.time_mix_decay_w1 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {time_decay_extra_dim, n_embd}); layer.time_mix_decay_w1 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {n_embd, time_decay_extra_dim});
layer.time_mix_decay_w2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {attn_hidden_size, time_decay_extra_dim}); layer.time_mix_decay_w2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {time_decay_extra_dim, attn_hidden_size});
layer.time_mix_key = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}); layer.time_mix_key = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd});
layer.time_mix_value = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}); layer.time_mix_value = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd});
layer.time_mix_receptance = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}); layer.time_mix_receptance = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd});
@ -9344,9 +9345,198 @@ static struct ggml_tensor * llm_build_time_mix(
struct ggml_context * ctx, struct ggml_context * ctx,
const struct llama_layer * layer, const struct llama_layer * layer,
struct ggml_tensor * current, struct ggml_tensor * current,
int layer_i) { struct ggml_tensor * x_prev,
struct ggml_tensor * wkv_state) {
size_t n_embed = current->ne[0];
size_t n_tokens = current->ne[1];
size_t head_size = layer->time_mix_first->ne[0];
size_t head_count = layer->time_mix_first->ne[1];
return current; struct ggml_tensor * sx = ggml_sub(ctx, x_prev, current);
struct ggml_tensor * xxx = ggml_add_inplace(
ctx,
ggml_mul(ctx, sx, layer->time_mix_lerp_x),
current
);
xxx = ggml_reshape_4d(
ctx,
ggml_tanh_inplace(
ctx,
ggml_mul_mat(ctx, layer->time_mix_w1, xxx)
),
layer->time_mix_w1->ne[1] / 5, 1, 5, n_tokens
);
xxx = ggml_cont(
ctx,
ggml_permute(ctx, xxx, 0, 1, 3, 2)
);
xxx = ggml_mul_mat(
ctx,
ggml_reshape_4d(
ctx,
layer->time_mix_w2,
layer->time_mix_w2->ne[0], layer->time_mix_w2->ne[1], 1, 5
),
xxx
);
struct ggml_tensor *mw = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * n_tokens);
mw = ggml_reshape_2d(
ctx,
ggml_set_1d(ctx, mw, ggml_view_1d(ctx, xxx, n_embed * n_tokens, 0), 0),
n_embed, n_tokens
);
struct ggml_tensor *mk = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * n_tokens);
mk = ggml_reshape_2d(
ctx,
ggml_set_1d_inplace(ctx, mk, ggml_view_1d(ctx, xxx, n_embed * n_tokens, n_embed * n_tokens * sizeof(float)), 0),
n_embed, n_tokens
);
struct ggml_tensor *mv = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * n_tokens);
mv = ggml_reshape_2d(
ctx,
ggml_set_1d_inplace(ctx, mv, ggml_view_1d(ctx, xxx, n_embed * n_tokens, n_embed * n_tokens * 2 * sizeof(float)), 0),
n_embed, n_tokens
);
struct ggml_tensor *mr = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * n_tokens);
mr = ggml_reshape_2d(
ctx,
ggml_set_1d_inplace(ctx, mr, ggml_view_1d(ctx, xxx, n_embed * n_tokens, n_embed * n_tokens * 3 * sizeof(float)), 0),
n_embed, n_tokens
);
struct ggml_tensor *mg = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * n_tokens);
mg = ggml_reshape_2d(
ctx,
ggml_set_1d_inplace(ctx, mg, ggml_view_1d(ctx, xxx, n_embed * n_tokens, n_embed * n_tokens * 4 * sizeof(float)), 0),
n_embed, n_tokens
);
struct ggml_tensor * xw = ggml_add_inplace(
ctx,
ggml_mul_inplace(
ctx,
ggml_add(ctx, mw, layer->time_mix_lerp_w),
sx
),
current
);
struct ggml_tensor * xk = ggml_add_inplace(
ctx,
ggml_mul_inplace(
ctx,
ggml_add(ctx, mk, layer->time_mix_lerp_k),
sx
),
current
);
struct ggml_tensor * xv = ggml_add_inplace(
ctx,
ggml_mul_inplace(
ctx,
ggml_add(ctx, mv, layer->time_mix_lerp_v),
sx
),
current
);
struct ggml_tensor * xr = ggml_add_inplace(
ctx,
ggml_mul_inplace(
ctx,
ggml_add(ctx, mr, layer->time_mix_lerp_r),
sx
),
current
);
struct ggml_tensor * xg = ggml_add_inplace(
ctx,
ggml_mul_inplace(
ctx,
ggml_add(ctx, mg, layer->time_mix_lerp_g),
sx
),
current
);
struct ggml_tensor * r = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer->time_mix_receptance, xr), head_size, 1, head_count, n_tokens);
struct ggml_tensor * k = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer->time_mix_key, xk), 1, head_size, head_count, n_tokens);
struct ggml_tensor * v = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer->time_mix_value, xv), head_size, 1, head_count, n_tokens);
struct ggml_tensor * g = ggml_silu_inplace(
ctx,
ggml_mul_mat(ctx, layer->time_mix_gate, xg)
);
struct ggml_tensor * w = ggml_mul_mat(
ctx,
layer->time_mix_decay_w2,
ggml_tanh_inplace(
ctx,
ggml_mul_mat(ctx, layer->time_mix_decay_w1, xw)
)
);
w = ggml_add_inplace(
ctx,
w,
ggml_reshape_1d(ctx, layer->time_mix_decay, n_embed)
);
w = ggml_exp(ctx, ggml_neg_inplace(ctx, ggml_exp(ctx, w)));
w = ggml_reshape_4d(ctx, w, 1, head_size, head_count, n_tokens);
k = ggml_transpose(ctx, k);
v = ggml_transpose(ctx, v);
r = ggml_transpose(ctx, r);
current = ggml_rwkv_wkv(ctx, k, v, r, layer->time_mix_first, w, wkv_state);
// ggml_group_norm considers groups in the third dimension.
current = ggml_reshape_4d(ctx, current, 1, 1, n_embed, n_tokens);
current = ggml_group_norm(ctx, current, head_count, 64e-5f);
// Convert back to a regular vector.
current = ggml_reshape_2d(ctx, current, n_embed, n_tokens);
current = ggml_add_inplace(
ctx,
ggml_mul_inplace(
ctx,
current,
layer->time_mix_ln
),
layer->time_mix_ln_b
);
current = ggml_mul(ctx, current, g);
return ggml_mul_mat(ctx, layer->time_mix_output, current);
}
static struct ggml_tensor * llm_build_channel_mix(
struct ggml_context * ctx,
const struct llama_layer * layer,
struct ggml_tensor * current,
struct ggml_tensor * x_prev) {
struct ggml_tensor * sx = ggml_sub(ctx, x_prev, current);
struct ggml_tensor * xk = ggml_add_inplace(
ctx,
ggml_mul(ctx, sx, layer->channel_mix_lerp_k),
current
);
struct ggml_tensor * xr = ggml_add_inplace(
ctx,
ggml_mul(ctx, sx, layer->channel_mix_lerp_r),
current
);
struct ggml_tensor * r = ggml_sigmoid_inplace(ctx, ggml_mul_mat(ctx, layer->channel_mix_receptance, xr));
struct ggml_tensor * k = ggml_sqr_inplace(ctx, ggml_relu_inplace(ctx, ggml_mul_mat(ctx, layer->channel_mix_key, xk)));
return ggml_mul_inplace(ctx, r, ggml_mul_mat(ctx, layer->channel_mix_value, k));
} }
struct llm_build_context { struct llm_build_context {
@ -14874,32 +15064,85 @@ struct llm_build_context {
ggml_cgraph * build_rwkv() { ggml_cgraph * build_rwkv() {
ggml_cgraph *gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); ggml_cgraph *gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
// Token shift state dimensions should be 2 * n_emb
GGML_ASSERT(n_embd == hparams.n_embd_k_s() / 2);
// Input embeddings, start of the model after tokenizing ({n_embd, n_tokens}) // Input embeddings, start of the model after tokenizing ({n_embd, n_tokens})
ggml_tensor * input_embeddings = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); ggml_tensor * input_embeddings = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
// x = self.layer_norm(x, self.w.blocks[0].ln0) ggml_tensor * x = llm_build_norm(ctx0, input_embeddings, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1);
ggml_tensor * current = llm_build_norm(ctx0, input_embeddings, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1);
for (int layer_i = 0; layer_i < n_layer; ++layer_i) { for (int layer_i = 0; layer_i < n_layer; ++layer_i) {
const llama_layer * layer = &model.layers[layer_i]; const llama_layer * layer = &model.layers[layer_i];
current = llm_build_norm(ctx0, current, hparams, layer->attn_norm, layer->attn_norm_b, LLM_NORM, cb, layer_i); // TODO: handle multiple kv cache cells
current = llm_build_time_mix(ctx0, layer, current, layer_i); struct ggml_tensor * wkv_state = ggml_view_1d(ctx0, kv_self.v_l[layer_i], hparams.n_embd_v_s(), (kv_self.size - 1) * hparams.n_embd_v_s() * ggml_type_size(kv_self.k_l[layer_i]->type));
struct ggml_tensor * att_shift = ggml_view_1d(ctx0, kv_self.k_l[layer_i], n_embd, (kv_self.size - 1) * 2 * n_embd * ggml_type_size(kv_self.k_l[layer_i]->type));
struct ggml_tensor * ffn_shift = ggml_view_1d(ctx0, kv_self.k_l[layer_i], n_embd, ((kv_self.size - 1) * 2 + 1) * n_embd * ggml_type_size(kv_self.k_l[layer_i]->type));
current = llm_build_norm(ctx0, current, hparams, layer->attn_norm_2, layer->attn_norm_2_b, LLM_NORM, cb, layer_i); struct ggml_tensor * x_norm = llm_build_norm(ctx0, x, hparams, layer->attn_norm, layer->attn_norm_b, LLM_NORM, cb, layer_i);
struct ggml_tensor * x_prev = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens);
x_prev = ggml_set_1d(ctx0, x_prev, att_shift, 0);
x_prev = ggml_set_1d(
ctx0,
x_prev,
ggml_view_1d(ctx0, x_norm, (n_tokens - 1) * n_embd, 0),
n_embd * ggml_type_size(x_prev->type)
);
x = ggml_add(ctx0, x, llm_build_time_mix(ctx0, layer, x_norm, x_prev, wkv_state));
ggml_build_forward_expand(gf, x);
ggml_build_forward_expand(
gf,
ggml_cpy(
ctx0,
ggml_view_1d(
ctx0,
x_norm,
n_embd,
(n_tokens - 1) * n_embd * ggml_type_size(kv_self.k_l[layer_i]->type)
),
att_shift
)
);
x_norm = llm_build_norm(ctx0, x, hparams, layer->attn_norm_2, layer->attn_norm_2_b, LLM_NORM, cb, layer_i);
x_prev = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens);
x_prev = ggml_set_1d(ctx0, x_prev, ffn_shift, 0);
x_prev = ggml_set_1d(
ctx0,
x_prev,
ggml_view_1d(ctx0, x_norm, (n_tokens - 1) * n_embd, 0),
n_embd * ggml_type_size(x_prev->type)
);
x = ggml_add(ctx0, x, llm_build_channel_mix(ctx0, layer, x_norm, x_prev));
ggml_build_forward_expand(gf, x);
ggml_build_forward_expand(
gf,
ggml_cpy(
ctx0,
ggml_view_1d(
ctx0,
x_norm,
n_embd,
(n_tokens - 1) * n_embd * ggml_type_size(kv_self.k_l[layer_i]->type)
),
ffn_shift
)
);
} }
// Something related to skipping tokens, specifics unclear // Something related to skipping tokens, specifics unclear
ggml_tensor * inp_out_ids = build_inp_out_ids(); ggml_tensor * inp_out_ids = build_inp_out_ids();
current = ggml_get_rows(ctx0, current, inp_out_ids); x = ggml_get_rows(ctx0, x, inp_out_ids);
// Output head, convert result vector to logits // Output head, convert result vector to logits
current = llm_build_norm(ctx0, current, hparams, model.output_norm, model.output_norm_b, LLM_NORM, cb, -1); x = llm_build_norm(ctx0, x, hparams, model.output_norm, model.output_norm_b, LLM_NORM, cb, -1);
current = ggml_mul_mat(ctx0, model.output, current); x = ggml_mul_mat(ctx0, model.output, x);
// Mark the output as being the result // Mark the output as being the result
cb(current, "result_output", -1); cb(x, "result_output", -1);
ggml_build_forward_expand(gf, current); ggml_build_forward_expand(gf, x);
return gf; return gf;
} }