Added capturing the stopping word and sending it along with the final JSON.
Fixed an fprintf warning Fixed a bug that broke streaming Properly removed thread changing in json (only grabbed batch_size before)
This commit is contained in:
parent
2e5c5ee224
commit
dda915cac4
1 changed files with 14 additions and 7 deletions
|
@ -33,13 +33,14 @@ struct llama_server_context
|
||||||
llama_context *ctx;
|
llama_context *ctx;
|
||||||
gpt_params params;
|
gpt_params params;
|
||||||
|
|
||||||
bool reload_ctx = false;
|
std::string stopping_word = "";
|
||||||
|
|
||||||
void rewind() {
|
void rewind() {
|
||||||
as_loop = false;
|
as_loop = false;
|
||||||
params.antiprompt.clear();
|
params.antiprompt.clear();
|
||||||
num_tokens_predicted = 0;
|
num_tokens_predicted = 0;
|
||||||
generated_text = "";
|
generated_text = "";
|
||||||
|
stopping_word = "";
|
||||||
|
|
||||||
//processed_tokens.clear();
|
//processed_tokens.clear();
|
||||||
embd_inp.clear();
|
embd_inp.clear();
|
||||||
|
@ -233,6 +234,7 @@ struct llama_server_context
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!embd.empty() && embd.back() == llama_token_eos()) {
|
if (!embd.empty() && embd.back() == llama_token_eos()) {
|
||||||
|
stopping_word = llama_token_to_str(ctx, embd.back());
|
||||||
has_next_token = false;
|
has_next_token = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -258,6 +260,7 @@ struct llama_server_context
|
||||||
size_t i = generated_text.find(word, generated_text.size() - (word.size() + token_text.size()));
|
size_t i = generated_text.find(word, generated_text.size() - (word.size() + token_text.size()));
|
||||||
if (i != std::string::npos) {
|
if (i != std::string::npos) {
|
||||||
generated_text.erase(generated_text.begin() + i, generated_text.begin() + i + word.size());
|
generated_text.erase(generated_text.begin() + i, generated_text.begin() + i + word.size());
|
||||||
|
stopping_word = word;
|
||||||
has_next_token = false;
|
has_next_token = false;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -313,7 +316,7 @@ void server_print_usage(int /*argc*/, char **argv, const gpt_params ¶ms, con
|
||||||
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
|
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
|
||||||
fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
|
fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
|
||||||
fprintf(stderr, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
|
fprintf(stderr, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
|
||||||
fprintf(stderr, " --host ip address to listen (default (default: %d)\n", sparams.hostname);
|
fprintf(stderr, " --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str());
|
||||||
fprintf(stderr, " --port PORT port to listen (default (default: %d)\n", sparams.port);
|
fprintf(stderr, " --port PORT port to listen (default (default: %d)\n", sparams.port);
|
||||||
fprintf(stderr, " -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
|
fprintf(stderr, " -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
|
@ -449,9 +452,9 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para
|
||||||
}
|
}
|
||||||
|
|
||||||
bool parse_options_completion(json body, llama_server_context& llama, Response &res) {
|
bool parse_options_completion(json body, llama_server_context& llama, Response &res) {
|
||||||
if (!body["threads"].is_null())
|
if (!body["as_loop"].is_null())
|
||||||
{
|
{
|
||||||
llama.params.n_threads = body["threads"].get<int>();
|
llama.as_loop = body["as_loop"].get<bool>();
|
||||||
}
|
}
|
||||||
if (!body["n_predict"].is_null())
|
if (!body["n_predict"].is_null())
|
||||||
{
|
{
|
||||||
|
@ -475,7 +478,7 @@ bool parse_options_completion(json body, llama_server_context& llama, Response &
|
||||||
}
|
}
|
||||||
if (!body["repeat_last_n"].is_null())
|
if (!body["repeat_last_n"].is_null())
|
||||||
{
|
{
|
||||||
llama.params.repeat_last_n = body["repeat_last_n"].get<float>();
|
llama.params.repeat_last_n = body["repeat_last_n"].get<int>();
|
||||||
}
|
}
|
||||||
if (!body["temperature"].is_null())
|
if (!body["temperature"].is_null())
|
||||||
{
|
{
|
||||||
|
@ -630,7 +633,8 @@ int main(int argc, char **argv)
|
||||||
"penalize_nl", llama.params.penalize_nl
|
"penalize_nl", llama.params.penalize_nl
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{"prompt", llama.params.prompt} };
|
{"prompt", llama.params.prompt},
|
||||||
|
{"stopping_word", llama.stopping_word} };
|
||||||
return res.set_content(data.dump(), "application/json");
|
return res.set_content(data.dump(), "application/json");
|
||||||
}
|
}
|
||||||
catch (const json::exception &e)
|
catch (const json::exception &e)
|
||||||
|
@ -684,6 +688,7 @@ int main(int argc, char **argv)
|
||||||
json data;
|
json data;
|
||||||
if (llama.has_next_token)
|
if (llama.has_next_token)
|
||||||
{
|
{
|
||||||
|
//fprintf(stdout, "Result: %s\n", result);
|
||||||
final_text += result;
|
final_text += result;
|
||||||
data = {
|
data = {
|
||||||
{"content", result },
|
{"content", result },
|
||||||
|
@ -715,6 +720,7 @@ int main(int argc, char **argv)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{"prompt", llama.params.prompt},
|
{"prompt", llama.params.prompt},
|
||||||
|
{"stopping_word", llama.stopping_word},
|
||||||
{"generated_text", final_text}
|
{"generated_text", final_text}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -735,7 +741,7 @@ int main(int argc, char **argv)
|
||||||
{
|
{
|
||||||
// Generation is done, send extra information.
|
// Generation is done, send extra information.
|
||||||
data = {
|
data = {
|
||||||
{"content", "\uFFFD" },
|
{"content", u8"\uFFFD" },
|
||||||
{"stop", true },
|
{"stop", true },
|
||||||
{"tokens_predicted", llama.num_tokens_predicted},
|
{"tokens_predicted", llama.num_tokens_predicted},
|
||||||
{"seed", llama.params.seed},
|
{"seed", llama.params.seed},
|
||||||
|
@ -756,6 +762,7 @@ int main(int argc, char **argv)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{"prompt", llama.params.prompt},
|
{"prompt", llama.params.prompt},
|
||||||
|
{"stopping_word", llama.stopping_word},
|
||||||
{"generated_text", final_text}
|
{"generated_text", final_text}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue