implement gradient checkpointing for training
reduces memory overhead from O(n_layer) to O(sqrt(n_layer)) as explained in readme of https://github.com/cybertronai/gradient-checkpointing
This commit is contained in:
parent
d7003a98cc
commit
6e3f95bf06
1 changed files with 597 additions and 10 deletions
|
@ -1921,6 +1921,556 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
|
|||
return t36;
|
||||
}
|
||||
|
||||
struct ggml_tensor * forward_batch_wo_cache_flash_attn_train_grad_checkpointing(
|
||||
struct my_llama_model * model,
|
||||
struct ggml_context * ctx0,
|
||||
struct ggml_cgraph * gf,
|
||||
struct ggml_cgraph * gb,
|
||||
struct ggml_tensor * * logits,
|
||||
struct ggml_tensor * tokens_input,
|
||||
struct ggml_tensor * targets,
|
||||
void * compute_buf_0,
|
||||
void * compute_buf_1,
|
||||
void * compute_buf_2,
|
||||
void * compute_buf_3,
|
||||
size_t size_buf_0,
|
||||
size_t size_buf_1,
|
||||
size_t size_buf_2,
|
||||
size_t size_buf_3,
|
||||
const int n_tokens,
|
||||
const int n_batch) {
|
||||
|
||||
// implements gradient-checkpointing as explained in readme of https://github.com/cybertronai/gradient-checkpointing
|
||||
|
||||
ggml_set_scratch(ctx0, { 0, 0, nullptr, });
|
||||
|
||||
const int n_past = 0;
|
||||
const int N = n_tokens;
|
||||
|
||||
gf->n_nodes = 0;
|
||||
gf->n_leafs = 0;
|
||||
gf->perf_runs = 0;
|
||||
gf->perf_cycles = 0;
|
||||
gf->perf_time_us = 0;
|
||||
|
||||
const auto & hparams = model->hparams;
|
||||
//const int n_ctx = hparams.n_ctx;
|
||||
const int n_vocab = hparams.n_vocab;
|
||||
const int n_embd = hparams.n_embd;
|
||||
const int n_layer = hparams.n_layer;
|
||||
const int n_head = hparams.n_head;
|
||||
const int n_rot = hparams.n_rot;
|
||||
const int n_ff = get_n_ff(&hparams);
|
||||
const int rope_mode = 0;
|
||||
|
||||
bool track_max_mem = true;
|
||||
|
||||
int last_buf = -1;
|
||||
size_t buf_offs[4] = { 0, 0, 0, 0 };
|
||||
size_t buf_size[4] = { size_buf_0,
|
||||
size_buf_1,
|
||||
size_buf_2,
|
||||
size_buf_3 };
|
||||
void * buf_data[4] = { compute_buf_0,
|
||||
compute_buf_1,
|
||||
compute_buf_2,
|
||||
compute_buf_3 };
|
||||
size_t buf_maxs[4] = { 0, 0, 0, 0 };
|
||||
|
||||
auto use_buf = [ctx0, &last_buf, &buf_offs, &buf_size, &buf_data, &buf_maxs] (int buf) {
|
||||
size_t last_offs = 0;
|
||||
last_offs = ggml_set_scratch(ctx0, { 0, 0, nullptr, });
|
||||
if (last_buf >= 0) {
|
||||
buf_offs[last_buf] = last_offs;
|
||||
buf_maxs[last_buf] = std::max(buf_maxs[last_buf], buf_offs[last_buf]);
|
||||
}
|
||||
if (buf >= 0) {
|
||||
size_t offs = buf_offs[buf];
|
||||
size_t size = buf_size[buf];
|
||||
void * data = buf_data[buf];
|
||||
ggml_set_scratch(ctx0, { offs, size, data, });
|
||||
}
|
||||
last_buf = buf;
|
||||
};
|
||||
|
||||
|
||||
auto clr_buf = [ctx0, &last_buf, &buf_offs, &buf_size, &buf_data, &buf_maxs, track_max_mem] (int buf) {
|
||||
if (buf < 0) return;
|
||||
if (track_max_mem) {
|
||||
size_t last_offs = 0;
|
||||
last_offs = ggml_set_scratch(ctx0, { 0, 0, nullptr, });
|
||||
if (last_buf >= 0) {
|
||||
buf_offs[last_buf] = last_offs;
|
||||
buf_maxs[last_buf] = std::max(buf_maxs[last_buf], buf_offs[last_buf]);
|
||||
}
|
||||
}
|
||||
buf_offs[buf] = 0;
|
||||
if (track_max_mem && last_buf >= 0) {
|
||||
size_t offs = buf_offs[last_buf];
|
||||
size_t size = buf_size[last_buf];
|
||||
void * data = buf_data[last_buf];
|
||||
ggml_set_scratch(ctx0, { offs, size, data, });
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
auto view__q = [ctx0, n_embd, n_head, N, n_batch] (struct ggml_tensor * t) -> struct ggml_tensor * {
|
||||
int64_t ne0 = n_embd/n_head;
|
||||
int64_t ne1 = N;
|
||||
int64_t ne2 = n_head;
|
||||
int64_t ne3 = n_batch;
|
||||
size_t nb0 = ggml_element_size(t);
|
||||
size_t nb1 = nb0*ne0;
|
||||
size_t nb2 = nb1*ne1;
|
||||
size_t nb3 = nb2*ne2;
|
||||
size_t offset = 0;
|
||||
return ggml_view_4d(ctx0, t, ne0, ne1, ne2, ne3, nb1, nb2, nb3, offset);
|
||||
};
|
||||
|
||||
auto view__k = [ctx0, n_embd, n_head, N, n_batch] (struct ggml_tensor * t) -> struct ggml_tensor * {
|
||||
int64_t ne0 = n_embd/n_head;
|
||||
int64_t ne1 = N;
|
||||
int64_t ne2 = n_head;
|
||||
int64_t ne3 = n_batch;
|
||||
size_t nb0 = ggml_element_size(t);
|
||||
size_t nb1 = nb0*ne0;
|
||||
size_t nb2 = nb1*ne1;
|
||||
size_t nb3 = nb2*ne2;
|
||||
size_t offset = nb3*ne3;
|
||||
return ggml_view_4d(ctx0, t, ne0, ne1, ne2, ne3, nb1, nb2, nb3, offset);
|
||||
};
|
||||
|
||||
auto view__v = [ctx0, n_embd, n_head, N, n_batch] (struct ggml_tensor * t) -> struct ggml_tensor * {
|
||||
int64_t ne0 = N;
|
||||
int64_t ne1 = n_embd/n_head;
|
||||
int64_t ne2 = n_head;
|
||||
int64_t ne3 = n_batch;
|
||||
size_t nb0 = ggml_element_size(t);
|
||||
size_t nb1 = nb0*ne0;
|
||||
size_t nb2 = nb1*ne1;
|
||||
size_t nb3 = nb2*ne2;
|
||||
size_t offset = 2*nb3*ne3;
|
||||
return ggml_view_4d(ctx0, t, ne0, ne1, ne2, ne3, nb1, nb2, nb3, offset);
|
||||
};
|
||||
|
||||
auto add_or_set = [ctx0] (struct ggml_tensor * a, struct ggml_tensor * b) -> struct ggml_tensor * {
|
||||
if (a == NULL) {
|
||||
return b;
|
||||
} else {
|
||||
return ggml_add_inplace(ctx0, a, b);
|
||||
}
|
||||
};
|
||||
|
||||
use_buf(-1);
|
||||
|
||||
model->tok_embeddings->grad = NULL;
|
||||
model->norm->grad = NULL;
|
||||
model->output->grad = NULL;
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct my_llama_layer & layer = model->layers[il];
|
||||
layer.attention_norm->grad = NULL;
|
||||
layer.wq->grad = NULL;
|
||||
layer.wk->grad = NULL;
|
||||
layer.wv->grad = NULL;
|
||||
layer.wo->grad = NULL;
|
||||
layer.ffn_norm->grad = NULL;
|
||||
layer.w1->grad = NULL;
|
||||
layer.w2->grad = NULL;
|
||||
layer.w3->grad = NULL;
|
||||
}
|
||||
|
||||
clr_buf(0);
|
||||
clr_buf(1);
|
||||
clr_buf(2);
|
||||
clr_buf(3);
|
||||
|
||||
use_buf(-1);
|
||||
|
||||
struct ggml_tensor * t00 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N*n_batch); assert_shape_1d(t00, N*n_batch);
|
||||
memcpy(t00->data, tokens_input->data, ggml_element_size(t00)*N*n_batch);
|
||||
|
||||
use_buf(-1);
|
||||
|
||||
struct ggml_tensor * t01 = expand(gf, ggml_get_rows(ctx0, model->tok_embeddings, t00)); assert_shape_2d(t01, n_embd, N*n_batch);
|
||||
|
||||
|
||||
std::vector<int> checkpoints;
|
||||
// for (int il = 0; il < n_layer; ++il) {
|
||||
// checkpoints.push_back(il);
|
||||
// }
|
||||
// n_check: number of layers between checkpoints
|
||||
int n_check = (int)(sqrtf(n_layer) + 0.5f);
|
||||
printf("%s: n_check = %d\n", __func__, n_check);
|
||||
for (int chk = n_check-1; chk+1 < n_layer; chk += n_check) {
|
||||
checkpoints.push_back(chk);
|
||||
}
|
||||
|
||||
for (int i = 0; i < checkpoints.size(); ++i) {
|
||||
printf("%s: checkpoint #%d = %d\n", __func__, i, checkpoints[i]);
|
||||
}
|
||||
|
||||
// example for 16 layers:
|
||||
// inp ~ implicit zeroth checkpoint == input
|
||||
// L00 f 4b
|
||||
// L01 f 4b
|
||||
// L02 f 4b
|
||||
// L03 fc4b first checkpoint
|
||||
// L04 f 3b
|
||||
// L05 f 3b
|
||||
// L06 f 3b
|
||||
// L07 fc3b second checkpoint
|
||||
// L08 f 2b
|
||||
// L09 f 2b
|
||||
// L10 f 2b
|
||||
// L11 fc2b third checkpoint
|
||||
// L12 f 1b
|
||||
// L13 f 1b
|
||||
// L14 f 1b
|
||||
// L15 f 1b
|
||||
|
||||
// need to remember these for the backward pass
|
||||
std::vector<struct ggml_tensor *> t02L; t02L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t03L; t03L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t04L; t04L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t05L; t05L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t06L; t06L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t07L; t07L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t08L; t08L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t09L; t09L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t10L; t10L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t11L; t11L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t12L; t12L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t13L; t13L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t14L; t14L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t15L; t15L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t16L; t16L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t17L; t17L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t18L; t18L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t19L; t19L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t20L; t20L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t21L; t21L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t22L; t22L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t23L; t23L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t24L; t24L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t25L; t25L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t26L; t26L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t27L; t27L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t28L; t28L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t29L; t29L.resize(n_layer, NULL);
|
||||
std::vector<struct ggml_tensor *> t30L; t30L.resize(n_layer, NULL);
|
||||
|
||||
struct ggml_tensor * cur = t01;
|
||||
|
||||
|
||||
int chk_idx = 0;
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct my_llama_layer & layer = model->layers[il];
|
||||
// tensors with values necessary for backward pass are in persistent buf(-1)
|
||||
// other tensors with buf(0), buf(1), etc are only temporary needed, and their memory reused
|
||||
bool is_checkpoint = (chk_idx < checkpoints.size() && il == checkpoints[chk_idx]);
|
||||
if (is_checkpoint) {
|
||||
printf("%s: layer %d is_checkpoint\n", __func__, il);
|
||||
chk_idx += 1;
|
||||
}
|
||||
const int prs = 0; // in first forward pass even persistent tensors are only temporary
|
||||
const int tmp = 0; // temporary
|
||||
// nxt is required to compute next layer.
|
||||
// for checkpoints we need to remember this for usage in backward pass,
|
||||
// otherwise temporary until next of this kind
|
||||
const int nxt = is_checkpoint ? -1 : 1;
|
||||
clr_buf(0);
|
||||
use_buf(prs); struct ggml_tensor * t02 = expand(gf, ggml_rms_norm (ctx0, cur, rms_norm_eps)); assert_shape_2d(t02, n_embd, N*n_batch);
|
||||
use_buf(tmp); struct ggml_tensor * t03 = expand(gf, ggml_repeat (ctx0, layer.attention_norm, t02)); assert_shape_2d(t03, n_embd, N*n_batch);
|
||||
use_buf(prs); struct ggml_tensor * t04 = expand(gf, ggml_mul (ctx0, t02, t03)); assert_shape_2d(t04, n_embd, N*n_batch);
|
||||
use_buf(prs); struct ggml_tensor * t05 = expand(gf, ggml_mul_mat (ctx0, layer.wq, t04)); assert_shape_2d(t05, n_embd, N*n_batch);
|
||||
use_buf(prs); struct ggml_tensor * t06 = expand(gf, ggml_reshape_4d (ctx0, t05, n_embd/n_head, n_head, N, n_batch)); assert_shape_4d(t06, n_embd/n_head, n_head, N, n_batch);
|
||||
use_buf(prs); struct ggml_tensor * t07 = expand(gf, ggml_rope_inplace (ctx0, t06, n_past, n_rot, rope_mode)); assert_shape_4d(t07, n_embd/n_head, n_head, N, n_batch);
|
||||
use_buf(prs); struct ggml_tensor * t08 = expand(gf, ggml_mul_mat (ctx0, layer.wk, t04)); assert_shape_2d(t08, n_embd, N*n_batch);
|
||||
use_buf(prs); struct ggml_tensor * t09 = expand(gf, ggml_reshape_4d (ctx0, t08, n_embd/n_head, n_head, N, n_batch)); assert_shape_4d(t09, n_embd/n_head, n_head, N, n_batch);
|
||||
use_buf(prs); struct ggml_tensor * t10 = expand(gf, ggml_rope_inplace (ctx0, t09, n_past, n_rot, rope_mode)); assert_shape_4d(t10, n_embd/n_head, n_head, N, n_batch);
|
||||
use_buf(prs); struct ggml_tensor * t11 = expand(gf, ggml_mul_mat (ctx0, t04, layer.wv)); assert_shape_2d(t11, N*n_batch, n_embd);
|
||||
use_buf(prs); struct ggml_tensor * t12 = expand(gf, ggml_reshape_4d (ctx0, t11, N, n_batch, n_embd/n_head, n_head)); assert_shape_4d(t12, N, n_batch, n_embd/n_head, n_head);
|
||||
use_buf(prs); struct ggml_tensor * t13 = expand(gf, ggml_permute (ctx0, t07, 0, 2, 1, 3)); assert_shape_4d(t13, n_embd/n_head, N, n_head, n_batch);
|
||||
use_buf(prs); struct ggml_tensor * t14 = expand(gf, ggml_permute (ctx0, t10, 0, 2, 1, 3)); assert_shape_4d(t14, n_embd/n_head, N, n_head, n_batch);
|
||||
use_buf(prs); struct ggml_tensor * t15 = expand(gf, ggml_permute (ctx0, t12, 0, 3, 1, 2)); assert_shape_4d(t15, N, n_embd/n_head, n_head, n_batch);
|
||||
use_buf(prs); struct ggml_tensor * t16 = expand(gf, ggml_flash_attn (ctx0, t13, t14, t15, true)); assert_shape_4d(t16, n_embd/n_head, N, n_head, n_batch);
|
||||
use_buf(tmp); struct ggml_tensor * t17 = expand(gf, ggml_permute (ctx0, t16, 0, 2, 1, 3)); assert_shape_4d(t17, n_embd/n_head, n_head, N, n_batch);
|
||||
use_buf(prs); struct ggml_tensor * t18 = expand(gf, ggml_cont (ctx0, t17)); assert_shape_4d(t18, n_embd/n_head, n_head, N, n_batch);
|
||||
use_buf(prs); struct ggml_tensor * t19 = expand(gf, ggml_reshape_2d (ctx0, t18, n_embd, N*n_batch)); assert_shape_2d(t19, n_embd, N*n_batch);
|
||||
use_buf(tmp); struct ggml_tensor * t20 = expand(gf, ggml_mul_mat (ctx0, layer.wo, t19)); assert_shape_2d(t20, n_embd, N*n_batch);
|
||||
use_buf(prs); struct ggml_tensor * t21 = expand(gf, ggml_add (ctx0, t20, cur)); assert_shape_2d(t21, n_embd, N*n_batch);
|
||||
use_buf(prs); struct ggml_tensor * t22 = expand(gf, ggml_rms_norm (ctx0, t21, rms_norm_eps)); assert_shape_2d(t22, n_embd, N*n_batch);
|
||||
use_buf(tmp); struct ggml_tensor * t23 = expand(gf, ggml_repeat (ctx0, layer.ffn_norm, t22)); assert_shape_2d(t23, n_embd, N*n_batch);
|
||||
use_buf(prs); struct ggml_tensor * t24 = expand(gf, ggml_mul (ctx0, t23, t22)); assert_shape_2d(t24, n_embd, N*n_batch);
|
||||
use_buf(prs); struct ggml_tensor * t25 = expand(gf, ggml_mul_mat (ctx0, layer.w3, t24)); assert_shape_2d(t25, n_ff, N*n_batch);
|
||||
use_buf(prs); struct ggml_tensor * t26 = expand(gf, ggml_mul_mat (ctx0, layer.w1, t24)); assert_shape_2d(t26, n_ff, N*n_batch);
|
||||
use_buf(prs); struct ggml_tensor * t27 = expand(gf, ggml_silu (ctx0, t26)); assert_shape_2d(t27, n_ff, N*n_batch);
|
||||
use_buf(prs); struct ggml_tensor * t28 = expand(gf, ggml_mul (ctx0, t27, t25)); assert_shape_2d(t28, n_ff, N*n_batch);
|
||||
use_buf(tmp); struct ggml_tensor * t29 = expand(gf, ggml_mul_mat (ctx0, layer.w2, t28)); assert_shape_2d(t29, n_embd, N*n_batch);
|
||||
clr_buf( 1);
|
||||
use_buf(nxt); struct ggml_tensor * t30 = expand(gf, ggml_add (ctx0, t21, t29)); assert_shape_2d(t30, n_embd, N*n_batch);
|
||||
|
||||
// only t30L is remembered for checkpointing in first forward pass
|
||||
if (is_checkpoint) {
|
||||
t30L[il] = t30;
|
||||
}
|
||||
cur = t30;
|
||||
}
|
||||
clr_buf(0);
|
||||
use_buf(0);
|
||||
struct ggml_tensor * t31 = expand(gf, ggml_rms_norm (ctx0, cur, rms_norm_eps)); assert_shape_2d(t31, n_embd, N*n_batch);
|
||||
struct ggml_tensor * t32 = expand(gf, ggml_repeat (ctx0, model->norm, t31)); assert_shape_2d(t32, n_embd, N*n_batch);
|
||||
struct ggml_tensor * t33 = expand(gf, ggml_mul (ctx0, t32, t31)); assert_shape_2d(t33, n_embd, N*n_batch);
|
||||
use_buf(-1);
|
||||
struct ggml_tensor * t34 = expand(gf, ggml_mul_mat (ctx0, model->output, t33)); assert_shape_2d(t34, n_vocab, N*n_batch);
|
||||
struct ggml_tensor * t35 = expand(gf, ggml_reshape_3d(ctx0, t34, n_vocab, N, n_batch)); assert_shape_3d(t35, n_vocab, N, n_batch);
|
||||
struct ggml_tensor * t36 = expand(gf, ggml_cross_entropy_loss(ctx0, t35, targets)); assert_shape_1d(t36, 1);
|
||||
|
||||
*gb = *gf;
|
||||
|
||||
// t36->grad gets set to one by optimizer, so we need the tensor.
|
||||
// initialize it with 1.0f to make sure.
|
||||
use_buf(-1);
|
||||
t36->grad = expand(gb, ggml_new_f32(ctx0, 1.0f));
|
||||
|
||||
use_buf(0);
|
||||
t35->grad = expand(gb, ggml_cross_entropy_loss_back(ctx0, t35, targets, t36->grad)); assert_shape_3d(t35->grad, n_vocab, N, n_batch);
|
||||
t34->grad = expand(gb, ggml_reshape_2d (ctx0, t35->grad, n_vocab, N*n_batch)); assert_shape_2d(t34->grad, n_vocab, N*n_batch);
|
||||
t33->grad = expand(gb, ggml_out_prod (ctx0, model->output, ggml_transpose(ctx0, t34->grad))); assert_shape_2d(t33->grad, n_embd, N*n_batch);
|
||||
t32->grad = expand(gb, ggml_mul (ctx0, t33->grad, t31)); assert_shape_2d(t32->grad, n_embd, N*n_batch);
|
||||
|
||||
use_buf(-1);
|
||||
|
||||
model->norm->grad = expand(gb, add_or_set(model->norm->grad, ggml_repeat_back(ctx0, t32->grad, model->norm))); assert_shape_1d(model->norm->grad, n_embd);
|
||||
model->output->grad = expand(gb, add_or_set(model->output->grad, ggml_out_prod(ctx0, t33, t34->grad))); assert_shape_2d(model->output->grad, n_embd, n_vocab);
|
||||
|
||||
clr_buf(1);
|
||||
use_buf(1);
|
||||
t31->grad = expand(gb, ggml_mul(ctx0, t33->grad, t32)); assert_shape_2d(t31->grad, n_embd, N*n_batch);
|
||||
|
||||
struct ggml_tensor * back_layer_inp = t31;
|
||||
struct ggml_tensor * grad_layer_inp = NULL;
|
||||
|
||||
printf("%s: checkpoints.size() = %zu\n", __func__, checkpoints.size());
|
||||
chk_idx = checkpoints.size()-1;
|
||||
int avail_begin = n_layer;
|
||||
int avail_end = n_layer;
|
||||
printf("%s: chk_idx=%d avail_begin=%d avail_end=%d\n", __func__, chk_idx, avail_begin, avail_end);
|
||||
for (int k = 0; k < n_layer; ++k) {
|
||||
// second forward pass for checkpointing
|
||||
int il = n_layer-1-k;
|
||||
if (il < avail_begin) {
|
||||
// make sure, that txxL[il] is available
|
||||
// forward pass from last checkpoint
|
||||
GGML_ASSERT(chk_idx >= -1);
|
||||
int begin = (chk_idx == -1)
|
||||
? 0
|
||||
: checkpoints[chk_idx] + 1; // checkpoint[chk_idx] contains t30 for computing following layers -> +1
|
||||
int end = (chk_idx+1 < checkpoints.size())
|
||||
? (checkpoints[chk_idx+1] + 1)
|
||||
: n_layer;
|
||||
GGML_ASSERT(begin <= il);
|
||||
GGML_ASSERT(il < end);
|
||||
cur = (chk_idx == -1) ? t01 : t30L[checkpoints[chk_idx]];
|
||||
clr_buf(2);
|
||||
printf("%s: second forward pass chk_idx=%d begin=%d end=%d\n", __func__, chk_idx, begin, end);
|
||||
for (int i = begin; i < end; ++i) {
|
||||
struct my_llama_layer & layer = model->layers[i];
|
||||
const int prs = 2; // persistent until next checkpoint
|
||||
const int tmp = 0; // temporary for this layer
|
||||
const bool is_checkpoint = (i == end-1);
|
||||
clr_buf(0);
|
||||
use_buf(prs); struct ggml_tensor * t02 = expand(gb, ggml_rms_norm (ctx0, cur, rms_norm_eps)); assert_shape_2d(t02, n_embd, N*n_batch);
|
||||
use_buf(tmp); struct ggml_tensor * t03 = expand(gb, ggml_repeat (ctx0, layer.attention_norm, t02)); assert_shape_2d(t03, n_embd, N*n_batch);
|
||||
use_buf(prs); struct ggml_tensor * t04 = expand(gb, ggml_mul (ctx0, t02, t03)); assert_shape_2d(t04, n_embd, N*n_batch);
|
||||
use_buf(prs); struct ggml_tensor * t05 = expand(gb, ggml_mul_mat (ctx0, layer.wq, t04)); assert_shape_2d(t05, n_embd, N*n_batch);
|
||||
use_buf(prs); struct ggml_tensor * t06 = expand(gb, ggml_reshape_4d (ctx0, t05, n_embd/n_head, n_head, N, n_batch)); assert_shape_4d(t06, n_embd/n_head, n_head, N, n_batch);
|
||||
use_buf(prs); struct ggml_tensor * t07 = expand(gb, ggml_rope_inplace (ctx0, t06, n_past, n_rot, rope_mode)); assert_shape_4d(t07, n_embd/n_head, n_head, N, n_batch);
|
||||
use_buf(prs); struct ggml_tensor * t08 = expand(gb, ggml_mul_mat (ctx0, layer.wk, t04)); assert_shape_2d(t08, n_embd, N*n_batch);
|
||||
use_buf(prs); struct ggml_tensor * t09 = expand(gb, ggml_reshape_4d (ctx0, t08, n_embd/n_head, n_head, N, n_batch)); assert_shape_4d(t09, n_embd/n_head, n_head, N, n_batch);
|
||||
use_buf(prs); struct ggml_tensor * t10 = expand(gb, ggml_rope_inplace (ctx0, t09, n_past, n_rot, rope_mode)); assert_shape_4d(t10, n_embd/n_head, n_head, N, n_batch);
|
||||
use_buf(prs); struct ggml_tensor * t11 = expand(gb, ggml_mul_mat (ctx0, t04, layer.wv)); assert_shape_2d(t11, N*n_batch, n_embd);
|
||||
use_buf(prs); struct ggml_tensor * t12 = expand(gb, ggml_reshape_4d (ctx0, t11, N, n_batch, n_embd/n_head, n_head)); assert_shape_4d(t12, N, n_batch, n_embd/n_head, n_head);
|
||||
use_buf(prs); struct ggml_tensor * t13 = expand(gb, ggml_permute (ctx0, t07, 0, 2, 1, 3)); assert_shape_4d(t13, n_embd/n_head, N, n_head, n_batch);
|
||||
use_buf(prs); struct ggml_tensor * t14 = expand(gb, ggml_permute (ctx0, t10, 0, 2, 1, 3)); assert_shape_4d(t14, n_embd/n_head, N, n_head, n_batch);
|
||||
use_buf(prs); struct ggml_tensor * t15 = expand(gb, ggml_permute (ctx0, t12, 0, 3, 1, 2)); assert_shape_4d(t15, N, n_embd/n_head, n_head, n_batch);
|
||||
use_buf(prs); struct ggml_tensor * t16 = expand(gb, ggml_flash_attn (ctx0, t13, t14, t15, true)); assert_shape_4d(t16, n_embd/n_head, N, n_head, n_batch);
|
||||
use_buf(tmp); struct ggml_tensor * t17 = expand(gb, ggml_permute (ctx0, t16, 0, 2, 1, 3)); assert_shape_4d(t17, n_embd/n_head, n_head, N, n_batch);
|
||||
use_buf(prs); struct ggml_tensor * t18 = expand(gb, ggml_cont (ctx0, t17)); assert_shape_4d(t18, n_embd/n_head, n_head, N, n_batch);
|
||||
use_buf(prs); struct ggml_tensor * t19 = expand(gb, ggml_reshape_2d (ctx0, t18, n_embd, N*n_batch)); assert_shape_2d(t19, n_embd, N*n_batch);
|
||||
use_buf(tmp); struct ggml_tensor * t20 = expand(gb, ggml_mul_mat (ctx0, layer.wo, t19)); assert_shape_2d(t20, n_embd, N*n_batch);
|
||||
use_buf(prs); struct ggml_tensor * t21 = expand(gb, ggml_add (ctx0, t20, cur)); assert_shape_2d(t21, n_embd, N*n_batch);
|
||||
use_buf(prs); struct ggml_tensor * t22 = expand(gb, ggml_rms_norm (ctx0, t21, rms_norm_eps)); assert_shape_2d(t22, n_embd, N*n_batch);
|
||||
use_buf(tmp); struct ggml_tensor * t23 = expand(gb, ggml_repeat (ctx0, layer.ffn_norm, t22)); assert_shape_2d(t23, n_embd, N*n_batch);
|
||||
use_buf(prs); struct ggml_tensor * t24 = expand(gb, ggml_mul (ctx0, t23, t22)); assert_shape_2d(t24, n_embd, N*n_batch);
|
||||
use_buf(prs); struct ggml_tensor * t25 = expand(gb, ggml_mul_mat (ctx0, layer.w3, t24)); assert_shape_2d(t25, n_ff, N*n_batch);
|
||||
use_buf(prs); struct ggml_tensor * t26 = expand(gb, ggml_mul_mat (ctx0, layer.w1, t24)); assert_shape_2d(t26, n_ff, N*n_batch);
|
||||
use_buf(prs); struct ggml_tensor * t27 = expand(gb, ggml_silu (ctx0, t26)); assert_shape_2d(t27, n_ff, N*n_batch);
|
||||
use_buf(prs); struct ggml_tensor * t28 = expand(gb, ggml_mul (ctx0, t27, t25)); assert_shape_2d(t28, n_ff, N*n_batch);
|
||||
use_buf(tmp); struct ggml_tensor * t29 = expand(gb, ggml_mul_mat (ctx0, layer.w2, t28)); assert_shape_2d(t29, n_embd, N*n_batch);
|
||||
if (t30L[i] == NULL) {
|
||||
use_buf(prs); struct ggml_tensor * t30 = expand(gb, ggml_add (ctx0, t21, t29)); assert_shape_2d(t30, n_embd, N*n_batch);
|
||||
t30L[i] = t30;
|
||||
cur = t30;
|
||||
}
|
||||
t02L[i] = t02;
|
||||
t03L[i] = t03;
|
||||
t04L[i] = t04;
|
||||
t05L[i] = t05;
|
||||
t06L[i] = t06;
|
||||
t07L[i] = t07;
|
||||
t08L[i] = t08;
|
||||
t09L[i] = t09;
|
||||
t10L[i] = t10;
|
||||
t11L[i] = t11;
|
||||
t12L[i] = t12;
|
||||
t13L[i] = t13;
|
||||
t14L[i] = t14;
|
||||
t15L[i] = t15;
|
||||
t16L[i] = t16;
|
||||
t17L[i] = t17;
|
||||
t18L[i] = t18;
|
||||
t19L[i] = t19;
|
||||
t20L[i] = t20;
|
||||
t21L[i] = t21;
|
||||
t22L[i] = t22;
|
||||
t23L[i] = t23;
|
||||
t24L[i] = t24;
|
||||
t25L[i] = t25;
|
||||
t26L[i] = t26;
|
||||
t27L[i] = t27;
|
||||
t28L[i] = t28;
|
||||
t29L[i] = t29;
|
||||
}
|
||||
--chk_idx;
|
||||
avail_begin = begin;
|
||||
avail_end = end;
|
||||
printf("%s: chk_idx=%d avail_begin=%d avail_end=%d\n", __func__, chk_idx, avail_begin, avail_end);
|
||||
}
|
||||
printf("%s: backward pass il=%d\n", __func__, il);
|
||||
|
||||
struct my_llama_layer & layer = model->layers[il];
|
||||
|
||||
struct ggml_tensor * t02 = t02L[il];
|
||||
struct ggml_tensor * t03 = t03L[il];
|
||||
struct ggml_tensor * t04 = t04L[il];
|
||||
struct ggml_tensor * t05 = t05L[il];
|
||||
struct ggml_tensor * t06 = t06L[il];
|
||||
struct ggml_tensor * t07 = t07L[il];
|
||||
struct ggml_tensor * t08 = t08L[il];
|
||||
struct ggml_tensor * t09 = t09L[il];
|
||||
struct ggml_tensor * t10 = t10L[il];
|
||||
struct ggml_tensor * t11 = t11L[il];
|
||||
struct ggml_tensor * t12 = t12L[il];
|
||||
struct ggml_tensor * t13 = t13L[il];
|
||||
struct ggml_tensor * t14 = t14L[il];
|
||||
struct ggml_tensor * t15 = t15L[il];
|
||||
struct ggml_tensor * t16 = t16L[il];
|
||||
struct ggml_tensor * t17 = t17L[il];
|
||||
struct ggml_tensor * t18 = t18L[il];
|
||||
struct ggml_tensor * t19 = t19L[il];
|
||||
struct ggml_tensor * t20 = t20L[il];
|
||||
struct ggml_tensor * t21 = t21L[il];
|
||||
struct ggml_tensor * t22 = t22L[il];
|
||||
struct ggml_tensor * t23 = t23L[il];
|
||||
struct ggml_tensor * t24 = t24L[il];
|
||||
struct ggml_tensor * t25 = t25L[il];
|
||||
struct ggml_tensor * t26 = t26L[il];
|
||||
struct ggml_tensor * t27 = t27L[il];
|
||||
struct ggml_tensor * t28 = t28L[il];
|
||||
struct ggml_tensor * t29 = t29L[il];
|
||||
struct ggml_tensor * t30 = t30L[il];
|
||||
|
||||
clr_buf(0);
|
||||
use_buf(0);
|
||||
t30->grad = expand(gb, ggml_rms_norm_back(ctx0, t30, back_layer_inp->grad)); assert_shape_2d(t30->grad, n_embd, N*n_batch);
|
||||
if (grad_layer_inp) {
|
||||
t30->grad = expand(gb, ggml_add(ctx0, t30->grad, grad_layer_inp->grad)); assert_shape_2d(t30->grad, n_embd, N*n_batch);
|
||||
}
|
||||
clr_buf(1);
|
||||
t29->grad = t30->grad; assert_shape_2d(t29->grad, n_embd, N*n_batch);
|
||||
t28->grad = expand(gb, ggml_out_prod(ctx0, layer.w2, ggml_transpose(ctx0, t29->grad))); assert_shape_2d(t28->grad, n_ff, N*n_batch);
|
||||
t27->grad = expand(gb, ggml_mul(ctx0, t28->grad, t25)); assert_shape_2d(t27->grad, n_ff, N*n_batch);
|
||||
t26->grad = expand(gb, ggml_silu_back(ctx0, t26, t27->grad)); assert_shape_2d(t26->grad, n_ff, N*n_batch);
|
||||
t25->grad = expand(gb, ggml_mul(ctx0, t28->grad, t27)); assert_shape_2d(t25->grad, n_ff, N*n_batch);
|
||||
t24->grad = expand(gb, ggml_add_inplace(ctx0,
|
||||
ggml_out_prod(ctx0, layer.w1, ggml_transpose(ctx0, t26->grad)),
|
||||
ggml_out_prod(ctx0, layer.w3, ggml_transpose(ctx0, t25->grad)))); assert_shape_2d(t24->grad, n_embd, N*n_batch);
|
||||
t23->grad = expand(gb, ggml_mul(ctx0, t24->grad, t22)); assert_shape_2d(t23->grad, n_embd, N*n_batch);
|
||||
t22->grad = expand(gb, ggml_mul(ctx0, t24->grad, ggml_repeat(ctx0, layer.ffn_norm, t24->grad))); assert_shape_2d(t22->grad, n_embd, N*n_batch);
|
||||
use_buf(1);
|
||||
t21->grad = expand(gb, ggml_add(ctx0, t30->grad, ggml_rms_norm_back(ctx0, t21, t22->grad))); assert_shape_2d(t21->grad, n_embd, N*n_batch);
|
||||
grad_layer_inp = t21;
|
||||
use_buf(0);
|
||||
t20->grad = t21->grad; assert_shape_2d(t20->grad, n_embd, N*n_batch);
|
||||
t19->grad = expand(gb, ggml_out_prod(ctx0, layer.wo, ggml_transpose(ctx0, t20->grad))); assert_shape_2d(t19->grad, n_embd, N*n_batch);
|
||||
t18->grad = expand(gb, ggml_reshape_4d(ctx0, t19->grad, n_embd/n_head, n_head, N, n_batch)); assert_shape_4d(t18->grad, n_embd/n_head, n_head, N, n_batch);
|
||||
t17->grad = t18->grad; assert_shape_4d(t17->grad, n_embd/n_head, n_head, N, n_batch);
|
||||
t16->grad = expand(gb, ggml_permute(ctx0, t17->grad, 0, 2, 1, 3)); assert_shape_4d(t16->grad, n_embd/n_head, N, n_head, n_batch);
|
||||
struct ggml_tensor * flash_attn = expand(gb, ggml_flash_attn_back(ctx0, t13, t14, t15, t16->grad, true)); assert_shape_4d(flash_attn, n_embd/n_head, N*3, n_head, n_batch);
|
||||
t15->grad = expand(gb, view__v(flash_attn)); assert_shape_4d(t15->grad, N, n_embd/n_head, n_head, n_batch);
|
||||
t14->grad = expand(gb, view__k(flash_attn)); assert_shape_4d(t14->grad, n_embd/n_head, N, n_head, n_batch);
|
||||
t13->grad = expand(gb, view__q(flash_attn)); assert_shape_4d(t13->grad, n_embd/n_head, N, n_head, n_batch);
|
||||
t12->grad = expand(gb, ggml_permute(ctx0, t15->grad, 0, 2, 3, 1)); assert_shape_4d(t12->grad, N, n_batch, n_embd/n_head, n_head);
|
||||
t11->grad = expand(gb, ggml_reshape_2d(ctx0, ggml_cont(ctx0, t12->grad), N*n_batch, n_embd)); assert_shape_2d(t11->grad, N*n_batch, n_embd);
|
||||
t10->grad = expand(gb, ggml_permute(ctx0, t14->grad, 0, 2, 1, 3)); assert_shape_4d(t10->grad, n_embd/n_head, n_head, N, n_batch);
|
||||
t09->grad = expand(gb, ggml_rope_back(ctx0, t10->grad, n_past, n_rot, rope_mode)); assert_shape_4d(t09->grad, n_embd/n_head, n_head, N, n_batch);
|
||||
t08->grad = expand(gb, ggml_reshape_2d(ctx0, t09->grad, n_embd, N*n_batch)); assert_shape_2d(t08->grad, n_embd, N*n_batch);
|
||||
t07->grad = expand(gb, ggml_permute(ctx0, t13->grad, 0, 2, 1, 3)); assert_shape_4d(t07->grad, n_embd/n_head, n_head, N, n_batch);
|
||||
t06->grad = expand(gb, ggml_rope_back(ctx0, t07->grad, n_past, n_rot, rope_mode)); assert_shape_4d(t06->grad, n_embd/n_head, n_head, N, n_batch);
|
||||
t05->grad = expand(gb, ggml_reshape_2d(ctx0, t06->grad, n_embd, N*n_batch)); assert_shape_2d(t05->grad, n_embd, N*n_batch);
|
||||
t04->grad = expand(gb, ggml_add_inplace(ctx0,
|
||||
ggml_add_inplace(ctx0,
|
||||
ggml_out_prod(ctx0, layer.wv, t11->grad),
|
||||
ggml_out_prod(ctx0, layer.wk, ggml_transpose(ctx0, t08->grad))),
|
||||
ggml_out_prod(ctx0, layer.wq, ggml_transpose(ctx0, t05->grad)))); assert_shape_2d(t04->grad, n_embd, N*n_batch);
|
||||
t03->grad = expand(gb, ggml_mul(ctx0, t04->grad, t02)); assert_shape_2d(t04->grad, n_embd, N*n_batch);
|
||||
use_buf(1);
|
||||
t02->grad = expand(gb, ggml_mul(ctx0, t04->grad, ggml_repeat(ctx0, layer.attention_norm, t02))); assert_shape_2d(t02->grad, n_embd, N*n_batch);
|
||||
back_layer_inp = t02;
|
||||
|
||||
use_buf(-1);
|
||||
layer.attention_norm->grad = expand(gb, add_or_set(layer.attention_norm->grad, ggml_repeat_back(ctx0, t03->grad, layer.attention_norm))); assert_shape_1d(layer.attention_norm->grad, n_embd);
|
||||
layer.wq->grad = expand(gb, add_or_set(layer.wq->grad, ggml_out_prod(ctx0, t04, t05->grad))); assert_shape_2d(layer.wq->grad, n_embd, n_embd);
|
||||
layer.wk->grad = expand(gb, add_or_set(layer.wk->grad, ggml_out_prod(ctx0, t04, t08->grad))); assert_shape_2d(layer.wk->grad, n_embd, n_embd);
|
||||
layer.wv->grad = expand(gb, add_or_set(layer.wv->grad, ggml_out_prod(ctx0, t04, ggml_transpose(ctx0, t11->grad)))); assert_shape_2d(layer.wv->grad, n_embd, n_embd);
|
||||
layer.wo->grad = expand(gb, add_or_set(layer.wo->grad, ggml_out_prod(ctx0, t19, t20->grad))); assert_shape_2d(layer.wo->grad, n_embd, n_embd);
|
||||
layer.ffn_norm->grad = expand(gb, add_or_set(layer.ffn_norm->grad, ggml_repeat_back(ctx0, t23->grad, layer.ffn_norm))); assert_shape_1d(layer.ffn_norm->grad, n_embd);
|
||||
layer.w1->grad = expand(gb, add_or_set(layer.w1->grad, ggml_out_prod(ctx0, t24, t26->grad))); assert_shape_2d(layer.w1->grad, n_embd, n_ff);
|
||||
layer.w2->grad = expand(gb, add_or_set(layer.w2->grad, ggml_out_prod(ctx0, t28, t29->grad))); assert_shape_2d(layer.w2->grad, n_ff, n_embd);
|
||||
layer.w3->grad = expand(gb, add_or_set(layer.w3->grad, ggml_out_prod(ctx0, t24, t25->grad))); assert_shape_2d(layer.w3->grad, n_embd, n_ff);
|
||||
}
|
||||
printf("%s: chk_idx=%d avail_begin=%d avail_end=%d\n", __func__, chk_idx, avail_begin, avail_end);
|
||||
GGML_ASSERT(chk_idx == -2);
|
||||
GGML_ASSERT(avail_begin == 0);
|
||||
clr_buf(0);
|
||||
use_buf(0);
|
||||
t01->grad = expand(gb, ggml_add_inplace(ctx0, grad_layer_inp->grad, ggml_rms_norm_back(ctx0, t01, back_layer_inp->grad))); assert_shape_2d(t01->grad, n_embd, N*n_batch);
|
||||
use_buf(-1);
|
||||
model->tok_embeddings->grad = expand(gb, ggml_get_rows_back(ctx0, t01->grad, t00, model->tok_embeddings)); assert_shape_2d(model->tok_embeddings->grad, n_embd, n_vocab);
|
||||
|
||||
*logits = t35;
|
||||
|
||||
clr_buf(0);
|
||||
clr_buf(1);
|
||||
clr_buf(2);
|
||||
clr_buf(3);
|
||||
|
||||
if (track_max_mem) {
|
||||
printf("%s: max size compute buf0: %zu\n", __func__, buf_maxs[0]);
|
||||
printf("%s: max size compute buf1: %zu\n", __func__, buf_maxs[1]);
|
||||
printf("%s: max size compute buf2: %zu\n", __func__, buf_maxs[2]);
|
||||
printf("%s: max size compute buf3: %zu\n", __func__, buf_maxs[3]);
|
||||
}
|
||||
|
||||
// now that all grads are created, set the graph leafs and grads
|
||||
graph_set_leafs_grads(gf);
|
||||
graph_set_leafs_grads(gb);
|
||||
|
||||
return t36;
|
||||
}
|
||||
|
||||
void set_f32_3d(struct ggml_tensor * tensor, int64_t i0, int64_t i1, int64_t i2, float value) {
|
||||
float * ptr = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2]);
|
||||
*ptr = value;
|
||||
|
@ -2810,6 +3360,7 @@ struct train_params {
|
|||
bool use_adam;
|
||||
bool use_flash;
|
||||
bool use_scratch;
|
||||
bool use_checkpointing;
|
||||
|
||||
// only adam
|
||||
int warmup;
|
||||
|
@ -2829,6 +3380,8 @@ struct train_params {
|
|||
int mem_compute_gb;
|
||||
int mem_compute0_gb;
|
||||
int mem_compute1_gb;
|
||||
int mem_compute2_gb;
|
||||
int mem_compute3_gb;
|
||||
};
|
||||
|
||||
struct train_params get_default_train_params() {
|
||||
|
@ -2860,6 +3413,7 @@ struct train_params get_default_train_params() {
|
|||
params.use_adam = true;
|
||||
params.use_flash = true;
|
||||
params.use_scratch = true;
|
||||
params.use_checkpointing = true;
|
||||
|
||||
// only adam
|
||||
params.warmup = 100;
|
||||
|
@ -2878,8 +3432,9 @@ struct train_params get_default_train_params() {
|
|||
params.mem_model_gb = 2;
|
||||
params.mem_compute_gb = 24;
|
||||
params.mem_compute0_gb = 8;
|
||||
params.mem_compute1_gb = 2;
|
||||
|
||||
params.mem_compute1_gb = 1;
|
||||
params.mem_compute2_gb = 2;
|
||||
params.mem_compute3_gb = 1;
|
||||
return params;
|
||||
}
|
||||
|
||||
|
@ -2909,14 +3464,16 @@ void train_print_usage(int /*argc*/, char ** argv, const struct train_params * p
|
|||
fprintf(stderr, " --samples-after-nl Training samples start after newlines. (default %s)\n", params->samples_start_after_nl ? "on" : "off");
|
||||
fprintf(stderr, " --use-lbfgs Use LBFGS optimizer instead of default Adam\n");
|
||||
fprintf(stderr, " --use-adam Use Adam optimizer (default)\n");
|
||||
fprintf(stderr, " --no-flash Don't use flash attention.\n");
|
||||
fprintf(stderr, " --no-flash Don't use flash attention. Implies no-scratch and no-checkpointing.\n");
|
||||
fprintf(stderr, " --use-flash Use flash attention (default)\n");
|
||||
fprintf(stderr, " --no-scratch Don't use scratch buffers\n");
|
||||
fprintf(stderr, " --use-scratch Use scratch buffers (default)\n");
|
||||
fprintf(stderr, " --warmup N Number of warmup steps (default %d)\n", params->warmup);
|
||||
fprintf(stderr, " --cos-decay-steps N Number of cosine decay steps (default %d)\n", params->cos_decay_steps);
|
||||
fprintf(stderr, " --cos-decay-restart N Increase of cosine decay steps after restart (default %f)\n", params->cos_decay_restart);
|
||||
fprintf(stderr, " --cos-decay-alpha N Cosine decay alpha (default %f)\n", params->cos_decay_alpha);
|
||||
fprintf(stderr, " --no-scratch Don't use scratch buffers. Implies no-checkpointing.\n");
|
||||
fprintf(stderr, " --use-scratch Use scratch buffers. Implies use-flash. (default)\n");
|
||||
fprintf(stderr, " --no-checkpointing Don't use gradient checkpointing\n");
|
||||
fprintf(stderr, " --use-checkpointing Use gradient checkpointing. Implies use-scratch and use-flash. (default)\n");
|
||||
fprintf(stderr, " --warmup N Only for Adam optimizer. Number of warmup steps (default %d)\n", params->warmup);
|
||||
fprintf(stderr, " --cos-decay-steps N Only for Adam optimizer. Number of cosine decay steps (default %d)\n", params->cos_decay_steps);
|
||||
fprintf(stderr, " --cos-decay-restart N Only for Adam optimizer. Increase of cosine decay steps after restart (default %f)\n", params->cos_decay_restart);
|
||||
fprintf(stderr, " --cos-decay-alpha N Only for Adam optimizer. Cosine decay alpha (default %f)\n", params->cos_decay_alpha);
|
||||
fprintf(stderr, " --lbfgs-iter N Maximum number of LBFGS optimization iterations for each batch (default %d)\n", params->lbfgs_n_iter);
|
||||
fprintf(stderr, " --adam-iter N Maximum number of Adam optimization iterations for each batch (default %d)\n", params->adam_n_iter);
|
||||
fprintf(stderr, " --adam-alpha N Adam learning rate alpha (default %f)\n", params->adam_alpha);
|
||||
|
@ -2928,6 +3485,8 @@ void train_print_usage(int /*argc*/, char ** argv, const struct train_params * p
|
|||
fprintf(stderr, " --mem-compute N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute_gb);
|
||||
fprintf(stderr, " --mem-compute0 N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute0_gb);
|
||||
fprintf(stderr, " --mem-compute1 N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute1_gb);
|
||||
fprintf(stderr, " --mem-compute2 N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute2_gb);
|
||||
fprintf(stderr, " --mem-compute3 N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute3_gb);
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
|
@ -3065,6 +3624,10 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) {
|
|||
params->use_scratch = false;
|
||||
} else if (arg == "--use-scratch") {
|
||||
params->use_scratch = true;
|
||||
} else if (arg == "--no-checkpointing") {
|
||||
params->use_checkpointing = false;
|
||||
} else if (arg == "--use-checkpointing") {
|
||||
params->use_checkpointing = true;
|
||||
} else if (arg == "--warmup") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
|
@ -3155,6 +3718,18 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) {
|
|||
break;
|
||||
}
|
||||
params->mem_compute1_gb = std::stoi(argv[i]);
|
||||
} else if (arg == "--mem-compute2") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->mem_compute2_gb = std::stoi(argv[i]);
|
||||
} else if (arg == "--mem-compute3") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->mem_compute3_gb = std::stoi(argv[i]);
|
||||
} else if (arg == "-h" || arg == "--help") {
|
||||
train_print_usage(argc, argv, &default_params);
|
||||
exit(0);
|
||||
|
@ -3316,8 +3891,12 @@ int main(int argc, char ** argv) {
|
|||
|
||||
size_t size_buf_0 = 1024ll*1024ll*1024ll*((size_t) params.mem_compute0_gb);
|
||||
size_t size_buf_1 = 1024ll*1024ll*1024ll*((size_t) params.mem_compute1_gb);
|
||||
size_t size_buf_2 = 1024ll*1024ll*1024ll*((size_t) params.mem_compute2_gb);
|
||||
size_t size_buf_3 = 1024ll*1024ll*1024ll*((size_t) params.mem_compute3_gb);
|
||||
uint8_t * compute_buf_0 = new uint8_t[size_buf_0];
|
||||
uint8_t * compute_buf_1 = new uint8_t[size_buf_1];
|
||||
uint8_t * compute_buf_2 = new uint8_t[size_buf_2];
|
||||
uint8_t * compute_buf_3 = new uint8_t[size_buf_3];
|
||||
|
||||
GGML_ASSERT(n_tokens < (int) train_tokens.size());
|
||||
std::vector<int> train_samples;
|
||||
|
@ -3376,7 +3955,15 @@ int main(int argc, char ** argv) {
|
|||
struct ggml_tensor * loss = NULL;
|
||||
struct ggml_tensor * logits = NULL;
|
||||
|
||||
if (params.use_scratch) {
|
||||
if (params.use_checkpointing) {
|
||||
loss = forward_batch_wo_cache_flash_attn_train_grad_checkpointing(
|
||||
&model, ctx0,
|
||||
gf, gb,
|
||||
&logits, tokens_input, target_probs,
|
||||
compute_buf_0, compute_buf_1, compute_buf_2, compute_buf_3,
|
||||
size_buf_0, size_buf_1, size_buf_2, size_buf_3,
|
||||
n_tokens, n_batch);
|
||||
} else if (params.use_scratch) {
|
||||
loss = forward_batch_wo_cache_flash_attn_train(
|
||||
&model, ctx0,
|
||||
gf, gb,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue