Fix parallel inferencing for RWKV
Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
parent
98ce5f43f0
commit
01dcf4bb77
3 changed files with 237 additions and 48 deletions
|
@ -513,6 +513,7 @@ extern "C" {
|
||||||
GGML_OP_GET_REL_POS,
|
GGML_OP_GET_REL_POS,
|
||||||
GGML_OP_ADD_REL_POS,
|
GGML_OP_ADD_REL_POS,
|
||||||
GGML_OP_RWKV_WKV,
|
GGML_OP_RWKV_WKV,
|
||||||
|
GGML_OP_RWKV_TOKEN_SHIFT,
|
||||||
|
|
||||||
GGML_OP_UNARY,
|
GGML_OP_UNARY,
|
||||||
|
|
||||||
|
@ -1904,7 +1905,14 @@ extern "C" {
|
||||||
struct ggml_tensor * r,
|
struct ggml_tensor * r,
|
||||||
struct ggml_tensor * tf,
|
struct ggml_tensor * tf,
|
||||||
struct ggml_tensor * td,
|
struct ggml_tensor * td,
|
||||||
struct ggml_tensor * state);
|
struct ggml_tensor * state,
|
||||||
|
struct ggml_tensor * state_seq);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_rwkv_token_shift(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * x_carry,
|
||||||
|
struct ggml_tensor * x_norm,
|
||||||
|
struct ggml_tensor * state_seq);
|
||||||
|
|
||||||
// custom operators
|
// custom operators
|
||||||
|
|
||||||
|
|
157
ggml/src/ggml.c
157
ggml/src/ggml.c
|
@ -2836,6 +2836,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||||
"GET_REL_POS",
|
"GET_REL_POS",
|
||||||
"ADD_REL_POS",
|
"ADD_REL_POS",
|
||||||
"RWKV_WKV",
|
"RWKV_WKV",
|
||||||
|
"RWKV_TOKEN_SHIFT",
|
||||||
|
|
||||||
"UNARY",
|
"UNARY",
|
||||||
|
|
||||||
|
@ -2854,7 +2855,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||||
"CROSS_ENTROPY_LOSS_BACK",
|
"CROSS_ENTROPY_LOSS_BACK",
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");
|
static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80");
|
||||||
|
|
||||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"none",
|
"none",
|
||||||
|
@ -2928,7 +2929,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"win_unpart(x)",
|
"win_unpart(x)",
|
||||||
"get_rel_pos(x)",
|
"get_rel_pos(x)",
|
||||||
"add_rel_pos(x)",
|
"add_rel_pos(x)",
|
||||||
"rwkv_wkv(x, k, v, r, tf, td, s)",
|
"rwkv_wkv(k, v, r, tf, td, s, sq)",
|
||||||
|
"rwkv_token_shift(xc, xn, sq)",
|
||||||
|
|
||||||
"unary(x)",
|
"unary(x)",
|
||||||
|
|
||||||
|
@ -2947,7 +2949,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"cross_entropy_loss_back(x,y)",
|
"cross_entropy_loss_back(x,y)",
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");
|
static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80");
|
||||||
|
|
||||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||||
|
|
||||||
|
@ -7648,35 +7650,39 @@ struct ggml_tensor * ggml_rwkv_wkv(
|
||||||
struct ggml_tensor * r,
|
struct ggml_tensor * r,
|
||||||
struct ggml_tensor * tf,
|
struct ggml_tensor * tf,
|
||||||
struct ggml_tensor * td,
|
struct ggml_tensor * td,
|
||||||
struct ggml_tensor * state) {
|
struct ggml_tensor * state,
|
||||||
|
struct ggml_tensor * state_seq) {
|
||||||
GGML_ASSERT(ggml_is_contiguous(k));
|
GGML_ASSERT(ggml_is_contiguous(k));
|
||||||
GGML_ASSERT(ggml_is_contiguous(v));
|
GGML_ASSERT(ggml_is_contiguous(v));
|
||||||
GGML_ASSERT(ggml_is_contiguous(r));
|
GGML_ASSERT(ggml_is_contiguous(r));
|
||||||
GGML_ASSERT(ggml_is_contiguous(tf));
|
GGML_ASSERT(ggml_is_contiguous(tf));
|
||||||
GGML_ASSERT(ggml_is_contiguous(td));
|
GGML_ASSERT(ggml_is_contiguous(td));
|
||||||
GGML_ASSERT(ggml_is_contiguous(state));
|
GGML_ASSERT(ggml_is_contiguous(state));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(state_seq));
|
||||||
|
GGML_ASSERT(state_seq->type == GGML_TYPE_I32);
|
||||||
|
|
||||||
const int64_t S = k->ne[0];
|
const int64_t S = k->ne[0];
|
||||||
const int64_t H = k->ne[2];
|
const int64_t H = k->ne[2];
|
||||||
const int64_t n_tokens = k->ne[3];
|
const int64_t n_tokens = k->ne[3];
|
||||||
|
const int64_t n_kv = state_seq->ne[0];
|
||||||
{
|
{
|
||||||
GGML_ASSERT(k->ne[1] == 1);
|
GGML_ASSERT(k->ne[1] == 1);
|
||||||
GGML_ASSERT(v->ne[0] == 1 && v->ne[1] == S && v->ne[2] == H && v->ne[3] == n_tokens);
|
GGML_ASSERT(v->ne[0] == 1 && v->ne[1] == S && v->ne[2] == H && v->ne[3] == n_tokens);
|
||||||
GGML_ASSERT(r->ne[0] == 1 && r->ne[1] == S && r->ne[2] == H && r->ne[3] == n_tokens);
|
GGML_ASSERT(r->ne[0] == 1 && r->ne[1] == S && r->ne[2] == H && r->ne[3] == n_tokens);
|
||||||
// TODO: RWKV v4 and v5
|
// TODO: RWKV v4 and v5
|
||||||
GGML_ASSERT(td->ne[0] == 1 && td->ne[1] == S && td->ne[2] == H && td->ne[3] == n_tokens);
|
GGML_ASSERT(td->ne[0] == 1 && td->ne[1] == S && td->ne[2] == H && td->ne[3] == n_tokens);
|
||||||
GGML_ASSERT(ggml_nelements(state) == S * S * H);
|
GGML_ASSERT(ggml_nelements(state) == S * S * H * n_kv);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_node = false;
|
bool is_node = false;
|
||||||
|
|
||||||
if (k->grad || v->grad || r->grad || tf->grad || td->grad || state->grad) {
|
if (k->grad || v->grad || r->grad || tf->grad || td->grad || state->grad || state_seq->grad) {
|
||||||
GGML_ABORT("fatal error"); // TODO: implement backward
|
GGML_ABORT("fatal error"); // TODO: implement backward
|
||||||
is_node = true;
|
is_node = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// concat output and new_state
|
// concat output and new_state
|
||||||
const int64_t ne[4] = { S * H, n_tokens + S, 1, 1 };
|
const int64_t ne[4] = { S * H, n_tokens + S * n_kv, 1, 1 };
|
||||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
||||||
|
|
||||||
result->op = GGML_OP_RWKV_WKV;
|
result->op = GGML_OP_RWKV_WKV;
|
||||||
|
@ -7687,6 +7693,48 @@ struct ggml_tensor * ggml_rwkv_wkv(
|
||||||
result->src[3] = tf;
|
result->src[3] = tf;
|
||||||
result->src[4] = td;
|
result->src[4] = td;
|
||||||
result->src[5] = state;
|
result->src[5] = state;
|
||||||
|
result->src[6] = state_seq;
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ggml_rwkv_token_shift
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_rwkv_token_shift(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * x_carry,
|
||||||
|
struct ggml_tensor * x_norm,
|
||||||
|
struct ggml_tensor * state_seq) {
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(x_carry));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(x_norm));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(state_seq));
|
||||||
|
GGML_ASSERT(state_seq->type == GGML_TYPE_I32);
|
||||||
|
|
||||||
|
const int64_t n_embd = x_norm->ne[0];
|
||||||
|
const int64_t n_kv = state_seq->ne[0];
|
||||||
|
const int64_t n_tokens = state_seq->ne[1];
|
||||||
|
{
|
||||||
|
GGML_ASSERT(x_norm->ne[0] == n_embd);
|
||||||
|
GGML_ASSERT(x_norm->ne[1] == n_tokens);
|
||||||
|
GGML_ASSERT(ggml_nelements(x_carry) == n_embd * n_kv);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_node = false;
|
||||||
|
|
||||||
|
if (x_carry->grad || x_norm->grad || state_seq->grad) {
|
||||||
|
GGML_ABORT("fatal error"); // TODO: implement backward
|
||||||
|
is_node = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// concat output and new_state
|
||||||
|
const int64_t ne[4] = { n_embd, n_tokens + n_kv, 1, 1 };
|
||||||
|
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
||||||
|
|
||||||
|
result->op = GGML_OP_RWKV_TOKEN_SHIFT;
|
||||||
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||||
|
result->src[0] = x_carry;
|
||||||
|
result->src[1] = x_norm;
|
||||||
|
result->src[2] = state_seq;
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -16857,6 +16905,7 @@ static void ggml_compute_forward_rwkv_wkv_f32(
|
||||||
const size_t T = dst->src[1]->ne[3];
|
const size_t T = dst->src[1]->ne[3];
|
||||||
const size_t C = dst->ne[0];
|
const size_t C = dst->ne[0];
|
||||||
const size_t H = dst->src[1]->ne[2];
|
const size_t H = dst->src[1]->ne[2];
|
||||||
|
const size_t n_kv = dst->src[6]->ne[0];
|
||||||
|
|
||||||
float * dst_data = (float *) dst->data;
|
float * dst_data = (float *) dst->data;
|
||||||
float * state = ((float *) dst->data) + C * T;
|
float * state = ((float *) dst->data) + C * T;
|
||||||
|
@ -16872,7 +16921,8 @@ static void ggml_compute_forward_rwkv_wkv_f32(
|
||||||
float * r = (float *) dst->src[2]->data;
|
float * r = (float *) dst->src[2]->data;
|
||||||
float * time_faaaa = (float *) dst->src[3]->data;
|
float * time_faaaa = (float *) dst->src[3]->data;
|
||||||
float * time_decay = (float *) dst->src[4]->data;
|
float * time_decay = (float *) dst->src[4]->data;
|
||||||
memcpy(state, dst->src[5]->data, (C / H) * C * sizeof(float));
|
int32_t * seq_data = (int32_t *) dst->src[6]->data;
|
||||||
|
memcpy(state, dst->src[5]->data, (C / H) * C * n_kv * sizeof(float));
|
||||||
|
|
||||||
size_t t_stride = H * (C / H);
|
size_t t_stride = H * (C / H);
|
||||||
|
|
||||||
|
@ -16885,6 +16935,7 @@ static void ggml_compute_forward_rwkv_wkv_f32(
|
||||||
// recursive through each token
|
// recursive through each token
|
||||||
for (size_t t = 0; t < T; t++) {
|
for (size_t t = 0; t < T; t++) {
|
||||||
size_t t_offset = t * t_stride;
|
size_t t_offset = t * t_stride;
|
||||||
|
float * state_cur = state + (C / H) * C * seq_data[t * n_kv];
|
||||||
|
|
||||||
for (size_t h = 0; h < H; h++) {
|
for (size_t h = 0; h < H; h++) {
|
||||||
size_t h_offset = h * h_stride;
|
size_t h_offset = h * h_stride;
|
||||||
|
@ -16908,14 +16959,23 @@ static void ggml_compute_forward_rwkv_wkv_f32(
|
||||||
|
|
||||||
float v_val = v[t_h_j_offset];
|
float v_val = v[t_h_j_offset];
|
||||||
float kv_val = v_val * k_val;
|
float kv_val = v_val * k_val;
|
||||||
float prev_state_val = state[h_2d_i_j_offset];
|
float prev_state_val = state_cur[h_2d_i_j_offset];
|
||||||
float temp_val = kv_val * time_faaaa_val + prev_state_val;
|
float temp_val = kv_val * time_faaaa_val + prev_state_val;
|
||||||
dst_data[t_h_j_offset] += temp_val * r_val;
|
dst_data[t_h_j_offset] += temp_val * r_val;
|
||||||
state[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
|
state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (size_t t = 0; t < T; t++) {
|
||||||
|
for (size_t kv = 1; kv < n_kv; kv++) {
|
||||||
|
int64_t seq = seq_data[t * n_kv + kv];
|
||||||
|
if (seq >= 0 && seq_data[(t + 1) * n_kv + kv] != seq) {
|
||||||
|
memcpy(state + (C / H) * C * seq, state + (C / H) * C * seq_data[t * n_kv], (C / H) * C * sizeof(float));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_compute_forward_rwkv_wkv(
|
static void ggml_compute_forward_rwkv_wkv(
|
||||||
|
@ -16936,6 +16996,77 @@ static void ggml_compute_forward_rwkv_wkv(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_compute_forward_rwkv_token_shift_f32(
|
||||||
|
const struct ggml_compute_params * params,
|
||||||
|
struct ggml_tensor * dst) {
|
||||||
|
const int64_t n_embd = dst->ne[0];
|
||||||
|
const int64_t n_kv = dst->src[2]->ne[0];
|
||||||
|
const int64_t n_tokens = dst->src[1]->ne[1];
|
||||||
|
float * dst_data = (float *) dst->data;
|
||||||
|
float * x_carry = (float *) dst->src[0]->data;
|
||||||
|
float * x_norm = (float *) dst->src[1]->data;
|
||||||
|
int32_t * sq_data = (int32_t *) dst->src[2]->data;
|
||||||
|
|
||||||
|
if (params->ith != 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t seq_start = 0;
|
||||||
|
int32_t seq_length = 0;
|
||||||
|
|
||||||
|
for (int i1 = 0; i1 < n_kv; ++i1) {
|
||||||
|
seq_start = -1;
|
||||||
|
// assume that the tokens for each sequence are contiguous
|
||||||
|
for (int i2 = 0; i2 < n_tokens; ++i2) {
|
||||||
|
int32_t seq = sq_data[i2*n_kv];
|
||||||
|
if (seq == i1 && seq_start < 0) {
|
||||||
|
seq_start = i2;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((seq_start >= 0 && seq != i1) || i2 == n_tokens - 1) {
|
||||||
|
seq_length = i2 - seq_start + (i2 == n_tokens - 1);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (seq_start >= 0) {
|
||||||
|
int32_t seq = sq_data[seq_start*n_kv];
|
||||||
|
memcpy(dst_data + seq_start*n_embd, x_carry + seq*n_embd, n_embd*sizeof(float));
|
||||||
|
memcpy(dst_data + (seq_start+1)*n_embd, x_norm + seq_start*n_embd, (seq_length-1)*n_embd*sizeof(float));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i3 = 0; i3 < n_kv; ++i3) {
|
||||||
|
int32_t last_token_pos = 0;
|
||||||
|
for (int i4 = 0; i4 < n_tokens; ++i4) {
|
||||||
|
for (int i5 = 0; i5 < n_kv; ++i5) {
|
||||||
|
if (sq_data[i4*n_kv + i5] == i3) {
|
||||||
|
last_token_pos = i4;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
memcpy(dst_data + (n_tokens + i3)*n_embd, x_norm + last_token_pos*n_embd, n_embd*sizeof(float));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_compute_forward_rwkv_token_shift(
|
||||||
|
const struct ggml_compute_params * params,
|
||||||
|
struct ggml_tensor * dst) {
|
||||||
|
|
||||||
|
const struct ggml_tensor * src0 = dst->src[0];
|
||||||
|
|
||||||
|
switch (src0->type) {
|
||||||
|
case GGML_TYPE_F32:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_rwkv_token_shift_f32(params, dst);
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
GGML_ABORT("fatal error");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_compute_forward_map_unary
|
// ggml_compute_forward_map_unary
|
||||||
|
|
||||||
static void ggml_compute_forward_map_unary_f32(
|
static void ggml_compute_forward_map_unary_f32(
|
||||||
|
@ -17591,6 +17722,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||||
{
|
{
|
||||||
ggml_compute_forward_rwkv_wkv(params, tensor);
|
ggml_compute_forward_rwkv_wkv(params, tensor);
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_RWKV_TOKEN_SHIFT:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_rwkv_token_shift(params, tensor);
|
||||||
|
} break;
|
||||||
case GGML_OP_MAP_UNARY:
|
case GGML_OP_MAP_UNARY:
|
||||||
{
|
{
|
||||||
ggml_unary_op_f32_t fun;
|
ggml_unary_op_f32_t fun;
|
||||||
|
@ -18715,6 +18850,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||||
case GGML_OP_GET_REL_POS:
|
case GGML_OP_GET_REL_POS:
|
||||||
case GGML_OP_ADD_REL_POS:
|
case GGML_OP_ADD_REL_POS:
|
||||||
case GGML_OP_RWKV_WKV:
|
case GGML_OP_RWKV_WKV:
|
||||||
|
case GGML_OP_RWKV_TOKEN_SHIFT:
|
||||||
case GGML_OP_MAP_UNARY:
|
case GGML_OP_MAP_UNARY:
|
||||||
case GGML_OP_MAP_BINARY:
|
case GGML_OP_MAP_BINARY:
|
||||||
case GGML_OP_MAP_CUSTOM1_F32:
|
case GGML_OP_MAP_CUSTOM1_F32:
|
||||||
|
@ -19290,6 +19426,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||||
case GGML_OP_WIN_UNPART:
|
case GGML_OP_WIN_UNPART:
|
||||||
case GGML_OP_GET_REL_POS:
|
case GGML_OP_GET_REL_POS:
|
||||||
case GGML_OP_RWKV_WKV:
|
case GGML_OP_RWKV_WKV:
|
||||||
|
case GGML_OP_RWKV_TOKEN_SHIFT:
|
||||||
case GGML_OP_MAP_UNARY:
|
case GGML_OP_MAP_UNARY:
|
||||||
case GGML_OP_MAP_BINARY:
|
case GGML_OP_MAP_BINARY:
|
||||||
case GGML_OP_MAP_CUSTOM1_F32:
|
case GGML_OP_MAP_CUSTOM1_F32:
|
||||||
|
|
118
src/llama.cpp
118
src/llama.cpp
|
@ -9366,11 +9366,13 @@ static struct ggml_tensor * llm_build_time_mix(
|
||||||
const struct llama_layer * layer,
|
const struct llama_layer * layer,
|
||||||
struct ggml_tensor * current,
|
struct ggml_tensor * current,
|
||||||
struct ggml_tensor * x_prev,
|
struct ggml_tensor * x_prev,
|
||||||
struct ggml_tensor ** wkv_state) {
|
struct ggml_tensor ** wkv_state,
|
||||||
|
struct ggml_tensor * state_seq) {
|
||||||
size_t n_embed = current->ne[0];
|
size_t n_embed = current->ne[0];
|
||||||
size_t n_tokens = current->ne[1];
|
size_t n_tokens = current->ne[1];
|
||||||
size_t head_size = layer->time_mix_first->ne[0];
|
size_t head_size = layer->time_mix_first->ne[0];
|
||||||
size_t head_count = layer->time_mix_first->ne[1];
|
size_t head_count = layer->time_mix_first->ne[1];
|
||||||
|
size_t n_kv = state_seq->ne[0];
|
||||||
|
|
||||||
struct ggml_tensor * sx = ggml_sub(ctx, x_prev, current);
|
struct ggml_tensor * sx = ggml_sub(ctx, x_prev, current);
|
||||||
struct ggml_tensor * xxx = ggml_add_inplace(
|
struct ggml_tensor * xxx = ggml_add_inplace(
|
||||||
|
@ -9515,9 +9517,9 @@ static struct ggml_tensor * llm_build_time_mix(
|
||||||
k = ggml_transpose(ctx, k);
|
k = ggml_transpose(ctx, k);
|
||||||
v = ggml_transpose(ctx, v);
|
v = ggml_transpose(ctx, v);
|
||||||
r = ggml_transpose(ctx, r);
|
r = ggml_transpose(ctx, r);
|
||||||
struct ggml_tensor * wkv_output = ggml_rwkv_wkv(ctx, k, v, r, layer->time_mix_first, w, *wkv_state);
|
struct ggml_tensor * wkv_output = ggml_rwkv_wkv(ctx, k, v, r, layer->time_mix_first, w, *wkv_state, state_seq);
|
||||||
current = ggml_view_1d(ctx, wkv_output, n_embed * n_tokens, 0);
|
current = ggml_view_1d(ctx, wkv_output, n_embed * n_tokens, 0);
|
||||||
*wkv_state = ggml_view_1d(ctx, wkv_output, n_embed * head_size, n_embed * n_tokens * sizeof(float));
|
*wkv_state = ggml_view_1d(ctx, wkv_output, n_embed * head_size * n_kv, n_embed * n_tokens * sizeof(float));
|
||||||
|
|
||||||
// ggml_group_norm considers groups in the third dimension.
|
// ggml_group_norm considers groups in the third dimension.
|
||||||
current = ggml_reshape_4d(ctx, current, 1, 1, n_embed, n_tokens);
|
current = ggml_reshape_4d(ctx, current, 1, 1, n_embed, n_tokens);
|
||||||
|
@ -15092,58 +15094,81 @@ struct llm_build_context {
|
||||||
// Input embeddings, start of the model after tokenizing ({n_embd, n_tokens})
|
// Input embeddings, start of the model after tokenizing ({n_embd, n_tokens})
|
||||||
ggml_tensor * input_embeddings = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
|
ggml_tensor * input_embeddings = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
|
||||||
|
|
||||||
|
struct ggml_tensor * state_mask = build_inp_s_mask();
|
||||||
|
struct ggml_tensor * state_seq = build_inp_s_seq();
|
||||||
|
|
||||||
ggml_tensor * x = llm_build_norm(ctx0, input_embeddings, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1);
|
ggml_tensor * x = llm_build_norm(ctx0, input_embeddings, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1);
|
||||||
|
|
||||||
for (int layer_i = 0; layer_i < n_layer; ++layer_i) {
|
for (int layer_i = 0; layer_i < n_layer; ++layer_i) {
|
||||||
const llama_layer * layer = &model.layers[layer_i];
|
const llama_layer * layer = &model.layers[layer_i];
|
||||||
|
|
||||||
// TODO: handle multiple kv cache cells
|
struct ggml_tensor * token_shift = ggml_reshape_2d(ctx0, kv_self.k_l[layer_i], hparams.n_embd_k_s(), kv_self.size);
|
||||||
struct ggml_tensor * wkv_state = ggml_view_1d(ctx0, kv_self.v_l[layer_i], hparams.n_embd_v_s(), (kv_self.size - 1) * hparams.n_embd_v_s() * ggml_type_size(kv_self.k_l[layer_i]->type));
|
struct ggml_tensor * wkv_states = ggml_reshape_2d(ctx0, kv_self.v_l[layer_i], hparams.n_embd_v_s(), kv_self.size);
|
||||||
struct ggml_tensor * att_shift = ggml_view_1d(ctx0, kv_self.k_l[layer_i], n_embd, (kv_self.size - 1) * 2 * n_embd * ggml_type_size(kv_self.k_l[layer_i]->type));
|
|
||||||
struct ggml_tensor * ffn_shift = ggml_view_1d(ctx0, kv_self.k_l[layer_i], n_embd, ((kv_self.size - 1) * 2 + 1) * n_embd * ggml_type_size(kv_self.k_l[layer_i]->type));
|
|
||||||
|
|
||||||
struct ggml_tensor * x_norm = llm_build_norm(ctx0, x, hparams, layer->attn_norm, layer->attn_norm_b, LLM_NORM, cb, layer_i);
|
{
|
||||||
struct ggml_tensor * x_prev = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens);
|
token_shift = ggml_mul(ctx0,
|
||||||
x_prev = ggml_set_1d(ctx0, x_prev, att_shift, 0);
|
ggml_view_2d(ctx0, token_shift, token_shift->ne[0], n_kv, token_shift->nb[1], kv_head*token_shift->nb[1]),
|
||||||
x_prev = ggml_set_1d(
|
state_mask);
|
||||||
|
wkv_states = ggml_mul(ctx0,
|
||||||
|
ggml_view_2d(ctx0, wkv_states, wkv_states->ne[0], n_kv, wkv_states->nb[1], kv_head*wkv_states->nb[1]),
|
||||||
|
state_mask);
|
||||||
|
}
|
||||||
|
|
||||||
|
token_shift = ggml_cont(
|
||||||
ctx0,
|
ctx0,
|
||||||
x_prev,
|
ggml_permute(
|
||||||
ggml_view_1d(ctx0, x_norm, (n_tokens - 1) * n_embd, 0),
|
ctx0,
|
||||||
n_embd * ggml_type_size(x_prev->type)
|
ggml_reshape_3d(ctx0, token_shift, n_embd, 2, n_kv),
|
||||||
|
0, 2, 1, 3
|
||||||
|
)
|
||||||
);
|
);
|
||||||
|
|
||||||
x = ggml_add(ctx0, x, llm_build_time_mix(ctx0, layer, x_norm, x_prev, &wkv_state));
|
struct ggml_tensor * att_shift = ggml_view_1d(ctx0, token_shift, n_embd * n_kv, 0);
|
||||||
|
struct ggml_tensor * ffn_shift = ggml_view_1d(ctx0, token_shift, n_embd * n_kv, n_embd * n_kv * ggml_element_size(kv_self.k_l[layer_i]));
|
||||||
|
|
||||||
|
struct ggml_tensor * x_norm = llm_build_norm(ctx0, x, hparams, layer->attn_norm, layer->attn_norm_b, LLM_NORM, cb, layer_i);
|
||||||
|
struct ggml_tensor * tmp = ggml_rwkv_token_shift(ctx0, att_shift, x_norm, state_seq);
|
||||||
|
struct ggml_tensor * x_prev = ggml_reshape_2d(
|
||||||
|
ctx0,
|
||||||
|
ggml_view_1d(ctx0, tmp, n_embd * n_tokens, 0),
|
||||||
|
n_embd, n_tokens
|
||||||
|
);
|
||||||
|
|
||||||
|
x = ggml_add(ctx0, x, llm_build_time_mix(ctx0, layer, x_norm, x_prev, &wkv_states, state_seq));
|
||||||
ggml_build_forward_expand(gf, x);
|
ggml_build_forward_expand(gf, x);
|
||||||
|
ggml_build_forward_expand(
|
||||||
|
gf,
|
||||||
|
ggml_cpy(
|
||||||
|
ctx0,
|
||||||
|
wkv_states,
|
||||||
|
ggml_view_1d(
|
||||||
|
ctx0,
|
||||||
|
kv_self.v_l[layer_i],
|
||||||
|
hparams.n_embd_v_s() * n_kv,
|
||||||
|
hparams.n_embd_v_s() * kv_head * ggml_type_size(kv_self.v_l[layer_i]->type)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
);
|
||||||
ggml_build_forward_expand(
|
ggml_build_forward_expand(
|
||||||
gf,
|
gf,
|
||||||
ggml_cpy(
|
ggml_cpy(
|
||||||
ctx0,
|
ctx0,
|
||||||
ggml_view_1d(
|
ggml_view_1d(
|
||||||
ctx0,
|
ctx0,
|
||||||
x_norm,
|
tmp,
|
||||||
n_embd,
|
n_embd * n_kv,
|
||||||
(n_tokens - 1) * n_embd * ggml_type_size(kv_self.k_l[layer_i]->type)
|
n_tokens * n_embd * ggml_type_size(kv_self.k_l[layer_i]->type)
|
||||||
),
|
),
|
||||||
att_shift
|
ggml_view_1d(ctx0, token_shift, n_embd * n_kv, 0)
|
||||||
)
|
|
||||||
);
|
|
||||||
ggml_build_forward_expand(
|
|
||||||
gf,
|
|
||||||
ggml_cpy(
|
|
||||||
ctx0,
|
|
||||||
wkv_state,
|
|
||||||
ggml_view_1d(ctx0, kv_self.v_l[layer_i], hparams.n_embd_v_s(), (kv_self.size - 1) * hparams.n_embd_v_s() * ggml_type_size(kv_self.k_l[layer_i]->type))
|
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
|
||||||
x_norm = llm_build_norm(ctx0, x, hparams, layer->attn_norm_2, layer->attn_norm_2_b, LLM_NORM, cb, layer_i);
|
x_norm = llm_build_norm(ctx0, x, hparams, layer->attn_norm_2, layer->attn_norm_2_b, LLM_NORM, cb, layer_i);
|
||||||
x_prev = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens);
|
tmp = ggml_rwkv_token_shift(ctx0, ffn_shift, x_norm, state_seq);
|
||||||
x_prev = ggml_set_1d(ctx0, x_prev, ffn_shift, 0);
|
x_prev = ggml_reshape_2d(
|
||||||
x_prev = ggml_set_1d(
|
|
||||||
ctx0,
|
ctx0,
|
||||||
x_prev,
|
ggml_view_1d(ctx0, tmp, n_embd * n_tokens, 0),
|
||||||
ggml_view_1d(ctx0, x_norm, (n_tokens - 1) * n_embd, 0),
|
n_embd, n_tokens
|
||||||
n_embd * ggml_type_size(x_prev->type)
|
|
||||||
);
|
);
|
||||||
x = ggml_add(ctx0, x, llm_build_channel_mix(ctx0, layer, x_norm, x_prev));
|
x = ggml_add(ctx0, x, llm_build_channel_mix(ctx0, layer, x_norm, x_prev));
|
||||||
ggml_build_forward_expand(gf, x);
|
ggml_build_forward_expand(gf, x);
|
||||||
|
@ -15153,13 +15178,32 @@ struct llm_build_context {
|
||||||
ctx0,
|
ctx0,
|
||||||
ggml_view_1d(
|
ggml_view_1d(
|
||||||
ctx0,
|
ctx0,
|
||||||
x_norm,
|
tmp,
|
||||||
n_embd,
|
n_embd * n_kv,
|
||||||
(n_tokens - 1) * n_embd * ggml_type_size(kv_self.k_l[layer_i]->type)
|
n_tokens * n_embd * ggml_type_size(kv_self.k_l[layer_i]->type)
|
||||||
),
|
),
|
||||||
ffn_shift
|
ggml_view_1d(ctx0, token_shift, n_embd * n_kv, n_kv * n_embd * ggml_type_size(kv_self.k_l[layer_i]->type))
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
|
||||||
|
token_shift = ggml_cont(
|
||||||
|
ctx0,
|
||||||
|
ggml_permute(
|
||||||
|
ctx0,
|
||||||
|
ggml_reshape_3d(ctx0, token_shift, n_embd, n_kv, 2),
|
||||||
|
0, 2, 1, 3
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
ggml_build_forward_expand(
|
||||||
|
gf,
|
||||||
|
ggml_cpy(
|
||||||
|
ctx0,
|
||||||
|
ggml_view_1d(ctx0, token_shift, n_embd * n_kv * 2, 0),
|
||||||
|
ggml_view_1d(ctx0, kv_self.k_l[layer_i], hparams.n_embd_k_s() * n_kv, hparams.n_embd_k_s() * kv_head * ggml_type_size(kv_self.k_l[layer_i]->type))
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
if ((layer_i + 1) % hparams.rescale_every_n_layers == 0) {
|
if ((layer_i + 1) % hparams.rescale_every_n_layers == 0) {
|
||||||
x = ggml_scale(ctx0, x, 0.5F);
|
x = ggml_scale(ctx0, x, 0.5F);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue