diff --git a/common/common.cpp b/common/common.cpp index 2597ba06a..61e285c4b 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -356,6 +356,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.interactive_first = true; } else if (arg == "-ins" || arg == "--instruct") { params.instruct = true; + } else if (arg == "--infill") { + params.infill = true; } else if (arg == "--multiline-input") { params.multiline_input = true; } else if (arg == "--simple-io") { diff --git a/common/common.h b/common/common.h index 18aea38ce..28a9ad359 100644 --- a/common/common.h +++ b/common/common.h @@ -118,6 +118,7 @@ struct gpt_params { bool numa = false; // attempt optimizations that help on some NUMA systems bool export_cgraph = false; // export the computation graph bool verbose_prompt = false; // print prompt tokens before generation + bool infill = false; // use infill mode }; bool gpt_params_parse(int argc, char ** argv, gpt_params & params); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index d78112260..365640ffb 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -239,8 +239,15 @@ int main(int argc, char ** argv) { LOG("add_bos: %d\n", add_bos); std::vector embd_inp; - - if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) { + if(params.infill) { + std::vector inp_pfx = ::llama_tokenize(ctx, params.input_prefix, add_bos); + std::vector inp_sfx = ::llama_tokenize(ctx, params.input_suffix, add_bos); + inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(ctx)); + inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(ctx)); + embd_inp = inp_pfx; + embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end()); + embd_inp.push_back(llama_token_middle(ctx)); + } else if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) { LOG("tokenize the prompt\n"); embd_inp = ::llama_tokenize(ctx, params.prompt, add_bos); } else { @@ -709,9 +716,58 @@ int main(int argc, char ** argv) { LOG("found antiprompt: %s\n", last_output.c_str()); } } - + // deal with eot token in infill mode + if ((last_tokens.back() == llama_token_eot(ctx) || is_interacting) && params.infill && params.interactive){ + if(is_interacting && !params.interactive_first) { + // print an eot token + printf("%s", llama_token_to_piece(ctx, llama_token_eot(ctx)).c_str()); + } + fflush(stdout); + printf("\n"); + console::set_display(console::user_input); + std::string buffer; + std::string line; + bool another_line=true; + // set a new prefix via stdin + do { + another_line = console::readline(line, params.multiline_input); + buffer += line; + } while (another_line); + // check if we got an empty line, if so we use the old input + if(!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) { + params.input_prefix = buffer; + } + buffer.clear(); + // set a new suffix via stdin + do { + another_line = console::readline(line, params.multiline_input); + buffer += line; + } while (another_line); + // check if we got an empty line + if(!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) { + params.input_suffix = buffer; + } + buffer.clear(); + // done taking input, reset color + console::set_display(console::reset); + // tokenize new prefix and suffix + std::vector inp_pfx = ::llama_tokenize(ctx, params.input_prefix, add_bos); + std::vector inp_sfx = ::llama_tokenize(ctx, params.input_suffix, add_bos); + inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(ctx)); + inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(ctx)); + embd_inp = inp_pfx; + embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end()); + embd_inp.push_back(llama_token_middle(ctx)); + embd.clear(); + embd_guidance.clear(); + n_remain = params.n_predict; + n_past = 0; + n_consumed = 0; + // LOG_TEE("took new input\n"); + is_interacting = false; + } // deal with end of text token in interactive mode - if (last_tokens.back() == llama_token_eos(ctx)) { + else if (last_tokens.back() == llama_token_eos(ctx)) { LOG("found EOS token\n"); if (params.interactive) { @@ -731,7 +787,7 @@ int main(int argc, char ** argv) { } } - if (n_past > 0 && is_interacting) { + if (n_past > 0 && is_interacting && !(params.infill && params.interactive)) { LOG("waiting for user input\n"); if (params.instruct) { @@ -825,17 +881,23 @@ int main(int argc, char ** argv) { // end of text token if (!embd.empty() && embd.back() == llama_token_eos(ctx) && !(params.instruct || params.interactive)) { - LOG_TEE(" [end of text]\n"); + if (!params.infill){ + LOG_TEE(" [end of text]\n"); + } break; } // In interactive mode, respect the maximum number of tokens and drop back to user input when reached. // We skip this logic when n_predict == -1 (infinite) or -2 (stop at context size). - if (params.interactive && n_remain <= 0 && params.n_predict >= 0) { + if (params.interactive && n_remain <= 0 && params.n_predict >= 0 ) { n_remain = params.n_predict; is_interacting = true; } } + if (params.infill && !params.interactive && n_remain <= 0) { + printf("%s", llama_token_to_piece(ctx, llama_token_eot(ctx)).c_str()); + fflush(stdout); + } if (!path_session.empty() && params.prompt_cache_all && !params.prompt_cache_ro) { LOG_TEE("\n%s: saving final output to session file '%s'\n", __func__, path_session.c_str()); diff --git a/examples/server/README.md b/examples/server/README.md index 517608046..d07f225e9 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -176,6 +176,16 @@ node index.js `content`: Set the text to process. + **POST** `/infill`: For code infilling. Takes a prefix and a suffix and returns the predicted completion as stream. + + *Options:* + + `input_prefix`: Set the prefix of the code to infill. + + `input_suffix`: Set the suffix of the code to infill. + + It also accepts all the options of `/completion` except `stream` and `prompt`. + ## More examples ### Interactive mode diff --git a/examples/server/server.cpp b/examples/server/server.cpp index ebd7f2fc5..3b7c57a4f 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -341,6 +341,71 @@ struct llama_server_context return true; } + void loadInfill() + { + auto prefix_tokens = tokenize(params.input_prefix, true); // always add BOS + auto suffix_tokens = tokenize(params.input_suffix, true); // always add BOS + prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(ctx)); + prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(ctx)); + prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end()); + prefix_tokens.push_back(llama_token_middle(ctx)); + auto prompt_tokens = prefix_tokens; + + num_prompt_tokens = prompt_tokens.size(); + + if (params.n_keep < 0) + { + params.n_keep = (int)num_prompt_tokens; + } + params.n_keep = std::min(params.n_ctx - 4, params.n_keep); + + // if input prompt is too big, truncate like normal + if (num_prompt_tokens >= (size_t)params.n_ctx) + { + printf("Input prompt is too big, truncating. Can only take %d tokens but got %zu\n", params.n_ctx, num_prompt_tokens); + // todo we probably want to cut from both sides + const int n_left = (params.n_ctx - params.n_keep) / 2; + std::vector new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); + const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left; + new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end()); + std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), last_n_tokens.begin()); + + LOG_VERBOSE("input truncated", { + {"n_ctx", params.n_ctx}, + {"n_keep", params.n_keep}, + {"n_left", n_left}, + {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, + }); + + truncated = true; + prompt_tokens = new_tokens; + } + else + { + const size_t ps = num_prompt_tokens; + std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0); + std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps); + } + + // compare the evaluated prompt with the new prompt + n_past = common_part(embd, prompt_tokens); + printf("n_past: %d\n", n_past); + embd = prompt_tokens; + if (n_past == num_prompt_tokens) + { + // we have to evaluate at least 1 token to generate logits. + printf("we have to evaluate at least 1 token to generate logits\n"); + n_past--; + } + + LOG_VERBOSE("prompt ingested", { + {"n_past", n_past}, + {"cached", tokens_to_str(ctx, embd.cbegin(), embd.cbegin() + n_past)}, + {"to_eval", tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend())}, + }); + + has_next_token = true; + } void loadPrompt() { auto prompt_tokens = tokenize(prompt, true); // always add BOS @@ -1199,6 +1264,27 @@ static void parse_options_completion(const json &body, llama_server_context &lla LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama)); } +static void parse_options_infill(const json &body, llama_server_context &llama) +{ + if (body.count("input_prefix") != 0) + { + llama.params.input_prefix = body["input_prefix"]; + } + else + { + llama.params.input_prefix = ""; + } + if (body.count("input_suffix") != 0) + { + llama.params.input_suffix = body["input_suffix"]; + } + else + { + llama.params.input_suffix = ""; + } + parse_options_completion(body, llama); +} + static void log_server_request(const Request &req, const Response &res) { LOG_INFO("request", { @@ -1498,6 +1584,127 @@ int main(int argc, char **argv) res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); } }); + svr.Post("/infill", [&llama](const Request &req, Response &res) + { + auto lock = llama.lock(); + + llama.rewind(); + + llama_reset_timings(llama.ctx); + + parse_options_infill(json::parse(req.body), llama); + + if (!llama.loadGrammar()) + { + res.status = 400; + return; + } + llama.loadInfill(); + llama.beginCompletion(); + const auto chunked_content_provider = [&](size_t, DataSink & sink) { + size_t sent_count = 0; + size_t sent_token_probs_index = 0; + + while (llama.has_next_token) { + const completion_token_output token_with_probs = llama.doCompletion(); + if (token_with_probs.tok == -1 || llama.multibyte_pending > 0) { + continue; + } + const std::string token_text = llama_token_to_piece(llama.ctx, token_with_probs.tok); + + size_t pos = std::min(sent_count, llama.generated_text.size()); + + const std::string str_test = llama.generated_text.substr(pos); + bool is_stop_full = false; + size_t stop_pos = + llama.findStoppingStrings(str_test, token_text.size(), STOP_FULL); + if (stop_pos != std::string::npos) { + is_stop_full = true; + llama.generated_text.erase( + llama.generated_text.begin() + pos + stop_pos, + llama.generated_text.end()); + pos = std::min(sent_count, llama.generated_text.size()); + } else { + is_stop_full = false; + stop_pos = llama.findStoppingStrings(str_test, token_text.size(), + STOP_PARTIAL); + } + + if ( + stop_pos == std::string::npos || + // Send rest of the text if we are at the end of the generation + (!llama.has_next_token && !is_stop_full && stop_pos > 0) + ) { + const std::string to_send = llama.generated_text.substr(pos, std::string::npos); + + sent_count += to_send.size(); + + std::vector probs_output = {}; + + if (llama.params.n_probs > 0) { + const std::vector to_send_toks = llama_tokenize(llama.ctx, to_send, false); + size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size()); + size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size()); + if (probs_pos < probs_stop_pos) { + probs_output = std::vector(llama.generated_token_probs.begin() + probs_pos, llama.generated_token_probs.begin() + probs_stop_pos); + } + sent_token_probs_index = probs_stop_pos; + } + + const json data = format_partial_response(llama, to_send, probs_output); + + const std::string str = + "data: " + + data.dump(-1, ' ', false, json::error_handler_t::replace) + + "\n\n"; + + LOG_VERBOSE("data stream", { + { "to_send", str } + }); + + if (!sink.write(str.data(), str.size())) { + LOG_VERBOSE("stream closed", {}); + llama_print_timings(llama.ctx); + return false; + } + } + + if (!llama.has_next_token) { + // Generation is done, send extra information. + const json data = format_final_response( + llama, + "", + std::vector(llama.generated_token_probs.begin(), llama.generated_token_probs.begin() + sent_token_probs_index) + ); + + const std::string str = + "data: " + + data.dump(-1, ' ', false, json::error_handler_t::replace) + + "\n\n"; + + LOG_VERBOSE("data stream", { + { "to_send", str } + }); + + if (!sink.write(str.data(), str.size())) { + LOG_VERBOSE("stream closed", {}); + llama_print_timings(llama.ctx); + return false; + } + } + } + + llama_print_timings(llama.ctx); + sink.done(); + return true; + }; + const auto on_complete = [&](bool) { + llama.mutex.unlock(); + }; + lock.release(); + res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); + }); + svr.Get("/model.json", [&llama](const Request &, Response &res) { const json data = format_generation_settings(llama); diff --git a/llama.cpp b/llama.cpp index 358bf5ec8..93afb3a57 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1052,6 +1052,10 @@ struct llama_vocab { id special_pad_id = -1; id linefeed_id = 13; + id special_prefix_id = 32007; + id special_middle_id = 32009; + id special_suffix_id = 32008; + id special_eot_id = 32010; int find_bpe_rank(std::string token_left, std::string token_right) const { replace_all(token_left, " ", "\u0120"); @@ -7024,6 +7028,22 @@ llama_token llama_token_eos(const struct llama_context * ctx) { llama_token llama_token_nl(const struct llama_context * ctx) { return ctx->model.vocab.linefeed_id; } +llama_token llama_token_prefix(const struct llama_context * ctx) { + return ctx->model.vocab.special_prefix_id; +} + +llama_token llama_token_middle(const struct llama_context * ctx) { + return ctx->model.vocab.special_middle_id; +} + +llama_token llama_token_suffix(const struct llama_context * ctx) { + return ctx->model.vocab.special_suffix_id; +} + +llama_token llama_token_eot(const struct llama_context * ctx) { + return ctx->model.vocab.special_eot_id; +} + int llama_tokenize( struct llama_context * ctx, diff --git a/llama.h b/llama.h index 369be048c..d7aa9d5ef 100644 --- a/llama.h +++ b/llama.h @@ -362,6 +362,11 @@ extern "C" { LLAMA_API llama_token llama_token_bos(const struct llama_context * ctx); // beginning-of-sentence LLAMA_API llama_token llama_token_eos(const struct llama_context * ctx); // end-of-sentence LLAMA_API llama_token llama_token_nl (const struct llama_context * ctx); // next-line + // codellama FIM tokens + LLAMA_API llama_token llama_token_prefix(const struct llama_context * ctx); // Beginning of FIM prefix + LLAMA_API llama_token llama_token_middle(const struct llama_context * ctx); // Beginning of FIM middle + LLAMA_API llama_token llama_token_suffix(const struct llama_context * ctx); // Beginning of FIM suffix + LLAMA_API llama_token llama_token_eot (const struct llama_context * ctx); // End of FIM middle // // Tokenization