add forward function without using cache, for more performant training
during training on whole samples no cache is required. removing the cache and simplifying the remaining code results in performance and memory usage improvement.
This commit is contained in:
parent
2afd218479
commit
93eb8f7752
1 changed files with 233 additions and 1 deletions
|
@ -911,6 +911,234 @@ struct ggml_tensor * forward_batch(
|
|||
return inpL;
|
||||
}
|
||||
|
||||
struct ggml_tensor * forward_batch_wo_cache(
|
||||
struct my_llama_model * model,
|
||||
struct ggml_context * ctx0,
|
||||
struct ggml_cgraph * gf,
|
||||
struct ggml_tensor * tokens_input,
|
||||
const int n_tokens,
|
||||
const int n_batch) {
|
||||
|
||||
const int n_past = 0;
|
||||
const int N = n_tokens;
|
||||
|
||||
const auto & hparams = model->hparams;
|
||||
const int n_ctx = hparams.n_ctx;
|
||||
const int n_vocab = hparams.n_vocab;
|
||||
const int n_embd = hparams.n_embd;
|
||||
const int n_layer = hparams.n_layer;
|
||||
const int n_head = hparams.n_head;
|
||||
const int n_rot = hparams.n_rot;
|
||||
const int n_ff = get_n_ff(&hparams);
|
||||
|
||||
struct ggml_tensor * tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N*n_batch);
|
||||
memcpy(tokens->data, tokens_input->data, ggml_element_size(tokens)*N*n_batch);
|
||||
|
||||
// inpL shape [n_embd,N*n_batch,1]
|
||||
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens);
|
||||
assert_shape_2d(inpL, n_embd, N*n_batch);
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
|
||||
struct ggml_tensor * cur;
|
||||
|
||||
// lctx.use_buf(ctx0, 0);
|
||||
|
||||
// norm
|
||||
{
|
||||
// cur shape [n_embd,N*n_batch,1,1]
|
||||
cur = ggml_rms_norm(ctx0, inpL);
|
||||
assert_shape_2d(cur, n_embd, N*n_batch);
|
||||
|
||||
// cur = attention_norm*cur
|
||||
cur = ggml_mul(ctx0,
|
||||
ggml_repeat(ctx0, model->layers[il].attention_norm, cur),
|
||||
cur);
|
||||
assert_shape_2d(cur, n_embd, N*n_batch);
|
||||
}
|
||||
|
||||
// self-attention
|
||||
{
|
||||
// compute Q and K and RoPE them
|
||||
// wq shape [n_embd, n_embd, 1, 1]
|
||||
// wk shape [n_embd, n_embd, 1, 1]
|
||||
// Qcur shape [n_embd/n_head, n_head, N, n_batch]
|
||||
// Kcur shape [n_embd/n_head, n_head, N, n_batch]
|
||||
struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0);
|
||||
struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0);
|
||||
assert_shape_4d(Qcur, n_embd/n_head, n_head, N, n_batch);
|
||||
assert_shape_4d(Kcur, n_embd/n_head, n_head, N, n_batch);
|
||||
|
||||
// Vcur shape [N, n_batch, n_embd/n_head, n_head]
|
||||
struct ggml_tensor * Vcur = ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, cur, model->layers[il].wv), N, n_batch, n_embd/n_head, n_head);
|
||||
assert_shape_4d(Vcur, N, n_batch, n_embd/n_head, n_head);
|
||||
|
||||
// Qcur shape [n_embd/n_head, n_head, N, n_batch]
|
||||
// Q shape [n_embd/n_head, N, n_head, n_batch]
|
||||
struct ggml_tensor * Q =
|
||||
ggml_permute(ctx0,
|
||||
Qcur,
|
||||
0, 2, 1, 3);
|
||||
assert_shape_4d(Q, n_embd/n_head, N, n_head, n_batch);
|
||||
|
||||
// kv_self.k shape [n_embd * n_ctx * n_batch * n_layer]
|
||||
// K shape [n_embd/n_head, N, n_head, n_batch]
|
||||
struct ggml_tensor * K =
|
||||
ggml_permute(ctx0,
|
||||
Kcur,
|
||||
0, 2, 1, 3);
|
||||
assert_shape_4d(K, n_embd/n_head, N, n_head, n_batch);
|
||||
|
||||
// K * Q
|
||||
// KQ shape [N, N, n_head, n_batch]
|
||||
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
||||
assert_shape_4d(KQ, N, N, n_head, n_batch);
|
||||
|
||||
// KQ_scaled = KQ / sqrt(n_embd/n_head)
|
||||
// KQ_scaled shape [N, N, n_head, n_batch]
|
||||
struct ggml_tensor * KQ_scaled =
|
||||
ggml_scale_inplace(ctx0,
|
||||
KQ,
|
||||
ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)));
|
||||
assert_shape_4d(KQ_scaled, N, N, n_head, n_batch);
|
||||
|
||||
// KQ_masked = mask_past(KQ_scaled)
|
||||
// KQ_masked shape [N, N, n_head, n_batch]
|
||||
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
|
||||
assert_shape_4d(KQ_masked, N, N, n_head, n_batch);
|
||||
|
||||
// KQ = soft_max(KQ_masked)
|
||||
// KQ_soft_max shape [N, N, n_head, n_batch]
|
||||
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
|
||||
assert_shape_4d(KQ_soft_max, N, N, n_head, n_batch);
|
||||
|
||||
// Vcur shape [N, n_batch, n_embd/n_head, n_head]
|
||||
// V shape [N, n_embd/n_head, n_head, n_batch]
|
||||
struct ggml_tensor * V =
|
||||
ggml_permute(ctx0,
|
||||
Vcur,
|
||||
0, 3, 1, 2);
|
||||
assert_shape_4d(V, N, n_embd/n_head, n_head, n_batch);
|
||||
|
||||
// KQV shape [n_embd/n_head, N, n_head, n_batch]
|
||||
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
||||
assert_shape_4d(KQV, n_embd/n_head, N, n_head, n_batch);
|
||||
|
||||
// KQV_merged = KQV.permute(0, 2, 1, 3)
|
||||
// KQV_merged shape [n_embd/n_head, n_head, N, n_batch]
|
||||
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
||||
assert_shape_4d(KQV_merged, n_embd/n_head, n_head, N, n_batch);
|
||||
// KQV_merged shape
|
||||
|
||||
// cur shape [n_embd,N*n_batch,1,1]
|
||||
cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, KQV_merged), n_embd, N*n_batch);
|
||||
assert_shape_2d(cur, n_embd, N*n_batch);
|
||||
|
||||
// projection (no bias)
|
||||
// cur shape [n_embd,N*n_batch,1,1]
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
model->layers[il].wo,
|
||||
cur);
|
||||
assert_shape_2d(cur, n_embd, N*n_batch);
|
||||
}
|
||||
|
||||
// lctx.use_buf(ctx0, 1);
|
||||
|
||||
// inpFF shape [n_embd,N*n_batch,1,1]
|
||||
struct ggml_tensor * inpFF = ggml_add_inplace(ctx0, cur, inpSA);
|
||||
assert_shape_2d(inpFF, n_embd, N*n_batch);
|
||||
|
||||
// feed-forward network
|
||||
{
|
||||
// norm
|
||||
{
|
||||
// cur shape [n_embd,N*n_batch,1,1]
|
||||
cur = ggml_rms_norm(ctx0, inpFF);
|
||||
assert_shape_2d(cur, n_embd, N*n_batch);
|
||||
|
||||
// cur = ffn_norm*cur
|
||||
// cur shape [n_embd,N*n_batch,1,1]
|
||||
cur = ggml_mul(ctx0,
|
||||
ggml_repeat(ctx0, model->layers[il].ffn_norm, cur),
|
||||
cur);
|
||||
assert_shape_2d(cur, n_embd, N*n_batch);
|
||||
}
|
||||
|
||||
// tmp shape [n_ff,N*n_batch,1,1]
|
||||
struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
|
||||
model->layers[il].w3,
|
||||
cur);
|
||||
assert_shape_2d(tmp, n_ff, N*n_batch);
|
||||
|
||||
// cur shape [n_ff,N*n_batch,1,1]
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
model->layers[il].w1,
|
||||
cur);
|
||||
assert_shape_2d(cur, n_ff, N*n_batch);
|
||||
|
||||
// SILU activation
|
||||
// cur shape [n_ff,N*n_batch,1,1]
|
||||
cur = ggml_silu(ctx0, cur);
|
||||
assert_shape_2d(cur, n_ff, N*n_batch);
|
||||
|
||||
// cur shape [n_ff,N*n_batch,1,1]
|
||||
cur = ggml_mul(ctx0, cur, tmp);
|
||||
assert_shape_2d(cur, n_ff, N*n_batch);
|
||||
|
||||
// cur shape [n_embd,N*n_batch,1,1]
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
model->layers[il].w2,
|
||||
cur);
|
||||
assert_shape_2d(cur, n_embd, N*n_batch);
|
||||
}
|
||||
|
||||
// cur shape [n_embd,N*n_batch,1,1]
|
||||
cur = ggml_add_inplace(ctx0, cur, inpFF);
|
||||
assert_shape_2d(cur, n_embd, N*n_batch);
|
||||
|
||||
// input for next layer
|
||||
// inpL shape [n_embd,N*n_batch,1,1]
|
||||
inpL = cur;
|
||||
assert_shape_2d(inpL, n_embd, N*n_batch);
|
||||
}
|
||||
|
||||
// norm
|
||||
{
|
||||
|
||||
// inpL shape [n_embd,N*n_batch,1,1]
|
||||
inpL = ggml_rms_norm(ctx0, inpL);
|
||||
assert_shape_2d(inpL, n_embd, N*n_batch);
|
||||
|
||||
// inpL = norm*inpL
|
||||
// inpL shape [n_embd,N*n_batch,1,1]
|
||||
inpL = ggml_mul(ctx0,
|
||||
ggml_repeat(ctx0, model->norm, inpL),
|
||||
inpL);
|
||||
|
||||
assert_shape_2d(inpL, n_embd, N*n_batch);
|
||||
|
||||
//embeddings = inpL;
|
||||
}
|
||||
|
||||
// lm_head
|
||||
// inpL shape [n_vocab,N*n_batch,1,1]
|
||||
inpL = ggml_mul_mat(ctx0, model->output, inpL);
|
||||
assert_shape_2d(inpL, n_vocab, N*n_batch);
|
||||
|
||||
{
|
||||
// inpL shape [n_vocab,N,n_batch,1]
|
||||
inpL = ggml_reshape_3d(ctx0,
|
||||
inpL,
|
||||
n_vocab, N, n_batch);
|
||||
assert_shape_3d(inpL, n_vocab, N, n_batch);
|
||||
}
|
||||
|
||||
// run the computation
|
||||
ggml_build_forward_expand(gf, inpL);
|
||||
|
||||
return inpL;
|
||||
}
|
||||
|
||||
void sample_softmax(struct ggml_tensor * logits, struct ggml_tensor * probs, struct ggml_tensor * best_samples) {
|
||||
assert(logits->n_dims == 2);
|
||||
assert(probs->n_dims == 2);
|
||||
|
@ -1627,7 +1855,11 @@ int main(int argc, char ** argv) {
|
|||
|
||||
get_example_targets_batch(ctx0, train_samples.data(), train_samples.size(), train_tokens.data(), train_tokens.size(), ex, tokens_input, target_logits, target_probs);
|
||||
|
||||
struct ggml_tensor * logits = forward_batch(&model, &kv_self, ctx0, &gf, tokens_input, n_tokens, n_past, n_batch);
|
||||
struct ggml_tensor * logits =
|
||||
(n_past == 0)
|
||||
? forward_batch_wo_cache(&model, ctx0, &gf, tokens_input, n_tokens, n_batch)
|
||||
: forward_batch(&model, &kv_self, ctx0, &gf, tokens_input, n_tokens, n_past, n_batch);
|
||||
|
||||
// struct ggml_tensor * se = square_error_loss(ctx0, logits, target_logits);
|
||||
struct ggml_tensor * ce = cross_entropy_loss(ctx0, logits, target_probs);
|
||||
// struct ggml_tensor * e = ggml_add(ctx0, se, ce);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue