Spaces to 4 and other code style cleanup. Notes in README.

This commit is contained in:
Randall Fitzgerald 2023-06-09 04:47:18 -04:00
parent ccd85e0a6b
commit a9c34779f6
2 changed files with 842 additions and 842 deletions

View file

@ -23,6 +23,8 @@ Command line options:
## Quick Start ## Quick Start
**Note:** The server is not built by default. Make sure to add `LLAMA_BUILD_SERVER=ON` to your CMake command.
To get started right away, run the following command, making sure to use the correct path for the model you have: To get started right away, run the following command, making sure to use the correct path for the model you have:
### Unix-based systems (Linux, macOS, etc.): ### Unix-based systems (Linux, macOS, etc.):
@ -99,7 +101,7 @@ node .
`top_p`: Limit the next token selection to a subset of tokens with a cumulative probability above a threshold P (default: 0.9). `top_p`: Limit the next token selection to a subset of tokens with a cumulative probability above a threshold P (default: 0.9).
`n_predict`: Set the number of tokens to predict when generating text (default: 128, -1 = infinity). `n_predict`: Set the number of tokens to predict when generating text. **Note:** May exceed the the limit slightly if the last token is a partial multibyte character. (default: 128, -1 = infinity).
`n_keep`: Specify the number of tokens from the initial prompt to retain when the model resets its internal context. `n_keep`: Specify the number of tokens from the initial prompt to retain when the model resets its internal context.
By default, this value is set to 0 (meaning no tokens are kept). Use `-1` to retain all tokens from the initial prompt. By default, this value is set to 0 (meaning no tokens are kept). Use `-1` to retain all tokens from the initial prompt.

View file

@ -14,7 +14,7 @@ struct server_params
bool verbose = false; bool verbose = false;
}; };
static size_t common_part(const std::vector<llama_token> & a, const std::vector<llama_token> & b) { static size_t common_part(const std::vector<llama_token>& a, const std::vector<llama_token>& b) {
size_t i; size_t i;
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++); for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++);
return i; return i;
@ -25,13 +25,13 @@ enum stop_type {
STOP_PARTIAL, STOP_PARTIAL,
}; };
bool ends_with(const std::string &str, const std::string &suffix) bool ends_with(const std::string& str, const std::string& suffix)
{ {
return str.size() >= suffix.size() && return str.size() >= suffix.size() &&
0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix); 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
} }
size_t find_partial_stop_string(const std::string &stop, const std::string &text) size_t find_partial_stop_string(const std::string& stop, const std::string& text)
{ {
if (!text.empty()) { if (!text.empty()) {
const char text_last_char = text.back(); const char text_last_char = text.back();
@ -47,7 +47,7 @@ size_t find_partial_stop_string(const std::string &stop, const std::string &text
return std::string::npos; return std::string::npos;
} }
static std::string debug_str(const std::string & s) { static std::string debug_str(const std::string& s) {
std::string ret; std::string ret;
for (size_t i = 0; s[i]; i++) { for (size_t i = 0; s[i]; i++) {
switch (s[i]) { switch (s[i]) {
@ -60,7 +60,7 @@ static std::string debug_str(const std::string & s) {
} }
template<class InputIt, class OutputIt> template<class InputIt, class OutputIt>
static std::string tokens_to_str(llama_context * ctx, InputIt begin, OutputIt end) { static std::string tokens_to_str(llama_context* ctx, InputIt begin, OutputIt end) {
std::string ret; std::string ret;
for (; begin != end; (void)++begin) { for (; begin != end; (void)++begin) {
ret += llama_token_to_str(ctx, *begin); ret += llama_token_to_str(ctx, *begin);
@ -81,7 +81,7 @@ struct llama_server_context
std::vector<llama_token> embd; std::vector<llama_token> embd;
std::vector<llama_token> last_n_tokens; std::vector<llama_token> last_n_tokens;
llama_context *ctx = nullptr; llama_context* ctx = nullptr;
gpt_params params; gpt_params params;
std::string stopping_word; std::string stopping_word;
@ -110,7 +110,7 @@ struct llama_server_context
n_past = 0; n_past = 0;
} }
bool loadModel(const gpt_params &params_) bool loadModel(const gpt_params& params_)
{ {
params = params_; params = params_;
ctx = llama_init_from_gpt_params(params); ctx = llama_init_from_gpt_params(params);
@ -136,7 +136,7 @@ struct llama_server_context
// if input prompt is too big, truncate like normal // if input prompt is too big, truncate like normal
if (prompt_tokens.size() >= (size_t)params.n_ctx) { if (prompt_tokens.size() >= (size_t)params.n_ctx) {
const int n_left = (params.n_ctx - params.n_keep)/2; const int n_left = (params.n_ctx - params.n_keep) / 2;
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
const int erased_blocks = (prompt_tokens.size() - params.n_keep - n_left - 1) / n_left; const int erased_blocks = (prompt_tokens.size() - params.n_keep - n_left - 1) / n_left;
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end()); new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
@ -196,7 +196,7 @@ struct llama_server_context
if (embd.size() >= (size_t)params.n_ctx) { if (embd.size() >= (size_t)params.n_ctx) {
// Reset context // Reset context
const int n_left = (params.n_ctx - params.n_keep)/2; const int n_left = (params.n_ctx - params.n_keep) / 2;
std::vector<llama_token> new_tokens(embd.begin(), embd.begin() + params.n_keep); std::vector<llama_token> new_tokens(embd.begin(), embd.begin() + params.n_keep);
new_tokens.insert(new_tokens.end(), embd.end() - n_left, embd.end()); new_tokens.insert(new_tokens.end(), embd.end() - n_left, embd.end());
@ -247,11 +247,11 @@ struct llama_server_context
const bool penalize_nl = params.penalize_nl; const bool penalize_nl = params.penalize_nl;
llama_token id = 0; llama_token id = 0;
{ {
auto *logits = llama_get_logits(ctx); auto* logits = llama_get_logits(ctx);
auto n_vocab = llama_n_vocab(ctx); auto n_vocab = llama_n_vocab(ctx);
// Apply params.logit_bias map // Apply params.logit_bias map
for (const auto &it : params.logit_bias) { for (const auto& it : params.logit_bias) {
logits[it.first] += it.second; logits[it.first] += it.second;
} }
@ -259,10 +259,10 @@ struct llama_server_context
candidates.reserve(n_vocab); candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) for (llama_token token_id = 0; token_id < n_vocab; token_id++)
{ {
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
} }
llama_token_data_array candidates_p = {candidates.data(), candidates.size(), false}; llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
// Apply penalties // Apply penalties
float nl_logit = logits[llama_token_nl()]; float nl_logit = logits[llama_token_nl()];
@ -282,24 +282,18 @@ struct llama_server_context
{ {
// Greedy sampling // Greedy sampling
id = llama_sample_token_greedy(ctx, &candidates_p); id = llama_sample_token_greedy(ctx, &candidates_p);
} } else {
else
{
if (mirostat == 1) if (mirostat == 1)
{ {
static float mirostat_mu = 2.0f * mirostat_tau; static float mirostat_mu = 2.0f * mirostat_tau;
const int mirostat_m = 100; const int mirostat_m = 100;
llama_sample_temperature(ctx, &candidates_p, temp); llama_sample_temperature(ctx, &candidates_p, temp);
id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
} } else if (mirostat == 2) {
else if (mirostat == 2)
{
static float mirostat_mu = 2.0f * mirostat_tau; static float mirostat_mu = 2.0f * mirostat_tau;
llama_sample_temperature(ctx, &candidates_p, temp); llama_sample_temperature(ctx, &candidates_p, temp);
id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu); id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
} } else {
else
{
// Temperature sampling // Temperature sampling
llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1); llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1);
llama_sample_typical(ctx, &candidates_p, typical_p, 1); llama_sample_typical(ctx, &candidates_p, typical_p, 1);
@ -333,17 +327,18 @@ struct llama_server_context
return result; return result;
} }
size_t findStoppingStrings(const std::string &text, const size_t last_token_size, size_t findStoppingStrings(const std::string& text, const size_t last_token_size,
const stop_type type) const stop_type type)
{ {
size_t stop_pos = std::string::npos; size_t stop_pos = std::string::npos;
for (const std::string &word : params.antiprompt) { for (const std::string& word : params.antiprompt) {
size_t pos; size_t pos;
if (type == STOP_FULL) { if (type == STOP_FULL) {
const size_t tmp = word.size() + last_token_size; const size_t tmp = word.size() + last_token_size;
const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0; const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0;
pos = text.find(word, from_pos); pos = text.find(word, from_pos);
} else { }
else {
pos = find_partial_stop_string(word, text); pos = find_partial_stop_string(word, text);
} }
if (pos != std::string::npos && if (pos != std::string::npos &&
@ -410,7 +405,7 @@ using namespace httplib;
using json = nlohmann::json; using json = nlohmann::json;
void server_print_usage(int /*argc*/, char **argv, const gpt_params &params, const server_params &sparams) void server_print_usage(int /*argc*/, char** argv, const gpt_params& params, const server_params& sparams)
{ {
fprintf(stderr, "usage: %s [options]\n", argv[0]); fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n"); fprintf(stderr, "\n");
@ -450,8 +445,8 @@ void server_print_usage(int /*argc*/, char **argv, const gpt_params &params, con
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
void server_params_parse(int argc, char **argv, server_params &sparams, void server_params_parse(int argc, char** argv, server_params& sparams,
gpt_params &params) gpt_params& params)
{ {
gpt_params default_params; gpt_params default_params;
server_params default_sparams; server_params default_sparams;
@ -603,12 +598,12 @@ void server_params_parse(int argc, char **argv, server_params &sparams,
} }
} }
json format_generation_settings(llama_server_context &llama) { json format_generation_settings(llama_server_context& llama) {
const auto eos_bias = llama.params.logit_bias.find(llama_token_eos()); const auto eos_bias = llama.params.logit_bias.find(llama_token_eos());
const bool ignore_eos = eos_bias != llama.params.logit_bias.end() && const bool ignore_eos = eos_bias != llama.params.logit_bias.end() &&
eos_bias->second < 0.0f && std::isinf(eos_bias->second); eos_bias->second < 0.0f && std::isinf(eos_bias->second);
return json { return json{
{ "seed", llama.params.seed }, { "seed", llama.params.seed },
{ "temp", llama.params.temp }, { "temp", llama.params.temp },
{ "top_k", llama.params.top_k }, { "top_k", llama.params.top_k },
@ -632,7 +627,7 @@ json format_generation_settings(llama_server_context &llama) {
}; };
} }
bool parse_options_completion(json body, llama_server_context& llama, Response &res) bool parse_options_completion(json body, llama_server_context& llama, Response& res)
{ {
gpt_params default_params; gpt_params default_params;
if (!body["stream"].is_null()) { if (!body["stream"].is_null()) {
@ -727,7 +722,7 @@ bool parse_options_completion(json body, llama_server_context& llama, Response &
} }
if (body["logit_bias"].is_array()) { if (body["logit_bias"].is_array()) {
int n_vocab = llama_n_vocab(llama.ctx); int n_vocab = llama_n_vocab(llama.ctx);
for (const auto &el : body["logit_bias"]) { for (const auto& el : body["logit_bias"]) {
if (el.is_array() && el.size() == 2 && el[0].is_number_integer()) { if (el.is_array() && el.size() == 2 && el[0].is_number_integer()) {
llama_token tok = el[0].get<llama_token>(); llama_token tok = el[0].get<llama_token>();
if (tok >= 0 && tok < n_vocab) { if (tok >= 0 && tok < n_vocab) {
@ -744,7 +739,7 @@ bool parse_options_completion(json body, llama_server_context& llama, Response &
if (!body["prompt"].is_null()) { if (!body["prompt"].is_null()) {
llama.params.prompt = body["prompt"].get<std::string>(); llama.params.prompt = body["prompt"].get<std::string>();
} else { } else {
json data = {{"status", "error"}, {"reason", "You need to provide a prompt"}}; json data = { {"status", "error"}, {"reason", "You need to provide a prompt"} };
res.set_content(data.dump(llama.json_indent), "application/json"); res.set_content(data.dump(llama.json_indent), "application/json");
res.status = 400; res.status = 400;
return false; return false;
@ -755,7 +750,7 @@ bool parse_options_completion(json body, llama_server_context& llama, Response &
const auto stop = body["stop"].get<std::vector<std::string>>(); const auto stop = body["stop"].get<std::vector<std::string>>();
std::copy_if(stop.begin(), stop.end(), std::copy_if(stop.begin(), stop.end(),
std::back_inserter(llama.params.antiprompt), std::back_inserter(llama.params.antiprompt),
[](const std::string &str) { return !str.empty(); }); [](const std::string& str) { return !str.empty(); });
} }
if (llama.verbose) { if (llama.verbose) {
@ -771,7 +766,7 @@ bool parse_options_completion(json body, llama_server_context& llama, Response &
return true; return true;
} }
int main(int argc, char **argv) int main(int argc, char** argv)
{ {
// own arguments required by this example // own arguments required by this example
gpt_params params; gpt_params params;
@ -809,10 +804,10 @@ int main(int argc, char **argv)
{"Access-Control-Allow-Headers", "content-type"} {"Access-Control-Allow-Headers", "content-type"}
}); });
svr.Get("/", [](const Request &, Response &res) svr.Get("/", [](const Request&, Response& res)
{ res.set_content("<h1>llama.cpp server works</h1>", "text/html"); }); { res.set_content("<h1>llama.cpp server works</h1>", "text/html"); });
svr.Post("/completion", [&llama](const Request &req, Response &res) { svr.Post("/completion", [&llama](const Request& req, Response& res) {
llama.rewind(); llama.rewind();
llama_reset_timings(llama.ctx); llama_reset_timings(llama.ctx);
@ -842,21 +837,22 @@ int main(int argc, char **argv)
llama.generated_text.end()); llama.generated_text.end());
} }
json data = {{"content", llama.generated_text}, json data = { {"content", llama.generated_text},
{"stop", true}, {"stop", true},
{"model", llama.params.model_alias}, {"model", llama.params.model_alias},
{"tokens_predicted", llama.num_tokens_predicted}, {"tokens_predicted", llama.num_tokens_predicted},
{"generation_settings", format_generation_settings(llama)}, {"generation_settings", format_generation_settings(llama)},
{"prompt", llama.params.prompt}, {"prompt", llama.params.prompt},
{"stopping_word", llama.stopping_word}}; {"stopping_word", llama.stopping_word} };
llama_print_timings(llama.ctx); llama_print_timings(llama.ctx);
res.set_content( res.set_content(
data.dump(llama.json_indent, ' ', false, json::error_handler_t::replace), data.dump(llama.json_indent, ' ', false, json::error_handler_t::replace),
"application/json"); "application/json");
} else { }
const auto chunked_content_provider = [&](size_t, DataSink &sink) { else {
const auto chunked_content_provider = [&](size_t, DataSink& sink) {
size_t sent_count = 0; size_t sent_count = 0;
while (llama.has_next_token) { while (llama.has_next_token) {
@ -867,7 +863,7 @@ int main(int argc, char **argv)
size_t pos = std::min(sent_count, llama.generated_text.size()); size_t pos = std::min(sent_count, llama.generated_text.size());
const char *str_test = llama.generated_text.c_str() + pos; const char* str_test = llama.generated_text.c_str() + pos;
size_t stop_pos = size_t stop_pos =
llama.findStoppingStrings(str_test, token_text.size(), STOP_FULL); llama.findStoppingStrings(str_test, token_text.size(), STOP_FULL);
if (stop_pos != std::string::npos) { if (stop_pos != std::string::npos) {
@ -885,7 +881,7 @@ int main(int argc, char **argv)
json data; json data;
if (llama.has_next_token) { if (llama.has_next_token) {
data = {{"content", to_send}, {"stop", false}}; data = { {"content", to_send}, {"stop", false} };
} else { } else {
// Generation is done, send extra information. // Generation is done, send extra information.
data = { data = {
@ -896,7 +892,7 @@ int main(int argc, char **argv)
{"generation_settings", format_generation_settings(llama)}, {"generation_settings", format_generation_settings(llama)},
{"prompt", llama.params.prompt}, {"prompt", llama.params.prompt},
{"stopping_word", llama.stopping_word}, {"stopping_word", llama.stopping_word},
{"generated_text", llama.generated_text}}; {"generated_text", llama.generated_text} };
} }
std::string str = std::string str =
@ -926,12 +922,12 @@ int main(int argc, char **argv)
} }
}); });
svr.Options(R"(/.*)", [](const Request &, Response &res) svr.Options(R"(/.*)", [](const Request&, Response& res)
{ {
return res.set_content("", "application/json"); return res.set_content("", "application/json");
}); });
svr.Post("/tokenize", [&llama](const Request &req, Response &res) svr.Post("/tokenize", [&llama](const Request& req, Response& res)
{ {
json body = json::parse(req.body); json body = json::parse(req.body);
json data = { json data = {
@ -950,14 +946,16 @@ int main(int argc, char **argv)
log.dump(-1, ' ', false, json::error_handler_t::replace).c_str()); log.dump(-1, ' ', false, json::error_handler_t::replace).c_str());
}); });
svr.set_exception_handler([](const Request &, Response &res, std::exception_ptr ep) { svr.set_exception_handler([](const Request&, Response& res, std::exception_ptr ep) {
const auto *fmt = "500 Internal Server Error\n%s"; const auto* fmt = "500 Internal Server Error\n%s";
char buf[BUFSIZ]; char buf[BUFSIZ];
try { try {
std::rethrow_exception(std::move(ep)); std::rethrow_exception(std::move(ep));
} catch (std::exception &e) { }
catch (std::exception& e) {
snprintf(buf, sizeof(buf), fmt, e.what()); snprintf(buf, sizeof(buf), fmt, e.what());
} catch (...) { }
catch (...) {
snprintf(buf, sizeof(buf), fmt, "Unknown Exception"); snprintf(buf, sizeof(buf), fmt, "Unknown Exception");
} }
res.set_content(buf, "text/plain"); res.set_content(buf, "text/plain");