fix embeddings

This commit is contained in:
Xuan Son Nguyen 2024-09-02 00:31:40 +02:00
parent 4a5dbd85b5
commit 9f56c17669

View file

@ -3037,8 +3037,6 @@ int main(int argc, char ** argv) {
}, [&](json error_data) { }, [&](json error_data) {
server_sent_event(sink, "error", error_data); server_sent_event(sink, "error", error_data);
}); });
std::string done_event = "[DONE]"; // OAI-compat behavior
sink.write(done_event.c_str(), done_event.size());
sink.done(); sink.done();
return true; return true;
}; };
@ -3106,28 +3104,28 @@ int main(int argc, char ** argv) {
} }
// create and queue the task // create and queue the task
json responses; json responses = json::array();
bool error = false;
{ {
std::vector<server_task> tasks = ctx_server.create_tasks_completion({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING); std::vector<server_task> tasks = ctx_server.create_tasks_completion({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING);
ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks); ctx_server.queue_tasks.post(tasks);
// get the result // get the result
server_task_result result = ctx_server.queue_results.recv(tasks); std::vector<int> task_ids = server_task::get_list_id(tasks);
ctx_server.queue_results.remove_waiting_tasks(tasks);
if (!result.error) { ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
if (result.data.count("results")) { for (const auto & res : results) {
// result for multi-task responses.push_back(res.data);
responses = result.data.at("results");
} else {
// result for single task
responses = std::vector<json>{result.data};
} }
} else { }, [&](json error_data) {
// error received, ignore everything else res_error(res, error_data);
res_error(res, result.data); error = true;
return; });
} }
if (error) {
return;
} }
// write JSON response // write JSON response