server: add llama_server_queue struct

This commit is contained in:
ngxson 2024-01-20 00:25:20 +01:00
parent 381ee19572
commit 6e29f4c725
4 changed files with 372 additions and 295 deletions

View file

@ -619,7 +619,7 @@ embedding: examples/embedding/embedding.cpp ggml.o llama.o $(C
save-load-state: examples/save-load-state/save-load-state.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
server: examples/server/server.cpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp examples/llava/clip.cpp examples/llava/clip.h common/stb_image.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
server: examples/server/server.cpp examples/server/utils.hpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp examples/llava/clip.cpp examples/llava/clip.h common/stb_image.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
$(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS) $(LWINSOCK2) -Wno-cast-qual
gguf: examples/gguf/gguf.cpp ggml.o $(OBJS)

View file

@ -1,7 +1,7 @@
set(TARGET server)
option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON)
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
add_executable(${TARGET} server.cpp json.hpp httplib.h)
add_executable(${TARGET} server.cpp utils.hpp json.hpp httplib.h)
install(TARGETS ${TARGET} RUNTIME)
target_compile_definitions(${TARGET} PRIVATE
SERVER_VERBOSE=$<BOOL:${LLAMA_SERVER_VERBOSE}>

View file

@ -1,6 +1,7 @@
#include "common.h"
#include "llama.h"
#include "grammar-parser.h"
#include "utils.hpp"
#include "../llava/clip.h"
@ -28,10 +29,6 @@
#include <condition_variable>
#include <atomic>
#ifndef SERVER_VERBOSE
#define SERVER_VERBOSE 1
#endif
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
using json = nlohmann::json;
@ -48,196 +45,9 @@ struct server_params
static bool server_verbose = false;
#if SERVER_VERBOSE != 1
#define LOG_VERBOSE(MSG, ...)
#else
#define LOG_VERBOSE(MSG, ...) \
do \
{ \
if (server_verbose) \
{ \
server_log("VERBOSE", __func__, __LINE__, MSG, __VA_ARGS__); \
} \
} while (0)
#endif
#define LOG_ERROR( MSG, ...) server_log("ERROR", __func__, __LINE__, MSG, __VA_ARGS__)
#define LOG_WARNING(MSG, ...) server_log("WARNING", __func__, __LINE__, MSG, __VA_ARGS__)
#define LOG_INFO( MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__)
json oaicompat_completion_params_parse(const json &body);
std::string format_chatml(std::vector<json> messages);
//
// base64 utils (TODO: move to common in the future)
//
static const std::string base64_chars =
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"abcdefghijklmnopqrstuvwxyz"
"0123456789+/";
static inline bool is_base64(uint8_t c)
{
return (isalnum(c) || (c == '+') || (c == '/'));
}
static std::vector<uint8_t> base64_decode(const std::string & encoded_string)
{
int i = 0;
int j = 0;
int in_ = 0;
int in_len = encoded_string.size();
uint8_t char_array_4[4];
uint8_t char_array_3[3];
std::vector<uint8_t> ret;
while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_]))
{
char_array_4[i++] = encoded_string[in_]; in_++;
if (i == 4)
{
for (i = 0; i <4; i++)
{
char_array_4[i] = base64_chars.find(char_array_4[i]);
}
char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4);
char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
for (i = 0; (i < 3); i++)
{
ret.push_back(char_array_3[i]);
}
i = 0;
}
}
if (i)
{
for (j = i; j <4; j++)
{
char_array_4[j] = 0;
}
for (j = 0; j <4; j++)
{
char_array_4[j] = base64_chars.find(char_array_4[j]);
}
char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4);
char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
for (j = 0; (j < i - 1); j++)
{
ret.push_back(char_array_3[j]);
}
}
return ret;
}
//
// parallel
//
enum server_state {
SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet
SERVER_STATE_READY, // Server is ready and model is loaded
SERVER_STATE_ERROR // An error occurred, load_model failed
};
enum task_type {
TASK_TYPE_COMPLETION,
TASK_TYPE_CANCEL,
};
struct task_server {
int id;
int target_id;
task_type type;
json data;
bool infill_mode = false;
bool embedding_mode = false;
int multitask_id = -1;
};
struct task_result {
int id;
int multitask_id = -1;
bool stop;
bool error;
json result_json;
};
struct task_multi {
int id;
std::set<int> subtasks_remaining{};
std::vector<task_result> results{};
};
// TODO: can become bool if we can't find use of more states
enum slot_state
{
IDLE,
PROCESSING,
};
enum slot_command
{
NONE,
LOAD_PROMPT,
RELEASE,
};
struct slot_params
{
bool stream = true;
bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt
uint32_t seed = -1; // RNG seed
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_predict = -1; // new tokens to predict
std::vector<std::string> antiprompt;
json input_prefix;
json input_suffix;
};
struct slot_image
{
int32_t id;
bool request_encode_image = false;
float * image_embedding = nullptr;
int32_t image_tokens = 0;
clip_image_u8 * img_data;
std::string prefix_prompt; // before of this image
};
// completion token output with probabilities
struct completion_token_output
{
struct token_prob
{
llama_token tok;
float prob;
};
std::vector<token_prob> probs;
llama_token tok;
std::string text_to_send;
};
static size_t common_part(const std::vector<llama_token> &a, const std::vector<llama_token> &b)
{
size_t i;
@ -292,28 +102,6 @@ static std::string tokens_to_str(llama_context *ctx, Iter begin, Iter end)
return ret;
}
static void server_log(const char *level, const char *function, int line,
const char *message, const nlohmann::ordered_json &extra)
{
nlohmann::ordered_json log
{
{"timestamp", time(nullptr)},
{"level", level},
{"function", function},
{"line", line},
{"message", message},
};
if (!extra.empty())
{
log.merge_patch(extra);
}
const std::string str = log.dump(-1, ' ', false, json::error_handler_t::replace);
printf("%.*s\n", (int)str.size(), str.data());
fflush(stdout);
}
// format incomplete utf-8 multibyte character for output
static std::string tokens_to_output_formatted_string(const llama_context *ctx, const llama_token token)
{
@ -539,7 +327,6 @@ struct llama_server_context
bool all_slots_are_idle = false;
bool add_bos_token = true;
int32_t id_gen;
int32_t n_ctx; // total context for all clients / slots
// system prompt
@ -554,11 +341,11 @@ struct llama_server_context
// slots / clients
std::vector<llama_client_slot> slots;
std::vector<task_server> queue_tasks;
llama_server_queue<task_server> queue_tasks;
std::vector<task_result> queue_results;
std::vector<task_multi> queue_multitasks;
std::mutex mutex_tasks; // also guards id_gen, and queue_multitasks
std::condition_variable condition_tasks;
// std::condition_variable condition_tasks;
std::mutex mutex_results;
std::condition_variable condition_results;
@ -619,8 +406,6 @@ struct llama_server_context
}
void initialize() {
id_gen = 0;
// create slots
all_slots_are_idle = true;
@ -1201,7 +986,7 @@ struct llama_server_context
multi.id = id;
std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end()));
queue_multitasks.push_back(multi);
condition_tasks.notify_one();
// TODO @ngxson : Do we need to notify the queue_tasks?
}
void update_multi_task(int multitask_id, int subtask_id, task_result& result)
@ -1213,7 +998,7 @@ struct llama_server_context
{
multitask.subtasks_remaining.erase(subtask_id);
multitask.results.push_back(result);
condition_tasks.notify_one();
// TODO @ngxson : Do we need to notify the queue_tasks?
}
}
}
@ -1401,7 +1186,6 @@ struct llama_server_context
{
std::unique_lock<std::mutex> lock(mutex_tasks);
task_server task;
task.id = id_gen++;
task.target_id = 0;
task.data = std::move(data);
task.infill_mode = infill;
@ -1417,13 +1201,12 @@ struct llama_server_context
}
// otherwise, it's a single-prompt task, we actually queue it
queue_tasks.push_back(task);
condition_tasks.notify_one();
return task.id;
return queue_tasks.post(task);
}
task_result next_result(int task_id)
{
LOG_TEE("next_result %i \n", task_id);
while (true)
{
std::unique_lock<std::mutex> lock(mutex_results);
@ -1525,13 +1308,10 @@ struct llama_server_context
void request_cancel(int task_id)
{
std::unique_lock<std::mutex> lock(mutex_tasks);
task_server task;
task.id = id_gen++;
task.type = TASK_TYPE_CANCEL;
task.target_id = task_id;
queue_tasks.push_back(task);
condition_tasks.notify_one();
queue_tasks.post(task);
}
int split_multiprompt_task(task_server& multiprompt_task)
@ -1539,7 +1319,7 @@ struct llama_server_context
int prompt_count = multiprompt_task.data.at("prompt").size();
assert(prompt_count > 1);
int multitask_id = id_gen++;
int multitask_id = queue_tasks.get_next_id();
std::vector<int> subtask_ids(prompt_count);
for (int i = 0; i < prompt_count; i++)
{
@ -1555,73 +1335,67 @@ struct llama_server_context
return multitask_id;
}
void process_tasks()
void process_single_task(task_server task)
{
std::unique_lock<std::mutex> lock(mutex_tasks);
std::vector<task_server> deferred_tasks;
while (!queue_tasks.empty())
switch (task.type)
{
task_server task = queue_tasks.front();
queue_tasks.erase(queue_tasks.begin());
switch (task.type)
{
case TASK_TYPE_COMPLETION: {
llama_client_slot *slot = get_slot(json_value(task.data, "slot_id", -1));
if (slot == nullptr)
{
// if no slot is available, we defer this task for processing later
deferred_tasks.push_back(task);
case TASK_TYPE_COMPLETION: {
llama_client_slot *slot = get_slot(json_value(task.data, "slot_id", -1));
if (slot == nullptr)
{
// if no slot is available, we defer this task for processing later
// deferred_tasks.push_back(task);
LOG_INFO("no slot", {});
break;
}
if (task.data.contains("system_prompt"))
{
if (!all_slots_are_idle) {
send_error(task, "system prompt can only be updated when all slots are idle");
break;
}
process_system_prompt_data(task.data["system_prompt"]);
if (task.data.contains("system_prompt"))
// reset cache_tokens for all slots
for (llama_client_slot &slot : slots)
{
if (!all_slots_are_idle) {
send_error(task, "system prompt can only be updated when all slots are idle");
break;
}
process_system_prompt_data(task.data["system_prompt"]);
// reset cache_tokens for all slots
for (llama_client_slot &slot : slots)
{
slot.cache_tokens.clear();
}
slot.cache_tokens.clear();
}
}
slot->reset();
slot->reset();
slot->infill = task.infill_mode;
slot->embedding = task.embedding_mode;
slot->task_id = task.id;
slot->multitask_id = task.multitask_id;
slot->infill = task.infill_mode;
slot->embedding = task.embedding_mode;
slot->task_id = task.id;
slot->multitask_id = task.multitask_id;
if (!launch_slot_with_data(slot, task.data))
if (!launch_slot_with_data(slot, task.data))
{
// send error result
send_error(task, "internal_error");
break;
}
} break;
case TASK_TYPE_CANCEL: { // release slot linked with the task id
for (auto & slot : slots)
{
if (slot.task_id == task.target_id)
{
// send error result
send_error(task, "internal_error");
slot.release();
break;
}
} break;
case TASK_TYPE_CANCEL: { // release slot linked with the task id
for (auto & slot : slots)
{
if (slot.task_id == task.target_id)
{
slot.release();
break;
}
}
} break;
}
}
// add all the deferred tasks back the the queue
for (task_server &task : deferred_tasks)
{
queue_tasks.push_back(task);
}
} break;
case TASK_TYPE_NEXT_RESPONSE: {
// do nothing
} break;
}
}
void process_multitask()
{
// remove finished multitasks from the queue of multitasks, and add the corresponding result to the result queue
std::vector<task_result> agg_results;
auto queue_iterator = queue_multitasks.begin();
@ -1657,18 +1431,12 @@ struct llama_server_context
}
}
// done with tasks, unlock
lock.unlock();
// copy aggregate results of complete multi-tasks to the results queue
std::lock_guard<std::mutex> lock_results(mutex_results);
queue_results.insert(queue_results.end(), agg_results.begin(), agg_results.end());
}
bool update_slots() {
// attend tasks
process_tasks();
if (system_need_update)
{
LOG_TEE("updating system prompt\n");
@ -1684,10 +1452,12 @@ struct llama_server_context
LOG_TEE("all slots are idle and system prompt is empty, clear the KV cache\n");
kv_cache_clear();
}
std::unique_lock<std::mutex> lock(mutex_tasks);
condition_tasks.wait(lock, [&]{
return !queue_tasks.empty();
});
return true;
} else {
task_server task;
task.type = TASK_TYPE_NEXT_RESPONSE;
task.target_id = -1;
queue_tasks.post(task);
}
for (llama_client_slot &slot : slots)
@ -1997,6 +1767,11 @@ struct llama_server_context
}
return true;
}
void run_on_all_tasks_finished() {
process_multitask();
update_slots();
}
};
static void server_print_usage(const char *argv0, const gpt_params &params,
@ -3360,15 +3135,21 @@ int main(int argc, char **argv)
// GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!?
// "Bus error: 10" - this is on macOS, it does not crash on Linux
//std::thread t2([&]()
{
/*{
bool running = true;
while (running)
{
running = llama.update_slots();
}
}
}*/
//);
llama.queue_tasks.on_new_task(std::bind(
&llama_server_context::process_single_task, &llama, std::placeholders::_1));
llama.queue_tasks.on_all_tasks_finished(std::bind(
&llama_server_context::run_on_all_tasks_finished, &llama));
llama.queue_tasks.start_loop();
t.join();
llama_backend_free();

296
examples/server/utils.hpp Normal file
View file

@ -0,0 +1,296 @@
#pragma once
#include <string>
#include <vector>
#include <set>
#include <mutex>
#include <condition_variable>
#include "json.hpp"
#include "../llava/clip.h"
using json = nlohmann::json;
#ifndef SERVER_VERBOSE
#define SERVER_VERBOSE 1
#endif
#if SERVER_VERBOSE != 1
#define LOG_VERBOSE(MSG, ...)
#else
#define LOG_VERBOSE(MSG, ...) \
do \
{ \
if (server_verbose) \
{ \
server_log("VERBOSE", __func__, __LINE__, MSG, __VA_ARGS__); \
} \
} while (0)
#endif
#define LOG_ERROR( MSG, ...) server_log("ERROR", __func__, __LINE__, MSG, __VA_ARGS__)
#define LOG_WARNING(MSG, ...) server_log("WARNING", __func__, __LINE__, MSG, __VA_ARGS__)
#define LOG_INFO( MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__)
//
// parallel
//
enum server_state {
SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet
SERVER_STATE_READY, // Server is ready and model is loaded
SERVER_STATE_ERROR // An error occurred, load_model failed
};
enum task_type {
TASK_TYPE_COMPLETION,
TASK_TYPE_CANCEL,
TASK_TYPE_NEXT_RESPONSE
};
struct task_server {
int id = -1; // to be filled by llama_server_queue
int target_id;
task_type type;
json data;
bool infill_mode = false;
bool embedding_mode = false;
int multitask_id = -1;
};
struct task_result {
int id;
int multitask_id = -1;
bool stop;
bool error;
json result_json;
};
struct task_multi {
int id;
std::set<int> subtasks_remaining{};
std::vector<task_result> results{};
};
// TODO: can become bool if we can't find use of more states
enum slot_state
{
IDLE,
PROCESSING,
};
enum slot_command
{
NONE,
LOAD_PROMPT,
RELEASE,
};
struct slot_params
{
bool stream = true;
bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt
uint32_t seed = -1; // RNG seed
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_predict = -1; // new tokens to predict
std::vector<std::string> antiprompt;
json input_prefix;
json input_suffix;
};
struct slot_image
{
int32_t id;
bool request_encode_image = false;
float * image_embedding = nullptr;
int32_t image_tokens = 0;
clip_image_u8 * img_data;
std::string prefix_prompt; // before of this image
};
// completion token output with probabilities
struct completion_token_output
{
struct token_prob
{
llama_token tok;
float prob;
};
std::vector<token_prob> probs;
llama_token tok;
std::string text_to_send;
};
static inline void server_log(const char *level, const char *function, int line,
const char *message, const nlohmann::ordered_json &extra)
{
nlohmann::ordered_json log
{
{"timestamp", time(nullptr)},
{"level", level},
{"function", function},
{"line", line},
{"message", message},
};
if (!extra.empty())
{
log.merge_patch(extra);
}
const std::string str = log.dump(-1, ' ', false, json::error_handler_t::replace);
printf("%.*s\n", (int)str.size(), str.data());
fflush(stdout);
}
//
// work queue utils
//
template<typename T>
struct llama_server_queue {
int id = 0;
std::mutex mutex_tasks;
std::vector<T> queue_tasks;
std::condition_variable condition_tasks;
std::function<void(T)> callback_new_task;
std::function<void(void)> callback_all_task_finished;
int post(T task) {
LOG_INFO("post", {});
std::unique_lock<std::mutex> lock(mutex_tasks);
task.id = id++;
queue_tasks.push_back(std::move(task));
condition_tasks.notify_one();
return task.id;
}
int get_next_id() {
std::unique_lock<std::mutex> lock(mutex_tasks);
return id++;
}
void on_new_task(std::function<void(T)> callback) {
callback_new_task = callback;
}
void on_all_tasks_finished(std::function<void(void)> callback) {
callback_all_task_finished = callback;
}
void start_loop() {
while (true) {
// new task arrived
LOG_INFO("have new task", {});
{
while (true)
{
std::unique_lock<std::mutex> lock(mutex_tasks);
if (queue_tasks.empty()) {
lock.unlock();
break;
}
task_server task = queue_tasks.front();
queue_tasks.erase(queue_tasks.begin());
lock.unlock();
LOG_INFO("callback_new_task", {});
callback_new_task(task);
}
LOG_INFO("callback_all_task_finished", {});
callback_all_task_finished();
}
LOG_INFO("wait for new task", {});
// wait for new task
{
std::unique_lock<std::mutex> lock(mutex_tasks);
if (queue_tasks.empty()) {
condition_tasks.wait(lock, [&]{
return !queue_tasks.empty();
});
}
}
}
}
};
//
// base64 utils (TODO: move to common in the future)
//
static const std::string base64_chars =
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"abcdefghijklmnopqrstuvwxyz"
"0123456789+/";
static inline bool is_base64(uint8_t c)
{
return (isalnum(c) || (c == '+') || (c == '/'));
}
static inline std::vector<uint8_t> base64_decode(const std::string & encoded_string)
{
int i = 0;
int j = 0;
int in_ = 0;
int in_len = encoded_string.size();
uint8_t char_array_4[4];
uint8_t char_array_3[3];
std::vector<uint8_t> ret;
while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_]))
{
char_array_4[i++] = encoded_string[in_]; in_++;
if (i == 4)
{
for (i = 0; i <4; i++)
{
char_array_4[i] = base64_chars.find(char_array_4[i]);
}
char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4);
char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
for (i = 0; (i < 3); i++)
{
ret.push_back(char_array_3[i]);
}
i = 0;
}
}
if (i)
{
for (j = i; j <4; j++)
{
char_array_4[j] = 0;
}
for (j = 0; j <4; j++)
{
char_array_4[j] = base64_chars.find(char_array_4[j]);
}
char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4);
char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
for (j = 0; (j < i - 1); j++)
{
ret.push_back(char_array_3[j]);
}
}
return ret;
}