RWKV v6 graph building
Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
parent
5732de89b7
commit
0784a0cf26
3 changed files with 423 additions and 25 deletions
|
@ -512,6 +512,7 @@ extern "C" {
|
||||||
GGML_OP_WIN_UNPART,
|
GGML_OP_WIN_UNPART,
|
||||||
GGML_OP_GET_REL_POS,
|
GGML_OP_GET_REL_POS,
|
||||||
GGML_OP_ADD_REL_POS,
|
GGML_OP_ADD_REL_POS,
|
||||||
|
GGML_OP_RWKV_WKV,
|
||||||
|
|
||||||
GGML_OP_UNARY,
|
GGML_OP_UNARY,
|
||||||
|
|
||||||
|
@ -1896,6 +1897,15 @@ extern "C" {
|
||||||
struct ggml_tensor * pw,
|
struct ggml_tensor * pw,
|
||||||
struct ggml_tensor * ph);
|
struct ggml_tensor * ph);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_rwkv_wkv(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * k,
|
||||||
|
struct ggml_tensor * v,
|
||||||
|
struct ggml_tensor * r,
|
||||||
|
struct ggml_tensor * tf,
|
||||||
|
struct ggml_tensor * td,
|
||||||
|
struct ggml_tensor * state);
|
||||||
|
|
||||||
// custom operators
|
// custom operators
|
||||||
|
|
||||||
typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
|
typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
|
||||||
|
|
149
ggml/src/ggml.c
149
ggml/src/ggml.c
|
@ -2835,6 +2835,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||||
"WIN_UNPART",
|
"WIN_UNPART",
|
||||||
"GET_REL_POS",
|
"GET_REL_POS",
|
||||||
"ADD_REL_POS",
|
"ADD_REL_POS",
|
||||||
|
"RWKV_WKV",
|
||||||
|
|
||||||
"UNARY",
|
"UNARY",
|
||||||
|
|
||||||
|
@ -2853,7 +2854,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||||
"CROSS_ENTROPY_LOSS_BACK",
|
"CROSS_ENTROPY_LOSS_BACK",
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78");
|
static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");
|
||||||
|
|
||||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"none",
|
"none",
|
||||||
|
@ -2927,6 +2928,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"win_unpart(x)",
|
"win_unpart(x)",
|
||||||
"get_rel_pos(x)",
|
"get_rel_pos(x)",
|
||||||
"add_rel_pos(x)",
|
"add_rel_pos(x)",
|
||||||
|
"rwkv_wkv(x, k, v, r, tf, td, s)",
|
||||||
|
|
||||||
"unary(x)",
|
"unary(x)",
|
||||||
|
|
||||||
|
@ -2945,7 +2947,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"cross_entropy_loss_back(x,y)",
|
"cross_entropy_loss_back(x,y)",
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78");
|
static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");
|
||||||
|
|
||||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||||
|
|
||||||
|
@ -7637,6 +7639,57 @@ struct ggml_tensor * ggml_add_rel_pos_inplace(
|
||||||
return ggml_add_rel_pos_impl(ctx, a, pw, ph, true);
|
return ggml_add_rel_pos_impl(ctx, a, pw, ph, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ggml_rwkv_wkv
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_rwkv_wkv(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * k,
|
||||||
|
struct ggml_tensor * v,
|
||||||
|
struct ggml_tensor * r,
|
||||||
|
struct ggml_tensor * tf,
|
||||||
|
struct ggml_tensor * td,
|
||||||
|
struct ggml_tensor * state) {
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(k));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(v));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(r));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(tf));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(td));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(state));
|
||||||
|
|
||||||
|
const int64_t S = k->ne[0];
|
||||||
|
const int64_t H = k->ne[2];
|
||||||
|
const int64_t n_tokens = k->ne[3];
|
||||||
|
{
|
||||||
|
GGML_ASSERT(k->ne[1] == 1);
|
||||||
|
GGML_ASSERT(v->ne[0] == 1 && v->ne[1] == S && v->ne[2] == H && v->ne[3] == n_tokens);
|
||||||
|
GGML_ASSERT(r->ne[0] == 1 && r->ne[1] == S && r->ne[2] == H && r->ne[3] == n_tokens);
|
||||||
|
// TODO: RWKV v4 and v5
|
||||||
|
GGML_ASSERT(td->ne[0] == 1 && td->ne[1] == S && td->ne[2] == H && td->ne[3] == n_tokens);
|
||||||
|
GGML_ASSERT(ggml_nelements(state) == S * S * H);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_node = false;
|
||||||
|
|
||||||
|
if (k->grad || v->grad || r->grad || tf->grad || td->grad || state->grad) {
|
||||||
|
GGML_ABORT("fatal error"); // TODO: implement backward
|
||||||
|
is_node = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t ne[4] = { S * H, n_tokens, 1, 1 };
|
||||||
|
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
||||||
|
|
||||||
|
result->op = GGML_OP_RWKV_WKV;
|
||||||
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||||
|
result->src[0] = k;
|
||||||
|
result->src[1] = v;
|
||||||
|
result->src[2] = r;
|
||||||
|
result->src[3] = tf;
|
||||||
|
result->src[4] = td;
|
||||||
|
result->src[5] = state;
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_unary
|
// ggml_unary
|
||||||
|
|
||||||
static struct ggml_tensor * ggml_unary_impl(
|
static struct ggml_tensor * ggml_unary_impl(
|
||||||
|
@ -16795,6 +16848,92 @@ static void ggml_compute_forward_add_rel_pos(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ggml_compute_forward_rwkv_wkv
|
||||||
|
|
||||||
|
static void ggml_compute_forward_rwkv_wkv_f32(
|
||||||
|
const struct ggml_compute_params * params,
|
||||||
|
struct ggml_tensor * dst) {
|
||||||
|
const size_t T = dst->ne[1];
|
||||||
|
const size_t C = dst->ne[0];
|
||||||
|
const size_t H = dst->src[1]->ne[2];
|
||||||
|
|
||||||
|
float * dst_data = (float *) dst->data;
|
||||||
|
|
||||||
|
if (params->ith != 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
memset(dst_data, 0, T * C * sizeof(float));
|
||||||
|
|
||||||
|
float * k = (float *) dst->src[0]->data;
|
||||||
|
float * v = (float *) dst->src[1]->data;
|
||||||
|
float * r = (float *) dst->src[2]->data;
|
||||||
|
float * time_faaaa = (float *) dst->src[3]->data;
|
||||||
|
float * time_decay = (float *) dst->src[4]->data;
|
||||||
|
float * state = (float *) dst->src[5]->data;
|
||||||
|
|
||||||
|
size_t t_stride = H * (C / H);
|
||||||
|
|
||||||
|
size_t h_stride = C / H;
|
||||||
|
size_t h_stride_2d = (C / H) * (C / H);
|
||||||
|
|
||||||
|
// basically fused operations:
|
||||||
|
// dst = r @ (time_faaaa * (k @ v) + state),
|
||||||
|
// state = time_decay * state + (k @ v),
|
||||||
|
// recursive through each token
|
||||||
|
for (size_t t = 0; t < T; t++) {
|
||||||
|
size_t t_offset = t * t_stride;
|
||||||
|
|
||||||
|
for (size_t h = 0; h < H; h++) {
|
||||||
|
size_t h_offset = h * h_stride;
|
||||||
|
size_t t_h_offset = t_offset + h_offset;
|
||||||
|
size_t h_2d_offset = h * h_stride_2d;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < C / H; i++) {
|
||||||
|
size_t t_h_i_offset = t_h_offset + i;
|
||||||
|
size_t h_i_offset = h_offset + i;
|
||||||
|
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
|
||||||
|
|
||||||
|
float k_val = k[t_h_i_offset];
|
||||||
|
float r_val = r[t_h_i_offset];
|
||||||
|
float time_faaaa_val = time_faaaa[h_i_offset];
|
||||||
|
// RWKV v6: different time_decay for each token.
|
||||||
|
float time_decay_val = time_decay[t_h_i_offset];
|
||||||
|
|
||||||
|
for (size_t j = 0; j < C / H; j ++) {
|
||||||
|
size_t t_h_j_offset = t_h_offset + j;
|
||||||
|
size_t h_2d_i_j_offset = h_2d_i_offset + j;
|
||||||
|
|
||||||
|
float v_val = v[t_h_j_offset];
|
||||||
|
float kv_val = v_val * k_val;
|
||||||
|
float prev_state_val = state[h_2d_i_j_offset];
|
||||||
|
float temp_val = kv_val * time_faaaa_val + prev_state_val;
|
||||||
|
dst_data[t_h_j_offset] += temp_val * r_val;
|
||||||
|
state[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_compute_forward_rwkv_wkv(
|
||||||
|
const struct ggml_compute_params * params,
|
||||||
|
struct ggml_tensor * dst) {
|
||||||
|
|
||||||
|
const struct ggml_tensor * src0 = dst->src[0];
|
||||||
|
|
||||||
|
switch (src0->type) {
|
||||||
|
case GGML_TYPE_F32:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_rwkv_wkv_f32(params, dst);
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
GGML_ABORT("fatal error");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_compute_forward_map_unary
|
// ggml_compute_forward_map_unary
|
||||||
|
|
||||||
static void ggml_compute_forward_map_unary_f32(
|
static void ggml_compute_forward_map_unary_f32(
|
||||||
|
@ -17446,6 +17585,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||||
{
|
{
|
||||||
ggml_compute_forward_add_rel_pos(params, tensor);
|
ggml_compute_forward_add_rel_pos(params, tensor);
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_RWKV_WKV:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_rwkv_wkv(params, tensor);
|
||||||
|
} break;
|
||||||
case GGML_OP_MAP_UNARY:
|
case GGML_OP_MAP_UNARY:
|
||||||
{
|
{
|
||||||
ggml_unary_op_f32_t fun;
|
ggml_unary_op_f32_t fun;
|
||||||
|
@ -18569,6 +18712,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_GET_REL_POS:
|
case GGML_OP_GET_REL_POS:
|
||||||
case GGML_OP_ADD_REL_POS:
|
case GGML_OP_ADD_REL_POS:
|
||||||
|
case GGML_OP_RWKV_WKV:
|
||||||
case GGML_OP_MAP_UNARY:
|
case GGML_OP_MAP_UNARY:
|
||||||
case GGML_OP_MAP_BINARY:
|
case GGML_OP_MAP_BINARY:
|
||||||
case GGML_OP_MAP_CUSTOM1_F32:
|
case GGML_OP_MAP_CUSTOM1_F32:
|
||||||
|
@ -19143,6 +19287,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||||
case GGML_OP_WIN_PART:
|
case GGML_OP_WIN_PART:
|
||||||
case GGML_OP_WIN_UNPART:
|
case GGML_OP_WIN_UNPART:
|
||||||
case GGML_OP_GET_REL_POS:
|
case GGML_OP_GET_REL_POS:
|
||||||
|
case GGML_OP_RWKV_WKV:
|
||||||
case GGML_OP_MAP_UNARY:
|
case GGML_OP_MAP_UNARY:
|
||||||
case GGML_OP_MAP_BINARY:
|
case GGML_OP_MAP_BINARY:
|
||||||
case GGML_OP_MAP_CUSTOM1_F32:
|
case GGML_OP_MAP_CUSTOM1_F32:
|
||||||
|
|
289
src/llama.cpp
289
src/llama.cpp
|
@ -3512,7 +3512,7 @@ static bool llama_kv_cache_find_slot(
|
||||||
const uint32_t n_seq_tokens = batch.n_seq_tokens;
|
const uint32_t n_seq_tokens = batch.n_seq_tokens;
|
||||||
|
|
||||||
if (cache.recurrent) {
|
if (cache.recurrent) {
|
||||||
// For recurrent state architectures (like Mamba),
|
// For recurrent state architectures (like Mamba or RWKV),
|
||||||
// each cache cell can store the state for a whole sequence.
|
// each cache cell can store the state for a whole sequence.
|
||||||
// A slot should be always be contiguous.
|
// A slot should be always be contiguous.
|
||||||
|
|
||||||
|
@ -3761,7 +3761,7 @@ static bool llama_kv_cache_seq_rm(
|
||||||
if (p0 < 0) p0 = 0;
|
if (p0 < 0) p0 = 0;
|
||||||
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
||||||
|
|
||||||
// models like Mamba can't have a state partially erased
|
// models like Mamba or RWKV can't have a state partially erased
|
||||||
if (cache.recurrent) {
|
if (cache.recurrent) {
|
||||||
if (seq_id >= (int64_t) cache.size) {
|
if (seq_id >= (int64_t) cache.size) {
|
||||||
// could be fatal
|
// could be fatal
|
||||||
|
@ -3897,7 +3897,7 @@ static void llama_kv_cache_seq_add(
|
||||||
if (p0 == p1) return;
|
if (p0 == p1) return;
|
||||||
|
|
||||||
if (cache.recurrent) {
|
if (cache.recurrent) {
|
||||||
// for Mamba-like models, only the pos needs to be shifted
|
// for Mamba-like or RWKV models, only the pos needs to be shifted
|
||||||
if (0 <= seq_id && seq_id < (int64_t) cache.size) {
|
if (0 <= seq_id && seq_id < (int64_t) cache.size) {
|
||||||
const int32_t tail_id = cache.cells[seq_id].tail;
|
const int32_t tail_id = cache.cells[seq_id].tail;
|
||||||
if (tail_id >= 0) {
|
if (tail_id >= 0) {
|
||||||
|
@ -3946,7 +3946,7 @@ static void llama_kv_cache_seq_div(
|
||||||
if (p0 == p1) return;
|
if (p0 == p1) return;
|
||||||
|
|
||||||
if (cache.recurrent) {
|
if (cache.recurrent) {
|
||||||
// for Mamba-like models, only the pos needs to be changed
|
// for Mamba-like or RWKV models, only the pos needs to be changed
|
||||||
if (0 <= seq_id && seq_id < (int64_t) cache.size) {
|
if (0 <= seq_id && seq_id < (int64_t) cache.size) {
|
||||||
const int32_t tail_id = cache.cells[seq_id].tail;
|
const int32_t tail_id = cache.cells[seq_id].tail;
|
||||||
if (tail_id >= 0) {
|
if (tail_id >= 0) {
|
||||||
|
@ -5885,8 +5885,9 @@ static void llm_load_hparams(
|
||||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||||
|
|
||||||
// TODO: Re-using mamba keys right now, but RWKV isn't state-space
|
// TODO: Re-using mamba keys right now, but RWKV isn't state-space
|
||||||
ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
|
ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
|
||||||
ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
|
ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
|
||||||
|
ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
|
||||||
} break;
|
} break;
|
||||||
default: (void)0;
|
default: (void)0;
|
||||||
}
|
}
|
||||||
|
@ -8323,7 +8324,7 @@ static bool llm_load_tensors(
|
||||||
|
|
||||||
// Block 0, LN0
|
// Block 0, LN0
|
||||||
model.tok_norm = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd});
|
model.tok_norm = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd});
|
||||||
model.tok_norm_b = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd});
|
model.tok_norm_b = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd});
|
||||||
|
|
||||||
// output
|
// output
|
||||||
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
|
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
|
||||||
|
@ -8348,8 +8349,8 @@ static bool llm_load_tensors(
|
||||||
layer.attn_norm_2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd});
|
layer.attn_norm_2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd});
|
||||||
layer.attn_norm_2_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd});
|
layer.attn_norm_2_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd});
|
||||||
|
|
||||||
layer.time_mix_w1 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {time_mix_extra_dim * 5, n_embd});
|
layer.time_mix_w1 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, time_mix_extra_dim * 5});
|
||||||
layer.time_mix_w2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {n_embd, time_mix_extra_dim, 5});
|
layer.time_mix_w2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5});
|
||||||
|
|
||||||
layer.time_mix_lerp_x = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1});
|
layer.time_mix_lerp_x = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1});
|
||||||
layer.time_mix_lerp_w = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1});
|
layer.time_mix_lerp_w = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1});
|
||||||
|
@ -8361,8 +8362,8 @@ static bool llm_load_tensors(
|
||||||
// TODO: Parametrize hardcoded dimensions for first & decay
|
// TODO: Parametrize hardcoded dimensions for first & decay
|
||||||
layer.time_mix_first = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size});
|
layer.time_mix_first = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size});
|
||||||
layer.time_mix_decay = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd});
|
layer.time_mix_decay = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd});
|
||||||
layer.time_mix_decay_w1 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {time_decay_extra_dim, n_embd});
|
layer.time_mix_decay_w1 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {n_embd, time_decay_extra_dim});
|
||||||
layer.time_mix_decay_w2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {attn_hidden_size, time_decay_extra_dim});
|
layer.time_mix_decay_w2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {time_decay_extra_dim, attn_hidden_size});
|
||||||
layer.time_mix_key = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd});
|
layer.time_mix_key = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd});
|
||||||
layer.time_mix_value = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd});
|
layer.time_mix_value = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd});
|
||||||
layer.time_mix_receptance = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd});
|
layer.time_mix_receptance = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd});
|
||||||
|
@ -9344,9 +9345,198 @@ static struct ggml_tensor * llm_build_time_mix(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
const struct llama_layer * layer,
|
const struct llama_layer * layer,
|
||||||
struct ggml_tensor * current,
|
struct ggml_tensor * current,
|
||||||
int layer_i) {
|
struct ggml_tensor * x_prev,
|
||||||
|
struct ggml_tensor * wkv_state) {
|
||||||
|
size_t n_embed = current->ne[0];
|
||||||
|
size_t n_tokens = current->ne[1];
|
||||||
|
size_t head_size = layer->time_mix_first->ne[0];
|
||||||
|
size_t head_count = layer->time_mix_first->ne[1];
|
||||||
|
|
||||||
return current;
|
struct ggml_tensor * sx = ggml_sub(ctx, x_prev, current);
|
||||||
|
struct ggml_tensor * xxx = ggml_add_inplace(
|
||||||
|
ctx,
|
||||||
|
ggml_mul(ctx, sx, layer->time_mix_lerp_x),
|
||||||
|
current
|
||||||
|
);
|
||||||
|
|
||||||
|
xxx = ggml_reshape_4d(
|
||||||
|
ctx,
|
||||||
|
ggml_tanh_inplace(
|
||||||
|
ctx,
|
||||||
|
ggml_mul_mat(ctx, layer->time_mix_w1, xxx)
|
||||||
|
),
|
||||||
|
layer->time_mix_w1->ne[1] / 5, 1, 5, n_tokens
|
||||||
|
);
|
||||||
|
|
||||||
|
xxx = ggml_cont(
|
||||||
|
ctx,
|
||||||
|
ggml_permute(ctx, xxx, 0, 1, 3, 2)
|
||||||
|
);
|
||||||
|
|
||||||
|
xxx = ggml_mul_mat(
|
||||||
|
ctx,
|
||||||
|
ggml_reshape_4d(
|
||||||
|
ctx,
|
||||||
|
layer->time_mix_w2,
|
||||||
|
layer->time_mix_w2->ne[0], layer->time_mix_w2->ne[1], 1, 5
|
||||||
|
),
|
||||||
|
xxx
|
||||||
|
);
|
||||||
|
|
||||||
|
struct ggml_tensor *mw = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * n_tokens);
|
||||||
|
mw = ggml_reshape_2d(
|
||||||
|
ctx,
|
||||||
|
ggml_set_1d(ctx, mw, ggml_view_1d(ctx, xxx, n_embed * n_tokens, 0), 0),
|
||||||
|
n_embed, n_tokens
|
||||||
|
);
|
||||||
|
|
||||||
|
struct ggml_tensor *mk = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * n_tokens);
|
||||||
|
mk = ggml_reshape_2d(
|
||||||
|
ctx,
|
||||||
|
ggml_set_1d_inplace(ctx, mk, ggml_view_1d(ctx, xxx, n_embed * n_tokens, n_embed * n_tokens * sizeof(float)), 0),
|
||||||
|
n_embed, n_tokens
|
||||||
|
);
|
||||||
|
|
||||||
|
struct ggml_tensor *mv = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * n_tokens);
|
||||||
|
mv = ggml_reshape_2d(
|
||||||
|
ctx,
|
||||||
|
ggml_set_1d_inplace(ctx, mv, ggml_view_1d(ctx, xxx, n_embed * n_tokens, n_embed * n_tokens * 2 * sizeof(float)), 0),
|
||||||
|
n_embed, n_tokens
|
||||||
|
);
|
||||||
|
|
||||||
|
struct ggml_tensor *mr = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * n_tokens);
|
||||||
|
mr = ggml_reshape_2d(
|
||||||
|
ctx,
|
||||||
|
ggml_set_1d_inplace(ctx, mr, ggml_view_1d(ctx, xxx, n_embed * n_tokens, n_embed * n_tokens * 3 * sizeof(float)), 0),
|
||||||
|
n_embed, n_tokens
|
||||||
|
);
|
||||||
|
|
||||||
|
struct ggml_tensor *mg = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * n_tokens);
|
||||||
|
mg = ggml_reshape_2d(
|
||||||
|
ctx,
|
||||||
|
ggml_set_1d_inplace(ctx, mg, ggml_view_1d(ctx, xxx, n_embed * n_tokens, n_embed * n_tokens * 4 * sizeof(float)), 0),
|
||||||
|
n_embed, n_tokens
|
||||||
|
);
|
||||||
|
|
||||||
|
struct ggml_tensor * xw = ggml_add_inplace(
|
||||||
|
ctx,
|
||||||
|
ggml_mul_inplace(
|
||||||
|
ctx,
|
||||||
|
ggml_add(ctx, mw, layer->time_mix_lerp_w),
|
||||||
|
sx
|
||||||
|
),
|
||||||
|
current
|
||||||
|
);
|
||||||
|
|
||||||
|
struct ggml_tensor * xk = ggml_add_inplace(
|
||||||
|
ctx,
|
||||||
|
ggml_mul_inplace(
|
||||||
|
ctx,
|
||||||
|
ggml_add(ctx, mk, layer->time_mix_lerp_k),
|
||||||
|
sx
|
||||||
|
),
|
||||||
|
current
|
||||||
|
);
|
||||||
|
|
||||||
|
struct ggml_tensor * xv = ggml_add_inplace(
|
||||||
|
ctx,
|
||||||
|
ggml_mul_inplace(
|
||||||
|
ctx,
|
||||||
|
ggml_add(ctx, mv, layer->time_mix_lerp_v),
|
||||||
|
sx
|
||||||
|
),
|
||||||
|
current
|
||||||
|
);
|
||||||
|
|
||||||
|
struct ggml_tensor * xr = ggml_add_inplace(
|
||||||
|
ctx,
|
||||||
|
ggml_mul_inplace(
|
||||||
|
ctx,
|
||||||
|
ggml_add(ctx, mr, layer->time_mix_lerp_r),
|
||||||
|
sx
|
||||||
|
),
|
||||||
|
current
|
||||||
|
);
|
||||||
|
|
||||||
|
struct ggml_tensor * xg = ggml_add_inplace(
|
||||||
|
ctx,
|
||||||
|
ggml_mul_inplace(
|
||||||
|
ctx,
|
||||||
|
ggml_add(ctx, mg, layer->time_mix_lerp_g),
|
||||||
|
sx
|
||||||
|
),
|
||||||
|
current
|
||||||
|
);
|
||||||
|
|
||||||
|
struct ggml_tensor * r = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer->time_mix_receptance, xr), head_size, 1, head_count, n_tokens);
|
||||||
|
struct ggml_tensor * k = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer->time_mix_key, xk), 1, head_size, head_count, n_tokens);
|
||||||
|
struct ggml_tensor * v = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer->time_mix_value, xv), head_size, 1, head_count, n_tokens);
|
||||||
|
struct ggml_tensor * g = ggml_silu_inplace(
|
||||||
|
ctx,
|
||||||
|
ggml_mul_mat(ctx, layer->time_mix_gate, xg)
|
||||||
|
);
|
||||||
|
|
||||||
|
struct ggml_tensor * w = ggml_mul_mat(
|
||||||
|
ctx,
|
||||||
|
layer->time_mix_decay_w2,
|
||||||
|
ggml_tanh_inplace(
|
||||||
|
ctx,
|
||||||
|
ggml_mul_mat(ctx, layer->time_mix_decay_w1, xw)
|
||||||
|
)
|
||||||
|
);
|
||||||
|
w = ggml_add_inplace(
|
||||||
|
ctx,
|
||||||
|
w,
|
||||||
|
ggml_reshape_1d(ctx, layer->time_mix_decay, n_embed)
|
||||||
|
);
|
||||||
|
w = ggml_exp(ctx, ggml_neg_inplace(ctx, ggml_exp(ctx, w)));
|
||||||
|
w = ggml_reshape_4d(ctx, w, 1, head_size, head_count, n_tokens);
|
||||||
|
|
||||||
|
k = ggml_transpose(ctx, k);
|
||||||
|
v = ggml_transpose(ctx, v);
|
||||||
|
r = ggml_transpose(ctx, r);
|
||||||
|
current = ggml_rwkv_wkv(ctx, k, v, r, layer->time_mix_first, w, wkv_state);
|
||||||
|
|
||||||
|
// ggml_group_norm considers groups in the third dimension.
|
||||||
|
current = ggml_reshape_4d(ctx, current, 1, 1, n_embed, n_tokens);
|
||||||
|
current = ggml_group_norm(ctx, current, head_count, 64e-5f);
|
||||||
|
// Convert back to a regular vector.
|
||||||
|
current = ggml_reshape_2d(ctx, current, n_embed, n_tokens);
|
||||||
|
current = ggml_add_inplace(
|
||||||
|
ctx,
|
||||||
|
ggml_mul_inplace(
|
||||||
|
ctx,
|
||||||
|
current,
|
||||||
|
layer->time_mix_ln
|
||||||
|
),
|
||||||
|
layer->time_mix_ln_b
|
||||||
|
);
|
||||||
|
|
||||||
|
current = ggml_mul(ctx, current, g);
|
||||||
|
|
||||||
|
return ggml_mul_mat(ctx, layer->time_mix_output, current);
|
||||||
|
}
|
||||||
|
|
||||||
|
static struct ggml_tensor * llm_build_channel_mix(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
const struct llama_layer * layer,
|
||||||
|
struct ggml_tensor * current,
|
||||||
|
struct ggml_tensor * x_prev) {
|
||||||
|
|
||||||
|
struct ggml_tensor * sx = ggml_sub(ctx, x_prev, current);
|
||||||
|
struct ggml_tensor * xk = ggml_add_inplace(
|
||||||
|
ctx,
|
||||||
|
ggml_mul(ctx, sx, layer->channel_mix_lerp_k),
|
||||||
|
current
|
||||||
|
);
|
||||||
|
struct ggml_tensor * xr = ggml_add_inplace(
|
||||||
|
ctx,
|
||||||
|
ggml_mul(ctx, sx, layer->channel_mix_lerp_r),
|
||||||
|
current
|
||||||
|
);
|
||||||
|
struct ggml_tensor * r = ggml_sigmoid_inplace(ctx, ggml_mul_mat(ctx, layer->channel_mix_receptance, xr));
|
||||||
|
struct ggml_tensor * k = ggml_sqr_inplace(ctx, ggml_relu_inplace(ctx, ggml_mul_mat(ctx, layer->channel_mix_key, xk)));
|
||||||
|
return ggml_mul_inplace(ctx, r, ggml_mul_mat(ctx, layer->channel_mix_value, k));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct llm_build_context {
|
struct llm_build_context {
|
||||||
|
@ -14874,32 +15064,85 @@ struct llm_build_context {
|
||||||
ggml_cgraph * build_rwkv() {
|
ggml_cgraph * build_rwkv() {
|
||||||
ggml_cgraph *gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
ggml_cgraph *gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
||||||
|
|
||||||
|
// Token shift state dimensions should be 2 * n_emb
|
||||||
|
GGML_ASSERT(n_embd == hparams.n_embd_k_s() / 2);
|
||||||
|
|
||||||
// Input embeddings, start of the model after tokenizing ({n_embd, n_tokens})
|
// Input embeddings, start of the model after tokenizing ({n_embd, n_tokens})
|
||||||
ggml_tensor * input_embeddings = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
|
ggml_tensor * input_embeddings = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
|
||||||
|
|
||||||
// x = self.layer_norm(x, self.w.blocks[0].ln0)
|
ggml_tensor * x = llm_build_norm(ctx0, input_embeddings, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1);
|
||||||
ggml_tensor * current = llm_build_norm(ctx0, input_embeddings, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1);
|
|
||||||
|
|
||||||
for (int layer_i = 0; layer_i < n_layer; ++layer_i) {
|
for (int layer_i = 0; layer_i < n_layer; ++layer_i) {
|
||||||
const llama_layer * layer = &model.layers[layer_i];
|
const llama_layer * layer = &model.layers[layer_i];
|
||||||
|
|
||||||
current = llm_build_norm(ctx0, current, hparams, layer->attn_norm, layer->attn_norm_b, LLM_NORM, cb, layer_i);
|
// TODO: handle multiple kv cache cells
|
||||||
current = llm_build_time_mix(ctx0, layer, current, layer_i);
|
struct ggml_tensor * wkv_state = ggml_view_1d(ctx0, kv_self.v_l[layer_i], hparams.n_embd_v_s(), (kv_self.size - 1) * hparams.n_embd_v_s() * ggml_type_size(kv_self.k_l[layer_i]->type));
|
||||||
|
struct ggml_tensor * att_shift = ggml_view_1d(ctx0, kv_self.k_l[layer_i], n_embd, (kv_self.size - 1) * 2 * n_embd * ggml_type_size(kv_self.k_l[layer_i]->type));
|
||||||
|
struct ggml_tensor * ffn_shift = ggml_view_1d(ctx0, kv_self.k_l[layer_i], n_embd, ((kv_self.size - 1) * 2 + 1) * n_embd * ggml_type_size(kv_self.k_l[layer_i]->type));
|
||||||
|
|
||||||
current = llm_build_norm(ctx0, current, hparams, layer->attn_norm_2, layer->attn_norm_2_b, LLM_NORM, cb, layer_i);
|
struct ggml_tensor * x_norm = llm_build_norm(ctx0, x, hparams, layer->attn_norm, layer->attn_norm_b, LLM_NORM, cb, layer_i);
|
||||||
|
struct ggml_tensor * x_prev = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens);
|
||||||
|
x_prev = ggml_set_1d(ctx0, x_prev, att_shift, 0);
|
||||||
|
x_prev = ggml_set_1d(
|
||||||
|
ctx0,
|
||||||
|
x_prev,
|
||||||
|
ggml_view_1d(ctx0, x_norm, (n_tokens - 1) * n_embd, 0),
|
||||||
|
n_embd * ggml_type_size(x_prev->type)
|
||||||
|
);
|
||||||
|
|
||||||
|
x = ggml_add(ctx0, x, llm_build_time_mix(ctx0, layer, x_norm, x_prev, wkv_state));
|
||||||
|
ggml_build_forward_expand(gf, x);
|
||||||
|
ggml_build_forward_expand(
|
||||||
|
gf,
|
||||||
|
ggml_cpy(
|
||||||
|
ctx0,
|
||||||
|
ggml_view_1d(
|
||||||
|
ctx0,
|
||||||
|
x_norm,
|
||||||
|
n_embd,
|
||||||
|
(n_tokens - 1) * n_embd * ggml_type_size(kv_self.k_l[layer_i]->type)
|
||||||
|
),
|
||||||
|
att_shift
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
x_norm = llm_build_norm(ctx0, x, hparams, layer->attn_norm_2, layer->attn_norm_2_b, LLM_NORM, cb, layer_i);
|
||||||
|
x_prev = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens);
|
||||||
|
x_prev = ggml_set_1d(ctx0, x_prev, ffn_shift, 0);
|
||||||
|
x_prev = ggml_set_1d(
|
||||||
|
ctx0,
|
||||||
|
x_prev,
|
||||||
|
ggml_view_1d(ctx0, x_norm, (n_tokens - 1) * n_embd, 0),
|
||||||
|
n_embd * ggml_type_size(x_prev->type)
|
||||||
|
);
|
||||||
|
x = ggml_add(ctx0, x, llm_build_channel_mix(ctx0, layer, x_norm, x_prev));
|
||||||
|
ggml_build_forward_expand(gf, x);
|
||||||
|
ggml_build_forward_expand(
|
||||||
|
gf,
|
||||||
|
ggml_cpy(
|
||||||
|
ctx0,
|
||||||
|
ggml_view_1d(
|
||||||
|
ctx0,
|
||||||
|
x_norm,
|
||||||
|
n_embd,
|
||||||
|
(n_tokens - 1) * n_embd * ggml_type_size(kv_self.k_l[layer_i]->type)
|
||||||
|
),
|
||||||
|
ffn_shift
|
||||||
|
)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Something related to skipping tokens, specifics unclear
|
// Something related to skipping tokens, specifics unclear
|
||||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||||
current = ggml_get_rows(ctx0, current, inp_out_ids);
|
x = ggml_get_rows(ctx0, x, inp_out_ids);
|
||||||
|
|
||||||
// Output head, convert result vector to logits
|
// Output head, convert result vector to logits
|
||||||
current = llm_build_norm(ctx0, current, hparams, model.output_norm, model.output_norm_b, LLM_NORM, cb, -1);
|
x = llm_build_norm(ctx0, x, hparams, model.output_norm, model.output_norm_b, LLM_NORM, cb, -1);
|
||||||
current = ggml_mul_mat(ctx0, model.output, current);
|
x = ggml_mul_mat(ctx0, model.output, x);
|
||||||
|
|
||||||
// Mark the output as being the result
|
// Mark the output as being the result
|
||||||
cb(current, "result_output", -1);
|
cb(x, "result_output", -1);
|
||||||
ggml_build_forward_expand(gf, current);
|
ggml_build_forward_expand(gf, x);
|
||||||
|
|
||||||
return gf;
|
return gf;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue