improve gradient checkpointing

sqrt(n_layers) is only the best checkpoint step when mem size of checkpoints and mem size of layers are equal.
since layers require more memory than the single-tensor-checkpoint we use, the optimal values are compute different:

```
  given: n, u, v
  objective: minimize(a*u+b*v) where a*b=n, a>0, b>0
  b=n/a
  minimize(a*u+v*n/a)
  diff(a*u+v*n/a, a) = u - (v*n/a)/a
  diff(a*u+v*n/a, a) == 0
  u - (v*n/a)/a == 0
  u == v*n/(a*a)
  u*a*a = v*n
  a*a = v*n/u
  a = sqrt(n*v/u)
```

this change results in more checkpoints, requiring less layers to store between checkpoints, overall improving memory usage.
This commit is contained in:
xaedes 2023-07-02 21:11:11 +02:00
parent 51dc77092f
commit 3744a9be74
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -2090,22 +2090,39 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train_grad_checkpointing(
struct ggml_tensor * t01 = expand(gf, ggml_get_rows(ctx0, model->tok_embeddings, t00)); assert_shape_2d(t01, n_embd, N*n_batch);
{
// given: n, u, v
// objective: minimize(a*u+b*v) where a*b=n, a>0, b>0
// b=n/a
// minimize(a*u+v*n/a)
// diff(a*u+v*n/a, a) = u - (v*n/a)/a
// diff(a*u+v*n/a, a) == 0
// u - (v*n/a)/a == 0
// u == v*n/(a*a)
// u*a*a = v*n
// a*a = v*n/u
// a = sqrt(n*v/u)
}
float memcost_checkpoint = n_embd; // (..)*N*n_batch
float memcost_snd_fwd_pass = 14*n_embd+4*n_ff; // (..)*N*n_batch
int n_checkstep = (int)(sqrtf(n_layer*memcost_checkpoint/memcost_snd_fwd_pass) + 0.5f);
if (n_checkstep < 1) {
n_checkstep = 1;
}
std::vector<int> checkpoints;
// for (int il = 0; il < n_layer; ++il) {
// checkpoints.push_back(il);
// }
// n_check: number of layers between checkpoints
int n_check = (int)(sqrtf(n_layer) + 0.5f);
printf("%s: n_check = %d\n", __func__, n_check);
for (int chk = n_check-1; chk+1 < n_layer; chk += n_check) {
for (int chk = n_checkstep-1; chk+1 < n_layer; chk += n_checkstep) {
checkpoints.push_back(chk);
}
int n_check = checkpoints.size();
// printf("%s: n_check = %d n_checkstep = %d\n", __func__, n_check, n_checkstep);
for (int i = 0; i < checkpoints.size(); ++i) {
printf("%s: checkpoint #%d = %d\n", __func__, i, checkpoints[i]);
}
// for (int i = 0; i < n_check; ++i) {
// printf("%s: checkpoint #%d = %d\n", __func__, i, checkpoints[i]);
// }
// example for 16 layers:
// example for 16 layers and memcost_checkpoint=memcost_snd_fwd_pass:
// inp ~ implicit zeroth checkpoint == input
// L00 f 4b [
// L01 f 4b 4th second forward pass