fix embeddings

This commit is contained in:
Xuan Son Nguyen 2024-09-02 00:31:40 +02:00
parent 4a5dbd85b5
commit 9f56c17669

View file

@ -3037,8 +3037,6 @@ int main(int argc, char ** argv) {
}, [&](json error_data) {
server_sent_event(sink, "error", error_data);
});
std::string done_event = "[DONE]"; // OAI-compat behavior
sink.write(done_event.c_str(), done_event.size());
sink.done();
return true;
};
@ -3106,29 +3104,29 @@ int main(int argc, char ** argv) {
}
// create and queue the task
json responses;
json responses = json::array();
bool error = false;
{
std::vector<server_task> tasks = ctx_server.create_tasks_completion({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING);
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks);
// get the result
server_task_result result = ctx_server.queue_results.recv(tasks);
ctx_server.queue_results.remove_waiting_tasks(tasks);
if (!result.error) {
if (result.data.count("results")) {
// result for multi-task
responses = result.data.at("results");
} else {
// result for single task
responses = std::vector<json>{result.data};
std::vector<int> task_ids = server_task::get_list_id(tasks);
ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
for (const auto & res : results) {
responses.push_back(res.data);
}
} else {
// error received, ignore everything else
res_error(res, result.data);
}, [&](json error_data) {
res_error(res, error_data);
error = true;
});
}
if (error) {
return;
}
}
// write JSON response
json root = is_openai