This commit is contained in:
jon-chuang 2023-04-16 00:12:04 +08:00
parent e524ce99fe
commit 81edec9776
2 changed files with 14 additions and 10 deletions

View file

@ -1,13 +1,14 @@
#include "common.h"
#include <cassert>
#include <iostream>
#include <cstring>
#include <fstream>
#include <string>
#include <iterator>
#include <algorithm>
#elif defined(__APPLE__) && defined(__MACH__)
#if defined(__APPLE__) && defined(__MACH__)
#include <sys/types.h>
#include <sys/sysctl.h>
#endif
@ -42,7 +43,6 @@ int32_t get_num_physical_cores() {
}
}
#elif defined(__APPLE__) && defined(__MACH__)
int32_t num_physical_cores;
size_t len = sizeof(num_physical_cores);
int result = sysctlbyname("hw.perflevel0.physicalcpu", &num_physical_cores, &len, NULL, 0);
@ -62,11 +62,6 @@ int32_t get_num_physical_cores() {
}
bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
// Clip if not a valid number of threads
if (params.n_threads <= 0) {
params.n_threads = std::max(1, std::min(8, (int32_t) std::thread::hardware_concurrency()));
}
bool invalid_param = false;
std::string arg;
gpt_params default_params;
@ -229,6 +224,17 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
exit(1);
}
// Clip if not a valid number of threads
if (params.n_threads <= 0) {
int32_t physical_cores = get_num_physical_cores();
if (physical_cores > 4) {
std::cout << "\n\033[1;31mWARNING:\033[0m Defaulting to 4 threads. "
<< "(detected " << physical_cores << " physical cores)" << std::endl
<< "Adjust --threads based on your observed inference speed in ms/token." << std::endl << std::endl;
}
params.n_threads = std::max(1, std::min(4, physical_cores));
}
return true;
}

View file

@ -13,11 +13,9 @@
// CLI argument parsing
//
int32_t get_num_physical_cores();
struct gpt_params {
int32_t seed = -1; // RNG seed
int32_t n_threads = get_num_physical_cores(); // (if <= 0, = clip(num_logical_cores, 1, 8))
int32_t n_threads = 0;
int32_t n_predict = 128; // new tokens to predict
int32_t repeat_last_n = 64; // last n tokens to penalize
int32_t n_parts = -1; // amount of model parts (-1 = determine from model dimensions)