diff --git a/common/common.cpp b/common/common.cpp index c11006bcb..af9a853e3 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -845,6 +845,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" -t N, --threads N number of threads to use during generation (default: %d)\n", params.n_threads); printf(" -tb N, --threads-batch N\n"); printf(" number of threads to use during batch and prompt processing (default: same as --threads)\n"); + printf(" -td N, --threads-draft N number of threads to use during generation (default: %d)\n", params.n_threads); + printf(" -tbd N, --threads-batch-draft N\n"); + printf(" number of threads to use during batch and prompt processing (default: same as --threads-draft)\n"); printf(" -p PROMPT, --prompt PROMPT\n"); printf(" prompt to start generation with (default: empty)\n"); printf(" -e, --escape process prompt escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n"); diff --git a/common/common.h b/common/common.h index 096468243..804f6be61 100644 --- a/common/common.h +++ b/common/common.h @@ -46,7 +46,9 @@ struct gpt_params { uint32_t seed = -1; // RNG seed int32_t n_threads = get_num_physical_cores(); + int32_t n_threads_draft = get_num_physical_cores(); int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads) + int32_t n_threads_batch_draft = -1; // number of threads to use for batch processing (-1 = use n_threads) int32_t n_predict = -1; // new tokens to predict int32_t n_ctx = 512; // context size int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS) diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 20f1fb5bf..18aa7cfe7 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -65,6 +65,8 @@ int main(int argc, char ** argv) { // load the draft model params.model = params.model_draft; params.n_gpu_layers = params.n_gpu_layers_draft; + params.n_threads = params.n_threads_draft; + params.n_threads_batch = params.n_threads_batch_draft; std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params); {