more formatting changes

This commit is contained in:
Henri Vasserman 2023-06-11 14:01:42 +03:00
parent bac0ddb58f
commit 2c00bf855d
No known key found for this signature in database
GPG key ID: 2995FC0F58B1A986

View file

@ -5,8 +5,7 @@
#include "httplib.h" #include "httplib.h"
#include "json.hpp" #include "json.hpp"
struct server_params struct server_params {
{
std::string hostname = "127.0.0.1"; std::string hostname = "127.0.0.1";
int32_t port = 8080; int32_t port = 8080;
int32_t read_timeout = 600; int32_t read_timeout = 600;
@ -25,14 +24,12 @@ enum stop_type {
STOP_PARTIAL, STOP_PARTIAL,
}; };
bool ends_with(const std::string & str, const std::string & suffix) bool ends_with(const std::string & str, const std::string & suffix) {
{
return str.size() >= suffix.size() && return str.size() >= suffix.size() &&
0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix); 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
} }
size_t find_partial_stop_string(const std::string & stop, const std::string & text) size_t find_partial_stop_string(const std::string & stop, const std::string & text) {
{
if (!text.empty() && !stop.empty()) { if (!text.empty() && !stop.empty()) {
const char text_last_char = text.back(); const char text_last_char = text.back();
for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) { for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) {
@ -59,8 +56,8 @@ static std::string debug_str(const std::string & s) {
return ret; return ret;
} }
template<class InputIt, class OutputIt> template<class Iter>
static std::string tokens_to_str(llama_context * ctx, InputIt begin, OutputIt end) { static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {
std::string ret; std::string ret;
for (; begin != end; (void)++begin) { for (; begin != end; (void)++begin) {
ret += llama_token_to_str(ctx, *begin); ret += llama_token_to_str(ctx, *begin);
@ -68,8 +65,7 @@ static std::string tokens_to_str(llama_context * ctx, InputIt begin, OutputIt en
return ret; return ret;
} }
struct llama_server_context struct llama_server_context {
{
bool stream = false; bool stream = false;
bool has_next_token = false; bool has_next_token = false;
std::string generated_text = ""; std::string generated_text = "";
@ -90,8 +86,7 @@ struct llama_server_context
int json_indent = -1; int json_indent = -1;
int32_t multibyte_pending = 0; int32_t multibyte_pending = 0;
~llama_server_context() ~llama_server_context() {
{
if (ctx) { if (ctx) {
llama_free(ctx); llama_free(ctx);
ctx = nullptr; ctx = nullptr;
@ -110,12 +105,10 @@ struct llama_server_context
n_past = 0; n_past = 0;
} }
bool loadModel(const gpt_params & params_) bool loadModel(const gpt_params & params_) {
{
params = params_; params = params_;
ctx = llama_init_from_gpt_params(params); ctx = llama_init_from_gpt_params(params);
if (ctx == NULL) if (ctx == NULL) {
{
fprintf(stderr, "%s: error: unable to load model\n", __func__); fprintf(stderr, "%s: error: unable to load model\n", __func__);
return false; return false;
} }
@ -184,8 +177,7 @@ struct llama_server_context
has_next_token = true; has_next_token = true;
} }
void beginCompletion() void beginCompletion() {
{
// number of tokens to keep when resetting context // number of tokens to keep when resetting context
n_remain = params.n_predict; n_remain = params.n_predict;
llama_set_rng_seed(ctx, params.seed); llama_set_rng_seed(ctx, params.seed);
@ -215,15 +207,12 @@ struct llama_server_context
} }
} }
while (n_past < embd.size()) while (n_past < embd.size()) {
{
int n_eval = (int)embd.size() - n_past; int n_eval = (int)embd.size() - n_past;
if (n_eval > params.n_batch) if (n_eval > params.n_batch) {
{
n_eval = params.n_batch; n_eval = params.n_batch;
} }
if (llama_eval(ctx, &embd[n_past], n_eval, n_past, params.n_threads)) if (llama_eval(ctx, &embd[n_past], n_eval, n_past, params.n_threads)) {
{
fprintf(stderr, "%s : failed to eval\n", __func__); fprintf(stderr, "%s : failed to eval\n", __func__);
has_next_token = false; has_next_token = false;
return result; return result;
@ -245,8 +234,7 @@ struct llama_server_context
const float mirostat_tau = params.mirostat_tau; const float mirostat_tau = params.mirostat_tau;
const float mirostat_eta = params.mirostat_eta; const float mirostat_eta = params.mirostat_eta;
const bool penalize_nl = params.penalize_nl; const bool penalize_nl = params.penalize_nl;
llama_token id = 0; llama_token id = 0; {
{
auto * logits = llama_get_logits(ctx); auto * logits = llama_get_logits(ctx);
auto n_vocab = llama_n_vocab(ctx); auto n_vocab = llama_n_vocab(ctx);
@ -257,8 +245,7 @@ struct llama_server_context
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab); candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
{
candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
} }
@ -273,18 +260,15 @@ struct llama_server_context
llama_sample_frequency_and_presence_penalties(ctx, &candidates_p, llama_sample_frequency_and_presence_penalties(ctx, &candidates_p,
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
last_n_repeat, alpha_frequency, alpha_presence); last_n_repeat, alpha_frequency, alpha_presence);
if (!penalize_nl) if (!penalize_nl) {
{
logits[llama_token_nl()] = nl_logit; logits[llama_token_nl()] = nl_logit;
} }
if (temp <= 0) if (temp <= 0) {
{
// Greedy sampling // Greedy sampling
id = llama_sample_token_greedy(ctx, &candidates_p); id = llama_sample_token_greedy(ctx, &candidates_p);
} else { } else {
if (mirostat == 1) if (mirostat == 1) {
{
static float mirostat_mu = 2.0f * mirostat_tau; static float mirostat_mu = 2.0f * mirostat_tau;
const int mirostat_m = 100; const int mirostat_m = 100;
llama_sample_temperature(ctx, &candidates_p, temp); llama_sample_temperature(ctx, &candidates_p, temp);
@ -328,8 +312,7 @@ struct llama_server_context
} }
size_t findStoppingStrings(const std::string & text, const size_t last_token_size, size_t findStoppingStrings(const std::string & text, const size_t last_token_size,
const stop_type type) const stop_type type) {
{
size_t stop_pos = std::string::npos; size_t stop_pos = std::string::npos;
for (const std::string & word : params.antiprompt) { for (const std::string & word : params.antiprompt) {
size_t pos; size_t pos;
@ -353,8 +336,7 @@ struct llama_server_context
return stop_pos; return stop_pos;
} }
std::string doCompletion() std::string doCompletion() {
{
llama_token token = nextToken(); llama_token token = nextToken();
std::string token_text = token == -1 ? "" : llama_token_to_str(ctx, token); std::string token_text = token == -1 ? "" : llama_token_to_str(ctx, token);
@ -405,8 +387,7 @@ using namespace httplib;
using json = nlohmann::json; using json = nlohmann::json;
void server_print_usage(int /*argc*/, char ** argv, const gpt_params & params, const server_params & sparams) void server_print_usage(int /*argc*/, char ** argv, const gpt_params & params, const server_params & sparams) {
{
fprintf(stderr, "usage: %s [options]\n", argv[0]); fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n"); fprintf(stderr, "\n");
fprintf(stderr, "options:\n"); fprintf(stderr, "options:\n");
@ -417,12 +398,10 @@ void server_print_usage(int /*argc*/, char ** argv, const gpt_params & params, c
fprintf(stderr, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); fprintf(stderr, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
fprintf(stderr, " --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n"); fprintf(stderr, " --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
fprintf(stderr, " not recommended: doubles context memory required and no measurable increase in quality\n"); fprintf(stderr, " not recommended: doubles context memory required and no measurable increase in quality\n");
if (llama_mlock_supported()) if (llama_mlock_supported()) {
{
fprintf(stderr, " --mlock force system to keep model in RAM rather than swapping or compressing\n"); fprintf(stderr, " --mlock force system to keep model in RAM rather than swapping or compressing\n");
} }
if (llama_mmap_supported()) if (llama_mmap_supported()) {
{
fprintf(stderr, " --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)\n"); fprintf(stderr, " --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)\n");
} }
#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD #ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
@ -446,8 +425,7 @@ void server_print_usage(int /*argc*/, char ** argv, const gpt_params & params, c
} }
void server_params_parse(int argc, char ** argv, server_params & sparams, void server_params_parse(int argc, char ** argv, server_params & sparams,
gpt_params & params) gpt_params & params) {
{
gpt_params default_params; gpt_params default_params;
server_params default_sparams; server_params default_sparams;
std::string arg; std::string arg;
@ -522,10 +500,8 @@ void server_params_parse(int argc, char ** argv, server_params & sparams,
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n"); fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
#endif #endif
} }
else if (arg == "--tensor-split" || arg == "-ts") else if (arg == "--tensor-split" || arg == "-ts") {
{ if (++i >= argc) {
if (++i >= argc)
{
invalid_param = true; invalid_param = true;
break; break;
} }
@ -538,14 +514,11 @@ void server_params_parse(int argc, char ** argv, server_params & sparams,
std::vector<std::string> split_arg{ it, {} }; std::vector<std::string> split_arg{ it, {} };
GGML_ASSERT(split_arg.size() <= LLAMA_MAX_DEVICES); GGML_ASSERT(split_arg.size() <= LLAMA_MAX_DEVICES);
for (size_t i = 0; i < LLAMA_MAX_DEVICES; ++i) for (size_t i = 0; i < LLAMA_MAX_DEVICES; ++i) {
{ if (i < split_arg.size()) {
if (i < split_arg.size())
{
params.tensor_split[i] = std::stof(split_arg[i]); params.tensor_split[i] = std::stof(split_arg[i]);
} }
else else {
{
params.tensor_split[i] = 0.0f; params.tensor_split[i] = 0.0f;
} }
} }
@ -553,10 +526,8 @@ void server_params_parse(int argc, char ** argv, server_params & sparams,
fprintf(stderr, "WARNING: llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n"); fprintf(stderr, "WARNING: llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n");
#endif // GGML_USE_CUBLAS #endif // GGML_USE_CUBLAS
} }
else if (arg == "--main-gpu" || arg == "-mg") else if (arg == "--main-gpu" || arg == "-mg") {
{ if (++i >= argc) {
if (++i >= argc)
{
invalid_param = true; invalid_param = true;
break; break;
} }
@ -603,32 +574,31 @@ json format_generation_settings(llama_server_context & llama) {
const bool ignore_eos = eos_bias != llama.params.logit_bias.end() && const bool ignore_eos = eos_bias != llama.params.logit_bias.end() &&
eos_bias->second < 0.0f && std::isinf(eos_bias->second); eos_bias->second < 0.0f && std::isinf(eos_bias->second);
return json{ return json {
{ "seed", llama.params.seed }, { "seed", llama.params.seed },
{ "temp", llama.params.temp }, { "temp", llama.params.temp },
{ "top_k", llama.params.top_k }, { "top_k", llama.params.top_k },
{ "top_p", llama.params.top_p }, { "top_p", llama.params.top_p },
{ "tfs_z", llama.params.tfs_z }, { "tfs_z", llama.params.tfs_z },
{ "typical_p", llama.params.typical_p }, { "typical_p", llama.params.typical_p },
{ "repeat_last_n", llama.params.repeat_last_n }, { "repeat_last_n", llama.params.repeat_last_n },
{ "repeat_penalty", llama.params.repeat_penalty }, { "repeat_penalty", llama.params.repeat_penalty },
{ "presence_penalty", llama.params.presence_penalty }, { "presence_penalty", llama.params.presence_penalty },
{ "frequency_penalty", llama.params.frequency_penalty }, { "frequency_penalty", llama.params.frequency_penalty },
{ "mirostat", llama.params.mirostat }, { "mirostat", llama.params.mirostat },
{ "mirostat_tau", llama.params.mirostat_tau }, { "mirostat_tau", llama.params.mirostat_tau },
{ "mirostat_eta", llama.params.mirostat_eta }, { "mirostat_eta", llama.params.mirostat_eta },
{ "penalize_nl", llama.params.penalize_nl }, { "penalize_nl", llama.params.penalize_nl },
{ "stop", llama.params.antiprompt }, { "stop", llama.params.antiprompt },
{ "n_predict", llama.params.n_predict }, { "n_predict", llama.params.n_predict },
{ "n_keep", llama.params.n_keep }, { "n_keep", llama.params.n_keep },
{ "ignore_eos", ignore_eos }, { "ignore_eos", ignore_eos },
{ "stream", llama.stream }, { "stream", llama.stream },
{ "logit_bias", llama.params.logit_bias }, { "logit_bias", llama.params.logit_bias },
}; };
} }
bool parse_options_completion(json body, llama_server_context & llama, Response & res) bool parse_options_completion(json body, llama_server_context & llama, Response & res) {
{
gpt_params default_params; gpt_params default_params;
if (!body["stream"].is_null()) { if (!body["stream"].is_null()) {
llama.stream = body["stream"].get<bool>(); llama.stream = body["stream"].get<bool>();
@ -766,8 +736,7 @@ bool parse_options_completion(json body, llama_server_context & llama, Response
return true; return true;
} }
int main(int argc, char ** argv) int main(int argc, char ** argv) {
{
// own arguments required by this example // own arguments required by this example
gpt_params params; gpt_params params;
server_params sparams; server_params sparams;
@ -791,20 +760,20 @@ int main(int argc, char ** argv)
std::thread::hardware_concurrency(), llama_print_system_info()); std::thread::hardware_concurrency(), llama_print_system_info());
// load the model // load the model
if (!llama.loadModel(params)) if (!llama.loadModel(params)) {
{
return 1; return 1;
} }
Server svr; Server svr;
svr.set_default_headers({ svr.set_default_headers({
{"Access-Control-Allow-Origin", "*"}, { "Access-Control-Allow-Origin", "*" },
{"Access-Control-Allow-Headers", "content-type"} { "Access-Control-Allow-Headers", "content-type" }
}); });
svr.Get("/", [](const Request &, Response & res) svr.Get("/", [](const Request &, Response & res) {
{ res.set_content("<h1>llama.cpp server works</h1>", "text/html"); }); res.set_content("<h1>llama.cpp server works</h1>", "text/html");
});
svr.Post("/completion", [&llama](const Request & req, Response & res) { svr.Post("/completion", [&llama](const Request & req, Response & res) {
@ -836,13 +805,15 @@ int main(int argc, char ** argv)
llama.generated_text.end()); llama.generated_text.end());
} }
json data = { {"content", llama.generated_text}, json data {
{"stop", true}, { "content", llama.generated_text },
{"model", llama.params.model_alias}, { "stop", true },
{"tokens_predicted", llama.num_tokens_predicted}, { "model", llama.params.model_alias },
{"generation_settings", format_generation_settings(llama)}, { "tokens_predicted", llama.num_tokens_predicted },
{"prompt", llama.params.prompt}, { "generation_settings", format_generation_settings(llama) },
{"stopping_word", llama.stopping_word} }; { "prompt", llama.params.prompt },
{ "stopping_word", llama.stopping_word },
};
llama_print_timings(llama.ctx); llama_print_timings(llama.ctx);
@ -851,7 +822,7 @@ int main(int argc, char ** argv)
"application/json"); "application/json");
} }
else { else {
const auto chunked_content_provider = [&](size_t, DataSink& sink) { const auto chunked_content_provider = [&](size_t, DataSink & sink) {
size_t sent_count = 0; size_t sent_count = 0;
while (llama.has_next_token) { while (llama.has_next_token) {
@ -880,18 +851,22 @@ int main(int argc, char ** argv)
json data; json data;
if (llama.has_next_token) { if (llama.has_next_token) {
data = { {"content", to_send}, {"stop", false} }; data = {
{ "content", to_send },
{ "stop", false },
};
} else { } else {
// Generation is done, send extra information. // Generation is done, send extra information.
data = { data = {
{"content", to_send}, { "content", to_send },
{"stop", true}, { "stop", true },
{"model", llama.params.model_alias}, { "model", llama.params.model_alias },
{"tokens_predicted", llama.num_tokens_predicted}, { "tokens_predicted", llama.num_tokens_predicted },
{"generation_settings", format_generation_settings(llama)}, { "generation_settings", format_generation_settings(llama) },
{"prompt", llama.params.prompt}, { "prompt", llama.params.prompt },
{"stopping_word", llama.stopping_word}, { "stopping_word", llama.stopping_word },
{"generated_text", llama.generated_text} }; { "generated_text", llama.generated_text },
};
} }
std::string str = std::string str =
@ -919,31 +894,31 @@ int main(int argc, char ** argv)
}; };
res.set_chunked_content_provider("text/event-stream", chunked_content_provider); res.set_chunked_content_provider("text/event-stream", chunked_content_provider);
} }
}); });
svr.Options(R"(/.*)", [](const Request &, Response & res) svr.Options(R"(/.*)", [](const Request &, Response & res) {
{ return res.set_content("", "application/json");
return res.set_content("", "application/json"); });
});
svr.Post("/tokenize", [&llama](const Request & req, Response & res) svr.Post("/tokenize", [&llama](const Request & req, Response & res) {
{ json body = json::parse(req.body);
json body = json::parse(req.body); std::string content = body["content"].get<std::string>();
json data = { std::vector<llama_token> tokens = ::llama_tokenize(llama.ctx, content, false);
{"tokens", ::llama_tokenize(llama.ctx, body["content"].get<std::string>(), false) } }; json data {{ "tokens", tokens }};
return res.set_content(data.dump(llama.json_indent), "application/json"); return res.set_content(data.dump(llama.json_indent), "application/json");
}); });
svr.set_logger([](const Request & req, const Response & res) { svr.set_logger([](const Request & req, const Response & res) {
json log = { json log = {
{ "time", time(NULL) },
{ "ip", req.remote_addr },
{ "status", res.status }, { "status", res.status },
{ "path", req.path }, { "path", req.path },
{ "request", req.body }, { "request", req.body },
{ "response", res.body }, { "response", res.body },
}; };
fprintf(stdout, "http_request: %s\n", fprintf(stdout, "%s\n", log.dump(-1, ' ', false, json::error_handler_t::replace).c_str());
log.dump(-1, ' ', false, json::error_handler_t::replace).c_str()); });
});
svr.set_exception_handler([](const Request &, Response & res, std::exception_ptr ep) { svr.set_exception_handler([](const Request &, Response & res, std::exception_ptr ep) {
const auto * fmt = "500 Internal Server Error\n%s"; const auto * fmt = "500 Internal Server Error\n%s";