try fixing format_infill
This commit is contained in:
parent
fea5ca4524
commit
07381f7d97
2 changed files with 28 additions and 19 deletions
|
@ -43,21 +43,6 @@
|
|||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
||||
#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
||||
#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
||||
#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
||||
|
||||
#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
|
||||
#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
enum stop_type {
|
||||
|
@ -2780,12 +2765,19 @@ int main(int argc, char ** argv) {
|
|||
json data = json::parse(req.body);
|
||||
|
||||
// validate input
|
||||
if (!data.contains("input_prefix")) {
|
||||
res_error(res, format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST));
|
||||
}
|
||||
|
||||
if (!data.contains("input_suffix")) {
|
||||
res_error(res, format_error_response("\"input_suffix\" is required", ERROR_TYPE_INVALID_REQUEST));
|
||||
}
|
||||
|
||||
if (data.contains("input_extra") && !data.at("input_extra").is_array()) {
|
||||
res_error(res, format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
json input_extra = json_value(data, "input_extra", json::array());
|
||||
|
||||
for (const auto & chunk : input_extra) {
|
||||
// { "text": string, "filename": string }
|
||||
if (!chunk.contains("text") || !chunk.at("text").is_string()) {
|
||||
|
@ -2798,6 +2790,7 @@ int main(int argc, char ** argv) {
|
|||
return;
|
||||
}
|
||||
}
|
||||
data["input_extra"] = input_extra; // default to empty array if it's not exist
|
||||
|
||||
return handle_completions_generic(SERVER_TASK_CMPL_TYPE_INFILL, data, res);
|
||||
};
|
||||
|
|
|
@ -26,6 +26,21 @@
|
|||
using json = nlohmann::ordered_json;
|
||||
using llama_tokens = std::vector<llama_token>;
|
||||
|
||||
#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
||||
#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
||||
#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
||||
#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
||||
|
||||
#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
|
||||
#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
|
||||
// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
|
||||
enum error_type {
|
||||
ERROR_TYPE_INVALID_REQUEST,
|
||||
|
@ -214,6 +229,7 @@ static llama_tokens format_infill(
|
|||
auto tokens_suffix = tokenize_mixed(ctx, input_suffix, false, false);
|
||||
|
||||
if (llama_token_fim_rep(model) != LLAMA_TOKEN_NULL) {
|
||||
// TODO: make project name an input
|
||||
static const auto k_fim_repo = common_tokenize(ctx, "myproject\n", false, false);
|
||||
|
||||
extra_tokens.push_back(llama_token_fim_rep(model));
|
||||
|
@ -221,8 +237,8 @@ static llama_tokens format_infill(
|
|||
}
|
||||
for (const auto & chunk : input_extra) {
|
||||
// { "text": string, "filename": string }
|
||||
const std::string text = chunk.value("text", "");
|
||||
const std::string filename = chunk.value("filename", "tmp");
|
||||
const std::string text = json_value(chunk, "text", std::string());
|
||||
const std::string filename = json_value(chunk, "filename", std::string("tmp"));
|
||||
|
||||
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
|
||||
const auto k_fim_file = common_tokenize(ctx, filename + "\n", false, false);
|
||||
|
@ -270,7 +286,7 @@ static llama_tokens format_infill(
|
|||
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
|
||||
}
|
||||
|
||||
LOG_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", n_ctx, n_extra_take, (int) extra_tokens.size());
|
||||
SRV_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", n_ctx, n_extra_take, (int) extra_tokens.size());
|
||||
|
||||
// put the extra context before the FIM prefix
|
||||
embd_inp.insert(embd_inp.begin(), extra_tokens.end() - n_extra_take, extra_tokens.end());
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue