Initial implementation

This commit is contained in:
Bach Le 2023-07-07 21:35:46 +08:00
parent 061f5f8d21
commit d09d5ed640
8 changed files with 148 additions and 16 deletions

View file

@ -236,6 +236,24 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break; break;
} }
params.mirostat_tau = std::stof(argv[i]); params.mirostat_tau = std::stof(argv[i]);
} else if (arg == "--cfg-negative-prompt") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.cfg_negative_prompt = argv[i];
} else if (arg == "--cfg-scale") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.cfg_scale = std::stof(argv[i]);
} else if (arg == "--cfg-smooth-factor") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.cfg_smooth_factor = std::stof(argv[i]);
} else if (arg == "-b" || arg == "--batch-size") { } else if (arg == "-b" || arg == "--batch-size") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
@ -468,6 +486,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " modifies the likelihood of token appearing in the completion,\n"); fprintf(stderr, " modifies the likelihood of token appearing in the completion,\n");
fprintf(stderr, " i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n"); fprintf(stderr, " i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n");
fprintf(stderr, " or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n"); fprintf(stderr, " or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n");
fprintf(stderr, " --cfg-negative-prompt PROMPT \n");
fprintf(stderr, " negative prompt to use for guidance. (default: empty)\n");
fprintf(stderr, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale);
fprintf(stderr, " --cfg-smooth-factor N smooth factor between old and new logits (default: %f, 1.0 = no smoothing)\n", params.cfg_smooth_factor);
fprintf(stderr, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx); fprintf(stderr, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n"); fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
fprintf(stderr, " --no-penalize-nl do not penalize newline token\n"); fprintf(stderr, " --no-penalize-nl do not penalize newline token\n");
@ -534,7 +556,7 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
return res; return res;
} }
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(const gpt_params & params) { std::tuple<struct llama_model *, struct llama_context *, struct llama_context_params> llama_init_from_gpt_params(const gpt_params & params) {
auto lparams = llama_context_default_params(); auto lparams = llama_context_default_params();
lparams.n_ctx = params.n_ctx; lparams.n_ctx = params.n_ctx;
@ -553,14 +575,14 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
llama_model * model = llama_load_model_from_file(params.model.c_str(), lparams); llama_model * model = llama_load_model_from_file(params.model.c_str(), lparams);
if (model == NULL) { if (model == NULL) {
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
return std::make_tuple(nullptr, nullptr); return std::make_tuple(nullptr, nullptr, lparams);
} }
llama_context * lctx = llama_new_context_with_model(model, lparams); llama_context * lctx = llama_new_context_with_model(model, lparams);
if (lctx == NULL) { if (lctx == NULL) {
fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str()); fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str());
llama_free_model(model); llama_free_model(model);
return std::make_tuple(nullptr, nullptr); return std::make_tuple(nullptr, nullptr, lparams);
} }
if (!params.lora_adapter.empty()) { if (!params.lora_adapter.empty()) {
@ -572,11 +594,11 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__); fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
llama_free(lctx); llama_free(lctx);
llama_free_model(model); llama_free_model(model);
return std::make_tuple(nullptr, nullptr); return std::make_tuple(nullptr, nullptr, lparams);
} }
} }
return std::make_tuple(model, lctx); return std::make_tuple(model, lctx, lparams);
} }
void console_init(console_state & con_st) { void console_init(console_state & con_st) {

View file

@ -48,6 +48,12 @@ struct gpt_params {
float mirostat_tau = 5.00f; // target entropy float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate float mirostat_eta = 0.10f; // learning rate
// Classifier-Free Guidance
// https://arxiv.org/abs/2306.17806
std::string cfg_negative_prompt; // string to help guidance
float cfg_scale = 1.f; // How strong is guidance
float cfg_smooth_factor = 1.f; // Smooth factor between old and new logits
std::string model = "models/7B/ggml-model.bin"; // model path std::string model = "models/7B/ggml-model.bin"; // model path
std::string model_alias = "unknown"; // model alias std::string model_alias = "unknown"; // model alias
std::string prompt = ""; std::string prompt = "";
@ -98,7 +104,7 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
// Model utils // Model utils
// //
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(const gpt_params & params); std::tuple<struct llama_model *, struct llama_context *, struct llama_context_params> llama_init_from_gpt_params(const gpt_params & params);
// //
// Console utils // Console utils

View file

@ -42,7 +42,7 @@ struct MyModel* create_mymodel(int argc, char ** argv) {
g_ctx = &ctx; g_ctx = &ctx;
// load the model and apply lora adapter, if any // load the model and apply lora adapter, if any
std::tie(model, ctx) = llama_init_from_gpt_params(params); std::tie(model, ctx, std::ignore) = llama_init_from_gpt_params(params);
if (model == NULL) { if (model == NULL) {
fprintf(stderr, "%s: error: unable to load model\n", __func__); fprintf(stderr, "%s: error: unable to load model\n", __func__);
return nullptr; return nullptr;

View file

@ -41,7 +41,7 @@ int main(int argc, char ** argv) {
llama_context * ctx; llama_context * ctx;
// load the model // load the model
std::tie(model, ctx) = llama_init_from_gpt_params(params); std::tie(model, ctx, std::ignore) = llama_init_from_gpt_params(params);
if (model == NULL) { if (model == NULL) {
fprintf(stderr, "%s: error: unable to load model\n", __func__); fprintf(stderr, "%s: error: unable to load model\n", __func__);
return 1; return 1;

View file

@ -54,6 +54,20 @@ void sigint_handler(int signo) {
} }
#endif #endif
void inplace_log_softmax(float* logits, int n_vocab) {
float sum = 0.f;
for (int i = 0; i < n_vocab; ++i) {
float p = expf(logits[i]);
logits[i] = p;
sum += p;
}
for (int i = 0; i < n_vocab; ++i) {
float p = logits[i];
logits[i] = logf(p/ sum);
}
}
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
gpt_params params; gpt_params params;
@ -109,10 +123,16 @@ int main(int argc, char ** argv) {
llama_model * model; llama_model * model;
llama_context * ctx; llama_context * ctx;
llama_context * guidance_ctx = NULL;
struct llama_context_params lparams;
g_ctx = &ctx; g_ctx = &ctx;
// load the model and apply lora adapter, if any // load the model and apply lora adapter, if any
std::tie(model, ctx) = llama_init_from_gpt_params(params); std::tie(model, ctx, lparams) = llama_init_from_gpt_params(params);
if (params.cfg_scale > 1.f) {
guidance_ctx = llama_new_context_with_model(model, lparams);
}
if (model == NULL) { if (model == NULL) {
fprintf(stderr, "%s: error: unable to load model\n", __func__); fprintf(stderr, "%s: error: unable to load model\n", __func__);
return 1; return 1;
@ -183,15 +203,28 @@ int main(int argc, char ** argv) {
// tokenize the prompt // tokenize the prompt
std::vector<llama_token> embd_inp; std::vector<llama_token> embd_inp;
if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) {
// Add a space in front of the first character to match OG llama tokenizer behavior // Add a space in front of the first character to match OG llama tokenizer behavior
params.prompt.insert(0, 1, ' '); params.prompt.insert(0, 1, ' ');
if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) {
embd_inp = ::llama_tokenize(ctx, params.prompt, true); embd_inp = ::llama_tokenize(ctx, params.prompt, true);
} else { } else {
embd_inp = session_tokens; embd_inp = session_tokens;
} }
// Tokenize negative prompt
std::vector<llama_token> guidance_inp;
int guidance_offset = 0;
int original_prompt_len = 0;
if (guidance_ctx) {
params.cfg_negative_prompt.insert(0, 1, ' ');
guidance_inp = ::llama_tokenize(guidance_ctx, params.cfg_negative_prompt, true);
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true);
original_prompt_len = original_inp.size();
guidance_offset = (int)guidance_inp.size() - original_prompt_len;
}
const int n_ctx = llama_n_ctx(ctx); const int n_ctx = llama_n_ctx(ctx);
if ((int) embd_inp.size() > n_ctx - 4) { if ((int) embd_inp.size() > n_ctx - 4) {
@ -258,6 +291,16 @@ int main(int argc, char ** argv) {
for (int i = 0; i < (int) embd_inp.size(); i++) { for (int i = 0; i < (int) embd_inp.size(); i++) {
fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i])); fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i]));
} }
if (guidance_ctx) {
fprintf(stderr, "\n");
fprintf(stderr, "%s: negative prompt: '%s'\n", __func__, params.cfg_negative_prompt.c_str());
fprintf(stderr, "%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size());
for (int i = 0; i < (int) guidance_inp.size(); i++) {
fprintf(stderr, "%6d -> '%s'\n", guidance_inp[i], llama_token_to_str(ctx, guidance_inp[i]));
}
}
if (params.n_keep > 0) { if (params.n_keep > 0) {
fprintf(stderr, "%s: static prompt based on n_keep: '", __func__); fprintf(stderr, "%s: static prompt based on n_keep: '", __func__);
for (int i = 0; i < params.n_keep; i++) { for (int i = 0; i < params.n_keep; i++) {
@ -334,11 +377,13 @@ int main(int argc, char ** argv) {
int n_remain = params.n_predict; int n_remain = params.n_predict;
int n_consumed = 0; int n_consumed = 0;
int n_session_consumed = 0; int n_session_consumed = 0;
int guidance_n_past = 0;
// the first thing we will do is to output the prompt, so set color accordingly // the first thing we will do is to output the prompt, so set color accordingly
console_set_color(con_st, CONSOLE_COLOR_PROMPT); console_set_color(con_st, CONSOLE_COLOR_PROMPT);
std::vector<llama_token> embd; std::vector<llama_token> embd;
std::vector<llama_token> guidance_embd;
// do one empty run to warm up the model // do one empty run to warm up the model
{ {
@ -367,11 +412,12 @@ int main(int argc, char ** argv) {
// if we run out of context: // if we run out of context:
// - take the n_keep first tokens from the original prompt (via n_past) // - take the n_keep first tokens from the original prompt (via n_past)
// - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
if (n_past + (int) embd.size() > n_ctx) { if (n_past + (int) embd.size() + std::max<int>(0, guidance_offset) > n_ctx) {
const int n_left = n_past - params.n_keep; const int n_left = n_past - params.n_keep;
// always keep the first token - BOS // always keep the first token - BOS
n_past = std::max(1, params.n_keep); n_past = std::max(1, params.n_keep);
guidance_n_past = std::max(1, params.n_keep + guidance_offset);
// insert n_left/2 tokens at the start of embd from last_n_tokens // insert n_left/2 tokens at the start of embd from last_n_tokens
embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_n_tokens.end() - embd.size()); embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_n_tokens.end() - embd.size());
@ -412,6 +458,48 @@ int main(int argc, char ** argv) {
// evaluate tokens in batches // evaluate tokens in batches
// embd is typically prepared beforehand to fit within a batch, but not always // embd is typically prepared beforehand to fit within a batch, but not always
if (guidance_ctx) {
int input_size = 0;
llama_token* input_buf = NULL;
if (guidance_n_past < (int) guidance_inp.size()) {
// Guidance context should have the same data with these modifications:
//
// * Replace the initial prompt
// * Shift everything by guidance_offset
guidance_embd = guidance_inp;
if (embd.begin() + original_prompt_len < embd.end()) {
guidance_embd.insert(
guidance_embd.end(),
embd.begin() + original_prompt_len,
embd.end()
);
}
input_buf = guidance_embd.data();
input_size = guidance_embd.size();
fprintf(stderr, "\n---------------------\n");
for (int i = 0; i < (int) guidance_embd.size(); i++) {
fprintf(stderr, "%s", llama_token_to_str(ctx, guidance_embd[i]));
}
fprintf(stderr, "\n---------------------\n");
} else {
input_buf = embd.data();
input_size = embd.size();
}
for (int i = 0; i < input_size; i += params.n_batch) {
int n_eval = std::min(input_size - i, params.n_batch);
if (llama_eval(guidance_ctx, input_buf + i, n_eval, guidance_n_past, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return 1;
}
guidance_n_past += n_eval;
}
}
for (int i = 0; i < (int) embd.size(); i += params.n_batch) { for (int i = 0; i < (int) embd.size(); i += params.n_batch) {
int n_eval = (int) embd.size() - i; int n_eval = (int) embd.size() - i;
if (n_eval > params.n_batch) { if (n_eval > params.n_batch) {
@ -431,6 +519,7 @@ int main(int argc, char ** argv) {
} }
embd.clear(); embd.clear();
guidance_embd.clear();
if ((int) embd_inp.size() <= n_consumed && !is_interacting) { if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
// out of user input, sample next token // out of user input, sample next token
@ -465,6 +554,21 @@ int main(int argc, char ** argv) {
logits[it->first] += it->second; logits[it->first] += it->second;
} }
if (guidance_ctx) {
inplace_log_softmax(logits, n_vocab);
auto* guidance_logits = llama_get_logits(guidance_ctx);
inplace_log_softmax(guidance_logits, n_vocab);
for (int i = 0; i < n_vocab; ++i) {
guidance_logits[i] = params.cfg_scale * (logits[i] - guidance_logits[i]) + guidance_logits[i];
}
inplace_log_softmax(guidance_logits, n_vocab);
for (int i = 0; i < n_vocab; ++i) {
logits[i] = guidance_logits[i] * params.cfg_smooth_factor + logits[i] * (1 - params.cfg_smooth_factor);
}
}
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab); candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) { for (llama_token token_id = 0; token_id < n_vocab; token_id++) {

View file

@ -153,7 +153,7 @@ int main(int argc, char ** argv) {
llama_context * ctx; llama_context * ctx;
// load the model and apply lora adapter, if any // load the model and apply lora adapter, if any
std::tie(model, ctx) = llama_init_from_gpt_params(params); std::tie(model, ctx, std::ignore) = llama_init_from_gpt_params(params);
if (model == NULL) { if (model == NULL) {
fprintf(stderr, "%s: error: unable to load model\n", __func__); fprintf(stderr, "%s: error: unable to load model\n", __func__);
return 1; return 1;

View file

@ -245,7 +245,7 @@ struct llama_server_context
bool loadModel(const gpt_params &params_) bool loadModel(const gpt_params &params_)
{ {
params = params_; params = params_;
std::tie(model, ctx) = llama_init_from_gpt_params(params); std::tie(model, ctx, std::ignore) = llama_init_from_gpt_params(params);
if (model == nullptr) if (model == nullptr)
{ {
LOG_ERROR("unable to load model", {{"model", params_.model}}); LOG_ERROR("unable to load model", {{"model", params_.model}});

View file

@ -71,7 +71,7 @@ int main(int argc, char ** argv)
llama_model * model; llama_model * model;
llama_context * ctx; llama_context * ctx;
std::tie(model, ctx) = llama_init_from_gpt_params( params ); std::tie(model, ctx, std::ignore) = llama_init_from_gpt_params( params );
if ( model == NULL ) if ( model == NULL )
{ {