vvhg-code-infill (#1)
This commit is contained in:
parent
7eb41179ed
commit
9d3514a2a6
7 changed files with 314 additions and 7 deletions
|
@ -356,6 +356,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||
params.interactive_first = true;
|
||||
} else if (arg == "-ins" || arg == "--instruct") {
|
||||
params.instruct = true;
|
||||
} else if (arg == "--infill") {
|
||||
params.infill = true;
|
||||
} else if (arg == "--multiline-input") {
|
||||
params.multiline_input = true;
|
||||
} else if (arg == "--simple-io") {
|
||||
|
|
|
@ -118,6 +118,7 @@ struct gpt_params {
|
|||
bool numa = false; // attempt optimizations that help on some NUMA systems
|
||||
bool export_cgraph = false; // export the computation graph
|
||||
bool verbose_prompt = false; // print prompt tokens before generation
|
||||
bool infill = false; // use infill mode
|
||||
};
|
||||
|
||||
bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
|
||||
|
|
|
@ -239,8 +239,15 @@ int main(int argc, char ** argv) {
|
|||
LOG("add_bos: %d\n", add_bos);
|
||||
|
||||
std::vector<llama_token> embd_inp;
|
||||
|
||||
if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) {
|
||||
if(params.infill) {
|
||||
std::vector<llama_token> inp_pfx = ::llama_tokenize(ctx, params.input_prefix, add_bos);
|
||||
std::vector<llama_token> inp_sfx = ::llama_tokenize(ctx, params.input_suffix, add_bos);
|
||||
inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(ctx));
|
||||
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(ctx));
|
||||
embd_inp = inp_pfx;
|
||||
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
|
||||
embd_inp.push_back(llama_token_middle(ctx));
|
||||
} else if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) {
|
||||
LOG("tokenize the prompt\n");
|
||||
embd_inp = ::llama_tokenize(ctx, params.prompt, add_bos);
|
||||
} else {
|
||||
|
@ -709,9 +716,58 @@ int main(int argc, char ** argv) {
|
|||
LOG("found antiprompt: %s\n", last_output.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
// deal with eot token in infill mode
|
||||
if ((last_tokens.back() == llama_token_eot(ctx) || is_interacting) && params.infill && params.interactive){
|
||||
if(is_interacting && !params.interactive_first) {
|
||||
// print an eot token
|
||||
printf("%s", llama_token_to_piece(ctx, llama_token_eot(ctx)).c_str());
|
||||
}
|
||||
fflush(stdout);
|
||||
printf("\n");
|
||||
console::set_display(console::user_input);
|
||||
std::string buffer;
|
||||
std::string line;
|
||||
bool another_line=true;
|
||||
// set a new prefix via stdin
|
||||
do {
|
||||
another_line = console::readline(line, params.multiline_input);
|
||||
buffer += line;
|
||||
} while (another_line);
|
||||
// check if we got an empty line, if so we use the old input
|
||||
if(!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) {
|
||||
params.input_prefix = buffer;
|
||||
}
|
||||
buffer.clear();
|
||||
// set a new suffix via stdin
|
||||
do {
|
||||
another_line = console::readline(line, params.multiline_input);
|
||||
buffer += line;
|
||||
} while (another_line);
|
||||
// check if we got an empty line
|
||||
if(!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) {
|
||||
params.input_suffix = buffer;
|
||||
}
|
||||
buffer.clear();
|
||||
// done taking input, reset color
|
||||
console::set_display(console::reset);
|
||||
// tokenize new prefix and suffix
|
||||
std::vector<llama_token> inp_pfx = ::llama_tokenize(ctx, params.input_prefix, add_bos);
|
||||
std::vector<llama_token> inp_sfx = ::llama_tokenize(ctx, params.input_suffix, add_bos);
|
||||
inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(ctx));
|
||||
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(ctx));
|
||||
embd_inp = inp_pfx;
|
||||
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
|
||||
embd_inp.push_back(llama_token_middle(ctx));
|
||||
embd.clear();
|
||||
embd_guidance.clear();
|
||||
n_remain = params.n_predict;
|
||||
n_past = 0;
|
||||
n_consumed = 0;
|
||||
// LOG_TEE("took new input\n");
|
||||
is_interacting = false;
|
||||
}
|
||||
// deal with end of text token in interactive mode
|
||||
if (last_tokens.back() == llama_token_eos(ctx)) {
|
||||
else if (last_tokens.back() == llama_token_eos(ctx)) {
|
||||
LOG("found EOS token\n");
|
||||
|
||||
if (params.interactive) {
|
||||
|
@ -731,7 +787,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
}
|
||||
|
||||
if (n_past > 0 && is_interacting) {
|
||||
if (n_past > 0 && is_interacting && !(params.infill && params.interactive)) {
|
||||
LOG("waiting for user input\n");
|
||||
|
||||
if (params.instruct) {
|
||||
|
@ -825,17 +881,23 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// end of text token
|
||||
if (!embd.empty() && embd.back() == llama_token_eos(ctx) && !(params.instruct || params.interactive)) {
|
||||
LOG_TEE(" [end of text]\n");
|
||||
if (!params.infill){
|
||||
LOG_TEE(" [end of text]\n");
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
// In interactive mode, respect the maximum number of tokens and drop back to user input when reached.
|
||||
// We skip this logic when n_predict == -1 (infinite) or -2 (stop at context size).
|
||||
if (params.interactive && n_remain <= 0 && params.n_predict >= 0) {
|
||||
if (params.interactive && n_remain <= 0 && params.n_predict >= 0 ) {
|
||||
n_remain = params.n_predict;
|
||||
is_interacting = true;
|
||||
}
|
||||
}
|
||||
if (params.infill && !params.interactive && n_remain <= 0) {
|
||||
printf("%s", llama_token_to_piece(ctx, llama_token_eot(ctx)).c_str());
|
||||
fflush(stdout);
|
||||
}
|
||||
|
||||
if (!path_session.empty() && params.prompt_cache_all && !params.prompt_cache_ro) {
|
||||
LOG_TEE("\n%s: saving final output to session file '%s'\n", __func__, path_session.c_str());
|
||||
|
|
|
@ -176,6 +176,16 @@ node index.js
|
|||
|
||||
`content`: Set the text to process.
|
||||
|
||||
**POST** `/infill`: For code infilling. Takes a prefix and a suffix and returns the predicted completion as stream.
|
||||
|
||||
*Options:*
|
||||
|
||||
`input_prefix`: Set the prefix of the code to infill.
|
||||
|
||||
`input_suffix`: Set the suffix of the code to infill.
|
||||
|
||||
It also accepts all the options of `/completion` except `stream` and `prompt`.
|
||||
|
||||
## More examples
|
||||
|
||||
### Interactive mode
|
||||
|
|
|
@ -341,6 +341,71 @@ struct llama_server_context
|
|||
return true;
|
||||
}
|
||||
|
||||
void loadInfill()
|
||||
{
|
||||
auto prefix_tokens = tokenize(params.input_prefix, true); // always add BOS
|
||||
auto suffix_tokens = tokenize(params.input_suffix, true); // always add BOS
|
||||
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(ctx));
|
||||
prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(ctx));
|
||||
prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
|
||||
prefix_tokens.push_back(llama_token_middle(ctx));
|
||||
auto prompt_tokens = prefix_tokens;
|
||||
|
||||
num_prompt_tokens = prompt_tokens.size();
|
||||
|
||||
if (params.n_keep < 0)
|
||||
{
|
||||
params.n_keep = (int)num_prompt_tokens;
|
||||
}
|
||||
params.n_keep = std::min(params.n_ctx - 4, params.n_keep);
|
||||
|
||||
// if input prompt is too big, truncate like normal
|
||||
if (num_prompt_tokens >= (size_t)params.n_ctx)
|
||||
{
|
||||
printf("Input prompt is too big, truncating. Can only take %d tokens but got %zu\n", params.n_ctx, num_prompt_tokens);
|
||||
// todo we probably want to cut from both sides
|
||||
const int n_left = (params.n_ctx - params.n_keep) / 2;
|
||||
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
|
||||
const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
|
||||
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
|
||||
std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), last_n_tokens.begin());
|
||||
|
||||
LOG_VERBOSE("input truncated", {
|
||||
{"n_ctx", params.n_ctx},
|
||||
{"n_keep", params.n_keep},
|
||||
{"n_left", n_left},
|
||||
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
|
||||
});
|
||||
|
||||
truncated = true;
|
||||
prompt_tokens = new_tokens;
|
||||
}
|
||||
else
|
||||
{
|
||||
const size_t ps = num_prompt_tokens;
|
||||
std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0);
|
||||
std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps);
|
||||
}
|
||||
|
||||
// compare the evaluated prompt with the new prompt
|
||||
n_past = common_part(embd, prompt_tokens);
|
||||
printf("n_past: %d\n", n_past);
|
||||
embd = prompt_tokens;
|
||||
if (n_past == num_prompt_tokens)
|
||||
{
|
||||
// we have to evaluate at least 1 token to generate logits.
|
||||
printf("we have to evaluate at least 1 token to generate logits\n");
|
||||
n_past--;
|
||||
}
|
||||
|
||||
LOG_VERBOSE("prompt ingested", {
|
||||
{"n_past", n_past},
|
||||
{"cached", tokens_to_str(ctx, embd.cbegin(), embd.cbegin() + n_past)},
|
||||
{"to_eval", tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend())},
|
||||
});
|
||||
|
||||
has_next_token = true;
|
||||
}
|
||||
void loadPrompt()
|
||||
{
|
||||
auto prompt_tokens = tokenize(prompt, true); // always add BOS
|
||||
|
@ -1199,6 +1264,27 @@ static void parse_options_completion(const json &body, llama_server_context &lla
|
|||
LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama));
|
||||
}
|
||||
|
||||
static void parse_options_infill(const json &body, llama_server_context &llama)
|
||||
{
|
||||
if (body.count("input_prefix") != 0)
|
||||
{
|
||||
llama.params.input_prefix = body["input_prefix"];
|
||||
}
|
||||
else
|
||||
{
|
||||
llama.params.input_prefix = "";
|
||||
}
|
||||
if (body.count("input_suffix") != 0)
|
||||
{
|
||||
llama.params.input_suffix = body["input_suffix"];
|
||||
}
|
||||
else
|
||||
{
|
||||
llama.params.input_suffix = "";
|
||||
}
|
||||
parse_options_completion(body, llama);
|
||||
}
|
||||
|
||||
static void log_server_request(const Request &req, const Response &res)
|
||||
{
|
||||
LOG_INFO("request", {
|
||||
|
@ -1498,6 +1584,127 @@ int main(int argc, char **argv)
|
|||
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
|
||||
} });
|
||||
|
||||
svr.Post("/infill", [&llama](const Request &req, Response &res)
|
||||
{
|
||||
auto lock = llama.lock();
|
||||
|
||||
llama.rewind();
|
||||
|
||||
llama_reset_timings(llama.ctx);
|
||||
|
||||
parse_options_infill(json::parse(req.body), llama);
|
||||
|
||||
if (!llama.loadGrammar())
|
||||
{
|
||||
res.status = 400;
|
||||
return;
|
||||
}
|
||||
llama.loadInfill();
|
||||
llama.beginCompletion();
|
||||
const auto chunked_content_provider = [&](size_t, DataSink & sink) {
|
||||
size_t sent_count = 0;
|
||||
size_t sent_token_probs_index = 0;
|
||||
|
||||
while (llama.has_next_token) {
|
||||
const completion_token_output token_with_probs = llama.doCompletion();
|
||||
if (token_with_probs.tok == -1 || llama.multibyte_pending > 0) {
|
||||
continue;
|
||||
}
|
||||
const std::string token_text = llama_token_to_piece(llama.ctx, token_with_probs.tok);
|
||||
|
||||
size_t pos = std::min(sent_count, llama.generated_text.size());
|
||||
|
||||
const std::string str_test = llama.generated_text.substr(pos);
|
||||
bool is_stop_full = false;
|
||||
size_t stop_pos =
|
||||
llama.findStoppingStrings(str_test, token_text.size(), STOP_FULL);
|
||||
if (stop_pos != std::string::npos) {
|
||||
is_stop_full = true;
|
||||
llama.generated_text.erase(
|
||||
llama.generated_text.begin() + pos + stop_pos,
|
||||
llama.generated_text.end());
|
||||
pos = std::min(sent_count, llama.generated_text.size());
|
||||
} else {
|
||||
is_stop_full = false;
|
||||
stop_pos = llama.findStoppingStrings(str_test, token_text.size(),
|
||||
STOP_PARTIAL);
|
||||
}
|
||||
|
||||
if (
|
||||
stop_pos == std::string::npos ||
|
||||
// Send rest of the text if we are at the end of the generation
|
||||
(!llama.has_next_token && !is_stop_full && stop_pos > 0)
|
||||
) {
|
||||
const std::string to_send = llama.generated_text.substr(pos, std::string::npos);
|
||||
|
||||
sent_count += to_send.size();
|
||||
|
||||
std::vector<completion_token_output> probs_output = {};
|
||||
|
||||
if (llama.params.n_probs > 0) {
|
||||
const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false);
|
||||
size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size());
|
||||
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size());
|
||||
if (probs_pos < probs_stop_pos) {
|
||||
probs_output = std::vector<completion_token_output>(llama.generated_token_probs.begin() + probs_pos, llama.generated_token_probs.begin() + probs_stop_pos);
|
||||
}
|
||||
sent_token_probs_index = probs_stop_pos;
|
||||
}
|
||||
|
||||
const json data = format_partial_response(llama, to_send, probs_output);
|
||||
|
||||
const std::string str =
|
||||
"data: " +
|
||||
data.dump(-1, ' ', false, json::error_handler_t::replace) +
|
||||
"\n\n";
|
||||
|
||||
LOG_VERBOSE("data stream", {
|
||||
{ "to_send", str }
|
||||
});
|
||||
|
||||
if (!sink.write(str.data(), str.size())) {
|
||||
LOG_VERBOSE("stream closed", {});
|
||||
llama_print_timings(llama.ctx);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!llama.has_next_token) {
|
||||
// Generation is done, send extra information.
|
||||
const json data = format_final_response(
|
||||
llama,
|
||||
"",
|
||||
std::vector<completion_token_output>(llama.generated_token_probs.begin(), llama.generated_token_probs.begin() + sent_token_probs_index)
|
||||
);
|
||||
|
||||
const std::string str =
|
||||
"data: " +
|
||||
data.dump(-1, ' ', false, json::error_handler_t::replace) +
|
||||
"\n\n";
|
||||
|
||||
LOG_VERBOSE("data stream", {
|
||||
{ "to_send", str }
|
||||
});
|
||||
|
||||
if (!sink.write(str.data(), str.size())) {
|
||||
LOG_VERBOSE("stream closed", {});
|
||||
llama_print_timings(llama.ctx);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
llama_print_timings(llama.ctx);
|
||||
sink.done();
|
||||
return true;
|
||||
};
|
||||
const auto on_complete = [&](bool) {
|
||||
llama.mutex.unlock();
|
||||
};
|
||||
lock.release();
|
||||
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
|
||||
});
|
||||
|
||||
svr.Get("/model.json", [&llama](const Request &, Response &res)
|
||||
{
|
||||
const json data = format_generation_settings(llama);
|
||||
|
|
20
llama.cpp
20
llama.cpp
|
@ -1052,6 +1052,10 @@ struct llama_vocab {
|
|||
id special_pad_id = -1;
|
||||
|
||||
id linefeed_id = 13;
|
||||
id special_prefix_id = 32007;
|
||||
id special_middle_id = 32009;
|
||||
id special_suffix_id = 32008;
|
||||
id special_eot_id = 32010;
|
||||
|
||||
int find_bpe_rank(std::string token_left, std::string token_right) const {
|
||||
replace_all(token_left, " ", "\u0120");
|
||||
|
@ -7024,6 +7028,22 @@ llama_token llama_token_eos(const struct llama_context * ctx) {
|
|||
llama_token llama_token_nl(const struct llama_context * ctx) {
|
||||
return ctx->model.vocab.linefeed_id;
|
||||
}
|
||||
llama_token llama_token_prefix(const struct llama_context * ctx) {
|
||||
return ctx->model.vocab.special_prefix_id;
|
||||
}
|
||||
|
||||
llama_token llama_token_middle(const struct llama_context * ctx) {
|
||||
return ctx->model.vocab.special_middle_id;
|
||||
}
|
||||
|
||||
llama_token llama_token_suffix(const struct llama_context * ctx) {
|
||||
return ctx->model.vocab.special_suffix_id;
|
||||
}
|
||||
|
||||
llama_token llama_token_eot(const struct llama_context * ctx) {
|
||||
return ctx->model.vocab.special_eot_id;
|
||||
}
|
||||
|
||||
|
||||
int llama_tokenize(
|
||||
struct llama_context * ctx,
|
||||
|
|
5
llama.h
5
llama.h
|
@ -362,6 +362,11 @@ extern "C" {
|
|||
LLAMA_API llama_token llama_token_bos(const struct llama_context * ctx); // beginning-of-sentence
|
||||
LLAMA_API llama_token llama_token_eos(const struct llama_context * ctx); // end-of-sentence
|
||||
LLAMA_API llama_token llama_token_nl (const struct llama_context * ctx); // next-line
|
||||
// codellama FIM tokens
|
||||
LLAMA_API llama_token llama_token_prefix(const struct llama_context * ctx); // Beginning of FIM prefix
|
||||
LLAMA_API llama_token llama_token_middle(const struct llama_context * ctx); // Beginning of FIM middle
|
||||
LLAMA_API llama_token llama_token_suffix(const struct llama_context * ctx); // Beginning of FIM suffix
|
||||
LLAMA_API llama_token llama_token_eot (const struct llama_context * ctx); // End of FIM middle
|
||||
|
||||
//
|
||||
// Tokenization
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue