server : refactor slot input data, move tokenizer to HTTP thread (#10023)
* server : refactor slot input data, move tokenizer to HTTP thread * move prompt_tokens.empty() check * fix incorrect if branch * fix infinite generation loop * bring back infill validation * add infill test * try fixing format_infill * fix test * remove redundant code * rename completion to inference * update docs * use llama_tokens everywhere
This commit is contained in:
parent
40f2555797
commit
958367bf53
5 changed files with 468 additions and 348 deletions
|
@ -43,21 +43,6 @@
|
|||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
||||
#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
||||
#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
||||
#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
||||
|
||||
#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
|
||||
#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
enum stop_type {
|
||||
|
@ -68,6 +53,7 @@ enum stop_type {
|
|||
// state diagram: https://github.com/ggerganov/llama.cpp/pull/9283
|
||||
enum slot_state {
|
||||
SLOT_STATE_IDLE,
|
||||
SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future
|
||||
SLOT_STATE_PROCESSING_PROMPT,
|
||||
SLOT_STATE_DONE_PROMPT,
|
||||
SLOT_STATE_GENERATING,
|
||||
|
@ -79,7 +65,7 @@ enum server_state {
|
|||
};
|
||||
|
||||
enum server_task_type {
|
||||
SERVER_TASK_TYPE_COMPLETION,
|
||||
SERVER_TASK_TYPE_INFERENCE,
|
||||
SERVER_TASK_TYPE_CANCEL,
|
||||
SERVER_TASK_TYPE_NEXT_RESPONSE,
|
||||
SERVER_TASK_TYPE_METRICS,
|
||||
|
@ -89,21 +75,22 @@ enum server_task_type {
|
|||
SERVER_TASK_TYPE_SET_LORA,
|
||||
};
|
||||
|
||||
enum server_task_cmpl_type {
|
||||
SERVER_TASK_CMPL_TYPE_NORMAL,
|
||||
SERVER_TASK_CMPL_TYPE_EMBEDDING,
|
||||
SERVER_TASK_CMPL_TYPE_RERANK,
|
||||
SERVER_TASK_CMPL_TYPE_INFILL,
|
||||
enum server_task_inf_type {
|
||||
SERVER_TASK_INF_TYPE_COMPLETION,
|
||||
SERVER_TASK_INF_TYPE_EMBEDDING,
|
||||
SERVER_TASK_INF_TYPE_RERANK,
|
||||
SERVER_TASK_INF_TYPE_INFILL,
|
||||
};
|
||||
|
||||
struct server_task {
|
||||
int id = -1; // to be filled by server_queue
|
||||
int id_target = -1; // used by SERVER_TASK_TYPE_CANCEL
|
||||
|
||||
llama_tokens prompt_tokens;
|
||||
server_task_type type;
|
||||
json data;
|
||||
|
||||
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
|
||||
server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
|
||||
|
||||
// utility function
|
||||
static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
|
||||
|
@ -161,26 +148,20 @@ struct server_slot {
|
|||
int32_t i_batch = -1;
|
||||
int32_t n_predict = -1; // TODO: disambiguate from params.n_predict
|
||||
|
||||
// n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated
|
||||
int32_t n_prompt_tokens = 0;
|
||||
int32_t n_prompt_tokens_processed = 0;
|
||||
|
||||
json prompt; // can be either a string, array of strings or array of token ids
|
||||
|
||||
json input_prefix;
|
||||
json input_suffix;
|
||||
json input_extra;
|
||||
|
||||
// when a task is submitted, we first tokenize the prompt and store it here
|
||||
std::vector<llama_token> prompt_tokens;
|
||||
std::vector<llama_token> extra_tokens;
|
||||
// input prompt tokens
|
||||
llama_tokens prompt_tokens;
|
||||
|
||||
size_t last_nl_pos = 0;
|
||||
|
||||
std::string generated_text;
|
||||
std::vector<llama_token> cache_tokens;
|
||||
llama_tokens cache_tokens;
|
||||
std::vector<completion_token_output> generated_token_probs;
|
||||
|
||||
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
|
||||
server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
|
||||
|
||||
bool has_next_token = true;
|
||||
bool has_new_line = false;
|
||||
|
@ -229,7 +210,7 @@ struct server_slot {
|
|||
n_past = 0;
|
||||
n_sent_text = 0;
|
||||
n_sent_token_probs = 0;
|
||||
cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
|
||||
inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
|
||||
|
||||
generated_token_probs.clear();
|
||||
}
|
||||
|
@ -734,42 +715,6 @@ struct server_context {
|
|||
metrics.init();
|
||||
}
|
||||
|
||||
std::vector<llama_token> tokenize(const json & json_prompt, bool add_special, bool parse_special) const {
|
||||
// If `add_bos` is true, we only add BOS, when json_prompt is a string,
|
||||
// or the first element of the json_prompt array is a string.
|
||||
std::vector<llama_token> prompt_tokens;
|
||||
|
||||
if (json_prompt.is_array()) {
|
||||
bool first = true;
|
||||
for (const auto & p : json_prompt) {
|
||||
if (p.is_string()) {
|
||||
auto s = p.template get<std::string>();
|
||||
|
||||
std::vector<llama_token> p;
|
||||
if (first) {
|
||||
p = common_tokenize(ctx, s, add_special, parse_special);
|
||||
first = false;
|
||||
} else {
|
||||
p = common_tokenize(ctx, s, false, parse_special);
|
||||
}
|
||||
|
||||
prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
|
||||
} else {
|
||||
if (first) {
|
||||
first = false;
|
||||
}
|
||||
|
||||
prompt_tokens.push_back(p.template get<llama_token>());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto s = json_prompt.template get<std::string>();
|
||||
prompt_tokens = common_tokenize(ctx, s, add_special, parse_special);
|
||||
}
|
||||
|
||||
return prompt_tokens;
|
||||
}
|
||||
|
||||
server_slot * get_slot_by_id(int id) {
|
||||
for (server_slot & slot : slots) {
|
||||
if (slot.id == id) {
|
||||
|
@ -794,22 +739,16 @@ struct server_context {
|
|||
continue;
|
||||
}
|
||||
|
||||
// skip the slot if it does not contains prompt
|
||||
if (!slot.prompt.is_string()) {
|
||||
// skip the slot if it does not contains cached tokens
|
||||
if (slot.prompt_tokens.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// current slot's prompt
|
||||
std::string slot_prompt = slot.prompt.get<std::string>();
|
||||
|
||||
// length of the current slot's prompt
|
||||
int slot_prompt_len = slot_prompt.size();
|
||||
|
||||
// length of the Longest Common Prefix between the current slot's prompt and the input prompt
|
||||
int lcp_len = longest_common_prefix(slot_prompt, prompt);
|
||||
int lcp_len = longest_common_prefix(slot.cache_tokens, slot.prompt_tokens);
|
||||
|
||||
// fraction of the common substring length compared to the current slot's prompt length
|
||||
similarity = static_cast<float>(lcp_len) / slot_prompt_len;
|
||||
similarity = static_cast<float>(lcp_len) / static_cast<int>(slot.prompt_tokens.size());
|
||||
|
||||
// select the current slot if the criteria match
|
||||
if (lcp_len > max_lcp_len && similarity > slot_prompt_similarity) {
|
||||
|
@ -914,57 +853,6 @@ struct server_context {
|
|||
SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.n_predict, slot.n_predict);
|
||||
}
|
||||
|
||||
// infill
|
||||
slot.input_prefix = json_value(data, "input_prefix", json());
|
||||
slot.input_suffix = json_value(data, "input_suffix", json());
|
||||
slot.input_extra = json_value(data, "input_extra", json());
|
||||
|
||||
SLT_DBG(slot, "extra_context chunks: %d\n", (int) slot.input_extra.size());
|
||||
for (const auto & chunk : slot.input_extra) {
|
||||
// { "text": string, "filename": string }
|
||||
if (!chunk.contains("text") || !chunk["text"].is_string()) {
|
||||
send_error(task, "extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
}
|
||||
|
||||
// filename is optional
|
||||
if (chunk.contains("filename") && !chunk["filename"].is_string()) {
|
||||
send_error(task, "extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
}
|
||||
|
||||
SLT_DBG(slot, "extra_context chunk in file '%s':\n%s\n", chunk.value("filename", "").c_str(), chunk.value("text", "").c_str());
|
||||
}
|
||||
|
||||
// get prompt
|
||||
{
|
||||
const auto & prompt = data.find("prompt");
|
||||
if (prompt == data.end()) {
|
||||
send_error(task, "\"prompt\" must be provided", ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
}
|
||||
|
||||
if ((prompt->is_string()) ||
|
||||
(prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_string()) ||
|
||||
(prompt->is_array() && !prompt->empty() && prompt->at(0).is_number_integer())) {
|
||||
slot.prompt = *prompt;
|
||||
} else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) {
|
||||
slot.prompt = prompt->at(0);
|
||||
} else if (prompt->is_array() && prompt->size() > 1) {
|
||||
// array of strings
|
||||
for (const auto & el : *prompt) {
|
||||
if (!el.is_string()) {
|
||||
send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
slot.prompt = *prompt;
|
||||
} else {
|
||||
send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
slot.sparams.logit_bias.clear();
|
||||
|
||||
|
@ -1044,8 +932,7 @@ struct server_context {
|
|||
}
|
||||
}
|
||||
|
||||
slot.state = SLOT_STATE_PROCESSING_PROMPT;
|
||||
slot.prompt_tokens.clear();
|
||||
slot.state = SLOT_STATE_STARTED;
|
||||
|
||||
SLT_INF(slot, "%s", "processing task\n");
|
||||
|
||||
|
@ -1297,7 +1184,7 @@ struct server_context {
|
|||
};
|
||||
|
||||
if (slot.sparams.n_probs > 0) {
|
||||
const std::vector<llama_token> to_send_toks = common_tokenize(ctx, tkn.text_to_send, false);
|
||||
const llama_tokens to_send_toks = common_tokenize(ctx, tkn.text_to_send, false);
|
||||
const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
|
||||
const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size());
|
||||
|
||||
|
@ -1333,7 +1220,7 @@ struct server_context {
|
|||
{"tokens_predicted", slot.n_decoded},
|
||||
{"tokens_evaluated", slot.n_prompt_tokens},
|
||||
{"generation_settings", get_formated_generation(slot)},
|
||||
{"prompt", slot.prompt},
|
||||
{"prompt", common_detokenize(ctx, slot.prompt_tokens)},
|
||||
{"has_new_line", slot.has_new_line},
|
||||
{"truncated", slot.truncated},
|
||||
{"stopped_eos", slot.stopped_eos},
|
||||
|
@ -1348,7 +1235,7 @@ struct server_context {
|
|||
if (slot.sparams.n_probs > 0) {
|
||||
std::vector<completion_token_output> probs;
|
||||
if (!slot.params.stream && slot.stopped_word) {
|
||||
const std::vector<llama_token> stop_word_toks = common_tokenize(ctx, slot.stopping_word, false);
|
||||
const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false);
|
||||
|
||||
size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size());
|
||||
probs = std::vector<completion_token_output>(
|
||||
|
@ -1457,19 +1344,17 @@ struct server_context {
|
|||
// Functions to create new task(s) and receive result(s)
|
||||
//
|
||||
|
||||
std::vector<server_task> create_tasks_cmpl(json data, server_task_cmpl_type cmpl_type) {
|
||||
// break the input "prompt" into multiple tasks if needed, then format and tokenize the input prompt(s)
|
||||
std::vector<server_task> create_tasks_inference(json data, server_task_inf_type inf_type) {
|
||||
std::vector<server_task> tasks;
|
||||
auto create_task = [&](json & task_data, bool replace_prompt, json prompt) {
|
||||
auto create_task = [&](json & task_data, llama_tokens & prompt_tokens) {
|
||||
SRV_DBG("create task, n_tokens = %d\n", (int) prompt_tokens.size());
|
||||
server_task task;
|
||||
task.id = queue_tasks.get_new_id();
|
||||
task.cmpl_type = cmpl_type;
|
||||
task.type = SERVER_TASK_TYPE_COMPLETION;
|
||||
if (replace_prompt) {
|
||||
task.data = task_data;
|
||||
task.data["prompt"] = std::move(prompt);
|
||||
} else {
|
||||
task.data = std::move(task_data);
|
||||
}
|
||||
task.id = queue_tasks.get_new_id();
|
||||
task.inf_type = inf_type;
|
||||
task.type = SERVER_TASK_TYPE_INFERENCE;
|
||||
task.data = task_data;
|
||||
task.prompt_tokens = std::move(prompt_tokens);
|
||||
tasks.push_back(std::move(task));
|
||||
};
|
||||
|
||||
|
@ -1478,41 +1363,49 @@ struct server_context {
|
|||
throw std::runtime_error(error_msg);
|
||||
}
|
||||
|
||||
json prompt = data.at("prompt");
|
||||
|
||||
// if the prompt is a singleton (i.e. a string or a list of tokens), we only need to create single task
|
||||
if (prompt.is_string() || json_is_array_of_numbers(prompt)) {
|
||||
data["index"] = 0;
|
||||
create_task(data, false, nullptr);
|
||||
} else if (prompt.is_array()) {
|
||||
// otherwise, it's a multiple-prompt task, we break it into smaller tasks
|
||||
std::vector<json> prompts = prompt;
|
||||
if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
||||
// prompts[0] is the question
|
||||
// the rest are the answers/documents
|
||||
SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) prompts.size() - 1);
|
||||
for (size_t i = 1; i < prompts.size(); i++) {
|
||||
json qd;
|
||||
qd.push_back(prompts[0]);
|
||||
qd.push_back(prompts[i]);
|
||||
data["index"] = i - 1;
|
||||
create_task(data, true, qd);
|
||||
}
|
||||
} else {
|
||||
SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) prompts.size());
|
||||
for (size_t i = 0; i < prompts.size(); i++) {
|
||||
const auto & e = prompts[i];
|
||||
if (e.is_string() || json_is_array_of_numbers(e)) {
|
||||
// because llama_tokenize api is thread-safe, we can tokenize the prompt from HTTP thread
|
||||
bool add_special = inf_type != SERVER_TASK_INF_TYPE_RERANK && inf_type != SERVER_TASK_INF_TYPE_INFILL;
|
||||
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx, data.at("prompt"), add_special, true);
|
||||
switch (inf_type) {
|
||||
case SERVER_TASK_INF_TYPE_RERANK:
|
||||
{
|
||||
// prompts[0] is the question
|
||||
// the rest are the answers/documents
|
||||
GGML_ASSERT(tokenized_prompts.size() > 1);
|
||||
SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) tokenized_prompts.size() - 1);
|
||||
for (size_t i = 1; i < tokenized_prompts.size(); i++) {
|
||||
data["index"] = i - 1;
|
||||
auto tokens = format_rerank(model, tokenized_prompts[0], tokenized_prompts[i]);
|
||||
create_task(data, tokens);
|
||||
}
|
||||
} break;
|
||||
case SERVER_TASK_INF_TYPE_INFILL:
|
||||
{
|
||||
SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
|
||||
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
||||
data["index"] = i;
|
||||
create_task(data, true, e);
|
||||
} else {
|
||||
throw std::runtime_error(error_msg);
|
||||
auto tokens = format_infill(
|
||||
ctx,
|
||||
data.at("input_prefix"),
|
||||
data.at("input_suffix"),
|
||||
data.at("input_extra"),
|
||||
params.n_batch,
|
||||
params.n_predict,
|
||||
slots[0].n_ctx, // TODO: there should be a better way
|
||||
params.spm_infill,
|
||||
tokenized_prompts[i]
|
||||
);
|
||||
create_task(data, tokens);
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
|
||||
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
||||
data["index"] = i;
|
||||
create_task(data, tokenized_prompts[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// invalid case
|
||||
throw std::runtime_error(error_msg);
|
||||
}
|
||||
|
||||
return tasks;
|
||||
|
@ -1534,7 +1427,7 @@ struct server_context {
|
|||
queue_tasks.post(cancel_tasks, true);
|
||||
}
|
||||
|
||||
// receive the results from task(s) created by create_tasks_cmpl
|
||||
// receive the results from task(s) created by create_tasks_inference
|
||||
void receive_cmpl_results(
|
||||
const std::unordered_set<int> & id_tasks,
|
||||
const std::function<void(std::vector<server_task_result>&)> & result_handler,
|
||||
|
@ -1558,7 +1451,7 @@ struct server_context {
|
|||
result_handler(results);
|
||||
}
|
||||
|
||||
// receive the results from task(s) created by create_tasks_cmpl, in stream mode
|
||||
// receive the results from task(s) created by create_tasks_inference, in stream mode
|
||||
void receive_cmpl_results_stream(
|
||||
const std::unordered_set<int> & id_tasks, const
|
||||
std::function<bool(server_task_result&)> & result_handler, const
|
||||
|
@ -1591,7 +1484,7 @@ struct server_context {
|
|||
|
||||
void process_single_task(const server_task & task) {
|
||||
switch (task.type) {
|
||||
case SERVER_TASK_TYPE_COMPLETION:
|
||||
case SERVER_TASK_TYPE_INFERENCE:
|
||||
{
|
||||
const int id_slot = json_value(task.data, "id_slot", -1);
|
||||
|
||||
|
@ -1623,9 +1516,10 @@ struct server_context {
|
|||
|
||||
slot->reset();
|
||||
|
||||
slot->id_task = task.id;
|
||||
slot->cmpl_type = task.cmpl_type;
|
||||
slot->index = json_value(task.data, "index", 0);
|
||||
slot->id_task = task.id;
|
||||
slot->inf_type = task.inf_type;
|
||||
slot->index = json_value(task.data, "index", 0);
|
||||
slot->prompt_tokens = std::move(task.prompt_tokens);
|
||||
|
||||
if (!launch_slot_with_task(*slot, task)) {
|
||||
SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id);
|
||||
|
@ -1658,7 +1552,7 @@ struct server_context {
|
|||
slot_data["id"] = slot.id;
|
||||
slot_data["id_task"] = slot.id_task;
|
||||
slot_data["state"] = slot.state;
|
||||
slot_data["prompt"] = slot.prompt;
|
||||
slot_data["prompt"] = common_detokenize(ctx, slot.prompt_tokens);
|
||||
slot_data["next_token"] = {
|
||||
{"has_next_token", slot.has_next_token},
|
||||
{"has_new_line", slot.has_new_line},
|
||||
|
@ -1785,9 +1679,6 @@ struct server_context {
|
|||
}
|
||||
slot->cache_tokens.resize(token_count);
|
||||
|
||||
// TODO: maybe detokenize the slot->cache_tokens instead?
|
||||
slot->prompt = string_format("[restored %d tokens from file]", (int) token_count);
|
||||
|
||||
const int64_t t_end = ggml_time_us();
|
||||
const double t_restore_ms = (t_end - t_start) / 1000.0;
|
||||
|
||||
|
@ -1954,142 +1845,18 @@ struct server_context {
|
|||
if (params.cont_batching || batch.n_tokens == 0) {
|
||||
for (auto & slot : slots) {
|
||||
// this slot still has a prompt to be processed
|
||||
if (slot.state == SLOT_STATE_PROCESSING_PROMPT) {
|
||||
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
|
||||
auto & prompt_tokens = slot.prompt_tokens;
|
||||
|
||||
// we haven't tokenized the prompt yet - do it now:
|
||||
if (prompt_tokens.empty()) {
|
||||
SLT_INF(slot, "tokenizing prompt, len = %d\n", (int) slot.prompt.size());
|
||||
|
||||
// TODO: maybe move branch to outside of this loop in the future
|
||||
if (slot.state == SLOT_STATE_STARTED) {
|
||||
slot.t_start_process_prompt = ggml_time_us();
|
||||
slot.t_start_generation = 0;
|
||||
|
||||
switch (slot.cmpl_type) {
|
||||
case SERVER_TASK_CMPL_TYPE_NORMAL:
|
||||
case SERVER_TASK_CMPL_TYPE_EMBEDDING:
|
||||
{
|
||||
prompt_tokens = tokenize(slot.prompt, llama_add_bos_token(model), true);
|
||||
} break;
|
||||
case SERVER_TASK_CMPL_TYPE_RERANK:
|
||||
{
|
||||
// require slot.prompt to be array of 2 strings
|
||||
if (!slot.prompt.is_array() || slot.prompt.size() != 2) {
|
||||
SLT_ERR(slot, "%s", "invalid prompt for rerank task\n");
|
||||
slot.release();
|
||||
send_error(slot, "invalid prompt for rerank task", ERROR_TYPE_INVALID_REQUEST);
|
||||
continue;
|
||||
}
|
||||
|
||||
// prompt: [BOS]query[EOS][SEP]doc[EOS]
|
||||
prompt_tokens.clear();
|
||||
prompt_tokens.push_back(llama_token_bos(model));
|
||||
{
|
||||
const auto part = tokenize(slot.prompt[0], false, false);
|
||||
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
|
||||
}
|
||||
prompt_tokens.push_back(llama_token_eos(model));
|
||||
prompt_tokens.push_back(llama_token_sep(model));
|
||||
{
|
||||
const auto part = tokenize(slot.prompt[1], false, false);
|
||||
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
|
||||
}
|
||||
prompt_tokens.push_back(llama_token_eos(model));
|
||||
} break;
|
||||
case SERVER_TASK_CMPL_TYPE_INFILL:
|
||||
{
|
||||
// TODO: optimize this block by reducing memory allocations and movement
|
||||
|
||||
// use FIM repo-level pattern:
|
||||
// ref: https://arxiv.org/pdf/2409.12186
|
||||
//
|
||||
// [FIM_REP]myproject
|
||||
// [FIM_SEP]filename0
|
||||
// extra chunk 0
|
||||
// [FIM_SEP]filename1
|
||||
// extra chunk 1
|
||||
// ...
|
||||
// [FIM_SEP]filename
|
||||
// [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt
|
||||
//
|
||||
auto tokens_prefix = tokenize(slot.input_prefix, false, false);
|
||||
auto tokens_suffix = tokenize(slot.input_suffix, false, false);
|
||||
auto tokens_prompt = tokenize(slot.prompt, false, false);
|
||||
|
||||
slot.extra_tokens.clear();
|
||||
if (llama_token_fim_rep(model) != LLAMA_TOKEN_NULL) {
|
||||
static const auto k_fim_repo = tokenize("myproject\n", false, false);
|
||||
|
||||
slot.extra_tokens.push_back(llama_token_fim_rep(model));
|
||||
slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end());
|
||||
}
|
||||
|
||||
for (const auto & chunk : slot.input_extra) {
|
||||
// { "text": string, "filename": string }
|
||||
const std::string text = chunk.value("text", "");
|
||||
const std::string filename = chunk.value("filename", "tmp");
|
||||
|
||||
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
|
||||
const auto k_fim_file = tokenize(filename + "\n", false, false);
|
||||
|
||||
slot.extra_tokens.insert(slot.extra_tokens.end(), llama_token_fim_sep(model));
|
||||
slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
|
||||
} else {
|
||||
// chunk separator in binary form to avoid confusing the AI
|
||||
static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00};
|
||||
static const auto k_chunk_prefix_tokens = tokenize(k_chunk_prefix_str, false, false);
|
||||
|
||||
slot.extra_tokens.insert(slot.extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end());
|
||||
}
|
||||
|
||||
const auto chunk_tokens = tokenize(text, false, false);
|
||||
slot.extra_tokens.insert(slot.extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end());
|
||||
}
|
||||
|
||||
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
|
||||
// TODO: current filename
|
||||
static const auto k_fim_file = tokenize("filename\n", false, false);
|
||||
|
||||
slot.extra_tokens.insert(slot.extra_tokens.end(), llama_token_fim_sep(model));
|
||||
slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
|
||||
}
|
||||
|
||||
// for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?)
|
||||
const int n_suffix_take = std::min<int>(tokens_suffix.size(), (n_batch/4));
|
||||
const int n_prefix_take = std::min<int>(tokens_prefix.size(), 3*(n_batch/4) - 3);
|
||||
|
||||
// fill the rest of the context with extra chunks
|
||||
const int n_extra_take = std::min<int>(std::max<int>(0, slot.n_ctx - (n_batch) - 2*slot.n_predict), slot.extra_tokens.size());
|
||||
|
||||
tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take);
|
||||
tokens_suffix.resize(n_suffix_take);
|
||||
|
||||
tokens_prefix.insert(tokens_prefix.begin(), llama_token_fim_pre(model));
|
||||
tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end());
|
||||
tokens_suffix.insert(tokens_suffix.begin(), llama_token_fim_suf(model));
|
||||
|
||||
auto embd_inp = params.spm_infill ? tokens_suffix : tokens_prefix;
|
||||
auto embd_end = params.spm_infill ? tokens_prefix : tokens_suffix;
|
||||
|
||||
if (llama_add_bos_token(model)) {
|
||||
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
|
||||
}
|
||||
|
||||
SLT_DBG(slot, "extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", slot.n_ctx, n_extra_take, (int) slot.extra_tokens.size());
|
||||
|
||||
// put the extra context before the FIM prefix
|
||||
embd_inp.insert(embd_inp.begin(), slot.extra_tokens.end() - n_extra_take, slot.extra_tokens.end());
|
||||
|
||||
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
|
||||
embd_inp.push_back(llama_token_fim_mid(model));
|
||||
|
||||
prompt_tokens = std::move(embd_inp);
|
||||
} break;
|
||||
}
|
||||
|
||||
slot.n_past = 0;
|
||||
slot.n_prompt_tokens = prompt_tokens.size();
|
||||
slot.state = SLOT_STATE_PROCESSING_PROMPT;
|
||||
|
||||
SLT_INF(slot, "prompt tokenized, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);
|
||||
SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);
|
||||
|
||||
// print prompt tokens (for debugging)
|
||||
if (1) {
|
||||
|
@ -2114,7 +1881,7 @@ struct server_context {
|
|||
continue;
|
||||
}
|
||||
|
||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
||||
if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
|
||||
// this prompt is too large to process - discard it
|
||||
if (slot.n_prompt_tokens > n_ubatch) {
|
||||
slot.release();
|
||||
|
@ -2144,7 +1911,7 @@ struct server_context {
|
|||
const int n_block_size = n_left / 2;
|
||||
const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
|
||||
|
||||
std::vector<llama_token> new_tokens(
|
||||
llama_tokens new_tokens(
|
||||
prompt_tokens.begin(),
|
||||
prompt_tokens.begin() + slot.params.n_keep);
|
||||
|
||||
|
@ -2225,7 +1992,7 @@ struct server_context {
|
|||
}
|
||||
|
||||
// non-causal tasks require to fit the entire prompt in the physical batch
|
||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
||||
if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
|
||||
// cannot fit the prompt in the current batch - will try next iter
|
||||
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
|
||||
continue;
|
||||
|
@ -2234,8 +2001,8 @@ struct server_context {
|
|||
|
||||
// check that we are in the right batch_type, if not defer the slot
|
||||
const bool slot_type =
|
||||
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ||
|
||||
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ? 1 : 0;
|
||||
slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING ||
|
||||
slot.inf_type == SERVER_TASK_INF_TYPE_RERANK ? 1 : 0;
|
||||
|
||||
if (batch_type == -1) {
|
||||
batch_type = slot_type;
|
||||
|
@ -2353,7 +2120,7 @@ struct server_context {
|
|||
}
|
||||
|
||||
if (slot.state == SLOT_STATE_DONE_PROMPT) {
|
||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) {
|
||||
if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING) {
|
||||
// prompt evaluated for embedding
|
||||
send_embedding(slot, batch_view);
|
||||
slot.release();
|
||||
|
@ -2361,7 +2128,7 @@ struct server_context {
|
|||
continue; // continue loop of slots
|
||||
}
|
||||
|
||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
||||
if (slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
|
||||
send_rerank(slot, batch_view);
|
||||
slot.release();
|
||||
slot.i_batch = -1;
|
||||
|
@ -2915,13 +2682,13 @@ int main(int argc, char ** argv) {
|
|||
res_ok(res, {{ "success", true }});
|
||||
};
|
||||
|
||||
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) {
|
||||
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_inf_type inf_type, json & data, httplib::Response & res) {
|
||||
if (ctx_server.params.embedding || ctx_server.params.reranking) {
|
||||
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, cmpl_type);
|
||||
std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, inf_type);
|
||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(tasks);
|
||||
|
||||
|
@ -2967,10 +2734,11 @@ int main(int argc, char ** argv) {
|
|||
|
||||
const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
||||
json data = json::parse(req.body);
|
||||
return handle_completions_generic(SERVER_TASK_CMPL_TYPE_NORMAL, data, res);
|
||||
return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res);
|
||||
};
|
||||
|
||||
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
||||
// check model compatibility
|
||||
std::string err;
|
||||
if (llama_token_fim_pre(ctx_server.model) == LLAMA_TOKEN_NULL) {
|
||||
err += "prefix token is missing. ";
|
||||
|
@ -2981,14 +2749,42 @@ int main(int argc, char ** argv) {
|
|||
if (llama_token_fim_mid(ctx_server.model) == LLAMA_TOKEN_NULL) {
|
||||
err += "middle token is missing. ";
|
||||
}
|
||||
|
||||
if (!err.empty()) {
|
||||
res_error(res, format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
json data = json::parse(req.body);
|
||||
return handle_completions_generic(SERVER_TASK_CMPL_TYPE_INFILL, data, res);
|
||||
|
||||
// validate input
|
||||
if (!data.contains("input_prefix")) {
|
||||
res_error(res, format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST));
|
||||
}
|
||||
|
||||
if (!data.contains("input_suffix")) {
|
||||
res_error(res, format_error_response("\"input_suffix\" is required", ERROR_TYPE_INVALID_REQUEST));
|
||||
}
|
||||
|
||||
if (data.contains("input_extra") && !data.at("input_extra").is_array()) {
|
||||
res_error(res, format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
json input_extra = json_value(data, "input_extra", json::array());
|
||||
for (const auto & chunk : input_extra) {
|
||||
// { "text": string, "filename": string }
|
||||
if (!chunk.contains("text") || !chunk.at("text").is_string()) {
|
||||
res_error(res, format_error_response("extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
// filename is optional
|
||||
if (chunk.contains("filename") && !chunk.at("filename").is_string()) {
|
||||
res_error(res, format_error_response("extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
}
|
||||
data["input_extra"] = input_extra; // default to empty array if it's not exist
|
||||
|
||||
return handle_completions_generic(SERVER_TASK_INF_TYPE_INFILL, data, res);
|
||||
};
|
||||
|
||||
// TODO: maybe merge this function with "handle_completions_generic"
|
||||
|
@ -3000,7 +2796,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
|
||||
|
||||
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL);
|
||||
std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, SERVER_TASK_INF_TYPE_COMPLETION);
|
||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(tasks);
|
||||
|
||||
|
@ -3073,7 +2869,7 @@ int main(int argc, char ** argv) {
|
|||
const bool add_special = json_value(body, "add_special", false);
|
||||
const bool with_pieces = json_value(body, "with_pieces", false);
|
||||
|
||||
std::vector<llama_token> tokens = ctx_server.tokenize(body.at("content"), add_special, true);
|
||||
llama_tokens tokens = tokenize_mixed(ctx_server.ctx, body.at("content"), add_special, true);
|
||||
|
||||
if (with_pieces) {
|
||||
for (const auto& token : tokens) {
|
||||
|
@ -3110,7 +2906,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
std::string content;
|
||||
if (body.count("tokens") != 0) {
|
||||
const std::vector<llama_token> tokens = body.at("tokens");
|
||||
const llama_tokens tokens = body.at("tokens");
|
||||
content = tokens_to_str(ctx_server.ctx, tokens.cbegin(), tokens.cend());
|
||||
}
|
||||
|
||||
|
@ -3144,7 +2940,7 @@ int main(int argc, char ** argv) {
|
|||
json responses = json::array();
|
||||
bool error = false;
|
||||
{
|
||||
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING);
|
||||
std::vector<server_task> tasks = ctx_server.create_tasks_inference({{"prompt", prompt}}, SERVER_TASK_INF_TYPE_EMBEDDING);
|
||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(tasks);
|
||||
|
||||
|
@ -3221,7 +3017,7 @@ int main(int argc, char ** argv) {
|
|||
json responses = json::array();
|
||||
bool error = false;
|
||||
{
|
||||
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK);
|
||||
std::vector<server_task> tasks = ctx_server.create_tasks_inference({{"prompt", prompt}}, SERVER_TASK_INF_TYPE_RERANK);
|
||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(tasks);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue