diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 28fbd2dc8..15f60513f 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -3721,6 +3721,8 @@ struct train_params { bool use_flash; bool use_scratch; bool use_checkpointing; + bool use_alloc; + bool use_unified; // only adam int warmup; @@ -3782,6 +3784,8 @@ struct train_params get_default_train_params() { params.use_flash = true; params.use_scratch = true; params.use_checkpointing = true; + params.use_alloc = true; + params.use_unified = true; params.opt_past = 0; params.opt_delta = 1e-5f; @@ -3845,6 +3849,10 @@ void train_print_usage(int /*argc*/, char ** argv, const struct train_params * p fprintf(stderr, " --use-scratch Use scratch buffers. Implies use-flash. (default)\n"); fprintf(stderr, " --no-checkpointing Don't use gradient checkpointing\n"); fprintf(stderr, " --use-checkpointing Use gradient checkpointing. Implies use-scratch and use-flash. (default)\n"); + fprintf(stderr, " --no-alloc Don't use allocator\n"); + fprintf(stderr, " --use-alloc Use allocator. Implies use-unified. (default)\n"); + fprintf(stderr, " --no-unified Don't use unified\n"); + fprintf(stderr, " --use-unified Use unified. (default)\n"); fprintf(stderr, " --warmup N Only for Adam optimizer. Number of warmup steps (default %d)\n", params->warmup); fprintf(stderr, " --cos-decay-steps N Only for Adam optimizer. Number of cosine decay steps (default %d)\n", params->cos_decay_steps); fprintf(stderr, " --cos-decay-restart N Only for Adam optimizer. Increase of cosine decay steps after restart (default %f)\n", params->cos_decay_restart); @@ -4010,6 +4018,14 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) { params->use_checkpointing = false; } else if (arg == "--use-checkpointing") { params->use_checkpointing = true; + } else if (arg == "--no-alloc") { + params->use_alloc = false; + } else if (arg == "--use-alloc") { + params->use_alloc = true; + } else if (arg == "--no-unified") { + params->use_unified = false; + } else if (arg == "--use-unified") { + params->use_unified = true; } else if (arg == "--warmup") { if (++i >= argc) { invalid_param = true;