From acea8e10a315bf4c82bd4c6399f70ca270b7abe1 Mon Sep 17 00:00:00 2001 From: crasm Date: Mon, 7 Aug 2023 21:12:43 -0400 Subject: [PATCH] examples/main: Add --prompt-cache-clobber parameter --- examples/common.cpp | 4 ++++ examples/common.h | 17 +++++++++-------- examples/main/main.cpp | 35 ++++++++++++++++++++++------------- 3 files changed, 35 insertions(+), 21 deletions(-) diff --git a/examples/common.cpp b/examples/common.cpp index 4d3ba9bb2..46afb2b51 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -137,6 +137,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.prompt_cache_all = true; } else if (arg == "--prompt-cache-ro") { params.prompt_cache_ro = true; + } else if (arg == "--prompt-cache-clobber") { + params.prompt_cache_clobber = true; } else if (arg == "-f" || arg == "--file") { if (++i >= argc) { invalid_param = true; @@ -537,6 +539,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stdout, " --prompt-cache-all if specified, saves user input and generations to cache as well.\n"); fprintf(stdout, " not supported with --interactive or other interactive options\n"); fprintf(stdout, " --prompt-cache-ro if specified, uses the prompt cache but does not update it.\n"); + fprintf(stdout, " --prompt-cache-clobber\n"); + fprintf(stdout, " if error on loading prompt cache, treat as new file\n"); fprintf(stdout, " --random-prompt start with a randomized prompt.\n"); fprintf(stdout, " --in-prefix-bos prefix BOS to user inputs, preceding the `--in-prefix` string\n"); fprintf(stdout, " --in-prefix STRING string to prefix user inputs with (default: empty)\n"); diff --git a/examples/common.h b/examples/common.h index 375bc0a3d..cd4faa71b 100644 --- a/examples/common.h +++ b/examples/common.h @@ -68,14 +68,15 @@ struct gpt_params { bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score - bool low_vram = false; // if true, reduce VRAM usage at the cost of performance - bool mul_mat_q = false; // if true, use experimental mul_mat_q kernels - bool memory_f16 = true; // use f16 instead of f32 for memory kv - bool random_prompt = false; // do not randomize prompt if none provided - bool use_color = false; // use color to distinguish generations and inputs - bool interactive = false; // interactive mode - bool prompt_cache_all = false; // save user input and generations to prompt cache - bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it + bool low_vram = false; // if true, reduce VRAM usage at the cost of performance + bool mul_mat_q = false; // if true, use experimental mul_mat_q kernels + bool memory_f16 = true; // use f16 instead of f32 for memory kv + bool random_prompt = false; // do not randomize prompt if none provided + bool use_color = false; // use color to distinguish generations and inputs + bool interactive = false; // interactive mode + bool prompt_cache_all = false; // save user input and generations to prompt cache + bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it + bool prompt_cache_clobber = false; // if error on loading prompt cache, treat as new file bool embedding = false; // get only sentence embedding bool interactive_first = false; // wait for user input immediately diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 56ada7e69..fa7c2de12 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -167,26 +167,35 @@ int main(int argc, char ** argv) { std::vector session_tokens; if (!path_session.empty()) { - fprintf(stderr, "%s: attempting to load saved session from '%s'\n", __func__, path_session.c_str()); + auto load_session = [&ctx, ¶ms, &path_session, &session_tokens]() { + fprintf(stderr, "load_session: attempting to load saved session from '%s'\n" , path_session.c_str()); - // fopen to check for existing session - FILE * fp = std::fopen(path_session.c_str(), "rb"); - if (fp != NULL) { + // fopen to check for existing session + FILE * fp = std::fopen(path_session.c_str(), "rb"); + if (fp == NULL) { + fprintf(stderr, "load_session: session file does not exist, will create\n"); + return; + } std::fclose(fp); session_tokens.resize(params.n_ctx); size_t n_token_count_out = 0; - if (!llama_load_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) { - fprintf(stderr, "%s: error: failed to load session file '%s'\n", __func__, path_session.c_str()); - return 1; + if (llama_load_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) { + session_tokens.resize(n_token_count_out); + llama_set_rng_seed(ctx, params.seed); + fprintf(stderr, "load_session: loaded a session with prompt size of %d tokens\n", (int)session_tokens.size()); + return; } - session_tokens.resize(n_token_count_out); - llama_set_rng_seed(ctx, params.seed); - fprintf(stderr, "%s: loaded a session with prompt size of %d tokens\n", __func__, (int) session_tokens.size()); - } else { - fprintf(stderr, "%s: session file does not exist, will create\n", __func__); - } + fprintf(stderr, "load_session: error: failed to load session file '%s'\n", path_session.c_str()); + if (params.prompt_cache_clobber) { + fprintf(stderr, "load_session: attempting to clobber session file\n"); + } else { + fprintf(stderr, "load_session: use --prompt-cache-clobber to overwrite this file\n"); + std::exit(1); + } + }; + load_session(); } // tokenize the prompt