build_rwkv: Avoid using inplace operations

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
Molly Sophia 2024-08-11 12:06:16 +08:00
parent 6ae2f4866f
commit 8bc1f9ae80

View file

@ -9364,36 +9364,29 @@ static struct ggml_tensor * llm_build_mamba(
static struct ggml_tensor * llm_build_time_mix( static struct ggml_tensor * llm_build_time_mix(
struct ggml_context * ctx, struct ggml_context * ctx,
const struct llama_layer * layer, const struct llama_layer * layer,
struct ggml_tensor * current, struct ggml_tensor * cur,
struct ggml_tensor * x_prev, struct ggml_tensor * x_prev,
struct ggml_tensor ** wkv_state, struct ggml_tensor ** wkv_state,
struct ggml_tensor * state_seq) { struct ggml_tensor * state_seq) {
size_t n_embed = current->ne[0]; size_t n_embed = cur->ne[0];
size_t n_tokens = current->ne[1]; size_t n_tokens = cur->ne[1];
size_t head_size = layer->time_mix_first->ne[0]; size_t head_size = layer->time_mix_first->ne[0];
size_t head_count = layer->time_mix_first->ne[1]; size_t head_count = layer->time_mix_first->ne[1];
size_t n_kv = state_seq->ne[0]; size_t n_kv = state_seq->ne[0];
struct ggml_tensor * sx = ggml_sub(ctx, x_prev, current); struct ggml_tensor * sx = ggml_sub(ctx, x_prev, cur);
struct ggml_tensor * xxx = ggml_add_inplace( struct ggml_tensor * xxx = ggml_add(ctx, ggml_mul(ctx, sx, layer->time_mix_lerp_x), cur);
ctx,
ggml_mul(ctx, sx, layer->time_mix_lerp_x),
current
);
xxx = ggml_reshape_4d( xxx = ggml_reshape_4d(
ctx, ctx,
ggml_tanh_inplace( ggml_tanh(
ctx, ctx,
ggml_mul_mat(ctx, layer->time_mix_w1, xxx) ggml_mul_mat(ctx, layer->time_mix_w1, xxx)
), ),
layer->time_mix_w1->ne[1] / 5, 1, 5, n_tokens layer->time_mix_w1->ne[1] / 5, 1, 5, n_tokens
); );
xxx = ggml_cont( xxx = ggml_cont(ctx, ggml_permute(ctx, xxx, 0, 1, 3, 2));
ctx,
ggml_permute(ctx, xxx, 0, 1, 3, 2)
);
xxx = ggml_mul_mat( xxx = ggml_mul_mat(
ctx, ctx,
@ -9415,85 +9408,85 @@ static struct ggml_tensor * llm_build_time_mix(
struct ggml_tensor *mk = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * n_tokens); struct ggml_tensor *mk = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * n_tokens);
mk = ggml_reshape_2d( mk = ggml_reshape_2d(
ctx, ctx,
ggml_set_1d_inplace(ctx, mk, ggml_view_1d(ctx, xxx, n_embed * n_tokens, n_embed * n_tokens * sizeof(float)), 0), ggml_set_1d(ctx, mk, ggml_view_1d(ctx, xxx, n_embed * n_tokens, n_embed * n_tokens * sizeof(float)), 0),
n_embed, n_tokens n_embed, n_tokens
); );
struct ggml_tensor *mv = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * n_tokens); struct ggml_tensor *mv = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * n_tokens);
mv = ggml_reshape_2d( mv = ggml_reshape_2d(
ctx, ctx,
ggml_set_1d_inplace(ctx, mv, ggml_view_1d(ctx, xxx, n_embed * n_tokens, n_embed * n_tokens * 2 * sizeof(float)), 0), ggml_set_1d(ctx, mv, ggml_view_1d(ctx, xxx, n_embed * n_tokens, n_embed * n_tokens * 2 * sizeof(float)), 0),
n_embed, n_tokens n_embed, n_tokens
); );
struct ggml_tensor *mr = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * n_tokens); struct ggml_tensor *mr = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * n_tokens);
mr = ggml_reshape_2d( mr = ggml_reshape_2d(
ctx, ctx,
ggml_set_1d_inplace(ctx, mr, ggml_view_1d(ctx, xxx, n_embed * n_tokens, n_embed * n_tokens * 3 * sizeof(float)), 0), ggml_set_1d(ctx, mr, ggml_view_1d(ctx, xxx, n_embed * n_tokens, n_embed * n_tokens * 3 * sizeof(float)), 0),
n_embed, n_tokens n_embed, n_tokens
); );
struct ggml_tensor *mg = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * n_tokens); struct ggml_tensor *mg = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * n_tokens);
mg = ggml_reshape_2d( mg = ggml_reshape_2d(
ctx, ctx,
ggml_set_1d_inplace(ctx, mg, ggml_view_1d(ctx, xxx, n_embed * n_tokens, n_embed * n_tokens * 4 * sizeof(float)), 0), ggml_set_1d(ctx, mg, ggml_view_1d(ctx, xxx, n_embed * n_tokens, n_embed * n_tokens * 4 * sizeof(float)), 0),
n_embed, n_tokens n_embed, n_tokens
); );
struct ggml_tensor * xw = ggml_add_inplace( struct ggml_tensor * xw = ggml_add(
ctx, ctx,
ggml_mul_inplace( ggml_mul(
ctx, ctx,
ggml_add(ctx, mw, layer->time_mix_lerp_w), ggml_add(ctx, mw, layer->time_mix_lerp_w),
sx sx
), ),
current cur
); );
struct ggml_tensor * xk = ggml_add_inplace( struct ggml_tensor * xk = ggml_add(
ctx, ctx,
ggml_mul_inplace( ggml_mul(
ctx, ctx,
ggml_add(ctx, mk, layer->time_mix_lerp_k), ggml_add(ctx, mk, layer->time_mix_lerp_k),
sx sx
), ),
current cur
); );
struct ggml_tensor * xv = ggml_add_inplace( struct ggml_tensor * xv = ggml_add(
ctx, ctx,
ggml_mul_inplace( ggml_mul(
ctx, ctx,
ggml_add(ctx, mv, layer->time_mix_lerp_v), ggml_add(ctx, mv, layer->time_mix_lerp_v),
sx sx
), ),
current cur
); );
struct ggml_tensor * xr = ggml_add_inplace( struct ggml_tensor * xr = ggml_add(
ctx, ctx,
ggml_mul_inplace( ggml_mul(
ctx, ctx,
ggml_add(ctx, mr, layer->time_mix_lerp_r), ggml_add(ctx, mr, layer->time_mix_lerp_r),
sx sx
), ),
current cur
); );
struct ggml_tensor * xg = ggml_add_inplace( struct ggml_tensor * xg = ggml_add(
ctx, ctx,
ggml_mul_inplace( ggml_mul(
ctx, ctx,
ggml_add(ctx, mg, layer->time_mix_lerp_g), ggml_add(ctx, mg, layer->time_mix_lerp_g),
sx sx
), ),
current cur
); );
struct ggml_tensor * r = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer->time_mix_receptance, xr), head_size, 1, head_count, n_tokens); struct ggml_tensor * r = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer->time_mix_receptance, xr), head_size, 1, head_count, n_tokens);
struct ggml_tensor * k = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer->time_mix_key, xk), 1, head_size, head_count, n_tokens); struct ggml_tensor * k = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer->time_mix_key, xk), 1, head_size, head_count, n_tokens);
struct ggml_tensor * v = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer->time_mix_value, xv), head_size, 1, head_count, n_tokens); struct ggml_tensor * v = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer->time_mix_value, xv), head_size, 1, head_count, n_tokens);
struct ggml_tensor * g = ggml_silu_inplace( struct ggml_tensor * g = ggml_silu(
ctx, ctx,
ggml_mul_mat(ctx, layer->time_mix_gate, xg) ggml_mul_mat(ctx, layer->time_mix_gate, xg)
); );
@ -9501,16 +9494,12 @@ static struct ggml_tensor * llm_build_time_mix(
struct ggml_tensor * w = ggml_mul_mat( struct ggml_tensor * w = ggml_mul_mat(
ctx, ctx,
layer->time_mix_decay_w2, layer->time_mix_decay_w2,
ggml_tanh_inplace( ggml_tanh(
ctx, ctx,
ggml_mul_mat(ctx, layer->time_mix_decay_w1, xw) ggml_mul_mat(ctx, layer->time_mix_decay_w1, xw)
) )
); );
w = ggml_add_inplace( w = ggml_add(ctx, w, ggml_reshape_1d(ctx, layer->time_mix_decay, n_embed));
ctx,
w,
ggml_reshape_1d(ctx, layer->time_mix_decay, n_embed)
);
w = ggml_exp(ctx, ggml_neg(ctx, ggml_exp(ctx, w))); w = ggml_exp(ctx, ggml_neg(ctx, ggml_exp(ctx, w)));
w = ggml_reshape_4d(ctx, w, 1, head_size, head_count, n_tokens); w = ggml_reshape_4d(ctx, w, 1, head_size, head_count, n_tokens);
@ -9518,48 +9507,39 @@ static struct ggml_tensor * llm_build_time_mix(
v = ggml_transpose(ctx, v); v = ggml_transpose(ctx, v);
r = ggml_transpose(ctx, r); r = ggml_transpose(ctx, r);
struct ggml_tensor * wkv_output = ggml_rwkv_wkv(ctx, k, v, r, layer->time_mix_first, w, *wkv_state, state_seq); struct ggml_tensor * wkv_output = ggml_rwkv_wkv(ctx, k, v, r, layer->time_mix_first, w, *wkv_state, state_seq);
current = ggml_view_1d(ctx, wkv_output, n_embed * n_tokens, 0); cur = ggml_view_1d(ctx, wkv_output, n_embed * n_tokens, 0);
*wkv_state = ggml_view_1d(ctx, wkv_output, n_embed * head_size * n_kv, n_embed * n_tokens * sizeof(float)); *wkv_state = ggml_view_1d(ctx, wkv_output, n_embed * head_size * n_kv, n_embed * n_tokens * sizeof(float));
// ggml_group_norm considers groups in the third dimension. // ggml_group_norm considers groups in the third dimension.
current = ggml_reshape_4d(ctx, current, 1, 1, n_embed, n_tokens); cur = ggml_reshape_4d(ctx, cur, 1, 1, n_embed, n_tokens);
current = ggml_group_norm(ctx, current, head_count, 64e-5f); cur = ggml_group_norm(ctx, cur, head_count, 64e-5f);
// Convert back to a regular vector. // Convert back to a regular vector.
current = ggml_reshape_2d(ctx, current, n_embed, n_tokens); cur = ggml_reshape_2d(ctx, cur, n_embed, n_tokens);
current = ggml_add_inplace( cur = ggml_add(ctx, ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b);
ctx,
ggml_mul_inplace(
ctx,
current,
layer->time_mix_ln
),
layer->time_mix_ln_b
);
current = ggml_mul(ctx, current, g); cur = ggml_mul(ctx, cur, g);
return ggml_mul_mat(ctx, layer->time_mix_output, current); return ggml_mul_mat(ctx, layer->time_mix_output, cur);
} }
static struct ggml_tensor * llm_build_channel_mix( static struct ggml_tensor * llm_build_channel_mix(
struct ggml_context * ctx, struct ggml_context * ctx,
const struct llama_layer * layer, const struct llama_layer * layer,
struct ggml_tensor * current, struct ggml_tensor * cur,
struct ggml_tensor * x_prev) { struct ggml_tensor * x_prev) {
struct ggml_tensor * sx = ggml_sub(ctx, x_prev, current); struct ggml_tensor * sx = ggml_sub(ctx, x_prev, cur);
struct ggml_tensor * xk = ggml_add_inplace( struct ggml_tensor * xk = ggml_add(ctx, ggml_mul(ctx, sx, layer->channel_mix_lerp_k), cur);
struct ggml_tensor * xr = ggml_add(ctx, ggml_mul(ctx, sx, layer->channel_mix_lerp_r), cur);
struct ggml_tensor * r = ggml_sigmoid(ctx, ggml_mul_mat(ctx, layer->channel_mix_receptance, xr));
struct ggml_tensor * k = ggml_sqr(
ctx, ctx,
ggml_mul(ctx, sx, layer->channel_mix_lerp_k), ggml_relu(
current
);
struct ggml_tensor * xr = ggml_add_inplace(
ctx, ctx,
ggml_mul(ctx, sx, layer->channel_mix_lerp_r), ggml_mul_mat(ctx, layer->channel_mix_key, xk)
current )
); );
struct ggml_tensor * r = ggml_sigmoid_inplace(ctx, ggml_mul_mat(ctx, layer->channel_mix_receptance, xr)); return ggml_mul(ctx, r, ggml_mul_mat(ctx, layer->channel_mix_value, k));
struct ggml_tensor * k = ggml_sqr_inplace(ctx, ggml_relu_inplace(ctx, ggml_mul_mat(ctx, layer->channel_mix_key, xk)));
return ggml_mul_inplace(ctx, r, ggml_mul_mat(ctx, layer->channel_mix_value, k));
} }
struct llm_build_context { struct llm_build_context {
@ -15090,13 +15070,12 @@ struct llm_build_context {
// Token shift state dimensions should be 2 * n_emb // Token shift state dimensions should be 2 * n_emb
GGML_ASSERT(n_embd == hparams.n_embd_k_s() / 2); GGML_ASSERT(n_embd == hparams.n_embd_k_s() / 2);
// Input embeddings, start of the model after tokenizing ({n_embd, n_tokens})
ggml_tensor * input_embeddings = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); ggml_tensor * input_embeddings = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
struct ggml_tensor * state_mask = build_inp_s_mask(); struct ggml_tensor * state_mask = build_inp_s_mask();
struct ggml_tensor * state_seq = build_inp_s_seq(); struct ggml_tensor * state_seq = build_inp_s_seq();
ggml_tensor * x = llm_build_norm(ctx0, input_embeddings, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1); ggml_tensor * cur = llm_build_norm(ctx0, input_embeddings, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1);
for (int layer_i = 0; layer_i < n_layer; ++layer_i) { for (int layer_i = 0; layer_i < n_layer; ++layer_i) {
const llama_layer * layer = &model.layers[layer_i]; const llama_layer * layer = &model.layers[layer_i];
@ -15125,7 +15104,7 @@ struct llm_build_context {
struct ggml_tensor * att_shift = ggml_view_1d(ctx0, token_shift, n_embd * n_kv, 0); struct ggml_tensor * att_shift = ggml_view_1d(ctx0, token_shift, n_embd * n_kv, 0);
struct ggml_tensor * ffn_shift = ggml_view_1d(ctx0, token_shift, n_embd * n_kv, n_embd * n_kv * ggml_element_size(kv_self.k_l[layer_i])); struct ggml_tensor * ffn_shift = ggml_view_1d(ctx0, token_shift, n_embd * n_kv, n_embd * n_kv * ggml_element_size(kv_self.k_l[layer_i]));
struct ggml_tensor * x_norm = llm_build_norm(ctx0, x, hparams, layer->attn_norm, layer->attn_norm_b, LLM_NORM, cb, layer_i); struct ggml_tensor * x_norm = llm_build_norm(ctx0, cur, hparams, layer->attn_norm, layer->attn_norm_b, LLM_NORM, cb, layer_i);
struct ggml_tensor * tmp = ggml_rwkv_token_shift(ctx0, att_shift, x_norm, state_seq); struct ggml_tensor * tmp = ggml_rwkv_token_shift(ctx0, att_shift, x_norm, state_seq);
struct ggml_tensor * x_prev = ggml_reshape_2d( struct ggml_tensor * x_prev = ggml_reshape_2d(
ctx0, ctx0,
@ -15133,8 +15112,8 @@ struct llm_build_context {
n_embd, n_tokens n_embd, n_tokens
); );
x = ggml_add(ctx0, x, llm_build_time_mix(ctx0, layer, x_norm, x_prev, &wkv_states, state_seq)); cur = ggml_add(ctx0, cur, llm_build_time_mix(ctx0, layer, x_norm, x_prev, &wkv_states, state_seq));
ggml_build_forward_expand(gf, x); ggml_build_forward_expand(gf, cur);
ggml_build_forward_expand( ggml_build_forward_expand(
gf, gf,
ggml_cpy( ggml_cpy(
@ -15162,15 +15141,15 @@ struct llm_build_context {
) )
); );
x_norm = llm_build_norm(ctx0, x, hparams, layer->attn_norm_2, layer->attn_norm_2_b, LLM_NORM, cb, layer_i); x_norm = llm_build_norm(ctx0, cur, hparams, layer->attn_norm_2, layer->attn_norm_2_b, LLM_NORM, cb, layer_i);
tmp = ggml_rwkv_token_shift(ctx0, ffn_shift, x_norm, state_seq); tmp = ggml_rwkv_token_shift(ctx0, ffn_shift, x_norm, state_seq);
x_prev = ggml_reshape_2d( x_prev = ggml_reshape_2d(
ctx0, ctx0,
ggml_view_1d(ctx0, tmp, n_embd * n_tokens, 0), ggml_view_1d(ctx0, tmp, n_embd * n_tokens, 0),
n_embd, n_tokens n_embd, n_tokens
); );
x = ggml_add(ctx0, x, llm_build_channel_mix(ctx0, layer, x_norm, x_prev)); cur = ggml_add(ctx0, cur, llm_build_channel_mix(ctx0, layer, x_norm, x_prev));
ggml_build_forward_expand(gf, x); ggml_build_forward_expand(gf, cur);
ggml_build_forward_expand( ggml_build_forward_expand(
gf, gf,
ggml_cpy( ggml_cpy(
@ -15204,21 +15183,18 @@ struct llm_build_context {
); );
if ((layer_i + 1) % hparams.rescale_every_n_layers == 0) { if ((layer_i + 1) % hparams.rescale_every_n_layers == 0) {
x = ggml_scale(ctx0, x, 0.5F); cur = ggml_scale(ctx0, cur, 0.5F);
} }
} }
// Something related to skipping tokens, specifics unclear
ggml_tensor * inp_out_ids = build_inp_out_ids(); ggml_tensor * inp_out_ids = build_inp_out_ids();
x = ggml_get_rows(ctx0, x, inp_out_ids); cur = ggml_get_rows(ctx0, cur, inp_out_ids);
// Output head, convert result vector to logits cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM, cb, -1);
x = llm_build_norm(ctx0, x, hparams, model.output_norm, model.output_norm_b, LLM_NORM, cb, -1); cur = ggml_mul_mat(ctx0, model.output, cur);
x = ggml_mul_mat(ctx0, model.output, x);
// Mark the output as being the result cb(cur, "result_output", -1);
cb(x, "result_output", -1); ggml_build_forward_expand(gf, cur);
ggml_build_forward_expand(gf, x);
return gf; return gf;
} }