server : fix format_infill (#10724)

* server : fix format_infill

* fix

* rename

* update test

* use another model

* update test

* update test

* test_invalid_input_extra_req
This commit is contained in:
Xuan Son Nguyen 2024-12-08 23:04:29 +01:00 committed by GitHub
parent e52522b869
commit ce8784bdb1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 53 additions and 8 deletions

View file

@ -3484,6 +3484,11 @@ int main(int argc, char ** argv) {
json data = json::parse(req.body);
// validate input
if (data.contains("prompt") && !data.at("prompt").is_string()) {
// prompt is optional
res_error(res, format_error_response("\"prompt\" must be a string", ERROR_TYPE_INVALID_REQUEST));
}
if (!data.contains("input_prefix")) {
res_error(res, format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST));
}
@ -3493,9 +3498,11 @@ int main(int argc, char ** argv) {
}
if (data.contains("input_extra") && !data.at("input_extra").is_array()) {
// input_extra is optional
res_error(res, format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST));
return;
}
json input_extra = json_value(data, "input_extra", json::array());
for (const auto & chunk : input_extra) {
// { "text": string, "filename": string }
@ -3511,6 +3518,21 @@ int main(int argc, char ** argv) {
}
data["input_extra"] = input_extra; // default to empty array if it's not exist
std::string prompt = json_value(data, "prompt", std::string());
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, true, true);
SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
data["prompt"] = format_infill(
ctx_server.ctx,
data.at("input_prefix"),
data.at("input_suffix"),
data.at("input_extra"),
ctx_server.params_base.n_batch,
ctx_server.params_base.n_predict,
ctx_server.slots[0].n_ctx, // TODO: there should be a better way
ctx_server.params_base.spm_infill,
tokenized_prompts[0]
);
return handle_completions_generic(SERVER_TASK_TYPE_INFILL, data, res);
};