specify types
This commit is contained in:
parent
1a29871348
commit
a601da6fd4
1 changed files with 13 additions and 14 deletions
|
@ -3373,40 +3373,39 @@ int main(int argc, char ** argv) {
|
||||||
bool is_openai = false;
|
bool is_openai = false;
|
||||||
|
|
||||||
// an input prompt can be a string or a list of tokens (integer)
|
// an input prompt can be a string or a list of tokens (integer)
|
||||||
json prompts = json::array();
|
json prompt;
|
||||||
if (body.count("input") != 0) {
|
if (body.count("input") != 0) {
|
||||||
is_openai = true;
|
is_openai = true;
|
||||||
prompts = body["input"].is_array()
|
prompt = body["input"];
|
||||||
? body["input"] // support multiple prompts
|
|
||||||
: json{body["input"]}; // single input prompt
|
|
||||||
} else if (body.count("content") != 0) {
|
} else if (body.count("content") != 0) {
|
||||||
// with "content", we only support single prompt
|
// with "content", we only support single prompt
|
||||||
std::string content = body["content"];
|
prompt = std::vector<std::string>{body["content"]};
|
||||||
prompts.push_back(content);
|
|
||||||
} else {
|
} else {
|
||||||
res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// create and queue the task
|
// create and queue the task
|
||||||
json responses = json::array();
|
json responses;
|
||||||
{
|
{
|
||||||
const int id_task = ctx_server.queue_tasks.get_new_id();
|
const int id_task = ctx_server.queue_tasks.get_new_id();
|
||||||
ctx_server.queue_results.add_waiting_task_id(id_task);
|
ctx_server.queue_results.add_waiting_task_id(id_task);
|
||||||
// if number of prompts is more than 1, we pass an array to create a multi-task
|
|
||||||
// otherwise, we pass a single prompt to make a single task
|
|
||||||
ctx_server.request_completion(id_task, -1, {
|
ctx_server.request_completion(id_task, -1, {
|
||||||
{"prompt", prompts.size() == 1 ? prompts[0] : prompts},
|
{"prompt", prompt},
|
||||||
{"n_predict", 0}
|
{"n_predict", 0},
|
||||||
}, false, true);
|
}, false, true);
|
||||||
|
|
||||||
// get the result
|
// get the result
|
||||||
server_task_result result = ctx_server.queue_results.recv(id_task);
|
server_task_result result = ctx_server.queue_results.recv(id_task);
|
||||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||||
if (!result.error) {
|
if (!result.error) {
|
||||||
responses = result.data.count("results")
|
if (result.data.count("results")) {
|
||||||
? result.data["results"] // result for multi-task
|
// result for multi-task
|
||||||
: json{result.data}; // result for single task
|
responses = result.data["results"];
|
||||||
|
} else {
|
||||||
|
// result for single task
|
||||||
|
responses = std::vector<json>{result.data};
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// error received, ignore everything else
|
// error received, ignore everything else
|
||||||
res_error(res, result.data);
|
res_error(res, result.data);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue