diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 9ee255f4e..ae3f79c63 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -2090,22 +2090,39 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train_grad_checkpointing( struct ggml_tensor * t01 = expand(gf, ggml_get_rows(ctx0, model->tok_embeddings, t00)); assert_shape_2d(t01, n_embd, N*n_batch); + { + // given: n, u, v + // objective: minimize(a*u+b*v) where a*b=n, a>0, b>0 + // b=n/a + // minimize(a*u+v*n/a) + // diff(a*u+v*n/a, a) = u - (v*n/a)/a + // diff(a*u+v*n/a, a) == 0 + // u - (v*n/a)/a == 0 + // u == v*n/(a*a) + // u*a*a = v*n + // a*a = v*n/u + // a = sqrt(n*v/u) + } + + float memcost_checkpoint = n_embd; // (..)*N*n_batch + float memcost_snd_fwd_pass = 14*n_embd+4*n_ff; // (..)*N*n_batch + + int n_checkstep = (int)(sqrtf(n_layer*memcost_checkpoint/memcost_snd_fwd_pass) + 0.5f); + if (n_checkstep < 1) { + n_checkstep = 1; + } std::vector checkpoints; - // for (int il = 0; il < n_layer; ++il) { - // checkpoints.push_back(il); - // } - // n_check: number of layers between checkpoints - int n_check = (int)(sqrtf(n_layer) + 0.5f); - printf("%s: n_check = %d\n", __func__, n_check); - for (int chk = n_check-1; chk+1 < n_layer; chk += n_check) { + for (int chk = n_checkstep-1; chk+1 < n_layer; chk += n_checkstep) { checkpoints.push_back(chk); } + int n_check = checkpoints.size(); + // printf("%s: n_check = %d n_checkstep = %d\n", __func__, n_check, n_checkstep); - for (int i = 0; i < checkpoints.size(); ++i) { - printf("%s: checkpoint #%d = %d\n", __func__, i, checkpoints[i]); - } + // for (int i = 0; i < n_check; ++i) { + // printf("%s: checkpoint #%d = %d\n", __func__, i, checkpoints[i]); + // } - // example for 16 layers: + // example for 16 layers and memcost_checkpoint=memcost_snd_fwd_pass: // inp ~ implicit zeroth checkpoint == input // L00 f 4b [ // L01 f 4b 4th second forward pass