rename completion to inference
This commit is contained in:
parent
575b1332ab
commit
4a9f3e7628
1 changed files with 33 additions and 33 deletions
|
@ -65,7 +65,7 @@ enum server_state {
|
|||
};
|
||||
|
||||
enum server_task_type {
|
||||
SERVER_TASK_TYPE_COMPLETION,
|
||||
SERVER_TASK_TYPE_INFERENCE,
|
||||
SERVER_TASK_TYPE_CANCEL,
|
||||
SERVER_TASK_TYPE_NEXT_RESPONSE,
|
||||
SERVER_TASK_TYPE_METRICS,
|
||||
|
@ -75,11 +75,11 @@ enum server_task_type {
|
|||
SERVER_TASK_TYPE_SET_LORA,
|
||||
};
|
||||
|
||||
enum server_task_cmpl_type {
|
||||
SERVER_TASK_CMPL_TYPE_NORMAL,
|
||||
SERVER_TASK_CMPL_TYPE_EMBEDDING,
|
||||
SERVER_TASK_CMPL_TYPE_RERANK,
|
||||
SERVER_TASK_CMPL_TYPE_INFILL,
|
||||
enum server_task_inf_type {
|
||||
SERVER_TASK_INF_TYPE_COMPLETION,
|
||||
SERVER_TASK_INF_TYPE_EMBEDDING,
|
||||
SERVER_TASK_INF_TYPE_RERANK,
|
||||
SERVER_TASK_INF_TYPE_INFILL,
|
||||
};
|
||||
|
||||
struct server_task {
|
||||
|
@ -90,7 +90,7 @@ struct server_task {
|
|||
server_task_type type;
|
||||
json data;
|
||||
|
||||
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
|
||||
server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
|
||||
|
||||
// utility function
|
||||
static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
|
||||
|
@ -161,7 +161,7 @@ struct server_slot {
|
|||
std::vector<llama_token> cache_tokens;
|
||||
std::vector<completion_token_output> generated_token_probs;
|
||||
|
||||
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
|
||||
server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
|
||||
|
||||
bool has_next_token = true;
|
||||
bool has_new_line = false;
|
||||
|
@ -210,7 +210,7 @@ struct server_slot {
|
|||
n_past = 0;
|
||||
n_sent_text = 0;
|
||||
n_sent_token_probs = 0;
|
||||
cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
|
||||
inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
|
||||
|
||||
generated_token_probs.clear();
|
||||
}
|
||||
|
@ -1345,14 +1345,14 @@ struct server_context {
|
|||
//
|
||||
|
||||
// break the input "prompt" into multiple tasks if needed, then format and tokenize the input prompt(s)
|
||||
std::vector<server_task> create_tasks_cmpl(json data, server_task_cmpl_type cmpl_type) {
|
||||
std::vector<server_task> create_tasks_inference(json data, server_task_inf_type inf_type) {
|
||||
std::vector<server_task> tasks;
|
||||
auto create_task = [&](json & task_data, llama_tokens & prompt_tokens) {
|
||||
SRV_DBG("create task, n_tokens = %d\n", (int) prompt_tokens.size());
|
||||
server_task task;
|
||||
task.id = queue_tasks.get_new_id();
|
||||
task.cmpl_type = cmpl_type;
|
||||
task.type = SERVER_TASK_TYPE_COMPLETION;
|
||||
task.inf_type = inf_type;
|
||||
task.type = SERVER_TASK_TYPE_INFERENCE;
|
||||
task.data = task_data;
|
||||
task.prompt_tokens = std::move(prompt_tokens);
|
||||
tasks.push_back(std::move(task));
|
||||
|
@ -1364,10 +1364,10 @@ struct server_context {
|
|||
}
|
||||
|
||||
// because llama_tokenize api is thread-safe, we can tokenize the prompt from HTTP thread
|
||||
bool add_special = cmpl_type != SERVER_TASK_CMPL_TYPE_RERANK && cmpl_type != SERVER_TASK_CMPL_TYPE_INFILL;
|
||||
bool add_special = inf_type != SERVER_TASK_INF_TYPE_RERANK && inf_type != SERVER_TASK_INF_TYPE_INFILL;
|
||||
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx, data.at("prompt"), add_special, true);
|
||||
switch (cmpl_type) {
|
||||
case SERVER_TASK_CMPL_TYPE_RERANK:
|
||||
switch (inf_type) {
|
||||
case SERVER_TASK_INF_TYPE_RERANK:
|
||||
{
|
||||
// prompts[0] is the question
|
||||
// the rest are the answers/documents
|
||||
|
@ -1379,7 +1379,7 @@ struct server_context {
|
|||
create_task(data, tokens);
|
||||
}
|
||||
} break;
|
||||
case SERVER_TASK_CMPL_TYPE_INFILL:
|
||||
case SERVER_TASK_INF_TYPE_INFILL:
|
||||
{
|
||||
SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
|
||||
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
||||
|
@ -1427,7 +1427,7 @@ struct server_context {
|
|||
queue_tasks.post(cancel_tasks, true);
|
||||
}
|
||||
|
||||
// receive the results from task(s) created by create_tasks_cmpl
|
||||
// receive the results from task(s) created by create_tasks_inference
|
||||
void receive_cmpl_results(
|
||||
const std::unordered_set<int> & id_tasks,
|
||||
const std::function<void(std::vector<server_task_result>&)> & result_handler,
|
||||
|
@ -1451,7 +1451,7 @@ struct server_context {
|
|||
result_handler(results);
|
||||
}
|
||||
|
||||
// receive the results from task(s) created by create_tasks_cmpl, in stream mode
|
||||
// receive the results from task(s) created by create_tasks_inference, in stream mode
|
||||
void receive_cmpl_results_stream(
|
||||
const std::unordered_set<int> & id_tasks, const
|
||||
std::function<bool(server_task_result&)> & result_handler, const
|
||||
|
@ -1484,7 +1484,7 @@ struct server_context {
|
|||
|
||||
void process_single_task(const server_task & task) {
|
||||
switch (task.type) {
|
||||
case SERVER_TASK_TYPE_COMPLETION:
|
||||
case SERVER_TASK_TYPE_INFERENCE:
|
||||
{
|
||||
const int id_slot = json_value(task.data, "id_slot", -1);
|
||||
|
||||
|
@ -1517,7 +1517,7 @@ struct server_context {
|
|||
slot->reset();
|
||||
|
||||
slot->id_task = task.id;
|
||||
slot->cmpl_type = task.cmpl_type;
|
||||
slot->inf_type = task.inf_type;
|
||||
slot->index = json_value(task.data, "index", 0);
|
||||
slot->prompt_tokens = std::move(task.prompt_tokens);
|
||||
|
||||
|
@ -1881,7 +1881,7 @@ struct server_context {
|
|||
continue;
|
||||
}
|
||||
|
||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
||||
if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
|
||||
// this prompt is too large to process - discard it
|
||||
if (slot.n_prompt_tokens > n_ubatch) {
|
||||
slot.release();
|
||||
|
@ -1992,7 +1992,7 @@ struct server_context {
|
|||
}
|
||||
|
||||
// non-causal tasks require to fit the entire prompt in the physical batch
|
||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
||||
if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
|
||||
// cannot fit the prompt in the current batch - will try next iter
|
||||
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
|
||||
continue;
|
||||
|
@ -2001,8 +2001,8 @@ struct server_context {
|
|||
|
||||
// check that we are in the right batch_type, if not defer the slot
|
||||
const bool slot_type =
|
||||
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ||
|
||||
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ? 1 : 0;
|
||||
slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING ||
|
||||
slot.inf_type == SERVER_TASK_INF_TYPE_RERANK ? 1 : 0;
|
||||
|
||||
if (batch_type == -1) {
|
||||
batch_type = slot_type;
|
||||
|
@ -2120,7 +2120,7 @@ struct server_context {
|
|||
}
|
||||
|
||||
if (slot.state == SLOT_STATE_DONE_PROMPT) {
|
||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) {
|
||||
if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING) {
|
||||
// prompt evaluated for embedding
|
||||
send_embedding(slot, batch_view);
|
||||
slot.release();
|
||||
|
@ -2128,7 +2128,7 @@ struct server_context {
|
|||
continue; // continue loop of slots
|
||||
}
|
||||
|
||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
||||
if (slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
|
||||
send_rerank(slot, batch_view);
|
||||
slot.release();
|
||||
slot.i_batch = -1;
|
||||
|
@ -2682,13 +2682,13 @@ int main(int argc, char ** argv) {
|
|||
res_ok(res, {{ "success", true }});
|
||||
};
|
||||
|
||||
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) {
|
||||
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_inf_type inf_type, json & data, httplib::Response & res) {
|
||||
if (ctx_server.params.embedding || ctx_server.params.reranking) {
|
||||
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, cmpl_type);
|
||||
std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, inf_type);
|
||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(tasks);
|
||||
|
||||
|
@ -2734,7 +2734,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
||||
json data = json::parse(req.body);
|
||||
return handle_completions_generic(SERVER_TASK_CMPL_TYPE_NORMAL, data, res);
|
||||
return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res);
|
||||
};
|
||||
|
||||
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
||||
|
@ -2784,7 +2784,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
data["input_extra"] = input_extra; // default to empty array if it's not exist
|
||||
|
||||
return handle_completions_generic(SERVER_TASK_CMPL_TYPE_INFILL, data, res);
|
||||
return handle_completions_generic(SERVER_TASK_INF_TYPE_INFILL, data, res);
|
||||
};
|
||||
|
||||
// TODO: maybe merge this function with "handle_completions_generic"
|
||||
|
@ -2796,7 +2796,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
|
||||
|
||||
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL);
|
||||
std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, SERVER_TASK_INF_TYPE_COMPLETION);
|
||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(tasks);
|
||||
|
||||
|
@ -2940,7 +2940,7 @@ int main(int argc, char ** argv) {
|
|||
json responses = json::array();
|
||||
bool error = false;
|
||||
{
|
||||
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING);
|
||||
std::vector<server_task> tasks = ctx_server.create_tasks_inference({{"prompt", prompt}}, SERVER_TASK_INF_TYPE_EMBEDDING);
|
||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(tasks);
|
||||
|
||||
|
@ -3017,7 +3017,7 @@ int main(int argc, char ** argv) {
|
|||
json responses = json::array();
|
||||
bool error = false;
|
||||
{
|
||||
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK);
|
||||
std::vector<server_task> tasks = ctx_server.create_tasks_inference({{"prompt", prompt}}, SERVER_TASK_INF_TYPE_RERANK);
|
||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(tasks);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue