speculative : add --draft CLI arg
This commit is contained in:
parent
a15ca746c7
commit
847896aba7
3 changed files with 11 additions and 3 deletions
|
@ -305,6 +305,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params.n_keep = std::stoi(argv[i]);
|
params.n_keep = std::stoi(argv[i]);
|
||||||
|
} else if (arg == "--draft") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
params.n_draft = std::stoi(argv[i]);
|
||||||
} else if (arg == "--chunks") {
|
} else if (arg == "--chunks") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
|
@ -644,6 +650,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
||||||
fprintf(stdout, " --hellaswag compute HellaSwag score over random tasks from datafile supplied with -f\n");
|
fprintf(stdout, " --hellaswag compute HellaSwag score over random tasks from datafile supplied with -f\n");
|
||||||
fprintf(stdout, " --hellaswag-tasks N number of tasks to use when computing the HellaSwag score (default: %zu)\n", params.hellaswag_tasks);
|
fprintf(stdout, " --hellaswag-tasks N number of tasks to use when computing the HellaSwag score (default: %zu)\n", params.hellaswag_tasks);
|
||||||
fprintf(stdout, " --keep N number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
|
fprintf(stdout, " --keep N number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
|
||||||
|
fprintf(stdout, " --draft N number of tokens to draft for speculative decoding (default: %d)\n", params.n_draft);
|
||||||
fprintf(stdout, " --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks);
|
fprintf(stdout, " --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks);
|
||||||
if (llama_mlock_supported()) {
|
if (llama_mlock_supported()) {
|
||||||
fprintf(stdout, " --mlock force system to keep model in RAM rather than swapping or compressing\n");
|
fprintf(stdout, " --mlock force system to keep model in RAM rather than swapping or compressing\n");
|
||||||
|
@ -676,7 +683,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
||||||
fprintf(stdout, " -m FNAME, --model FNAME\n");
|
fprintf(stdout, " -m FNAME, --model FNAME\n");
|
||||||
fprintf(stdout, " model path (default: %s)\n", params.model.c_str());
|
fprintf(stdout, " model path (default: %s)\n", params.model.c_str());
|
||||||
fprintf(stdout, " -md FNAME, --model-draft FNAME\n");
|
fprintf(stdout, " -md FNAME, --model-draft FNAME\n");
|
||||||
fprintf(stdout, " draft model for speculative sampling (default: %s)\n", params.model.c_str());
|
fprintf(stdout, " draft model for speculative decoding (default: %s)\n", params.model.c_str());
|
||||||
fprintf(stdout, " -ld LOGDIR, --logdir LOGDIR\n");
|
fprintf(stdout, " -ld LOGDIR, --logdir LOGDIR\n");
|
||||||
fprintf(stdout, " path under which to save YAML logs (no logging if unset)\n");
|
fprintf(stdout, " path under which to save YAML logs (no logging if unset)\n");
|
||||||
fprintf(stdout, "\n");
|
fprintf(stdout, "\n");
|
||||||
|
|
|
@ -32,6 +32,7 @@ struct gpt_params {
|
||||||
int32_t n_ctx = 512; // context size
|
int32_t n_ctx = 512; // context size
|
||||||
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
|
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
|
||||||
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
||||||
|
int32_t n_draft = 16; // number of tokens to draft during speculative decoding
|
||||||
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
|
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
|
||||||
int32_t n_gpu_layers = 0; // number of layers to store in VRAM
|
int32_t n_gpu_layers = 0; // number of layers to store in VRAM
|
||||||
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
|
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
|
||||||
|
@ -63,7 +64,7 @@ struct gpt_params {
|
||||||
float cfg_scale = 1.f; // How strong is guidance
|
float cfg_scale = 1.f; // How strong is guidance
|
||||||
|
|
||||||
std::string model = "models/7B/ggml-model-f16.gguf"; // model path
|
std::string model = "models/7B/ggml-model-f16.gguf"; // model path
|
||||||
std::string model_draft = ""; // draft model for speculative sampling
|
std::string model_draft = ""; // draft model for speculative decoding
|
||||||
std::string model_alias = "unknown"; // model alias
|
std::string model_alias = "unknown"; // model alias
|
||||||
std::string prompt = "";
|
std::string prompt = "";
|
||||||
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
|
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
|
||||||
|
|
|
@ -84,7 +84,7 @@ int main(int argc, char ** argv) {
|
||||||
//GGML_ASSERT(n_vocab == llama_n_vocab(ctx_dft));
|
//GGML_ASSERT(n_vocab == llama_n_vocab(ctx_dft));
|
||||||
|
|
||||||
// how many tokens to draft each time
|
// how many tokens to draft each time
|
||||||
const int n_draft = 16;
|
const int n_draft = params.n_draft;
|
||||||
|
|
||||||
int n_predict = 0;
|
int n_predict = 0;
|
||||||
int n_drafted = 0;
|
int n_drafted = 0;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue