more json changes and stop info

This commit is contained in:
Henri Vasserman 2023-06-12 17:29:25 +03:00
parent dff11a14d2
commit 13cf6929b7
No known key found for this signature in database
GPG key ID: 2995FC0F58B1A986

View file

@ -105,6 +105,10 @@ struct llama_server_context {
llama_context * ctx = nullptr;
gpt_params params;
bool truncated = false;
bool stopped_eos = false;
bool stopped_word = false;
bool stopped_limit = false;
std::string stopping_word;
int json_indent = -1;
@ -122,6 +126,10 @@ struct llama_server_context {
num_tokens_predicted = 0;
generated_text = "";
generated_text.reserve(params.n_ctx);
truncated = false;
stopped_eos = false;
stopped_word = false;
stopped_limit = false;
stopping_word = "";
multibyte_pending = 0;
@ -166,6 +174,7 @@ struct llama_server_context {
{ "new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend()) },
});
truncated = true;
prompt_tokens = new_tokens;
} else {
const size_t ps = prompt_tokens.size();
@ -207,14 +216,13 @@ struct llama_server_context {
new_tokens.insert(new_tokens.end(), embd.end() - n_left, embd.end());
embd = new_tokens;
n_past = params.n_keep;
if (server_verbose) {
LOG_VERBOSE("input truncated", {
{ "n_ctx", params.n_ctx },
{ "n_keep", params.n_keep },
{ "n_left", n_left },
{ "new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend()) },
});
}
truncated = true;
LOG_VERBOSE("input truncated", {
{ "n_ctx", params.n_ctx },
{ "n_keep", params.n_keep },
{ "n_left", n_left },
{ "new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend()) },
});
}
while (n_past < embd.size()) {
@ -314,8 +322,9 @@ struct llama_server_context {
--n_remain;
if (!embd.empty() && embd.back() == llama_token_eos()) {
stopping_word = llama_token_to_str(ctx, embd.back());
//stopping_word = llama_token_to_str(ctx, embd.back());
has_next_token = false;
stopped_eos = true;
LOG_VERBOSE("eos token found", {});
return result;
}
@ -341,6 +350,7 @@ struct llama_server_context {
(stop_pos == std::string::npos || pos < stop_pos)) {
if (type == STOP_FULL) {
stopping_word = word;
stopped_word = true;
has_next_token = false;
}
stop_pos = pos;
@ -378,17 +388,22 @@ struct llama_server_context {
n_remain++;
}
if (server_verbose) {
LOG_VERBOSE("next token", {
{ "token", token },
{ "token_text", llama_token_to_str(ctx, token) },
{ "has_next_token", has_next_token },
{ "n_remain", n_remain },
{ "num_tokens_predicted", num_tokens_predicted },
{ "stopping_word", stopping_word },
});
if (!has_next_token && n_remain == 0) {
stopped_limit = true;
}
LOG_VERBOSE("next token", {
{ "token", token },
{ "token_text", llama_token_to_str(ctx, token) },
{ "has_next_token", has_next_token },
{ "n_remain", n_remain },
{ "num_tokens_predicted", num_tokens_predicted },
{ "stopped_eos", stopped_eos },
{ "stopped_word", stopped_word },
{ "stopped_limit", stopped_limit },
{ "stopping_word", stopping_word },
});
return token_text;
}
};
@ -578,7 +593,7 @@ void server_params_parse(int argc, char ** argv, server_params & sparams,
}
}
json format_generation_settings(llama_server_context & llama) {
static json format_generation_settings(llama_server_context & llama) {
const auto eos_bias = llama.params.logit_bias.find(llama_token_eos());
const bool ignore_eos = eos_bias != llama.params.logit_bias.end() &&
eos_bias->second < 0.0f && std::isinf(eos_bias->second);
@ -607,6 +622,35 @@ json format_generation_settings(llama_server_context & llama) {
};
}
static json format_final_response(llama_server_context & llama, const std::string & content) {
return json {
{ "content", content },
{ "stop", true },
{ "model", llama.params.model_alias },
{ "tokens_predicted", llama.num_tokens_predicted },
{ "generation_settings", format_generation_settings(llama) },
{ "prompt", llama.params.prompt },
{ "truncated", llama.truncated },
{ "stopped_eos", llama.stopped_eos },
{ "stopped_word", llama.stopped_word },
{ "stopped_limit", llama.stopped_limit },
{ "stopping_word", llama.stopping_word },
};
}
static json format_partial_response(const std::string & content) {
return json {
{ "content", content },
{ "stop", false },
};
}
static json format_tokenizer_response(const std::vector<llama_token> & tokens) {
return json {
{ "tokens", tokens }
};
}
bool parse_options_completion(json body, llama_server_context & llama) {
gpt_params default_params;
@ -663,6 +707,17 @@ bool parse_options_completion(json body, llama_server_context & llama) {
return true;
}
static void log_server_request(const Request & req, const Response & res) {
LOG_INFO("request", {
{ "remote_addr", req.remote_addr },
{ "remote_port", req.remote_port },
{ "status", res.status },
{ "path", req.path },
{ "request", req.body },
{ "response", res.body },
});
}
int main(int argc, char ** argv) {
// own arguments required by this example
gpt_params params;
@ -739,15 +794,7 @@ int main(int argc, char ** argv) {
llama.generated_text.end());
}
json data {
{ "content", llama.generated_text },
{ "stop", true },
{ "model", llama.params.model_alias },
{ "tokens_predicted", llama.num_tokens_predicted },
{ "generation_settings", format_generation_settings(llama) },
{ "prompt", llama.params.prompt },
{ "stopping_word", llama.stopping_word },
};
json data = format_final_response(llama, llama.generated_text);
llama_print_timings(llama.ctx);
@ -785,22 +832,10 @@ int main(int argc, char ** argv) {
json data;
if (llama.has_next_token) {
data = {
{ "content", to_send },
{ "stop", false },
};
data = format_partial_response(to_send);
} else {
// Generation is done, send extra information.
data = {
{ "content", to_send },
{ "stop", true },
{ "model", llama.params.model_alias },
{ "tokens_predicted", llama.num_tokens_predicted },
{ "generation_settings", format_generation_settings(llama) },
{ "prompt", llama.params.prompt },
{ "stopping_word", llama.stopping_word },
{ "generated_text", llama.generated_text },
};
data = format_final_response(llama, to_send);
}
std::string str =
@ -836,20 +871,11 @@ int main(int argc, char ** argv) {
json body = json::parse(req.body);
std::string content = body["content"].get<std::string>();
std::vector<llama_token> tokens = ::llama_tokenize(llama.ctx, content, false);
json data {{ "tokens", tokens }};
json data = format_tokenizer_response(tokens);
return res.set_content(data.dump(llama.json_indent), "application/json");
});
svr.set_logger([](const Request & req, const Response & res) {
LOG_INFO("request", {
{ "remote_addr", req.remote_addr },
{ "remote_port", req.remote_port },
{ "status", res.status },
{ "path", req.path },
{ "request", req.body },
{ "response", res.body },
});
});
svr.set_logger(log_server_request);
svr.set_exception_handler([](const Request &, Response & res, std::exception_ptr ep) {
const auto * fmt = "500 Internal Server Error\n%s";