From b1592ea05451f7e02e6f60d1889bc2305d9b428e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 3 Nov 2023 14:40:45 +0200 Subject: [PATCH] train : fix context size calculations --- examples/finetune/finetune.cpp | 8 +++----- .../train-text-from-scratch/train-text-from-scratch.cpp | 8 +++----- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/examples/finetune/finetune.cpp b/examples/finetune/finetune.cpp index b891908be..58f37487e 100644 --- a/examples/finetune/finetune.cpp +++ b/examples/finetune/finetune.cpp @@ -1742,11 +1742,9 @@ int main(int argc, char ** argv) { ggml_allocr_free(alloc); // context for compute tensors without their data - size_t estimated_compute_size_wo_data = ( - ggml_tensor_overhead()*LLAMA_TRAIN_MAX_NODES*2 - + (GGML_OBJECT_SIZE+ggml_graph_overhead())*( - params.common.use_checkpointing ? 3 : 2 - ) + const size_t estimated_compute_size_wo_data = ( + 2*LLAMA_TRAIN_MAX_NODES*ggml_tensor_overhead() + + (params.common.use_checkpointing ? 3 : 2)*(GGML_OBJECT_SIZE+ggml_graph_overhead_custom(LLAMA_TRAIN_MAX_NODES, true)) ); struct ggml_init_params ctx_compute_params = { estimated_compute_size_wo_data, // mem_size diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 5a734691f..f317e4a84 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -1109,11 +1109,9 @@ int main(int argc, char ** argv) { ggml_allocr_free(alloc); // context for compute tensors without their data - size_t estimated_compute_size_wo_data = ( - ggml_tensor_overhead()*LLAMA_TRAIN_MAX_NODES*2 - + (GGML_OBJECT_SIZE+ggml_graph_overhead())*( - params.common.use_checkpointing ? 3 : 2 - ) + const size_t estimated_compute_size_wo_data = ( + 2*LLAMA_TRAIN_MAX_NODES*ggml_tensor_overhead() + + (params.common.use_checkpointing ? 3 : 2)*(GGML_OBJECT_SIZE+ggml_graph_overhead_custom(LLAMA_TRAIN_MAX_NODES, true)) ); struct ggml_init_params ctx_compute_params = { estimated_compute_size_wo_data, // mem_size