diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 01e9a6da1..6e97ff7d0 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -54,7 +54,10 @@ enum server_state { }; enum server_task_type { - SERVER_TASK_TYPE_INFERENCE, + SERVER_TASK_TYPE_COMPLETION, + SERVER_TASK_TYPE_EMBEDDING, + SERVER_TASK_TYPE_RERANK, + SERVER_TASK_TYPE_INFILL, SERVER_TASK_TYPE_CANCEL, SERVER_TASK_TYPE_NEXT_RESPONSE, SERVER_TASK_TYPE_METRICS, @@ -64,13 +67,6 @@ enum server_task_type { SERVER_TASK_TYPE_SET_LORA, }; -enum server_task_inf_type { - SERVER_TASK_INF_TYPE_COMPLETION, - SERVER_TASK_INF_TYPE_EMBEDDING, - SERVER_TASK_INF_TYPE_RERANK, - SERVER_TASK_INF_TYPE_INFILL, -}; - // https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11 enum error_type { ERROR_TYPE_INVALID_REQUEST, @@ -163,8 +159,7 @@ struct server_task { int id = -1; // to be filled by server_queue int index = -1; // used when there are multiple prompts (batch request) - server_task_type type; - server_task_inf_type inf_type; + server_task_type type; // used by SERVER_TASK_TYPE_CANCEL int id_target = -1; @@ -185,9 +180,7 @@ struct server_task { // used by SERVER_TASK_TYPE_METRICS bool metrics_reset_bucket = false; - server_task( - server_task_type type, - server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION) : type(type), inf_type(inf_type) {} + server_task(server_task_type type) : type(type) {} static slot_params params_from_json_cmpl( const llama_model * model, @@ -893,6 +886,9 @@ struct server_slot { int id; int id_task = -1; + // only used for completion/embedding/infill/rerank + server_task_type task_type = SERVER_TASK_TYPE_COMPLETION; + llama_batch batch_spec = {}; llama_context * ctx = nullptr; @@ -931,8 +927,6 @@ struct server_slot { llama_tokens cache_tokens; std::vector generated_token_probs; - server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION; - bool has_next_token = true; bool has_new_line = false; bool truncated = false; @@ -972,11 +966,15 @@ struct server_slot { n_past = 0; n_sent_text = 0; n_sent_token_probs = 0; - inf_type = SERVER_TASK_INF_TYPE_COMPLETION; + task_type = SERVER_TASK_TYPE_COMPLETION; generated_token_probs.clear(); } + bool is_non_causal() const { + return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK; + } + bool has_budget(const common_params & global_params) { if (params.n_predict == -1 && global_params.n_predict == -1) { return true; // limitless @@ -1088,6 +1086,7 @@ struct server_slot { {"n_ctx", n_ctx}, {"speculative", can_speculate()}, {"is_processing", is_processing()}, + {"non_causal", is_non_causal()}, {"params", params.to_json()}, {"prompt", common_detokenize(ctx, prompt_tokens)}, {"next_token", @@ -1653,8 +1652,8 @@ struct server_context { bool launch_slot_with_task(server_slot & slot, const server_task & task) { slot.reset(); slot.id_task = task.id; - slot.inf_type = task.inf_type; slot.index = task.index; + slot.task_type = task.type; slot.params = std::move(task.params); slot.prompt_tokens = std::move(task.prompt_tokens); @@ -2120,7 +2119,10 @@ struct server_context { void process_single_task(server_task task) { switch (task.type) { - case SERVER_TASK_TYPE_INFERENCE: + case SERVER_TASK_TYPE_COMPLETION: + case SERVER_TASK_TYPE_INFILL: + case SERVER_TASK_TYPE_EMBEDDING: + case SERVER_TASK_TYPE_RERANK: { const int id_slot = task.id_selected_slot; @@ -2462,7 +2464,7 @@ struct server_context { continue; } - if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) { + if (slot.is_non_causal()) { if (slot.n_prompt_tokens > n_ubatch) { slot.release(); send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER); @@ -2577,7 +2579,7 @@ struct server_context { } // non-causal tasks require to fit the entire prompt in the physical batch - if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) { + if (slot.is_non_causal()) { // cannot fit the prompt in the current batch - will try next iter if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { continue; @@ -2585,10 +2587,7 @@ struct server_context { } // check that we are in the right batch_type, if not defer the slot - const bool slot_type = - slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || - slot.inf_type == SERVER_TASK_INF_TYPE_RERANK ? 1 : 0; - + int slot_type = slot.is_non_causal(); if (batch_type == -1) { batch_type = slot_type; } else if (batch_type != slot_type) { @@ -2705,7 +2704,7 @@ struct server_context { } if (slot.state == SLOT_STATE_DONE_PROMPT) { - if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING) { + if (slot.task_type == SERVER_TASK_TYPE_EMBEDDING) { // prompt evaluated for embedding send_embedding(slot, batch_view); slot.release(); @@ -2713,7 +2712,7 @@ struct server_context { continue; // continue loop of slots } - if (slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) { + if (slot.task_type == SERVER_TASK_TYPE_RERANK) { send_rerank(slot, batch_view); slot.release(); slot.i_batch = -1; @@ -3352,11 +3351,13 @@ int main(int argc, char ** argv) { // handle completion-like requests (completion, chat, infill) // we can optionally provide a custom format for partial results and final results const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok]( - server_task_inf_type inf_type, + server_task_type type, json & data, httplib::Response & res, bool oaicompat = false, bool oaicompat_chat = false) { + GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); + if (ctx_server.params_base.embedding) { res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); return; @@ -3369,7 +3370,8 @@ int main(int argc, char ** argv) { std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, data.at("prompt"), true, true); tasks.reserve(tokenized_prompts.size()); for (size_t i = 0; i < tokenized_prompts.size(); i++) { - server_task task = server_task(SERVER_TASK_TYPE_INFERENCE, inf_type); + server_task task = server_task(type); + task.id = ctx_server.queue_tasks.get_new_id(); task.index = i; @@ -3450,7 +3452,7 @@ int main(int argc, char ** argv) { const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) { json data = json::parse(req.body); return handle_completions_generic( - SERVER_TASK_INF_TYPE_COMPLETION, + SERVER_TASK_TYPE_COMPLETION, data, res, /* oaicompat */ false, @@ -3504,7 +3506,7 @@ int main(int argc, char ** argv) { } data["input_extra"] = input_extra; // default to empty array if it's not exist - return handle_completions_generic(SERVER_TASK_INF_TYPE_INFILL, data, res); + return handle_completions_generic(SERVER_TASK_TYPE_INFILL, data, res); }; const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) { @@ -3515,7 +3517,7 @@ int main(int argc, char ** argv) { json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); return handle_completions_generic( - SERVER_TASK_INF_TYPE_COMPLETION, + SERVER_TASK_TYPE_COMPLETION, data, res, /* oaicompat */ true, @@ -3616,7 +3618,7 @@ int main(int argc, char ** argv) { std::vector tasks; std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, /* add_special */ false, true); for (size_t i = 0; i < tokenized_prompts.size(); i++) { - server_task task = server_task(SERVER_TASK_TYPE_INFERENCE, SERVER_TASK_INF_TYPE_EMBEDDING); + server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); task.id = ctx_server.queue_tasks.get_new_id(); task.index = i; task.prompt_tokens = std::move(tokenized_prompts[i]); @@ -3698,7 +3700,7 @@ int main(int argc, char ** argv) { std::vector tokenized_docs = tokenize_input_prompts(ctx_server.ctx, documents, /* add_special */ false, true); tasks.reserve(tokenized_docs.size()); for (size_t i = 0; i < tokenized_docs.size(); i++) { - server_task task = server_task(SERVER_TASK_TYPE_INFERENCE, SERVER_TASK_INF_TYPE_RERANK); + server_task task = server_task(SERVER_TASK_TYPE_RERANK); task.id = ctx_server.queue_tasks.get_new_id(); task.index = i; task.prompt_tokens = format_rerank(ctx_server.model, tokenized_query, tokenized_docs[i]);