improve training memory usage with scratch buffers
instead of relying on the automatic backward pass, we manually create the graph for the backward pass. it turns out that all backward pass operations need only temporary memory which can be reused after each layer. will compute backward pass for ALL model parameters
This commit is contained in:
parent
765b290010
commit
0d4b87de3d
1 changed files with 577 additions and 16 deletions
|
@ -1337,6 +1337,505 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(
|
|||
return inpL;
|
||||
}
|
||||
|
||||
struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
|
||||
struct my_llama_model * model,
|
||||
struct ggml_context * ctx0,
|
||||
struct ggml_cgraph * gf,
|
||||
struct ggml_cgraph * gb,
|
||||
struct ggml_tensor * * logits,
|
||||
struct ggml_tensor * tokens_input,
|
||||
struct ggml_tensor * targets,
|
||||
void * compute_buf_0,
|
||||
void * compute_buf_1,
|
||||
void * compute_buf_2,
|
||||
size_t size_buf_0,
|
||||
size_t size_buf_1,
|
||||
size_t size_buf_2,
|
||||
const int n_tokens,
|
||||
const int n_batch) {
|
||||
|
||||
ggml_set_scratch(ctx0, { 0, 0, nullptr, });
|
||||
|
||||
const int n_past = 0;
|
||||
const int N = n_tokens;
|
||||
|
||||
gf->n_nodes = 0;
|
||||
gf->n_leafs = 0;
|
||||
gf->work_size = 0;
|
||||
gf->perf_runs = 0;
|
||||
gf->perf_cycles = 0;
|
||||
gf->perf_time_us = 0;
|
||||
gf->work = NULL;
|
||||
|
||||
const auto & hparams = model->hparams;
|
||||
const int n_ctx = hparams.n_ctx;
|
||||
const int n_vocab = hparams.n_vocab;
|
||||
const int n_embd = hparams.n_embd;
|
||||
const int n_layer = hparams.n_layer;
|
||||
const int n_head = hparams.n_head;
|
||||
const int n_rot = hparams.n_rot;
|
||||
const int n_ff = get_n_ff(&hparams);
|
||||
const int rope_mode = 0;
|
||||
|
||||
auto expand = [] (struct ggml_cgraph * g, struct ggml_tensor * t) -> struct ggml_tensor * {
|
||||
ggml_build_forward_expand(g, t);
|
||||
return t;
|
||||
};
|
||||
|
||||
int last_buf = -1;
|
||||
size_t buf_offs[3] = { 0, 0, 0 };
|
||||
size_t buf_size[3] = { size_buf_0,
|
||||
size_buf_1,
|
||||
size_buf_2 };
|
||||
void * buf_data[3] = { compute_buf_0,
|
||||
compute_buf_1,
|
||||
compute_buf_2 };
|
||||
auto use_buf = [ctx0, &last_buf, &buf_offs, &buf_size, &buf_data] (int buf) {
|
||||
size_t last_offs = 0;
|
||||
last_offs = ggml_set_scratch(ctx0, { 0, 0, nullptr, });
|
||||
if (last_buf >= 0) {
|
||||
buf_offs[last_buf] = last_offs;
|
||||
}
|
||||
if (buf >= 0) {
|
||||
size_t offs = buf_offs[buf];
|
||||
size_t size = buf_size[buf];
|
||||
void * data = buf_data[buf];
|
||||
ggml_set_scratch(ctx0, { offs, size, data, });
|
||||
}
|
||||
last_buf = buf;
|
||||
};
|
||||
|
||||
auto clr_buf = [&buf_offs] (int buf) {
|
||||
if (buf < 0) return;
|
||||
// size_t last_offs = 0;
|
||||
// last_offs = ggml_set_scratch(ctx, { 0, 0, nullptr, });
|
||||
// if (last_buf >= 0) {
|
||||
// buf_offs[last_buf] = last_offs;
|
||||
// }
|
||||
// buf_max_size[buf] = std::max(buf_max_size[buf], buf_offs[buf]);
|
||||
buf_offs[buf] = 0;
|
||||
// if (last_buf >= 0) {
|
||||
// size_t offs = buf_offs[last_buf];
|
||||
// size_t size = buf_size[last_buf];
|
||||
// void * data = buf_data[last_buf];
|
||||
// ggml_set_scratch(ctx0, { offset, size, data, });
|
||||
// }
|
||||
};
|
||||
|
||||
auto view__q = [ctx0, n_embd, n_head, N, n_batch] (struct ggml_tensor * t) -> struct ggml_tensor * {
|
||||
int64_t ne0 = n_embd/n_head;
|
||||
int64_t ne1 = N;
|
||||
int64_t ne2 = n_head;
|
||||
int64_t ne3 = n_batch;
|
||||
size_t nb0 = ggml_element_size(t);
|
||||
size_t nb1 = nb0*ne0;
|
||||
size_t nb2 = nb1*ne1;
|
||||
size_t nb3 = nb2*ne2;
|
||||
size_t offset = 0;
|
||||
return ggml_view_4d(ctx0, t, ne0, ne1, ne2, ne3, nb1, nb2, nb3, offset);
|
||||
};
|
||||
|
||||
auto view__k = [ctx0, n_embd, n_head, N, n_batch] (struct ggml_tensor * t) -> struct ggml_tensor * {
|
||||
int64_t ne0 = n_embd/n_head;
|
||||
int64_t ne1 = N;
|
||||
int64_t ne2 = n_head;
|
||||
int64_t ne3 = n_batch;
|
||||
size_t nb0 = ggml_element_size(t);
|
||||
size_t nb1 = nb0*ne0;
|
||||
size_t nb2 = nb1*ne1;
|
||||
size_t nb3 = nb2*ne2;
|
||||
size_t offset = nb3*ne3;
|
||||
return ggml_view_4d(ctx0, t, ne0, ne1, ne2, ne3, nb1, nb2, nb3, offset);
|
||||
};
|
||||
|
||||
auto view__v = [ctx0, n_embd, n_head, N, n_batch] (struct ggml_tensor * t) -> struct ggml_tensor * {
|
||||
int64_t ne0 = N;
|
||||
int64_t ne1 = n_embd/n_head;
|
||||
int64_t ne2 = n_head;
|
||||
int64_t ne3 = n_batch;
|
||||
size_t nb0 = ggml_element_size(t);
|
||||
size_t nb1 = nb0*ne0;
|
||||
size_t nb2 = nb1*ne1;
|
||||
size_t nb3 = nb2*ne2;
|
||||
size_t offset = 2*nb3*ne3;
|
||||
return ggml_view_4d(ctx0, t, ne0, ne1, ne2, ne3, nb1, nb2, nb3, offset);
|
||||
};
|
||||
|
||||
auto add_or_set = [ctx0] (struct ggml_tensor * a, struct ggml_tensor * b) -> struct ggml_tensor * {
|
||||
if (a == NULL) {
|
||||
return b;
|
||||
} else {
|
||||
return ggml_add_inplace(ctx0, a, b);
|
||||
}
|
||||
};
|
||||
|
||||
use_buf(-1);
|
||||
|
||||
model->tok_embeddings->grad = ggml_dup_tensor(ctx0, model->tok_embeddings->grad);
|
||||
model->norm->grad = ggml_dup_tensor(ctx0, model->norm->grad);
|
||||
model->output->grad = ggml_dup_tensor(ctx0, model->output->grad);
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct my_llama_layer & layer = model->layers[il];
|
||||
layer.attention_norm->grad = ggml_dup_tensor(ctx0, layer.attention_norm->grad);
|
||||
layer.wq->grad = ggml_dup_tensor(ctx0, layer.wq->grad);
|
||||
layer.wk->grad = ggml_dup_tensor(ctx0, layer.wk->grad);
|
||||
layer.wv->grad = ggml_dup_tensor(ctx0, layer.wv->grad);
|
||||
layer.wo->grad = ggml_dup_tensor(ctx0, layer.wo->grad);
|
||||
layer.ffn_norm->grad = ggml_dup_tensor(ctx0, layer.ffn_norm->grad);
|
||||
layer.w1->grad = ggml_dup_tensor(ctx0, layer.w1->grad);
|
||||
layer.w2->grad = ggml_dup_tensor(ctx0, layer.w2->grad);
|
||||
layer.w3->grad = ggml_dup_tensor(ctx0, layer.w3->grad);
|
||||
}
|
||||
|
||||
clr_buf(1);
|
||||
clr_buf(2);
|
||||
|
||||
use_buf(0);
|
||||
|
||||
struct ggml_tensor * t00 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N*n_batch); assert_shape_1d(t00, N*n_batch);
|
||||
memcpy(t00->data, tokens_input->data, ggml_element_size(t00)*N*n_batch);
|
||||
|
||||
struct ggml_tensor * t01 = expand(gf, ggml_get_rows(ctx0, model->tok_embeddings, t00)); assert_shape_2d(t01, n_embd, N*n_batch);
|
||||
|
||||
// need to remember these for the backward pass
|
||||
std::vector<struct ggml_tensor *> t02L; t02L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t03L; t03L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t04L; t04L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t05L; t05L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t06L; t06L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t07L; t07L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t08L; t08L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t09L; t09L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t10L; t10L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t11L; t11L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t12L; t12L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t13L; t13L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t14L; t14L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t15L; t15L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t16L; t16L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t17L; t17L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t18L; t18L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t19L; t19L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t20L; t20L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t21L; t21L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t22L; t22L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t23L; t23L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t24L; t24L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t25L; t25L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t26L; t26L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t27L; t27L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t28L; t28L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t29L; t29L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t30L; t30L.resize(n_layer, NULL);
|
||||
|
||||
struct ggml_tensor * cur = t01;
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
clr_buf(1);
|
||||
struct my_llama_layer & layer = model->layers[il];
|
||||
// tensors with values necessary for backward pass are in persistent buf(0)
|
||||
// other tensors with buf(1) are only temporary needed, and their memory reused after layer is completed.
|
||||
use_buf(0); struct ggml_tensor * t02 = expand(gf, ggml_rms_norm (ctx0, cur)); assert_shape_2d(t02, n_embd, N*n_batch); // n_embd, N*n_batch
|
||||
use_buf(1); struct ggml_tensor * t03 = expand(gf, ggml_repeat (ctx0, layer.attention_norm, t02)); assert_shape_2d(t03, n_embd, N*n_batch);
|
||||
use_buf(0); struct ggml_tensor * t04 = expand(gf, ggml_mul (ctx0, t02, t03)); assert_shape_2d(t04, n_embd, N*n_batch); // n_embd, N*n_batch
|
||||
use_buf(1); struct ggml_tensor * t05 = expand(gf, ggml_mul_mat (ctx0, layer.wq, t04)); assert_shape_2d(t05, n_embd, N*n_batch);
|
||||
use_buf(1); struct ggml_tensor * t06 = expand(gf, ggml_reshape_4d (ctx0, t05, n_embd/n_head, n_head, N, n_batch)); assert_shape_4d(t06, n_embd/n_head, n_head, N, n_batch);
|
||||
use_buf(1); struct ggml_tensor * t07 = expand(gf, ggml_rope_inplace (ctx0, t06, n_past, n_rot, rope_mode)); assert_shape_4d(t07, n_embd/n_head, n_head, N, n_batch);
|
||||
use_buf(1); struct ggml_tensor * t08 = expand(gf, ggml_mul_mat (ctx0, layer.wk, t04)); assert_shape_2d(t08, n_embd, N*n_batch);
|
||||
use_buf(1); struct ggml_tensor * t09 = expand(gf, ggml_reshape_4d (ctx0, t08, n_embd/n_head, n_head, N, n_batch)); assert_shape_4d(t09, n_embd/n_head, n_head, N, n_batch);
|
||||
use_buf(1); struct ggml_tensor * t10 = expand(gf, ggml_rope_inplace (ctx0, t09, n_past, n_rot, rope_mode)); assert_shape_4d(t10, n_embd/n_head, n_head, N, n_batch);
|
||||
use_buf(1); struct ggml_tensor * t11 = expand(gf, ggml_mul_mat (ctx0, t04, layer.wv)); assert_shape_2d(t11, N*n_batch, n_embd);
|
||||
use_buf(1); struct ggml_tensor * t12 = expand(gf, ggml_reshape_4d (ctx0, t11, N, n_batch, n_embd/n_head, n_head)); assert_shape_4d(t12, N, n_batch, n_embd/n_head, n_head);
|
||||
use_buf(0); struct ggml_tensor * t13 = expand(gf, ggml_permute (ctx0, t07, 0, 2, 1, 3)); assert_shape_4d(t13, n_embd/n_head, N, n_head, n_batch); // n_embd/n_head, N, n_head, n_batch
|
||||
use_buf(0); struct ggml_tensor * t14 = expand(gf, ggml_permute (ctx0, t10, 0, 2, 1, 3)); assert_shape_4d(t14, n_embd/n_head, N, n_head, n_batch); // n_embd/n_head, N, n_head, n_batch
|
||||
use_buf(0); struct ggml_tensor * t15 = expand(gf, ggml_permute (ctx0, t12, 0, 3, 1, 2)); assert_shape_4d(t15, N, n_embd/n_head, n_head, n_batch); // N, n_embd/n_head, n_head, n_batch
|
||||
use_buf(1); struct ggml_tensor * t16 = expand(gf, ggml_flash_attn (ctx0, t13, t14, t15, true)); assert_shape_4d(t16, n_embd/n_head, N, n_head, n_batch);
|
||||
use_buf(1); struct ggml_tensor * t17 = expand(gf, ggml_permute (ctx0, t16, 0, 2, 1, 3)); assert_shape_4d(t17, n_embd/n_head, n_head, N, n_batch);
|
||||
use_buf(1); struct ggml_tensor * t18 = expand(gf, ggml_cont (ctx0, t17)); assert_shape_4d(t18, n_embd/n_head, n_head, N, n_batch);
|
||||
use_buf(0); struct ggml_tensor * t19 = expand(gf, ggml_reshape_2d (ctx0, t18, n_embd, N*n_batch)); assert_shape_2d(t19, n_embd, N*n_batch); // n_embd, N*n_batch
|
||||
use_buf(1); struct ggml_tensor * t20 = expand(gf, ggml_mul_mat (ctx0, layer.wo, t19)); assert_shape_2d(t20, n_embd, N*n_batch);
|
||||
use_buf(0); struct ggml_tensor * t21 = expand(gf, ggml_add (ctx0, t20, cur)); assert_shape_2d(t21, n_embd, N*n_batch); // n_embd, N*n_batch
|
||||
use_buf(0); struct ggml_tensor * t22 = expand(gf, ggml_rms_norm (ctx0, t21)); assert_shape_2d(t22, n_embd, N*n_batch); // n_embd, N*n_batch
|
||||
use_buf(1); struct ggml_tensor * t23 = expand(gf, ggml_repeat (ctx0, layer.ffn_norm, t22)); assert_shape_2d(t23, n_embd, N*n_batch);
|
||||
use_buf(0); struct ggml_tensor * t24 = expand(gf, ggml_mul (ctx0, t23, t22)); assert_shape_2d(t24, n_embd, N*n_batch); // n_embd, N*n_batch
|
||||
use_buf(0); struct ggml_tensor * t25 = expand(gf, ggml_mul_mat (ctx0, layer.w3, t24)); assert_shape_2d(t25, n_ff, N*n_batch); // n_ff, N*n_batch
|
||||
use_buf(0); struct ggml_tensor * t26 = expand(gf, ggml_mul_mat (ctx0, layer.w1, t24)); assert_shape_2d(t26, n_ff, N*n_batch); // n_ff, N*n_batch
|
||||
use_buf(0); struct ggml_tensor * t27 = expand(gf, ggml_silu (ctx0, t26)); assert_shape_2d(t27, n_ff, N*n_batch); // n_ff, N*n_batch
|
||||
use_buf(0); struct ggml_tensor * t28 = expand(gf, ggml_mul (ctx0, t27, t25)); assert_shape_2d(t28, n_ff, N*n_batch); // n_ff, N*n_batch
|
||||
use_buf(1); struct ggml_tensor * t29 = expand(gf, ggml_mul_mat (ctx0, layer.w2, t28)); assert_shape_2d(t29, n_embd, N*n_batch);
|
||||
use_buf(0); struct ggml_tensor * t30 = expand(gf, ggml_add (ctx0, t21, t29)); assert_shape_2d(t30, n_embd, N*n_batch); // n_embd, N*n_batch
|
||||
t02L[il] = t02;
|
||||
t03L[il] = t03;
|
||||
t04L[il] = t04;
|
||||
t05L[il] = t05;
|
||||
t06L[il] = t06;
|
||||
t07L[il] = t07;
|
||||
t08L[il] = t08;
|
||||
t09L[il] = t09;
|
||||
t10L[il] = t10;
|
||||
t11L[il] = t11;
|
||||
t12L[il] = t12;
|
||||
t13L[il] = t13;
|
||||
t14L[il] = t14;
|
||||
t15L[il] = t15;
|
||||
t16L[il] = t16;
|
||||
t17L[il] = t17;
|
||||
t18L[il] = t18;
|
||||
t19L[il] = t19;
|
||||
t20L[il] = t20;
|
||||
t21L[il] = t21;
|
||||
t22L[il] = t22;
|
||||
t23L[il] = t23;
|
||||
t24L[il] = t24;
|
||||
t25L[il] = t25;
|
||||
t26L[il] = t26;
|
||||
t27L[il] = t27;
|
||||
t28L[il] = t28;
|
||||
t29L[il] = t29;
|
||||
t30L[il] = t30;
|
||||
|
||||
cur = t30;
|
||||
}
|
||||
clr_buf(1);
|
||||
use_buf(1);
|
||||
struct ggml_tensor * t31 = expand(gf, ggml_rms_norm (ctx0, cur)); assert_shape_2d(t31, n_embd, N*n_batch);
|
||||
struct ggml_tensor * t32 = expand(gf, ggml_repeat (ctx0, model->norm, t31)); assert_shape_2d(t32, n_embd, N*n_batch);
|
||||
struct ggml_tensor * t33 = expand(gf, ggml_mul (ctx0, t32, t31)); assert_shape_2d(t33, n_embd, N*n_batch);
|
||||
struct ggml_tensor * t34 = expand(gf, ggml_mul_mat (ctx0, model->output, t33)); assert_shape_2d(t34, n_vocab, N*n_batch);
|
||||
struct ggml_tensor * t35 = expand(gf, ggml_reshape_3d(ctx0, t34, n_vocab, N, n_batch)); assert_shape_3d(t35, n_vocab, N, n_batch);
|
||||
struct ggml_tensor * t36 = expand(gf, ggml_cross_entropy_loss(ctx0, t35, targets)); assert_shape_1d(t36, 1);
|
||||
|
||||
{
|
||||
/*
|
||||
tok_embeddings | grad_tok_embeddings = ggml_get_rows_back(grad_t01, t00)
|
||||
L0_att_norm | grad_L0_att_norm = ggml_repeat_back(grad_t03L0, L0_att_norm.shape)
|
||||
L0_wq | grad_L0_wq = ggml_out_prod(t04L0, grad_t05L0)
|
||||
L0_wk | grad_L0_wk = ggml_out_prod(t04L0, grad_t08L0)
|
||||
L0_wv | grad_L0_wv = ggml_out_prod(t04L0, ggml_transpose(grad_t11L0))
|
||||
L0_wo | grad_L0_wo = ggml_out_prod(t19L0, grad_t20L0)
|
||||
L0_ffn_norm | grad_L0_ffn_norm = ggml_repeat_back(grad_t23L0, L0_ffn_norm.shape)
|
||||
L0_w1 | grad_L0_w1 = ggml_out_prod(t24L0, grad_t26L0)
|
||||
L0_w2 | grad_L0_w2 = ggml_out_prod(t28L0, grad_t29L0)
|
||||
L0_w3 | grad_L0_w3 = ggml_out_prod(t24L0, grad_t25L0)
|
||||
L1_att_norm | grad_L1_att_norm = ggml_repeat_back(grad_t03L1, L1_att_norm.shape)
|
||||
L1_wq | grad_L1_wq = ggml_out_prod(t04L1, grad_t05L1)
|
||||
L1_wk | grad_L1_wk = ggml_out_prod(t04L1, grad_t08L1)
|
||||
L1_wv | grad_L1_wv = ggml_out_prod(t04L1, ggml_transpose(grad_t11L1))
|
||||
L1_wo | grad_L1_wo = ggml_out_prod(t19L1, grad_t20L1)
|
||||
L1_ffn_norm | grad_L1_ffn_norm = ggml_repeat_back(grad_t23L1, L1_ffn_norm.shape)
|
||||
L1_w1 | grad_L1_w1 = ggml_out_prod(t24L1, grad_t26L1)
|
||||
L1_w2 | grad_L1_w2 = ggml_out_prod(t28L1, grad_t29L1)
|
||||
L1_w3 | grad_L1_w3 = ggml_out_prod(t24L1, grad_t25L1)
|
||||
norm | grad_norm = ggml_repeat_back(grad_t32, norm.shape)
|
||||
output | grad_output = ggml_out_prod(t33, grad_t34)
|
||||
|
|
||||
t01 = ggml_get_rows(tok_embeddings, t00) | grad_t01 = grad_t21L0 + ggml_rms_norm_back(t01, grad_t02L0)
|
||||
for layer: |
|
||||
t02L0*= ggml_rms_norm (t01) | grad_t02L0 = ggml_mul(grad_t04L0, t03L0)
|
||||
t03L0 = ggml_repeat (L0_att_norm, t02L0_shape) | grad_t03L0 = ggml_mul(grad_t04L0, t02L0)
|
||||
t04L0*= ggml_mul (t02L0, t03L0) | grad_t04L0 = ggml_out_prod(L0_wv, grad_t11L0) + ggml_out_prod(L0_wk, ggml_transpose(grad_t08L0)) + ggml_out_prod(L0_wq, ggml_transpose(grad_t05L0))
|
||||
t05L0 = ggml_mul_mat (L0_wq, t04L0) | grad_t05L0 = ggml_reshape(grad_t06L0, t05L0_shape)
|
||||
t06L0 = ggml_reshape_4d (t05L0, n_embd/n_head, n_head, N, n_batch) | grad_t06L0 = ggml_rope_back(grad_t07L0)
|
||||
t07L0 = ggml_rope_inplace (t06L0) | grad_t07L0 = ggml_permute_back(grad_t13L0, 0, 2, 1, 3) = ggml_permute(grad_t13L0, 0, 2, 1, 3)
|
||||
t08L0 = ggml_mul_mat (L0_wk, t04L0) | grad_t08L0 = ggml_reshape(grad_t09L0, t08L0_shape)
|
||||
t09L0 = ggml_reshape_4d (t08L0, n_embd/n_head, n_head, N, n_batch) | grad_t09L0 = ggml_rope_back(grad_t10L0)
|
||||
t10L0 = ggml_rope_inplace (t09L0) | grad_t10L0 = ggml_permute_back(grad_t14L0, 0, 2, 1, 3) = ggml_permute(grad_t14L0, 0, 2, 1, 3)
|
||||
t11L0 = ggml_mul_mat (t04L0, L0_wv) | grad_t11L0 = ggml_reshape(grad_t12L0, t11L0_shape)
|
||||
t12L0 = ggml_reshape_4d (t11L0, N, n_batch, n_embd/n_head, n_head) | grad_t12L0 = ggml_permute_back(grad_t15L0, 0, 3, 1, 2) = ggml_permute(grad_t15L0, 0, 2, 3, 1)
|
||||
t13L0*= ggml_permute (t07L0, 0, 2, 1, 3) | grad_t13L0 = view__q(ggml_flash_attn_back(t13L0, t14L0, t15L0, grad_t16L0))
|
||||
t14L0*= ggml_permute (t10L0, 0, 2, 1, 3) | grad_t14L0 = view__k(ggml_flash_attn_back(t13L0, t14L0, t15L0, grad_t16L0))
|
||||
t15L0*= ggml_permute (t12L0, 0, 3, 1, 2) | grad_t15L0 = view__v(ggml_flash_attn_back(t13L0, t14L0, t15L0, grad_t16L0))
|
||||
t16L0 = ggml_flash_attn (t13L0, t14L0, t15L0) | grad_t16L0 = ggml_permute_back(grad_t17L0, 0, 2, 1, 3) = ggml_permute(grad_t17L0, 0, 2, 1, 3)
|
||||
t17L0 = ggml_permute (t16L0, 0, 2, 1, 3) | grad_t17L0 = grad_t18L0
|
||||
t18L0 = ggml_cont (t17L0) | grad_t18L0 = ggml_reshape(grad_t19L0, t18L0_shape)
|
||||
t19L0*= ggml_reshape_2d (t18L0, n_embd, N*n_batch) | grad_t19L0 = ggml_out_prod(L0_wo, ggml_transpose(grad_t20L0))
|
||||
t20L0 = ggml_mul_mat (L0_wo, t19L0) | grad_t20L0 = grad_t21L0
|
||||
t21L0*= ggml_add (t20L0, t01) | grad_t21L0 = grad_t30L0 + ggml_rms_norm_back(t21L0, grad_t22L0)
|
||||
t22L0*= ggml_rms_norm (t21L0) | grad_t22L0 = ggml_mul(grad_t24L0, t23L0)
|
||||
t23L0 = ggml_repeat (L0_ffn_norm, t22L0_shape) | grad_t23L0 = ggml_mul(grad_t24L0, t22L0)
|
||||
t24L0*= ggml_mul (t23L0, t22L0) | grad_t24L0 = ggml_out_prod(L0_w1, ggml_transpose(grad_t26L0)) + ggml_out_prod(L0_w3, ggml_transpose(grad_t25L0))
|
||||
t25L0*= ggml_mul_mat (L0_w3, t24L0) | grad_t25L0 = ggml_mul(grad_t28L0, t27L0)
|
||||
t26L0*= ggml_mul_mat (L0_w1, t24L0) | grad_t26L0 = ggml_silu_back(t26L0, grad_t27L0)
|
||||
t27L0*= ggml_silu (t26L0) | grad_t27L0 = ggml_mul(grad_t28L0, t25L0)
|
||||
t28L0*= ggml_mul (t27L0, t25L0) | grad_t28L0 = ggml_out_prod(L0_w2, ggml_transpose(grad_t29L0))
|
||||
t29L0 = ggml_mul_mat (L0_w2, t28L0) | grad_t29L0 = grad_t30L0
|
||||
t30L0*= ggml_add (t21L0, t29L0) | grad_t30L0 = ggml_rms_norm_back(t30L0, grad_t02L1) + grad_t21L1
|
||||
^
|
||||
t02L1*= ggml_rms_norm (t30L0) | grad_t02L1 = ggml_mul(grad_t04L1, t03L1)
|
||||
t03L1 = ggml_repeat (L1_att_norm, t02L1_shape) | grad_t03L1 = ggml_mul(grad_t04L1, t02L1)
|
||||
t04L1*= ggml_mul (t02L1, t03L1) | grad_t04L1 = ggml_out_prod(L1_wv, grad_t11L1) + ggml_out_prod(L1_wk, ggml_transpose(grad_t08L1)) + ggml_out_prod(L1_wq, ggml_transpose(grad_t05L1))
|
||||
t05L1 = ggml_mul_mat (L1_wq, t04L1) | grad_t05L1 = ggml_reshape(grad_t06L1, t05L1_shape)
|
||||
t06L1 = ggml_reshape_4d (t05L1, n_embd/n_head, n_head, N, n_batch) | grad_t06L1 = ggml_rope_back(grad_t07L1)
|
||||
t07L1 = ggml_rope_inplace (t06L1) | grad_t07L1 = ggml_permute_back(grad_t13L1, 0, 2, 1, 3) = ggml_permute(grad_t13L1, 0, 2, 1, 3)
|
||||
t08L1 = ggml_mul_mat (L1_wk, t04L1) | grad_t08L1 = ggml_reshape(grad_t09L1, t08L1_shape)
|
||||
t09L1 = ggml_reshape_4d (t08L1, n_embd/n_head, n_head, N, n_batch) | grad_t09L1 = ggml_rope_back(grad_t10L1)
|
||||
t10L1 = ggml_rope_inplace (t09L1) | grad_t10L1 = ggml_permute_back(grad_t14L1, 0, 2, 1, 3) = ggml_permute(grad_t14L1, 0, 2, 1, 3)
|
||||
t11L1 = ggml_mul_mat (t04L1, L1_wv) | grad_t11L1 = ggml_reshape(grad_t12L1, t11L1_shape)
|
||||
t12L1 = ggml_reshape_4d (t11L1, N, n_batch, n_embd/n_head, n_head) | grad_t12L1 = ggml_permute_back(grad_t15L1, 0, 3, 1, 2) = ggml_permute(grad_t15L1, 0, 2, 3, 1)
|
||||
t13L1*= ggml_permute (t07L1, 0, 2, 1, 3) | grad_t13L1 = view__q(ggml_flash_attn_back(t13L1, t14L1, t15L1, grad_t16L1))
|
||||
t14L1*= ggml_permute (t10L1, 0, 2, 1, 3) | grad_t14L1 = view__k(ggml_flash_attn_back(t13L1, t14L1, t15L1, grad_t16L1))
|
||||
t15L1*= ggml_permute (t12L1, 0, 3, 1, 2) | grad_t15L1 = view__v(ggml_flash_attn_back(t13L1, t14L1, t15L1, grad_t16L1))
|
||||
t16L1 = ggml_flash_attn (t13L1, t14L1, t15L1) | grad_t16L1 = ggml_permute_back(grad_t17L1, 0, 2, 1, 3) = ggml_permute(grad_t17L1, 0, 2, 1, 3)
|
||||
t17L1 = ggml_permute (t16L1, 0, 2, 1, 3) | grad_t17L1 = grad_t18L1
|
||||
t18L1 = ggml_cont (t17L1) | grad_t18L1 = ggml_reshape(grad_t19L1, t18L1_shape)
|
||||
t19L1*= ggml_reshape_2d (t18L1, n_embd, N*n_batch) | grad_t19L1 = ggml_out_prod(L1_wo, ggml_transpose(grad_t20L1))
|
||||
t20L1 = ggml_mul_mat (L1_wo, t19L1) | grad_t20L1 = grad_t21L1
|
||||
t21L1*= ggml_add (t20L1, t30L0) | grad_t21L1 = grad_t30L1 + ggml_rms_norm_back(t21L1, grad_t22L1)
|
||||
t22L1*= ggml_rms_norm (t21L1) | grad_t22L1 = ggml_mul(grad_t24L1, t23L1)
|
||||
t23L1 = ggml_repeat (L1_ffn_norm, t22L1_shape) | grad_t23L1 = ggml_mul(grad_t24L1, t22L1)
|
||||
t24L1*= ggml_mul (t23L1, t22L1) | grad_t24L1 = ggml_out_prod(L1_w1, ggml_transpose(grad_t26L1)) + ggml_out_prod(L1_w3, ggml_transpose(grad_t25L1))
|
||||
t25L1*= ggml_mul_mat (L1_w3, t24L1) | grad_t25L1 = ggml_mul(grad_t28L1, t27L1)
|
||||
t26L1*= ggml_mul_mat (L1_w1, t24L1) | grad_t26L1 = ggml_silu_back(t26L1, grad_t27L1)
|
||||
t27L1*= ggml_silu (t26L1) | grad_t27L1 = ggml_mul(grad_t28L1, t25L1)
|
||||
t28L1*= ggml_mul (t27L1, t25L1) | grad_t28L1 = ggml_out_prod(L1_w2, ggml_transpose(grad_t29L1))
|
||||
t29L1 = ggml_mul_mat (L1_w2, t28L1) | grad_t29L1 = grad_t30L1
|
||||
t30L1*= ggml_add (t21L1, t29L1) | grad_t30L1 = ggml_rms_norm_back(t30L1, grad_t31)
|
||||
^
|
||||
t31 = ggml_rms_norm (t30L1) | grad_t31 = ggml_mul(grad_t33, t32)
|
||||
t32 = ggml_repeat (norm, t31.shape) | grad_t32 = ggml_mul(grad_t33, t31)
|
||||
t33 = ggml_mul (t32, t31) | grad_t33 = ggml_out_prod(output, ggml_transpose(grad_t34))
|
||||
t34 = ggml_mul_mat (output, t33) | grad_t34 = ggml_reshape(grad_t35, t34.shape)
|
||||
t35 = ggml_reshape_3d (t34, n_vocab, N, n_batch) | grad_t35 = ggml_cross_entropy_loss_back(t35, targets, grad_t36)
|
||||
t36 = ggml_cross_entropy_loss(t35, targets) | grad_t36 = 1 (optimizer)
|
||||
tensors marked with * need to be stored until grad computation
|
||||
tensors during grad computation are all temporary
|
||||
*/
|
||||
}
|
||||
|
||||
*gb = *gf;
|
||||
|
||||
use_buf(-1);
|
||||
// t36->grad gets set to one by optimizer, so we need to create the tensor.
|
||||
// initialize it with 1.0f to make sure.
|
||||
t36->grad = ggml_new_f32(ctx0, 1.0f);
|
||||
|
||||
use_buf(1);
|
||||
t35->grad = expand(gb, ggml_cross_entropy_loss_back(ctx0, t35, targets, t36->grad)); assert_shape_3d(t35->grad, n_vocab, N, n_batch);
|
||||
t34->grad = expand(gb, ggml_reshape_2d (ctx0, t35->grad, n_vocab, N*n_batch)); assert_shape_2d(t34->grad, n_vocab, N*n_batch);
|
||||
t33->grad = expand(gb, ggml_out_prod (ctx0, model->output, ggml_transpose(ctx0, t34->grad))); assert_shape_2d(t33->grad, n_embd, N*n_batch);
|
||||
t32->grad = expand(gb, ggml_mul (ctx0, t33->grad, t31)); assert_shape_2d(t32->grad, n_embd, N*n_batch);
|
||||
|
||||
use_buf(-1);
|
||||
|
||||
model->norm->grad = expand(gb, add_or_set(model->norm->grad, ggml_repeat_back(ctx0, t32->grad, model->norm))); assert_shape_1d(model->norm->grad, n_embd);
|
||||
model->output->grad = expand(gb, add_or_set(model->output->grad, ggml_out_prod(ctx0, t33, t34->grad))); assert_shape_2d(model->output->grad, n_embd, n_vocab);
|
||||
|
||||
clr_buf(2);
|
||||
use_buf(2);
|
||||
t31->grad = expand(gb, ggml_mul(ctx0, t33->grad, t32)); assert_shape_2d(t31->grad, n_embd, N*n_batch);
|
||||
|
||||
struct ggml_tensor * back_layer_inp = t31;
|
||||
struct ggml_tensor * grad_layer_inp = NULL;
|
||||
|
||||
for (int k = 0; k < n_layer; ++k) {
|
||||
int il = n_layer-1-k;
|
||||
struct my_llama_layer & layer = model->layers[il];
|
||||
|
||||
struct ggml_tensor * t02 = t02L[il];
|
||||
struct ggml_tensor * t03 = t03L[il];
|
||||
struct ggml_tensor * t04 = t04L[il];
|
||||
struct ggml_tensor * t05 = t05L[il];
|
||||
struct ggml_tensor * t06 = t06L[il];
|
||||
struct ggml_tensor * t07 = t07L[il];
|
||||
struct ggml_tensor * t08 = t08L[il];
|
||||
struct ggml_tensor * t09 = t09L[il];
|
||||
struct ggml_tensor * t10 = t10L[il];
|
||||
struct ggml_tensor * t11 = t11L[il];
|
||||
struct ggml_tensor * t12 = t12L[il];
|
||||
struct ggml_tensor * t13 = t13L[il];
|
||||
struct ggml_tensor * t14 = t14L[il];
|
||||
struct ggml_tensor * t15 = t15L[il];
|
||||
struct ggml_tensor * t16 = t16L[il];
|
||||
struct ggml_tensor * t17 = t17L[il];
|
||||
struct ggml_tensor * t18 = t18L[il];
|
||||
struct ggml_tensor * t19 = t19L[il];
|
||||
struct ggml_tensor * t20 = t20L[il];
|
||||
struct ggml_tensor * t21 = t21L[il];
|
||||
struct ggml_tensor * t22 = t22L[il];
|
||||
struct ggml_tensor * t23 = t23L[il];
|
||||
struct ggml_tensor * t24 = t24L[il];
|
||||
struct ggml_tensor * t25 = t25L[il];
|
||||
struct ggml_tensor * t26 = t26L[il];
|
||||
struct ggml_tensor * t27 = t27L[il];
|
||||
struct ggml_tensor * t28 = t28L[il];
|
||||
struct ggml_tensor * t29 = t29L[il];
|
||||
struct ggml_tensor * t30 = t30L[il];
|
||||
|
||||
clr_buf(1);
|
||||
use_buf(1);
|
||||
t30->grad = expand(gb, ggml_rms_norm_back(ctx0, t30, back_layer_inp->grad)); assert_shape_2d(t30->grad, n_embd, N*n_batch);
|
||||
if (grad_layer_inp) {
|
||||
t30->grad = expand(gb, ggml_add(ctx0, t30->grad, grad_layer_inp)); assert_shape_2d(t30->grad, n_embd, N*n_batch);
|
||||
}
|
||||
clr_buf(2);
|
||||
t29->grad = t30->grad; assert_shape_2d(t29->grad, n_embd, N*n_batch);
|
||||
t28->grad = expand(gb, ggml_out_prod(ctx0, layer.w2, ggml_transpose(ctx0, t29->grad))); assert_shape_2d(t28->grad, n_ff, N*n_batch);
|
||||
t27->grad = expand(gb, ggml_mul(ctx0, t28->grad, t25)); assert_shape_2d(t27->grad, n_ff, N*n_batch);
|
||||
t26->grad = expand(gb, ggml_silu_back(ctx0, t26, t27->grad)); assert_shape_2d(t26->grad, n_ff, N*n_batch);
|
||||
t25->grad = expand(gb, ggml_mul(ctx0, t28->grad, t27)); assert_shape_2d(t25->grad, n_ff, N*n_batch);
|
||||
t24->grad = expand(gb, ggml_add_inplace(ctx0,
|
||||
ggml_out_prod(ctx0, layer.w1, ggml_transpose(ctx0, t26->grad)),
|
||||
ggml_out_prod(ctx0, layer.w3, ggml_transpose(ctx0, t25->grad)))); assert_shape_2d(t24->grad, n_embd, N*n_batch);
|
||||
t23->grad = expand(gb, ggml_mul(ctx0, t24->grad, t22)); assert_shape_2d(t23->grad, n_embd, N*n_batch);
|
||||
t22->grad = expand(gb, ggml_mul(ctx0, t24->grad, ggml_repeat(ctx0, layer.ffn_norm, t24->grad))); assert_shape_2d(t22->grad, n_embd, N*n_batch);
|
||||
use_buf(2);
|
||||
t21->grad = expand(gb, ggml_add(ctx0, t30->grad, ggml_rms_norm_back(ctx0, t21, t22->grad))); assert_shape_2d(t21->grad, n_embd, N*n_batch);
|
||||
grad_layer_inp = t21;
|
||||
use_buf(1);
|
||||
t20->grad = t21->grad; assert_shape_2d(t20->grad, n_embd, N*n_batch);
|
||||
t19->grad = expand(gb, ggml_out_prod(ctx0, layer.wo, ggml_transpose(ctx0, t20->grad))); assert_shape_2d(t19->grad, n_embd, N*n_batch);
|
||||
t18->grad = expand(gb, ggml_reshape(ctx0, t19->grad, t18)); assert_shape_4d(t18->grad, n_embd/n_head, n_head, N, n_batch);
|
||||
t17->grad = t18->grad; assert_shape_4d(t17->grad, n_embd/n_head, n_head, N, n_batch);
|
||||
t16->grad = expand(gb, ggml_permute(ctx0, t17->grad, 0, 2, 1, 3)); assert_shape_4d(t16->grad, n_embd/n_head, N, n_head, n_batch);
|
||||
struct ggml_tensor * flash_attn = expand(gb, ggml_flash_attn_back(ctx0, t13, t14, t15, t16->grad, true)); assert_shape_4d(flash_attn, n_embd/n_head, N*3, n_head, n_batch);
|
||||
t15->grad = expand(gb, view__v(flash_attn)); assert_shape_4d(t15->grad, N, n_embd/n_head, n_head, n_batch);
|
||||
t14->grad = expand(gb, view__k(flash_attn)); assert_shape_4d(t14->grad, n_embd/n_head, N, n_head, n_batch);
|
||||
t13->grad = expand(gb, view__q(flash_attn)); assert_shape_4d(t13->grad, n_embd/n_head, N, n_head, n_batch);
|
||||
t12->grad = expand(gb, ggml_permute(ctx0, t15->grad, 0, 2, 3, 1)); assert_shape_4d(t12->grad, N, n_batch, n_embd/n_head, n_head);
|
||||
t11->grad = expand(gb, ggml_reshape(ctx0, ggml_cont(ctx0, t12->grad), t11)); assert_shape_2d(t11->grad, N*n_batch, n_embd);
|
||||
t10->grad = expand(gb, ggml_permute(ctx0, t14->grad, 0, 2, 1, 3)); assert_shape_4d(t10->grad, n_embd/n_head, n_head, N, n_batch);
|
||||
t09->grad = expand(gb, ggml_rope_back(ctx0, t10->grad, n_past, n_rot, rope_mode)); assert_shape_4d(t09->grad, n_embd/n_head, n_head, N, n_batch);
|
||||
t08->grad = expand(gb, ggml_reshape(ctx0, t09->grad, t08)); assert_shape_2d(t08->grad, n_embd, N*n_batch);
|
||||
t07->grad = expand(gb, ggml_permute(ctx0, t13->grad, 0, 2, 1, 3)); assert_shape_4d(t07->grad, n_embd/n_head, n_head, N, n_batch);
|
||||
t06->grad = expand(gb, ggml_rope_back(ctx0, t07->grad, n_past, n_rot, rope_mode)); assert_shape_4d(t06->grad, n_embd/n_head, n_head, N, n_batch);
|
||||
t05->grad = expand(gb, ggml_reshape(ctx0, t06->grad, t05)); assert_shape_2d(t05->grad, n_embd, N*n_batch);
|
||||
t04->grad = expand(gb, ggml_add_inplace(ctx0,
|
||||
ggml_add_inplace(ctx0,
|
||||
ggml_out_prod(ctx0, layer.wv, t11->grad),
|
||||
ggml_out_prod(ctx0, layer.wk, ggml_transpose(ctx0, t08->grad))),
|
||||
ggml_out_prod(ctx0, layer.wq, ggml_transpose(ctx0, t05->grad)))); assert_shape_2d(t04->grad, n_embd, N*n_batch);
|
||||
t03->grad = expand(gb, ggml_mul(ctx0, t04->grad, t02)); assert_shape_2d(t04->grad, n_embd, N*n_batch);
|
||||
use_buf(2);
|
||||
t02->grad = expand(gb, ggml_mul(ctx0, t04->grad, t03)); assert_shape_2d(t02->grad, n_embd, N*n_batch);
|
||||
back_layer_inp = t02->grad;
|
||||
use_buf(1);
|
||||
|
||||
use_buf(-1);
|
||||
layer.attention_norm->grad = expand(gb, add_or_set(layer.attention_norm->grad, ggml_repeat_back(ctx0, t03->grad, layer.attention_norm))); assert_shape_1d(layer.attention_norm->grad, n_embd);
|
||||
layer.wq->grad = expand(gb, add_or_set(layer.wq->grad, ggml_out_prod(ctx0, t04, t05->grad))); assert_shape_2d(layer.wq->grad, n_embd, n_embd);
|
||||
layer.wk->grad = expand(gb, add_or_set(layer.wk->grad, ggml_out_prod(ctx0, t04, t08->grad))); assert_shape_2d(layer.wk->grad, n_embd, n_embd);
|
||||
layer.wv->grad = expand(gb, add_or_set(layer.wv->grad, ggml_out_prod(ctx0, t04, ggml_transpose(ctx0, t11->grad)))); assert_shape_2d(layer.wv->grad, n_embd, n_embd);
|
||||
layer.wo->grad = expand(gb, add_or_set(layer.wo->grad, ggml_out_prod(ctx0, t19, t20->grad))); assert_shape_2d(layer.wo->grad, n_embd, n_embd);
|
||||
layer.ffn_norm->grad = expand(gb, add_or_set(layer.ffn_norm->grad, ggml_repeat_back(ctx0, t23->grad, layer.ffn_norm))); assert_shape_1d(layer.ffn_norm->grad, n_embd);
|
||||
layer.w1->grad = expand(gb, add_or_set(layer.w1->grad, ggml_out_prod(ctx0, t24, t26->grad))); assert_shape_2d(layer.w1->grad, n_embd, n_ff);
|
||||
layer.w2->grad = expand(gb, add_or_set(layer.w2->grad, ggml_out_prod(ctx0, t28, t29->grad))); assert_shape_2d(layer.w2->grad, n_ff, n_embd);
|
||||
layer.w3->grad = expand(gb, add_or_set(layer.w3->grad, ggml_out_prod(ctx0, t24, t25->grad))); assert_shape_2d(layer.w3->grad, n_embd, n_ff);
|
||||
use_buf(1);
|
||||
}
|
||||
clr_buf(1);
|
||||
use_buf(1);
|
||||
t01->grad = expand(gb, ggml_add_inplace(ctx0, grad_layer_inp->grad, ggml_rms_norm_back(ctx0, t01, back_layer_inp->grad))); assert_shape_2d(t01->grad, n_embd, N*n_batch);
|
||||
use_buf(-1);
|
||||
model->tok_embeddings->grad = expand(gb, ggml_get_rows_back(ctx0, t01->grad, t00, model->tok_embeddings)); assert_shape_2d(model->tok_embeddings->grad, n_embd, n_vocab);
|
||||
clr_buf(2);
|
||||
clr_buf(1);
|
||||
|
||||
*logits = t35;
|
||||
|
||||
return t36;
|
||||
}
|
||||
|
||||
void set_f32_3d(struct ggml_tensor * tensor, int64_t i0, int64_t i1, int64_t i2, float value) {
|
||||
float * ptr = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2]);
|
||||
*ptr = value;
|
||||
|
@ -2129,6 +2628,9 @@ struct train_params {
|
|||
|
||||
int mem_model_gb;
|
||||
int mem_compute_gb;
|
||||
int mem_compute0_gb;
|
||||
int mem_compute1_gb;
|
||||
int mem_compute2_gb;
|
||||
};
|
||||
|
||||
struct train_params get_default_train_params() {
|
||||
|
@ -2172,7 +2674,10 @@ struct train_params get_default_train_params() {
|
|||
params.adam_decay = 1e-3;
|
||||
|
||||
params.mem_model_gb = 2;
|
||||
params.mem_compute_gb = 32;
|
||||
params.mem_compute_gb = 8;
|
||||
params.mem_compute0_gb = 24;
|
||||
params.mem_compute1_gb = 8;
|
||||
params.mem_compute2_gb = 8;
|
||||
|
||||
return params;
|
||||
}
|
||||
|
@ -2215,6 +2720,9 @@ void train_print_usage(int /*argc*/, char ** argv, const struct train_params * p
|
|||
fprintf(stderr, " --adam-decay N AdamW weight decay. Values greater zero enable AdamW instead of regular Adam. (default %f)\n", params->adam_decay);
|
||||
fprintf(stderr, " --mem-model N Memory to allocate for model and cache in gigabytes. (default %d)\n", params->mem_model_gb);
|
||||
fprintf(stderr, " --mem-compute N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute_gb);
|
||||
fprintf(stderr, " --mem-compute0 N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute0_gb);
|
||||
fprintf(stderr, " --mem-compute1 N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute1_gb);
|
||||
fprintf(stderr, " --mem-compute2 N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute2_gb);
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
|
@ -2408,6 +2916,24 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) {
|
|||
break;
|
||||
}
|
||||
params->mem_compute_gb = std::stoi(argv[i]);
|
||||
} else if (arg == "--mem-compute0") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->mem_compute0_gb = std::stoi(argv[i]);
|
||||
} else if (arg == "--mem-compute1") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->mem_compute1_gb = std::stoi(argv[i]);
|
||||
} else if (arg == "--mem-compute2") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->mem_compute2_gb = std::stoi(argv[i]);
|
||||
} else if (arg == "-h" || arg == "--help") {
|
||||
train_print_usage(argc, argv, &default_params);
|
||||
exit(0);
|
||||
|
@ -2563,6 +3089,13 @@ int main(int argc, char ** argv) {
|
|||
size_t compute_size = 1024ll*1024ll*1024ll*((size_t) params.mem_compute_gb);
|
||||
uint8_t * compute_addr = new uint8_t[compute_size];
|
||||
|
||||
size_t size_buf_0 = 1024ll*1024ll*1024ll*((size_t) params.mem_compute0_gb);
|
||||
size_t size_buf_1 = 1024ll*1024ll*1024ll*((size_t) params.mem_compute1_gb);
|
||||
size_t size_buf_2 = 1024ll*1024ll*1024ll*((size_t) params.mem_compute2_gb);
|
||||
uint8_t * compute_buf_0 = new uint8_t[size_buf_0];
|
||||
uint8_t * compute_buf_1 = new uint8_t[size_buf_1];
|
||||
uint8_t * compute_buf_2 = new uint8_t[size_buf_2];
|
||||
|
||||
GGML_ASSERT(train_tokens.size() > n_tokens);;
|
||||
std::vector<int> train_samples;
|
||||
train_samples.push_back(0);
|
||||
|
@ -2601,22 +3134,46 @@ int main(int argc, char ** argv) {
|
|||
|
||||
int n_past = 0;
|
||||
|
||||
ggml_cgraph gf = {};
|
||||
gf.n_threads = params.n_threads;
|
||||
struct ggml_tensor * gfbuf = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / ggml_type_size(GGML_TYPE_I32) + (sizeof(struct ggml_cgraph) % ggml_type_size(GGML_TYPE_I32) ? 1 : 0));
|
||||
struct ggml_tensor * gbbuf = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / ggml_type_size(GGML_TYPE_I32) + (sizeof(struct ggml_cgraph) % ggml_type_size(GGML_TYPE_I32) ? 1 : 0));
|
||||
|
||||
struct ggml_cgraph * gf = (struct ggml_cgraph *) gfbuf->data;
|
||||
struct ggml_cgraph * gb = (struct ggml_cgraph *) gbbuf->data;
|
||||
|
||||
// ggml_cgraph gf = {};
|
||||
gf->n_threads = params.n_threads;
|
||||
gb->n_threads = params.n_threads;
|
||||
|
||||
get_example_targets_batch(lctx, train_samples.data(), train_samples.size(), train_tokens.data(), train_tokens.size(), ex, tokens_input, target_logits, target_probs);
|
||||
|
||||
struct ggml_tensor * logits =
|
||||
(n_past == 0)
|
||||
? (params.use_flash
|
||||
? forward_batch_wo_cache_flash_attn(&model, ctx0, &gf, tokens_input, n_tokens, n_batch)
|
||||
: forward_batch_wo_cache(&model, ctx0, &gf, tokens_input, n_tokens, n_batch))
|
||||
: forward_batch(&model, &kv_self, ctx0, &gf, tokens_input, n_tokens, n_past, n_batch);
|
||||
// struct ggml_tensor * logits =
|
||||
// (n_past == 0)
|
||||
// ? (params.use_flash
|
||||
// ? forward_batch_wo_cache_flash_attn(&model, ctx0, &gf, tokens_input, n_tokens, n_batch)
|
||||
// : forward_batch_wo_cache(&model, ctx0, &gf, tokens_input, n_tokens, n_batch))
|
||||
// : forward_batch(&model, &kv_self, ctx0, &gf, tokens_input, n_tokens, n_past, n_batch);
|
||||
|
||||
struct ggml_tensor * e = cross_entropy_loss(ctx0, logits, target_probs);
|
||||
// struct ggml_tensor * e = cross_entropy_loss(ctx0, logits, target_probs);
|
||||
struct ggml_tensor * logits;
|
||||
struct ggml_tensor * e = forward_batch_wo_cache_flash_attn_train(
|
||||
&model,
|
||||
ctx0,
|
||||
gf,
|
||||
gb,
|
||||
&logits,
|
||||
tokens_input,
|
||||
target_probs,
|
||||
compute_buf_0,
|
||||
compute_buf_1,
|
||||
compute_buf_2,
|
||||
size_buf_0,
|
||||
size_buf_1,
|
||||
size_buf_2,
|
||||
n_tokens,
|
||||
n_batch);
|
||||
|
||||
ggml_build_forward_expand(&gf, e);
|
||||
ggml_graph_compute(ctx0, &gf);
|
||||
// ggml_build_forward_expand(&gf, e);
|
||||
ggml_graph_compute(ctx0, gf);
|
||||
|
||||
size_t used_mem_before_opt = ggml_used_mem(ctx0);
|
||||
|
||||
|
@ -2633,7 +3190,8 @@ int main(int argc, char ** argv) {
|
|||
printf("%s: opt->params.adam.sched %.5f\n", __func__, opt->params.adam.sched);
|
||||
|
||||
// ggml_opt(ctx0, opt->params, e);
|
||||
ggml_opt_resume(ctx0, opt, e);
|
||||
// ggml_opt_resume(ctx0, opt, e);
|
||||
ggml_opt_resume_g(ctx0, opt, e, gf, gb);
|
||||
|
||||
size_t used_mem_after_opt = ggml_used_mem(ctx0);
|
||||
|
||||
|
@ -2641,8 +3199,8 @@ int main(int argc, char ** argv) {
|
|||
model.train_samples += n_batch;
|
||||
model.train_tokens += n_batch * n_tokens;
|
||||
|
||||
ggml_build_forward_expand(&gf, e);
|
||||
ggml_graph_compute(ctx0, &gf);
|
||||
//ggml_build_forward_expand(&gf, e);
|
||||
ggml_graph_compute(ctx0, gf);
|
||||
|
||||
float error_after_opt = ggml_get_f32_1d(e, 0);
|
||||
|
||||
|
@ -2753,7 +3311,10 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
}
|
||||
|
||||
free(compute_addr);
|
||||
delete[] compute_addr;
|
||||
delete[] compute_buf_0;
|
||||
delete[] compute_buf_1;
|
||||
delete[] compute_buf_2;
|
||||
ggml_free(model.ctx);
|
||||
|
||||
return 0;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue