Added capturing the stopping word and sending it along with the final JSON.

Fixed an fprintf warning
Fixed a bug that broke streaming
Properly removed thread changing in json (only grabbed batch_size before)
This commit is contained in:
digiwombat 2023-05-28 08:43:38 -04:00
parent 2e5c5ee224
commit dda915cac4

View file

@ -33,13 +33,14 @@ struct llama_server_context
llama_context *ctx; llama_context *ctx;
gpt_params params; gpt_params params;
bool reload_ctx = false; std::string stopping_word = "";
void rewind() { void rewind() {
as_loop = false; as_loop = false;
params.antiprompt.clear(); params.antiprompt.clear();
num_tokens_predicted = 0; num_tokens_predicted = 0;
generated_text = ""; generated_text = "";
stopping_word = "";
//processed_tokens.clear(); //processed_tokens.clear();
embd_inp.clear(); embd_inp.clear();
@ -233,6 +234,7 @@ struct llama_server_context
} }
if (!embd.empty() && embd.back() == llama_token_eos()) { if (!embd.empty() && embd.back() == llama_token_eos()) {
stopping_word = llama_token_to_str(ctx, embd.back());
has_next_token = false; has_next_token = false;
} }
@ -258,6 +260,7 @@ struct llama_server_context
size_t i = generated_text.find(word, generated_text.size() - (word.size() + token_text.size())); size_t i = generated_text.find(word, generated_text.size() - (word.size() + token_text.size()));
if (i != std::string::npos) { if (i != std::string::npos) {
generated_text.erase(generated_text.begin() + i, generated_text.begin() + i + word.size()); generated_text.erase(generated_text.begin() + i, generated_text.begin() + i + word.size());
stopping_word = word;
has_next_token = false; has_next_token = false;
break; break;
} }
@ -313,7 +316,7 @@ void server_print_usage(int /*argc*/, char **argv, const gpt_params &params, con
fprintf(stderr, " model path (default: %s)\n", params.model.c_str()); fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n"); fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
fprintf(stderr, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n"); fprintf(stderr, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
fprintf(stderr, " --host ip address to listen (default (default: %d)\n", sparams.hostname); fprintf(stderr, " --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str());
fprintf(stderr, " --port PORT port to listen (default (default: %d)\n", sparams.port); fprintf(stderr, " --port PORT port to listen (default (default: %d)\n", sparams.port);
fprintf(stderr, " -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout); fprintf(stderr, " -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
fprintf(stderr, "\n"); fprintf(stderr, "\n");
@ -449,9 +452,9 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para
} }
bool parse_options_completion(json body, llama_server_context& llama, Response &res) { bool parse_options_completion(json body, llama_server_context& llama, Response &res) {
if (!body["threads"].is_null()) if (!body["as_loop"].is_null())
{ {
llama.params.n_threads = body["threads"].get<int>(); llama.as_loop = body["as_loop"].get<bool>();
} }
if (!body["n_predict"].is_null()) if (!body["n_predict"].is_null())
{ {
@ -475,7 +478,7 @@ bool parse_options_completion(json body, llama_server_context& llama, Response &
} }
if (!body["repeat_last_n"].is_null()) if (!body["repeat_last_n"].is_null())
{ {
llama.params.repeat_last_n = body["repeat_last_n"].get<float>(); llama.params.repeat_last_n = body["repeat_last_n"].get<int>();
} }
if (!body["temperature"].is_null()) if (!body["temperature"].is_null())
{ {
@ -630,7 +633,8 @@ int main(int argc, char **argv)
"penalize_nl", llama.params.penalize_nl "penalize_nl", llama.params.penalize_nl
} }
}, },
{"prompt", llama.params.prompt} }; {"prompt", llama.params.prompt},
{"stopping_word", llama.stopping_word} };
return res.set_content(data.dump(), "application/json"); return res.set_content(data.dump(), "application/json");
} }
catch (const json::exception &e) catch (const json::exception &e)
@ -684,6 +688,7 @@ int main(int argc, char **argv)
json data; json data;
if (llama.has_next_token) if (llama.has_next_token)
{ {
//fprintf(stdout, "Result: %s\n", result);
final_text += result; final_text += result;
data = { data = {
{"content", result }, {"content", result },
@ -715,6 +720,7 @@ int main(int argc, char **argv)
} }
}, },
{"prompt", llama.params.prompt}, {"prompt", llama.params.prompt},
{"stopping_word", llama.stopping_word},
{"generated_text", final_text} {"generated_text", final_text}
}; };
} }
@ -735,7 +741,7 @@ int main(int argc, char **argv)
{ {
// Generation is done, send extra information. // Generation is done, send extra information.
data = { data = {
{"content", "\uFFFD" }, {"content", u8"\uFFFD" },
{"stop", true }, {"stop", true },
{"tokens_predicted", llama.num_tokens_predicted}, {"tokens_predicted", llama.num_tokens_predicted},
{"seed", llama.params.seed}, {"seed", llama.params.seed},
@ -756,6 +762,7 @@ int main(int argc, char **argv)
} }
}, },
{"prompt", llama.params.prompt}, {"prompt", llama.params.prompt},
{"stopping_word", llama.stopping_word},
{"generated_text", final_text} {"generated_text", final_text}
}; };
} }