some var names, state fixes + improvement performance
This commit is contained in:
parent
9a1039d9ee
commit
9e6e714dc5
2 changed files with 127 additions and 70 deletions
|
@ -1,22 +1,27 @@
|
||||||
|
|
||||||
const auto index_html = R"(
|
const auto index_html = R"(
|
||||||
<!DOCTYPE html>
|
<!DOCTYPE html>
|
||||||
<html>
|
<html>
|
||||||
<head>
|
<head>
|
||||||
<title>llama.cpp - server parallel PoC</title>
|
<title>llama.cpp - server parallel PoC</title>
|
||||||
</head>
|
</head>
|
||||||
<body>
|
<body>
|
||||||
<div style="width: 90%;margin: auto;">
|
<div style="width: 90%;margin: auto;">
|
||||||
<h2>Server parallel - Proof of Concept</h2>
|
<h2>Server parallel - PoC</h2>
|
||||||
<form id="myForm">
|
<form id="myForm" >
|
||||||
<label for="client_slot">Client Slot (-1 load in a idle client)</label>
|
<label for="slot_id">Slot ID (-1 load in a idle slot)</label>
|
||||||
<input type="number" id="client_slot" value="-1" required>
|
<input type="number" id="slot_id" value="-1" required>
|
||||||
<br><br>
|
<br>
|
||||||
|
<label for="temperature">Temperature</label>
|
||||||
|
<input type="number" id="temperature" value="0.1" required>
|
||||||
|
<br>
|
||||||
<label for="message">Message</label>
|
<label for="message">Message</label>
|
||||||
<input id="message" style="width: 80%;" required>
|
<input id="message" style="width: 80%;" required>
|
||||||
<br><br>
|
<br><br>
|
||||||
<button type="button" id="btn_send" onclick="perform() " >Send</button>
|
<button type="button" id="btn_send" onclick="perform() " >Send</button>
|
||||||
<button type="button" onclick="reset() " >Reset</button>
|
<br>
|
||||||
|
<br>
|
||||||
|
<button type="button" id="btn_reset" onclick="resetBtn() " >Reset</button>
|
||||||
</form>
|
</form>
|
||||||
<div id="conversation_view">
|
<div id="conversation_view">
|
||||||
</div>
|
</div>
|
||||||
|
@ -26,9 +31,20 @@ const auto index_html = R"(
|
||||||
let conversation = [];
|
let conversation = [];
|
||||||
let current_message = -1;
|
let current_message = -1;
|
||||||
const questions = ["Who is Elon Musk?", "Who is Jeff Bezos?", "How to get a job at google?", "What are you?", "When was born Abraham Lincoln?"];
|
const questions = ["Who is Elon Musk?", "Who is Jeff Bezos?", "How to get a job at google?", "What are you?", "When was born Abraham Lincoln?"];
|
||||||
window.onload = function() {
|
|
||||||
|
docReady(() => {
|
||||||
document.getElementById("message").value = questions[Math.floor(Math.random() * questions.length)];
|
document.getElementById("message").value = questions[Math.floor(Math.random() * questions.length)];
|
||||||
};
|
});
|
||||||
|
|
||||||
|
function docReady(fn) {
|
||||||
|
// see if DOM is already available
|
||||||
|
if (document.readyState === "complete" || document.readyState === "interactive") {
|
||||||
|
// call on next available tick
|
||||||
|
setTimeout(fn, 1);
|
||||||
|
} else {
|
||||||
|
document.addEventListener("DOMContentLoaded", fn);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
function updateView() {
|
function updateView() {
|
||||||
let conv_view = document.getElementById("conversation_view");
|
let conv_view = document.getElementById("conversation_view");
|
||||||
|
@ -64,6 +80,7 @@ const auto index_html = R"(
|
||||||
while (cont) {
|
while (cont) {
|
||||||
const result = await reader.read();
|
const result = await reader.read();
|
||||||
if (result.done) {
|
if (result.done) {
|
||||||
|
document.getElementById("btn_send").disabled = false;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -108,43 +125,51 @@ const auto index_html = R"(
|
||||||
}
|
}
|
||||||
|
|
||||||
function generatePrompt() {
|
function generatePrompt() {
|
||||||
|
// generate a good prompt to have coherence
|
||||||
let prompt = '';
|
let prompt = '';
|
||||||
for(let index in conversation) {
|
for(let index in conversation) {
|
||||||
if(index == 0) {
|
if(index == 0) {
|
||||||
prompt += conversation[index].user + "\n";
|
prompt += conversation[index].user + "\n";
|
||||||
} else {
|
} else {
|
||||||
prompt += "User: " + conversation[index].user + "\n";
|
prompt += "User:" + conversation[index].user + "\n";
|
||||||
}
|
}
|
||||||
if(index == current_message) {
|
if(index == current_message) {
|
||||||
prompt += "Assistant:";
|
prompt += "Assistant:";
|
||||||
} else {
|
} else {
|
||||||
prompt += "Assistant: " + conversation[index].assistant;
|
prompt += "Assistant:" + conversation[index].assistant;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return prompt;
|
return prompt;
|
||||||
}
|
}
|
||||||
|
|
||||||
function reset() {
|
function resetBtn() {
|
||||||
conversation = [];
|
document.getElementById("slot_id").value = "-1";
|
||||||
document.getElementById("client_slot").value = "-1";
|
document.getElementById("temperature").value = "0.1";
|
||||||
document.getElementById("message").value = "";
|
document.getElementById("message").value = questions[Math.floor(Math.random() * questions.length)];
|
||||||
document.getElementById("conversation_view").innerHTML = "";
|
document.getElementById("conversation_view").innerHTML = "";
|
||||||
|
conversation = [];
|
||||||
|
current_message = -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
async function perform() {
|
async function perform() {
|
||||||
var client_slot = parseFloat(document.getElementById("client_slot").value);
|
document.getElementById("btn_send").disabled = true;
|
||||||
var prompt = document.getElementById("message").value;
|
var slot_id = parseInt(document.getElementById("slot_id").value);
|
||||||
if (!isNaN(client_slot) && prompt.length > 0) {
|
var temperature = parseFloat(document.getElementById("temperature").value);
|
||||||
|
var prompt = " " + document.getElementById("message").value;
|
||||||
|
if (!isNaN(slot_id) && !isNaN(temperature) && prompt.length > 0) {
|
||||||
current_message++;
|
current_message++;
|
||||||
conversation.push({
|
conversation.push({
|
||||||
user: prompt,
|
user: prompt,
|
||||||
assistant: ''
|
assistant: ''
|
||||||
});
|
});
|
||||||
updateView();
|
updateView();
|
||||||
|
document.getElementById("message").value = "";
|
||||||
await call_llama({
|
await call_llama({
|
||||||
client_slot,
|
slot_id,
|
||||||
|
temperature,
|
||||||
prompt: generatePrompt()
|
prompt: generatePrompt()
|
||||||
});
|
});
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
document.getElementById("conversation_view").innerText = "please, insert valid props.";
|
document.getElementById("conversation_view").innerText = "please, insert valid props.";
|
||||||
}
|
}
|
||||||
|
|
|
@ -59,9 +59,14 @@ enum stop_type
|
||||||
|
|
||||||
enum slot_state
|
enum slot_state
|
||||||
{
|
{
|
||||||
BUSY,
|
|
||||||
IDLE,
|
IDLE,
|
||||||
NEXT_TOKEN
|
PROCESSING
|
||||||
|
};
|
||||||
|
|
||||||
|
enum slot_command {
|
||||||
|
NONE,
|
||||||
|
LOAD_PROMPT,
|
||||||
|
RELEASE
|
||||||
};
|
};
|
||||||
|
|
||||||
static std::string system_prompt =
|
static std::string system_prompt =
|
||||||
|
@ -80,15 +85,38 @@ struct llama_client_slot
|
||||||
int32_t n_prompt = 0;
|
int32_t n_prompt = 0;
|
||||||
int32_t n_decoded = 0;
|
int32_t n_decoded = 0;
|
||||||
int32_t i_batch = -1;
|
int32_t i_batch = -1;
|
||||||
bool process_prompt = false;
|
|
||||||
bool release_slot = false;
|
|
||||||
bool forced_release = false;
|
|
||||||
string prompt = "";
|
string prompt = "";
|
||||||
string sampled_token_str;
|
string sampled_token_str;
|
||||||
string generated_text;
|
string generated_text = "";
|
||||||
llama_token sampled;
|
llama_token sampled;
|
||||||
std::vector<llama_token> tokens_prev;
|
std::vector<llama_token> tokens_prev;
|
||||||
slot_state state = IDLE;
|
slot_state state = IDLE;
|
||||||
|
slot_command command = NONE;
|
||||||
|
bool newToken = false;
|
||||||
|
float temperature = 0.1f;
|
||||||
|
|
||||||
|
void start(string prompt_, float temp_) {
|
||||||
|
prompt = prompt_;
|
||||||
|
command = LOAD_PROMPT;
|
||||||
|
temperature = temp_;
|
||||||
|
newToken = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool hasNewToken() {
|
||||||
|
if(newToken) {
|
||||||
|
newToken = false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool available() {
|
||||||
|
return state == IDLE && command == NONE;
|
||||||
|
}
|
||||||
|
|
||||||
|
void nofity() {
|
||||||
|
newToken = !newToken;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct server_parallel_context {
|
struct server_parallel_context {
|
||||||
|
@ -131,7 +159,7 @@ struct server_parallel_context {
|
||||||
slot.state = IDLE;
|
slot.state = IDLE;
|
||||||
slot.tokens_prev.resize(std::max(256, params.n_predict));
|
slot.tokens_prev.resize(std::max(256, params.n_predict));
|
||||||
std::fill(slot.tokens_prev.begin(), slot.tokens_prev.end(), 0);
|
std::fill(slot.tokens_prev.begin(), slot.tokens_prev.end(), 0);
|
||||||
LOG_TEE(" -> client slot: %i\n", slot.id);
|
LOG_TEE(" - slot %i\n", slot.id);
|
||||||
slots.push_back(slot);
|
slots.push_back(slot);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -169,16 +197,15 @@ struct server_parallel_context {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_client_slot* loadPrompt(int slot_id, string prompt) {
|
llama_client_slot* loadPrompt(int slot_id, string prompt, float temp_) {
|
||||||
for (llama_client_slot & slot : slots)
|
for (llama_client_slot & slot : slots)
|
||||||
{
|
{
|
||||||
if (
|
if (
|
||||||
slot_id == -1 && slot.state == IDLE ||
|
slot_id == -1 && slot.available() ||
|
||||||
slot.id == slot_id)
|
slot.id == slot_id)
|
||||||
{
|
{
|
||||||
slot.prompt = prompt;
|
slot.start(prompt, temp_);
|
||||||
slot.process_prompt = true;
|
LOG_TEE("slot %i is processing\n", slot.id);
|
||||||
LOG_TEE("client %i is workloaded\n", slot.id);
|
|
||||||
return &slot; // return a pointer to slot (thread safe?)
|
return &slot; // return a pointer to slot (thread safe?)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -211,24 +238,26 @@ struct server_parallel_context {
|
||||||
return stop_pos;
|
return stop_pos;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
bool updateSlots() {
|
bool updateSlots() {
|
||||||
batch.n_tokens = 0;
|
batch.n_tokens = 0;
|
||||||
|
|
||||||
// decode any currently ongoing sequences
|
// decode any currently ongoing sequences
|
||||||
for (auto & slot : slots) {
|
for (auto & slot : slots) {
|
||||||
if(slot.release_slot && slot.state == BUSY || slot.forced_release) {
|
if (slot.state == PROCESSING && slot.command == RELEASE)
|
||||||
if(slot.forced_release) {
|
{
|
||||||
llama_kv_cache_seq_rm(ctx, slot.id, n_tokens_system, n_ctx);
|
llama_kv_cache_seq_rm(ctx, slot.id, n_tokens_system, n_ctx);
|
||||||
slot.forced_release = false;
|
|
||||||
}
|
|
||||||
LOG_TEE("client %i is released\n", slot.id);
|
|
||||||
slot.state = IDLE;
|
slot.state = IDLE;
|
||||||
slot.release_slot = false;
|
LOG_TEE("slot %i is released\n", slot.id);
|
||||||
|
slot.command = NONE;
|
||||||
}
|
}
|
||||||
if (slot.state == IDLE) {
|
|
||||||
|
// no decode wait until the token had been send to client
|
||||||
|
// improves performance and avoid decoherence?
|
||||||
|
|
||||||
|
if (slot.state == IDLE || slot.newToken) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
batch.token [batch.n_tokens] = slot.sampled;
|
batch.token [batch.n_tokens] = slot.sampled;
|
||||||
batch.pos [batch.n_tokens] = n_tokens_system + slot.n_prompt + slot.n_decoded;
|
batch.pos [batch.n_tokens] = n_tokens_system + slot.n_prompt + slot.n_decoded;
|
||||||
batch.seq_id[batch.n_tokens] = slot.id;
|
batch.seq_id[batch.n_tokens] = slot.id;
|
||||||
|
@ -253,10 +282,11 @@ struct server_parallel_context {
|
||||||
// assign workload to the slots
|
// assign workload to the slots
|
||||||
if (params.cont_batching || batch.n_tokens == 0) {
|
if (params.cont_batching || batch.n_tokens == 0) {
|
||||||
for (llama_client_slot & slot : slots) {
|
for (llama_client_slot & slot : slots) {
|
||||||
if (slot.state == IDLE && slot.process_prompt) {
|
// need process the prompt
|
||||||
slot.state = BUSY;
|
if (slot.state == IDLE && slot.command == LOAD_PROMPT) {
|
||||||
slot.process_prompt = false;
|
slot.state = PROCESSING;
|
||||||
//LOG_TEE("client %i process prompt:\n%s'------------------------------\n", slot.id, slot.prompt.c_str());
|
slot.command = NONE;
|
||||||
|
//LOG_TEE("slot %i process prompt:\n%s%s'------------------------------\n", slot.id, system_prompt.c_str(), slot.prompt.c_str());
|
||||||
std::fill(slot.tokens_prev.begin(), slot.tokens_prev.end(), 0);
|
std::fill(slot.tokens_prev.begin(), slot.tokens_prev.end(), 0);
|
||||||
|
|
||||||
// do not prepend BOS because we have a system prompt!
|
// do not prepend BOS because we have a system prompt!
|
||||||
|
@ -328,7 +358,6 @@ struct server_parallel_context {
|
||||||
// retry with half the batch size to try to find a free slot in the KV cache
|
// retry with half the batch size to try to find a free slot in the KV cache
|
||||||
n_batch /= 2;
|
n_batch /= 2;
|
||||||
i -= n_batch;
|
i -= n_batch;
|
||||||
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -337,6 +366,7 @@ struct server_parallel_context {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
params.temp = slot.temperature;
|
||||||
const llama_token id = llama_sample_token(ctx, NULL, NULL, params, slot.tokens_prev, candidates, slot.i_batch - i);
|
const llama_token id = llama_sample_token(ctx, NULL, NULL, params, slot.tokens_prev, candidates, slot.i_batch - i);
|
||||||
|
|
||||||
// remember which tokens were sampled - used for repetition penalties during sampling
|
// remember which tokens were sampled - used for repetition penalties during sampling
|
||||||
|
@ -353,7 +383,8 @@ struct server_parallel_context {
|
||||||
findStoppingStrings(slot.generated_text, token_str.size(), STOP_FULL);
|
findStoppingStrings(slot.generated_text, token_str.size(), STOP_FULL);
|
||||||
|
|
||||||
slot.sampled_token_str = token_str;
|
slot.sampled_token_str = token_str;
|
||||||
slot.state = NEXT_TOKEN;
|
// notify new token
|
||||||
|
slot.nofity();
|
||||||
|
|
||||||
if (slot.n_decoded > 2 &&
|
if (slot.n_decoded > 2 &&
|
||||||
(id == llama_token_eos(ctx) ||
|
(id == llama_token_eos(ctx) ||
|
||||||
|
@ -361,11 +392,9 @@ struct server_parallel_context {
|
||||||
slot.n_decoded + slot.n_prompt >=
|
slot.n_decoded + slot.n_prompt >=
|
||||||
params.n_predict) ||
|
params.n_predict) ||
|
||||||
stop_pos != std::string::npos)) {
|
stop_pos != std::string::npos)) {
|
||||||
// delete only the generated part of the sequence, i.e. keep the system prompt in the cache
|
//LOG_TEE("slot %i generated text:\n%s'------------------------------\n", slot.id, slot.generated_text.c_str());
|
||||||
llama_kv_cache_seq_rm(ctx, slot.id, n_tokens_system, n_ctx);
|
|
||||||
//LOG_TEE("client %i generated text:\n%s'------------------------------\n", slot.id, slot.generated_text.c_str());
|
|
||||||
slot.generated_text.clear();
|
slot.generated_text.clear();
|
||||||
slot.release_slot = true;
|
slot.command = RELEASE;
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.i_batch = -1;
|
slot.i_batch = -1;
|
||||||
|
@ -712,16 +741,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void processClient(server_parallel_context* ctx)
|
|
||||||
{
|
|
||||||
bool running = true;
|
|
||||||
while (running)
|
|
||||||
{
|
|
||||||
running = ctx->updateSlots();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int main(int argc, char **argv)
|
int main(int argc, char **argv)
|
||||||
{
|
{
|
||||||
gpt_params params;
|
gpt_params params;
|
||||||
|
@ -730,6 +749,12 @@ int main(int argc, char **argv)
|
||||||
|
|
||||||
server_params_parse(argc, argv, sparams, params);
|
server_params_parse(argc, argv, sparams, params);
|
||||||
|
|
||||||
|
#ifndef LOG_DISABLE_LOGS
|
||||||
|
log_set_target(log_filename_generator("server-parallel", "log"));
|
||||||
|
LOG_TEE("Log start\n");
|
||||||
|
log_dump_cmdline(argc, argv);
|
||||||
|
#endif // LOG_DISABLE_LOGS
|
||||||
|
|
||||||
llama_backend_init(params.numa);
|
llama_backend_init(params.numa);
|
||||||
|
|
||||||
// load the target model
|
// load the target model
|
||||||
|
@ -754,28 +779,29 @@ int main(int argc, char **argv)
|
||||||
svr.Post("/completion", [&llama](const Request &req, Response &res)
|
svr.Post("/completion", [&llama](const Request &req, Response &res)
|
||||||
{
|
{
|
||||||
json data = json::parse(req.body);
|
json data = json::parse(req.body);
|
||||||
int slot_id = data.value("client_slot", -1);
|
int slot_id = data.value("slot_id", -1);
|
||||||
|
float temperature = data.value("temperature", 0.8f);
|
||||||
string prompt = data.value("prompt", "");
|
string prompt = data.value("prompt", "");
|
||||||
llama_client_slot* slot_client = llama.loadPrompt(slot_id, prompt);
|
llama_client_slot* slot = llama.loadPrompt(slot_id, prompt, temperature);
|
||||||
|
|
||||||
// Verify if the slot exist
|
// Verify if the slot exist
|
||||||
if (slot_client) {
|
if (slot) {
|
||||||
res.set_chunked_content_provider("text/event-stream",
|
res.set_chunked_content_provider("text/event-stream",
|
||||||
[slot_client](size_t /*offset*/, DataSink &sink) {
|
[slot](size_t /*offset*/, DataSink &sink) {
|
||||||
if(slot_client->state == IDLE && !slot_client->process_prompt) { // slot has been released
|
if(slot->available()) { // slot has been released
|
||||||
sink.done();
|
sink.done();
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if(slot_client->state == NEXT_TOKEN) { // new token notification
|
|
||||||
|
if(slot->hasNewToken()) { // new token notification
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
json res_d = {{"token", slot_client->sampled_token_str}};
|
json res_d = {{"token", slot->sampled_token_str}};
|
||||||
ss << "data: " << res_d.dump() << "\n\n";
|
ss << "data: " << res_d.dump() << "\n\n";
|
||||||
string result = ss.str();
|
string result = ss.str();
|
||||||
if(!sink.write(result.c_str(), result.size())) { // user request release
|
if(!sink.write(result.c_str(), result.size())) { // user request release
|
||||||
slot_client->forced_release = true;
|
slot->command = RELEASE;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
slot_client->state = BUSY; // process next token
|
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
});
|
});
|
||||||
|
@ -785,7 +811,13 @@ int main(int argc, char **argv)
|
||||||
res.set_content("slot_error", "text/plain");
|
res.set_content("slot_error", "text/plain");
|
||||||
} });
|
} });
|
||||||
|
|
||||||
thread t(processClient, &llama);
|
thread t([&llama]()
|
||||||
|
{
|
||||||
|
bool running = true;
|
||||||
|
while (running)
|
||||||
|
{
|
||||||
|
running = llama.updateSlots();
|
||||||
|
} });
|
||||||
|
|
||||||
svr.set_read_timeout(sparams.read_timeout);
|
svr.set_read_timeout(sparams.read_timeout);
|
||||||
svr.set_write_timeout(sparams.write_timeout);
|
svr.set_write_timeout(sparams.write_timeout);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue