add input tensors as checkpoints

so that recursive tensor cloning of gradient checkpointing terminates on input tensors
This commit is contained in:
xaedes 2023-08-14 17:58:49 +02:00
parent b2f1310196
commit 5884b43a62
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -1579,7 +1579,10 @@ struct ggml_tensor * llama_build_train_graphs(
struct ggml_tensor * cur = t01;
std::vector<struct ggml_tensor *> checkpoints;
checkpoints.push_back(cur);
checkpoints.push_back(tokens_input);
checkpoints.push_back(targets);
checkpoints.push_back(t00);
checkpoints.push_back(t01);
struct ggml_tensor * kv_scale;
if (flash_attn) {