server: add llama_server_response_event

This commit is contained in:
ngxson 2024-01-21 14:45:28 +01:00
parent 6e29f4c725
commit 12829b2e64
5 changed files with 161 additions and 86 deletions

View file

@ -619,7 +619,7 @@ embedding: examples/embedding/embedding.cpp ggml.o llama.o $(C
save-load-state: examples/save-load-state/save-load-state.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
server: examples/server/server.cpp examples/server/utils.hpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp examples/llava/clip.cpp examples/llava/clip.h common/stb_image.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
server: examples/server/server.cpp examples/server/oai.hpp examples/server/utils.hpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp examples/llava/clip.cpp examples/llava/clip.h common/stb_image.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
$(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS) $(LWINSOCK2) -Wno-cast-qual
gguf: examples/gguf/gguf.cpp ggml.o $(OBJS)

View file

@ -1,7 +1,7 @@
set(TARGET server)
option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON)
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
add_executable(${TARGET} server.cpp utils.hpp json.hpp httplib.h)
add_executable(${TARGET} server.cpp oai.hpp utils.hpp json.hpp httplib.h)
install(TARGETS ${TARGET} RUNTIME)
target_compile_definitions(${TARGET} PRIVATE
SERVER_VERBOSE=$<BOOL:${LLAMA_SERVER_VERBOSE}>

14
examples/server/oai.hpp Normal file
View file

@ -0,0 +1,14 @@
#pragma once
#include <string>
#include <vector>
#include <set>
#include <mutex>
#include <condition_variable>
#include <unordered_map>
#include "json.hpp"
#include "utils.hpp"
using json = nlohmann::json;

View file

@ -43,7 +43,7 @@ struct server_params
int32_t write_timeout = 600;
};
static bool server_verbose = false;
bool server_verbose = false;
json oaicompat_completion_params_parse(const json &body);
std::string format_chatml(std::vector<json> messages);
@ -279,7 +279,7 @@ struct llama_client_slot
}
void release() {
if (state == IDLE || state == PROCESSING)
if (state == PROCESSING)
{
t_token_generation = (ggml_time_us() - t_start_genereration) / 1e3;
command = RELEASE;
@ -342,12 +342,9 @@ struct llama_server_context
std::vector<llama_client_slot> slots;
llama_server_queue<task_server> queue_tasks;
std::vector<task_result> queue_results;
llama_server_response_event queue_results;
std::vector<task_multi> queue_multitasks;
std::mutex mutex_tasks; // also guards id_gen, and queue_multitasks
// std::condition_variable condition_tasks;
std::mutex mutex_results;
std::condition_variable condition_results;
std::mutex mutex_multitasks;
~llama_server_context()
{
@ -968,20 +965,18 @@ struct llama_server_context
void send_error(task_server& task, const std::string &error)
{
LOG_TEE("task %i - error: %s\n", task.id, error.c_str());
std::unique_lock<std::mutex> lock(mutex_results);
task_result res;
res.id = task.id;
res.multitask_id = task.multitask_id;
res.stop = false;
res.error = true;
res.result_json = { { "content", error } };
queue_results.push_back(res);
condition_results.notify_all();
queue_results.send(res);
}
void add_multi_task(int id, std::vector<int>& sub_ids)
void add_multitask(int id, std::vector<int>& sub_ids)
{
std::lock_guard<std::mutex> lock(mutex_tasks);
std::lock_guard<std::mutex> lock(mutex_multitasks);
task_multi multi;
multi.id = id;
std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end()));
@ -989,9 +984,9 @@ struct llama_server_context
// TODO @ngxson : Do we need to notify the queue_tasks?
}
void update_multi_task(int multitask_id, int subtask_id, task_result& result)
void update_multitask(int multitask_id, int subtask_id, task_result& result)
{
std::lock_guard<std::mutex> lock(mutex_tasks);
std::lock_guard<std::mutex> lock(mutex_multitasks);
for (auto& multitask : queue_multitasks)
{
if (multitask.id == multitask_id)
@ -1046,7 +1041,6 @@ struct llama_server_context
void send_partial_response(llama_client_slot &slot, completion_token_output tkn)
{
std::unique_lock<std::mutex> lock(mutex_results);
task_result res;
res.id = slot.task_id;
res.multitask_id = slot.multitask_id;
@ -1081,13 +1075,11 @@ struct llama_server_context
res.result_json["model"] = slot.oaicompat_model;
}
queue_results.push_back(res);
condition_results.notify_all();
queue_results.send(res);
}
void send_final_response(llama_client_slot &slot)
{
std::unique_lock<std::mutex> lock(mutex_results);
task_result res;
res.id = slot.task_id;
res.multitask_id = slot.multitask_id;
@ -1136,22 +1128,17 @@ struct llama_server_context
res.result_json["model"] = slot.oaicompat_model;
}
queue_results.push_back(res);
condition_results.notify_all();
// done with results, unlock
lock.unlock();
queue_results.send(res);
// parent multitask, if any, needs to be updated
if (slot.multitask_id != -1)
{
update_multi_task(slot.multitask_id, slot.task_id, res);
update_multitask(slot.multitask_id, slot.task_id, res);
}
}
void send_embedding(llama_client_slot &slot)
{
std::unique_lock<std::mutex> lock(mutex_results);
task_result res;
res.id = slot.task_id;
res.multitask_id = slot.multitask_id;
@ -1178,13 +1165,12 @@ struct llama_server_context
{"embedding", embedding },
};
}
queue_results.push_back(res);
condition_results.notify_all();
queue_results.send(res);
}
int request_completion(json data, bool infill, bool embedding, int multitask_id)
{
std::unique_lock<std::mutex> lock(mutex_tasks);
std::unique_lock<std::mutex> lock(mutex_multitasks);
task_server task;
task.target_id = 0;
task.data = std::move(data);
@ -1204,40 +1190,6 @@ struct llama_server_context
return queue_tasks.post(task);
}
task_result next_result(int task_id)
{
LOG_TEE("next_result %i \n", task_id);
while (true)
{
std::unique_lock<std::mutex> lock(mutex_results);
condition_results.wait(lock, [&]{
return !queue_results.empty();
});
for (int i = 0; i < (int) queue_results.size(); i++)
{
// for now, tasks that have associated parent multitasks just get erased once multitask picks up the result
if (queue_results[i].multitask_id == task_id)
{
update_multi_task(task_id, queue_results[i].id, queue_results[i]);
queue_results.erase(queue_results.begin() + i);
continue;
}
if (queue_results[i].id == task_id)
{
assert(queue_results[i].multitask_id == -1);
task_result res = queue_results[i];
queue_results.erase(queue_results.begin() + i);
return res;
}
}
}
// never reached
//return task_result{-1, false, false, {}};
}
// for multiple images processing
bool ingest_images(llama_client_slot &slot, int n_batch)
{
@ -1331,7 +1283,7 @@ struct llama_server_context
}
// queue up the multitask so we can track its subtask progression
add_multi_task(multitask_id, subtask_ids);
add_multitask(multitask_id, subtask_ids);
return multitask_id;
}
@ -1344,8 +1296,8 @@ struct llama_server_context
if (slot == nullptr)
{
// if no slot is available, we defer this task for processing later
// deferred_tasks.push_back(task);
LOG_INFO("no slot", {});
LOG_TEE("no slot\n");
queue_tasks.defer(task);
break;
}
@ -1417,12 +1369,7 @@ struct llama_server_context
aggregate_result.error = aggregate_result.error && subres.error;
}
aggregate_result.result_json = json{ "results", result_jsons };
agg_results.push_back(aggregate_result);
condition_results.notify_all();
queue_iterator = queue_multitasks.erase(queue_iterator);
}
else
@ -1432,8 +1379,9 @@ struct llama_server_context
}
// copy aggregate results of complete multi-tasks to the results queue
std::lock_guard<std::mutex> lock_results(mutex_results);
queue_results.insert(queue_results.end(), agg_results.begin(), agg_results.end());
for (auto& res : agg_results) {
queue_results.send(res);
}
}
bool update_slots() {
@ -2845,9 +2793,10 @@ int main(int argc, char **argv)
}
json data = json::parse(req.body);
const int task_id = llama.request_completion(data, false, false, -1);
llama.queue_results.add_waiting_task_id(task_id);
if (!json_value(data, "stream", false)) {
std::string completion_text;
task_result result = llama.next_result(task_id);
task_result result = llama.queue_results.recv(task_id);
if (!result.error && result.stop) {
res.set_content(result.result_json.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8");
}
@ -2855,6 +2804,7 @@ int main(int argc, char **argv)
{
res.status = 404;
res.set_content(result.result_json["content"], "text/plain; charset=utf-8");
llama.queue_results.remove_waiting_task_id(task_id);
return;
}
} else {
@ -2862,7 +2812,7 @@ int main(int argc, char **argv)
{
while (true)
{
task_result result = llama.next_result(task_id);
task_result result = llama.queue_results.recv(task_id);
if (!result.error) {
const std::string str =
"data: " +
@ -2873,6 +2823,7 @@ int main(int argc, char **argv)
});
if (!sink.write(str.c_str(), str.size()))
{
llama.queue_results.remove_waiting_task_id(task_id);
return false;
}
if (result.stop) {
@ -2888,12 +2839,14 @@ int main(int argc, char **argv)
});
if (!sink.write(str.c_str(), str.size()))
{
llama.queue_results.remove_waiting_task_id(task_id);
return false;
}
break;
}
}
sink.done();
llama.queue_results.remove_waiting_task_id(task_id);
return true;
};
@ -2901,6 +2854,7 @@ int main(int argc, char **argv)
{
// cancel
llama.request_cancel(task_id);
llama.queue_results.remove_waiting_task_id(task_id);
};
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
@ -2938,10 +2892,11 @@ int main(int argc, char **argv)
json data = oaicompat_completion_params_parse(json::parse(req.body));
const int task_id = llama.request_completion(data, false, false, -1);
llama.queue_results.add_waiting_task_id(task_id);
if (!json_value(data, "stream", false)) {
std::string completion_text;
task_result result = llama.next_result(task_id);
task_result result = llama.queue_results.recv(task_id);
if (!result.error && result.stop) {
json oaicompat_result = format_final_response_oaicompat(data, result);
@ -2952,12 +2907,13 @@ int main(int argc, char **argv)
} else {
res.status = 500;
res.set_content(result.result_json["content"], "text/plain; charset=utf-8");
llama.queue_results.remove_waiting_task_id(task_id);
return;
}
} else {
const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink &sink) {
while (true) {
task_result llama_result = llama.next_result(task_id);
task_result llama_result = llama.queue_results.recv(task_id);
if (!llama_result.error) {
std::vector<json> result_array = format_partial_response_oaicompat( llama_result);
@ -2970,6 +2926,7 @@ int main(int argc, char **argv)
"\n\n";
LOG_VERBOSE("data stream", {{"to_send", str}});
if (!sink.write(str.c_str(), str.size())) {
llama.queue_results.remove_waiting_task_id(task_id);
return false;
}
}
@ -2985,18 +2942,21 @@ int main(int argc, char **argv)
"\n\n";
LOG_VERBOSE("data stream", {{"to_send", str}});
if (!sink.write(str.c_str(), str.size())) {
llama.queue_results.remove_waiting_task_id(task_id);
return false;
}
break;
}
}
sink.done();
llama.queue_results.remove_waiting_task_id(task_id);
return true;
};
auto on_complete = [task_id, &llama](bool) {
// cancel request
llama.request_cancel(task_id);
llama.queue_results.remove_waiting_task_id(task_id);
};
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
@ -3013,7 +2973,7 @@ int main(int argc, char **argv)
const int task_id = llama.request_completion(data, true, false, -1);
if (!json_value(data, "stream", false)) {
std::string completion_text;
task_result result = llama.next_result(task_id);
task_result result = llama.queue_results.recv(task_id);
if (!result.error && result.stop)
{
res.set_content(result.result_json.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8");
@ -3028,7 +2988,7 @@ int main(int argc, char **argv)
const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink & sink) {
while (true)
{
task_result result = llama.next_result(task_id);
task_result result = llama.queue_results.recv(task_id);
if (!result.error) {
const std::string str =
"data: " +
@ -3128,7 +3088,7 @@ int main(int argc, char **argv)
}
const int task_id = llama.request_completion({ {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, false, true, -1);
task_result result = llama.next_result(task_id);
task_result result = llama.queue_results.recv(task_id);
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
});
@ -3149,6 +3109,13 @@ int main(int argc, char **argv)
llama.queue_tasks.on_all_tasks_finished(std::bind(
&llama_server_context::run_on_all_tasks_finished, &llama));
llama.queue_tasks.start_loop();
llama.queue_results.on_multitask_update(std::bind(
&llama_server_context::update_multitask,
&llama,
std::placeholders::_1,
std::placeholders::_2,
std::placeholders::_3
));
t.join();

View file

@ -5,6 +5,7 @@
#include <set>
#include <mutex>
#include <condition_variable>
#include <unordered_map>
#include "json.hpp"
@ -12,6 +13,8 @@
using json = nlohmann::json;
extern bool server_verbose;
#ifndef SERVER_VERBOSE
#define SERVER_VERBOSE 1
#endif
@ -160,12 +163,12 @@ struct llama_server_queue {
int id = 0;
std::mutex mutex_tasks;
std::vector<T> queue_tasks;
std::vector<T> queue_tasks_deferred;
std::condition_variable condition_tasks;
std::function<void(T)> callback_new_task;
std::function<void(void)> callback_all_task_finished;
int post(T task) {
LOG_INFO("post", {});
std::unique_lock<std::mutex> lock(mutex_tasks);
task.id = id++;
queue_tasks.push_back(std::move(task));
@ -173,6 +176,11 @@ struct llama_server_queue {
return task.id;
}
void defer(T task) {
std::unique_lock<std::mutex> lock(mutex_tasks);
queue_tasks_deferred.push_back(std::move(task));
}
int get_next_id() {
std::unique_lock<std::mutex> lock(mutex_tasks);
return id++;
@ -189,7 +197,7 @@ struct llama_server_queue {
void start_loop() {
while (true) {
// new task arrived
LOG_INFO("have new task", {});
LOG_VERBOSE("have new task", {});
{
while (true)
{
@ -201,13 +209,27 @@ struct llama_server_queue {
task_server task = queue_tasks.front();
queue_tasks.erase(queue_tasks.begin());
lock.unlock();
LOG_INFO("callback_new_task", {});
LOG_VERBOSE("callback_new_task", {});
callback_new_task(task);
}
LOG_INFO("callback_all_task_finished", {});
// move deferred tasks back to main loop
{
std::unique_lock<std::mutex> lock(mutex_tasks);
//queue_tasks.insert(
// queue_tasks.end(),
// std::make_move_iterator(queue_tasks_deferred.begin()),
// std::make_move_iterator(queue_tasks_deferred.end())
//);
for (auto & task : queue_tasks_deferred) {
queue_tasks.push_back(task);
}
queue_tasks_deferred.clear();
lock.unlock();
}
LOG_VERBOSE("callback_all_task_finished", {});
callback_all_task_finished();
}
LOG_INFO("wait for new task", {});
LOG_VERBOSE("wait for new task", {});
// wait for new task
{
std::unique_lock<std::mutex> lock(mutex_tasks);
@ -221,6 +243,78 @@ struct llama_server_queue {
}
};
struct llama_server_response_event {
typedef std::function<void(int, int, task_result&)> callback_multitask_t;
std::vector<task_result> queue_results;
std::mutex mutex_task_ids;
std::set<int> waiting_task_ids;
std::mutex mutex_results;
std::condition_variable condition_results;
callback_multitask_t callback_update_multitask;
void add_waiting_task_id(int task_id) {
std::unique_lock<std::mutex> lock(mutex_task_ids);
waiting_task_ids.insert(task_id);
}
void remove_waiting_task_id(int task_id) {
std::unique_lock<std::mutex> lock(mutex_task_ids);
waiting_task_ids.erase(task_id);
}
task_result recv(int task_id) {
while (true)
{
std::unique_lock<std::mutex> lock(mutex_results);
condition_results.wait(lock, [&]{
return !queue_results.empty();
});
LOG_VERBOSE("condition_results unblock", {});
for (int i = 0; i < (int) queue_results.size(); i++)
{
if (queue_results[i].id == task_id)
{
assert(queue_results[i].multitask_id == -1);
task_result res = queue_results[i];
queue_results.erase(queue_results.begin() + i);
return res;
}
}
}
// should never reach here
}
void on_multitask_update(callback_multitask_t callback) {
callback_update_multitask = callback;
}
void send(task_result result) {
std::unique_lock<std::mutex> lock(mutex_results);
std::unique_lock<std::mutex> lock1(mutex_task_ids);
LOG_VERBOSE("send new result", {});
for (auto& task_id : waiting_task_ids) {
LOG_TEE("waiting task id %i \n", task_id);
// for now, tasks that have associated parent multitasks just get erased once multitask picks up the result
if (result.multitask_id == task_id)
{
LOG_VERBOSE("callback_update_multitask", {});
callback_update_multitask(task_id, result.id, result);
continue;
}
if (result.id == task_id)
{
LOG_VERBOSE("queue_results.push_back", {});
queue_results.push_back(result);
condition_results.notify_one();
return;
}
}
}
};
//
// base64 utils (TODO: move to common in the future)
//