speculative: revert default behavior when -td is unspecified

This commit is contained in:
Stéphane du Hamel 2024-01-15 18:11:54 +01:00
parent 49cbf3ec4d
commit f4fe6333d8
3 changed files with 6 additions and 4 deletions

View file

@ -864,7 +864,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" -tb N, --threads-batch N\n"); printf(" -tb N, --threads-batch N\n");
printf(" number of threads to use during batch and prompt processing (default: same as --threads)\n"); printf(" number of threads to use during batch and prompt processing (default: same as --threads)\n");
printf(" -td N, --threads-draft N"); printf(" -td N, --threads-draft N");
printf(" number of threads to use during generation (default: %d)\n", params.n_threads_draft); printf(" number of threads to use during generation (default: same as --threads)");
printf(" -tbd N, --threads-batch-draft N\n"); printf(" -tbd N, --threads-batch-draft N\n");
printf(" number of threads to use during batch and prompt processing (default: same as --threads-draft)\n"); printf(" number of threads to use during batch and prompt processing (default: same as --threads-draft)\n");
printf(" -p PROMPT, --prompt PROMPT\n"); printf(" -p PROMPT, --prompt PROMPT\n");

View file

@ -46,9 +46,9 @@ struct gpt_params {
uint32_t seed = -1; // RNG seed uint32_t seed = -1; // RNG seed
int32_t n_threads = get_num_physical_cores(); int32_t n_threads = get_num_physical_cores();
int32_t n_threads_draft = get_num_physical_cores(); int32_t n_threads_draft = -1;
int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads) int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads)
int32_t n_threads_batch_draft = -1; // number of threads to use for batch processing (-1 = use n_threads) int32_t n_threads_batch_draft = -1;
int32_t n_predict = -1; // new tokens to predict int32_t n_predict = -1; // new tokens to predict
int32_t n_ctx = 512; // context size int32_t n_ctx = 512; // context size
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS) int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)

View file

@ -65,7 +65,9 @@ int main(int argc, char ** argv) {
// load the draft model // load the draft model
params.model = params.model_draft; params.model = params.model_draft;
params.n_gpu_layers = params.n_gpu_layers_draft; params.n_gpu_layers = params.n_gpu_layers_draft;
params.n_threads = params.n_threads_draft; if (params.n_threads_draft > 0) {
params.n_threads = params.n_threads_draft;
}
params.n_threads_batch = params.n_threads_batch_draft; params.n_threads_batch = params.n_threads_batch_draft;
std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params); std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params);