server : coding-style normalization (part 2)

This commit is contained in:
Georgi Gerganov 2023-10-19 14:09:45 +03:00
parent e44ed60187
commit 654e0a1fe0
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -1933,12 +1933,15 @@ static void parse_options_completion(const json &body, llama_client_slot* slot,
size_t end_prefix = pos; size_t end_prefix = pos;
pos += pattern.length(); pos += pattern.length();
size_t end_pos = prompt.find("]", pos); size_t end_pos = prompt.find("]", pos);
if (end_pos != std::string::npos) { if (end_pos != std::string::npos)
{
std::string image_id = prompt.substr(pos, end_pos - pos); std::string image_id = prompt.substr(pos, end_pos - pos);
try { try
{
int img_id = std::stoi(image_id); int img_id = std::stoi(image_id);
bool found = false; bool found = false;
for(slot_image &img : slot->images) { for (slot_image &img : slot->images)
{
if (img.id == img_id) { if (img.id == img_id) {
found = true; found = true;
img.prefix_prompt = prompt.substr(begin_prefix, end_prefix - begin_prefix); img.prefix_prompt = prompt.substr(begin_prefix, end_prefix - begin_prefix);
@ -2043,7 +2046,8 @@ static void beam_search_callback(void *callback_data, llama_beams_state beams_st
#endif #endif
} }
struct token_translator { struct token_translator
{
llama_context * ctx; llama_context * ctx;
std::string operator()(llama_token tok) const { return llama_token_to_piece(ctx, tok); } std::string operator()(llama_token tok) const { return llama_token_to_piece(ctx, tok); }
std::string operator()(const completion_token_output &cto) const { return (*this)(cto.tok); } std::string operator()(const completion_token_output &cto) const { return (*this)(cto.tok); }
@ -2055,10 +2059,12 @@ static void append_to_generated_text_from_generated_token_probs(llama_server_con
auto translator = token_translator{llama.ctx}; auto translator = token_translator{llama.ctx};
auto add_strlen = [=](size_t sum, const completion_token_output & cto) { return sum + translator(cto).size(); }; auto add_strlen = [=](size_t sum, const completion_token_output & cto) { return sum + translator(cto).size(); };
const size_t len = std::accumulate(gtps.begin(), gtps.end(), size_t(0), add_strlen); const size_t len = std::accumulate(gtps.begin(), gtps.end(), size_t(0), add_strlen);
if (slot->generated_text.capacity() < slot->generated_text.size() + len) { if (slot->generated_text.capacity() < slot->generated_text.size() + len)
{
slot->generated_text.reserve(slot->generated_text.size() + len); slot->generated_text.reserve(slot->generated_text.size() + len);
} }
for (const completion_token_output & cto : gtps) { for (const completion_token_output & cto : gtps)
{
slot->generated_text += translator(cto); slot->generated_text += translator(cto);
} }
} }
@ -2108,25 +2114,29 @@ int main(int argc, char **argv)
svr.Get("/", [](const httplib::Request &, httplib::Response &res) svr.Get("/", [](const httplib::Request &, httplib::Response &res)
{ {
res.set_content(reinterpret_cast<const char*>(&index_html), index_html_len, "text/html"); res.set_content(reinterpret_cast<const char*>(&index_html), index_html_len, "text/html");
return false; }); return false;
});
// this is only called if no index.js is found in the public --path // this is only called if no index.js is found in the public --path
svr.Get("/index.js", [](const httplib::Request &, httplib::Response &res) svr.Get("/index.js", [](const httplib::Request &, httplib::Response &res)
{ {
res.set_content(reinterpret_cast<const char *>(&index_js), index_js_len, "text/javascript"); res.set_content(reinterpret_cast<const char *>(&index_js), index_js_len, "text/javascript");
return false; }); return false;
});
// this is only called if no index.html is found in the public --path // this is only called if no index.html is found in the public --path
svr.Get("/completion.js", [](const httplib::Request &, httplib::Response &res) svr.Get("/completion.js", [](const httplib::Request &, httplib::Response &res)
{ {
res.set_content(reinterpret_cast<const char*>(&completion_js), completion_js_len, "application/javascript"); res.set_content(reinterpret_cast<const char*>(&completion_js), completion_js_len, "application/javascript");
return false; }); return false;
});
// this is only called if no index.html is found in the public --path // this is only called if no index.html is found in the public --path
svr.Get("/json-schema-to-grammar.mjs", [](const httplib::Request &, httplib::Response &res) svr.Get("/json-schema-to-grammar.mjs", [](const httplib::Request &, httplib::Response &res)
{ {
res.set_content(reinterpret_cast<const char*>(&json_schema_to_grammar_mjs), json_schema_to_grammar_mjs_len, "application/javascript"); res.set_content(reinterpret_cast<const char*>(&json_schema_to_grammar_mjs), json_schema_to_grammar_mjs_len, "application/javascript");
return false; }); return false;
});
svr.Get("/props", [&llama](const httplib::Request & /*req*/, httplib::Response &res) svr.Get("/props", [&llama](const httplib::Request & /*req*/, httplib::Response &res)
{ {
@ -2135,7 +2145,8 @@ int main(int argc, char **argv)
{ "user_name", llama.user_name.c_str() }, { "user_name", llama.user_name.c_str() },
{ "assistant_name", llama.assistant_name.c_str() } { "assistant_name", llama.assistant_name.c_str() }
}; };
res.set_content(data.dump(), "application/json"); }); res.set_content(data.dump(), "application/json");
});
svr.Post("/completion", [&llama](const httplib::Request &req, httplib::Response &res) svr.Post("/completion", [&llama](const httplib::Request &req, httplib::Response &res)
{ {
@ -2166,7 +2177,8 @@ int main(int argc, char **argv)
if (!slot->params.stream) { if (!slot->params.stream) {
std::string completion_text = ""; std::string completion_text = "";
if (llama.params.n_beams) { if (llama.params.n_beams)
{
// Fill llama.generated_token_probs vector with final beam. // Fill llama.generated_token_probs vector with final beam.
server_beam_search_callback_data data_beam; server_beam_search_callback_data data_beam;
data_beam.slot = slot; data_beam.slot = slot;
@ -2175,18 +2187,25 @@ int main(int argc, char **argv)
slot->n_past, llama.params.n_predict); slot->n_past, llama.params.n_predict);
// Translate llama.generated_token_probs to llama.generated_text. // Translate llama.generated_token_probs to llama.generated_text.
append_to_generated_text_from_generated_token_probs(llama, slot); append_to_generated_text_from_generated_token_probs(llama, slot);
} else { }
while (slot->is_processing()) { else
if(slot->has_new_token()) { {
while (slot->is_processing())
{
if (slot->has_new_token())
{
completion_text += slot->next().text_to_send; completion_text += slot->next().text_to_send;
} else { }
else
{
std::this_thread::sleep_for(std::chrono::microseconds(5)); std::this_thread::sleep_for(std::chrono::microseconds(5));
} }
} }
} }
auto probs = slot->generated_token_probs; auto probs = slot->generated_token_probs;
if (slot->sparams.n_probs > 0 && slot->stopped_word) { if (slot->sparams.n_probs > 0 && slot->stopped_word)
{
const std::vector<llama_token> stop_word_toks = llama_tokenize(llama.ctx, slot->stopping_word, false); const std::vector<llama_token> stop_word_toks = llama_tokenize(llama.ctx, slot->stopping_word, false);
probs = std::vector<completion_token_output>(slot->generated_token_probs.begin(), slot->generated_token_probs.end() - stop_word_toks.size()); probs = std::vector<completion_token_output>(slot->generated_token_probs.begin(), slot->generated_token_probs.end() - stop_word_toks.size());
} }
@ -2194,20 +2213,23 @@ int main(int argc, char **argv)
const json data = format_final_response(llama, slot, completion_text, probs); const json data = format_final_response(llama, slot, completion_text, probs);
slot_print_timings(slot); slot_print_timings(slot);
slot->release(); slot->release();
res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json");
"application/json");
} else { } else {
const auto chunked_content_provider = [slot, &llama](size_t, httplib::DataSink & sink) { const auto chunked_content_provider = [slot, &llama](size_t, httplib::DataSink & sink) {
size_t sent_token_probs_index = 0; size_t sent_token_probs_index = 0;
while(slot->is_processing()) { while (slot->is_processing())
if(slot->has_new_token()) { // new token notification {
if (slot->has_new_token())
{ // new token notification
const completion_token_output token = slot->next(); const completion_token_output token = slot->next();
std::vector<completion_token_output> probs_output = {}; std::vector<completion_token_output> probs_output = {};
if (slot->sparams.n_probs > 0) { if (slot->sparams.n_probs > 0)
{
const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, token.text_to_send, false); const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, token.text_to_send, false);
size_t probs_pos = std::min(sent_token_probs_index, slot->generated_token_probs.size()); size_t probs_pos = std::min(sent_token_probs_index, slot->generated_token_probs.size());
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), slot->generated_token_probs.size()); size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), slot->generated_token_probs.size());
if (probs_pos < probs_stop_pos) { if (probs_pos < probs_stop_pos)
{
probs_output = std::vector<completion_token_output>(slot->generated_token_probs.begin() + probs_pos, slot->generated_token_probs.begin() + probs_stop_pos); probs_output = std::vector<completion_token_output>(slot->generated_token_probs.begin() + probs_pos, slot->generated_token_probs.begin() + probs_stop_pos);
} }
sent_token_probs_index = probs_stop_pos; sent_token_probs_index = probs_stop_pos;
@ -2220,11 +2242,14 @@ int main(int argc, char **argv)
LOG_VERBOSE("data stream", { LOG_VERBOSE("data stream", {
{ "to_send", str } { "to_send", str }
}); });
if(!sink.write(str.c_str(), str.size())) { if (!sink.write(str.c_str(), str.size()))
{
slot->release(); slot->release();
return false; return false;
} }
} else { }
else
{
std::this_thread::sleep_for(std::chrono::microseconds(5)); std::this_thread::sleep_for(std::chrono::microseconds(5));
} }
} }
@ -2240,10 +2265,13 @@ int main(int argc, char **argv)
"data: " + "data: " +
data.dump(-1, ' ', false, json::error_handler_t::replace) + data.dump(-1, ' ', false, json::error_handler_t::replace) +
"\n\n"; "\n\n";
LOG_VERBOSE("data stream", { LOG_VERBOSE("data stream", {
{ "to_send", str } { "to_send", str }
}); });
if (!sink.write(str.data(), str.size())) {
if (!sink.write(str.data(), str.size()))
{
slot->release(); slot->release();
return false; return false;
} }
@ -2255,23 +2283,25 @@ int main(int argc, char **argv)
slot->clean_tokens(); slot->clean_tokens();
}; };
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
} }); }
});
svr.Post("/infill", [&llama](const httplib::Request &req, httplib::Response &res) svr.Post("/infill", [&llama](const httplib::Request &req, httplib::Response &res)
{ {
json data = json::parse(req.body); json data = json::parse(req.body);
llama_client_slot* slot = llama.get_slot(json_value(data, "slot_id", -1)); llama_client_slot* slot = llama.get_slot(json_value(data, "slot_id", -1));
if(slot == nullptr) { if (slot == nullptr)
{
LOG_TEE("slot unavailable\n"); LOG_TEE("slot unavailable\n");
res.status = 404; res.status = 404;
res.set_content("slot_error", "text/plain"); res.set_content("slot_error", "text/plain");
return; return;
} }
if(data.contains("system_prompt")) { if (data.contains("system_prompt"))
{
llama.process_system_prompt_data(data["system_prompt"]); llama.process_system_prompt_data(data["system_prompt"]);
} }
@ -2294,7 +2324,9 @@ int main(int argc, char **argv)
if(slot->has_new_token()) if(slot->has_new_token())
{ {
completion_text += slot->next().text_to_send; completion_text += slot->next().text_to_send;
} else { }
else
{
std::this_thread::sleep_for(std::chrono::microseconds(5)); std::this_thread::sleep_for(std::chrono::microseconds(5));
} }
} }
@ -2315,15 +2347,20 @@ int main(int argc, char **argv)
{ {
const auto chunked_content_provider = [slot, &llama](size_t, httplib::DataSink & sink) { const auto chunked_content_provider = [slot, &llama](size_t, httplib::DataSink & sink) {
size_t sent_token_probs_index = 0; size_t sent_token_probs_index = 0;
while(slot->is_processing()) { while (slot->is_processing())
if(slot->has_new_token()) { // new token notification {
if (slot->has_new_token())
{
// new token notification
const completion_token_output token = slot->next(); const completion_token_output token = slot->next();
std::vector<completion_token_output> probs_output = {}; std::vector<completion_token_output> probs_output = {};
if (slot->sparams.n_probs > 0) { if (slot->sparams.n_probs > 0)
{
const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, token.text_to_send, false); const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, token.text_to_send, false);
size_t probs_pos = std::min(sent_token_probs_index, slot->generated_token_probs.size()); size_t probs_pos = std::min(sent_token_probs_index, slot->generated_token_probs.size());
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), slot->generated_token_probs.size()); size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), slot->generated_token_probs.size());
if (probs_pos < probs_stop_pos) { if (probs_pos < probs_stop_pos)
{
probs_output = std::vector<completion_token_output>(slot->generated_token_probs.begin() + probs_pos, slot->generated_token_probs.begin() + probs_stop_pos); probs_output = std::vector<completion_token_output>(slot->generated_token_probs.begin() + probs_pos, slot->generated_token_probs.begin() + probs_stop_pos);
} }
sent_token_probs_index = probs_stop_pos; sent_token_probs_index = probs_stop_pos;
@ -2336,11 +2373,14 @@ int main(int argc, char **argv)
LOG_VERBOSE("data stream", { LOG_VERBOSE("data stream", {
{ "to_send", str } { "to_send", str }
}); });
if(!sink.write(str.c_str(), str.size())) { if (!sink.write(str.c_str(), str.size()))
{
slot->release(); slot->release();
return false; return false;
} }
} else { }
else
{
std::this_thread::sleep_for(std::chrono::milliseconds(5)); std::this_thread::sleep_for(std::chrono::milliseconds(5));
} }
} }
@ -2359,7 +2399,8 @@ int main(int argc, char **argv)
LOG_VERBOSE("data stream", { LOG_VERBOSE("data stream", {
{ "to_send", str } { "to_send", str }
}); });
if (!sink.write(str.data(), str.size())) { if (!sink.write(str.data(), str.size()))
{
slot->release(); slot->release();
return false; return false;
} }
@ -2378,14 +2419,14 @@ int main(int argc, char **argv)
svr.Get("/model.json", [&llama](const httplib::Request &, httplib::Response &res) svr.Get("/model.json", [&llama](const httplib::Request &, httplib::Response &res)
{ {
const json data = format_generation_settings(llama, llama.get_slot(0)); const json data = format_generation_settings(llama, llama.get_slot(0));
return res.set_content(data.dump(), "application/json"); }); return res.set_content(data.dump(), "application/json");
});
svr.Options(R"(/.*)", [](const httplib::Request &, httplib::Response &res) svr.Options(R"(/.*)", [](const httplib::Request &, httplib::Response &res)
{ return res.set_content("", "application/json"); }); { return res.set_content("", "application/json"); });
svr.Post("/tokenize", [&llama](const httplib::Request &req, httplib::Response &res) svr.Post("/tokenize", [&llama](const httplib::Request &req, httplib::Response &res)
{ {
const json body = json::parse(req.body); const json body = json::parse(req.body);
std::vector<llama_token> tokens; std::vector<llama_token> tokens;
if (body.count("content") != 0) if (body.count("content") != 0)
@ -2393,11 +2434,11 @@ int main(int argc, char **argv)
tokens = llama.tokenize(body["content"], false); tokens = llama.tokenize(body["content"], false);
} }
const json data = format_tokenizer_response(tokens); const json data = format_tokenizer_response(tokens);
return res.set_content(data.dump(), "application/json"); }); return res.set_content(data.dump(), "application/json");
});
svr.Post("/detokenize", [&llama](const httplib::Request &req, httplib::Response &res) svr.Post("/detokenize", [&llama](const httplib::Request &req, httplib::Response &res)
{ {
const json body = json::parse(req.body); const json body = json::parse(req.body);
std::string content; std::string content;
if (body.count("tokens") != 0) if (body.count("tokens") != 0)
@ -2407,7 +2448,8 @@ int main(int argc, char **argv)
} }
const json data = format_detokenized_response(content); const json data = format_detokenized_response(content);
return res.set_content(data.dump(), "application/json"); }); return res.set_content(data.dump(), "application/json");
});
svr.Post("/embedding", [&llama](const httplib::Request &req, httplib::Response &res) svr.Post("/embedding", [&llama](const httplib::Request &req, httplib::Response &res)
{ {
@ -2428,7 +2470,8 @@ int main(int argc, char **argv)
std::this_thread::sleep_for(std::chrono::microseconds(10)); std::this_thread::sleep_for(std::chrono::microseconds(10));
} }
const json data = format_embedding_response(llama); const json data = format_embedding_response(llama);
return res.set_content(data.dump(), "application/json"); }); return res.set_content(data.dump(), "application/json");
});
svr.set_logger(log_server_request); svr.set_logger(log_server_request);
@ -2436,24 +2479,34 @@ int main(int argc, char **argv)
{ {
const char fmt[] = "500 Internal Server Error\n%s"; const char fmt[] = "500 Internal Server Error\n%s";
char buf[BUFSIZ]; char buf[BUFSIZ];
try { try
{
std::rethrow_exception(std::move(ep)); std::rethrow_exception(std::move(ep));
} catch (std::exception & e) { }
catch (std::exception &e)
{
snprintf(buf, sizeof(buf), fmt, e.what()); snprintf(buf, sizeof(buf), fmt, e.what());
} catch (...) { }
catch (...)
{
snprintf(buf, sizeof(buf), fmt, "Unknown Exception"); snprintf(buf, sizeof(buf), fmt, "Unknown Exception");
} }
res.set_content(buf, "text/plain"); res.set_content(buf, "text/plain");
res.status = 500; }); res.status = 500;
});
svr.set_error_handler([](const httplib::Request &, httplib::Response &res) svr.set_error_handler([](const httplib::Request &, httplib::Response &res)
{ {
if (res.status == 400) { if (res.status == 400)
{
res.set_content("Invalid request", "text/plain"); res.set_content("Invalid request", "text/plain");
} else if (res.status != 500) { }
else if (res.status != 500)
{
res.set_content("File Not Found", "text/plain"); res.set_content("File Not Found", "text/plain");
res.status = 404; res.status = 404;
} }); }
});
// set timeouts and change hostname and port // set timeouts and change hostname and port
svr.set_read_timeout (sparams.read_timeout); svr.set_read_timeout (sparams.read_timeout);