rename completion to inference

This commit is contained in:
Xuan Son Nguyen 2024-10-24 16:29:38 +02:00
parent 575b1332ab
commit 4a9f3e7628

View file

@ -65,7 +65,7 @@ enum server_state {
}; };
enum server_task_type { enum server_task_type {
SERVER_TASK_TYPE_COMPLETION, SERVER_TASK_TYPE_INFERENCE,
SERVER_TASK_TYPE_CANCEL, SERVER_TASK_TYPE_CANCEL,
SERVER_TASK_TYPE_NEXT_RESPONSE, SERVER_TASK_TYPE_NEXT_RESPONSE,
SERVER_TASK_TYPE_METRICS, SERVER_TASK_TYPE_METRICS,
@ -75,11 +75,11 @@ enum server_task_type {
SERVER_TASK_TYPE_SET_LORA, SERVER_TASK_TYPE_SET_LORA,
}; };
enum server_task_cmpl_type { enum server_task_inf_type {
SERVER_TASK_CMPL_TYPE_NORMAL, SERVER_TASK_INF_TYPE_COMPLETION,
SERVER_TASK_CMPL_TYPE_EMBEDDING, SERVER_TASK_INF_TYPE_EMBEDDING,
SERVER_TASK_CMPL_TYPE_RERANK, SERVER_TASK_INF_TYPE_RERANK,
SERVER_TASK_CMPL_TYPE_INFILL, SERVER_TASK_INF_TYPE_INFILL,
}; };
struct server_task { struct server_task {
@ -90,7 +90,7 @@ struct server_task {
server_task_type type; server_task_type type;
json data; json data;
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL; server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
// utility function // utility function
static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) { static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
@ -161,7 +161,7 @@ struct server_slot {
std::vector<llama_token> cache_tokens; std::vector<llama_token> cache_tokens;
std::vector<completion_token_output> generated_token_probs; std::vector<completion_token_output> generated_token_probs;
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL; server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
bool has_next_token = true; bool has_next_token = true;
bool has_new_line = false; bool has_new_line = false;
@ -210,7 +210,7 @@ struct server_slot {
n_past = 0; n_past = 0;
n_sent_text = 0; n_sent_text = 0;
n_sent_token_probs = 0; n_sent_token_probs = 0;
cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL; inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
generated_token_probs.clear(); generated_token_probs.clear();
} }
@ -1345,14 +1345,14 @@ struct server_context {
// //
// break the input "prompt" into multiple tasks if needed, then format and tokenize the input prompt(s) // break the input "prompt" into multiple tasks if needed, then format and tokenize the input prompt(s)
std::vector<server_task> create_tasks_cmpl(json data, server_task_cmpl_type cmpl_type) { std::vector<server_task> create_tasks_inference(json data, server_task_inf_type inf_type) {
std::vector<server_task> tasks; std::vector<server_task> tasks;
auto create_task = [&](json & task_data, llama_tokens & prompt_tokens) { auto create_task = [&](json & task_data, llama_tokens & prompt_tokens) {
SRV_DBG("create task, n_tokens = %d\n", (int) prompt_tokens.size()); SRV_DBG("create task, n_tokens = %d\n", (int) prompt_tokens.size());
server_task task; server_task task;
task.id = queue_tasks.get_new_id(); task.id = queue_tasks.get_new_id();
task.cmpl_type = cmpl_type; task.inf_type = inf_type;
task.type = SERVER_TASK_TYPE_COMPLETION; task.type = SERVER_TASK_TYPE_INFERENCE;
task.data = task_data; task.data = task_data;
task.prompt_tokens = std::move(prompt_tokens); task.prompt_tokens = std::move(prompt_tokens);
tasks.push_back(std::move(task)); tasks.push_back(std::move(task));
@ -1364,10 +1364,10 @@ struct server_context {
} }
// because llama_tokenize api is thread-safe, we can tokenize the prompt from HTTP thread // because llama_tokenize api is thread-safe, we can tokenize the prompt from HTTP thread
bool add_special = cmpl_type != SERVER_TASK_CMPL_TYPE_RERANK && cmpl_type != SERVER_TASK_CMPL_TYPE_INFILL; bool add_special = inf_type != SERVER_TASK_INF_TYPE_RERANK && inf_type != SERVER_TASK_INF_TYPE_INFILL;
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx, data.at("prompt"), add_special, true); std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx, data.at("prompt"), add_special, true);
switch (cmpl_type) { switch (inf_type) {
case SERVER_TASK_CMPL_TYPE_RERANK: case SERVER_TASK_INF_TYPE_RERANK:
{ {
// prompts[0] is the question // prompts[0] is the question
// the rest are the answers/documents // the rest are the answers/documents
@ -1379,7 +1379,7 @@ struct server_context {
create_task(data, tokens); create_task(data, tokens);
} }
} break; } break;
case SERVER_TASK_CMPL_TYPE_INFILL: case SERVER_TASK_INF_TYPE_INFILL:
{ {
SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size()); SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
for (size_t i = 0; i < tokenized_prompts.size(); i++) { for (size_t i = 0; i < tokenized_prompts.size(); i++) {
@ -1427,7 +1427,7 @@ struct server_context {
queue_tasks.post(cancel_tasks, true); queue_tasks.post(cancel_tasks, true);
} }
// receive the results from task(s) created by create_tasks_cmpl // receive the results from task(s) created by create_tasks_inference
void receive_cmpl_results( void receive_cmpl_results(
const std::unordered_set<int> & id_tasks, const std::unordered_set<int> & id_tasks,
const std::function<void(std::vector<server_task_result>&)> & result_handler, const std::function<void(std::vector<server_task_result>&)> & result_handler,
@ -1451,7 +1451,7 @@ struct server_context {
result_handler(results); result_handler(results);
} }
// receive the results from task(s) created by create_tasks_cmpl, in stream mode // receive the results from task(s) created by create_tasks_inference, in stream mode
void receive_cmpl_results_stream( void receive_cmpl_results_stream(
const std::unordered_set<int> & id_tasks, const const std::unordered_set<int> & id_tasks, const
std::function<bool(server_task_result&)> & result_handler, const std::function<bool(server_task_result&)> & result_handler, const
@ -1484,7 +1484,7 @@ struct server_context {
void process_single_task(const server_task & task) { void process_single_task(const server_task & task) {
switch (task.type) { switch (task.type) {
case SERVER_TASK_TYPE_COMPLETION: case SERVER_TASK_TYPE_INFERENCE:
{ {
const int id_slot = json_value(task.data, "id_slot", -1); const int id_slot = json_value(task.data, "id_slot", -1);
@ -1517,7 +1517,7 @@ struct server_context {
slot->reset(); slot->reset();
slot->id_task = task.id; slot->id_task = task.id;
slot->cmpl_type = task.cmpl_type; slot->inf_type = task.inf_type;
slot->index = json_value(task.data, "index", 0); slot->index = json_value(task.data, "index", 0);
slot->prompt_tokens = std::move(task.prompt_tokens); slot->prompt_tokens = std::move(task.prompt_tokens);
@ -1881,7 +1881,7 @@ struct server_context {
continue; continue;
} }
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
// this prompt is too large to process - discard it // this prompt is too large to process - discard it
if (slot.n_prompt_tokens > n_ubatch) { if (slot.n_prompt_tokens > n_ubatch) {
slot.release(); slot.release();
@ -1992,7 +1992,7 @@ struct server_context {
} }
// non-causal tasks require to fit the entire prompt in the physical batch // non-causal tasks require to fit the entire prompt in the physical batch
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
// cannot fit the prompt in the current batch - will try next iter // cannot fit the prompt in the current batch - will try next iter
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
continue; continue;
@ -2001,8 +2001,8 @@ struct server_context {
// check that we are in the right batch_type, if not defer the slot // check that we are in the right batch_type, if not defer the slot
const bool slot_type = const bool slot_type =
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING ||
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ? 1 : 0; slot.inf_type == SERVER_TASK_INF_TYPE_RERANK ? 1 : 0;
if (batch_type == -1) { if (batch_type == -1) {
batch_type = slot_type; batch_type = slot_type;
@ -2120,7 +2120,7 @@ struct server_context {
} }
if (slot.state == SLOT_STATE_DONE_PROMPT) { if (slot.state == SLOT_STATE_DONE_PROMPT) {
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) { if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING) {
// prompt evaluated for embedding // prompt evaluated for embedding
send_embedding(slot, batch_view); send_embedding(slot, batch_view);
slot.release(); slot.release();
@ -2128,7 +2128,7 @@ struct server_context {
continue; // continue loop of slots continue; // continue loop of slots
} }
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { if (slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
send_rerank(slot, batch_view); send_rerank(slot, batch_view);
slot.release(); slot.release();
slot.i_batch = -1; slot.i_batch = -1;
@ -2682,13 +2682,13 @@ int main(int argc, char ** argv) {
res_ok(res, {{ "success", true }}); res_ok(res, {{ "success", true }});
}; };
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) { const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_inf_type inf_type, json & data, httplib::Response & res) {
if (ctx_server.params.embedding || ctx_server.params.reranking) { if (ctx_server.params.embedding || ctx_server.params.reranking) {
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
return; return;
} }
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, cmpl_type); std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, inf_type);
ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks); ctx_server.queue_tasks.post(tasks);
@ -2734,7 +2734,7 @@ int main(int argc, char ** argv) {
const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) { const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
json data = json::parse(req.body); json data = json::parse(req.body);
return handle_completions_generic(SERVER_TASK_CMPL_TYPE_NORMAL, data, res); return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res);
}; };
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) { const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
@ -2784,7 +2784,7 @@ int main(int argc, char ** argv) {
} }
data["input_extra"] = input_extra; // default to empty array if it's not exist data["input_extra"] = input_extra; // default to empty array if it's not exist
return handle_completions_generic(SERVER_TASK_CMPL_TYPE_INFILL, data, res); return handle_completions_generic(SERVER_TASK_INF_TYPE_INFILL, data, res);
}; };
// TODO: maybe merge this function with "handle_completions_generic" // TODO: maybe merge this function with "handle_completions_generic"
@ -2796,7 +2796,7 @@ int main(int argc, char ** argv) {
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL); std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, SERVER_TASK_INF_TYPE_COMPLETION);
ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks); ctx_server.queue_tasks.post(tasks);
@ -2940,7 +2940,7 @@ int main(int argc, char ** argv) {
json responses = json::array(); json responses = json::array();
bool error = false; bool error = false;
{ {
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING); std::vector<server_task> tasks = ctx_server.create_tasks_inference({{"prompt", prompt}}, SERVER_TASK_INF_TYPE_EMBEDDING);
ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks); ctx_server.queue_tasks.post(tasks);
@ -3017,7 +3017,7 @@ int main(int argc, char ** argv) {
json responses = json::array(); json responses = json::array();
bool error = false; bool error = false;
{ {
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK); std::vector<server_task> tasks = ctx_server.create_tasks_inference({{"prompt", prompt}}, SERVER_TASK_INF_TYPE_RERANK);
ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks); ctx_server.queue_tasks.post(tasks);