Added capturing the stopping word and sending it along with the final JSON.
Fixed an fprintf warning Fixed a bug that broke streaming Properly removed thread changing in json (only grabbed batch_size before)
This commit is contained in:
parent
2e5c5ee224
commit
dda915cac4
1 changed files with 14 additions and 7 deletions
|
@ -33,13 +33,14 @@ struct llama_server_context
|
|||
llama_context *ctx;
|
||||
gpt_params params;
|
||||
|
||||
bool reload_ctx = false;
|
||||
std::string stopping_word = "";
|
||||
|
||||
void rewind() {
|
||||
as_loop = false;
|
||||
params.antiprompt.clear();
|
||||
num_tokens_predicted = 0;
|
||||
generated_text = "";
|
||||
stopping_word = "";
|
||||
|
||||
//processed_tokens.clear();
|
||||
embd_inp.clear();
|
||||
|
@ -233,6 +234,7 @@ struct llama_server_context
|
|||
}
|
||||
|
||||
if (!embd.empty() && embd.back() == llama_token_eos()) {
|
||||
stopping_word = llama_token_to_str(ctx, embd.back());
|
||||
has_next_token = false;
|
||||
}
|
||||
|
||||
|
@ -258,6 +260,7 @@ struct llama_server_context
|
|||
size_t i = generated_text.find(word, generated_text.size() - (word.size() + token_text.size()));
|
||||
if (i != std::string::npos) {
|
||||
generated_text.erase(generated_text.begin() + i, generated_text.begin() + i + word.size());
|
||||
stopping_word = word;
|
||||
has_next_token = false;
|
||||
break;
|
||||
}
|
||||
|
@ -313,7 +316,7 @@ void server_print_usage(int /*argc*/, char **argv, const gpt_params ¶ms, con
|
|||
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
|
||||
fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
|
||||
fprintf(stderr, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
|
||||
fprintf(stderr, " --host ip address to listen (default (default: %d)\n", sparams.hostname);
|
||||
fprintf(stderr, " --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str());
|
||||
fprintf(stderr, " --port PORT port to listen (default (default: %d)\n", sparams.port);
|
||||
fprintf(stderr, " -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
|
||||
fprintf(stderr, "\n");
|
||||
|
@ -449,9 +452,9 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para
|
|||
}
|
||||
|
||||
bool parse_options_completion(json body, llama_server_context& llama, Response &res) {
|
||||
if (!body["threads"].is_null())
|
||||
if (!body["as_loop"].is_null())
|
||||
{
|
||||
llama.params.n_threads = body["threads"].get<int>();
|
||||
llama.as_loop = body["as_loop"].get<bool>();
|
||||
}
|
||||
if (!body["n_predict"].is_null())
|
||||
{
|
||||
|
@ -475,7 +478,7 @@ bool parse_options_completion(json body, llama_server_context& llama, Response &
|
|||
}
|
||||
if (!body["repeat_last_n"].is_null())
|
||||
{
|
||||
llama.params.repeat_last_n = body["repeat_last_n"].get<float>();
|
||||
llama.params.repeat_last_n = body["repeat_last_n"].get<int>();
|
||||
}
|
||||
if (!body["temperature"].is_null())
|
||||
{
|
||||
|
@ -630,7 +633,8 @@ int main(int argc, char **argv)
|
|||
"penalize_nl", llama.params.penalize_nl
|
||||
}
|
||||
},
|
||||
{"prompt", llama.params.prompt} };
|
||||
{"prompt", llama.params.prompt},
|
||||
{"stopping_word", llama.stopping_word} };
|
||||
return res.set_content(data.dump(), "application/json");
|
||||
}
|
||||
catch (const json::exception &e)
|
||||
|
@ -684,6 +688,7 @@ int main(int argc, char **argv)
|
|||
json data;
|
||||
if (llama.has_next_token)
|
||||
{
|
||||
//fprintf(stdout, "Result: %s\n", result);
|
||||
final_text += result;
|
||||
data = {
|
||||
{"content", result },
|
||||
|
@ -715,6 +720,7 @@ int main(int argc, char **argv)
|
|||
}
|
||||
},
|
||||
{"prompt", llama.params.prompt},
|
||||
{"stopping_word", llama.stopping_word},
|
||||
{"generated_text", final_text}
|
||||
};
|
||||
}
|
||||
|
@ -735,7 +741,7 @@ int main(int argc, char **argv)
|
|||
{
|
||||
// Generation is done, send extra information.
|
||||
data = {
|
||||
{"content", "\uFFFD" },
|
||||
{"content", u8"\uFFFD" },
|
||||
{"stop", true },
|
||||
{"tokens_predicted", llama.num_tokens_predicted},
|
||||
{"seed", llama.params.seed},
|
||||
|
@ -756,6 +762,7 @@ int main(int argc, char **argv)
|
|||
}
|
||||
},
|
||||
{"prompt", llama.params.prompt},
|
||||
{"stopping_word", llama.stopping_word},
|
||||
{"generated_text", final_text}
|
||||
};
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue