try fixing format_infill
This commit is contained in:
parent
fea5ca4524
commit
07381f7d97
2 changed files with 28 additions and 19 deletions
|
@ -43,21 +43,6 @@
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
|
|
||||||
#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
|
||||||
#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
|
||||||
#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
|
||||||
#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
|
||||||
|
|
||||||
#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
||||||
#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
||||||
#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
||||||
#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
||||||
|
|
||||||
#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
||||||
#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
||||||
#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
||||||
#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
||||||
|
|
||||||
using json = nlohmann::ordered_json;
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
enum stop_type {
|
enum stop_type {
|
||||||
|
@ -2780,12 +2765,19 @@ int main(int argc, char ** argv) {
|
||||||
json data = json::parse(req.body);
|
json data = json::parse(req.body);
|
||||||
|
|
||||||
// validate input
|
// validate input
|
||||||
|
if (!data.contains("input_prefix")) {
|
||||||
|
res_error(res, format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!data.contains("input_suffix")) {
|
||||||
|
res_error(res, format_error_response("\"input_suffix\" is required", ERROR_TYPE_INVALID_REQUEST));
|
||||||
|
}
|
||||||
|
|
||||||
if (data.contains("input_extra") && !data.at("input_extra").is_array()) {
|
if (data.contains("input_extra") && !data.at("input_extra").is_array()) {
|
||||||
res_error(res, format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST));
|
res_error(res, format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
json input_extra = json_value(data, "input_extra", json::array());
|
json input_extra = json_value(data, "input_extra", json::array());
|
||||||
|
|
||||||
for (const auto & chunk : input_extra) {
|
for (const auto & chunk : input_extra) {
|
||||||
// { "text": string, "filename": string }
|
// { "text": string, "filename": string }
|
||||||
if (!chunk.contains("text") || !chunk.at("text").is_string()) {
|
if (!chunk.contains("text") || !chunk.at("text").is_string()) {
|
||||||
|
@ -2798,6 +2790,7 @@ int main(int argc, char ** argv) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
data["input_extra"] = input_extra; // default to empty array if it's not exist
|
||||||
|
|
||||||
return handle_completions_generic(SERVER_TASK_CMPL_TYPE_INFILL, data, res);
|
return handle_completions_generic(SERVER_TASK_CMPL_TYPE_INFILL, data, res);
|
||||||
};
|
};
|
||||||
|
|
|
@ -26,6 +26,21 @@
|
||||||
using json = nlohmann::ordered_json;
|
using json = nlohmann::ordered_json;
|
||||||
using llama_tokens = std::vector<llama_token>;
|
using llama_tokens = std::vector<llama_token>;
|
||||||
|
|
||||||
|
#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
||||||
|
#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
||||||
|
#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
||||||
|
#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
||||||
|
|
||||||
|
#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||||
|
#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||||
|
#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||||
|
#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||||
|
|
||||||
|
#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||||
|
#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||||
|
#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||||
|
#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||||
|
|
||||||
// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
|
// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
|
||||||
enum error_type {
|
enum error_type {
|
||||||
ERROR_TYPE_INVALID_REQUEST,
|
ERROR_TYPE_INVALID_REQUEST,
|
||||||
|
@ -214,6 +229,7 @@ static llama_tokens format_infill(
|
||||||
auto tokens_suffix = tokenize_mixed(ctx, input_suffix, false, false);
|
auto tokens_suffix = tokenize_mixed(ctx, input_suffix, false, false);
|
||||||
|
|
||||||
if (llama_token_fim_rep(model) != LLAMA_TOKEN_NULL) {
|
if (llama_token_fim_rep(model) != LLAMA_TOKEN_NULL) {
|
||||||
|
// TODO: make project name an input
|
||||||
static const auto k_fim_repo = common_tokenize(ctx, "myproject\n", false, false);
|
static const auto k_fim_repo = common_tokenize(ctx, "myproject\n", false, false);
|
||||||
|
|
||||||
extra_tokens.push_back(llama_token_fim_rep(model));
|
extra_tokens.push_back(llama_token_fim_rep(model));
|
||||||
|
@ -221,8 +237,8 @@ static llama_tokens format_infill(
|
||||||
}
|
}
|
||||||
for (const auto & chunk : input_extra) {
|
for (const auto & chunk : input_extra) {
|
||||||
// { "text": string, "filename": string }
|
// { "text": string, "filename": string }
|
||||||
const std::string text = chunk.value("text", "");
|
const std::string text = json_value(chunk, "text", std::string());
|
||||||
const std::string filename = chunk.value("filename", "tmp");
|
const std::string filename = json_value(chunk, "filename", std::string("tmp"));
|
||||||
|
|
||||||
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
|
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
|
||||||
const auto k_fim_file = common_tokenize(ctx, filename + "\n", false, false);
|
const auto k_fim_file = common_tokenize(ctx, filename + "\n", false, false);
|
||||||
|
@ -270,7 +286,7 @@ static llama_tokens format_infill(
|
||||||
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
|
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", n_ctx, n_extra_take, (int) extra_tokens.size());
|
SRV_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", n_ctx, n_extra_take, (int) extra_tokens.size());
|
||||||
|
|
||||||
// put the extra context before the FIM prefix
|
// put the extra context before the FIM prefix
|
||||||
embd_inp.insert(embd_inp.begin(), extra_tokens.end() - n_extra_take, extra_tokens.end());
|
embd_inp.insert(embd_inp.begin(), extra_tokens.end() - n_extra_take, extra_tokens.end());
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue