From d09d5ed640583b3504f69926c72334e97aa45b86 Mon Sep 17 00:00:00 2001 From: Bach Le Date: Fri, 7 Jul 2023 21:35:46 +0800 Subject: [PATCH] Initial implementation --- examples/common.cpp | 32 +++++-- examples/common.h | 8 +- examples/embd-input/embd-input-lib.cpp | 2 +- examples/embedding/embedding.cpp | 2 +- examples/main/main.cpp | 114 +++++++++++++++++++++++-- examples/perplexity/perplexity.cpp | 2 +- examples/server/server.cpp | 2 +- examples/simple/simple.cpp | 2 +- 8 files changed, 148 insertions(+), 16 deletions(-) diff --git a/examples/common.cpp b/examples/common.cpp index 3278a0643..35be2b5aa 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -236,6 +236,24 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.mirostat_tau = std::stof(argv[i]); + } else if (arg == "--cfg-negative-prompt") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.cfg_negative_prompt = argv[i]; + } else if (arg == "--cfg-scale") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.cfg_scale = std::stof(argv[i]); + } else if (arg == "--cfg-smooth-factor") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.cfg_smooth_factor = std::stof(argv[i]); } else if (arg == "-b" || arg == "--batch-size") { if (++i >= argc) { invalid_param = true; @@ -468,6 +486,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " modifies the likelihood of token appearing in the completion,\n"); fprintf(stderr, " i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n"); fprintf(stderr, " or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n"); + fprintf(stderr, " --cfg-negative-prompt PROMPT \n"); + fprintf(stderr, " negative prompt to use for guidance. (default: empty)\n"); + fprintf(stderr, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale); + fprintf(stderr, " --cfg-smooth-factor N smooth factor between old and new logits (default: %f, 1.0 = no smoothing)\n", params.cfg_smooth_factor); fprintf(stderr, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx); fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n"); fprintf(stderr, " --no-penalize-nl do not penalize newline token\n"); @@ -534,7 +556,7 @@ std::vector llama_tokenize(struct llama_context * ctx, const std::s return res; } -std::tuple llama_init_from_gpt_params(const gpt_params & params) { +std::tuple llama_init_from_gpt_params(const gpt_params & params) { auto lparams = llama_context_default_params(); lparams.n_ctx = params.n_ctx; @@ -553,14 +575,14 @@ std::tuple llama_init_from_gpt_par llama_model * model = llama_load_model_from_file(params.model.c_str(), lparams); if (model == NULL) { fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); - return std::make_tuple(nullptr, nullptr); + return std::make_tuple(nullptr, nullptr, lparams); } llama_context * lctx = llama_new_context_with_model(model, lparams); if (lctx == NULL) { fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str()); llama_free_model(model); - return std::make_tuple(nullptr, nullptr); + return std::make_tuple(nullptr, nullptr, lparams); } if (!params.lora_adapter.empty()) { @@ -572,11 +594,11 @@ std::tuple llama_init_from_gpt_par fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__); llama_free(lctx); llama_free_model(model); - return std::make_tuple(nullptr, nullptr); + return std::make_tuple(nullptr, nullptr, lparams); } } - return std::make_tuple(model, lctx); + return std::make_tuple(model, lctx, lparams); } void console_init(console_state & con_st) { diff --git a/examples/common.h b/examples/common.h index 96f2228f8..bed576438 100644 --- a/examples/common.h +++ b/examples/common.h @@ -48,6 +48,12 @@ struct gpt_params { float mirostat_tau = 5.00f; // target entropy float mirostat_eta = 0.10f; // learning rate + // Classifier-Free Guidance + // https://arxiv.org/abs/2306.17806 + std::string cfg_negative_prompt; // string to help guidance + float cfg_scale = 1.f; // How strong is guidance + float cfg_smooth_factor = 1.f; // Smooth factor between old and new logits + std::string model = "models/7B/ggml-model.bin"; // model path std::string model_alias = "unknown"; // model alias std::string prompt = ""; @@ -98,7 +104,7 @@ std::vector llama_tokenize(struct llama_context * ctx, const std::s // Model utils // -std::tuple llama_init_from_gpt_params(const gpt_params & params); +std::tuple llama_init_from_gpt_params(const gpt_params & params); // // Console utils diff --git a/examples/embd-input/embd-input-lib.cpp b/examples/embd-input/embd-input-lib.cpp index 5fa4942be..576ac0af0 100644 --- a/examples/embd-input/embd-input-lib.cpp +++ b/examples/embd-input/embd-input-lib.cpp @@ -42,7 +42,7 @@ struct MyModel* create_mymodel(int argc, char ** argv) { g_ctx = &ctx; // load the model and apply lora adapter, if any - std::tie(model, ctx) = llama_init_from_gpt_params(params); + std::tie(model, ctx, std::ignore) = llama_init_from_gpt_params(params); if (model == NULL) { fprintf(stderr, "%s: error: unable to load model\n", __func__); return nullptr; diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 03e801c2a..7b1135e6a 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -41,7 +41,7 @@ int main(int argc, char ** argv) { llama_context * ctx; // load the model - std::tie(model, ctx) = llama_init_from_gpt_params(params); + std::tie(model, ctx, std::ignore) = llama_init_from_gpt_params(params); if (model == NULL) { fprintf(stderr, "%s: error: unable to load model\n", __func__); return 1; diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 0f6391acb..65ead0a00 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -54,6 +54,20 @@ void sigint_handler(int signo) { } #endif +void inplace_log_softmax(float* logits, int n_vocab) { + float sum = 0.f; + for (int i = 0; i < n_vocab; ++i) { + float p = expf(logits[i]); + logits[i] = p; + sum += p; + } + + for (int i = 0; i < n_vocab; ++i) { + float p = logits[i]; + logits[i] = logf(p/ sum); + } +} + int main(int argc, char ** argv) { gpt_params params; @@ -109,10 +123,16 @@ int main(int argc, char ** argv) { llama_model * model; llama_context * ctx; + llama_context * guidance_ctx = NULL; + struct llama_context_params lparams; g_ctx = &ctx; // load the model and apply lora adapter, if any - std::tie(model, ctx) = llama_init_from_gpt_params(params); + std::tie(model, ctx, lparams) = llama_init_from_gpt_params(params); + if (params.cfg_scale > 1.f) { + guidance_ctx = llama_new_context_with_model(model, lparams); + } + if (model == NULL) { fprintf(stderr, "%s: error: unable to load model\n", __func__); return 1; @@ -183,15 +203,28 @@ int main(int argc, char ** argv) { // tokenize the prompt std::vector embd_inp; - if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) { - // Add a space in front of the first character to match OG llama tokenizer behavior - params.prompt.insert(0, 1, ' '); + // Add a space in front of the first character to match OG llama tokenizer behavior + params.prompt.insert(0, 1, ' '); + if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) { embd_inp = ::llama_tokenize(ctx, params.prompt, true); } else { embd_inp = session_tokens; } + // Tokenize negative prompt + std::vector guidance_inp; + int guidance_offset = 0; + int original_prompt_len = 0; + if (guidance_ctx) { + params.cfg_negative_prompt.insert(0, 1, ' '); + guidance_inp = ::llama_tokenize(guidance_ctx, params.cfg_negative_prompt, true); + + std::vector original_inp = ::llama_tokenize(ctx, params.prompt, true); + original_prompt_len = original_inp.size(); + guidance_offset = (int)guidance_inp.size() - original_prompt_len; + } + const int n_ctx = llama_n_ctx(ctx); if ((int) embd_inp.size() > n_ctx - 4) { @@ -258,6 +291,16 @@ int main(int argc, char ** argv) { for (int i = 0; i < (int) embd_inp.size(); i++) { fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i])); } + + if (guidance_ctx) { + fprintf(stderr, "\n"); + fprintf(stderr, "%s: negative prompt: '%s'\n", __func__, params.cfg_negative_prompt.c_str()); + fprintf(stderr, "%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size()); + for (int i = 0; i < (int) guidance_inp.size(); i++) { + fprintf(stderr, "%6d -> '%s'\n", guidance_inp[i], llama_token_to_str(ctx, guidance_inp[i])); + } + } + if (params.n_keep > 0) { fprintf(stderr, "%s: static prompt based on n_keep: '", __func__); for (int i = 0; i < params.n_keep; i++) { @@ -334,11 +377,13 @@ int main(int argc, char ** argv) { int n_remain = params.n_predict; int n_consumed = 0; int n_session_consumed = 0; + int guidance_n_past = 0; // the first thing we will do is to output the prompt, so set color accordingly console_set_color(con_st, CONSOLE_COLOR_PROMPT); std::vector embd; + std::vector guidance_embd; // do one empty run to warm up the model { @@ -367,11 +412,12 @@ int main(int argc, char ** argv) { // if we run out of context: // - take the n_keep first tokens from the original prompt (via n_past) // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches - if (n_past + (int) embd.size() > n_ctx) { + if (n_past + (int) embd.size() + std::max(0, guidance_offset) > n_ctx) { const int n_left = n_past - params.n_keep; // always keep the first token - BOS n_past = std::max(1, params.n_keep); + guidance_n_past = std::max(1, params.n_keep + guidance_offset); // insert n_left/2 tokens at the start of embd from last_n_tokens embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_n_tokens.end() - embd.size()); @@ -412,6 +458,48 @@ int main(int argc, char ** argv) { // evaluate tokens in batches // embd is typically prepared beforehand to fit within a batch, but not always + + if (guidance_ctx) { + int input_size = 0; + llama_token* input_buf = NULL; + + if (guidance_n_past < (int) guidance_inp.size()) { + // Guidance context should have the same data with these modifications: + // + // * Replace the initial prompt + // * Shift everything by guidance_offset + guidance_embd = guidance_inp; + if (embd.begin() + original_prompt_len < embd.end()) { + guidance_embd.insert( + guidance_embd.end(), + embd.begin() + original_prompt_len, + embd.end() + ); + } + + input_buf = guidance_embd.data(); + input_size = guidance_embd.size(); + fprintf(stderr, "\n---------------------\n"); + for (int i = 0; i < (int) guidance_embd.size(); i++) { + fprintf(stderr, "%s", llama_token_to_str(ctx, guidance_embd[i])); + } + fprintf(stderr, "\n---------------------\n"); + } else { + input_buf = embd.data(); + input_size = embd.size(); + } + + for (int i = 0; i < input_size; i += params.n_batch) { + int n_eval = std::min(input_size - i, params.n_batch); + if (llama_eval(guidance_ctx, input_buf + i, n_eval, guidance_n_past, params.n_threads)) { + fprintf(stderr, "%s : failed to eval\n", __func__); + return 1; + } + + guidance_n_past += n_eval; + } + } + for (int i = 0; i < (int) embd.size(); i += params.n_batch) { int n_eval = (int) embd.size() - i; if (n_eval > params.n_batch) { @@ -431,6 +519,7 @@ int main(int argc, char ** argv) { } embd.clear(); + guidance_embd.clear(); if ((int) embd_inp.size() <= n_consumed && !is_interacting) { // out of user input, sample next token @@ -465,6 +554,21 @@ int main(int argc, char ** argv) { logits[it->first] += it->second; } + if (guidance_ctx) { + inplace_log_softmax(logits, n_vocab); + auto* guidance_logits = llama_get_logits(guidance_ctx); + inplace_log_softmax(guidance_logits, n_vocab); + + for (int i = 0; i < n_vocab; ++i) { + guidance_logits[i] = params.cfg_scale * (logits[i] - guidance_logits[i]) + guidance_logits[i]; + } + inplace_log_softmax(guidance_logits, n_vocab); + + for (int i = 0; i < n_vocab; ++i) { + logits[i] = guidance_logits[i] * params.cfg_smooth_factor + logits[i] * (1 - params.cfg_smooth_factor); + } + } + std::vector candidates; candidates.reserve(n_vocab); for (llama_token token_id = 0; token_id < n_vocab; token_id++) { diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index fd4b03cb2..768c2b400 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -153,7 +153,7 @@ int main(int argc, char ** argv) { llama_context * ctx; // load the model and apply lora adapter, if any - std::tie(model, ctx) = llama_init_from_gpt_params(params); + std::tie(model, ctx, std::ignore) = llama_init_from_gpt_params(params); if (model == NULL) { fprintf(stderr, "%s: error: unable to load model\n", __func__); return 1; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 2cbfc0018..55cf1c94d 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -245,7 +245,7 @@ struct llama_server_context bool loadModel(const gpt_params ¶ms_) { params = params_; - std::tie(model, ctx) = llama_init_from_gpt_params(params); + std::tie(model, ctx, std::ignore) = llama_init_from_gpt_params(params); if (model == nullptr) { LOG_ERROR("unable to load model", {{"model", params_.model}}); diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index 2d913cebb..f59788865 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -71,7 +71,7 @@ int main(int argc, char ** argv) llama_model * model; llama_context * ctx; - std::tie(model, ctx) = llama_init_from_gpt_params( params ); + std::tie(model, ctx, std::ignore) = llama_init_from_gpt_params( params ); if ( model == NULL ) {