remove task inf_type
This commit is contained in:
parent
e721f4c6b4
commit
090a113417
1 changed files with 35 additions and 33 deletions
|
@ -54,7 +54,10 @@ enum server_state {
|
||||||
};
|
};
|
||||||
|
|
||||||
enum server_task_type {
|
enum server_task_type {
|
||||||
SERVER_TASK_TYPE_INFERENCE,
|
SERVER_TASK_TYPE_COMPLETION,
|
||||||
|
SERVER_TASK_TYPE_EMBEDDING,
|
||||||
|
SERVER_TASK_TYPE_RERANK,
|
||||||
|
SERVER_TASK_TYPE_INFILL,
|
||||||
SERVER_TASK_TYPE_CANCEL,
|
SERVER_TASK_TYPE_CANCEL,
|
||||||
SERVER_TASK_TYPE_NEXT_RESPONSE,
|
SERVER_TASK_TYPE_NEXT_RESPONSE,
|
||||||
SERVER_TASK_TYPE_METRICS,
|
SERVER_TASK_TYPE_METRICS,
|
||||||
|
@ -64,13 +67,6 @@ enum server_task_type {
|
||||||
SERVER_TASK_TYPE_SET_LORA,
|
SERVER_TASK_TYPE_SET_LORA,
|
||||||
};
|
};
|
||||||
|
|
||||||
enum server_task_inf_type {
|
|
||||||
SERVER_TASK_INF_TYPE_COMPLETION,
|
|
||||||
SERVER_TASK_INF_TYPE_EMBEDDING,
|
|
||||||
SERVER_TASK_INF_TYPE_RERANK,
|
|
||||||
SERVER_TASK_INF_TYPE_INFILL,
|
|
||||||
};
|
|
||||||
|
|
||||||
// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
|
// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
|
||||||
enum error_type {
|
enum error_type {
|
||||||
ERROR_TYPE_INVALID_REQUEST,
|
ERROR_TYPE_INVALID_REQUEST,
|
||||||
|
@ -164,7 +160,6 @@ struct server_task {
|
||||||
int index = -1; // used when there are multiple prompts (batch request)
|
int index = -1; // used when there are multiple prompts (batch request)
|
||||||
|
|
||||||
server_task_type type;
|
server_task_type type;
|
||||||
server_task_inf_type inf_type;
|
|
||||||
|
|
||||||
// used by SERVER_TASK_TYPE_CANCEL
|
// used by SERVER_TASK_TYPE_CANCEL
|
||||||
int id_target = -1;
|
int id_target = -1;
|
||||||
|
@ -185,9 +180,7 @@ struct server_task {
|
||||||
// used by SERVER_TASK_TYPE_METRICS
|
// used by SERVER_TASK_TYPE_METRICS
|
||||||
bool metrics_reset_bucket = false;
|
bool metrics_reset_bucket = false;
|
||||||
|
|
||||||
server_task(
|
server_task(server_task_type type) : type(type) {}
|
||||||
server_task_type type,
|
|
||||||
server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION) : type(type), inf_type(inf_type) {}
|
|
||||||
|
|
||||||
static slot_params params_from_json_cmpl(
|
static slot_params params_from_json_cmpl(
|
||||||
const llama_model * model,
|
const llama_model * model,
|
||||||
|
@ -893,6 +886,9 @@ struct server_slot {
|
||||||
int id;
|
int id;
|
||||||
int id_task = -1;
|
int id_task = -1;
|
||||||
|
|
||||||
|
// only used for completion/embedding/infill/rerank
|
||||||
|
server_task_type task_type = SERVER_TASK_TYPE_COMPLETION;
|
||||||
|
|
||||||
llama_batch batch_spec = {};
|
llama_batch batch_spec = {};
|
||||||
|
|
||||||
llama_context * ctx = nullptr;
|
llama_context * ctx = nullptr;
|
||||||
|
@ -931,8 +927,6 @@ struct server_slot {
|
||||||
llama_tokens cache_tokens;
|
llama_tokens cache_tokens;
|
||||||
std::vector<completion_token_output> generated_token_probs;
|
std::vector<completion_token_output> generated_token_probs;
|
||||||
|
|
||||||
server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
|
|
||||||
|
|
||||||
bool has_next_token = true;
|
bool has_next_token = true;
|
||||||
bool has_new_line = false;
|
bool has_new_line = false;
|
||||||
bool truncated = false;
|
bool truncated = false;
|
||||||
|
@ -972,11 +966,15 @@ struct server_slot {
|
||||||
n_past = 0;
|
n_past = 0;
|
||||||
n_sent_text = 0;
|
n_sent_text = 0;
|
||||||
n_sent_token_probs = 0;
|
n_sent_token_probs = 0;
|
||||||
inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
|
task_type = SERVER_TASK_TYPE_COMPLETION;
|
||||||
|
|
||||||
generated_token_probs.clear();
|
generated_token_probs.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool is_non_causal() const {
|
||||||
|
return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK;
|
||||||
|
}
|
||||||
|
|
||||||
bool has_budget(const common_params & global_params) {
|
bool has_budget(const common_params & global_params) {
|
||||||
if (params.n_predict == -1 && global_params.n_predict == -1) {
|
if (params.n_predict == -1 && global_params.n_predict == -1) {
|
||||||
return true; // limitless
|
return true; // limitless
|
||||||
|
@ -1088,6 +1086,7 @@ struct server_slot {
|
||||||
{"n_ctx", n_ctx},
|
{"n_ctx", n_ctx},
|
||||||
{"speculative", can_speculate()},
|
{"speculative", can_speculate()},
|
||||||
{"is_processing", is_processing()},
|
{"is_processing", is_processing()},
|
||||||
|
{"non_causal", is_non_causal()},
|
||||||
{"params", params.to_json()},
|
{"params", params.to_json()},
|
||||||
{"prompt", common_detokenize(ctx, prompt_tokens)},
|
{"prompt", common_detokenize(ctx, prompt_tokens)},
|
||||||
{"next_token",
|
{"next_token",
|
||||||
|
@ -1653,8 +1652,8 @@ struct server_context {
|
||||||
bool launch_slot_with_task(server_slot & slot, const server_task & task) {
|
bool launch_slot_with_task(server_slot & slot, const server_task & task) {
|
||||||
slot.reset();
|
slot.reset();
|
||||||
slot.id_task = task.id;
|
slot.id_task = task.id;
|
||||||
slot.inf_type = task.inf_type;
|
|
||||||
slot.index = task.index;
|
slot.index = task.index;
|
||||||
|
slot.task_type = task.type;
|
||||||
slot.params = std::move(task.params);
|
slot.params = std::move(task.params);
|
||||||
slot.prompt_tokens = std::move(task.prompt_tokens);
|
slot.prompt_tokens = std::move(task.prompt_tokens);
|
||||||
|
|
||||||
|
@ -2120,7 +2119,10 @@ struct server_context {
|
||||||
|
|
||||||
void process_single_task(server_task task) {
|
void process_single_task(server_task task) {
|
||||||
switch (task.type) {
|
switch (task.type) {
|
||||||
case SERVER_TASK_TYPE_INFERENCE:
|
case SERVER_TASK_TYPE_COMPLETION:
|
||||||
|
case SERVER_TASK_TYPE_INFILL:
|
||||||
|
case SERVER_TASK_TYPE_EMBEDDING:
|
||||||
|
case SERVER_TASK_TYPE_RERANK:
|
||||||
{
|
{
|
||||||
const int id_slot = task.id_selected_slot;
|
const int id_slot = task.id_selected_slot;
|
||||||
|
|
||||||
|
@ -2462,7 +2464,7 @@ struct server_context {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
|
if (slot.is_non_causal()) {
|
||||||
if (slot.n_prompt_tokens > n_ubatch) {
|
if (slot.n_prompt_tokens > n_ubatch) {
|
||||||
slot.release();
|
slot.release();
|
||||||
send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
|
send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
|
||||||
|
@ -2577,7 +2579,7 @@ struct server_context {
|
||||||
}
|
}
|
||||||
|
|
||||||
// non-causal tasks require to fit the entire prompt in the physical batch
|
// non-causal tasks require to fit the entire prompt in the physical batch
|
||||||
if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
|
if (slot.is_non_causal()) {
|
||||||
// cannot fit the prompt in the current batch - will try next iter
|
// cannot fit the prompt in the current batch - will try next iter
|
||||||
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
|
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
|
||||||
continue;
|
continue;
|
||||||
|
@ -2585,10 +2587,7 @@ struct server_context {
|
||||||
}
|
}
|
||||||
|
|
||||||
// check that we are in the right batch_type, if not defer the slot
|
// check that we are in the right batch_type, if not defer the slot
|
||||||
const bool slot_type =
|
int slot_type = slot.is_non_causal();
|
||||||
slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING ||
|
|
||||||
slot.inf_type == SERVER_TASK_INF_TYPE_RERANK ? 1 : 0;
|
|
||||||
|
|
||||||
if (batch_type == -1) {
|
if (batch_type == -1) {
|
||||||
batch_type = slot_type;
|
batch_type = slot_type;
|
||||||
} else if (batch_type != slot_type) {
|
} else if (batch_type != slot_type) {
|
||||||
|
@ -2705,7 +2704,7 @@ struct server_context {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (slot.state == SLOT_STATE_DONE_PROMPT) {
|
if (slot.state == SLOT_STATE_DONE_PROMPT) {
|
||||||
if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING) {
|
if (slot.task_type == SERVER_TASK_TYPE_EMBEDDING) {
|
||||||
// prompt evaluated for embedding
|
// prompt evaluated for embedding
|
||||||
send_embedding(slot, batch_view);
|
send_embedding(slot, batch_view);
|
||||||
slot.release();
|
slot.release();
|
||||||
|
@ -2713,7 +2712,7 @@ struct server_context {
|
||||||
continue; // continue loop of slots
|
continue; // continue loop of slots
|
||||||
}
|
}
|
||||||
|
|
||||||
if (slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
|
if (slot.task_type == SERVER_TASK_TYPE_RERANK) {
|
||||||
send_rerank(slot, batch_view);
|
send_rerank(slot, batch_view);
|
||||||
slot.release();
|
slot.release();
|
||||||
slot.i_batch = -1;
|
slot.i_batch = -1;
|
||||||
|
@ -3352,11 +3351,13 @@ int main(int argc, char ** argv) {
|
||||||
// handle completion-like requests (completion, chat, infill)
|
// handle completion-like requests (completion, chat, infill)
|
||||||
// we can optionally provide a custom format for partial results and final results
|
// we can optionally provide a custom format for partial results and final results
|
||||||
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](
|
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](
|
||||||
server_task_inf_type inf_type,
|
server_task_type type,
|
||||||
json & data,
|
json & data,
|
||||||
httplib::Response & res,
|
httplib::Response & res,
|
||||||
bool oaicompat = false,
|
bool oaicompat = false,
|
||||||
bool oaicompat_chat = false) {
|
bool oaicompat_chat = false) {
|
||||||
|
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
|
||||||
|
|
||||||
if (ctx_server.params_base.embedding) {
|
if (ctx_server.params_base.embedding) {
|
||||||
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
||||||
return;
|
return;
|
||||||
|
@ -3369,7 +3370,8 @@ int main(int argc, char ** argv) {
|
||||||
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, data.at("prompt"), true, true);
|
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, data.at("prompt"), true, true);
|
||||||
tasks.reserve(tokenized_prompts.size());
|
tasks.reserve(tokenized_prompts.size());
|
||||||
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
||||||
server_task task = server_task(SERVER_TASK_TYPE_INFERENCE, inf_type);
|
server_task task = server_task(type);
|
||||||
|
|
||||||
task.id = ctx_server.queue_tasks.get_new_id();
|
task.id = ctx_server.queue_tasks.get_new_id();
|
||||||
task.index = i;
|
task.index = i;
|
||||||
|
|
||||||
|
@ -3450,7 +3452,7 @@ int main(int argc, char ** argv) {
|
||||||
const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
||||||
json data = json::parse(req.body);
|
json data = json::parse(req.body);
|
||||||
return handle_completions_generic(
|
return handle_completions_generic(
|
||||||
SERVER_TASK_INF_TYPE_COMPLETION,
|
SERVER_TASK_TYPE_COMPLETION,
|
||||||
data,
|
data,
|
||||||
res,
|
res,
|
||||||
/* oaicompat */ false,
|
/* oaicompat */ false,
|
||||||
|
@ -3504,7 +3506,7 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
data["input_extra"] = input_extra; // default to empty array if it's not exist
|
data["input_extra"] = input_extra; // default to empty array if it's not exist
|
||||||
|
|
||||||
return handle_completions_generic(SERVER_TASK_INF_TYPE_INFILL, data, res);
|
return handle_completions_generic(SERVER_TASK_TYPE_INFILL, data, res);
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
||||||
|
@ -3515,7 +3517,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
|
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
|
||||||
return handle_completions_generic(
|
return handle_completions_generic(
|
||||||
SERVER_TASK_INF_TYPE_COMPLETION,
|
SERVER_TASK_TYPE_COMPLETION,
|
||||||
data,
|
data,
|
||||||
res,
|
res,
|
||||||
/* oaicompat */ true,
|
/* oaicompat */ true,
|
||||||
|
@ -3616,7 +3618,7 @@ int main(int argc, char ** argv) {
|
||||||
std::vector<server_task> tasks;
|
std::vector<server_task> tasks;
|
||||||
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, /* add_special */ false, true);
|
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, /* add_special */ false, true);
|
||||||
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
||||||
server_task task = server_task(SERVER_TASK_TYPE_INFERENCE, SERVER_TASK_INF_TYPE_EMBEDDING);
|
server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
|
||||||
task.id = ctx_server.queue_tasks.get_new_id();
|
task.id = ctx_server.queue_tasks.get_new_id();
|
||||||
task.index = i;
|
task.index = i;
|
||||||
task.prompt_tokens = std::move(tokenized_prompts[i]);
|
task.prompt_tokens = std::move(tokenized_prompts[i]);
|
||||||
|
@ -3698,7 +3700,7 @@ int main(int argc, char ** argv) {
|
||||||
std::vector<llama_tokens> tokenized_docs = tokenize_input_prompts(ctx_server.ctx, documents, /* add_special */ false, true);
|
std::vector<llama_tokens> tokenized_docs = tokenize_input_prompts(ctx_server.ctx, documents, /* add_special */ false, true);
|
||||||
tasks.reserve(tokenized_docs.size());
|
tasks.reserve(tokenized_docs.size());
|
||||||
for (size_t i = 0; i < tokenized_docs.size(); i++) {
|
for (size_t i = 0; i < tokenized_docs.size(); i++) {
|
||||||
server_task task = server_task(SERVER_TASK_TYPE_INFERENCE, SERVER_TASK_INF_TYPE_RERANK);
|
server_task task = server_task(SERVER_TASK_TYPE_RERANK);
|
||||||
task.id = ctx_server.queue_tasks.get_new_id();
|
task.id = ctx_server.queue_tasks.get_new_id();
|
||||||
task.index = i;
|
task.index = i;
|
||||||
task.prompt_tokens = format_rerank(ctx_server.model, tokenized_query, tokenized_docs[i]);
|
task.prompt_tokens = format_rerank(ctx_server.model, tokenized_query, tokenized_docs[i]);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue