rename completion to inference
This commit is contained in:
parent
575b1332ab
commit
4a9f3e7628
1 changed files with 33 additions and 33 deletions
|
@ -65,7 +65,7 @@ enum server_state {
|
||||||
};
|
};
|
||||||
|
|
||||||
enum server_task_type {
|
enum server_task_type {
|
||||||
SERVER_TASK_TYPE_COMPLETION,
|
SERVER_TASK_TYPE_INFERENCE,
|
||||||
SERVER_TASK_TYPE_CANCEL,
|
SERVER_TASK_TYPE_CANCEL,
|
||||||
SERVER_TASK_TYPE_NEXT_RESPONSE,
|
SERVER_TASK_TYPE_NEXT_RESPONSE,
|
||||||
SERVER_TASK_TYPE_METRICS,
|
SERVER_TASK_TYPE_METRICS,
|
||||||
|
@ -75,11 +75,11 @@ enum server_task_type {
|
||||||
SERVER_TASK_TYPE_SET_LORA,
|
SERVER_TASK_TYPE_SET_LORA,
|
||||||
};
|
};
|
||||||
|
|
||||||
enum server_task_cmpl_type {
|
enum server_task_inf_type {
|
||||||
SERVER_TASK_CMPL_TYPE_NORMAL,
|
SERVER_TASK_INF_TYPE_COMPLETION,
|
||||||
SERVER_TASK_CMPL_TYPE_EMBEDDING,
|
SERVER_TASK_INF_TYPE_EMBEDDING,
|
||||||
SERVER_TASK_CMPL_TYPE_RERANK,
|
SERVER_TASK_INF_TYPE_RERANK,
|
||||||
SERVER_TASK_CMPL_TYPE_INFILL,
|
SERVER_TASK_INF_TYPE_INFILL,
|
||||||
};
|
};
|
||||||
|
|
||||||
struct server_task {
|
struct server_task {
|
||||||
|
@ -90,7 +90,7 @@ struct server_task {
|
||||||
server_task_type type;
|
server_task_type type;
|
||||||
json data;
|
json data;
|
||||||
|
|
||||||
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
|
server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
|
||||||
|
|
||||||
// utility function
|
// utility function
|
||||||
static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
|
static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
|
||||||
|
@ -161,7 +161,7 @@ struct server_slot {
|
||||||
std::vector<llama_token> cache_tokens;
|
std::vector<llama_token> cache_tokens;
|
||||||
std::vector<completion_token_output> generated_token_probs;
|
std::vector<completion_token_output> generated_token_probs;
|
||||||
|
|
||||||
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
|
server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
|
||||||
|
|
||||||
bool has_next_token = true;
|
bool has_next_token = true;
|
||||||
bool has_new_line = false;
|
bool has_new_line = false;
|
||||||
|
@ -210,7 +210,7 @@ struct server_slot {
|
||||||
n_past = 0;
|
n_past = 0;
|
||||||
n_sent_text = 0;
|
n_sent_text = 0;
|
||||||
n_sent_token_probs = 0;
|
n_sent_token_probs = 0;
|
||||||
cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
|
inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
|
||||||
|
|
||||||
generated_token_probs.clear();
|
generated_token_probs.clear();
|
||||||
}
|
}
|
||||||
|
@ -1345,14 +1345,14 @@ struct server_context {
|
||||||
//
|
//
|
||||||
|
|
||||||
// break the input "prompt" into multiple tasks if needed, then format and tokenize the input prompt(s)
|
// break the input "prompt" into multiple tasks if needed, then format and tokenize the input prompt(s)
|
||||||
std::vector<server_task> create_tasks_cmpl(json data, server_task_cmpl_type cmpl_type) {
|
std::vector<server_task> create_tasks_inference(json data, server_task_inf_type inf_type) {
|
||||||
std::vector<server_task> tasks;
|
std::vector<server_task> tasks;
|
||||||
auto create_task = [&](json & task_data, llama_tokens & prompt_tokens) {
|
auto create_task = [&](json & task_data, llama_tokens & prompt_tokens) {
|
||||||
SRV_DBG("create task, n_tokens = %d\n", (int) prompt_tokens.size());
|
SRV_DBG("create task, n_tokens = %d\n", (int) prompt_tokens.size());
|
||||||
server_task task;
|
server_task task;
|
||||||
task.id = queue_tasks.get_new_id();
|
task.id = queue_tasks.get_new_id();
|
||||||
task.cmpl_type = cmpl_type;
|
task.inf_type = inf_type;
|
||||||
task.type = SERVER_TASK_TYPE_COMPLETION;
|
task.type = SERVER_TASK_TYPE_INFERENCE;
|
||||||
task.data = task_data;
|
task.data = task_data;
|
||||||
task.prompt_tokens = std::move(prompt_tokens);
|
task.prompt_tokens = std::move(prompt_tokens);
|
||||||
tasks.push_back(std::move(task));
|
tasks.push_back(std::move(task));
|
||||||
|
@ -1364,10 +1364,10 @@ struct server_context {
|
||||||
}
|
}
|
||||||
|
|
||||||
// because llama_tokenize api is thread-safe, we can tokenize the prompt from HTTP thread
|
// because llama_tokenize api is thread-safe, we can tokenize the prompt from HTTP thread
|
||||||
bool add_special = cmpl_type != SERVER_TASK_CMPL_TYPE_RERANK && cmpl_type != SERVER_TASK_CMPL_TYPE_INFILL;
|
bool add_special = inf_type != SERVER_TASK_INF_TYPE_RERANK && inf_type != SERVER_TASK_INF_TYPE_INFILL;
|
||||||
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx, data.at("prompt"), add_special, true);
|
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx, data.at("prompt"), add_special, true);
|
||||||
switch (cmpl_type) {
|
switch (inf_type) {
|
||||||
case SERVER_TASK_CMPL_TYPE_RERANK:
|
case SERVER_TASK_INF_TYPE_RERANK:
|
||||||
{
|
{
|
||||||
// prompts[0] is the question
|
// prompts[0] is the question
|
||||||
// the rest are the answers/documents
|
// the rest are the answers/documents
|
||||||
|
@ -1379,7 +1379,7 @@ struct server_context {
|
||||||
create_task(data, tokens);
|
create_task(data, tokens);
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case SERVER_TASK_CMPL_TYPE_INFILL:
|
case SERVER_TASK_INF_TYPE_INFILL:
|
||||||
{
|
{
|
||||||
SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
|
SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
|
||||||
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
||||||
|
@ -1427,7 +1427,7 @@ struct server_context {
|
||||||
queue_tasks.post(cancel_tasks, true);
|
queue_tasks.post(cancel_tasks, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
// receive the results from task(s) created by create_tasks_cmpl
|
// receive the results from task(s) created by create_tasks_inference
|
||||||
void receive_cmpl_results(
|
void receive_cmpl_results(
|
||||||
const std::unordered_set<int> & id_tasks,
|
const std::unordered_set<int> & id_tasks,
|
||||||
const std::function<void(std::vector<server_task_result>&)> & result_handler,
|
const std::function<void(std::vector<server_task_result>&)> & result_handler,
|
||||||
|
@ -1451,7 +1451,7 @@ struct server_context {
|
||||||
result_handler(results);
|
result_handler(results);
|
||||||
}
|
}
|
||||||
|
|
||||||
// receive the results from task(s) created by create_tasks_cmpl, in stream mode
|
// receive the results from task(s) created by create_tasks_inference, in stream mode
|
||||||
void receive_cmpl_results_stream(
|
void receive_cmpl_results_stream(
|
||||||
const std::unordered_set<int> & id_tasks, const
|
const std::unordered_set<int> & id_tasks, const
|
||||||
std::function<bool(server_task_result&)> & result_handler, const
|
std::function<bool(server_task_result&)> & result_handler, const
|
||||||
|
@ -1484,7 +1484,7 @@ struct server_context {
|
||||||
|
|
||||||
void process_single_task(const server_task & task) {
|
void process_single_task(const server_task & task) {
|
||||||
switch (task.type) {
|
switch (task.type) {
|
||||||
case SERVER_TASK_TYPE_COMPLETION:
|
case SERVER_TASK_TYPE_INFERENCE:
|
||||||
{
|
{
|
||||||
const int id_slot = json_value(task.data, "id_slot", -1);
|
const int id_slot = json_value(task.data, "id_slot", -1);
|
||||||
|
|
||||||
|
@ -1517,7 +1517,7 @@ struct server_context {
|
||||||
slot->reset();
|
slot->reset();
|
||||||
|
|
||||||
slot->id_task = task.id;
|
slot->id_task = task.id;
|
||||||
slot->cmpl_type = task.cmpl_type;
|
slot->inf_type = task.inf_type;
|
||||||
slot->index = json_value(task.data, "index", 0);
|
slot->index = json_value(task.data, "index", 0);
|
||||||
slot->prompt_tokens = std::move(task.prompt_tokens);
|
slot->prompt_tokens = std::move(task.prompt_tokens);
|
||||||
|
|
||||||
|
@ -1881,7 +1881,7 @@ struct server_context {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
|
||||||
// this prompt is too large to process - discard it
|
// this prompt is too large to process - discard it
|
||||||
if (slot.n_prompt_tokens > n_ubatch) {
|
if (slot.n_prompt_tokens > n_ubatch) {
|
||||||
slot.release();
|
slot.release();
|
||||||
|
@ -1992,7 +1992,7 @@ struct server_context {
|
||||||
}
|
}
|
||||||
|
|
||||||
// non-causal tasks require to fit the entire prompt in the physical batch
|
// non-causal tasks require to fit the entire prompt in the physical batch
|
||||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
|
||||||
// cannot fit the prompt in the current batch - will try next iter
|
// cannot fit the prompt in the current batch - will try next iter
|
||||||
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
|
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
|
||||||
continue;
|
continue;
|
||||||
|
@ -2001,8 +2001,8 @@ struct server_context {
|
||||||
|
|
||||||
// check that we are in the right batch_type, if not defer the slot
|
// check that we are in the right batch_type, if not defer the slot
|
||||||
const bool slot_type =
|
const bool slot_type =
|
||||||
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ||
|
slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING ||
|
||||||
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ? 1 : 0;
|
slot.inf_type == SERVER_TASK_INF_TYPE_RERANK ? 1 : 0;
|
||||||
|
|
||||||
if (batch_type == -1) {
|
if (batch_type == -1) {
|
||||||
batch_type = slot_type;
|
batch_type = slot_type;
|
||||||
|
@ -2120,7 +2120,7 @@ struct server_context {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (slot.state == SLOT_STATE_DONE_PROMPT) {
|
if (slot.state == SLOT_STATE_DONE_PROMPT) {
|
||||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) {
|
if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING) {
|
||||||
// prompt evaluated for embedding
|
// prompt evaluated for embedding
|
||||||
send_embedding(slot, batch_view);
|
send_embedding(slot, batch_view);
|
||||||
slot.release();
|
slot.release();
|
||||||
|
@ -2128,7 +2128,7 @@ struct server_context {
|
||||||
continue; // continue loop of slots
|
continue; // continue loop of slots
|
||||||
}
|
}
|
||||||
|
|
||||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
if (slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
|
||||||
send_rerank(slot, batch_view);
|
send_rerank(slot, batch_view);
|
||||||
slot.release();
|
slot.release();
|
||||||
slot.i_batch = -1;
|
slot.i_batch = -1;
|
||||||
|
@ -2682,13 +2682,13 @@ int main(int argc, char ** argv) {
|
||||||
res_ok(res, {{ "success", true }});
|
res_ok(res, {{ "success", true }});
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) {
|
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_inf_type inf_type, json & data, httplib::Response & res) {
|
||||||
if (ctx_server.params.embedding || ctx_server.params.reranking) {
|
if (ctx_server.params.embedding || ctx_server.params.reranking) {
|
||||||
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
|
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, cmpl_type);
|
std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, inf_type);
|
||||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||||
ctx_server.queue_tasks.post(tasks);
|
ctx_server.queue_tasks.post(tasks);
|
||||||
|
|
||||||
|
@ -2734,7 +2734,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
||||||
json data = json::parse(req.body);
|
json data = json::parse(req.body);
|
||||||
return handle_completions_generic(SERVER_TASK_CMPL_TYPE_NORMAL, data, res);
|
return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res);
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
||||||
|
@ -2784,7 +2784,7 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
data["input_extra"] = input_extra; // default to empty array if it's not exist
|
data["input_extra"] = input_extra; // default to empty array if it's not exist
|
||||||
|
|
||||||
return handle_completions_generic(SERVER_TASK_CMPL_TYPE_INFILL, data, res);
|
return handle_completions_generic(SERVER_TASK_INF_TYPE_INFILL, data, res);
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO: maybe merge this function with "handle_completions_generic"
|
// TODO: maybe merge this function with "handle_completions_generic"
|
||||||
|
@ -2796,7 +2796,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
|
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
|
||||||
|
|
||||||
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL);
|
std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, SERVER_TASK_INF_TYPE_COMPLETION);
|
||||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||||
ctx_server.queue_tasks.post(tasks);
|
ctx_server.queue_tasks.post(tasks);
|
||||||
|
|
||||||
|
@ -2940,7 +2940,7 @@ int main(int argc, char ** argv) {
|
||||||
json responses = json::array();
|
json responses = json::array();
|
||||||
bool error = false;
|
bool error = false;
|
||||||
{
|
{
|
||||||
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING);
|
std::vector<server_task> tasks = ctx_server.create_tasks_inference({{"prompt", prompt}}, SERVER_TASK_INF_TYPE_EMBEDDING);
|
||||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||||
ctx_server.queue_tasks.post(tasks);
|
ctx_server.queue_tasks.post(tasks);
|
||||||
|
|
||||||
|
@ -3017,7 +3017,7 @@ int main(int argc, char ** argv) {
|
||||||
json responses = json::array();
|
json responses = json::array();
|
||||||
bool error = false;
|
bool error = false;
|
||||||
{
|
{
|
||||||
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK);
|
std::vector<server_task> tasks = ctx_server.create_tasks_inference({{"prompt", prompt}}, SERVER_TASK_INF_TYPE_RERANK);
|
||||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||||
ctx_server.queue_tasks.post(tasks);
|
ctx_server.queue_tasks.post(tasks);
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue