server : refactor slot input data, move tokenizer to HTTP thread
This commit is contained in:
parent
190a37d797
commit
125835b253
2 changed files with 492 additions and 489 deletions
|
@ -100,6 +100,7 @@ struct server_task {
|
|||
int id = -1; // to be filled by server_queue
|
||||
int id_target = -1; // used by SERVER_TASK_TYPE_CANCEL
|
||||
|
||||
std::vector<llama_token> prompt_tokens;
|
||||
server_task_type type;
|
||||
json data;
|
||||
|
||||
|
@ -161,18 +162,12 @@ struct server_slot {
|
|||
int32_t i_batch = -1;
|
||||
int32_t n_predict = -1; // TODO: disambiguate from params.n_predict
|
||||
|
||||
// n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated
|
||||
int32_t n_prompt_tokens = 0;
|
||||
int32_t n_prompt_tokens_processed = 0;
|
||||
|
||||
json prompt; // can be either a string, array of strings or array of token ids
|
||||
|
||||
json input_prefix;
|
||||
json input_suffix;
|
||||
json input_extra;
|
||||
|
||||
// when a task is submitted, we first tokenize the prompt and store it here
|
||||
// input prompt tokens
|
||||
std::vector<llama_token> prompt_tokens;
|
||||
std::vector<llama_token> extra_tokens;
|
||||
|
||||
size_t last_nl_pos = 0;
|
||||
|
||||
|
@ -735,39 +730,7 @@ struct server_context {
|
|||
}
|
||||
|
||||
std::vector<llama_token> tokenize(const json & json_prompt, bool add_special, bool parse_special) const {
|
||||
// If `add_bos` is true, we only add BOS, when json_prompt is a string,
|
||||
// or the first element of the json_prompt array is a string.
|
||||
std::vector<llama_token> prompt_tokens;
|
||||
|
||||
if (json_prompt.is_array()) {
|
||||
bool first = true;
|
||||
for (const auto & p : json_prompt) {
|
||||
if (p.is_string()) {
|
||||
auto s = p.template get<std::string>();
|
||||
|
||||
std::vector<llama_token> p;
|
||||
if (first) {
|
||||
p = common_tokenize(ctx, s, add_special, parse_special);
|
||||
first = false;
|
||||
} else {
|
||||
p = common_tokenize(ctx, s, false, parse_special);
|
||||
}
|
||||
|
||||
prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
|
||||
} else {
|
||||
if (first) {
|
||||
first = false;
|
||||
}
|
||||
|
||||
prompt_tokens.push_back(p.template get<llama_token>());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto s = json_prompt.template get<std::string>();
|
||||
prompt_tokens = common_tokenize(ctx, s, add_special, parse_special);
|
||||
}
|
||||
|
||||
return prompt_tokens;
|
||||
return tokenize_mixed(ctx, json_prompt, add_special, parse_special);
|
||||
}
|
||||
|
||||
server_slot * get_slot_by_id(int id) {
|
||||
|
@ -794,22 +757,16 @@ struct server_context {
|
|||
continue;
|
||||
}
|
||||
|
||||
// skip the slot if it does not contains prompt
|
||||
if (!slot.prompt.is_string()) {
|
||||
// skip the slot if it does not contains cached tokens
|
||||
if (slot.prompt_tokens.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// current slot's prompt
|
||||
std::string slot_prompt = slot.prompt.get<std::string>();
|
||||
|
||||
// length of the current slot's prompt
|
||||
int slot_prompt_len = slot_prompt.size();
|
||||
|
||||
// length of the Longest Common Prefix between the current slot's prompt and the input prompt
|
||||
int lcp_len = longest_common_prefix(slot_prompt, prompt);
|
||||
int lcp_len = longest_common_prefix(slot.cache_tokens, slot.prompt_tokens);
|
||||
|
||||
// fraction of the common substring length compared to the current slot's prompt length
|
||||
similarity = static_cast<float>(lcp_len) / slot_prompt_len;
|
||||
similarity = static_cast<float>(lcp_len) / static_cast<int>(slot.prompt_tokens.size());
|
||||
|
||||
// select the current slot if the criteria match
|
||||
if (lcp_len > max_lcp_len && similarity > slot_prompt_similarity) {
|
||||
|
@ -914,57 +871,6 @@ struct server_context {
|
|||
SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.n_predict, slot.n_predict);
|
||||
}
|
||||
|
||||
// infill
|
||||
slot.input_prefix = json_value(data, "input_prefix", json());
|
||||
slot.input_suffix = json_value(data, "input_suffix", json());
|
||||
slot.input_extra = json_value(data, "input_extra", json());
|
||||
|
||||
SLT_DBG(slot, "extra_context chunks: %d\n", (int) slot.input_extra.size());
|
||||
for (const auto & chunk : slot.input_extra) {
|
||||
// { "text": string, "filename": string }
|
||||
if (!chunk.contains("text") || !chunk["text"].is_string()) {
|
||||
send_error(task, "extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
}
|
||||
|
||||
// filename is optional
|
||||
if (chunk.contains("filename") && !chunk["filename"].is_string()) {
|
||||
send_error(task, "extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
}
|
||||
|
||||
SLT_DBG(slot, "extra_context chunk in file '%s':\n%s\n", chunk.value("filename", "").c_str(), chunk.value("text", "").c_str());
|
||||
}
|
||||
|
||||
// get prompt
|
||||
{
|
||||
const auto & prompt = data.find("prompt");
|
||||
if (prompt == data.end()) {
|
||||
send_error(task, "\"prompt\" must be provided", ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
}
|
||||
|
||||
if ((prompt->is_string()) ||
|
||||
(prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_string()) ||
|
||||
(prompt->is_array() && !prompt->empty() && prompt->at(0).is_number_integer())) {
|
||||
slot.prompt = *prompt;
|
||||
} else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) {
|
||||
slot.prompt = prompt->at(0);
|
||||
} else if (prompt->is_array() && prompt->size() > 1) {
|
||||
// array of strings
|
||||
for (const auto & el : *prompt) {
|
||||
if (!el.is_string()) {
|
||||
send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
slot.prompt = *prompt;
|
||||
} else {
|
||||
send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
slot.sparams.logit_bias.clear();
|
||||
|
||||
|
@ -1045,7 +951,6 @@ struct server_context {
|
|||
}
|
||||
|
||||
slot.state = SLOT_STATE_PROCESSING_PROMPT;
|
||||
slot.prompt_tokens.clear();
|
||||
|
||||
SLT_INF(slot, "%s", "processing task\n");
|
||||
|
||||
|
@ -1333,7 +1238,7 @@ struct server_context {
|
|||
{"tokens_predicted", slot.n_decoded},
|
||||
{"tokens_evaluated", slot.n_prompt_tokens},
|
||||
{"generation_settings", get_formated_generation(slot)},
|
||||
{"prompt", slot.prompt},
|
||||
{"prompt", common_detokenize(ctx, slot.prompt_tokens)},
|
||||
{"has_new_line", slot.has_new_line},
|
||||
{"truncated", slot.truncated},
|
||||
{"stopped_eos", slot.stopped_eos},
|
||||
|
@ -1457,19 +1362,21 @@ struct server_context {
|
|||
// Functions to create new task(s) and receive result(s)
|
||||
//
|
||||
|
||||
// break the input "prompt" into multiple tasks if needed, then format and tokenize the input prompt(s)
|
||||
std::vector<server_task> create_tasks_cmpl(json data, server_task_cmpl_type cmpl_type) {
|
||||
std::vector<server_task> tasks;
|
||||
auto create_task = [&](json & task_data, bool replace_prompt, json prompt) {
|
||||
server_task task;
|
||||
task.id = queue_tasks.get_new_id();
|
||||
task.cmpl_type = cmpl_type;
|
||||
task.type = SERVER_TASK_TYPE_COMPLETION;
|
||||
if (replace_prompt) {
|
||||
task.data = task_data;
|
||||
task.data["prompt"] = std::move(prompt);
|
||||
} else {
|
||||
task.data = std::move(task_data);
|
||||
auto create_task = [&](json & task_data, llama_tokens & prompt_tokens) {
|
||||
if (prompt_tokens.empty()) {
|
||||
// TODO @ngxson : should not throw an error
|
||||
throw std::runtime_error("prompt must not be empty");
|
||||
}
|
||||
SRV_DBG("create task, n_tokens = %d\n", (int) prompt_tokens.size());
|
||||
server_task task;
|
||||
task.id = queue_tasks.get_new_id();
|
||||
task.cmpl_type = cmpl_type;
|
||||
task.type = SERVER_TASK_TYPE_COMPLETION;
|
||||
task.data = task_data;
|
||||
task.prompt_tokens = std::move(prompt_tokens);
|
||||
tasks.push_back(std::move(task));
|
||||
};
|
||||
|
||||
|
@ -1478,41 +1385,49 @@ struct server_context {
|
|||
throw std::runtime_error(error_msg);
|
||||
}
|
||||
|
||||
json prompt = data.at("prompt");
|
||||
|
||||
// if the prompt is a singleton (i.e. a string or a list of tokens), we only need to create single task
|
||||
if (prompt.is_string() || json_is_array_of_numbers(prompt)) {
|
||||
data["index"] = 0;
|
||||
create_task(data, false, nullptr);
|
||||
} else if (prompt.is_array()) {
|
||||
// otherwise, it's a multiple-prompt task, we break it into smaller tasks
|
||||
std::vector<json> prompts = prompt;
|
||||
if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
||||
// prompts[0] is the question
|
||||
// the rest are the answers/documents
|
||||
SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) prompts.size() - 1);
|
||||
for (size_t i = 1; i < prompts.size(); i++) {
|
||||
json qd;
|
||||
qd.push_back(prompts[0]);
|
||||
qd.push_back(prompts[i]);
|
||||
data["index"] = i - 1;
|
||||
create_task(data, true, qd);
|
||||
}
|
||||
} else {
|
||||
SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) prompts.size());
|
||||
for (size_t i = 0; i < prompts.size(); i++) {
|
||||
const auto & e = prompts[i];
|
||||
if (e.is_string() || json_is_array_of_numbers(e)) {
|
||||
// because llama_tokenize api is thread-safe, we can tokenize the prompt from HTTP thread
|
||||
bool add_special = cmpl_type != SERVER_TASK_CMPL_TYPE_RERANK && cmpl_type != SERVER_TASK_CMPL_TYPE_INFILL;
|
||||
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx, data.at("prompt"), add_special, true);
|
||||
switch (cmpl_type) {
|
||||
case SERVER_TASK_CMPL_TYPE_RERANK:
|
||||
{
|
||||
// prompts[0] is the question
|
||||
// the rest are the answers/documents
|
||||
GGML_ASSERT(tokenized_prompts.size() > 1);
|
||||
SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) tokenized_prompts.size() - 1);
|
||||
for (size_t i = 1; i < tokenized_prompts.size(); i++) {
|
||||
data["index"] = i - 1;
|
||||
auto tokens = format_rerank(model, tokenized_prompts[0], tokenized_prompts[i]);
|
||||
create_task(data, tokens);
|
||||
}
|
||||
} break;
|
||||
case SERVER_TASK_CMPL_TYPE_INFILL:
|
||||
{
|
||||
SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
|
||||
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
||||
data["index"] = i;
|
||||
create_task(data, true, e);
|
||||
} else {
|
||||
throw std::runtime_error(error_msg);
|
||||
auto tokens = format_infill(
|
||||
ctx,
|
||||
data.at("input_prefix"),
|
||||
data.at("input_suffix"),
|
||||
data.at("input_extra"),
|
||||
params.n_batch,
|
||||
params.n_predict,
|
||||
slots[0].n_ctx, // TODO: there should be a better way
|
||||
params.spm_infill,
|
||||
tokenized_prompts[i]
|
||||
);
|
||||
create_task(data, tokens);
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
|
||||
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
||||
data["index"] = i;
|
||||
create_task(data, tokenized_prompts[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// invalid case
|
||||
throw std::runtime_error(error_msg);
|
||||
}
|
||||
|
||||
return tasks;
|
||||
|
@ -1623,9 +1538,10 @@ struct server_context {
|
|||
|
||||
slot->reset();
|
||||
|
||||
slot->id_task = task.id;
|
||||
slot->cmpl_type = task.cmpl_type;
|
||||
slot->index = json_value(task.data, "index", 0);
|
||||
slot->id_task = task.id;
|
||||
slot->cmpl_type = task.cmpl_type;
|
||||
slot->index = json_value(task.data, "index", 0);
|
||||
slot->prompt_tokens = std::move(task.prompt_tokens);
|
||||
|
||||
if (!launch_slot_with_task(*slot, task)) {
|
||||
SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id);
|
||||
|
@ -1658,7 +1574,7 @@ struct server_context {
|
|||
slot_data["id"] = slot.id;
|
||||
slot_data["id_task"] = slot.id_task;
|
||||
slot_data["state"] = slot.state;
|
||||
slot_data["prompt"] = slot.prompt;
|
||||
slot_data["prompt"] = common_detokenize(ctx, slot.prompt_tokens);
|
||||
slot_data["next_token"] = {
|
||||
{"has_next_token", slot.has_next_token},
|
||||
{"has_new_line", slot.has_new_line},
|
||||
|
@ -1785,9 +1701,6 @@ struct server_context {
|
|||
}
|
||||
slot->cache_tokens.resize(token_count);
|
||||
|
||||
// TODO: maybe detokenize the slot->cache_tokens instead?
|
||||
slot->prompt = string_format("[restored %d tokens from file]", (int) token_count);
|
||||
|
||||
const int64_t t_end = ggml_time_us();
|
||||
const double t_restore_ms = (t_end - t_start) / 1000.0;
|
||||
|
||||
|
@ -1953,349 +1866,225 @@ struct server_context {
|
|||
// next, batch any pending prompts without exceeding n_batch
|
||||
if (params.cont_batching || batch.n_tokens == 0) {
|
||||
for (auto & slot : slots) {
|
||||
if (!slot.is_processing()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// this slot still has a prompt to be processed
|
||||
if (slot.state == SLOT_STATE_PROCESSING_PROMPT) {
|
||||
if (!slot.prompt_tokens.empty() && slot.state == SLOT_STATE_PROCESSING_PROMPT) {
|
||||
auto & prompt_tokens = slot.prompt_tokens;
|
||||
|
||||
// we haven't tokenized the prompt yet - do it now:
|
||||
slot.t_start_process_prompt = ggml_time_us();
|
||||
slot.t_start_generation = 0;
|
||||
slot.n_past = 0;
|
||||
slot.n_prompt_tokens = prompt_tokens.size();
|
||||
|
||||
SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);
|
||||
|
||||
// print prompt tokens (for debugging)
|
||||
if (1) {
|
||||
// first 16 tokens (avoid flooding logs)
|
||||
for (int i = 0; i < std::min<int>(16, prompt_tokens.size()); i++) {
|
||||
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
|
||||
}
|
||||
} else {
|
||||
// all
|
||||
for (int i = 0; i < (int) prompt_tokens.size(); i++) {
|
||||
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
|
||||
}
|
||||
}
|
||||
|
||||
// empty prompt passed -> release the slot and send empty response
|
||||
if (prompt_tokens.empty()) {
|
||||
SLT_INF(slot, "tokenizing prompt, len = %d\n", (int) slot.prompt.size());
|
||||
SLT_WRN(slot, "%s", "empty prompt - releasing slot\n");
|
||||
|
||||
slot.t_start_process_prompt = ggml_time_us();
|
||||
slot.t_start_generation = 0;
|
||||
|
||||
switch (slot.cmpl_type) {
|
||||
case SERVER_TASK_CMPL_TYPE_NORMAL:
|
||||
case SERVER_TASK_CMPL_TYPE_EMBEDDING:
|
||||
{
|
||||
prompt_tokens = tokenize(slot.prompt, llama_add_bos_token(model), true);
|
||||
} break;
|
||||
case SERVER_TASK_CMPL_TYPE_RERANK:
|
||||
{
|
||||
// require slot.prompt to be array of 2 strings
|
||||
if (!slot.prompt.is_array() || slot.prompt.size() != 2) {
|
||||
SLT_ERR(slot, "%s", "invalid prompt for rerank task\n");
|
||||
slot.release();
|
||||
send_error(slot, "invalid prompt for rerank task", ERROR_TYPE_INVALID_REQUEST);
|
||||
continue;
|
||||
}
|
||||
|
||||
// prompt: [BOS]query[EOS][SEP]doc[EOS]
|
||||
prompt_tokens.clear();
|
||||
prompt_tokens.push_back(llama_token_bos(model));
|
||||
{
|
||||
const auto part = tokenize(slot.prompt[0], false, false);
|
||||
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
|
||||
}
|
||||
prompt_tokens.push_back(llama_token_eos(model));
|
||||
prompt_tokens.push_back(llama_token_sep(model));
|
||||
{
|
||||
const auto part = tokenize(slot.prompt[1], false, false);
|
||||
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
|
||||
}
|
||||
prompt_tokens.push_back(llama_token_eos(model));
|
||||
} break;
|
||||
case SERVER_TASK_CMPL_TYPE_INFILL:
|
||||
{
|
||||
// TODO: optimize this block by reducing memory allocations and movement
|
||||
|
||||
// use FIM repo-level pattern:
|
||||
// ref: https://arxiv.org/pdf/2409.12186
|
||||
//
|
||||
// [FIM_REP]myproject
|
||||
// [FIM_SEP]filename0
|
||||
// extra chunk 0
|
||||
// [FIM_SEP]filename1
|
||||
// extra chunk 1
|
||||
// ...
|
||||
// [FIM_SEP]filename
|
||||
// [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt
|
||||
//
|
||||
auto tokens_prefix = tokenize(slot.input_prefix, false, false);
|
||||
auto tokens_suffix = tokenize(slot.input_suffix, false, false);
|
||||
auto tokens_prompt = tokenize(slot.prompt, false, false);
|
||||
|
||||
slot.extra_tokens.clear();
|
||||
if (llama_token_fim_rep(model) != LLAMA_TOKEN_NULL) {
|
||||
static const auto k_fim_repo = tokenize("myproject\n", false, false);
|
||||
|
||||
slot.extra_tokens.push_back(llama_token_fim_rep(model));
|
||||
slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end());
|
||||
}
|
||||
|
||||
for (const auto & chunk : slot.input_extra) {
|
||||
// { "text": string, "filename": string }
|
||||
const std::string text = chunk.value("text", "");
|
||||
const std::string filename = chunk.value("filename", "tmp");
|
||||
|
||||
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
|
||||
const auto k_fim_file = tokenize(filename + "\n", false, false);
|
||||
|
||||
slot.extra_tokens.insert(slot.extra_tokens.end(), llama_token_fim_sep(model));
|
||||
slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
|
||||
} else {
|
||||
// chunk separator in binary form to avoid confusing the AI
|
||||
static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00};
|
||||
static const auto k_chunk_prefix_tokens = tokenize(k_chunk_prefix_str, false, false);
|
||||
|
||||
slot.extra_tokens.insert(slot.extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end());
|
||||
}
|
||||
|
||||
const auto chunk_tokens = tokenize(text, false, false);
|
||||
slot.extra_tokens.insert(slot.extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end());
|
||||
}
|
||||
|
||||
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
|
||||
// TODO: current filename
|
||||
static const auto k_fim_file = tokenize("filename\n", false, false);
|
||||
|
||||
slot.extra_tokens.insert(slot.extra_tokens.end(), llama_token_fim_sep(model));
|
||||
slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
|
||||
}
|
||||
|
||||
// for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?)
|
||||
const int n_suffix_take = std::min<int>(tokens_suffix.size(), (n_batch/4));
|
||||
const int n_prefix_take = std::min<int>(tokens_prefix.size(), 3*(n_batch/4) - 3);
|
||||
|
||||
// fill the rest of the context with extra chunks
|
||||
const int n_extra_take = std::min<int>(std::max<int>(0, slot.n_ctx - (n_batch) - 2*slot.n_predict), slot.extra_tokens.size());
|
||||
|
||||
tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take);
|
||||
tokens_suffix.resize(n_suffix_take);
|
||||
|
||||
tokens_prefix.insert(tokens_prefix.begin(), llama_token_fim_pre(model));
|
||||
tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end());
|
||||
tokens_suffix.insert(tokens_suffix.begin(), llama_token_fim_suf(model));
|
||||
|
||||
auto embd_inp = params.spm_infill ? tokens_suffix : tokens_prefix;
|
||||
auto embd_end = params.spm_infill ? tokens_prefix : tokens_suffix;
|
||||
|
||||
if (llama_add_bos_token(model)) {
|
||||
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
|
||||
}
|
||||
|
||||
SLT_DBG(slot, "extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", slot.n_ctx, n_extra_take, (int) slot.extra_tokens.size());
|
||||
|
||||
// put the extra context before the FIM prefix
|
||||
embd_inp.insert(embd_inp.begin(), slot.extra_tokens.end() - n_extra_take, slot.extra_tokens.end());
|
||||
|
||||
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
|
||||
embd_inp.push_back(llama_token_fim_mid(model));
|
||||
|
||||
prompt_tokens = std::move(embd_inp);
|
||||
} break;
|
||||
}
|
||||
|
||||
slot.n_past = 0;
|
||||
slot.n_prompt_tokens = prompt_tokens.size();
|
||||
|
||||
SLT_INF(slot, "prompt tokenized, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);
|
||||
|
||||
// print prompt tokens (for debugging)
|
||||
if (1) {
|
||||
// first 16 tokens (avoid flooding logs)
|
||||
for (int i = 0; i < std::min<int>(16, prompt_tokens.size()); i++) {
|
||||
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
|
||||
}
|
||||
} else {
|
||||
// all
|
||||
for (int i = 0; i < (int) prompt_tokens.size(); i++) {
|
||||
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
|
||||
}
|
||||
}
|
||||
|
||||
// empty prompt passed -> release the slot and send empty response
|
||||
if (prompt_tokens.empty()) {
|
||||
SLT_WRN(slot, "%s", "empty prompt - releasing slot\n");
|
||||
|
||||
slot.release();
|
||||
slot.print_timings();
|
||||
send_final_response(slot);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
||||
// this prompt is too large to process - discard it
|
||||
if (slot.n_prompt_tokens > n_ubatch) {
|
||||
slot.release();
|
||||
send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
if (!params.ctx_shift) {
|
||||
// if context shift is disabled, we make sure prompt size is smaller than KV size
|
||||
// TODO: there should be a separate parameter that control prompt truncation
|
||||
// context shift should be applied only during the generation phase
|
||||
if (slot.n_prompt_tokens >= slot.n_ctx) {
|
||||
slot.release();
|
||||
send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (slot.params.n_keep < 0) {
|
||||
slot.params.n_keep = slot.n_prompt_tokens;
|
||||
}
|
||||
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
|
||||
|
||||
// if input prompt is too big, truncate it
|
||||
if (slot.n_prompt_tokens >= slot.n_ctx) {
|
||||
const int n_left = slot.n_ctx - slot.params.n_keep;
|
||||
|
||||
const int n_block_size = n_left / 2;
|
||||
const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
|
||||
|
||||
std::vector<llama_token> new_tokens(
|
||||
prompt_tokens.begin(),
|
||||
prompt_tokens.begin() + slot.params.n_keep);
|
||||
|
||||
new_tokens.insert(
|
||||
new_tokens.end(),
|
||||
prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size,
|
||||
prompt_tokens.end());
|
||||
|
||||
prompt_tokens = std::move(new_tokens);
|
||||
|
||||
slot.truncated = true;
|
||||
slot.n_prompt_tokens = prompt_tokens.size();
|
||||
|
||||
SLT_WRN(slot, "input truncated, n_ctx = %d, n_keep = %d, n_left = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, n_left, slot.n_prompt_tokens);
|
||||
|
||||
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
|
||||
}
|
||||
|
||||
common_sampler_reset(slot.smpl);
|
||||
|
||||
if (slot.params.cache_prompt) {
|
||||
// reuse any previously computed tokens that are common with the new prompt
|
||||
slot.n_past = longest_common_prefix(slot.cache_tokens, prompt_tokens);
|
||||
|
||||
// push the prompt into the sampling context (do not apply grammar)
|
||||
for (int i = 0; i < slot.n_past; ++i) {
|
||||
common_sampler_accept(slot.smpl, slot.cache_tokens[i], false);
|
||||
}
|
||||
|
||||
// reuse chunks from the cached prompt by shifting their KV cache in the new position
|
||||
if (params.n_cache_reuse > 0) {
|
||||
size_t head_c = slot.n_past; // cache
|
||||
size_t head_p = slot.n_past; // current prompt
|
||||
|
||||
SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params.n_cache_reuse, slot.n_past);
|
||||
|
||||
while (head_c < slot.cache_tokens.size() &&
|
||||
head_p < prompt_tokens.size()) {
|
||||
|
||||
size_t n_match = 0;
|
||||
while (head_c + n_match < slot.cache_tokens.size() &&
|
||||
head_p + n_match < prompt_tokens.size() &&
|
||||
slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) {
|
||||
|
||||
n_match++;
|
||||
}
|
||||
|
||||
if (n_match >= (size_t) params.n_cache_reuse) {
|
||||
SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
|
||||
//for (size_t i = head_p; i < head_p + n_match; i++) {
|
||||
// SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
|
||||
//}
|
||||
|
||||
const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;
|
||||
|
||||
llama_kv_cache_seq_rm (ctx, slot.id + 1, head_p, head_c);
|
||||
llama_kv_cache_seq_add(ctx, slot.id + 1, head_c, -1, kv_shift);
|
||||
|
||||
for (size_t i = 0; i < n_match; i++) {
|
||||
slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
|
||||
|
||||
common_sampler_accept(slot.smpl, slot.cache_tokens[head_p + i], false);
|
||||
|
||||
slot.n_past++;
|
||||
}
|
||||
|
||||
head_c += n_match;
|
||||
head_p += n_match;
|
||||
} else {
|
||||
head_c += 1;
|
||||
}
|
||||
}
|
||||
|
||||
SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) {
|
||||
// we have to evaluate at least 1 token to generate logits.
|
||||
SLT_WRN(slot, "need to evaluate at least 1 token to generate logits, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens);
|
||||
|
||||
slot.n_past--;
|
||||
}
|
||||
|
||||
slot.n_prompt_tokens_processed = 0;
|
||||
}
|
||||
|
||||
// non-causal tasks require to fit the entire prompt in the physical batch
|
||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
||||
// cannot fit the prompt in the current batch - will try next iter
|
||||
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// check that we are in the right batch_type, if not defer the slot
|
||||
const bool slot_type =
|
||||
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ||
|
||||
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ? 1 : 0;
|
||||
|
||||
if (batch_type == -1) {
|
||||
batch_type = slot_type;
|
||||
} else if (batch_type != slot_type) {
|
||||
slot.release();
|
||||
slot.print_timings();
|
||||
send_final_response(slot);
|
||||
continue;
|
||||
}
|
||||
|
||||
// keep only the common part
|
||||
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, slot.n_past, -1)) {
|
||||
// could not partially delete (likely using a non-Transformer model)
|
||||
llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1);
|
||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
||||
// this prompt is too large to process - discard it
|
||||
if (slot.n_prompt_tokens > n_ubatch) {
|
||||
slot.release();
|
||||
send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
if (!params.ctx_shift) {
|
||||
// if context shift is disabled, we make sure prompt size is smaller than KV size
|
||||
// TODO: there should be a separate parameter that control prompt truncation
|
||||
// context shift should be applied only during the generation phase
|
||||
if (slot.n_prompt_tokens >= slot.n_ctx) {
|
||||
slot.release();
|
||||
send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (slot.params.n_keep < 0) {
|
||||
slot.params.n_keep = slot.n_prompt_tokens;
|
||||
}
|
||||
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
|
||||
|
||||
// there is no common part left
|
||||
slot.n_past = 0;
|
||||
// if input prompt is too big, truncate it
|
||||
if (slot.n_prompt_tokens >= slot.n_ctx) {
|
||||
const int n_left = slot.n_ctx - slot.params.n_keep;
|
||||
|
||||
common_sampler_reset(slot.smpl);
|
||||
}
|
||||
const int n_block_size = n_left / 2;
|
||||
const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
|
||||
|
||||
SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);
|
||||
std::vector<llama_token> new_tokens(
|
||||
prompt_tokens.begin(),
|
||||
prompt_tokens.begin() + slot.params.n_keep);
|
||||
|
||||
// remove the non-common part from the cache
|
||||
slot.cache_tokens.resize(slot.n_past);
|
||||
new_tokens.insert(
|
||||
new_tokens.end(),
|
||||
prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size,
|
||||
prompt_tokens.end());
|
||||
|
||||
// add prompt tokens for processing in the current batch
|
||||
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
|
||||
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id + 1 }, false);
|
||||
prompt_tokens = std::move(new_tokens);
|
||||
|
||||
if (slot.params.cache_prompt) {
|
||||
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
|
||||
slot.truncated = true;
|
||||
slot.n_prompt_tokens = prompt_tokens.size();
|
||||
|
||||
SLT_WRN(slot, "input truncated, n_ctx = %d, n_keep = %d, n_left = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, n_left, slot.n_prompt_tokens);
|
||||
|
||||
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
|
||||
}
|
||||
|
||||
slot.n_prompt_tokens_processed++;
|
||||
slot.n_past++;
|
||||
common_sampler_reset(slot.smpl);
|
||||
|
||||
if (slot.params.cache_prompt) {
|
||||
// reuse any previously computed tokens that are common with the new prompt
|
||||
slot.n_past = longest_common_prefix(slot.cache_tokens, prompt_tokens);
|
||||
|
||||
// push the prompt into the sampling context (do not apply grammar)
|
||||
for (int i = 0; i < slot.n_past; ++i) {
|
||||
common_sampler_accept(slot.smpl, slot.cache_tokens[i], false);
|
||||
}
|
||||
|
||||
// reuse chunks from the cached prompt by shifting their KV cache in the new position
|
||||
if (params.n_cache_reuse > 0) {
|
||||
size_t head_c = slot.n_past; // cache
|
||||
size_t head_p = slot.n_past; // current prompt
|
||||
|
||||
SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params.n_cache_reuse, slot.n_past);
|
||||
|
||||
while (head_c < slot.cache_tokens.size() &&
|
||||
head_p < prompt_tokens.size()) {
|
||||
|
||||
size_t n_match = 0;
|
||||
while (head_c + n_match < slot.cache_tokens.size() &&
|
||||
head_p + n_match < prompt_tokens.size() &&
|
||||
slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) {
|
||||
|
||||
n_match++;
|
||||
}
|
||||
|
||||
if (n_match >= (size_t) params.n_cache_reuse) {
|
||||
SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
|
||||
//for (size_t i = head_p; i < head_p + n_match; i++) {
|
||||
// SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
|
||||
//}
|
||||
|
||||
const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;
|
||||
|
||||
llama_kv_cache_seq_rm (ctx, slot.id + 1, head_p, head_c);
|
||||
llama_kv_cache_seq_add(ctx, slot.id + 1, head_c, -1, kv_shift);
|
||||
|
||||
for (size_t i = 0; i < n_match; i++) {
|
||||
slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
|
||||
|
||||
common_sampler_accept(slot.smpl, slot.cache_tokens[head_p + i], false);
|
||||
|
||||
slot.n_past++;
|
||||
}
|
||||
|
||||
head_c += n_match;
|
||||
head_p += n_match;
|
||||
} else {
|
||||
head_c += 1;
|
||||
}
|
||||
}
|
||||
|
||||
SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
|
||||
if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) {
|
||||
// we have to evaluate at least 1 token to generate logits.
|
||||
SLT_WRN(slot, "need to evaluate at least 1 token to generate logits, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens);
|
||||
|
||||
// entire prompt has been processed
|
||||
if (slot.n_past == slot.n_prompt_tokens) {
|
||||
slot.state = SLOT_STATE_DONE_PROMPT;
|
||||
|
||||
GGML_ASSERT(batch.n_tokens > 0);
|
||||
|
||||
// extract the logits only for the last token
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
|
||||
slot.n_decoded = 0;
|
||||
slot.i_batch = batch.n_tokens - 1;
|
||||
|
||||
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens);
|
||||
slot.n_past--;
|
||||
}
|
||||
|
||||
slot.n_prompt_tokens_processed = 0;
|
||||
}
|
||||
|
||||
// non-causal tasks require to fit the entire prompt in the physical batch
|
||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
||||
// cannot fit the prompt in the current batch - will try next iter
|
||||
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// check that we are in the right batch_type, if not defer the slot
|
||||
const bool slot_type =
|
||||
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ||
|
||||
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ? 1 : 0;
|
||||
|
||||
if (batch_type == -1) {
|
||||
batch_type = slot_type;
|
||||
} else if (batch_type != slot_type) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// keep only the common part
|
||||
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, slot.n_past, -1)) {
|
||||
// could not partially delete (likely using a non-Transformer model)
|
||||
llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1);
|
||||
|
||||
// there is no common part left
|
||||
slot.n_past = 0;
|
||||
|
||||
common_sampler_reset(slot.smpl);
|
||||
}
|
||||
|
||||
SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);
|
||||
|
||||
// remove the non-common part from the cache
|
||||
slot.cache_tokens.resize(slot.n_past);
|
||||
|
||||
// add prompt tokens for processing in the current batch
|
||||
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
|
||||
common_batch_add(batch, slot.prompt_tokens[slot.n_past], slot.n_past, { slot.id + 1 }, false);
|
||||
|
||||
if (slot.params.cache_prompt) {
|
||||
slot.cache_tokens.push_back(slot.prompt_tokens[slot.n_past]);
|
||||
}
|
||||
|
||||
slot.n_prompt_tokens_processed++;
|
||||
slot.n_past++;
|
||||
}
|
||||
|
||||
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
|
||||
|
||||
// entire prompt has been processed
|
||||
if (slot.n_past == slot.n_prompt_tokens) {
|
||||
slot.state = SLOT_STATE_DONE_PROMPT;
|
||||
|
||||
GGML_ASSERT(batch.n_tokens > 0);
|
||||
|
||||
// extract the logits only for the last token
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
|
||||
slot.n_decoded = 0;
|
||||
slot.i_batch = batch.n_tokens - 1;
|
||||
|
||||
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens);
|
||||
}
|
||||
|
||||
if (batch.n_tokens >= n_batch) {
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
using llama_tokens = std::vector<llama_token>;
|
||||
|
||||
// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
|
||||
enum error_type {
|
||||
|
@ -52,9 +53,234 @@ static T json_value(const json & body, const std::string & key, const T & defaul
|
|||
}
|
||||
|
||||
//
|
||||
// chat template utils
|
||||
// tokenizer and input processing utils
|
||||
//
|
||||
|
||||
static bool json_is_array_of_numbers(const json & data) {
|
||||
if (data.is_array()) {
|
||||
for (const auto & e : data) {
|
||||
if (!e.is_number_integer()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// is array having BOTH numbers & strings?
|
||||
static bool json_is_array_of_mixed_numbers_strings(const json & data) {
|
||||
bool seen_string = false;
|
||||
bool seen_number = false;
|
||||
if (data.is_array()) {
|
||||
for (const auto & e : data) {
|
||||
seen_string |= e.is_string();
|
||||
seen_number |= e.is_number_integer();
|
||||
if (seen_number && seen_string) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* this handles 2 cases:
|
||||
* - only string, example: "string"
|
||||
* - mixed string and tokens, example: [12, 34, "string", 56, 78]
|
||||
*/
|
||||
static llama_tokens tokenize_mixed(const llama_context * ctx, const json & json_prompt, bool add_special, bool parse_special) {
|
||||
// If `add_bos` is true, we only add BOS, when json_prompt is a string,
|
||||
// or the first element of the json_prompt array is a string.
|
||||
llama_tokens prompt_tokens;
|
||||
|
||||
if (json_prompt.is_array()) {
|
||||
bool first = true;
|
||||
for (const auto & p : json_prompt) {
|
||||
if (p.is_string()) {
|
||||
auto s = p.template get<std::string>();
|
||||
|
||||
llama_tokens p;
|
||||
if (first) {
|
||||
p = common_tokenize(ctx, s, add_special, parse_special);
|
||||
first = false;
|
||||
} else {
|
||||
p = common_tokenize(ctx, s, false, parse_special);
|
||||
}
|
||||
|
||||
prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
|
||||
} else {
|
||||
if (first) {
|
||||
first = false;
|
||||
}
|
||||
|
||||
prompt_tokens.push_back(p.template get<llama_token>());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto s = json_prompt.template get<std::string>();
|
||||
prompt_tokens = common_tokenize(ctx, s, add_special, parse_special);
|
||||
}
|
||||
|
||||
return prompt_tokens;
|
||||
}
|
||||
|
||||
/**
|
||||
* break the input "prompt" object into multiple prompt if needed, then tokenize them
|
||||
* this supports these cases:
|
||||
* - "prompt": "string"
|
||||
* - "prompt": [12, 34, 56]
|
||||
* - "prompt": [12, 34, "string", 56, 78]
|
||||
* and multiple prompts (multi-tasks):
|
||||
* - "prompt": ["string1", "string2"]
|
||||
* - "prompt": ["string1", [12, 34, 56]]
|
||||
* - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]]
|
||||
*/
|
||||
static std::vector<llama_tokens> tokenize_input_prompts(llama_context * ctx, const json & json_prompt, bool add_special, bool parse_special) {
|
||||
std::vector<llama_tokens> result;
|
||||
if (json_prompt.is_string() || json_is_array_of_mixed_numbers_strings(json_prompt)) {
|
||||
// string or mixed
|
||||
result.push_back(tokenize_mixed(ctx, json_prompt, add_special, parse_special));
|
||||
} else if (json_is_array_of_numbers(json_prompt)) {
|
||||
// array of tokens
|
||||
result.push_back(json_prompt.get<llama_tokens>());
|
||||
} else if (json_prompt.is_array()) {
|
||||
// array of prompts
|
||||
result.reserve(json_prompt.size());
|
||||
for (const auto & p : json_prompt) {
|
||||
if (p.is_string() || json_is_array_of_mixed_numbers_strings(p)) {
|
||||
result.push_back(tokenize_mixed(ctx, p, add_special, parse_special));
|
||||
} else if (json_is_array_of_numbers(p)) {
|
||||
// array of tokens
|
||||
result.push_back(p.get<llama_tokens>());
|
||||
} else {
|
||||
throw std::runtime_error("element of \"prompt\" must be a string, an list of tokens, or a list of mixed strings & tokens");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error("\"prompt\" must be a string, an list of tokens, a list of mixed strings & tokens, or a list of prompts");
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
//
|
||||
// template utils
|
||||
//
|
||||
|
||||
// format rerank task: [BOS]query[EOS][SEP]doc[EOS]
|
||||
static llama_tokens format_rerank(const struct llama_model * model, const llama_tokens & query, const llama_tokens & doc) {
|
||||
llama_tokens result;
|
||||
result.reserve(doc.size() + query.size() + 4);
|
||||
result.push_back(llama_token_bos(model));
|
||||
result.insert(result.end(), query.begin(), query.end());
|
||||
result.push_back(llama_token_eos(model));
|
||||
result.push_back(llama_token_sep(model));
|
||||
result.insert(result.end(), doc.begin(), doc.end());
|
||||
result.push_back(llama_token_eos(model));
|
||||
return result;
|
||||
}
|
||||
|
||||
// format infill task
|
||||
static llama_tokens format_infill(
|
||||
const llama_context * ctx,
|
||||
const json & input_prefix,
|
||||
const json & input_suffix,
|
||||
const json & input_extra,
|
||||
const int n_batch,
|
||||
const int n_predict,
|
||||
const int n_ctx,
|
||||
const bool spm_infill,
|
||||
const llama_tokens & tokens_prompt
|
||||
) {
|
||||
// TODO: optimize this block by reducing memory allocations and movement
|
||||
|
||||
// use FIM repo-level pattern:
|
||||
// ref: https://arxiv.org/pdf/2409.12186
|
||||
//
|
||||
// [FIM_REP]myproject
|
||||
// [FIM_SEP]filename0
|
||||
// extra chunk 0
|
||||
// [FIM_SEP]filename1
|
||||
// extra chunk 1
|
||||
// ...
|
||||
// [FIM_SEP]filename
|
||||
// [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt
|
||||
//
|
||||
llama_tokens extra_tokens;
|
||||
extra_tokens.reserve(n_ctx);
|
||||
|
||||
auto model = llama_get_model(ctx);
|
||||
auto tokens_prefix = tokenize_mixed(ctx, input_prefix, false, false);
|
||||
auto tokens_suffix = tokenize_mixed(ctx, input_suffix, false, false);
|
||||
|
||||
if (llama_token_fim_rep(model) != LLAMA_TOKEN_NULL) {
|
||||
static const auto k_fim_repo = common_tokenize(ctx, "myproject\n", false, false);
|
||||
|
||||
extra_tokens.push_back(llama_token_fim_rep(model));
|
||||
extra_tokens.insert(extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end());
|
||||
}
|
||||
for (const auto & chunk : input_extra) {
|
||||
// { "text": string, "filename": string }
|
||||
const std::string text = chunk.value("text", "");
|
||||
const std::string filename = chunk.value("filename", "tmp");
|
||||
|
||||
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
|
||||
const auto k_fim_file = common_tokenize(ctx, filename + "\n", false, false);
|
||||
|
||||
extra_tokens.insert(extra_tokens.end(), llama_token_fim_sep(model));
|
||||
extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
|
||||
} else {
|
||||
// chunk separator in binary form to avoid confusing the AI
|
||||
static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00};
|
||||
static const auto k_chunk_prefix_tokens = common_tokenize(ctx, k_chunk_prefix_str, false, false);
|
||||
|
||||
extra_tokens.insert(extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end());
|
||||
}
|
||||
|
||||
const auto chunk_tokens = common_tokenize(ctx, text, false, false);
|
||||
extra_tokens.insert(extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end());
|
||||
}
|
||||
|
||||
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
|
||||
// TODO: current filename
|
||||
static const auto k_fim_file = common_tokenize(ctx, "filename\n", false, false);
|
||||
|
||||
extra_tokens.insert(extra_tokens.end(), llama_token_fim_sep(model));
|
||||
extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
|
||||
}
|
||||
|
||||
// for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?)
|
||||
const int n_suffix_take = std::min<int>(tokens_suffix.size(), (n_batch/4));
|
||||
const int n_prefix_take = std::min<int>(tokens_prefix.size(), 3*(n_batch/4) - 3);
|
||||
|
||||
// fill the rest of the context with extra chunks
|
||||
const int n_extra_take = std::min<int>(std::max<int>(0, n_ctx - (n_batch) - 2*n_predict), extra_tokens.size());
|
||||
|
||||
tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take);
|
||||
tokens_suffix.resize(n_suffix_take);
|
||||
|
||||
tokens_prefix.insert(tokens_prefix.begin(), llama_token_fim_pre(model));
|
||||
tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end());
|
||||
tokens_suffix.insert(tokens_suffix.begin(), llama_token_fim_suf(model));
|
||||
|
||||
auto embd_inp = spm_infill ? tokens_suffix : tokens_prefix;
|
||||
auto embd_end = spm_infill ? tokens_prefix : tokens_suffix;
|
||||
|
||||
if (llama_add_bos_token(model)) {
|
||||
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
|
||||
}
|
||||
|
||||
LOG_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", n_ctx, n_extra_take, (int) extra_tokens.size());
|
||||
|
||||
// put the extra context before the FIM prefix
|
||||
embd_inp.insert(embd_inp.begin(), extra_tokens.end() - n_extra_take, extra_tokens.end());
|
||||
|
||||
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
|
||||
embd_inp.push_back(llama_token_fim_mid(model));
|
||||
|
||||
return embd_inp;
|
||||
}
|
||||
|
||||
// Format given chat. If tmpl is empty, we take the template from model metadata
|
||||
inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages) {
|
||||
std::vector<common_chat_msg> chat;
|
||||
|
@ -229,18 +455,6 @@ static size_t find_partial_stop_string(const std::string &stop, const std::strin
|
|||
return std::string::npos;
|
||||
}
|
||||
|
||||
static bool json_is_array_of_numbers(const json & data) {
|
||||
if (data.is_array()) {
|
||||
for (const auto & e : data) {
|
||||
if (!e.is_number()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: reuse llama_detokenize
|
||||
template <class Iter>
|
||||
static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue