From 847896aba762bdba22e17b44e602e5adeb84547f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 3 Sep 2023 13:51:07 +0300 Subject: [PATCH] speculative : add --draft CLI arg --- common/common.cpp | 9 ++++++++- common/common.h | 3 ++- examples/speculative/speculative.cpp | 2 +- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 0a01f24ac..313821375 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -305,6 +305,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.n_keep = std::stoi(argv[i]); + } else if (arg == "--draft") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_draft = std::stoi(argv[i]); } else if (arg == "--chunks") { if (++i >= argc) { invalid_param = true; @@ -644,6 +650,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stdout, " --hellaswag compute HellaSwag score over random tasks from datafile supplied with -f\n"); fprintf(stdout, " --hellaswag-tasks N number of tasks to use when computing the HellaSwag score (default: %zu)\n", params.hellaswag_tasks); fprintf(stdout, " --keep N number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep); + fprintf(stdout, " --draft N number of tokens to draft for speculative decoding (default: %d)\n", params.n_draft); fprintf(stdout, " --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks); if (llama_mlock_supported()) { fprintf(stdout, " --mlock force system to keep model in RAM rather than swapping or compressing\n"); @@ -676,7 +683,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stdout, " -m FNAME, --model FNAME\n"); fprintf(stdout, " model path (default: %s)\n", params.model.c_str()); fprintf(stdout, " -md FNAME, --model-draft FNAME\n"); - fprintf(stdout, " draft model for speculative sampling (default: %s)\n", params.model.c_str()); + fprintf(stdout, " draft model for speculative decoding (default: %s)\n", params.model.c_str()); fprintf(stdout, " -ld LOGDIR, --logdir LOGDIR\n"); fprintf(stdout, " path under which to save YAML logs (no logging if unset)\n"); fprintf(stdout, "\n"); diff --git a/common/common.h b/common/common.h index e77fa3cf5..105fb09e4 100644 --- a/common/common.h +++ b/common/common.h @@ -32,6 +32,7 @@ struct gpt_params { int32_t n_ctx = 512; // context size int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS) int32_t n_keep = 0; // number of tokens to keep from initial prompt + int32_t n_draft = 16; // number of tokens to draft during speculative decoding int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited) int32_t n_gpu_layers = 0; // number of layers to store in VRAM int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors @@ -63,7 +64,7 @@ struct gpt_params { float cfg_scale = 1.f; // How strong is guidance std::string model = "models/7B/ggml-model-f16.gguf"; // model path - std::string model_draft = ""; // draft model for speculative sampling + std::string model_draft = ""; // draft model for speculative decoding std::string model_alias = "unknown"; // model alias std::string prompt = ""; std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 67d5ba113..f0400c13f 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -84,7 +84,7 @@ int main(int argc, char ** argv) { //GGML_ASSERT(n_vocab == llama_n_vocab(ctx_dft)); // how many tokens to draft each time - const int n_draft = 16; + const int n_draft = params.n_draft; int n_predict = 0; int n_drafted = 0;