diff --git a/examples/finetune/finetune.cpp b/examples/finetune/finetune.cpp index b891908be..58f37487e 100644 --- a/examples/finetune/finetune.cpp +++ b/examples/finetune/finetune.cpp @@ -1742,11 +1742,9 @@ int main(int argc, char ** argv) { ggml_allocr_free(alloc); // context for compute tensors without their data - size_t estimated_compute_size_wo_data = ( - ggml_tensor_overhead()*LLAMA_TRAIN_MAX_NODES*2 - + (GGML_OBJECT_SIZE+ggml_graph_overhead())*( - params.common.use_checkpointing ? 3 : 2 - ) + const size_t estimated_compute_size_wo_data = ( + 2*LLAMA_TRAIN_MAX_NODES*ggml_tensor_overhead() + + (params.common.use_checkpointing ? 3 : 2)*(GGML_OBJECT_SIZE+ggml_graph_overhead_custom(LLAMA_TRAIN_MAX_NODES, true)) ); struct ggml_init_params ctx_compute_params = { estimated_compute_size_wo_data, // mem_size diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 5a734691f..f317e4a84 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -1109,11 +1109,9 @@ int main(int argc, char ** argv) { ggml_allocr_free(alloc); // context for compute tensors without their data - size_t estimated_compute_size_wo_data = ( - ggml_tensor_overhead()*LLAMA_TRAIN_MAX_NODES*2 - + (GGML_OBJECT_SIZE+ggml_graph_overhead())*( - params.common.use_checkpointing ? 3 : 2 - ) + const size_t estimated_compute_size_wo_data = ( + 2*LLAMA_TRAIN_MAX_NODES*ggml_tensor_overhead() + + (params.common.use_checkpointing ? 3 : 2)*(GGML_OBJECT_SIZE+ggml_graph_overhead_custom(LLAMA_TRAIN_MAX_NODES, true)) ); struct ggml_init_params ctx_compute_params = { estimated_compute_size_wo_data, // mem_size