improve gradient checkpointing
sqrt(n_layers) is only the best checkpoint step when mem size of checkpoints and mem size of layers are equal. since layers require more memory than the single-tensor-checkpoint we use, the optimal values are compute different: ``` given: n, u, v objective: minimize(a*u+b*v) where a*b=n, a>0, b>0 b=n/a minimize(a*u+v*n/a) diff(a*u+v*n/a, a) = u - (v*n/a)/a diff(a*u+v*n/a, a) == 0 u - (v*n/a)/a == 0 u == v*n/(a*a) u*a*a = v*n a*a = v*n/u a = sqrt(n*v/u) ``` this change results in more checkpoints, requiring less layers to store between checkpoints, overall improving memory usage.
This commit is contained in:
parent
51dc77092f
commit
3744a9be74
1 changed files with 28 additions and 11 deletions
|
@ -2090,22 +2090,39 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train_grad_checkpointing(
|
|||
struct ggml_tensor * t01 = expand(gf, ggml_get_rows(ctx0, model->tok_embeddings, t00)); assert_shape_2d(t01, n_embd, N*n_batch);
|
||||
|
||||
|
||||
{
|
||||
// given: n, u, v
|
||||
// objective: minimize(a*u+b*v) where a*b=n, a>0, b>0
|
||||
// b=n/a
|
||||
// minimize(a*u+v*n/a)
|
||||
// diff(a*u+v*n/a, a) = u - (v*n/a)/a
|
||||
// diff(a*u+v*n/a, a) == 0
|
||||
// u - (v*n/a)/a == 0
|
||||
// u == v*n/(a*a)
|
||||
// u*a*a = v*n
|
||||
// a*a = v*n/u
|
||||
// a = sqrt(n*v/u)
|
||||
}
|
||||
|
||||
float memcost_checkpoint = n_embd; // (..)*N*n_batch
|
||||
float memcost_snd_fwd_pass = 14*n_embd+4*n_ff; // (..)*N*n_batch
|
||||
|
||||
int n_checkstep = (int)(sqrtf(n_layer*memcost_checkpoint/memcost_snd_fwd_pass) + 0.5f);
|
||||
if (n_checkstep < 1) {
|
||||
n_checkstep = 1;
|
||||
}
|
||||
std::vector<int> checkpoints;
|
||||
// for (int il = 0; il < n_layer; ++il) {
|
||||
// checkpoints.push_back(il);
|
||||
// }
|
||||
// n_check: number of layers between checkpoints
|
||||
int n_check = (int)(sqrtf(n_layer) + 0.5f);
|
||||
printf("%s: n_check = %d\n", __func__, n_check);
|
||||
for (int chk = n_check-1; chk+1 < n_layer; chk += n_check) {
|
||||
for (int chk = n_checkstep-1; chk+1 < n_layer; chk += n_checkstep) {
|
||||
checkpoints.push_back(chk);
|
||||
}
|
||||
int n_check = checkpoints.size();
|
||||
// printf("%s: n_check = %d n_checkstep = %d\n", __func__, n_check, n_checkstep);
|
||||
|
||||
for (int i = 0; i < checkpoints.size(); ++i) {
|
||||
printf("%s: checkpoint #%d = %d\n", __func__, i, checkpoints[i]);
|
||||
}
|
||||
// for (int i = 0; i < n_check; ++i) {
|
||||
// printf("%s: checkpoint #%d = %d\n", __func__, i, checkpoints[i]);
|
||||
// }
|
||||
|
||||
// example for 16 layers:
|
||||
// example for 16 layers and memcost_checkpoint=memcost_snd_fwd_pass:
|
||||
// inp ~ implicit zeroth checkpoint == input
|
||||
// L00 f 4b [
|
||||
// L01 f 4b 4th second forward pass
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue