From 3744a9be74b27c758b06ea2bdf8ee97046e2b196 Mon Sep 17 00:00:00 2001 From: xaedes Date: Sun, 2 Jul 2023 21:11:11 +0200 Subject: [PATCH] improve gradient checkpointing sqrt(n_layers) is only the best checkpoint step when mem size of checkpoints and mem size of layers are equal. since layers require more memory than the single-tensor-checkpoint we use, the optimal values are compute different: ``` given: n, u, v objective: minimize(a*u+b*v) where a*b=n, a>0, b>0 b=n/a minimize(a*u+v*n/a) diff(a*u+v*n/a, a) = u - (v*n/a)/a diff(a*u+v*n/a, a) == 0 u - (v*n/a)/a == 0 u == v*n/(a*a) u*a*a = v*n a*a = v*n/u a = sqrt(n*v/u) ``` this change results in more checkpoints, requiring less layers to store between checkpoints, overall improving memory usage. --- .../train-text-from-scratch.cpp | 39 +++++++++++++------ 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 9ee255f4e..ae3f79c63 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -2090,22 +2090,39 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train_grad_checkpointing( struct ggml_tensor * t01 = expand(gf, ggml_get_rows(ctx0, model->tok_embeddings, t00)); assert_shape_2d(t01, n_embd, N*n_batch); + { + // given: n, u, v + // objective: minimize(a*u+b*v) where a*b=n, a>0, b>0 + // b=n/a + // minimize(a*u+v*n/a) + // diff(a*u+v*n/a, a) = u - (v*n/a)/a + // diff(a*u+v*n/a, a) == 0 + // u - (v*n/a)/a == 0 + // u == v*n/(a*a) + // u*a*a = v*n + // a*a = v*n/u + // a = sqrt(n*v/u) + } + + float memcost_checkpoint = n_embd; // (..)*N*n_batch + float memcost_snd_fwd_pass = 14*n_embd+4*n_ff; // (..)*N*n_batch + + int n_checkstep = (int)(sqrtf(n_layer*memcost_checkpoint/memcost_snd_fwd_pass) + 0.5f); + if (n_checkstep < 1) { + n_checkstep = 1; + } std::vector checkpoints; - // for (int il = 0; il < n_layer; ++il) { - // checkpoints.push_back(il); - // } - // n_check: number of layers between checkpoints - int n_check = (int)(sqrtf(n_layer) + 0.5f); - printf("%s: n_check = %d\n", __func__, n_check); - for (int chk = n_check-1; chk+1 < n_layer; chk += n_check) { + for (int chk = n_checkstep-1; chk+1 < n_layer; chk += n_checkstep) { checkpoints.push_back(chk); } + int n_check = checkpoints.size(); + // printf("%s: n_check = %d n_checkstep = %d\n", __func__, n_check, n_checkstep); - for (int i = 0; i < checkpoints.size(); ++i) { - printf("%s: checkpoint #%d = %d\n", __func__, i, checkpoints[i]); - } + // for (int i = 0; i < n_check; ++i) { + // printf("%s: checkpoint #%d = %d\n", __func__, i, checkpoints[i]); + // } - // example for 16 layers: + // example for 16 layers and memcost_checkpoint=memcost_snd_fwd_pass: // inp ~ implicit zeroth checkpoint == input // L00 f 4b [ // L01 f 4b 4th second forward pass