server : (refactor) no more json in server_task input

This commit is contained in:
Xuan Son Nguyen 2024-12-06 15:01:12 +01:00
parent 6c5bc0625f
commit db97c8b19b
3 changed files with 372 additions and 365 deletions

View file

@ -82,28 +82,6 @@ enum error_type {
ERROR_TYPE_NOT_SUPPORTED, // custom error ERROR_TYPE_NOT_SUPPORTED, // custom error
}; };
struct server_task {
int id = -1; // to be filled by server_queue
int id_target = -1; // used by SERVER_TASK_TYPE_CANCEL
llama_tokens prompt_tokens;
server_task_type type;
// TODO @ngxson : we should get rid of json type here
json data;
server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
// utility function
static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
std::unordered_set<int> ids(tasks.size());
for (size_t i = 0; i < tasks.size(); i++) {
ids.insert(tasks[i].id);
}
return ids;
}
};
struct slot_params { struct slot_params {
bool stream = true; bool stream = true;
bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
@ -118,6 +96,7 @@ struct slot_params {
std::vector<std::string> antiprompt; std::vector<std::string> antiprompt;
bool timings_per_token = false; bool timings_per_token = false;
bool ignore_eos = false;
struct common_params_sampling sampling; struct common_params_sampling sampling;
struct common_params_speculative speculative; struct common_params_speculative speculative;
@ -134,7 +113,7 @@ struct slot_params {
std::string oaicompat_model; std::string oaicompat_model;
std::string oaicompat_cmpl_id; std::string oaicompat_cmpl_id;
json to_json() { json to_json() const {
std::vector<std::string> samplers; std::vector<std::string> samplers;
samplers.reserve(sampling.samplers.size()); samplers.reserve(sampling.samplers.size());
for (const auto & sampler : sampling.samplers) { for (const auto & sampler : sampling.samplers) {
@ -172,7 +151,7 @@ struct slot_params {
{"n_discard", n_discard}, {"n_discard", n_discard},
{"ignore_eos", sampling.ignore_eos}, {"ignore_eos", sampling.ignore_eos},
{"stream", stream}, {"stream", stream},
//{"logit_bias", sampling.logit_bias}, {"logit_bias", format_logit_bias(sampling.logit_bias)},
{"n_probs", sampling.n_probs}, {"n_probs", sampling.n_probs},
{"min_keep", sampling.min_keep}, {"min_keep", sampling.min_keep},
{"grammar", sampling.grammar}, {"grammar", sampling.grammar},
@ -186,6 +165,212 @@ struct slot_params {
} }
}; };
struct server_task {
int id = -1; // to be filled by server_queue
int index = -1; // used when there are multiple prompts (batch request)
server_task_type type;
server_task_inf_type inf_type;
// used by SERVER_TASK_TYPE_CANCEL
int id_target = -1;
// used by SERVER_TASK_TYPE_INFERENCE
slot_params params;
llama_tokens prompt_tokens;
int id_selected_slot = -1;
// used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE
struct slot_action {
int slot_id;
std::string filename;
std::string filepath;
};
slot_action slot_action;
// used by SERVER_TASK_TYPE_METRICS
bool metrics_reset_bucket = false;
server_task(
server_task_type type,
server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION) : type(type), inf_type(inf_type) {}
static slot_params params_from_json_cmpl(
const llama_model * model,
const common_params & params_base,
const json & data) {
slot_params params;
// Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
slot_params defaults;
defaults.sampling = params_base.sampling;
defaults.speculative = params_base.speculative;
// enabling this will output extra debug information in the HTTP responses from the server
params.verbose = params_base.verbosity > 9;
params.timings_per_token = json_value(data, "timings_per_token", false);
params.stream = json_value(data, "stream", false);
params.cache_prompt = json_value(data, "cache_prompt", true);
params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
params.n_indent = json_value(data, "n_indent", defaults.n_indent);
params.n_keep = json_value(data, "n_keep", defaults.n_keep);
params.n_discard = json_value(data, "n_discard", defaults.n_discard);
//params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p);
params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability);
params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold);
params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p);
params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp);
params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range);
params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent);
params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n);
params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat);
params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq);
params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present);
params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier);
params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base);
params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length);
params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n);
params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat);
params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau);
params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta);
params.sampling.penalize_nl = json_value(data, "penalize_nl", defaults.sampling.penalize_nl);
params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);
params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min);
params.speculative.n_min = std::min(params.speculative.n_max, params.speculative.n_min);
params.speculative.n_min = std::max(params.speculative.n_min, 2);
params.speculative.n_max = std::max(params.speculative.n_max, 0);
if (params.sampling.dry_base < 1.0f) {
params.sampling.dry_base = defaults.sampling.dry_base;
}
// sequence breakers for DRY
{
// Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format
// Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39
if (data.contains("dry_sequence_breakers")) {
params.sampling.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector<std::string>());
if (params.sampling.dry_sequence_breakers.empty()) {
throw std::runtime_error("Error: dry_sequence_breakers must be a non-empty array of strings");
}
}
}
// process "json_schema" and "grammar"
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
}
if (data.contains("json_schema") && !data.contains("grammar")) {
try {
auto schema = json_value(data, "json_schema", json::object());
params.sampling.grammar = json_schema_to_grammar(schema);
} catch (const std::exception & e) {
throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
}
} else {
params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
}
{
params.sampling.logit_bias.clear();
params.ignore_eos = json_value(data, "ignore_eos", false);
const auto & logit_bias = data.find("logit_bias");
if (logit_bias != data.end() && logit_bias->is_array()) {
const int n_vocab = llama_n_vocab(model);
for (const auto & el : *logit_bias) {
// TODO: we may want to throw errors here, in case "el" is incorrect
if (el.is_array() && el.size() == 2) {
float bias;
if (el[1].is_number()) {
bias = el[1].get<float>();
} else if (el[1].is_boolean() && !el[1].get<bool>()) {
bias = -INFINITY;
} else {
continue;
}
if (el[0].is_number_integer()) {
llama_token tok = el[0].get<llama_token>();
if (tok >= 0 && tok < n_vocab) {
params.sampling.logit_bias.push_back({tok, bias});
}
} else if (el[0].is_string()) {
auto toks = common_tokenize(model, el[0].get<std::string>(), false);
for (auto tok : toks) {
params.sampling.logit_bias.push_back({tok, bias});
}
}
}
}
}
}
{
params.antiprompt.clear();
const auto & stop = data.find("stop");
if (stop != data.end() && stop->is_array()) {
for (const auto & word : *stop) {
if (!word.empty()) {
params.antiprompt.push_back(word);
}
}
}
}
{
const auto & samplers = data.find("samplers");
if (samplers != data.end()) {
if (samplers->is_array()) {
std::vector<std::string> sampler_names;
for (const auto & name : *samplers) {
if (name.is_string()) {
sampler_names.emplace_back(name);
}
}
params.sampling.samplers = common_sampler_types_from_names(sampler_names, false);
} else if (samplers->is_string()){
std::string sampler_string;
for (const auto & name : *samplers) {
sampler_string += name;
}
params.sampling.samplers = common_sampler_types_from_chars(sampler_string);
}
} else {
params.sampling.samplers = defaults.sampling.samplers;
}
}
std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias;
params.oaicompat_model = json_value(data, "model", model_name);
return params;
}
// utility function
static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
std::unordered_set<int> ids(tasks.size());
for (size_t i = 0; i < tasks.size(); i++) {
ids.insert(tasks[i].id);
}
return ids;
}
};
struct result_timings { struct result_timings {
int32_t prompt_n = -1; int32_t prompt_n = -1;
double prompt_ms; double prompt_ms;
@ -197,7 +382,7 @@ struct result_timings {
double predicted_per_token_ms; double predicted_per_token_ms;
double predicted_per_second; double predicted_per_second;
json to_json() { json to_json() const {
return { return {
{"prompt_n", prompt_n}, {"prompt_n", prompt_n},
{"prompt_ms", prompt_ms}, {"prompt_ms", prompt_ms},
@ -861,6 +1046,22 @@ struct server_slot {
return timings; return timings;
} }
json to_json() const {
json res = params.to_json();
res["id"] = id;
res["id_task"] = id_task;
res["is_processing"] = is_processing();
res["prompt_tokens"] = prompt_tokens;
res["next_token"] = {
{"has_next_token", has_next_token},
{"has_new_line", has_new_line},
{"n_remain", n_remaining},
{"n_decoded", n_decoded},
{"stopping_word", stopping_word},
};
return res;
}
size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) { size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) {
size_t stop_pos = std::string::npos; size_t stop_pos = std::string::npos;
@ -978,9 +1179,7 @@ struct server_queue {
// Add a new task to the end of the queue // Add a new task to the end of the queue
int post(server_task task, bool front = false) { int post(server_task task, bool front = false) {
std::unique_lock<std::mutex> lock(mutex_tasks); std::unique_lock<std::mutex> lock(mutex_tasks);
if (task.id == -1) { GGML_ASSERT(task.id != -1);
task.id = id++;
}
QUE_DBG("new task, id = %d, front = %d\n", task.id, front); QUE_DBG("new task, id = %d, front = %d\n", task.id, front);
if (front) { if (front) {
queue_tasks.push_front(std::move(task)); queue_tasks.push_front(std::move(task));
@ -1458,104 +1657,14 @@ struct server_context {
} }
bool launch_slot_with_task(server_slot & slot, const server_task & task) { bool launch_slot_with_task(server_slot & slot, const server_task & task) {
// Sampling parameter defaults are loaded from the global server context (but individual requests can still override them) slot.reset();
slot_params defaults; slot.id_task = task.id;
defaults.sampling = params_base.sampling; slot.inf_type = task.inf_type;
defaults.speculative = params_base.speculative; slot.index = task.index;
slot.params = std::move(task.params);
slot.prompt_tokens = std::move(task.prompt_tokens);
const auto & data = task.data; SLT_DBG(slot, "launching slot : %s\n", slot.to_json().dump().c_str());
if (data.count("__oaicompat") != 0) {
std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias;
slot.params.oaicompat = true;
slot.params.oaicompat_chat = json_value(data, "__oaicompat_chat", false);
slot.params.oaicompat_model = json_value(data, "model", model_name);
slot.params.oaicompat_cmpl_id = json_value(data, "completion_id", std::string());
} else {
slot.params.oaicompat = false;
}
// enabling this will output extra debug information in the HTTP responses from the server
slot.params.verbose = params_base.verbosity > 9;
slot.params.timings_per_token = json_value(data, "timings_per_token", false);
slot.params.stream = json_value(data, "stream", false);
slot.params.cache_prompt = json_value(data, "cache_prompt", true);
slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
slot.params.n_indent = json_value(data, "n_indent", defaults.n_indent);
slot.params.n_keep = json_value(data, "n_keep", defaults.n_keep);
slot.params.n_discard = json_value(data, "n_discard", defaults.n_discard);
//slot.params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
slot.params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
slot.params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
slot.params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
slot.params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p);
slot.params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability);
slot.params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold);
slot.params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p);
slot.params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp);
slot.params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range);
slot.params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent);
slot.params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n);
slot.params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat);
slot.params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq);
slot.params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present);
slot.params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier);
slot.params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base);
slot.params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length);
slot.params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n);
slot.params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat);
slot.params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau);
slot.params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta);
slot.params.sampling.penalize_nl = json_value(data, "penalize_nl", defaults.sampling.penalize_nl);
slot.params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
slot.params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
slot.params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
slot.params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
slot.params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);
slot.params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min);
slot.params.speculative.n_min = std::min(slot.params.speculative.n_max, slot.params.speculative.n_min);
slot.params.speculative.n_min = std::max(slot.params.speculative.n_min, 2);
slot.params.speculative.n_max = std::max(slot.params.speculative.n_max, 0);
if (slot.params.sampling.dry_base < 1.0f) {
slot.params.sampling.dry_base = defaults.sampling.dry_base;
}
// sequence breakers for DRY
{
// Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format
// Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39
if (data.contains("dry_sequence_breakers")) {
slot.params.sampling.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector<std::string>());
if (slot.params.sampling.dry_sequence_breakers.empty()) {
send_error(task, "Error: dry_sequence_breakers must be a non-empty array of strings", ERROR_TYPE_INVALID_REQUEST);
return false;
}
}
}
// process "json_schema" and "grammar"
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
send_error(task, "Either \"json_schema\" or \"grammar\" can be specified, but not both", ERROR_TYPE_INVALID_REQUEST);
return false;
}
if (data.contains("json_schema") && !data.contains("grammar")) {
try {
auto schema = json_value(data, "json_schema", json::object());
slot.params.sampling.grammar = json_schema_to_grammar(schema);
} catch (const std::exception & e) {
send_error(task, std::string("\"json_schema\": ") + e.what(), ERROR_TYPE_INVALID_REQUEST);
return false;
}
} else {
slot.params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
}
if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) { if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
// Might be better to reject the request with a 400 ? // Might be better to reject the request with a 400 ?
@ -1563,80 +1672,10 @@ struct server_context {
SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.n_predict, slot.n_predict); SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.n_predict, slot.n_predict);
} }
{ if (slot.params.ignore_eos && has_eos_token) {
slot.params.sampling.logit_bias.clear();
if (json_value(data, "ignore_eos", false) && has_eos_token) {
slot.params.sampling.logit_bias.push_back({llama_token_eos(model), -INFINITY}); slot.params.sampling.logit_bias.push_back({llama_token_eos(model), -INFINITY});
} }
const auto & logit_bias = data.find("logit_bias");
if (logit_bias != data.end() && logit_bias->is_array()) {
const int n_vocab = llama_n_vocab(model);
for (const auto & el : *logit_bias) {
// TODO: we may want to throw errors here, in case "el" is incorrect
if (el.is_array() && el.size() == 2) {
float bias;
if (el[1].is_number()) {
bias = el[1].get<float>();
} else if (el[1].is_boolean() && !el[1].get<bool>()) {
bias = -INFINITY;
} else {
continue;
}
if (el[0].is_number_integer()) {
llama_token tok = el[0].get<llama_token>();
if (tok >= 0 && tok < n_vocab) {
slot.params.sampling.logit_bias.push_back({tok, bias});
}
} else if (el[0].is_string()) {
auto toks = common_tokenize(model, el[0].get<std::string>(), false);
for (auto tok : toks) {
slot.params.sampling.logit_bias.push_back({tok, bias});
}
}
}
}
}
}
{
slot.params.antiprompt.clear();
const auto & stop = data.find("stop");
if (stop != data.end() && stop->is_array()) {
for (const auto & word : *stop) {
if (!word.empty()) {
slot.params.antiprompt.push_back(word);
}
}
}
}
{
const auto & samplers = data.find("samplers");
if (samplers != data.end()) {
if (samplers->is_array()) {
std::vector<std::string> sampler_names;
for (const auto & name : *samplers) {
if (name.is_string()) {
sampler_names.emplace_back(name);
}
}
slot.params.sampling.samplers = common_sampler_types_from_names(sampler_names, false);
} else if (samplers->is_string()){
std::string sampler_string;
for (const auto & name : *samplers) {
sampler_string += name;
}
slot.params.sampling.samplers = common_sampler_types_from_chars(sampler_string);
}
} else {
slot.params.sampling.samplers = defaults.sampling.samplers;
}
}
{ {
if (slot.smpl != nullptr) { if (slot.smpl != nullptr) {
common_sampler_free(slot.smpl); common_sampler_free(slot.smpl);
@ -2007,81 +2046,13 @@ struct server_context {
// Functions to create new task(s) and receive result(s) // Functions to create new task(s) and receive result(s)
// //
// break the input "prompt" into multiple tasks if needed, then format and tokenize the input prompt(s)
std::vector<server_task> create_tasks_inference(json data, server_task_inf_type inf_type) {
std::vector<server_task> tasks;
auto create_task = [&](json & task_data, llama_tokens & prompt_tokens) {
SRV_DBG("create task, n_tokens = %d\n", (int) prompt_tokens.size());
server_task task;
task.id = queue_tasks.get_new_id();
task.inf_type = inf_type;
task.type = SERVER_TASK_TYPE_INFERENCE;
task.data = task_data;
task.prompt_tokens = std::move(prompt_tokens);
tasks.push_back(std::move(task));
};
static constexpr const char * error_msg = "\"prompt\" must be a string, an array of token ids or an array of prompts";
if (!data.contains("prompt")) {
throw std::runtime_error(error_msg);
}
// because llama_tokenize api is thread-safe, we can tokenize the prompt from HTTP thread
bool add_special = inf_type != SERVER_TASK_INF_TYPE_RERANK && inf_type != SERVER_TASK_INF_TYPE_INFILL;
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx, data.at("prompt"), add_special, true);
switch (inf_type) {
case SERVER_TASK_INF_TYPE_RERANK:
{
// prompts[0] is the question
// the rest are the answers/documents
GGML_ASSERT(tokenized_prompts.size() > 1);
SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) tokenized_prompts.size() - 1);
for (size_t i = 1; i < tokenized_prompts.size(); i++) {
data["index"] = i - 1;
auto tokens = format_rerank(model, tokenized_prompts[0], tokenized_prompts[i]);
create_task(data, tokens);
}
} break;
case SERVER_TASK_INF_TYPE_INFILL:
{
SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
data["index"] = i;
auto tokens = format_infill(
ctx,
data.at("input_prefix"),
data.at("input_suffix"),
data.at("input_extra"),
params_base.n_batch,
params_base.n_predict,
slots[0].n_ctx, // TODO: there should be a better way
params_base.spm_infill,
tokenized_prompts[i]
);
create_task(data, tokens);
}
} break;
default:
{
SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
data["index"] = i;
create_task(data, tokenized_prompts[i]);
}
}
}
return tasks;
}
void cancel_tasks(const std::unordered_set<int> & id_tasks) { void cancel_tasks(const std::unordered_set<int> & id_tasks) {
std::vector<server_task> cancel_tasks; std::vector<server_task> cancel_tasks;
cancel_tasks.reserve(id_tasks.size()); cancel_tasks.reserve(id_tasks.size());
for (const auto & id_task : id_tasks) { for (const auto & id_task : id_tasks) {
SRV_WRN("cancel task, id_task = %d\n", id_task); SRV_WRN("cancel task, id_task = %d\n", id_task);
server_task task; server_task task(SERVER_TASK_TYPE_CANCEL);
task.type = SERVER_TASK_TYPE_CANCEL;
task.id_target = id_task; task.id_target = id_task;
cancel_tasks.push_back(task); cancel_tasks.push_back(task);
queue_results.remove_waiting_task_id(id_task); queue_results.remove_waiting_task_id(id_task);
@ -2090,7 +2061,7 @@ struct server_context {
queue_tasks.post(cancel_tasks, true); queue_tasks.post(cancel_tasks, true);
} }
// receive the results from task(s) created by create_tasks_inference // receive the results from task(s)
void receive_multi_results( void receive_multi_results(
const std::unordered_set<int> & id_tasks, const std::unordered_set<int> & id_tasks,
const std::function<void(std::vector<server_task_result_ptr>&)> & result_handler, const std::function<void(std::vector<server_task_result_ptr>&)> & result_handler,
@ -2117,7 +2088,7 @@ struct server_context {
result_handler(results); result_handler(results);
} }
// receive the results from task(s) created by create_tasks_inference, in stream mode // receive the results from task(s), in stream mode
void receive_cmpl_results_stream( void receive_cmpl_results_stream(
const std::unordered_set<int> & id_tasks, const std::unordered_set<int> & id_tasks,
const std::function<bool(server_task_result_ptr&)> & result_handler, const std::function<bool(server_task_result_ptr&)> & result_handler,
@ -2154,7 +2125,7 @@ struct server_context {
switch (task.type) { switch (task.type) {
case SERVER_TASK_TYPE_INFERENCE: case SERVER_TASK_TYPE_INFERENCE:
{ {
const int id_slot = json_value(task.data, "id_slot", -1); const int id_slot = task.id_selected_slot;
server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task); server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task);
@ -2171,13 +2142,6 @@ struct server_context {
break; break;
} }
slot->reset();
slot->id_task = task.id;
slot->inf_type = task.inf_type;
slot->index = json_value(task.data, "index", 0);
slot->prompt_tokens = std::move(task.prompt_tokens);
if (!launch_slot_with_task(*slot, task)) { if (!launch_slot_with_task(*slot, task)) {
SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id); SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id);
break; break;
@ -2205,18 +2169,8 @@ struct server_context {
int n_processing_slots = 0; int n_processing_slots = 0;
for (server_slot & slot : slots) { for (server_slot & slot : slots) {
json slot_data = slot.params.to_json(); json slot_data = slot.to_json();
slot_data["id"] = slot.id;
slot_data["id_task"] = slot.id_task;
slot_data["is_processing"] = slot.is_processing();
slot_data["prompt"] = common_detokenize(ctx, slot.prompt_tokens); slot_data["prompt"] = common_detokenize(ctx, slot.prompt_tokens);
slot_data["next_token"] = {
{"has_next_token", slot.has_next_token},
{"has_new_line", slot.has_new_line},
{"n_remain", slot.n_remaining},
{"n_decoded", slot.n_decoded},
{"stopping_word", slot.stopping_word},
};
if (slot.is_processing()) { if (slot.is_processing()) {
n_processing_slots++; n_processing_slots++;
@ -2251,14 +2205,14 @@ struct server_context {
res->n_decode_total = metrics.n_decode_total; res->n_decode_total = metrics.n_decode_total;
res->n_busy_slots_total = metrics.n_busy_slots_total; res->n_busy_slots_total = metrics.n_busy_slots_total;
if (json_value(task.data, "reset_bucket", false)) { if (task.metrics_reset_bucket) {
metrics.reset_bucket(); metrics.reset_bucket();
} }
queue_results.send(std::move(res)); queue_results.send(std::move(res));
} break; } break;
case SERVER_TASK_TYPE_SLOT_SAVE: case SERVER_TASK_TYPE_SLOT_SAVE:
{ {
int id_slot = task.data.at("id_slot"); int id_slot = task.slot_action.slot_id;
server_slot * slot = get_slot_by_id(id_slot); server_slot * slot = get_slot_by_id(id_slot);
if (slot == nullptr) { if (slot == nullptr) {
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
@ -2274,8 +2228,8 @@ struct server_context {
const size_t token_count = slot->cache_tokens.size(); const size_t token_count = slot->cache_tokens.size();
const int64_t t_start = ggml_time_us(); const int64_t t_start = ggml_time_us();
std::string filename = task.data.at("filename"); std::string filename = task.slot_action.filename;
std::string filepath = task.data.at("filepath"); std::string filepath = task.slot_action.filepath;
const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count); const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count);
@ -2294,7 +2248,7 @@ struct server_context {
} break; } break;
case SERVER_TASK_TYPE_SLOT_RESTORE: case SERVER_TASK_TYPE_SLOT_RESTORE:
{ {
int id_slot = task.data.at("id_slot"); int id_slot = task.slot_action.slot_id;
server_slot * slot = get_slot_by_id(id_slot); server_slot * slot = get_slot_by_id(id_slot);
if (slot == nullptr) { if (slot == nullptr) {
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
@ -2309,8 +2263,8 @@ struct server_context {
const int64_t t_start = ggml_time_us(); const int64_t t_start = ggml_time_us();
std::string filename = task.data.at("filename"); std::string filename = task.slot_action.filename;
std::string filepath = task.data.at("filepath"); std::string filepath = task.slot_action.filepath;
slot->cache_tokens.resize(slot->n_ctx); slot->cache_tokens.resize(slot->n_ctx);
size_t token_count = 0; size_t token_count = 0;
@ -2337,7 +2291,7 @@ struct server_context {
} break; } break;
case SERVER_TASK_TYPE_SLOT_ERASE: case SERVER_TASK_TYPE_SLOT_ERASE:
{ {
int id_slot = task.data.at("id_slot"); int id_slot = task.slot_action.slot_id;
server_slot * slot = get_slot_by_id(id_slot); server_slot * slot = get_slot_by_id(id_slot);
if (slot == nullptr) { if (slot == nullptr) {
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
@ -2396,10 +2350,8 @@ struct server_context {
{ {
SRV_DBG("%s", "posting NEXT_RESPONSE\n"); SRV_DBG("%s", "posting NEXT_RESPONSE\n");
server_task task; server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE);
task.type = SERVER_TASK_TYPE_NEXT_RESPONSE; task.id = queue_tasks.get_new_id();
task.id_target = -1;
queue_tasks.post(task); queue_tasks.post(task);
} }
@ -3136,10 +3088,8 @@ int main(int argc, char ** argv) {
} }
// request slots data using task queue // request slots data using task queue
server_task task; server_task task(SERVER_TASK_TYPE_METRICS);
task.id = ctx_server.queue_tasks.get_new_id(); task.id = ctx_server.queue_tasks.get_new_id();
task.type = SERVER_TASK_TYPE_METRICS;
ctx_server.queue_results.add_waiting_task_id(task.id); ctx_server.queue_results.add_waiting_task_id(task.id);
ctx_server.queue_tasks.post(task, true); // high-priority task ctx_server.queue_tasks.post(task, true); // high-priority task
@ -3174,11 +3124,9 @@ int main(int argc, char ** argv) {
} }
// request slots data using task queue // request slots data using task queue
server_task task; server_task task(SERVER_TASK_TYPE_METRICS);
task.id = ctx_server.queue_tasks.get_new_id(); task.id = ctx_server.queue_tasks.get_new_id();
task.id_target = -1; task.metrics_reset_bucket = true;
task.type = SERVER_TASK_TYPE_METRICS;
task.data.push_back({{"reset_bucket", true}});
ctx_server.queue_results.add_waiting_task_id(task.id); ctx_server.queue_results.add_waiting_task_id(task.id);
ctx_server.queue_tasks.post(task, true); // high-priority task ctx_server.queue_tasks.post(task, true); // high-priority task
@ -3282,19 +3230,17 @@ int main(int argc, char ** argv) {
} }
std::string filepath = params.slot_save_path + filename; std::string filepath = params.slot_save_path + filename;
server_task task; server_task task(SERVER_TASK_TYPE_SLOT_SAVE);
task.type = SERVER_TASK_TYPE_SLOT_SAVE; task.id = ctx_server.queue_tasks.get_new_id();
task.data = { task.slot_action.slot_id = id_slot;
{ "id_slot", id_slot }, task.slot_action.filename = filename;
{ "filename", filename }, task.slot_action.filepath = filepath;
{ "filepath", filepath },
};
const int id_task = ctx_server.queue_tasks.post(task); ctx_server.queue_results.add_waiting_task_id(task.id);
ctx_server.queue_results.add_waiting_task_id(id_task); ctx_server.queue_tasks.post(task);
server_task_result_ptr result = ctx_server.queue_results.recv(id_task); server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
ctx_server.queue_results.remove_waiting_task_id(id_task); ctx_server.queue_results.remove_waiting_task_id(task.id);
if (result->is_error()) { if (result->is_error()) {
res_error(res, result->to_json()); res_error(res, result->to_json());
@ -3313,19 +3259,17 @@ int main(int argc, char ** argv) {
} }
std::string filepath = params.slot_save_path + filename; std::string filepath = params.slot_save_path + filename;
server_task task; server_task task(SERVER_TASK_TYPE_SLOT_RESTORE);
task.type = SERVER_TASK_TYPE_SLOT_RESTORE; task.id = ctx_server.queue_tasks.get_new_id();
task.data = { task.slot_action.slot_id = id_slot;
{ "id_slot", id_slot }, task.slot_action.filename = filename;
{ "filename", filename }, task.slot_action.filepath = filepath;
{ "filepath", filepath },
};
const int id_task = ctx_server.queue_tasks.post(task); ctx_server.queue_results.add_waiting_task_id(task.id);
ctx_server.queue_results.add_waiting_task_id(id_task); ctx_server.queue_tasks.post(task);
server_task_result_ptr result = ctx_server.queue_results.recv(id_task); server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
ctx_server.queue_results.remove_waiting_task_id(id_task); ctx_server.queue_results.remove_waiting_task_id(task.id);
if (result->is_error()) { if (result->is_error()) {
res_error(res, result->to_json()); res_error(res, result->to_json());
@ -3337,17 +3281,15 @@ int main(int argc, char ** argv) {
}; };
const auto handle_slots_erase = [&ctx_server, &res_error, &res_ok](const httplib::Request & /* req */, httplib::Response & res, int id_slot) { const auto handle_slots_erase = [&ctx_server, &res_error, &res_ok](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
server_task task; server_task task(SERVER_TASK_TYPE_SLOT_ERASE);
task.type = SERVER_TASK_TYPE_SLOT_ERASE; task.id = ctx_server.queue_tasks.get_new_id();
task.data = { task.slot_action.slot_id = id_slot;
{ "id_slot", id_slot },
};
const int id_task = ctx_server.queue_tasks.post(task); ctx_server.queue_results.add_waiting_task_id(task.id);
ctx_server.queue_results.add_waiting_task_id(id_task); ctx_server.queue_tasks.post(task);
server_task_result_ptr result = ctx_server.queue_results.recv(id_task); server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
ctx_server.queue_results.remove_waiting_task_id(id_task); ctx_server.queue_results.remove_waiting_task_id(task.id);
if (result->is_error()) { if (result->is_error()) {
res_error(res, result->to_json()); res_error(res, result->to_json());
@ -3416,14 +3358,41 @@ int main(int argc, char ** argv) {
server_task_inf_type inf_type, server_task_inf_type inf_type,
json & data, json & data,
httplib::Response & res, httplib::Response & res,
bool oai_compat = false) { bool oaicompat = false,
bool oaicompat_chat = false) {
if (ctx_server.params_base.embedding) { if (ctx_server.params_base.embedding) {
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return; return;
} }
data["completion_id"] = gen_chatcmplid(); auto completion_id = gen_chatcmplid();
std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, inf_type); std::vector<server_task> tasks;
try {
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, data.at("prompt"), true, true);
tasks.reserve(tokenized_prompts.size());
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
server_task task = server_task(SERVER_TASK_TYPE_INFERENCE, inf_type);
task.id = ctx_server.queue_tasks.get_new_id();
task.index = i;
task.prompt_tokens = std::move(tokenized_prompts[i]);
task.params = server_task::params_from_json_cmpl(ctx_server.model, ctx_server.params_base, data);
task.id_selected_slot = json_value(data, "id_slot", -1);
// OAI-compat
task.params.oaicompat = oaicompat;
task.params.oaicompat_chat = oaicompat_chat;
task.params.oaicompat_cmpl_id = completion_id;
// oaicompat_model is already populated by params_from_json_cmpl
tasks.push_back(task);
}
} catch (const std::exception & e) {
res_error(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
return;
}
ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks); ctx_server.queue_tasks.post(tasks);
@ -3449,7 +3418,7 @@ int main(int argc, char ** argv) {
ctx_server.queue_results.remove_waiting_task_ids(task_ids); ctx_server.queue_results.remove_waiting_task_ids(task_ids);
} else { } else {
const auto chunked_content_provider = [task_ids, &ctx_server, oai_compat](size_t, httplib::DataSink & sink) { const auto chunked_content_provider = [task_ids, &ctx_server, oaicompat](size_t, httplib::DataSink & sink) {
ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_ptr & result) -> bool { ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_ptr & result) -> bool {
json res_json = result->to_json(); json res_json = result->to_json();
if (res_json.is_array()) { if (res_json.is_array()) {
@ -3465,7 +3434,7 @@ int main(int argc, char ** argv) {
}, [&](const json & error_data) { }, [&](const json & error_data) {
server_sent_event(sink, "error", error_data); server_sent_event(sink, "error", error_data);
}); });
if (oai_compat) { if (oaicompat) {
static const std::string ev_done = "data: [DONE]\n\n"; static const std::string ev_done = "data: [DONE]\n\n";
sink.write(ev_done.data(), ev_done.size()); sink.write(ev_done.data(), ev_done.size());
} }
@ -3483,7 +3452,12 @@ int main(int argc, char ** argv) {
const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) { const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
json data = json::parse(req.body); json data = json::parse(req.body);
return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res); return handle_completions_generic(
SERVER_TASK_INF_TYPE_COMPLETION,
data,
res,
/* oaicompat */ false,
/* oaicompat_chat */ false);
}; };
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) { const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
@ -3543,8 +3517,12 @@ int main(int argc, char ** argv) {
} }
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
data["__oaicompat_chat"] = true; return handle_completions_generic(
return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res, true); SERVER_TASK_INF_TYPE_COMPLETION,
data,
res,
/* oaicompat */ true,
/* oaicompat_chat */ true);
}; };
const auto handle_models = [&params, &ctx_server](const httplib::Request &, httplib::Response & res) { const auto handle_models = [&params, &ctx_server](const httplib::Request &, httplib::Response & res) {
@ -3638,7 +3616,16 @@ int main(int argc, char ** argv) {
json responses = json::array(); json responses = json::array();
bool error = false; bool error = false;
{ {
std::vector<server_task> tasks = ctx_server.create_tasks_inference({{"prompt", prompt}}, SERVER_TASK_INF_TYPE_EMBEDDING); std::vector<server_task> tasks;
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, /* add_special */ false, true);
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
server_task task = server_task(SERVER_TASK_TYPE_INFERENCE, SERVER_TASK_INF_TYPE_EMBEDDING);
task.id = ctx_server.queue_tasks.get_new_id();
task.index = i;
task.prompt_tokens = std::move(tokenized_prompts[i]);
tasks.push_back(task);
}
ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks); ctx_server.queue_tasks.post(tasks);
@ -3665,7 +3652,7 @@ int main(int argc, char ** argv) {
// write JSON response // write JSON response
json root = oaicompat json root = oaicompat
? format_embeddings_response_oaicompat(body, responses) ? format_embeddings_response_oaicompat(body, responses)
: responses[0]; : responses.size() == 1 ? responses[0] : json(responses);
res_ok(res, root); res_ok(res, root);
}; };
@ -3704,20 +3691,23 @@ int main(int argc, char ** argv) {
return; return;
} }
// construct prompt object: array of ["query", "doc0", "doc1", ...] llama_tokens tokenized_query = tokenize_input_prompts(ctx_server.ctx, query, /* add_special */ false, true)[0];
json prompt;
prompt.push_back(query);
for (const auto & doc : documents) {
prompt.push_back(doc);
}
LOG_DBG("rerank prompt: %s\n", prompt.dump().c_str());
// create and queue the task // create and queue the task
json responses = json::array(); json responses = json::array();
bool error = false; bool error = false;
{ {
std::vector<server_task> tasks = ctx_server.create_tasks_inference({{"prompt", prompt}}, SERVER_TASK_INF_TYPE_RERANK); std::vector<server_task> tasks;
std::vector<llama_tokens> tokenized_docs = tokenize_input_prompts(ctx_server.ctx, documents, /* add_special */ false, true);
tasks.reserve(tokenized_docs.size());
for (size_t i = 0; i < tokenized_docs.size(); i++) {
server_task task = server_task(SERVER_TASK_TYPE_INFERENCE, SERVER_TASK_INF_TYPE_RERANK);
task.id = ctx_server.queue_tasks.get_new_id();
task.index = i;
task.prompt_tokens = format_rerank(ctx_server.model, tokenized_query, tokenized_docs[i]);
tasks.push_back(task);
}
ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks); ctx_server.queue_tasks.post(tasks);
@ -3778,13 +3768,13 @@ int main(int argc, char ** argv) {
} }
} }
server_task task; server_task task(SERVER_TASK_TYPE_SET_LORA);
task.type = SERVER_TASK_TYPE_SET_LORA; task.id = ctx_server.queue_tasks.get_new_id();
const int id_task = ctx_server.queue_tasks.post(task); ctx_server.queue_results.add_waiting_task_id(task.id);
ctx_server.queue_results.add_waiting_task_id(id_task); ctx_server.queue_tasks.post(task);
server_task_result_ptr result = ctx_server.queue_results.recv(id_task); server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
ctx_server.queue_results.remove_waiting_task_id(id_task); ctx_server.queue_results.remove_waiting_task_id(task.id);
if (result->is_error()) { if (result->is_error()) {
res_error(res, result->to_json()); res_error(res, result->to_json());

View file

@ -30,6 +30,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte
], ],
}) })
assert res.status_code == 200 assert res.status_code == 200
assert "cmpl" in res.body["id"]
assert res.body["model"] == model if model is not None else server.model_alias assert res.body["model"] == model if model is not None else server.model_alias
assert res.body["usage"]["prompt_tokens"] == n_prompt assert res.body["usage"]["prompt_tokens"] == n_prompt
assert res.body["usage"]["completion_tokens"] == n_predicted assert res.body["usage"]["completion_tokens"] == n_predicted
@ -59,9 +60,13 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
"stream": True, "stream": True,
}) })
content = "" content = ""
last_cmpl_id = None
for data in res: for data in res:
choice = data["choices"][0] choice = data["choices"][0]
assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future
if last_cmpl_id is None:
last_cmpl_id = data["id"]
assert last_cmpl_id == data["id"]
if choice["finish_reason"] in ["stop", "length"]: if choice["finish_reason"] in ["stop", "length"]:
assert data["usage"]["prompt_tokens"] == n_prompt assert data["usage"]["prompt_tokens"] == n_prompt
assert data["usage"]["completion_tokens"] == n_predicted assert data["usage"]["completion_tokens"] == n_predicted

View file

@ -164,6 +164,9 @@ static std::vector<llama_tokens> tokenize_input_prompts(llama_context * ctx, con
} else { } else {
throw std::runtime_error("\"prompt\" must be a string, an list of tokens, a list of mixed strings & tokens, or a list of prompts"); throw std::runtime_error("\"prompt\" must be a string, an list of tokens, a list of mixed strings & tokens, or a list of prompts");
} }
if (result.empty()) {
throw std::runtime_error("\"prompt\" must not be empty");
}
return result; return result;
} }
@ -496,8 +499,6 @@ static json oaicompat_completion_params_parse(
const std::string & chat_template) { const std::string & chat_template) {
json llama_params; json llama_params;
llama_params["__oaicompat"] = true;
// Apply chat template to the list of messages // Apply chat template to the list of messages
llama_params["prompt"] = format_chat(model, chat_template, body.at("messages")); llama_params["prompt"] = format_chat(model, chat_template, body.at("messages"));
@ -648,3 +649,14 @@ static json format_detokenized_response(const std::string & content) {
{"content", content} {"content", content}
}; };
} }
static json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias) {
json data = json::array();
for (const auto & lb : logit_bias) {
data.push_back(json{
{"bias", lb.bias},
{"token", lb.token},
});
}
return data;
}