implement gradient checkpointing for training

reduces memory overhead from O(n_layer) to O(sqrt(n_layer))

as explained in readme of https://github.com/cybertronai/gradient-checkpointing
This commit is contained in:
xaedes 2023-07-28 23:06:05 +02:00
parent d7003a98cc
commit 6e3f95bf06
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -1921,6 +1921,556 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
return t36;
}
struct ggml_tensor * forward_batch_wo_cache_flash_attn_train_grad_checkpointing(
struct my_llama_model * model,
struct ggml_context * ctx0,
struct ggml_cgraph * gf,
struct ggml_cgraph * gb,
struct ggml_tensor * * logits,
struct ggml_tensor * tokens_input,
struct ggml_tensor * targets,
void * compute_buf_0,
void * compute_buf_1,
void * compute_buf_2,
void * compute_buf_3,
size_t size_buf_0,
size_t size_buf_1,
size_t size_buf_2,
size_t size_buf_3,
const int n_tokens,
const int n_batch) {
// implements gradient-checkpointing as explained in readme of https://github.com/cybertronai/gradient-checkpointing
ggml_set_scratch(ctx0, { 0, 0, nullptr, });
const int n_past = 0;
const int N = n_tokens;
gf->n_nodes = 0;
gf->n_leafs = 0;
gf->perf_runs = 0;
gf->perf_cycles = 0;
gf->perf_time_us = 0;
const auto & hparams = model->hparams;
//const int n_ctx = hparams.n_ctx;
const int n_vocab = hparams.n_vocab;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_head = hparams.n_head;
const int n_rot = hparams.n_rot;
const int n_ff = get_n_ff(&hparams);
const int rope_mode = 0;
bool track_max_mem = true;
int last_buf = -1;
size_t buf_offs[4] = { 0, 0, 0, 0 };
size_t buf_size[4] = { size_buf_0,
size_buf_1,
size_buf_2,
size_buf_3 };
void * buf_data[4] = { compute_buf_0,
compute_buf_1,
compute_buf_2,
compute_buf_3 };
size_t buf_maxs[4] = { 0, 0, 0, 0 };
auto use_buf = [ctx0, &last_buf, &buf_offs, &buf_size, &buf_data, &buf_maxs] (int buf) {
size_t last_offs = 0;
last_offs = ggml_set_scratch(ctx0, { 0, 0, nullptr, });
if (last_buf >= 0) {
buf_offs[last_buf] = last_offs;
buf_maxs[last_buf] = std::max(buf_maxs[last_buf], buf_offs[last_buf]);
}
if (buf >= 0) {
size_t offs = buf_offs[buf];
size_t size = buf_size[buf];
void * data = buf_data[buf];
ggml_set_scratch(ctx0, { offs, size, data, });
}
last_buf = buf;
};
auto clr_buf = [ctx0, &last_buf, &buf_offs, &buf_size, &buf_data, &buf_maxs, track_max_mem] (int buf) {
if (buf < 0) return;
if (track_max_mem) {
size_t last_offs = 0;
last_offs = ggml_set_scratch(ctx0, { 0, 0, nullptr, });
if (last_buf >= 0) {
buf_offs[last_buf] = last_offs;
buf_maxs[last_buf] = std::max(buf_maxs[last_buf], buf_offs[last_buf]);
}
}
buf_offs[buf] = 0;
if (track_max_mem && last_buf >= 0) {
size_t offs = buf_offs[last_buf];
size_t size = buf_size[last_buf];
void * data = buf_data[last_buf];
ggml_set_scratch(ctx0, { offs, size, data, });
}
};
auto view__q = [ctx0, n_embd, n_head, N, n_batch] (struct ggml_tensor * t) -> struct ggml_tensor * {
int64_t ne0 = n_embd/n_head;
int64_t ne1 = N;
int64_t ne2 = n_head;
int64_t ne3 = n_batch;
size_t nb0 = ggml_element_size(t);
size_t nb1 = nb0*ne0;
size_t nb2 = nb1*ne1;
size_t nb3 = nb2*ne2;
size_t offset = 0;
return ggml_view_4d(ctx0, t, ne0, ne1, ne2, ne3, nb1, nb2, nb3, offset);
};
auto view__k = [ctx0, n_embd, n_head, N, n_batch] (struct ggml_tensor * t) -> struct ggml_tensor * {
int64_t ne0 = n_embd/n_head;
int64_t ne1 = N;
int64_t ne2 = n_head;
int64_t ne3 = n_batch;
size_t nb0 = ggml_element_size(t);
size_t nb1 = nb0*ne0;
size_t nb2 = nb1*ne1;
size_t nb3 = nb2*ne2;
size_t offset = nb3*ne3;
return ggml_view_4d(ctx0, t, ne0, ne1, ne2, ne3, nb1, nb2, nb3, offset);
};
auto view__v = [ctx0, n_embd, n_head, N, n_batch] (struct ggml_tensor * t) -> struct ggml_tensor * {
int64_t ne0 = N;
int64_t ne1 = n_embd/n_head;
int64_t ne2 = n_head;
int64_t ne3 = n_batch;
size_t nb0 = ggml_element_size(t);
size_t nb1 = nb0*ne0;
size_t nb2 = nb1*ne1;
size_t nb3 = nb2*ne2;
size_t offset = 2*nb3*ne3;
return ggml_view_4d(ctx0, t, ne0, ne1, ne2, ne3, nb1, nb2, nb3, offset);
};
auto add_or_set = [ctx0] (struct ggml_tensor * a, struct ggml_tensor * b) -> struct ggml_tensor * {
if (a == NULL) {
return b;
} else {
return ggml_add_inplace(ctx0, a, b);
}
};
use_buf(-1);
model->tok_embeddings->grad = NULL;
model->norm->grad = NULL;
model->output->grad = NULL;
for (int il = 0; il < n_layer; ++il) {
struct my_llama_layer & layer = model->layers[il];
layer.attention_norm->grad = NULL;
layer.wq->grad = NULL;
layer.wk->grad = NULL;
layer.wv->grad = NULL;
layer.wo->grad = NULL;
layer.ffn_norm->grad = NULL;
layer.w1->grad = NULL;
layer.w2->grad = NULL;
layer.w3->grad = NULL;
}
clr_buf(0);
clr_buf(1);
clr_buf(2);
clr_buf(3);
use_buf(-1);
struct ggml_tensor * t00 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N*n_batch); assert_shape_1d(t00, N*n_batch);
memcpy(t00->data, tokens_input->data, ggml_element_size(t00)*N*n_batch);
use_buf(-1);
struct ggml_tensor * t01 = expand(gf, ggml_get_rows(ctx0, model->tok_embeddings, t00)); assert_shape_2d(t01, n_embd, N*n_batch);
std::vector<int> checkpoints;
// for (int il = 0; il < n_layer; ++il) {
// checkpoints.push_back(il);
// }
// n_check: number of layers between checkpoints
int n_check = (int)(sqrtf(n_layer) + 0.5f);
printf("%s: n_check = %d\n", __func__, n_check);
for (int chk = n_check-1; chk+1 < n_layer; chk += n_check) {
checkpoints.push_back(chk);
}
for (int i = 0; i < checkpoints.size(); ++i) {
printf("%s: checkpoint #%d = %d\n", __func__, i, checkpoints[i]);
}
// example for 16 layers:
// inp ~ implicit zeroth checkpoint == input
// L00 f 4b
// L01 f 4b
// L02 f 4b
// L03 fc4b first checkpoint
// L04 f 3b
// L05 f 3b
// L06 f 3b
// L07 fc3b second checkpoint
// L08 f 2b
// L09 f 2b
// L10 f 2b
// L11 fc2b third checkpoint
// L12 f 1b
// L13 f 1b
// L14 f 1b
// L15 f 1b
// need to remember these for the backward pass
std::vector<struct ggml_tensor *> t02L; t02L.resize(n_layer, NULL);
std::vector<struct ggml_tensor *> t03L; t03L.resize(n_layer, NULL);
std::vector<struct ggml_tensor *> t04L; t04L.resize(n_layer, NULL);
std::vector<struct ggml_tensor *> t05L; t05L.resize(n_layer, NULL);
std::vector<struct ggml_tensor *> t06L; t06L.resize(n_layer, NULL);
std::vector<struct ggml_tensor *> t07L; t07L.resize(n_layer, NULL);
std::vector<struct ggml_tensor *> t08L; t08L.resize(n_layer, NULL);
std::vector<struct ggml_tensor *> t09L; t09L.resize(n_layer, NULL);
std::vector<struct ggml_tensor *> t10L; t10L.resize(n_layer, NULL);
std::vector<struct ggml_tensor *> t11L; t11L.resize(n_layer, NULL);
std::vector<struct ggml_tensor *> t12L; t12L.resize(n_layer, NULL);
std::vector<struct ggml_tensor *> t13L; t13L.resize(n_layer, NULL);
std::vector<struct ggml_tensor *> t14L; t14L.resize(n_layer, NULL);
std::vector<struct ggml_tensor *> t15L; t15L.resize(n_layer, NULL);
std::vector<struct ggml_tensor *> t16L; t16L.resize(n_layer, NULL);
std::vector<struct ggml_tensor *> t17L; t17L.resize(n_layer, NULL);
std::vector<struct ggml_tensor *> t18L; t18L.resize(n_layer, NULL);
std::vector<struct ggml_tensor *> t19L; t19L.resize(n_layer, NULL);
std::vector<struct ggml_tensor *> t20L; t20L.resize(n_layer, NULL);
std::vector<struct ggml_tensor *> t21L; t21L.resize(n_layer, NULL);
std::vector<struct ggml_tensor *> t22L; t22L.resize(n_layer, NULL);
std::vector<struct ggml_tensor *> t23L; t23L.resize(n_layer, NULL);
std::vector<struct ggml_tensor *> t24L; t24L.resize(n_layer, NULL);
std::vector<struct ggml_tensor *> t25L; t25L.resize(n_layer, NULL);
std::vector<struct ggml_tensor *> t26L; t26L.resize(n_layer, NULL);
std::vector<struct ggml_tensor *> t27L; t27L.resize(n_layer, NULL);
std::vector<struct ggml_tensor *> t28L; t28L.resize(n_layer, NULL);
std::vector<struct ggml_tensor *> t29L; t29L.resize(n_layer, NULL);
std::vector<struct ggml_tensor *> t30L; t30L.resize(n_layer, NULL);
struct ggml_tensor * cur = t01;
int chk_idx = 0;
for (int il = 0; il < n_layer; ++il) {
struct my_llama_layer & layer = model->layers[il];
// tensors with values necessary for backward pass are in persistent buf(-1)
// other tensors with buf(0), buf(1), etc are only temporary needed, and their memory reused
bool is_checkpoint = (chk_idx < checkpoints.size() && il == checkpoints[chk_idx]);
if (is_checkpoint) {
printf("%s: layer %d is_checkpoint\n", __func__, il);
chk_idx += 1;
}
const int prs = 0; // in first forward pass even persistent tensors are only temporary
const int tmp = 0; // temporary
// nxt is required to compute next layer.
// for checkpoints we need to remember this for usage in backward pass,
// otherwise temporary until next of this kind
const int nxt = is_checkpoint ? -1 : 1;
clr_buf(0);
use_buf(prs); struct ggml_tensor * t02 = expand(gf, ggml_rms_norm (ctx0, cur, rms_norm_eps)); assert_shape_2d(t02, n_embd, N*n_batch);
use_buf(tmp); struct ggml_tensor * t03 = expand(gf, ggml_repeat (ctx0, layer.attention_norm, t02)); assert_shape_2d(t03, n_embd, N*n_batch);
use_buf(prs); struct ggml_tensor * t04 = expand(gf, ggml_mul (ctx0, t02, t03)); assert_shape_2d(t04, n_embd, N*n_batch);
use_buf(prs); struct ggml_tensor * t05 = expand(gf, ggml_mul_mat (ctx0, layer.wq, t04)); assert_shape_2d(t05, n_embd, N*n_batch);
use_buf(prs); struct ggml_tensor * t06 = expand(gf, ggml_reshape_4d (ctx0, t05, n_embd/n_head, n_head, N, n_batch)); assert_shape_4d(t06, n_embd/n_head, n_head, N, n_batch);
use_buf(prs); struct ggml_tensor * t07 = expand(gf, ggml_rope_inplace (ctx0, t06, n_past, n_rot, rope_mode)); assert_shape_4d(t07, n_embd/n_head, n_head, N, n_batch);
use_buf(prs); struct ggml_tensor * t08 = expand(gf, ggml_mul_mat (ctx0, layer.wk, t04)); assert_shape_2d(t08, n_embd, N*n_batch);
use_buf(prs); struct ggml_tensor * t09 = expand(gf, ggml_reshape_4d (ctx0, t08, n_embd/n_head, n_head, N, n_batch)); assert_shape_4d(t09, n_embd/n_head, n_head, N, n_batch);
use_buf(prs); struct ggml_tensor * t10 = expand(gf, ggml_rope_inplace (ctx0, t09, n_past, n_rot, rope_mode)); assert_shape_4d(t10, n_embd/n_head, n_head, N, n_batch);
use_buf(prs); struct ggml_tensor * t11 = expand(gf, ggml_mul_mat (ctx0, t04, layer.wv)); assert_shape_2d(t11, N*n_batch, n_embd);
use_buf(prs); struct ggml_tensor * t12 = expand(gf, ggml_reshape_4d (ctx0, t11, N, n_batch, n_embd/n_head, n_head)); assert_shape_4d(t12, N, n_batch, n_embd/n_head, n_head);
use_buf(prs); struct ggml_tensor * t13 = expand(gf, ggml_permute (ctx0, t07, 0, 2, 1, 3)); assert_shape_4d(t13, n_embd/n_head, N, n_head, n_batch);
use_buf(prs); struct ggml_tensor * t14 = expand(gf, ggml_permute (ctx0, t10, 0, 2, 1, 3)); assert_shape_4d(t14, n_embd/n_head, N, n_head, n_batch);
use_buf(prs); struct ggml_tensor * t15 = expand(gf, ggml_permute (ctx0, t12, 0, 3, 1, 2)); assert_shape_4d(t15, N, n_embd/n_head, n_head, n_batch);
use_buf(prs); struct ggml_tensor * t16 = expand(gf, ggml_flash_attn (ctx0, t13, t14, t15, true)); assert_shape_4d(t16, n_embd/n_head, N, n_head, n_batch);
use_buf(tmp); struct ggml_tensor * t17 = expand(gf, ggml_permute (ctx0, t16, 0, 2, 1, 3)); assert_shape_4d(t17, n_embd/n_head, n_head, N, n_batch);
use_buf(prs); struct ggml_tensor * t18 = expand(gf, ggml_cont (ctx0, t17)); assert_shape_4d(t18, n_embd/n_head, n_head, N, n_batch);
use_buf(prs); struct ggml_tensor * t19 = expand(gf, ggml_reshape_2d (ctx0, t18, n_embd, N*n_batch)); assert_shape_2d(t19, n_embd, N*n_batch);
use_buf(tmp); struct ggml_tensor * t20 = expand(gf, ggml_mul_mat (ctx0, layer.wo, t19)); assert_shape_2d(t20, n_embd, N*n_batch);
use_buf(prs); struct ggml_tensor * t21 = expand(gf, ggml_add (ctx0, t20, cur)); assert_shape_2d(t21, n_embd, N*n_batch);
use_buf(prs); struct ggml_tensor * t22 = expand(gf, ggml_rms_norm (ctx0, t21, rms_norm_eps)); assert_shape_2d(t22, n_embd, N*n_batch);
use_buf(tmp); struct ggml_tensor * t23 = expand(gf, ggml_repeat (ctx0, layer.ffn_norm, t22)); assert_shape_2d(t23, n_embd, N*n_batch);
use_buf(prs); struct ggml_tensor * t24 = expand(gf, ggml_mul (ctx0, t23, t22)); assert_shape_2d(t24, n_embd, N*n_batch);
use_buf(prs); struct ggml_tensor * t25 = expand(gf, ggml_mul_mat (ctx0, layer.w3, t24)); assert_shape_2d(t25, n_ff, N*n_batch);
use_buf(prs); struct ggml_tensor * t26 = expand(gf, ggml_mul_mat (ctx0, layer.w1, t24)); assert_shape_2d(t26, n_ff, N*n_batch);
use_buf(prs); struct ggml_tensor * t27 = expand(gf, ggml_silu (ctx0, t26)); assert_shape_2d(t27, n_ff, N*n_batch);
use_buf(prs); struct ggml_tensor * t28 = expand(gf, ggml_mul (ctx0, t27, t25)); assert_shape_2d(t28, n_ff, N*n_batch);
use_buf(tmp); struct ggml_tensor * t29 = expand(gf, ggml_mul_mat (ctx0, layer.w2, t28)); assert_shape_2d(t29, n_embd, N*n_batch);
clr_buf( 1);
use_buf(nxt); struct ggml_tensor * t30 = expand(gf, ggml_add (ctx0, t21, t29)); assert_shape_2d(t30, n_embd, N*n_batch);
// only t30L is remembered for checkpointing in first forward pass
if (is_checkpoint) {
t30L[il] = t30;
}
cur = t30;
}
clr_buf(0);
use_buf(0);
struct ggml_tensor * t31 = expand(gf, ggml_rms_norm (ctx0, cur, rms_norm_eps)); assert_shape_2d(t31, n_embd, N*n_batch);
struct ggml_tensor * t32 = expand(gf, ggml_repeat (ctx0, model->norm, t31)); assert_shape_2d(t32, n_embd, N*n_batch);
struct ggml_tensor * t33 = expand(gf, ggml_mul (ctx0, t32, t31)); assert_shape_2d(t33, n_embd, N*n_batch);
use_buf(-1);
struct ggml_tensor * t34 = expand(gf, ggml_mul_mat (ctx0, model->output, t33)); assert_shape_2d(t34, n_vocab, N*n_batch);
struct ggml_tensor * t35 = expand(gf, ggml_reshape_3d(ctx0, t34, n_vocab, N, n_batch)); assert_shape_3d(t35, n_vocab, N, n_batch);
struct ggml_tensor * t36 = expand(gf, ggml_cross_entropy_loss(ctx0, t35, targets)); assert_shape_1d(t36, 1);
*gb = *gf;
// t36->grad gets set to one by optimizer, so we need the tensor.
// initialize it with 1.0f to make sure.
use_buf(-1);
t36->grad = expand(gb, ggml_new_f32(ctx0, 1.0f));
use_buf(0);
t35->grad = expand(gb, ggml_cross_entropy_loss_back(ctx0, t35, targets, t36->grad)); assert_shape_3d(t35->grad, n_vocab, N, n_batch);
t34->grad = expand(gb, ggml_reshape_2d (ctx0, t35->grad, n_vocab, N*n_batch)); assert_shape_2d(t34->grad, n_vocab, N*n_batch);
t33->grad = expand(gb, ggml_out_prod (ctx0, model->output, ggml_transpose(ctx0, t34->grad))); assert_shape_2d(t33->grad, n_embd, N*n_batch);
t32->grad = expand(gb, ggml_mul (ctx0, t33->grad, t31)); assert_shape_2d(t32->grad, n_embd, N*n_batch);
use_buf(-1);
model->norm->grad = expand(gb, add_or_set(model->norm->grad, ggml_repeat_back(ctx0, t32->grad, model->norm))); assert_shape_1d(model->norm->grad, n_embd);
model->output->grad = expand(gb, add_or_set(model->output->grad, ggml_out_prod(ctx0, t33, t34->grad))); assert_shape_2d(model->output->grad, n_embd, n_vocab);
clr_buf(1);
use_buf(1);
t31->grad = expand(gb, ggml_mul(ctx0, t33->grad, t32)); assert_shape_2d(t31->grad, n_embd, N*n_batch);
struct ggml_tensor * back_layer_inp = t31;
struct ggml_tensor * grad_layer_inp = NULL;
printf("%s: checkpoints.size() = %zu\n", __func__, checkpoints.size());
chk_idx = checkpoints.size()-1;
int avail_begin = n_layer;
int avail_end = n_layer;
printf("%s: chk_idx=%d avail_begin=%d avail_end=%d\n", __func__, chk_idx, avail_begin, avail_end);
for (int k = 0; k < n_layer; ++k) {
// second forward pass for checkpointing
int il = n_layer-1-k;
if (il < avail_begin) {
// make sure, that txxL[il] is available
// forward pass from last checkpoint
GGML_ASSERT(chk_idx >= -1);
int begin = (chk_idx == -1)
? 0
: checkpoints[chk_idx] + 1; // checkpoint[chk_idx] contains t30 for computing following layers -> +1
int end = (chk_idx+1 < checkpoints.size())
? (checkpoints[chk_idx+1] + 1)
: n_layer;
GGML_ASSERT(begin <= il);
GGML_ASSERT(il < end);
cur = (chk_idx == -1) ? t01 : t30L[checkpoints[chk_idx]];
clr_buf(2);
printf("%s: second forward pass chk_idx=%d begin=%d end=%d\n", __func__, chk_idx, begin, end);
for (int i = begin; i < end; ++i) {
struct my_llama_layer & layer = model->layers[i];
const int prs = 2; // persistent until next checkpoint
const int tmp = 0; // temporary for this layer
const bool is_checkpoint = (i == end-1);
clr_buf(0);
use_buf(prs); struct ggml_tensor * t02 = expand(gb, ggml_rms_norm (ctx0, cur, rms_norm_eps)); assert_shape_2d(t02, n_embd, N*n_batch);
use_buf(tmp); struct ggml_tensor * t03 = expand(gb, ggml_repeat (ctx0, layer.attention_norm, t02)); assert_shape_2d(t03, n_embd, N*n_batch);
use_buf(prs); struct ggml_tensor * t04 = expand(gb, ggml_mul (ctx0, t02, t03)); assert_shape_2d(t04, n_embd, N*n_batch);
use_buf(prs); struct ggml_tensor * t05 = expand(gb, ggml_mul_mat (ctx0, layer.wq, t04)); assert_shape_2d(t05, n_embd, N*n_batch);
use_buf(prs); struct ggml_tensor * t06 = expand(gb, ggml_reshape_4d (ctx0, t05, n_embd/n_head, n_head, N, n_batch)); assert_shape_4d(t06, n_embd/n_head, n_head, N, n_batch);
use_buf(prs); struct ggml_tensor * t07 = expand(gb, ggml_rope_inplace (ctx0, t06, n_past, n_rot, rope_mode)); assert_shape_4d(t07, n_embd/n_head, n_head, N, n_batch);
use_buf(prs); struct ggml_tensor * t08 = expand(gb, ggml_mul_mat (ctx0, layer.wk, t04)); assert_shape_2d(t08, n_embd, N*n_batch);
use_buf(prs); struct ggml_tensor * t09 = expand(gb, ggml_reshape_4d (ctx0, t08, n_embd/n_head, n_head, N, n_batch)); assert_shape_4d(t09, n_embd/n_head, n_head, N, n_batch);
use_buf(prs); struct ggml_tensor * t10 = expand(gb, ggml_rope_inplace (ctx0, t09, n_past, n_rot, rope_mode)); assert_shape_4d(t10, n_embd/n_head, n_head, N, n_batch);
use_buf(prs); struct ggml_tensor * t11 = expand(gb, ggml_mul_mat (ctx0, t04, layer.wv)); assert_shape_2d(t11, N*n_batch, n_embd);
use_buf(prs); struct ggml_tensor * t12 = expand(gb, ggml_reshape_4d (ctx0, t11, N, n_batch, n_embd/n_head, n_head)); assert_shape_4d(t12, N, n_batch, n_embd/n_head, n_head);
use_buf(prs); struct ggml_tensor * t13 = expand(gb, ggml_permute (ctx0, t07, 0, 2, 1, 3)); assert_shape_4d(t13, n_embd/n_head, N, n_head, n_batch);
use_buf(prs); struct ggml_tensor * t14 = expand(gb, ggml_permute (ctx0, t10, 0, 2, 1, 3)); assert_shape_4d(t14, n_embd/n_head, N, n_head, n_batch);
use_buf(prs); struct ggml_tensor * t15 = expand(gb, ggml_permute (ctx0, t12, 0, 3, 1, 2)); assert_shape_4d(t15, N, n_embd/n_head, n_head, n_batch);
use_buf(prs); struct ggml_tensor * t16 = expand(gb, ggml_flash_attn (ctx0, t13, t14, t15, true)); assert_shape_4d(t16, n_embd/n_head, N, n_head, n_batch);
use_buf(tmp); struct ggml_tensor * t17 = expand(gb, ggml_permute (ctx0, t16, 0, 2, 1, 3)); assert_shape_4d(t17, n_embd/n_head, n_head, N, n_batch);
use_buf(prs); struct ggml_tensor * t18 = expand(gb, ggml_cont (ctx0, t17)); assert_shape_4d(t18, n_embd/n_head, n_head, N, n_batch);
use_buf(prs); struct ggml_tensor * t19 = expand(gb, ggml_reshape_2d (ctx0, t18, n_embd, N*n_batch)); assert_shape_2d(t19, n_embd, N*n_batch);
use_buf(tmp); struct ggml_tensor * t20 = expand(gb, ggml_mul_mat (ctx0, layer.wo, t19)); assert_shape_2d(t20, n_embd, N*n_batch);
use_buf(prs); struct ggml_tensor * t21 = expand(gb, ggml_add (ctx0, t20, cur)); assert_shape_2d(t21, n_embd, N*n_batch);
use_buf(prs); struct ggml_tensor * t22 = expand(gb, ggml_rms_norm (ctx0, t21, rms_norm_eps)); assert_shape_2d(t22, n_embd, N*n_batch);
use_buf(tmp); struct ggml_tensor * t23 = expand(gb, ggml_repeat (ctx0, layer.ffn_norm, t22)); assert_shape_2d(t23, n_embd, N*n_batch);
use_buf(prs); struct ggml_tensor * t24 = expand(gb, ggml_mul (ctx0, t23, t22)); assert_shape_2d(t24, n_embd, N*n_batch);
use_buf(prs); struct ggml_tensor * t25 = expand(gb, ggml_mul_mat (ctx0, layer.w3, t24)); assert_shape_2d(t25, n_ff, N*n_batch);
use_buf(prs); struct ggml_tensor * t26 = expand(gb, ggml_mul_mat (ctx0, layer.w1, t24)); assert_shape_2d(t26, n_ff, N*n_batch);
use_buf(prs); struct ggml_tensor * t27 = expand(gb, ggml_silu (ctx0, t26)); assert_shape_2d(t27, n_ff, N*n_batch);
use_buf(prs); struct ggml_tensor * t28 = expand(gb, ggml_mul (ctx0, t27, t25)); assert_shape_2d(t28, n_ff, N*n_batch);
use_buf(tmp); struct ggml_tensor * t29 = expand(gb, ggml_mul_mat (ctx0, layer.w2, t28)); assert_shape_2d(t29, n_embd, N*n_batch);
if (t30L[i] == NULL) {
use_buf(prs); struct ggml_tensor * t30 = expand(gb, ggml_add (ctx0, t21, t29)); assert_shape_2d(t30, n_embd, N*n_batch);
t30L[i] = t30;
cur = t30;
}
t02L[i] = t02;
t03L[i] = t03;
t04L[i] = t04;
t05L[i] = t05;
t06L[i] = t06;
t07L[i] = t07;
t08L[i] = t08;
t09L[i] = t09;
t10L[i] = t10;
t11L[i] = t11;
t12L[i] = t12;
t13L[i] = t13;
t14L[i] = t14;
t15L[i] = t15;
t16L[i] = t16;
t17L[i] = t17;
t18L[i] = t18;
t19L[i] = t19;
t20L[i] = t20;
t21L[i] = t21;
t22L[i] = t22;
t23L[i] = t23;
t24L[i] = t24;
t25L[i] = t25;
t26L[i] = t26;
t27L[i] = t27;
t28L[i] = t28;
t29L[i] = t29;
}
--chk_idx;
avail_begin = begin;
avail_end = end;
printf("%s: chk_idx=%d avail_begin=%d avail_end=%d\n", __func__, chk_idx, avail_begin, avail_end);
}
printf("%s: backward pass il=%d\n", __func__, il);
struct my_llama_layer & layer = model->layers[il];
struct ggml_tensor * t02 = t02L[il];
struct ggml_tensor * t03 = t03L[il];
struct ggml_tensor * t04 = t04L[il];
struct ggml_tensor * t05 = t05L[il];
struct ggml_tensor * t06 = t06L[il];
struct ggml_tensor * t07 = t07L[il];
struct ggml_tensor * t08 = t08L[il];
struct ggml_tensor * t09 = t09L[il];
struct ggml_tensor * t10 = t10L[il];
struct ggml_tensor * t11 = t11L[il];
struct ggml_tensor * t12 = t12L[il];
struct ggml_tensor * t13 = t13L[il];
struct ggml_tensor * t14 = t14L[il];
struct ggml_tensor * t15 = t15L[il];
struct ggml_tensor * t16 = t16L[il];
struct ggml_tensor * t17 = t17L[il];
struct ggml_tensor * t18 = t18L[il];
struct ggml_tensor * t19 = t19L[il];
struct ggml_tensor * t20 = t20L[il];
struct ggml_tensor * t21 = t21L[il];
struct ggml_tensor * t22 = t22L[il];
struct ggml_tensor * t23 = t23L[il];
struct ggml_tensor * t24 = t24L[il];
struct ggml_tensor * t25 = t25L[il];
struct ggml_tensor * t26 = t26L[il];
struct ggml_tensor * t27 = t27L[il];
struct ggml_tensor * t28 = t28L[il];
struct ggml_tensor * t29 = t29L[il];
struct ggml_tensor * t30 = t30L[il];
clr_buf(0);
use_buf(0);
t30->grad = expand(gb, ggml_rms_norm_back(ctx0, t30, back_layer_inp->grad)); assert_shape_2d(t30->grad, n_embd, N*n_batch);
if (grad_layer_inp) {
t30->grad = expand(gb, ggml_add(ctx0, t30->grad, grad_layer_inp->grad)); assert_shape_2d(t30->grad, n_embd, N*n_batch);
}
clr_buf(1);
t29->grad = t30->grad; assert_shape_2d(t29->grad, n_embd, N*n_batch);
t28->grad = expand(gb, ggml_out_prod(ctx0, layer.w2, ggml_transpose(ctx0, t29->grad))); assert_shape_2d(t28->grad, n_ff, N*n_batch);
t27->grad = expand(gb, ggml_mul(ctx0, t28->grad, t25)); assert_shape_2d(t27->grad, n_ff, N*n_batch);
t26->grad = expand(gb, ggml_silu_back(ctx0, t26, t27->grad)); assert_shape_2d(t26->grad, n_ff, N*n_batch);
t25->grad = expand(gb, ggml_mul(ctx0, t28->grad, t27)); assert_shape_2d(t25->grad, n_ff, N*n_batch);
t24->grad = expand(gb, ggml_add_inplace(ctx0,
ggml_out_prod(ctx0, layer.w1, ggml_transpose(ctx0, t26->grad)),
ggml_out_prod(ctx0, layer.w3, ggml_transpose(ctx0, t25->grad)))); assert_shape_2d(t24->grad, n_embd, N*n_batch);
t23->grad = expand(gb, ggml_mul(ctx0, t24->grad, t22)); assert_shape_2d(t23->grad, n_embd, N*n_batch);
t22->grad = expand(gb, ggml_mul(ctx0, t24->grad, ggml_repeat(ctx0, layer.ffn_norm, t24->grad))); assert_shape_2d(t22->grad, n_embd, N*n_batch);
use_buf(1);
t21->grad = expand(gb, ggml_add(ctx0, t30->grad, ggml_rms_norm_back(ctx0, t21, t22->grad))); assert_shape_2d(t21->grad, n_embd, N*n_batch);
grad_layer_inp = t21;
use_buf(0);
t20->grad = t21->grad; assert_shape_2d(t20->grad, n_embd, N*n_batch);
t19->grad = expand(gb, ggml_out_prod(ctx0, layer.wo, ggml_transpose(ctx0, t20->grad))); assert_shape_2d(t19->grad, n_embd, N*n_batch);
t18->grad = expand(gb, ggml_reshape_4d(ctx0, t19->grad, n_embd/n_head, n_head, N, n_batch)); assert_shape_4d(t18->grad, n_embd/n_head, n_head, N, n_batch);
t17->grad = t18->grad; assert_shape_4d(t17->grad, n_embd/n_head, n_head, N, n_batch);
t16->grad = expand(gb, ggml_permute(ctx0, t17->grad, 0, 2, 1, 3)); assert_shape_4d(t16->grad, n_embd/n_head, N, n_head, n_batch);
struct ggml_tensor * flash_attn = expand(gb, ggml_flash_attn_back(ctx0, t13, t14, t15, t16->grad, true)); assert_shape_4d(flash_attn, n_embd/n_head, N*3, n_head, n_batch);
t15->grad = expand(gb, view__v(flash_attn)); assert_shape_4d(t15->grad, N, n_embd/n_head, n_head, n_batch);
t14->grad = expand(gb, view__k(flash_attn)); assert_shape_4d(t14->grad, n_embd/n_head, N, n_head, n_batch);
t13->grad = expand(gb, view__q(flash_attn)); assert_shape_4d(t13->grad, n_embd/n_head, N, n_head, n_batch);
t12->grad = expand(gb, ggml_permute(ctx0, t15->grad, 0, 2, 3, 1)); assert_shape_4d(t12->grad, N, n_batch, n_embd/n_head, n_head);
t11->grad = expand(gb, ggml_reshape_2d(ctx0, ggml_cont(ctx0, t12->grad), N*n_batch, n_embd)); assert_shape_2d(t11->grad, N*n_batch, n_embd);
t10->grad = expand(gb, ggml_permute(ctx0, t14->grad, 0, 2, 1, 3)); assert_shape_4d(t10->grad, n_embd/n_head, n_head, N, n_batch);
t09->grad = expand(gb, ggml_rope_back(ctx0, t10->grad, n_past, n_rot, rope_mode)); assert_shape_4d(t09->grad, n_embd/n_head, n_head, N, n_batch);
t08->grad = expand(gb, ggml_reshape_2d(ctx0, t09->grad, n_embd, N*n_batch)); assert_shape_2d(t08->grad, n_embd, N*n_batch);
t07->grad = expand(gb, ggml_permute(ctx0, t13->grad, 0, 2, 1, 3)); assert_shape_4d(t07->grad, n_embd/n_head, n_head, N, n_batch);
t06->grad = expand(gb, ggml_rope_back(ctx0, t07->grad, n_past, n_rot, rope_mode)); assert_shape_4d(t06->grad, n_embd/n_head, n_head, N, n_batch);
t05->grad = expand(gb, ggml_reshape_2d(ctx0, t06->grad, n_embd, N*n_batch)); assert_shape_2d(t05->grad, n_embd, N*n_batch);
t04->grad = expand(gb, ggml_add_inplace(ctx0,
ggml_add_inplace(ctx0,
ggml_out_prod(ctx0, layer.wv, t11->grad),
ggml_out_prod(ctx0, layer.wk, ggml_transpose(ctx0, t08->grad))),
ggml_out_prod(ctx0, layer.wq, ggml_transpose(ctx0, t05->grad)))); assert_shape_2d(t04->grad, n_embd, N*n_batch);
t03->grad = expand(gb, ggml_mul(ctx0, t04->grad, t02)); assert_shape_2d(t04->grad, n_embd, N*n_batch);
use_buf(1);
t02->grad = expand(gb, ggml_mul(ctx0, t04->grad, ggml_repeat(ctx0, layer.attention_norm, t02))); assert_shape_2d(t02->grad, n_embd, N*n_batch);
back_layer_inp = t02;
use_buf(-1);
layer.attention_norm->grad = expand(gb, add_or_set(layer.attention_norm->grad, ggml_repeat_back(ctx0, t03->grad, layer.attention_norm))); assert_shape_1d(layer.attention_norm->grad, n_embd);
layer.wq->grad = expand(gb, add_or_set(layer.wq->grad, ggml_out_prod(ctx0, t04, t05->grad))); assert_shape_2d(layer.wq->grad, n_embd, n_embd);
layer.wk->grad = expand(gb, add_or_set(layer.wk->grad, ggml_out_prod(ctx0, t04, t08->grad))); assert_shape_2d(layer.wk->grad, n_embd, n_embd);
layer.wv->grad = expand(gb, add_or_set(layer.wv->grad, ggml_out_prod(ctx0, t04, ggml_transpose(ctx0, t11->grad)))); assert_shape_2d(layer.wv->grad, n_embd, n_embd);
layer.wo->grad = expand(gb, add_or_set(layer.wo->grad, ggml_out_prod(ctx0, t19, t20->grad))); assert_shape_2d(layer.wo->grad, n_embd, n_embd);
layer.ffn_norm->grad = expand(gb, add_or_set(layer.ffn_norm->grad, ggml_repeat_back(ctx0, t23->grad, layer.ffn_norm))); assert_shape_1d(layer.ffn_norm->grad, n_embd);
layer.w1->grad = expand(gb, add_or_set(layer.w1->grad, ggml_out_prod(ctx0, t24, t26->grad))); assert_shape_2d(layer.w1->grad, n_embd, n_ff);
layer.w2->grad = expand(gb, add_or_set(layer.w2->grad, ggml_out_prod(ctx0, t28, t29->grad))); assert_shape_2d(layer.w2->grad, n_ff, n_embd);
layer.w3->grad = expand(gb, add_or_set(layer.w3->grad, ggml_out_prod(ctx0, t24, t25->grad))); assert_shape_2d(layer.w3->grad, n_embd, n_ff);
}
printf("%s: chk_idx=%d avail_begin=%d avail_end=%d\n", __func__, chk_idx, avail_begin, avail_end);
GGML_ASSERT(chk_idx == -2);
GGML_ASSERT(avail_begin == 0);
clr_buf(0);
use_buf(0);
t01->grad = expand(gb, ggml_add_inplace(ctx0, grad_layer_inp->grad, ggml_rms_norm_back(ctx0, t01, back_layer_inp->grad))); assert_shape_2d(t01->grad, n_embd, N*n_batch);
use_buf(-1);
model->tok_embeddings->grad = expand(gb, ggml_get_rows_back(ctx0, t01->grad, t00, model->tok_embeddings)); assert_shape_2d(model->tok_embeddings->grad, n_embd, n_vocab);
*logits = t35;
clr_buf(0);
clr_buf(1);
clr_buf(2);
clr_buf(3);
if (track_max_mem) {
printf("%s: max size compute buf0: %zu\n", __func__, buf_maxs[0]);
printf("%s: max size compute buf1: %zu\n", __func__, buf_maxs[1]);
printf("%s: max size compute buf2: %zu\n", __func__, buf_maxs[2]);
printf("%s: max size compute buf3: %zu\n", __func__, buf_maxs[3]);
}
// now that all grads are created, set the graph leafs and grads
graph_set_leafs_grads(gf);
graph_set_leafs_grads(gb);
return t36;
}
void set_f32_3d(struct ggml_tensor * tensor, int64_t i0, int64_t i1, int64_t i2, float value) {
float * ptr = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2]);
*ptr = value;
@ -2810,6 +3360,7 @@ struct train_params {
bool use_adam;
bool use_flash;
bool use_scratch;
bool use_checkpointing;
// only adam
int warmup;
@ -2829,6 +3380,8 @@ struct train_params {
int mem_compute_gb;
int mem_compute0_gb;
int mem_compute1_gb;
int mem_compute2_gb;
int mem_compute3_gb;
};
struct train_params get_default_train_params() {
@ -2860,6 +3413,7 @@ struct train_params get_default_train_params() {
params.use_adam = true;
params.use_flash = true;
params.use_scratch = true;
params.use_checkpointing = true;
// only adam
params.warmup = 100;
@ -2878,8 +3432,9 @@ struct train_params get_default_train_params() {
params.mem_model_gb = 2;
params.mem_compute_gb = 24;
params.mem_compute0_gb = 8;
params.mem_compute1_gb = 2;
params.mem_compute1_gb = 1;
params.mem_compute2_gb = 2;
params.mem_compute3_gb = 1;
return params;
}
@ -2909,14 +3464,16 @@ void train_print_usage(int /*argc*/, char ** argv, const struct train_params * p
fprintf(stderr, " --samples-after-nl Training samples start after newlines. (default %s)\n", params->samples_start_after_nl ? "on" : "off");
fprintf(stderr, " --use-lbfgs Use LBFGS optimizer instead of default Adam\n");
fprintf(stderr, " --use-adam Use Adam optimizer (default)\n");
fprintf(stderr, " --no-flash Don't use flash attention.\n");
fprintf(stderr, " --no-flash Don't use flash attention. Implies no-scratch and no-checkpointing.\n");
fprintf(stderr, " --use-flash Use flash attention (default)\n");
fprintf(stderr, " --no-scratch Don't use scratch buffers\n");
fprintf(stderr, " --use-scratch Use scratch buffers (default)\n");
fprintf(stderr, " --warmup N Number of warmup steps (default %d)\n", params->warmup);
fprintf(stderr, " --cos-decay-steps N Number of cosine decay steps (default %d)\n", params->cos_decay_steps);
fprintf(stderr, " --cos-decay-restart N Increase of cosine decay steps after restart (default %f)\n", params->cos_decay_restart);
fprintf(stderr, " --cos-decay-alpha N Cosine decay alpha (default %f)\n", params->cos_decay_alpha);
fprintf(stderr, " --no-scratch Don't use scratch buffers. Implies no-checkpointing.\n");
fprintf(stderr, " --use-scratch Use scratch buffers. Implies use-flash. (default)\n");
fprintf(stderr, " --no-checkpointing Don't use gradient checkpointing\n");
fprintf(stderr, " --use-checkpointing Use gradient checkpointing. Implies use-scratch and use-flash. (default)\n");
fprintf(stderr, " --warmup N Only for Adam optimizer. Number of warmup steps (default %d)\n", params->warmup);
fprintf(stderr, " --cos-decay-steps N Only for Adam optimizer. Number of cosine decay steps (default %d)\n", params->cos_decay_steps);
fprintf(stderr, " --cos-decay-restart N Only for Adam optimizer. Increase of cosine decay steps after restart (default %f)\n", params->cos_decay_restart);
fprintf(stderr, " --cos-decay-alpha N Only for Adam optimizer. Cosine decay alpha (default %f)\n", params->cos_decay_alpha);
fprintf(stderr, " --lbfgs-iter N Maximum number of LBFGS optimization iterations for each batch (default %d)\n", params->lbfgs_n_iter);
fprintf(stderr, " --adam-iter N Maximum number of Adam optimization iterations for each batch (default %d)\n", params->adam_n_iter);
fprintf(stderr, " --adam-alpha N Adam learning rate alpha (default %f)\n", params->adam_alpha);
@ -2928,6 +3485,8 @@ void train_print_usage(int /*argc*/, char ** argv, const struct train_params * p
fprintf(stderr, " --mem-compute N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute_gb);
fprintf(stderr, " --mem-compute0 N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute0_gb);
fprintf(stderr, " --mem-compute1 N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute1_gb);
fprintf(stderr, " --mem-compute2 N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute2_gb);
fprintf(stderr, " --mem-compute3 N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute3_gb);
fprintf(stderr, "\n");
}
@ -3065,6 +3624,10 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) {
params->use_scratch = false;
} else if (arg == "--use-scratch") {
params->use_scratch = true;
} else if (arg == "--no-checkpointing") {
params->use_checkpointing = false;
} else if (arg == "--use-checkpointing") {
params->use_checkpointing = true;
} else if (arg == "--warmup") {
if (++i >= argc) {
invalid_param = true;
@ -3155,6 +3718,18 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) {
break;
}
params->mem_compute1_gb = std::stoi(argv[i]);
} else if (arg == "--mem-compute2") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->mem_compute2_gb = std::stoi(argv[i]);
} else if (arg == "--mem-compute3") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->mem_compute3_gb = std::stoi(argv[i]);
} else if (arg == "-h" || arg == "--help") {
train_print_usage(argc, argv, &default_params);
exit(0);
@ -3316,8 +3891,12 @@ int main(int argc, char ** argv) {
size_t size_buf_0 = 1024ll*1024ll*1024ll*((size_t) params.mem_compute0_gb);
size_t size_buf_1 = 1024ll*1024ll*1024ll*((size_t) params.mem_compute1_gb);
size_t size_buf_2 = 1024ll*1024ll*1024ll*((size_t) params.mem_compute2_gb);
size_t size_buf_3 = 1024ll*1024ll*1024ll*((size_t) params.mem_compute3_gb);
uint8_t * compute_buf_0 = new uint8_t[size_buf_0];
uint8_t * compute_buf_1 = new uint8_t[size_buf_1];
uint8_t * compute_buf_2 = new uint8_t[size_buf_2];
uint8_t * compute_buf_3 = new uint8_t[size_buf_3];
GGML_ASSERT(n_tokens < (int) train_tokens.size());
std::vector<int> train_samples;
@ -3376,7 +3955,15 @@ int main(int argc, char ** argv) {
struct ggml_tensor * loss = NULL;
struct ggml_tensor * logits = NULL;
if (params.use_scratch) {
if (params.use_checkpointing) {
loss = forward_batch_wo_cache_flash_attn_train_grad_checkpointing(
&model, ctx0,
gf, gb,
&logits, tokens_input, target_probs,
compute_buf_0, compute_buf_1, compute_buf_2, compute_buf_3,
size_buf_0, size_buf_1, size_buf_2, size_buf_3,
n_tokens, n_batch);
} else if (params.use_scratch) {
loss = forward_batch_wo_cache_flash_attn_train(
&model, ctx0,
gf, gb,