From 01dcf4bb7706c11bfd4754139693d2df40cfb3d3 Mon Sep 17 00:00:00 2001 From: Molly Sophia Date: Fri, 9 Aug 2024 20:51:00 +0800 Subject: [PATCH] Fix parallel inferencing for RWKV Signed-off-by: Molly Sophia --- ggml/include/ggml.h | 10 ++- ggml/src/ggml.c | 157 +++++++++++++++++++++++++++++++++++++++++--- src/llama.cpp | 118 ++++++++++++++++++++++----------- 3 files changed, 237 insertions(+), 48 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 39aff9e39..76a3176a1 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -513,6 +513,7 @@ extern "C" { GGML_OP_GET_REL_POS, GGML_OP_ADD_REL_POS, GGML_OP_RWKV_WKV, + GGML_OP_RWKV_TOKEN_SHIFT, GGML_OP_UNARY, @@ -1904,7 +1905,14 @@ extern "C" { struct ggml_tensor * r, struct ggml_tensor * tf, struct ggml_tensor * td, - struct ggml_tensor * state); + struct ggml_tensor * state, + struct ggml_tensor * state_seq); + + GGML_API struct ggml_tensor * ggml_rwkv_token_shift( + struct ggml_context * ctx, + struct ggml_tensor * x_carry, + struct ggml_tensor * x_norm, + struct ggml_tensor * state_seq); // custom operators diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 17e92eff2..06d8a8654 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2836,6 +2836,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GET_REL_POS", "ADD_REL_POS", "RWKV_WKV", + "RWKV_TOKEN_SHIFT", "UNARY", @@ -2854,7 +2855,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79"); +static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -2928,7 +2929,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "win_unpart(x)", "get_rel_pos(x)", "add_rel_pos(x)", - "rwkv_wkv(x, k, v, r, tf, td, s)", + "rwkv_wkv(k, v, r, tf, td, s, sq)", + "rwkv_token_shift(xc, xn, sq)", "unary(x)", @@ -2947,7 +2949,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79"); +static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -7648,35 +7650,39 @@ struct ggml_tensor * ggml_rwkv_wkv( struct ggml_tensor * r, struct ggml_tensor * tf, struct ggml_tensor * td, - struct ggml_tensor * state) { + struct ggml_tensor * state, + struct ggml_tensor * state_seq) { GGML_ASSERT(ggml_is_contiguous(k)); GGML_ASSERT(ggml_is_contiguous(v)); GGML_ASSERT(ggml_is_contiguous(r)); GGML_ASSERT(ggml_is_contiguous(tf)); GGML_ASSERT(ggml_is_contiguous(td)); GGML_ASSERT(ggml_is_contiguous(state)); + GGML_ASSERT(ggml_is_contiguous(state_seq)); + GGML_ASSERT(state_seq->type == GGML_TYPE_I32); const int64_t S = k->ne[0]; const int64_t H = k->ne[2]; const int64_t n_tokens = k->ne[3]; + const int64_t n_kv = state_seq->ne[0]; { GGML_ASSERT(k->ne[1] == 1); GGML_ASSERT(v->ne[0] == 1 && v->ne[1] == S && v->ne[2] == H && v->ne[3] == n_tokens); GGML_ASSERT(r->ne[0] == 1 && r->ne[1] == S && r->ne[2] == H && r->ne[3] == n_tokens); // TODO: RWKV v4 and v5 GGML_ASSERT(td->ne[0] == 1 && td->ne[1] == S && td->ne[2] == H && td->ne[3] == n_tokens); - GGML_ASSERT(ggml_nelements(state) == S * S * H); + GGML_ASSERT(ggml_nelements(state) == S * S * H * n_kv); } bool is_node = false; - if (k->grad || v->grad || r->grad || tf->grad || td->grad || state->grad) { + if (k->grad || v->grad || r->grad || tf->grad || td->grad || state->grad || state_seq->grad) { GGML_ABORT("fatal error"); // TODO: implement backward is_node = true; } // concat output and new_state - const int64_t ne[4] = { S * H, n_tokens + S, 1, 1 }; + const int64_t ne[4] = { S * H, n_tokens + S * n_kv, 1, 1 }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); result->op = GGML_OP_RWKV_WKV; @@ -7687,6 +7693,48 @@ struct ggml_tensor * ggml_rwkv_wkv( result->src[3] = tf; result->src[4] = td; result->src[5] = state; + result->src[6] = state_seq; + + return result; +} + +// ggml_rwkv_token_shift + +struct ggml_tensor * ggml_rwkv_token_shift( + struct ggml_context * ctx, + struct ggml_tensor * x_carry, + struct ggml_tensor * x_norm, + struct ggml_tensor * state_seq) { + GGML_ASSERT(ggml_is_contiguous(x_carry)); + GGML_ASSERT(ggml_is_contiguous(x_norm)); + GGML_ASSERT(ggml_is_contiguous(state_seq)); + GGML_ASSERT(state_seq->type == GGML_TYPE_I32); + + const int64_t n_embd = x_norm->ne[0]; + const int64_t n_kv = state_seq->ne[0]; + const int64_t n_tokens = state_seq->ne[1]; + { + GGML_ASSERT(x_norm->ne[0] == n_embd); + GGML_ASSERT(x_norm->ne[1] == n_tokens); + GGML_ASSERT(ggml_nelements(x_carry) == n_embd * n_kv); + } + + bool is_node = false; + + if (x_carry->grad || x_norm->grad || state_seq->grad) { + GGML_ABORT("fatal error"); // TODO: implement backward + is_node = true; + } + + // concat output and new_state + const int64_t ne[4] = { n_embd, n_tokens + n_kv, 1, 1 }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + result->op = GGML_OP_RWKV_TOKEN_SHIFT; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = x_carry; + result->src[1] = x_norm; + result->src[2] = state_seq; return result; } @@ -16857,6 +16905,7 @@ static void ggml_compute_forward_rwkv_wkv_f32( const size_t T = dst->src[1]->ne[3]; const size_t C = dst->ne[0]; const size_t H = dst->src[1]->ne[2]; + const size_t n_kv = dst->src[6]->ne[0]; float * dst_data = (float *) dst->data; float * state = ((float *) dst->data) + C * T; @@ -16872,7 +16921,8 @@ static void ggml_compute_forward_rwkv_wkv_f32( float * r = (float *) dst->src[2]->data; float * time_faaaa = (float *) dst->src[3]->data; float * time_decay = (float *) dst->src[4]->data; - memcpy(state, dst->src[5]->data, (C / H) * C * sizeof(float)); + int32_t * seq_data = (int32_t *) dst->src[6]->data; + memcpy(state, dst->src[5]->data, (C / H) * C * n_kv * sizeof(float)); size_t t_stride = H * (C / H); @@ -16885,6 +16935,7 @@ static void ggml_compute_forward_rwkv_wkv_f32( // recursive through each token for (size_t t = 0; t < T; t++) { size_t t_offset = t * t_stride; + float * state_cur = state + (C / H) * C * seq_data[t * n_kv]; for (size_t h = 0; h < H; h++) { size_t h_offset = h * h_stride; @@ -16908,14 +16959,23 @@ static void ggml_compute_forward_rwkv_wkv_f32( float v_val = v[t_h_j_offset]; float kv_val = v_val * k_val; - float prev_state_val = state[h_2d_i_j_offset]; + float prev_state_val = state_cur[h_2d_i_j_offset]; float temp_val = kv_val * time_faaaa_val + prev_state_val; dst_data[t_h_j_offset] += temp_val * r_val; - state[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val; + state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val; } } } } + + for (size_t t = 0; t < T; t++) { + for (size_t kv = 1; kv < n_kv; kv++) { + int64_t seq = seq_data[t * n_kv + kv]; + if (seq >= 0 && seq_data[(t + 1) * n_kv + kv] != seq) { + memcpy(state + (C / H) * C * seq, state + (C / H) * C * seq_data[t * n_kv], (C / H) * C * sizeof(float)); + } + } + } } static void ggml_compute_forward_rwkv_wkv( @@ -16936,6 +16996,77 @@ static void ggml_compute_forward_rwkv_wkv( } } +static void ggml_compute_forward_rwkv_token_shift_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + const int64_t n_embd = dst->ne[0]; + const int64_t n_kv = dst->src[2]->ne[0]; + const int64_t n_tokens = dst->src[1]->ne[1]; + float * dst_data = (float *) dst->data; + float * x_carry = (float *) dst->src[0]->data; + float * x_norm = (float *) dst->src[1]->data; + int32_t * sq_data = (int32_t *) dst->src[2]->data; + + if (params->ith != 0) { + return; + } + + int32_t seq_start = 0; + int32_t seq_length = 0; + + for (int i1 = 0; i1 < n_kv; ++i1) { + seq_start = -1; + // assume that the tokens for each sequence are contiguous + for (int i2 = 0; i2 < n_tokens; ++i2) { + int32_t seq = sq_data[i2*n_kv]; + if (seq == i1 && seq_start < 0) { + seq_start = i2; + } + + if ((seq_start >= 0 && seq != i1) || i2 == n_tokens - 1) { + seq_length = i2 - seq_start + (i2 == n_tokens - 1); + break; + } + } + + if (seq_start >= 0) { + int32_t seq = sq_data[seq_start*n_kv]; + memcpy(dst_data + seq_start*n_embd, x_carry + seq*n_embd, n_embd*sizeof(float)); + memcpy(dst_data + (seq_start+1)*n_embd, x_norm + seq_start*n_embd, (seq_length-1)*n_embd*sizeof(float)); + } + } + + for (int i3 = 0; i3 < n_kv; ++i3) { + int32_t last_token_pos = 0; + for (int i4 = 0; i4 < n_tokens; ++i4) { + for (int i5 = 0; i5 < n_kv; ++i5) { + if (sq_data[i4*n_kv + i5] == i3) { + last_token_pos = i4; + } + } + } + memcpy(dst_data + (n_tokens + i3)*n_embd, x_norm + last_token_pos*n_embd, n_embd*sizeof(float)); + } +} + +static void ggml_compute_forward_rwkv_token_shift( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_rwkv_token_shift_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_map_unary static void ggml_compute_forward_map_unary_f32( @@ -17591,6 +17722,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_rwkv_wkv(params, tensor); } break; + case GGML_OP_RWKV_TOKEN_SHIFT: + { + ggml_compute_forward_rwkv_token_shift(params, tensor); + } break; case GGML_OP_MAP_UNARY: { ggml_unary_op_f32_t fun; @@ -18715,6 +18850,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor case GGML_OP_GET_REL_POS: case GGML_OP_ADD_REL_POS: case GGML_OP_RWKV_WKV: + case GGML_OP_RWKV_TOKEN_SHIFT: case GGML_OP_MAP_UNARY: case GGML_OP_MAP_BINARY: case GGML_OP_MAP_CUSTOM1_F32: @@ -19290,6 +19426,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_WIN_UNPART: case GGML_OP_GET_REL_POS: case GGML_OP_RWKV_WKV: + case GGML_OP_RWKV_TOKEN_SHIFT: case GGML_OP_MAP_UNARY: case GGML_OP_MAP_BINARY: case GGML_OP_MAP_CUSTOM1_F32: diff --git a/src/llama.cpp b/src/llama.cpp index 5e474d61d..9606ae0b9 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -9366,11 +9366,13 @@ static struct ggml_tensor * llm_build_time_mix( const struct llama_layer * layer, struct ggml_tensor * current, struct ggml_tensor * x_prev, - struct ggml_tensor ** wkv_state) { + struct ggml_tensor ** wkv_state, + struct ggml_tensor * state_seq) { size_t n_embed = current->ne[0]; size_t n_tokens = current->ne[1]; size_t head_size = layer->time_mix_first->ne[0]; size_t head_count = layer->time_mix_first->ne[1]; + size_t n_kv = state_seq->ne[0]; struct ggml_tensor * sx = ggml_sub(ctx, x_prev, current); struct ggml_tensor * xxx = ggml_add_inplace( @@ -9515,9 +9517,9 @@ static struct ggml_tensor * llm_build_time_mix( k = ggml_transpose(ctx, k); v = ggml_transpose(ctx, v); r = ggml_transpose(ctx, r); - struct ggml_tensor * wkv_output = ggml_rwkv_wkv(ctx, k, v, r, layer->time_mix_first, w, *wkv_state); + struct ggml_tensor * wkv_output = ggml_rwkv_wkv(ctx, k, v, r, layer->time_mix_first, w, *wkv_state, state_seq); current = ggml_view_1d(ctx, wkv_output, n_embed * n_tokens, 0); - *wkv_state = ggml_view_1d(ctx, wkv_output, n_embed * head_size, n_embed * n_tokens * sizeof(float)); + *wkv_state = ggml_view_1d(ctx, wkv_output, n_embed * head_size * n_kv, n_embed * n_tokens * sizeof(float)); // ggml_group_norm considers groups in the third dimension. current = ggml_reshape_4d(ctx, current, 1, 1, n_embed, n_tokens); @@ -15092,58 +15094,81 @@ struct llm_build_context { // Input embeddings, start of the model after tokenizing ({n_embd, n_tokens}) ggml_tensor * input_embeddings = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + struct ggml_tensor * state_mask = build_inp_s_mask(); + struct ggml_tensor * state_seq = build_inp_s_seq(); + ggml_tensor * x = llm_build_norm(ctx0, input_embeddings, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1); for (int layer_i = 0; layer_i < n_layer; ++layer_i) { const llama_layer * layer = &model.layers[layer_i]; - // TODO: handle multiple kv cache cells - struct ggml_tensor * wkv_state = ggml_view_1d(ctx0, kv_self.v_l[layer_i], hparams.n_embd_v_s(), (kv_self.size - 1) * hparams.n_embd_v_s() * ggml_type_size(kv_self.k_l[layer_i]->type)); - struct ggml_tensor * att_shift = ggml_view_1d(ctx0, kv_self.k_l[layer_i], n_embd, (kv_self.size - 1) * 2 * n_embd * ggml_type_size(kv_self.k_l[layer_i]->type)); - struct ggml_tensor * ffn_shift = ggml_view_1d(ctx0, kv_self.k_l[layer_i], n_embd, ((kv_self.size - 1) * 2 + 1) * n_embd * ggml_type_size(kv_self.k_l[layer_i]->type)); + struct ggml_tensor * token_shift = ggml_reshape_2d(ctx0, kv_self.k_l[layer_i], hparams.n_embd_k_s(), kv_self.size); + struct ggml_tensor * wkv_states = ggml_reshape_2d(ctx0, kv_self.v_l[layer_i], hparams.n_embd_v_s(), kv_self.size); - struct ggml_tensor * x_norm = llm_build_norm(ctx0, x, hparams, layer->attn_norm, layer->attn_norm_b, LLM_NORM, cb, layer_i); - struct ggml_tensor * x_prev = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); - x_prev = ggml_set_1d(ctx0, x_prev, att_shift, 0); - x_prev = ggml_set_1d( + { + token_shift = ggml_mul(ctx0, + ggml_view_2d(ctx0, token_shift, token_shift->ne[0], n_kv, token_shift->nb[1], kv_head*token_shift->nb[1]), + state_mask); + wkv_states = ggml_mul(ctx0, + ggml_view_2d(ctx0, wkv_states, wkv_states->ne[0], n_kv, wkv_states->nb[1], kv_head*wkv_states->nb[1]), + state_mask); + } + + token_shift = ggml_cont( ctx0, - x_prev, - ggml_view_1d(ctx0, x_norm, (n_tokens - 1) * n_embd, 0), - n_embd * ggml_type_size(x_prev->type) + ggml_permute( + ctx0, + ggml_reshape_3d(ctx0, token_shift, n_embd, 2, n_kv), + 0, 2, 1, 3 + ) ); - x = ggml_add(ctx0, x, llm_build_time_mix(ctx0, layer, x_norm, x_prev, &wkv_state)); + struct ggml_tensor * att_shift = ggml_view_1d(ctx0, token_shift, n_embd * n_kv, 0); + struct ggml_tensor * ffn_shift = ggml_view_1d(ctx0, token_shift, n_embd * n_kv, n_embd * n_kv * ggml_element_size(kv_self.k_l[layer_i])); + + struct ggml_tensor * x_norm = llm_build_norm(ctx0, x, hparams, layer->attn_norm, layer->attn_norm_b, LLM_NORM, cb, layer_i); + struct ggml_tensor * tmp = ggml_rwkv_token_shift(ctx0, att_shift, x_norm, state_seq); + struct ggml_tensor * x_prev = ggml_reshape_2d( + ctx0, + ggml_view_1d(ctx0, tmp, n_embd * n_tokens, 0), + n_embd, n_tokens + ); + + x = ggml_add(ctx0, x, llm_build_time_mix(ctx0, layer, x_norm, x_prev, &wkv_states, state_seq)); ggml_build_forward_expand(gf, x); + ggml_build_forward_expand( + gf, + ggml_cpy( + ctx0, + wkv_states, + ggml_view_1d( + ctx0, + kv_self.v_l[layer_i], + hparams.n_embd_v_s() * n_kv, + hparams.n_embd_v_s() * kv_head * ggml_type_size(kv_self.v_l[layer_i]->type) + ) + ) + ); ggml_build_forward_expand( gf, ggml_cpy( ctx0, ggml_view_1d( ctx0, - x_norm, - n_embd, - (n_tokens - 1) * n_embd * ggml_type_size(kv_self.k_l[layer_i]->type) + tmp, + n_embd * n_kv, + n_tokens * n_embd * ggml_type_size(kv_self.k_l[layer_i]->type) ), - att_shift - ) - ); - ggml_build_forward_expand( - gf, - ggml_cpy( - ctx0, - wkv_state, - ggml_view_1d(ctx0, kv_self.v_l[layer_i], hparams.n_embd_v_s(), (kv_self.size - 1) * hparams.n_embd_v_s() * ggml_type_size(kv_self.k_l[layer_i]->type)) + ggml_view_1d(ctx0, token_shift, n_embd * n_kv, 0) ) ); x_norm = llm_build_norm(ctx0, x, hparams, layer->attn_norm_2, layer->attn_norm_2_b, LLM_NORM, cb, layer_i); - x_prev = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); - x_prev = ggml_set_1d(ctx0, x_prev, ffn_shift, 0); - x_prev = ggml_set_1d( + tmp = ggml_rwkv_token_shift(ctx0, ffn_shift, x_norm, state_seq); + x_prev = ggml_reshape_2d( ctx0, - x_prev, - ggml_view_1d(ctx0, x_norm, (n_tokens - 1) * n_embd, 0), - n_embd * ggml_type_size(x_prev->type) + ggml_view_1d(ctx0, tmp, n_embd * n_tokens, 0), + n_embd, n_tokens ); x = ggml_add(ctx0, x, llm_build_channel_mix(ctx0, layer, x_norm, x_prev)); ggml_build_forward_expand(gf, x); @@ -15153,13 +15178,32 @@ struct llm_build_context { ctx0, ggml_view_1d( ctx0, - x_norm, - n_embd, - (n_tokens - 1) * n_embd * ggml_type_size(kv_self.k_l[layer_i]->type) + tmp, + n_embd * n_kv, + n_tokens * n_embd * ggml_type_size(kv_self.k_l[layer_i]->type) ), - ffn_shift + ggml_view_1d(ctx0, token_shift, n_embd * n_kv, n_kv * n_embd * ggml_type_size(kv_self.k_l[layer_i]->type)) ) ); + + token_shift = ggml_cont( + ctx0, + ggml_permute( + ctx0, + ggml_reshape_3d(ctx0, token_shift, n_embd, n_kv, 2), + 0, 2, 1, 3 + ) + ); + + ggml_build_forward_expand( + gf, + ggml_cpy( + ctx0, + ggml_view_1d(ctx0, token_shift, n_embd * n_kv * 2, 0), + ggml_view_1d(ctx0, kv_self.k_l[layer_i], hparams.n_embd_k_s() * n_kv, hparams.n_embd_k_s() * kv_head * ggml_type_size(kv_self.k_l[layer_i]->type)) + ) + ); + if ((layer_i + 1) % hparams.rescale_every_n_layers == 0) { x = ggml_scale(ctx0, x, 0.5F); }