examples/main: Add --prompt-cache-clobber parameter

This commit is contained in:
crasm 2023-08-07 21:12:43 -04:00
parent f3c3b4b167
commit acea8e10a3
3 changed files with 35 additions and 21 deletions

View file

@ -137,6 +137,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.prompt_cache_all = true; params.prompt_cache_all = true;
} else if (arg == "--prompt-cache-ro") { } else if (arg == "--prompt-cache-ro") {
params.prompt_cache_ro = true; params.prompt_cache_ro = true;
} else if (arg == "--prompt-cache-clobber") {
params.prompt_cache_clobber = true;
} else if (arg == "-f" || arg == "--file") { } else if (arg == "-f" || arg == "--file") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
@ -537,6 +539,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stdout, " --prompt-cache-all if specified, saves user input and generations to cache as well.\n"); fprintf(stdout, " --prompt-cache-all if specified, saves user input and generations to cache as well.\n");
fprintf(stdout, " not supported with --interactive or other interactive options\n"); fprintf(stdout, " not supported with --interactive or other interactive options\n");
fprintf(stdout, " --prompt-cache-ro if specified, uses the prompt cache but does not update it.\n"); fprintf(stdout, " --prompt-cache-ro if specified, uses the prompt cache but does not update it.\n");
fprintf(stdout, " --prompt-cache-clobber\n");
fprintf(stdout, " if error on loading prompt cache, treat as new file\n");
fprintf(stdout, " --random-prompt start with a randomized prompt.\n"); fprintf(stdout, " --random-prompt start with a randomized prompt.\n");
fprintf(stdout, " --in-prefix-bos prefix BOS to user inputs, preceding the `--in-prefix` string\n"); fprintf(stdout, " --in-prefix-bos prefix BOS to user inputs, preceding the `--in-prefix` string\n");
fprintf(stdout, " --in-prefix STRING string to prefix user inputs with (default: empty)\n"); fprintf(stdout, " --in-prefix STRING string to prefix user inputs with (default: empty)\n");

View file

@ -68,14 +68,15 @@ struct gpt_params {
bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt
size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score
bool low_vram = false; // if true, reduce VRAM usage at the cost of performance bool low_vram = false; // if true, reduce VRAM usage at the cost of performance
bool mul_mat_q = false; // if true, use experimental mul_mat_q kernels bool mul_mat_q = false; // if true, use experimental mul_mat_q kernels
bool memory_f16 = true; // use f16 instead of f32 for memory kv bool memory_f16 = true; // use f16 instead of f32 for memory kv
bool random_prompt = false; // do not randomize prompt if none provided bool random_prompt = false; // do not randomize prompt if none provided
bool use_color = false; // use color to distinguish generations and inputs bool use_color = false; // use color to distinguish generations and inputs
bool interactive = false; // interactive mode bool interactive = false; // interactive mode
bool prompt_cache_all = false; // save user input and generations to prompt cache bool prompt_cache_all = false; // save user input and generations to prompt cache
bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it
bool prompt_cache_clobber = false; // if error on loading prompt cache, treat as new file
bool embedding = false; // get only sentence embedding bool embedding = false; // get only sentence embedding
bool interactive_first = false; // wait for user input immediately bool interactive_first = false; // wait for user input immediately

View file

@ -167,26 +167,35 @@ int main(int argc, char ** argv) {
std::vector<llama_token> session_tokens; std::vector<llama_token> session_tokens;
if (!path_session.empty()) { if (!path_session.empty()) {
fprintf(stderr, "%s: attempting to load saved session from '%s'\n", __func__, path_session.c_str()); auto load_session = [&ctx, &params, &path_session, &session_tokens]() {
fprintf(stderr, "load_session: attempting to load saved session from '%s'\n" , path_session.c_str());
// fopen to check for existing session // fopen to check for existing session
FILE * fp = std::fopen(path_session.c_str(), "rb"); FILE * fp = std::fopen(path_session.c_str(), "rb");
if (fp != NULL) { if (fp == NULL) {
fprintf(stderr, "load_session: session file does not exist, will create\n");
return;
}
std::fclose(fp); std::fclose(fp);
session_tokens.resize(params.n_ctx); session_tokens.resize(params.n_ctx);
size_t n_token_count_out = 0; size_t n_token_count_out = 0;
if (!llama_load_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) { if (llama_load_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) {
fprintf(stderr, "%s: error: failed to load session file '%s'\n", __func__, path_session.c_str()); session_tokens.resize(n_token_count_out);
return 1; llama_set_rng_seed(ctx, params.seed);
fprintf(stderr, "load_session: loaded a session with prompt size of %d tokens\n", (int)session_tokens.size());
return;
} }
session_tokens.resize(n_token_count_out);
llama_set_rng_seed(ctx, params.seed);
fprintf(stderr, "%s: loaded a session with prompt size of %d tokens\n", __func__, (int) session_tokens.size()); fprintf(stderr, "load_session: error: failed to load session file '%s'\n", path_session.c_str());
} else { if (params.prompt_cache_clobber) {
fprintf(stderr, "%s: session file does not exist, will create\n", __func__); fprintf(stderr, "load_session: attempting to clobber session file\n");
} } else {
fprintf(stderr, "load_session: use --prompt-cache-clobber to overwrite this file\n");
std::exit(1);
}
};
load_session();
} }
// tokenize the prompt // tokenize the prompt