diff --git a/examples/common.cpp b/examples/common.cpp index 1b77fef43..94b5ead6f 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -1,13 +1,14 @@ #include "common.h" #include +#include #include #include #include #include #include -#elif defined(__APPLE__) && defined(__MACH__) +#if defined(__APPLE__) && defined(__MACH__) #include #include #endif @@ -42,7 +43,6 @@ int32_t get_num_physical_cores() { } } #elif defined(__APPLE__) && defined(__MACH__) - int32_t num_physical_cores; size_t len = sizeof(num_physical_cores); int result = sysctlbyname("hw.perflevel0.physicalcpu", &num_physical_cores, &len, NULL, 0); @@ -62,11 +62,6 @@ int32_t get_num_physical_cores() { } bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { - // Clip if not a valid number of threads - if (params.n_threads <= 0) { - params.n_threads = std::max(1, std::min(8, (int32_t) std::thread::hardware_concurrency())); - } - bool invalid_param = false; std::string arg; gpt_params default_params; @@ -229,6 +224,17 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { exit(1); } + // Clip if not a valid number of threads + if (params.n_threads <= 0) { + int32_t physical_cores = get_num_physical_cores(); + if (physical_cores > 4) { + std::cout << "\n\033[1;31mWARNING:\033[0m Defaulting to 4 threads. " + << "(detected " << physical_cores << " physical cores)" << std::endl + << "Adjust --threads based on your observed inference speed in ms/token." << std::endl << std::endl; + } + params.n_threads = std::max(1, std::min(4, physical_cores)); + } + return true; } diff --git a/examples/common.h b/examples/common.h index 61e49557f..37a4d66ab 100644 --- a/examples/common.h +++ b/examples/common.h @@ -13,11 +13,9 @@ // CLI argument parsing // -int32_t get_num_physical_cores(); - struct gpt_params { int32_t seed = -1; // RNG seed - int32_t n_threads = get_num_physical_cores(); // (if <= 0, = clip(num_logical_cores, 1, 8)) + int32_t n_threads = 0; int32_t n_predict = 128; // new tokens to predict int32_t repeat_last_n = 64; // last n tokens to penalize int32_t n_parts = -1; // amount of model parts (-1 = determine from model dimensions)