add train params to specify memory size
This commit is contained in:
parent
ad966da955
commit
1074a81e81
1 changed files with 22 additions and 2 deletions
|
@ -2122,6 +2122,9 @@ struct train_params {
|
||||||
int adam_n_iter;
|
int adam_n_iter;
|
||||||
float adam_alpha;
|
float adam_alpha;
|
||||||
float adam_decay;
|
float adam_decay;
|
||||||
|
|
||||||
|
int mem_model_gb;
|
||||||
|
int mem_compute_gb;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct train_params get_default_train_params() {
|
struct train_params get_default_train_params() {
|
||||||
|
@ -2164,6 +2167,9 @@ struct train_params get_default_train_params() {
|
||||||
params.adam_alpha = 1e-3;
|
params.adam_alpha = 1e-3;
|
||||||
params.adam_decay = 1e-3;
|
params.adam_decay = 1e-3;
|
||||||
|
|
||||||
|
params.mem_model_gb = 2;
|
||||||
|
params.mem_compute_gb = 32;
|
||||||
|
|
||||||
return params;
|
return params;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2203,6 +2209,8 @@ void train_print_usage(int /*argc*/, char ** argv, const struct train_params * p
|
||||||
fprintf(stderr, " --adam-iter N Maximum number of Adam optimization iterations for each batch (default %d)\n", params->adam_n_iter);
|
fprintf(stderr, " --adam-iter N Maximum number of Adam optimization iterations for each batch (default %d)\n", params->adam_n_iter);
|
||||||
fprintf(stderr, " --adam-alpha N Adam learning rate alpha (default %f)\n", params->adam_alpha);
|
fprintf(stderr, " --adam-alpha N Adam learning rate alpha (default %f)\n", params->adam_alpha);
|
||||||
fprintf(stderr, " --adam-decay N AdamW weight decay. Values greater zero enable AdamW instead of regular Adam. (default %f)\n", params->adam_decay);
|
fprintf(stderr, " --adam-decay N AdamW weight decay. Values greater zero enable AdamW instead of regular Adam. (default %f)\n", params->adam_decay);
|
||||||
|
fprintf(stderr, " --mem-model N Memory to allocate for model and cache in gigabytes. (default %d)\n", params->mem_model_gb);
|
||||||
|
fprintf(stderr, " --mem-compute N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute_gb);
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2384,6 +2392,18 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params->adam_decay = std::stof(argv[i]);
|
params->adam_decay = std::stof(argv[i]);
|
||||||
|
} else if (arg == "--mem-model") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
params->mem_model_gb = std::stoi(argv[i]);
|
||||||
|
} else if (arg == "--mem-compute") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
params->mem_compute_gb = std::stoi(argv[i]);
|
||||||
} else if (arg == "-h" || arg == "--help") {
|
} else if (arg == "-h" || arg == "--help") {
|
||||||
train_print_usage(argc, argv, &default_params);
|
train_print_usage(argc, argv, &default_params);
|
||||||
exit(0);
|
exit(0);
|
||||||
|
@ -2480,7 +2500,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
|
|
||||||
struct ggml_init_params lcparams;
|
struct ggml_init_params lcparams;
|
||||||
lcparams.mem_size = 1024ll*1024ll*1024ll*2ll;
|
lcparams.mem_size = 1024ll*1024ll*1024ll*((size_t) params.mem_model_gb);
|
||||||
lcparams.mem_buffer = NULL;
|
lcparams.mem_buffer = NULL;
|
||||||
lcparams.no_alloc = false;
|
lcparams.no_alloc = false;
|
||||||
|
|
||||||
|
@ -2536,7 +2556,7 @@ int main(int argc, char ** argv) {
|
||||||
printf("used_mem model+cache: %zu bytes\n", ggml_used_mem(model.ctx));
|
printf("used_mem model+cache: %zu bytes\n", ggml_used_mem(model.ctx));
|
||||||
// ggml_print_tensor_objects(model.ctx);
|
// ggml_print_tensor_objects(model.ctx);
|
||||||
|
|
||||||
size_t compute_size = 1024ll*1024ll*1024ll*32ll;
|
size_t compute_size = 1024ll*1024ll*1024ll*((size_t) params.mem_compute_gb);
|
||||||
uint8_t * compute_addr = new uint8_t[compute_size];
|
uint8_t * compute_addr = new uint8_t[compute_size];
|
||||||
|
|
||||||
GGML_ASSERT(train_tokens.size() > n_tokens);;
|
GGML_ASSERT(train_tokens.size() > n_tokens);;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue