rename completion to inference

This commit is contained in:
Xuan Son Nguyen 2024-10-24 16:29:38 +02:00
parent 575b1332ab
commit 4a9f3e7628

View file

@ -65,7 +65,7 @@ enum server_state {
};
enum server_task_type {
SERVER_TASK_TYPE_COMPLETION,
SERVER_TASK_TYPE_INFERENCE,
SERVER_TASK_TYPE_CANCEL,
SERVER_TASK_TYPE_NEXT_RESPONSE,
SERVER_TASK_TYPE_METRICS,
@ -75,11 +75,11 @@ enum server_task_type {
SERVER_TASK_TYPE_SET_LORA,
};
enum server_task_cmpl_type {
SERVER_TASK_CMPL_TYPE_NORMAL,
SERVER_TASK_CMPL_TYPE_EMBEDDING,
SERVER_TASK_CMPL_TYPE_RERANK,
SERVER_TASK_CMPL_TYPE_INFILL,
enum server_task_inf_type {
SERVER_TASK_INF_TYPE_COMPLETION,
SERVER_TASK_INF_TYPE_EMBEDDING,
SERVER_TASK_INF_TYPE_RERANK,
SERVER_TASK_INF_TYPE_INFILL,
};
struct server_task {
@ -90,7 +90,7 @@ struct server_task {
server_task_type type;
json data;
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
// utility function
static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
@ -161,7 +161,7 @@ struct server_slot {
std::vector<llama_token> cache_tokens;
std::vector<completion_token_output> generated_token_probs;
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
bool has_next_token = true;
bool has_new_line = false;
@ -210,7 +210,7 @@ struct server_slot {
n_past = 0;
n_sent_text = 0;
n_sent_token_probs = 0;
cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
generated_token_probs.clear();
}
@ -1345,14 +1345,14 @@ struct server_context {
//
// break the input "prompt" into multiple tasks if needed, then format and tokenize the input prompt(s)
std::vector<server_task> create_tasks_cmpl(json data, server_task_cmpl_type cmpl_type) {
std::vector<server_task> create_tasks_inference(json data, server_task_inf_type inf_type) {
std::vector<server_task> tasks;
auto create_task = [&](json & task_data, llama_tokens & prompt_tokens) {
SRV_DBG("create task, n_tokens = %d\n", (int) prompt_tokens.size());
server_task task;
task.id = queue_tasks.get_new_id();
task.cmpl_type = cmpl_type;
task.type = SERVER_TASK_TYPE_COMPLETION;
task.inf_type = inf_type;
task.type = SERVER_TASK_TYPE_INFERENCE;
task.data = task_data;
task.prompt_tokens = std::move(prompt_tokens);
tasks.push_back(std::move(task));
@ -1364,10 +1364,10 @@ struct server_context {
}
// because llama_tokenize api is thread-safe, we can tokenize the prompt from HTTP thread
bool add_special = cmpl_type != SERVER_TASK_CMPL_TYPE_RERANK && cmpl_type != SERVER_TASK_CMPL_TYPE_INFILL;
bool add_special = inf_type != SERVER_TASK_INF_TYPE_RERANK && inf_type != SERVER_TASK_INF_TYPE_INFILL;
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx, data.at("prompt"), add_special, true);
switch (cmpl_type) {
case SERVER_TASK_CMPL_TYPE_RERANK:
switch (inf_type) {
case SERVER_TASK_INF_TYPE_RERANK:
{
// prompts[0] is the question
// the rest are the answers/documents
@ -1379,7 +1379,7 @@ struct server_context {
create_task(data, tokens);
}
} break;
case SERVER_TASK_CMPL_TYPE_INFILL:
case SERVER_TASK_INF_TYPE_INFILL:
{
SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
@ -1427,7 +1427,7 @@ struct server_context {
queue_tasks.post(cancel_tasks, true);
}
// receive the results from task(s) created by create_tasks_cmpl
// receive the results from task(s) created by create_tasks_inference
void receive_cmpl_results(
const std::unordered_set<int> & id_tasks,
const std::function<void(std::vector<server_task_result>&)> & result_handler,
@ -1451,7 +1451,7 @@ struct server_context {
result_handler(results);
}
// receive the results from task(s) created by create_tasks_cmpl, in stream mode
// receive the results from task(s) created by create_tasks_inference, in stream mode
void receive_cmpl_results_stream(
const std::unordered_set<int> & id_tasks, const
std::function<bool(server_task_result&)> & result_handler, const
@ -1484,7 +1484,7 @@ struct server_context {
void process_single_task(const server_task & task) {
switch (task.type) {
case SERVER_TASK_TYPE_COMPLETION:
case SERVER_TASK_TYPE_INFERENCE:
{
const int id_slot = json_value(task.data, "id_slot", -1);
@ -1517,7 +1517,7 @@ struct server_context {
slot->reset();
slot->id_task = task.id;
slot->cmpl_type = task.cmpl_type;
slot->inf_type = task.inf_type;
slot->index = json_value(task.data, "index", 0);
slot->prompt_tokens = std::move(task.prompt_tokens);
@ -1881,7 +1881,7 @@ struct server_context {
continue;
}
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
// this prompt is too large to process - discard it
if (slot.n_prompt_tokens > n_ubatch) {
slot.release();
@ -1992,7 +1992,7 @@ struct server_context {
}
// non-causal tasks require to fit the entire prompt in the physical batch
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
// cannot fit the prompt in the current batch - will try next iter
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
continue;
@ -2001,8 +2001,8 @@ struct server_context {
// check that we are in the right batch_type, if not defer the slot
const bool slot_type =
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ||
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ? 1 : 0;
slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING ||
slot.inf_type == SERVER_TASK_INF_TYPE_RERANK ? 1 : 0;
if (batch_type == -1) {
batch_type = slot_type;
@ -2120,7 +2120,7 @@ struct server_context {
}
if (slot.state == SLOT_STATE_DONE_PROMPT) {
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) {
if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING) {
// prompt evaluated for embedding
send_embedding(slot, batch_view);
slot.release();
@ -2128,7 +2128,7 @@ struct server_context {
continue; // continue loop of slots
}
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
if (slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
send_rerank(slot, batch_view);
slot.release();
slot.i_batch = -1;
@ -2682,13 +2682,13 @@ int main(int argc, char ** argv) {
res_ok(res, {{ "success", true }});
};
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) {
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_inf_type inf_type, json & data, httplib::Response & res) {
if (ctx_server.params.embedding || ctx_server.params.reranking) {
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, cmpl_type);
std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, inf_type);
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks);
@ -2734,7 +2734,7 @@ int main(int argc, char ** argv) {
const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
json data = json::parse(req.body);
return handle_completions_generic(SERVER_TASK_CMPL_TYPE_NORMAL, data, res);
return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res);
};
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
@ -2784,7 +2784,7 @@ int main(int argc, char ** argv) {
}
data["input_extra"] = input_extra; // default to empty array if it's not exist
return handle_completions_generic(SERVER_TASK_CMPL_TYPE_INFILL, data, res);
return handle_completions_generic(SERVER_TASK_INF_TYPE_INFILL, data, res);
};
// TODO: maybe merge this function with "handle_completions_generic"
@ -2796,7 +2796,7 @@ int main(int argc, char ** argv) {
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL);
std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, SERVER_TASK_INF_TYPE_COMPLETION);
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks);
@ -2940,7 +2940,7 @@ int main(int argc, char ** argv) {
json responses = json::array();
bool error = false;
{
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING);
std::vector<server_task> tasks = ctx_server.create_tasks_inference({{"prompt", prompt}}, SERVER_TASK_INF_TYPE_EMBEDDING);
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks);
@ -3017,7 +3017,7 @@ int main(int argc, char ** argv) {
json responses = json::array();
bool error = false;
{
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK);
std::vector<server_task> tasks = ctx_server.create_tasks_inference({{"prompt", prompt}}, SERVER_TASK_INF_TYPE_RERANK);
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks);