llama: rwkv6: Use the new advanced batch splits

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
Molly Sophia 2024-08-23 10:14:35 +08:00
parent 6da6aa48b0
commit f5d955d2fe
3 changed files with 66 additions and 204 deletions

View file

@ -513,7 +513,6 @@ extern "C" {
GGML_OP_GET_REL_POS, GGML_OP_GET_REL_POS,
GGML_OP_ADD_REL_POS, GGML_OP_ADD_REL_POS,
GGML_OP_RWKV_WKV, GGML_OP_RWKV_WKV,
GGML_OP_RWKV_TOKEN_SHIFT,
GGML_OP_UNARY, GGML_OP_UNARY,
@ -1905,14 +1904,7 @@ extern "C" {
struct ggml_tensor * r, struct ggml_tensor * r,
struct ggml_tensor * tf, struct ggml_tensor * tf,
struct ggml_tensor * td, struct ggml_tensor * td,
struct ggml_tensor * state, struct ggml_tensor * state);
struct ggml_tensor * state_seq);
GGML_API struct ggml_tensor * ggml_rwkv_token_shift(
struct ggml_context * ctx,
struct ggml_tensor * x_carry,
struct ggml_tensor * x_norm,
struct ggml_tensor * state_seq);
// custom operators // custom operators

View file

@ -2836,7 +2836,6 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"GET_REL_POS", "GET_REL_POS",
"ADD_REL_POS", "ADD_REL_POS",
"RWKV_WKV", "RWKV_WKV",
"RWKV_TOKEN_SHIFT",
"UNARY", "UNARY",
@ -2855,7 +2854,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CROSS_ENTROPY_LOSS_BACK", "CROSS_ENTROPY_LOSS_BACK",
}; };
static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80"); static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none", "none",
@ -2929,8 +2928,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"win_unpart(x)", "win_unpart(x)",
"get_rel_pos(x)", "get_rel_pos(x)",
"add_rel_pos(x)", "add_rel_pos(x)",
"rwkv_wkv(k, v, r, tf, td, s, sq)", "rwkv_wkv(k, v, r, tf, td, s)",
"rwkv_token_shift(xc, xn, sq)",
"unary(x)", "unary(x)",
@ -2949,7 +2947,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"cross_entropy_loss_back(x,y)", "cross_entropy_loss_back(x,y)",
}; };
static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80"); static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@ -7650,39 +7648,36 @@ struct ggml_tensor * ggml_rwkv_wkv(
struct ggml_tensor * r, struct ggml_tensor * r,
struct ggml_tensor * tf, struct ggml_tensor * tf,
struct ggml_tensor * td, struct ggml_tensor * td,
struct ggml_tensor * state, struct ggml_tensor * state) {
struct ggml_tensor * state_seq) {
GGML_ASSERT(ggml_is_contiguous(k)); GGML_ASSERT(ggml_is_contiguous(k));
GGML_ASSERT(ggml_is_contiguous(v)); GGML_ASSERT(ggml_is_contiguous(v));
GGML_ASSERT(ggml_is_contiguous(r)); GGML_ASSERT(ggml_is_contiguous(r));
GGML_ASSERT(ggml_is_contiguous(tf)); GGML_ASSERT(ggml_is_contiguous(tf));
GGML_ASSERT(ggml_is_contiguous(td)); GGML_ASSERT(ggml_is_contiguous(td));
GGML_ASSERT(ggml_is_contiguous(state)); GGML_ASSERT(ggml_is_contiguous(state));
GGML_ASSERT(ggml_is_contiguous(state_seq));
GGML_ASSERT(state_seq->type == GGML_TYPE_I32);
const int64_t S = k->ne[0]; const int64_t S = k->ne[0];
const int64_t H = k->ne[2]; const int64_t H = k->ne[2];
const int64_t n_tokens = k->ne[3]; const int64_t n_tokens = k->ne[3];
const int64_t n_kv = state_seq->ne[0]; const int64_t n_seqs = state->ne[1];
{ {
GGML_ASSERT(k->ne[1] == 1); GGML_ASSERT(k->ne[1] == 1);
GGML_ASSERT(v->ne[0] == 1 && v->ne[1] == S && v->ne[2] == H && v->ne[3] == n_tokens); GGML_ASSERT(v->ne[0] == 1 && v->ne[1] == S && v->ne[2] == H && v->ne[3] == n_tokens);
GGML_ASSERT(r->ne[0] == 1 && r->ne[1] == S && r->ne[2] == H && r->ne[3] == n_tokens); GGML_ASSERT(r->ne[0] == 1 && r->ne[1] == S && r->ne[2] == H && r->ne[3] == n_tokens);
// TODO: RWKV v4 and v5 // TODO: RWKV v4 and v5
GGML_ASSERT(td->ne[0] == 1 && td->ne[1] == S && td->ne[2] == H && td->ne[3] == n_tokens); GGML_ASSERT(td->ne[0] == 1 && td->ne[1] == S && td->ne[2] == H && td->ne[3] == n_tokens);
GGML_ASSERT(ggml_nelements(state) == S * S * H * n_kv); GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
} }
bool is_node = false; bool is_node = false;
if (k->grad || v->grad || r->grad || tf->grad || td->grad || state->grad || state_seq->grad) { if (k->grad || v->grad || r->grad || tf->grad || td->grad || state->grad) {
GGML_ABORT("fatal error"); // TODO: implement backward GGML_ABORT("fatal error"); // TODO: implement backward
is_node = true; is_node = true;
} }
// concat output and new_state // concat output and new_state
const int64_t ne[4] = { S * H, n_tokens + S * n_kv, 1, 1 }; const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
result->op = GGML_OP_RWKV_WKV; result->op = GGML_OP_RWKV_WKV;
@ -7693,48 +7688,6 @@ struct ggml_tensor * ggml_rwkv_wkv(
result->src[3] = tf; result->src[3] = tf;
result->src[4] = td; result->src[4] = td;
result->src[5] = state; result->src[5] = state;
result->src[6] = state_seq;
return result;
}
// ggml_rwkv_token_shift
struct ggml_tensor * ggml_rwkv_token_shift(
struct ggml_context * ctx,
struct ggml_tensor * x_carry,
struct ggml_tensor * x_norm,
struct ggml_tensor * state_seq) {
GGML_ASSERT(ggml_is_contiguous(x_carry));
GGML_ASSERT(ggml_is_contiguous(x_norm));
GGML_ASSERT(ggml_is_contiguous(state_seq));
GGML_ASSERT(state_seq->type == GGML_TYPE_I32);
const int64_t n_embd = x_norm->ne[0];
const int64_t n_kv = state_seq->ne[0];
const int64_t n_tokens = state_seq->ne[1];
{
GGML_ASSERT(x_norm->ne[0] == n_embd);
GGML_ASSERT(x_norm->ne[1] == n_tokens);
GGML_ASSERT(ggml_nelements(x_carry) == n_embd * n_kv);
}
bool is_node = false;
if (x_carry->grad || x_norm->grad || state_seq->grad) {
GGML_ABORT("fatal error"); // TODO: implement backward
is_node = true;
}
// concat output and new_state
const int64_t ne[4] = { n_embd, n_tokens + n_kv, 1, 1 };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
result->op = GGML_OP_RWKV_TOKEN_SHIFT;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = x_carry;
result->src[1] = x_norm;
result->src[2] = state_seq;
return result; return result;
} }
@ -16905,7 +16858,7 @@ static void ggml_compute_forward_rwkv_wkv_f32(
const size_t T = dst->src[1]->ne[3]; const size_t T = dst->src[1]->ne[3];
const size_t C = dst->ne[0]; const size_t C = dst->ne[0];
const size_t H = dst->src[1]->ne[2]; const size_t H = dst->src[1]->ne[2];
const size_t n_kv = dst->src[6]->ne[0]; const size_t n_seqs = dst->src[5]->ne[1];
float * dst_data = (float *) dst->data; float * dst_data = (float *) dst->data;
float * state = ((float *) dst->data) + C * T; float * state = ((float *) dst->data) + C * T;
@ -16921,8 +16874,7 @@ static void ggml_compute_forward_rwkv_wkv_f32(
float * r = (float *) dst->src[2]->data; float * r = (float *) dst->src[2]->data;
float * time_faaaa = (float *) dst->src[3]->data; float * time_faaaa = (float *) dst->src[3]->data;
float * time_decay = (float *) dst->src[4]->data; float * time_decay = (float *) dst->src[4]->data;
int32_t * seq_data = (int32_t *) dst->src[6]->data; memcpy(state, dst->src[5]->data, (C / H) * C * n_seqs * sizeof(float));
memcpy(state, dst->src[5]->data, (C / H) * C * n_kv * sizeof(float));
size_t t_stride = H * (C / H); size_t t_stride = H * (C / H);
@ -16935,7 +16887,7 @@ static void ggml_compute_forward_rwkv_wkv_f32(
// recursive through each token // recursive through each token
for (size_t t = 0; t < T; t++) { for (size_t t = 0; t < T; t++) {
size_t t_offset = t * t_stride; size_t t_offset = t * t_stride;
float * state_cur = state + (C / H) * C * seq_data[t * n_kv]; float * state_cur = state + (C / H) * C * (t / (T / n_seqs));
for (size_t h = 0; h < H; h++) { for (size_t h = 0; h < H; h++) {
size_t h_offset = h * h_stride; size_t h_offset = h * h_stride;
@ -16967,15 +16919,6 @@ static void ggml_compute_forward_rwkv_wkv_f32(
} }
} }
} }
for (size_t t = 0; t < T; t++) {
for (size_t kv = 1; kv < n_kv; kv++) {
int64_t seq = seq_data[t * n_kv + kv];
if (seq >= 0 && seq_data[(t + 1) * n_kv + kv] != seq) {
memcpy(state + (C / H) * C * seq, state + (C / H) * C * seq_data[t * n_kv], (C / H) * C * sizeof(float));
}
}
}
} }
static void ggml_compute_forward_rwkv_wkv( static void ggml_compute_forward_rwkv_wkv(
@ -16996,77 +16939,6 @@ static void ggml_compute_forward_rwkv_wkv(
} }
} }
static void ggml_compute_forward_rwkv_token_shift_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const int64_t n_embd = dst->ne[0];
const int64_t n_kv = dst->src[2]->ne[0];
const int64_t n_tokens = dst->src[1]->ne[1];
float * dst_data = (float *) dst->data;
float * x_carry = (float *) dst->src[0]->data;
float * x_norm = (float *) dst->src[1]->data;
int32_t * sq_data = (int32_t *) dst->src[2]->data;
if (params->ith != 0) {
return;
}
int32_t seq_start = 0;
int32_t seq_length = 0;
for (int i1 = 0; i1 < n_kv; ++i1) {
seq_start = -1;
// assume that the tokens for each sequence are contiguous
for (int i2 = 0; i2 < n_tokens; ++i2) {
int32_t seq = sq_data[i2*n_kv];
if (seq == i1 && seq_start < 0) {
seq_start = i2;
}
if ((seq_start >= 0 && seq != i1) || i2 == n_tokens - 1) {
seq_length = i2 - seq_start + (i2 == n_tokens - 1);
break;
}
}
if (seq_start >= 0) {
int32_t seq = sq_data[seq_start*n_kv];
memcpy(dst_data + seq_start*n_embd, x_carry + seq*n_embd, n_embd*sizeof(float));
memcpy(dst_data + (seq_start+1)*n_embd, x_norm + seq_start*n_embd, (seq_length-1)*n_embd*sizeof(float));
}
}
for (int i3 = 0; i3 < n_kv; ++i3) {
int32_t last_token_pos = 0;
for (int i4 = 0; i4 < n_tokens; ++i4) {
for (int i5 = 0; i5 < n_kv; ++i5) {
if (sq_data[i4*n_kv + i5] == i3) {
last_token_pos = i4;
}
}
}
memcpy(dst_data + (n_tokens + i3)*n_embd, x_norm + last_token_pos*n_embd, n_embd*sizeof(float));
}
}
static void ggml_compute_forward_rwkv_token_shift(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_rwkv_token_shift_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_map_unary // ggml_compute_forward_map_unary
static void ggml_compute_forward_map_unary_f32( static void ggml_compute_forward_map_unary_f32(
@ -17722,10 +17594,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{ {
ggml_compute_forward_rwkv_wkv(params, tensor); ggml_compute_forward_rwkv_wkv(params, tensor);
} break; } break;
case GGML_OP_RWKV_TOKEN_SHIFT:
{
ggml_compute_forward_rwkv_token_shift(params, tensor);
} break;
case GGML_OP_MAP_UNARY: case GGML_OP_MAP_UNARY:
{ {
ggml_unary_op_f32_t fun; ggml_unary_op_f32_t fun;
@ -18859,7 +18727,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
case GGML_OP_GET_REL_POS: case GGML_OP_GET_REL_POS:
case GGML_OP_ADD_REL_POS: case GGML_OP_ADD_REL_POS:
case GGML_OP_RWKV_WKV: case GGML_OP_RWKV_WKV:
case GGML_OP_RWKV_TOKEN_SHIFT:
case GGML_OP_MAP_UNARY: case GGML_OP_MAP_UNARY:
case GGML_OP_MAP_BINARY: case GGML_OP_MAP_BINARY:
case GGML_OP_MAP_CUSTOM1_F32: case GGML_OP_MAP_CUSTOM1_F32:
@ -19435,7 +19302,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_WIN_UNPART: case GGML_OP_WIN_UNPART:
case GGML_OP_GET_REL_POS: case GGML_OP_GET_REL_POS:
case GGML_OP_RWKV_WKV: case GGML_OP_RWKV_WKV:
case GGML_OP_RWKV_TOKEN_SHIFT:
case GGML_OP_MAP_UNARY: case GGML_OP_MAP_UNARY:
case GGML_OP_MAP_BINARY: case GGML_OP_MAP_BINARY:
case GGML_OP_MAP_CUSTOM1_F32: case GGML_OP_MAP_CUSTOM1_F32:

View file

@ -9378,15 +9378,20 @@ static struct ggml_tensor * llm_build_time_mix_rwkv6(
const struct llama_layer * layer, const struct llama_layer * layer,
struct ggml_tensor * cur, struct ggml_tensor * cur,
struct ggml_tensor * x_prev, struct ggml_tensor * x_prev,
struct ggml_tensor ** wkv_state, struct ggml_tensor ** wkv_state) {
struct ggml_tensor * state_seq) {
size_t n_embed = cur->ne[0]; size_t n_embed = cur->ne[0];
size_t n_tokens = cur->ne[1]; size_t n_seq_tokens = cur->ne[1];
size_t n_seqs = cur->ne[2];
size_t head_size = layer->time_mix_first->ne[0]; size_t head_size = layer->time_mix_first->ne[0];
size_t head_count = layer->time_mix_first->ne[1]; size_t head_count = layer->time_mix_first->ne[1];
size_t n_kv = state_seq->ne[0];
size_t n_tokens = n_seqs * n_seq_tokens;
struct ggml_tensor * sx = ggml_sub(ctx, x_prev, cur); struct ggml_tensor * sx = ggml_sub(ctx, x_prev, cur);
sx = ggml_reshape_2d(ctx, sx, n_embed, n_tokens);
cur = ggml_reshape_2d(ctx, cur, n_embed, n_tokens);
struct ggml_tensor * xxx = ggml_add(ctx, ggml_mul(ctx, sx, layer->time_mix_lerp_x), cur); struct ggml_tensor * xxx = ggml_add(ctx, ggml_mul(ctx, sx, layer->time_mix_lerp_x), cur);
xxx = ggml_reshape_4d( xxx = ggml_reshape_4d(
@ -9489,9 +9494,9 @@ static struct ggml_tensor * llm_build_time_mix_rwkv6(
k = ggml_transpose(ctx, k); k = ggml_transpose(ctx, k);
v = ggml_transpose(ctx, v); v = ggml_transpose(ctx, v);
r = ggml_transpose(ctx, r); r = ggml_transpose(ctx, r);
struct ggml_tensor * wkv_output = ggml_rwkv_wkv(ctx, k, v, r, layer->time_mix_first, w, *wkv_state, state_seq); struct ggml_tensor * wkv_output = ggml_rwkv_wkv(ctx, k, v, r, layer->time_mix_first, w, *wkv_state);
cur = ggml_view_1d(ctx, wkv_output, n_embed * n_tokens, 0); cur = ggml_view_1d(ctx, wkv_output, n_embed * n_tokens, 0);
*wkv_state = ggml_view_1d(ctx, wkv_output, n_embed * head_size * n_kv, n_embed * n_tokens * sizeof(float)); *wkv_state = ggml_view_1d(ctx, wkv_output, n_embed * head_size * n_seqs, n_embed * n_tokens * sizeof(float));
// ggml_group_norm considers groups in the third dimension. // ggml_group_norm considers groups in the third dimension.
cur = ggml_reshape_4d(ctx, cur, n_embed / head_count, 1, head_count, n_tokens); cur = ggml_reshape_4d(ctx, cur, n_embed / head_count, 1, head_count, n_tokens);
@ -9501,8 +9506,9 @@ static struct ggml_tensor * llm_build_time_mix_rwkv6(
cur = ggml_add(ctx, ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b); cur = ggml_add(ctx, ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b);
cur = ggml_mul(ctx, cur, g); cur = ggml_mul(ctx, cur, g);
cur = ggml_mul_mat(ctx, layer->time_mix_output, cur);
return ggml_mul_mat(ctx, layer->time_mix_output, cur); return ggml_reshape_3d(ctx, cur, n_embed, n_seq_tokens, n_seqs);
} }
static struct ggml_tensor * llm_build_channel_mix_rwkv6( static struct ggml_tensor * llm_build_channel_mix_rwkv6(
@ -15053,49 +15059,56 @@ struct llm_build_context {
// Token shift state dimensions should be 2 * n_emb // Token shift state dimensions should be 2 * n_emb
GGML_ASSERT(n_embd == hparams.n_embd_k_s() / 2); GGML_ASSERT(n_embd == hparams.n_embd_k_s() / 2);
const int64_t n_seqs = batch.n_seqs;
const int64_t n_seq_tokens = batch.n_seq_tokens;
const int64_t n_tokens = batch.n_tokens;
GGML_ASSERT(n_seqs != 0);
GGML_ASSERT(batch.equal_seqs);
GGML_ASSERT(n_tokens == n_seq_tokens * n_seqs);
ggml_tensor * input_embeddings = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); ggml_tensor * input_embeddings = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
struct ggml_tensor * state_copy = build_inp_s_copy();
struct ggml_tensor * state_mask = build_inp_s_mask(); struct ggml_tensor * state_mask = build_inp_s_mask();
struct ggml_tensor * state_seq = build_inp_s_seq();
ggml_tensor * cur = llm_build_norm(ctx0, input_embeddings, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1); ggml_tensor * cur = llm_build_norm(ctx0, input_embeddings, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1);
for (int layer_i = 0; layer_i < n_layer; ++layer_i) { for (int layer_i = 0; layer_i < n_layer; ++layer_i) {
const llama_layer * layer = &model.layers[layer_i]; const llama_layer * layer = &model.layers[layer_i];
struct ggml_tensor * token_shift = ggml_reshape_2d(ctx0, kv_self.k_l[layer_i], hparams.n_embd_k_s(), kv_self.size); // (ab)using the KV cache to store the states
struct ggml_tensor * wkv_states = ggml_reshape_2d(ctx0, kv_self.v_l[layer_i], hparams.n_embd_v_s(), kv_self.size); struct ggml_tensor * token_shift = llm_build_copy_mask_state(ctx0,
gf, kv_self.k_l[layer_i], state_copy, state_mask,
hparams.n_embd_k_s(), kv_self.size, kv_head, n_kv, n_seqs);
struct ggml_tensor * wkv_states = llm_build_copy_mask_state(ctx0,
gf, kv_self.v_l[layer_i], state_copy, state_mask,
hparams.n_embd_v_s(), kv_self.size, kv_head, n_kv, n_seqs);
{ cur = ggml_reshape_3d(ctx0, cur, n_embd, n_seq_tokens, n_seqs);
token_shift = ggml_mul(ctx0,
ggml_view_2d(ctx0, token_shift, token_shift->ne[0], n_kv, token_shift->nb[1], kv_head*token_shift->nb[1]),
state_mask);
wkv_states = ggml_mul(ctx0,
ggml_view_2d(ctx0, wkv_states, wkv_states->ne[0], n_kv, wkv_states->nb[1], kv_head*wkv_states->nb[1]),
state_mask);
}
token_shift = ggml_cont( token_shift = ggml_cont(
ctx0, ctx0,
ggml_permute( ggml_permute(
ctx0, ctx0,
ggml_reshape_3d(ctx0, token_shift, n_embd, 2, n_kv), ggml_reshape_3d(ctx0, token_shift, n_embd, 2, n_seqs),
0, 2, 1, 3 0, 2, 1, 3
) )
); );
struct ggml_tensor * att_shift = ggml_view_1d(ctx0, token_shift, n_embd * n_kv, 0); struct ggml_tensor * att_shift = ggml_view_1d(ctx0, token_shift, n_embd * n_seqs, 0);
struct ggml_tensor * ffn_shift = ggml_view_1d(ctx0, token_shift, n_embd * n_kv, n_embd * n_kv * ggml_element_size(kv_self.k_l[layer_i])); struct ggml_tensor * ffn_shift = ggml_view_1d(ctx0, token_shift, n_embd * n_seqs, n_embd * n_seqs * ggml_element_size(token_shift));
att_shift = ggml_reshape_3d(ctx0, att_shift, n_embd, 1, n_seqs);
ffn_shift = ggml_reshape_3d(ctx0, ffn_shift, n_embd, 1, n_seqs);
struct ggml_tensor * x_norm = llm_build_norm(ctx0, cur, hparams, layer->attn_norm, layer->attn_norm_b, LLM_NORM, cb, layer_i); struct ggml_tensor * x_norm = llm_build_norm(ctx0, cur, hparams, layer->attn_norm, layer->attn_norm_b, LLM_NORM, cb, layer_i);
struct ggml_tensor * tmp = ggml_rwkv_token_shift(ctx0, att_shift, x_norm, state_seq); struct ggml_tensor * x_prev = ggml_concat(
struct ggml_tensor * x_prev = ggml_reshape_2d(
ctx0, ctx0,
ggml_view_1d(ctx0, tmp, n_embd * n_tokens, 0), att_shift,
n_embd, n_tokens ggml_view_3d(ctx0, x_norm, n_embd, n_seq_tokens - 1, n_seqs, x_norm->nb[1], x_norm->nb[2], 0),
1
); );
cur = ggml_add(ctx0, cur, llm_build_time_mix_rwkv6(ctx0, layer, x_norm, x_prev, &wkv_states, state_seq)); cur = ggml_add(ctx0, cur, llm_build_time_mix_rwkv6(ctx0, layer, x_norm, x_prev, &wkv_states));
ggml_build_forward_expand(gf, cur); ggml_build_forward_expand(gf, cur);
ggml_build_forward_expand( ggml_build_forward_expand(
gf, gf,
@ -15105,45 +15118,35 @@ struct llm_build_context {
ggml_view_1d( ggml_view_1d(
ctx0, ctx0,
kv_self.v_l[layer_i], kv_self.v_l[layer_i],
hparams.n_embd_v_s() * n_kv, hparams.n_embd_v_s() * n_seqs,
hparams.n_embd_v_s() * kv_head * ggml_type_size(kv_self.v_l[layer_i]->type) hparams.n_embd_v_s() * kv_head * ggml_type_size(kv_self.v_l[layer_i]->type)
) )
) )
); );
struct ggml_tensor * last_norm = ggml_view_3d(ctx0, x_norm, n_embd, 1, n_seqs, x_norm->nb[1], x_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm));
ggml_build_forward_expand( ggml_build_forward_expand(
gf, gf,
ggml_cpy( ggml_cpy(
ctx0, ctx0, last_norm,
ggml_view_1d( ggml_view_1d(ctx0, token_shift, n_embd * n_seqs, 0)
ctx0,
tmp,
n_embd * n_kv,
n_tokens * n_embd * ggml_type_size(kv_self.k_l[layer_i]->type)
),
ggml_view_1d(ctx0, token_shift, n_embd * n_kv, 0)
) )
); );
x_norm = llm_build_norm(ctx0, cur, hparams, layer->attn_norm_2, layer->attn_norm_2_b, LLM_NORM, cb, layer_i); x_norm = llm_build_norm(ctx0, cur, hparams, layer->attn_norm_2, layer->attn_norm_2_b, LLM_NORM, cb, layer_i);
tmp = ggml_rwkv_token_shift(ctx0, ffn_shift, x_norm, state_seq); x_prev = ggml_concat(
x_prev = ggml_reshape_2d(
ctx0, ctx0,
ggml_view_1d(ctx0, tmp, n_embd * n_tokens, 0), ffn_shift,
n_embd, n_tokens ggml_view_3d(ctx0, x_norm, n_embd, n_seq_tokens - 1, n_seqs, x_norm->nb[1], x_norm->nb[2], 0),
1
); );
cur = ggml_add(ctx0, cur, llm_build_channel_mix_rwkv6(ctx0, layer, x_norm, x_prev)); cur = ggml_add(ctx0, cur, llm_build_channel_mix_rwkv6(ctx0, layer, x_norm, x_prev));
last_norm = ggml_view_3d(ctx0, x_norm, n_embd, 1, n_seqs, x_norm->nb[1], x_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm));
ggml_build_forward_expand(gf, cur); ggml_build_forward_expand(gf, cur);
ggml_build_forward_expand( ggml_build_forward_expand(
gf, gf,
ggml_cpy( ggml_cpy(
ctx0, ctx0, last_norm,
ggml_view_1d( ggml_view_1d(ctx0, token_shift, n_embd * n_seqs, n_embd * n_seqs * ggml_element_size(token_shift))
ctx0,
tmp,
n_embd * n_kv,
n_tokens * n_embd * ggml_type_size(kv_self.k_l[layer_i]->type)
),
ggml_view_1d(ctx0, token_shift, n_embd * n_kv, n_kv * n_embd * ggml_type_size(kv_self.k_l[layer_i]->type))
) )
); );
@ -15151,7 +15154,7 @@ struct llm_build_context {
ctx0, ctx0,
ggml_permute( ggml_permute(
ctx0, ctx0,
ggml_reshape_3d(ctx0, token_shift, n_embd, n_kv, 2), ggml_reshape_3d(ctx0, token_shift, n_embd, n_seqs, 2),
0, 2, 1, 3 0, 2, 1, 3
) )
); );
@ -15160,8 +15163,8 @@ struct llm_build_context {
gf, gf,
ggml_cpy( ggml_cpy(
ctx0, ctx0,
ggml_view_1d(ctx0, token_shift, n_embd * n_kv * 2, 0), ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * 2, 0),
ggml_view_1d(ctx0, kv_self.k_l[layer_i], hparams.n_embd_k_s() * n_kv, hparams.n_embd_k_s() * kv_head * ggml_type_size(kv_self.k_l[layer_i]->type)) ggml_view_1d(ctx0, kv_self.k_l[layer_i], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_type_size(kv_self.k_l[layer_i]->type))
) )
); );
@ -15171,6 +15174,7 @@ struct llm_build_context {
} }
ggml_tensor * inp_out_ids = build_inp_out_ids(); ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
cur = ggml_get_rows(ctx0, cur, inp_out_ids); cur = ggml_get_rows(ctx0, cur, inp_out_ids);
cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM, cb, -1); cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM, cb, -1);