train : allocate grads for backward graphs

This commit is contained in:
Georgi Gerganov 2023-11-07 10:17:45 +02:00
parent aa1f36c90a
commit a4de8042ee
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 4 additions and 4 deletions

View file

@ -1769,7 +1769,7 @@ int main(int argc, char ** argv) {
alloc = ggml_allocr_new_measure(tensor_alignment);
gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gf->order = (enum ggml_cgraph_eval_order) order;
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false);
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gb_tmp = params.common.use_checkpointing
? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false)
: NULL;
@ -1802,7 +1802,7 @@ int main(int argc, char ** argv) {
alloc = ggml_allocr_new(mem_compute_data.data(), mem_compute_data.size(), tensor_alignment);
gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gf->order = best_order;
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false);
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gb_tmp = params.common.use_checkpointing
? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false)
: NULL;

View file

@ -1136,7 +1136,7 @@ int main(int argc, char ** argv) {
alloc = ggml_allocr_new_measure(tensor_alignment);
gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gf->order = (enum ggml_cgraph_eval_order) order;
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false);
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gb_tmp = params.common.use_checkpointing
? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false)
: NULL;
@ -1169,7 +1169,7 @@ int main(int argc, char ** argv) {
alloc = ggml_allocr_new(mem_compute_data.data(), mem_compute_data.size(), tensor_alignment);
gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gf->order = best_order;
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false);
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gb_tmp = params.common.use_checkpointing
? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, false)
: NULL;