server : accept extra_context for the infill endpoint
ggml-ci
This commit is contained in:
parent
c7181bd294
commit
5a699f147e
2 changed files with 78 additions and 22 deletions
|
@ -139,6 +139,8 @@ struct slot_params {
|
||||||
|
|
||||||
json input_prefix;
|
json input_prefix;
|
||||||
json input_suffix;
|
json input_suffix;
|
||||||
|
|
||||||
|
json extra_context;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct server_slot {
|
struct server_slot {
|
||||||
|
@ -170,6 +172,7 @@ struct server_slot {
|
||||||
|
|
||||||
// when a task is submitted, we first tokenize the prompt and store it here
|
// when a task is submitted, we first tokenize the prompt and store it here
|
||||||
std::vector<llama_token> prompt_tokens;
|
std::vector<llama_token> prompt_tokens;
|
||||||
|
std::vector<llama_token> extra_tokens;
|
||||||
|
|
||||||
std::string generated_text;
|
std::string generated_text;
|
||||||
std::vector<llama_token> cache_tokens;
|
std::vector<llama_token> cache_tokens;
|
||||||
|
@ -906,8 +909,18 @@ struct server_context {
|
||||||
}
|
}
|
||||||
|
|
||||||
// infill
|
// infill
|
||||||
slot.params.input_prefix = json_value(data, "input_prefix", default_params.input_prefix);
|
slot.params.input_prefix = json_value(data, "input_prefix", default_params.input_prefix);
|
||||||
slot.params.input_suffix = json_value(data, "input_suffix", default_params.input_suffix);
|
slot.params.input_suffix = json_value(data, "input_suffix", default_params.input_suffix);
|
||||||
|
slot.params.extra_context = json_value(data, "extra_context", default_params.extra_context);
|
||||||
|
|
||||||
|
SLT_DBG(slot, "extra_context chunks: %d\n", (int) slot.params.extra_context.size());
|
||||||
|
for (const auto & chunk : slot.params.extra_context) {
|
||||||
|
if (chunk.is_string()) {
|
||||||
|
SLT_DBG(slot, "chunk: \n%s\n", chunk.get<std::string>().c_str());
|
||||||
|
} else {
|
||||||
|
SLT_DBG(slot, "%s", "chunk is not a string - skipping\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// get prompt
|
// get prompt
|
||||||
if (task.cmpl_type != SERVER_TASK_CMPL_TYPE_INFILL) {
|
if (task.cmpl_type != SERVER_TASK_CMPL_TYPE_INFILL) {
|
||||||
|
@ -1937,10 +1950,28 @@ struct server_context {
|
||||||
auto prefix_tokens = tokenize(slot.params.input_prefix, false, false);
|
auto prefix_tokens = tokenize(slot.params.input_prefix, false, false);
|
||||||
auto suffix_tokens = tokenize(slot.params.input_suffix, false, false);
|
auto suffix_tokens = tokenize(slot.params.input_suffix, false, false);
|
||||||
|
|
||||||
// for now pick context to fit in a single batch (ratio prefix:suffix = 3:1, TODO: configurable?)
|
slot.extra_tokens.clear();
|
||||||
const int n_suffix_take = std::min<int>(suffix_tokens.size(), n_batch/4);
|
for (const auto & e : slot.params.extra_context) {
|
||||||
|
if (e.is_string()) {
|
||||||
|
// chunk separator in binary form to avoid confusing the AI
|
||||||
|
static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00};
|
||||||
|
static const auto k_chunk_prefix_tokens = tokenize(k_chunk_prefix_str, false, false);
|
||||||
|
slot.extra_tokens.insert(slot.extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end());
|
||||||
|
|
||||||
|
const auto part = tokenize(e, false, false);
|
||||||
|
slot.extra_tokens.insert(slot.extra_tokens.end(), part.begin(), part.end());
|
||||||
|
} else {
|
||||||
|
SLT_WRN(slot, "%s", "extra context element is not a string\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?)
|
||||||
|
const int n_suffix_take = std::min<int>(suffix_tokens.size(), (n_batch)/4);
|
||||||
const int n_prefix_take = std::min<int>(prefix_tokens.size(), (n_batch - 3) - n_suffix_take);
|
const int n_prefix_take = std::min<int>(prefix_tokens.size(), (n_batch - 3) - n_suffix_take);
|
||||||
|
|
||||||
|
// fill the rest of the context with extra chunks
|
||||||
|
const int n_extra_take = std::min<int>(std::max<int>(0, slot.n_ctx - (n_batch) - 2*slot.n_predict), slot.extra_tokens.size());
|
||||||
|
|
||||||
prefix_tokens.erase(prefix_tokens.begin(), prefix_tokens.begin() + prefix_tokens.size() - n_prefix_take);
|
prefix_tokens.erase(prefix_tokens.begin(), prefix_tokens.begin() + prefix_tokens.size() - n_prefix_take);
|
||||||
suffix_tokens.resize(n_suffix_take);
|
suffix_tokens.resize(n_suffix_take);
|
||||||
|
|
||||||
|
@ -1954,6 +1985,11 @@ struct server_context {
|
||||||
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
|
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SLT_DBG(slot, "extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", slot.n_ctx, n_extra_take, (int) slot.extra_tokens.size());
|
||||||
|
|
||||||
|
// put the extra context before the FIM prefix
|
||||||
|
embd_inp.insert(embd_inp.begin(), slot.extra_tokens.end() - n_extra_take, slot.extra_tokens.end());
|
||||||
|
|
||||||
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
|
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
|
||||||
embd_inp.push_back(llama_token_fim_mid(model));
|
embd_inp.push_back(llama_token_fim_mid(model));
|
||||||
|
|
||||||
|
|
|
@ -6596,8 +6596,8 @@ static void llm_load_vocab(
|
||||||
) {
|
) {
|
||||||
vocab.special_eot_id = t.second;
|
vocab.special_eot_id = t.second;
|
||||||
if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
||||||
LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
|
LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
|
||||||
__func__, t.first.c_str());
|
__func__, t.second, t.first.c_str());
|
||||||
vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
|
vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -6610,8 +6610,8 @@ static void llm_load_vocab(
|
||||||
) {
|
) {
|
||||||
vocab.special_eom_id = t.second;
|
vocab.special_eom_id = t.second;
|
||||||
if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
||||||
LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
|
LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
|
||||||
__func__, t.first.c_str());
|
__func__, t.second, t.first.c_str());
|
||||||
vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
|
vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -6627,8 +6627,8 @@ static void llm_load_vocab(
|
||||||
) {
|
) {
|
||||||
vocab.special_fim_pre_id = t.second;
|
vocab.special_fim_pre_id = t.second;
|
||||||
if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
||||||
LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
|
LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
|
||||||
__func__, t.first.c_str());
|
__func__, t.second, t.first.c_str());
|
||||||
vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
|
vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -6644,8 +6644,8 @@ static void llm_load_vocab(
|
||||||
) {
|
) {
|
||||||
vocab.special_fim_suf_id = t.second;
|
vocab.special_fim_suf_id = t.second;
|
||||||
if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
||||||
LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
|
LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
|
||||||
__func__, t.first.c_str());
|
__func__, t.second, t.first.c_str());
|
||||||
vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
|
vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -6661,8 +6661,8 @@ static void llm_load_vocab(
|
||||||
) {
|
) {
|
||||||
vocab.special_fim_mid_id = t.second;
|
vocab.special_fim_mid_id = t.second;
|
||||||
if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
||||||
LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
|
LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
|
||||||
__func__, t.first.c_str());
|
__func__, t.second, t.first.c_str());
|
||||||
vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
|
vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -6677,8 +6677,8 @@ static void llm_load_vocab(
|
||||||
) {
|
) {
|
||||||
vocab.special_fim_pad_id = t.second;
|
vocab.special_fim_pad_id = t.second;
|
||||||
if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
||||||
LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
|
LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
|
||||||
__func__, t.first.c_str());
|
__func__, t.second, t.first.c_str());
|
||||||
vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
|
vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -6694,8 +6694,8 @@ static void llm_load_vocab(
|
||||||
) {
|
) {
|
||||||
vocab.special_fim_rep_id = t.second;
|
vocab.special_fim_rep_id = t.second;
|
||||||
if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
||||||
LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
|
LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
|
||||||
__func__, t.first.c_str());
|
__func__, t.second, t.first.c_str());
|
||||||
vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
|
vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -6708,8 +6708,8 @@ static void llm_load_vocab(
|
||||||
) {
|
) {
|
||||||
vocab.special_fim_sep_id = t.second;
|
vocab.special_fim_sep_id = t.second;
|
||||||
if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
||||||
LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
|
LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
|
||||||
__func__, t.first.c_str());
|
__func__, t.second, t.first.c_str());
|
||||||
vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
|
vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -6720,6 +6720,19 @@ static void llm_load_vocab(
|
||||||
// this is currently determined based on the token text, which is obviously not ideal
|
// this is currently determined based on the token text, which is obviously not ideal
|
||||||
// ref: https://github.com/ggerganov/llama.cpp/issues/9606
|
// ref: https://github.com/ggerganov/llama.cpp/issues/9606
|
||||||
vocab.special_eog_ids.clear();
|
vocab.special_eog_ids.clear();
|
||||||
|
|
||||||
|
if (vocab.special_fim_pad_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_fim_pad_id) == 0) {
|
||||||
|
vocab.special_eog_ids.insert(vocab.special_fim_pad_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (vocab.special_fim_rep_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_fim_rep_id) == 0) {
|
||||||
|
vocab.special_eog_ids.insert(vocab.special_fim_rep_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (vocab.special_fim_sep_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_fim_sep_id) == 0) {
|
||||||
|
vocab.special_eog_ids.insert(vocab.special_fim_sep_id);
|
||||||
|
}
|
||||||
|
|
||||||
for (const auto & t : vocab.token_to_id) {
|
for (const auto & t : vocab.token_to_id) {
|
||||||
if (false
|
if (false
|
||||||
|| t.first == "<|eot_id|>"
|
|| t.first == "<|eot_id|>"
|
||||||
|
@ -6732,13 +6745,20 @@ static void llm_load_vocab(
|
||||||
) {
|
) {
|
||||||
vocab.special_eog_ids.insert(t.second);
|
vocab.special_eog_ids.insert(t.second);
|
||||||
if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
||||||
LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
|
LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
|
||||||
__func__, t.first.c_str());
|
__func__, t.second, t.first.c_str());
|
||||||
vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
|
vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
// token is control, but not marked as EOG -> print a warning
|
||||||
|
if (vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL && vocab.special_eog_ids.count(t.second) == 0) {
|
||||||
|
LLAMA_LOG_WARN("%s: control token: %6d '%s' is not marked as EOG\n",
|
||||||
|
__func__, t.second, t.first.c_str());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sanity checks
|
||||||
if (vocab.special_eos_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_eos_id) == 0) {
|
if (vocab.special_eos_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_eos_id) == 0) {
|
||||||
vocab.special_eog_ids.insert(vocab.special_eos_id);
|
vocab.special_eog_ids.insert(vocab.special_eos_id);
|
||||||
LLAMA_LOG_WARN("%s: special_eos_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
|
LLAMA_LOG_WARN("%s: special_eos_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue