From 1074a81e819b9076599d9da93d6ebd99dcce93b8 Mon Sep 17 00:00:00 2001 From: xaedes Date: Tue, 30 May 2023 16:06:20 +0200 Subject: [PATCH] add train params to specify memory size --- examples/baby-llama/baby-llama-text.cpp | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/examples/baby-llama/baby-llama-text.cpp b/examples/baby-llama/baby-llama-text.cpp index 5d48b7155..03f93c749 100644 --- a/examples/baby-llama/baby-llama-text.cpp +++ b/examples/baby-llama/baby-llama-text.cpp @@ -2122,6 +2122,9 @@ struct train_params { int adam_n_iter; float adam_alpha; float adam_decay; + + int mem_model_gb; + int mem_compute_gb; }; struct train_params get_default_train_params() { @@ -2164,6 +2167,9 @@ struct train_params get_default_train_params() { params.adam_alpha = 1e-3; params.adam_decay = 1e-3; + params.mem_model_gb = 2; + params.mem_compute_gb = 32; + return params; } @@ -2203,6 +2209,8 @@ void train_print_usage(int /*argc*/, char ** argv, const struct train_params * p fprintf(stderr, " --adam-iter N Maximum number of Adam optimization iterations for each batch (default %d)\n", params->adam_n_iter); fprintf(stderr, " --adam-alpha N Adam learning rate alpha (default %f)\n", params->adam_alpha); fprintf(stderr, " --adam-decay N AdamW weight decay. Values greater zero enable AdamW instead of regular Adam. (default %f)\n", params->adam_decay); + fprintf(stderr, " --mem-model N Memory to allocate for model and cache in gigabytes. (default %d)\n", params->mem_model_gb); + fprintf(stderr, " --mem-compute N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute_gb); fprintf(stderr, "\n"); } @@ -2384,6 +2392,18 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) { break; } params->adam_decay = std::stof(argv[i]); + } else if (arg == "--mem-model") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->mem_model_gb = std::stoi(argv[i]); + } else if (arg == "--mem-compute") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->mem_compute_gb = std::stoi(argv[i]); } else if (arg == "-h" || arg == "--help") { train_print_usage(argc, argv, &default_params); exit(0); @@ -2480,7 +2500,7 @@ int main(int argc, char ** argv) { struct ggml_init_params lcparams; - lcparams.mem_size = 1024ll*1024ll*1024ll*2ll; + lcparams.mem_size = 1024ll*1024ll*1024ll*((size_t) params.mem_model_gb); lcparams.mem_buffer = NULL; lcparams.no_alloc = false; @@ -2536,7 +2556,7 @@ int main(int argc, char ** argv) { printf("used_mem model+cache: %zu bytes\n", ggml_used_mem(model.ctx)); // ggml_print_tensor_objects(model.ctx); - size_t compute_size = 1024ll*1024ll*1024ll*32ll; + size_t compute_size = 1024ll*1024ll*1024ll*((size_t) params.mem_compute_gb); uint8_t * compute_addr = new uint8_t[compute_size]; GGML_ASSERT(train_tokens.size() > n_tokens);;