specify types

This commit is contained in:
ngxson 2024-03-11 22:14:27 +01:00
parent 1a29871348
commit a601da6fd4

View file

@ -3373,40 +3373,39 @@ int main(int argc, char ** argv) {
bool is_openai = false; bool is_openai = false;
// an input prompt can be a string or a list of tokens (integer) // an input prompt can be a string or a list of tokens (integer)
json prompts = json::array(); json prompt;
if (body.count("input") != 0) { if (body.count("input") != 0) {
is_openai = true; is_openai = true;
prompts = body["input"].is_array() prompt = body["input"];
? body["input"] // support multiple prompts
: json{body["input"]}; // single input prompt
} else if (body.count("content") != 0) { } else if (body.count("content") != 0) {
// with "content", we only support single prompt // with "content", we only support single prompt
std::string content = body["content"]; prompt = std::vector<std::string>{body["content"]};
prompts.push_back(content);
} else { } else {
res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
return; return;
} }
// create and queue the task // create and queue the task
json responses = json::array(); json responses;
{ {
const int id_task = ctx_server.queue_tasks.get_new_id(); const int id_task = ctx_server.queue_tasks.get_new_id();
ctx_server.queue_results.add_waiting_task_id(id_task); ctx_server.queue_results.add_waiting_task_id(id_task);
// if number of prompts is more than 1, we pass an array to create a multi-task
// otherwise, we pass a single prompt to make a single task
ctx_server.request_completion(id_task, -1, { ctx_server.request_completion(id_task, -1, {
{"prompt", prompts.size() == 1 ? prompts[0] : prompts}, {"prompt", prompt},
{"n_predict", 0} {"n_predict", 0},
}, false, true); }, false, true);
// get the result // get the result
server_task_result result = ctx_server.queue_results.recv(id_task); server_task_result result = ctx_server.queue_results.recv(id_task);
ctx_server.queue_results.remove_waiting_task_id(id_task); ctx_server.queue_results.remove_waiting_task_id(id_task);
if (!result.error) { if (!result.error) {
responses = result.data.count("results") if (result.data.count("results")) {
? result.data["results"] // result for multi-task // result for multi-task
: json{result.data}; // result for single task responses = result.data["results"];
} else {
// result for single task
responses = std::vector<json>{result.data};
}
} else { } else {
// error received, ignore everything else // error received, ignore everything else
res_error(res, result.data); res_error(res, result.data);