Merge branch 'master' into compilade/refactor-kv-cache
This commit is contained in:
commit
bc320ef66d
395 changed files with 57725 additions and 169970 deletions
|
@ -15,6 +15,8 @@
|
|||
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
||||
#define JSON_ASSERT GGML_ASSERT
|
||||
#include "json.hpp"
|
||||
// mime type for sending response
|
||||
#define MIMETYPE_JSON "application/json; charset=utf-8"
|
||||
|
||||
// auto generated files (update with ./deps.sh)
|
||||
#include "colorthemes.css.hpp"
|
||||
|
@ -67,7 +69,6 @@ enum slot_command {
|
|||
enum server_state {
|
||||
SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet
|
||||
SERVER_STATE_READY, // Server is ready and model is loaded
|
||||
SERVER_STATE_ERROR // An error occurred, load_model failed
|
||||
};
|
||||
|
||||
enum server_task_type {
|
||||
|
@ -78,6 +79,7 @@ enum server_task_type {
|
|||
SERVER_TASK_TYPE_SLOT_SAVE,
|
||||
SERVER_TASK_TYPE_SLOT_RESTORE,
|
||||
SERVER_TASK_TYPE_SLOT_ERASE,
|
||||
SERVER_TASK_TYPE_SET_LORA,
|
||||
};
|
||||
|
||||
struct server_task {
|
||||
|
@ -622,6 +624,7 @@ struct server_response {
|
|||
struct server_context {
|
||||
llama_model * model = nullptr;
|
||||
llama_context * ctx = nullptr;
|
||||
std::vector<llama_lora_adapter_container> lora_adapters;
|
||||
|
||||
gpt_params params;
|
||||
|
||||
|
@ -629,6 +632,7 @@ struct server_context {
|
|||
|
||||
bool clean_kv_cache = true;
|
||||
bool add_bos_token = true;
|
||||
bool has_eos_token = false;
|
||||
|
||||
int32_t n_ctx; // total context for all clients / slots
|
||||
|
||||
|
@ -677,7 +681,11 @@ struct server_context {
|
|||
// dedicate one sequence to the system prompt
|
||||
params.n_parallel += 1;
|
||||
|
||||
std::tie(model, ctx) = llama_init_from_gpt_params(params);
|
||||
llama_init_result llama_init = llama_init_from_gpt_params(params);
|
||||
|
||||
model = llama_init.model;
|
||||
ctx = llama_init.context;
|
||||
lora_adapters = llama_init.lora_adapters;
|
||||
params.n_parallel -= 1; // but be sneaky about it
|
||||
if (model == nullptr) {
|
||||
LOG_ERROR("unable to load model", {{"model", params.model}});
|
||||
|
@ -686,8 +694,8 @@ struct server_context {
|
|||
|
||||
n_ctx = llama_n_ctx(ctx);
|
||||
|
||||
add_bos_token = llama_should_add_bos_token(model);
|
||||
GGML_ASSERT(llama_add_eos_token(model) != 1);
|
||||
add_bos_token = llama_add_bos_token(model);
|
||||
has_eos_token = !llama_add_eos_token(model);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
@ -737,6 +745,8 @@ struct server_context {
|
|||
slot.ga_n = ga_n;
|
||||
slot.ga_w = ga_w;
|
||||
|
||||
slot.sparams = params.sparams;
|
||||
|
||||
slot.reset();
|
||||
|
||||
slots.push_back(slot);
|
||||
|
@ -745,13 +755,13 @@ struct server_context {
|
|||
default_generation_settings_for_props = get_formated_generation(slots.front());
|
||||
default_generation_settings_for_props["seed"] = -1;
|
||||
|
||||
// the update_slots() logic will always submit a maximum of n_batch tokens
|
||||
// the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens
|
||||
// note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
|
||||
{
|
||||
const int32_t n_batch = llama_n_batch(ctx);
|
||||
|
||||
// only a single seq_id per token is needed
|
||||
batch = llama_batch_init(n_batch, 0, 1);
|
||||
batch = llama_batch_init(std::max(n_batch, params.n_parallel), 0, 1);
|
||||
}
|
||||
|
||||
metrics.init();
|
||||
|
@ -884,7 +894,8 @@ struct server_context {
|
|||
|
||||
bool launch_slot_with_task(server_slot & slot, const server_task & task) {
|
||||
slot_params default_params;
|
||||
llama_sampling_params default_sparams;
|
||||
// Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
|
||||
llama_sampling_params default_sparams = params.sparams;
|
||||
auto & data = task.data;
|
||||
|
||||
if (data.count("__oaicompat") != 0) {
|
||||
|
@ -897,7 +908,7 @@ struct server_context {
|
|||
|
||||
slot.params.stream = json_value(data, "stream", false);
|
||||
slot.params.cache_prompt = json_value(data, "cache_prompt", false);
|
||||
slot.params.n_predict = json_value(data, "n_predict", default_params.n_predict);
|
||||
slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict));
|
||||
slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
|
||||
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
|
||||
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
|
||||
|
@ -966,6 +977,8 @@ struct server_context {
|
|||
(prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_string()) ||
|
||||
(prompt->is_array() && !prompt->empty() && prompt->at(0).is_number_integer())) {
|
||||
slot.prompt = *prompt;
|
||||
} else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) {
|
||||
slot.prompt = prompt->at(0);
|
||||
} else {
|
||||
send_error(task, "\"prompt\" must be a string or an array of integers", ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
|
@ -1020,7 +1033,7 @@ struct server_context {
|
|||
{
|
||||
slot.sparams.logit_bias.clear();
|
||||
|
||||
if (json_value(data, "ignore_eos", false)) {
|
||||
if (json_value(data, "ignore_eos", false) && has_eos_token) {
|
||||
slot.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
|
||||
}
|
||||
|
||||
|
@ -1125,28 +1138,19 @@ struct server_context {
|
|||
if (!system_prompt.empty()) {
|
||||
system_tokens = ::llama_tokenize(ctx, system_prompt, true);
|
||||
|
||||
llama_batch_clear(batch);
|
||||
|
||||
for (int i = 0; i < (int)system_tokens.size(); ++i) {
|
||||
llama_batch_add(batch, system_tokens[i], i, { 0 }, false);
|
||||
}
|
||||
|
||||
const int32_t n_batch = llama_n_batch(ctx);
|
||||
const int32_t n_tokens_prompt = system_tokens.size();
|
||||
|
||||
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
|
||||
const int32_t n_tokens = std::min(params.n_batch, batch.n_tokens - i);
|
||||
llama_batch batch_view = {
|
||||
n_tokens,
|
||||
batch.token + i,
|
||||
nullptr,
|
||||
batch.pos + i,
|
||||
batch.n_seq_id + i,
|
||||
batch.seq_id + i,
|
||||
batch.logits + i,
|
||||
0, 0, 0, // unused
|
||||
};
|
||||
for (int32_t i = 0; i < n_tokens_prompt; i += n_batch) {
|
||||
const int32_t n_tokens = std::min(n_batch, n_tokens_prompt - i);
|
||||
|
||||
if (llama_decode(ctx, batch_view) != 0) {
|
||||
llama_batch_clear(batch);
|
||||
|
||||
for (int32_t j = 0; j < n_tokens; ++j) {
|
||||
llama_batch_add(batch, system_tokens[i + j], i + j, { 0 }, false);
|
||||
}
|
||||
|
||||
if (llama_decode(ctx, batch) != 0) {
|
||||
LOG_ERROR("llama_decode() failed", {});
|
||||
return;
|
||||
}
|
||||
|
@ -1179,7 +1183,7 @@ struct server_context {
|
|||
|
||||
bool process_token(completion_token_output & result, server_slot & slot) {
|
||||
// remember which tokens were sampled - used for repetition penalties during sampling
|
||||
const std::string token_str = llama_token_to_piece(ctx, result.tok, false);
|
||||
const std::string token_str = llama_token_to_piece(ctx, result.tok, params.special);
|
||||
slot.sampled = result.tok;
|
||||
|
||||
// search stop word and delete it
|
||||
|
@ -1319,7 +1323,7 @@ struct server_context {
|
|||
|
||||
return json {
|
||||
{"n_ctx", slot.n_ctx},
|
||||
{"n_predict", slot.n_predict},
|
||||
{"n_predict", slot.n_predict}, // Server configured n_predict
|
||||
{"model", params.model_alias},
|
||||
{"seed", slot.sparams.seed},
|
||||
{"temperature", slot.sparams.temp},
|
||||
|
@ -1341,7 +1345,7 @@ struct server_context {
|
|||
{"mirostat_eta", slot.sparams.mirostat_eta},
|
||||
{"penalize_nl", slot.sparams.penalize_nl},
|
||||
{"stop", slot.params.antiprompt},
|
||||
{"n_predict", slot.params.n_predict}, // TODO: fix duplicate key n_predict
|
||||
{"max_tokens", slot.params.n_predict}, // User configured n_predict
|
||||
{"n_keep", slot.params.n_keep},
|
||||
{"n_discard", slot.params.n_discard},
|
||||
{"ignore_eos", ignore_eos},
|
||||
|
@ -1844,6 +1848,16 @@ struct server_context {
|
|||
};
|
||||
queue_results.send(result);
|
||||
} break;
|
||||
case SERVER_TASK_TYPE_SET_LORA:
|
||||
{
|
||||
llama_lora_adapters_apply(ctx, lora_adapters);
|
||||
server_task_result result;
|
||||
result.id = task.id;
|
||||
result.stop = true;
|
||||
result.error = false;
|
||||
result.data = json{{ "success", true }};
|
||||
queue_results.send(result);
|
||||
} break;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2002,6 +2016,11 @@ struct server_context {
|
|||
int32_t n_batch = llama_n_batch(ctx);
|
||||
int32_t n_ubatch = llama_n_ubatch(ctx);
|
||||
|
||||
// track if this is an embedding or non-embedding batch
|
||||
// if we've added sampled tokens above, we are in non-embedding mode
|
||||
// -1: none, 0: non-embedding, 1: embedding
|
||||
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
|
||||
|
||||
// next, batch any pending prompts without exceeding n_batch
|
||||
if (params.cont_batching || batch.n_tokens == 0) {
|
||||
for (auto & slot : slots) {
|
||||
|
@ -2020,7 +2039,7 @@ struct server_context {
|
|||
slot.t_start_generation = 0;
|
||||
|
||||
if (slot.infill) {
|
||||
const bool add_bos = llama_should_add_bos_token(model);
|
||||
const bool add_bos = llama_add_bos_token(model);
|
||||
bool suff_rm_leading_spc = true;
|
||||
if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) {
|
||||
params.input_suffix.erase(0, 1);
|
||||
|
@ -2172,6 +2191,14 @@ struct server_context {
|
|||
}
|
||||
}
|
||||
|
||||
// check that we are in the right batch_type, if not defer the slot
|
||||
bool slot_type = slot.embedding ? 1 : 0;
|
||||
if (batch_type == -1) {
|
||||
batch_type = slot_type;
|
||||
} else if (batch_type != slot_type) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// keep only the common part
|
||||
llama_pos p0 = (llama_pos) system_tokens.size() + slot.n_past;
|
||||
|
||||
|
@ -2278,6 +2305,9 @@ struct server_context {
|
|||
{"n_tokens", batch.n_tokens},
|
||||
});
|
||||
|
||||
// make sure we're in the right embedding mode
|
||||
llama_set_embeddings(ctx, batch_type == 1);
|
||||
|
||||
// process the created batch of tokens
|
||||
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
|
||||
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
|
||||
|
@ -2482,6 +2512,9 @@ int main(int argc, char ** argv) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
// parse arguments from environment variables
|
||||
gpt_params_parse_from_env(params);
|
||||
|
||||
// TODO: not great to use extern vars
|
||||
server_log_json = params.log_json;
|
||||
server_verbose = params.verbosity > 0;
|
||||
|
@ -2506,8 +2539,8 @@ int main(int argc, char ** argv) {
|
|||
});
|
||||
|
||||
LOG_INFO("system info", {
|
||||
{"n_threads", params.n_threads},
|
||||
{"n_threads_batch", params.n_threads_batch},
|
||||
{"n_threads", params.cpuparams.n_threads},
|
||||
{"n_threads_batch", params.cpuparams_batch.n_threads},
|
||||
{"total_threads", std::thread::hardware_concurrency()},
|
||||
{"system_info", llama_print_system_info()},
|
||||
});
|
||||
|
@ -2532,19 +2565,19 @@ int main(int argc, char ** argv) {
|
|||
svr->set_default_headers({{"Server", "llama.cpp"}});
|
||||
|
||||
// CORS preflight
|
||||
svr->Options(R"(.*)", [](const httplib::Request & req, httplib::Response & res) {
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
svr->Options(R"(.*)", [](const httplib::Request &, httplib::Response & res) {
|
||||
// Access-Control-Allow-Origin is already set by middleware
|
||||
res.set_header("Access-Control-Allow-Credentials", "true");
|
||||
res.set_header("Access-Control-Allow-Methods", "POST");
|
||||
res.set_header("Access-Control-Allow-Headers", "*");
|
||||
return res.set_content("", "application/json; charset=utf-8");
|
||||
return res.set_content("", "text/html"); // blank response, no data
|
||||
});
|
||||
|
||||
svr->set_logger(log_server_request);
|
||||
|
||||
auto res_error = [](httplib::Response & res, json error_data) {
|
||||
json final_response {{"error", error_data}};
|
||||
res.set_content(final_response.dump(), "application/json; charset=utf-8");
|
||||
res.set_content(final_response.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
|
||||
res.status = json_value(error_data, "code", 500);
|
||||
};
|
||||
|
||||
|
@ -2574,11 +2607,6 @@ int main(int argc, char ** argv) {
|
|||
svr->set_read_timeout (params.timeout_read);
|
||||
svr->set_write_timeout(params.timeout_write);
|
||||
|
||||
if (!svr->bind_to_port(params.hostname, params.port)) {
|
||||
fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", params.hostname.c_str(), params.port);
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, std::string> log_data;
|
||||
|
||||
log_data["hostname"] = params.hostname;
|
||||
|
@ -2594,35 +2622,6 @@ int main(int argc, char ** argv) {
|
|||
// Necessary similarity of prompt for slot selection
|
||||
ctx_server.slot_prompt_similarity = params.slot_prompt_similarity;
|
||||
|
||||
// load the model
|
||||
if (!ctx_server.load_model(params)) {
|
||||
state.store(SERVER_STATE_ERROR);
|
||||
return 1;
|
||||
} else {
|
||||
ctx_server.init();
|
||||
state.store(SERVER_STATE_READY);
|
||||
}
|
||||
|
||||
LOG_INFO("model loaded", {});
|
||||
|
||||
const auto model_meta = ctx_server.model_meta();
|
||||
|
||||
// if a custom chat template is not supplied, we will use the one that comes with the model (if any)
|
||||
if (params.chat_template.empty()) {
|
||||
if (!ctx_server.validate_model_chat_template()) {
|
||||
LOG_ERROR("The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {});
|
||||
params.chat_template = "chatml";
|
||||
}
|
||||
}
|
||||
|
||||
// print sample chat example to make it clear which template is used
|
||||
{
|
||||
LOG_INFO("chat template", {
|
||||
{"chat_example", llama_chat_format_example(ctx_server.model, params.chat_template)},
|
||||
{"built_in", params.chat_template.empty()},
|
||||
});
|
||||
}
|
||||
|
||||
//
|
||||
// Middlewares
|
||||
//
|
||||
|
@ -2666,8 +2665,6 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// API key is invalid or not provided
|
||||
// TODO: make another middleware for CORS related logic
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
res_error(res, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION));
|
||||
|
||||
LOG_WARNING("Unauthorized: Invalid API Key", {});
|
||||
|
@ -2675,8 +2672,21 @@ int main(int argc, char ** argv) {
|
|||
return false;
|
||||
};
|
||||
|
||||
auto middleware_server_state = [&res_error, &state](const httplib::Request &, httplib::Response & res) {
|
||||
server_state current_state = state.load();
|
||||
if (current_state == SERVER_STATE_LOADING_MODEL) {
|
||||
res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
// register server middlewares
|
||||
svr->set_pre_routing_handler([&middleware_validate_api_key](const httplib::Request & req, httplib::Response & res) {
|
||||
svr->set_pre_routing_handler([&middleware_validate_api_key, &middleware_server_state](const httplib::Request & req, httplib::Response & res) {
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
if (!middleware_server_state(req, res)) {
|
||||
return httplib::Server::HandlerResponse::Handled;
|
||||
}
|
||||
if (!middleware_validate_api_key(req, res)) {
|
||||
return httplib::Server::HandlerResponse::Handled;
|
||||
}
|
||||
|
@ -2687,62 +2697,15 @@ int main(int argc, char ** argv) {
|
|||
// Route handlers (or controllers)
|
||||
//
|
||||
|
||||
const auto handle_health = [&](const httplib::Request & req, httplib::Response & res) {
|
||||
server_state current_state = state.load();
|
||||
switch (current_state) {
|
||||
case SERVER_STATE_READY:
|
||||
{
|
||||
// request slots data using task queue
|
||||
server_task task;
|
||||
task.id = ctx_server.queue_tasks.get_new_id();
|
||||
task.type = SERVER_TASK_TYPE_METRICS;
|
||||
task.id_target = -1;
|
||||
|
||||
ctx_server.queue_results.add_waiting_task_id(task.id);
|
||||
ctx_server.queue_tasks.post(task);
|
||||
|
||||
// get the result
|
||||
server_task_result result = ctx_server.queue_results.recv(task.id);
|
||||
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
||||
|
||||
const int n_idle_slots = result.data.at("idle");
|
||||
const int n_processing_slots = result.data.at("processing");
|
||||
|
||||
json health = {
|
||||
{"status", "ok"},
|
||||
{"slots_idle", n_idle_slots},
|
||||
{"slots_processing", n_processing_slots}
|
||||
};
|
||||
|
||||
res.status = 200; // HTTP OK
|
||||
if (params.endpoint_slots && req.has_param("include_slots")) {
|
||||
health["slots"] = result.data.at("slots");
|
||||
}
|
||||
|
||||
if (n_idle_slots == 0) {
|
||||
health["status"] = "no slot available";
|
||||
if (req.has_param("fail_on_no_slot")) {
|
||||
res.status = 503; // HTTP Service Unavailable
|
||||
}
|
||||
}
|
||||
|
||||
res.set_content(health.dump(), "application/json");
|
||||
break;
|
||||
}
|
||||
case SERVER_STATE_LOADING_MODEL:
|
||||
{
|
||||
res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
|
||||
} break;
|
||||
case SERVER_STATE_ERROR:
|
||||
{
|
||||
res_error(res, format_error_response("Model failed to load", ERROR_TYPE_SERVER));
|
||||
} break;
|
||||
}
|
||||
const auto handle_health = [&](const httplib::Request &, httplib::Response & res) {
|
||||
// error and loading states are handled by middleware
|
||||
json health = {{"status", "ok"}};
|
||||
res.set_content(health.dump(), "application/json");
|
||||
};
|
||||
|
||||
const auto handle_slots = [&](const httplib::Request &, httplib::Response & res) {
|
||||
const auto handle_slots = [&](const httplib::Request & req, httplib::Response & res) {
|
||||
if (!params.endpoint_slots) {
|
||||
res_error(res, format_error_response("This server does not support slots endpoint.", ERROR_TYPE_NOT_SUPPORTED));
|
||||
res_error(res, format_error_response("This server does not support slots endpoint. Start it without `--no-slots`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -2760,13 +2723,22 @@ int main(int argc, char ** argv) {
|
|||
server_task_result result = ctx_server.queue_results.recv(task.id);
|
||||
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
||||
|
||||
res.set_content(result.data.at("slots").dump(), "application/json");
|
||||
// optionally return "fail_on_no_slot" error
|
||||
const int n_idle_slots = result.data.at("idle");
|
||||
if (req.has_param("fail_on_no_slot")) {
|
||||
if (n_idle_slots == 0) {
|
||||
res_error(res, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
res.set_content(result.data.at("slots").dump(), MIMETYPE_JSON);
|
||||
res.status = 200; // HTTP OK
|
||||
};
|
||||
|
||||
const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) {
|
||||
if (!params.endpoint_metrics) {
|
||||
res_error(res, format_error_response("This server does not support metrics endpoint.", ERROR_TYPE_NOT_SUPPORTED));
|
||||
res_error(res, format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -2891,7 +2863,7 @@ int main(int argc, char ** argv) {
|
|||
if (result.error) {
|
||||
res_error(res, result.data);
|
||||
} else {
|
||||
res.set_content(result.data.dump(), "application/json");
|
||||
res.set_content(result.data.dump(), MIMETYPE_JSON);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -2921,7 +2893,7 @@ int main(int argc, char ** argv) {
|
|||
if (result.error) {
|
||||
res_error(res, result.data);
|
||||
} else {
|
||||
res.set_content(result.data.dump(), "application/json");
|
||||
res.set_content(result.data.dump(), MIMETYPE_JSON);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -2941,13 +2913,11 @@ int main(int argc, char ** argv) {
|
|||
if (result.error) {
|
||||
res_error(res, result.data);
|
||||
} else {
|
||||
res.set_content(result.data.dump(), "application/json");
|
||||
res.set_content(result.data.dump(), MIMETYPE_JSON);
|
||||
}
|
||||
};
|
||||
|
||||
const auto handle_slots_action = [&res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
|
||||
std::string id_slot_str = req.path_params.at("id_slot");
|
||||
int id_slot;
|
||||
|
||||
|
@ -2971,19 +2941,30 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
};
|
||||
|
||||
const auto handle_props = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
const auto handle_props = [&ctx_server](const httplib::Request &, httplib::Response & res) {
|
||||
std::string template_key = "tokenizer.chat_template", curr_tmpl;
|
||||
int32_t tlen = llama_model_meta_val_str(ctx_server.model, template_key.c_str(), nullptr, 0);
|
||||
if (tlen > 0) {
|
||||
std::vector<char> curr_tmpl_buf(tlen + 1, 0);
|
||||
if (llama_model_meta_val_str(ctx_server.model, template_key.c_str(), curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) {
|
||||
curr_tmpl = std::string(curr_tmpl_buf.data(), tlen);
|
||||
}
|
||||
}
|
||||
json data = {
|
||||
{ "system_prompt", ctx_server.system_prompt.c_str() },
|
||||
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
|
||||
{ "total_slots", ctx_server.params.n_parallel }
|
||||
{ "total_slots", ctx_server.params.n_parallel },
|
||||
{ "chat_template", curr_tmpl.c_str() }
|
||||
};
|
||||
|
||||
res.set_content(data.dump(), "application/json; charset=utf-8");
|
||||
res.set_content(data.dump(), MIMETYPE_JSON);
|
||||
};
|
||||
|
||||
const auto handle_completions = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
if (ctx_server.params.embedding) {
|
||||
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
json data = json::parse(req.body);
|
||||
|
||||
|
@ -2995,7 +2976,7 @@ int main(int argc, char ** argv) {
|
|||
if (!json_value(data, "stream", false)) {
|
||||
server_task_result result = ctx_server.queue_results.recv(id_task);
|
||||
if (!result.error && result.stop) {
|
||||
res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8");
|
||||
res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
|
||||
} else {
|
||||
res_error(res, result.data);
|
||||
}
|
||||
|
@ -3058,9 +3039,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
};
|
||||
|
||||
const auto handle_models = [¶ms, &model_meta](const httplib::Request & req, httplib::Response & res) {
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
|
||||
const auto handle_models = [¶ms, &ctx_server](const httplib::Request &, httplib::Response & res) {
|
||||
json models = {
|
||||
{"object", "list"},
|
||||
{"data", {
|
||||
|
@ -3069,16 +3048,19 @@ int main(int argc, char ** argv) {
|
|||
{"object", "model"},
|
||||
{"created", std::time(0)},
|
||||
{"owned_by", "llamacpp"},
|
||||
{"meta", model_meta}
|
||||
{"meta", ctx_server.model_meta()}
|
||||
},
|
||||
}}
|
||||
};
|
||||
|
||||
res.set_content(models.dump(), "application/json; charset=utf-8");
|
||||
res.set_content(models.dump(), MIMETYPE_JSON);
|
||||
};
|
||||
|
||||
const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error](const httplib::Request & req, httplib::Response & res) {
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
if (ctx_server.params.embedding) {
|
||||
res_error(res, format_error_response("This server does not support chat completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
|
||||
|
||||
const int id_task = ctx_server.queue_tasks.get_new_id();
|
||||
|
@ -3093,7 +3075,7 @@ int main(int argc, char ** argv) {
|
|||
if (!result.error && result.stop) {
|
||||
json result_oai = format_final_response_oaicompat(data, result.data, completion_id);
|
||||
|
||||
res.set_content(result_oai.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8");
|
||||
res.set_content(result_oai.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
|
||||
} else {
|
||||
res_error(res, result.data);
|
||||
}
|
||||
|
@ -3150,7 +3132,10 @@ int main(int argc, char ** argv) {
|
|||
};
|
||||
|
||||
const auto handle_infill = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
if (ctx_server.params.embedding) {
|
||||
res_error(res, format_error_response("This server does not support infill. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
json data = json::parse(req.body);
|
||||
|
||||
|
@ -3162,7 +3147,7 @@ int main(int argc, char ** argv) {
|
|||
if (!json_value(data, "stream", false)) {
|
||||
server_task_result result = ctx_server.queue_results.recv(id_task);
|
||||
if (!result.error && result.stop) {
|
||||
res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8");
|
||||
res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
|
||||
} else {
|
||||
res_error(res, result.data);
|
||||
}
|
||||
|
@ -3210,7 +3195,6 @@ int main(int argc, char ** argv) {
|
|||
};
|
||||
|
||||
const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
const json body = json::parse(req.body);
|
||||
|
||||
std::vector<llama_token> tokens;
|
||||
|
@ -3219,11 +3203,10 @@ int main(int argc, char ** argv) {
|
|||
tokens = ctx_server.tokenize(body.at("content"), add_special);
|
||||
}
|
||||
const json data = format_tokenizer_response(tokens);
|
||||
return res.set_content(data.dump(), "application/json; charset=utf-8");
|
||||
return res.set_content(data.dump(), MIMETYPE_JSON);
|
||||
};
|
||||
|
||||
const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
const json body = json::parse(req.body);
|
||||
|
||||
std::string content;
|
||||
|
@ -3233,17 +3216,10 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
const json data = format_detokenized_response(content);
|
||||
return res.set_content(data.dump(), "application/json; charset=utf-8");
|
||||
return res.set_content(data.dump(), MIMETYPE_JSON);
|
||||
};
|
||||
|
||||
const auto handle_embeddings = [¶ms, &ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
if (!params.embedding) {
|
||||
res.status = 501;
|
||||
res.set_content("This server does not support embeddings. Start it with `--embeddings`", "text/plain; charset=utf-8");
|
||||
return;
|
||||
}
|
||||
|
||||
const auto handle_embeddings = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
|
||||
const json body = json::parse(req.body);
|
||||
bool is_openai = false;
|
||||
|
||||
|
@ -3289,7 +3265,53 @@ int main(int argc, char ** argv) {
|
|||
json root = is_openai
|
||||
? format_embeddings_response_oaicompat(body, responses)
|
||||
: responses[0];
|
||||
return res.set_content(root.dump(), "application/json; charset=utf-8");
|
||||
return res.set_content(root.dump(), MIMETYPE_JSON);
|
||||
};
|
||||
|
||||
const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
|
||||
json result = json::array();
|
||||
for (size_t i = 0; i < ctx_server.lora_adapters.size(); ++i) {
|
||||
auto & la = ctx_server.lora_adapters[i];
|
||||
result.push_back({
|
||||
{"id", i},
|
||||
{"path", la.path},
|
||||
{"scale", la.scale},
|
||||
});
|
||||
}
|
||||
res.set_content(result.dump(), MIMETYPE_JSON);
|
||||
res.status = 200; // HTTP OK
|
||||
};
|
||||
|
||||
const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) {
|
||||
const std::vector<json> body = json::parse(req.body);
|
||||
int max_idx = ctx_server.lora_adapters.size();
|
||||
|
||||
// clear existing value
|
||||
for (auto & la : ctx_server.lora_adapters) {
|
||||
la.scale = 0.0f;
|
||||
}
|
||||
|
||||
// set value
|
||||
for (auto entry : body) {
|
||||
int id = entry.at("id");
|
||||
float scale = entry.at("scale");
|
||||
if (0 <= id && id < max_idx) {
|
||||
ctx_server.lora_adapters[id].scale = scale;
|
||||
} else {
|
||||
throw std::runtime_error("invalid adapter id");
|
||||
}
|
||||
}
|
||||
|
||||
server_task task;
|
||||
task.type = SERVER_TASK_TYPE_SET_LORA;
|
||||
const int id_task = ctx_server.queue_tasks.post(task);
|
||||
ctx_server.queue_results.add_waiting_task_id(id_task);
|
||||
|
||||
server_task_result result = ctx_server.queue_results.recv(id_task);
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
|
||||
res.set_content(result.data.dump(), MIMETYPE_JSON);
|
||||
res.status = 200; // HTTP OK
|
||||
};
|
||||
|
||||
auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) {
|
||||
|
@ -3330,7 +3352,6 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// register API routes
|
||||
svr->Get ("/health", handle_health);
|
||||
svr->Get ("/slots", handle_slots);
|
||||
svr->Get ("/metrics", handle_metrics);
|
||||
svr->Get ("/props", handle_props);
|
||||
svr->Get ("/v1/models", handle_models);
|
||||
|
@ -3345,6 +3366,11 @@ int main(int argc, char ** argv) {
|
|||
svr->Post("/v1/embeddings", handle_embeddings);
|
||||
svr->Post("/tokenize", handle_tokenize);
|
||||
svr->Post("/detokenize", handle_detokenize);
|
||||
// LoRA adapters hotswap
|
||||
svr->Get ("/lora-adapters", handle_lora_adapters_list);
|
||||
svr->Post("/lora-adapters", handle_lora_adapters_apply);
|
||||
// Save & load slots
|
||||
svr->Get ("/slots", handle_slots);
|
||||
if (!params.slot_save_path.empty()) {
|
||||
// only enable slot endpoints if slot_save_path is set
|
||||
svr->Post("/slots/:id_slot", handle_slots_action);
|
||||
|
@ -3360,35 +3386,75 @@ int main(int argc, char ** argv) {
|
|||
log_data["n_threads_http"] = std::to_string(params.n_threads_http);
|
||||
svr->new_task_queue = [¶ms] { return new httplib::ThreadPool(params.n_threads_http); };
|
||||
|
||||
LOG_INFO("HTTP server listening", log_data);
|
||||
// clean up function, to be called before exit
|
||||
auto clean_up = [&svr]() {
|
||||
svr->stop();
|
||||
llama_backend_free();
|
||||
};
|
||||
|
||||
// run the HTTP server in a thread - see comment below
|
||||
std::thread t([&]() {
|
||||
if (!svr->listen_after_bind()) {
|
||||
state.store(SERVER_STATE_ERROR);
|
||||
return 1;
|
||||
// bind HTTP listen port, run the HTTP server in a thread
|
||||
if (!svr->bind_to_port(params.hostname, params.port)) {
|
||||
LOG_ERROR("couldn't bind HTTP server socket", {
|
||||
{"hostname", params.hostname},
|
||||
{"port", params.port},
|
||||
});
|
||||
clean_up();
|
||||
LOG_ERROR("exiting due to HTTP server error", {});
|
||||
return 1;
|
||||
}
|
||||
std::thread t([&]() { svr->listen_after_bind(); });
|
||||
svr->wait_until_ready();
|
||||
|
||||
LOG_INFO("HTTP server is listening", log_data);
|
||||
|
||||
// load the model
|
||||
LOG_INFO("loading model", log_data);
|
||||
if (!ctx_server.load_model(params)) {
|
||||
clean_up();
|
||||
t.join();
|
||||
LOG_ERROR("exiting due to model loading error", {});
|
||||
return 1;
|
||||
} else {
|
||||
ctx_server.init();
|
||||
state.store(SERVER_STATE_READY);
|
||||
|
||||
LOG_INFO("model loaded", {});
|
||||
|
||||
// if a custom chat template is not supplied, we will use the one that comes with the model (if any)
|
||||
if (params.chat_template.empty()) {
|
||||
if (!ctx_server.validate_model_chat_template()) {
|
||||
LOG_WARNING("The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {});
|
||||
params.chat_template = "chatml";
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
});
|
||||
// print sample chat example to make it clear which template is used
|
||||
{
|
||||
LOG_INFO("chat template", {
|
||||
{"chat_example", llama_chat_format_example(ctx_server.model, params.chat_template)},
|
||||
{"built_in", params.chat_template.empty()},
|
||||
});
|
||||
}
|
||||
|
||||
ctx_server.queue_tasks.on_new_task(std::bind(
|
||||
&server_context::process_single_task, &ctx_server, std::placeholders::_1));
|
||||
ctx_server.queue_tasks.on_finish_multitask(std::bind(
|
||||
&server_context::on_finish_multitask, &ctx_server, std::placeholders::_1));
|
||||
ctx_server.queue_tasks.on_update_slots(std::bind(
|
||||
&server_context::update_slots, &ctx_server));
|
||||
ctx_server.queue_results.on_multitask_update(std::bind(
|
||||
&server_queue::update_multitask,
|
||||
&ctx_server.queue_tasks,
|
||||
std::placeholders::_1,
|
||||
std::placeholders::_2,
|
||||
std::placeholders::_3
|
||||
));
|
||||
ctx_server.queue_tasks.on_new_task(std::bind(
|
||||
&server_context::process_single_task, &ctx_server, std::placeholders::_1));
|
||||
ctx_server.queue_tasks.on_finish_multitask(std::bind(
|
||||
&server_context::on_finish_multitask, &ctx_server, std::placeholders::_1));
|
||||
ctx_server.queue_tasks.on_update_slots(std::bind(
|
||||
&server_context::update_slots, &ctx_server));
|
||||
ctx_server.queue_results.on_multitask_update(std::bind(
|
||||
&server_queue::update_multitask,
|
||||
&ctx_server.queue_tasks,
|
||||
std::placeholders::_1,
|
||||
std::placeholders::_2,
|
||||
std::placeholders::_3
|
||||
));
|
||||
|
||||
shutdown_handler = [&](int) {
|
||||
ctx_server.queue_tasks.terminate();
|
||||
};
|
||||
shutdown_handler = [&](int) {
|
||||
ctx_server.queue_tasks.terminate();
|
||||
};
|
||||
ctx_server.queue_tasks.start_loop();
|
||||
}
|
||||
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
||||
struct sigaction sigint_action;
|
||||
|
@ -3404,12 +3470,8 @@ int main(int argc, char ** argv) {
|
|||
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
|
||||
#endif
|
||||
|
||||
ctx_server.queue_tasks.start_loop();
|
||||
|
||||
svr->stop();
|
||||
clean_up();
|
||||
t.join();
|
||||
|
||||
llama_backend_free();
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue