some var names, state fixes + improvement performance

This commit is contained in:
FSSRepo 2023-10-04 14:49:35 -04:00
parent 9a1039d9ee
commit 9e6e714dc5
2 changed files with 127 additions and 70 deletions

View file

@ -7,16 +7,21 @@ const auto index_html = R"(
</head> </head>
<body> <body>
<div style="width: 90%;margin: auto;"> <div style="width: 90%;margin: auto;">
<h2>Server parallel - Proof of Concept</h2> <h2>Server parallel - PoC</h2>
<form id="myForm" > <form id="myForm" >
<label for="client_slot">Client Slot (-1 load in a idle client)</label> <label for="slot_id">Slot ID (-1 load in a idle slot)</label>
<input type="number" id="client_slot" value="-1" required> <input type="number" id="slot_id" value="-1" required>
<br><br> <br>
<label for="temperature">Temperature</label>
<input type="number" id="temperature" value="0.1" required>
<br>
<label for="message">Message</label> <label for="message">Message</label>
<input id="message" style="width: 80%;" required> <input id="message" style="width: 80%;" required>
<br><br> <br><br>
<button type="button" id="btn_send" onclick="perform() " >Send</button> <button type="button" id="btn_send" onclick="perform() " >Send</button>
<button type="button" onclick="reset() " >Reset</button> <br>
<br>
<button type="button" id="btn_reset" onclick="resetBtn() " >Reset</button>
</form> </form>
<div id="conversation_view"> <div id="conversation_view">
</div> </div>
@ -26,9 +31,20 @@ const auto index_html = R"(
let conversation = []; let conversation = [];
let current_message = -1; let current_message = -1;
const questions = ["Who is Elon Musk?", "Who is Jeff Bezos?", "How to get a job at google?", "What are you?", "When was born Abraham Lincoln?"]; const questions = ["Who is Elon Musk?", "Who is Jeff Bezos?", "How to get a job at google?", "What are you?", "When was born Abraham Lincoln?"];
window.onload = function() {
docReady(() => {
document.getElementById("message").value = questions[Math.floor(Math.random() * questions.length)]; document.getElementById("message").value = questions[Math.floor(Math.random() * questions.length)];
}; });
function docReady(fn) {
// see if DOM is already available
if (document.readyState === "complete" || document.readyState === "interactive") {
// call on next available tick
setTimeout(fn, 1);
} else {
document.addEventListener("DOMContentLoaded", fn);
}
}
function updateView() { function updateView() {
let conv_view = document.getElementById("conversation_view"); let conv_view = document.getElementById("conversation_view");
@ -64,6 +80,7 @@ const auto index_html = R"(
while (cont) { while (cont) {
const result = await reader.read(); const result = await reader.read();
if (result.done) { if (result.done) {
document.getElementById("btn_send").disabled = false;
break; break;
} }
@ -108,6 +125,7 @@ const auto index_html = R"(
} }
function generatePrompt() { function generatePrompt() {
// generate a good prompt to have coherence
let prompt = ''; let prompt = '';
for(let index in conversation) { for(let index in conversation) {
if(index == 0) { if(index == 0) {
@ -124,27 +142,34 @@ const auto index_html = R"(
return prompt; return prompt;
} }
function reset() { function resetBtn() {
conversation = []; document.getElementById("slot_id").value = "-1";
document.getElementById("client_slot").value = "-1"; document.getElementById("temperature").value = "0.1";
document.getElementById("message").value = ""; document.getElementById("message").value = questions[Math.floor(Math.random() * questions.length)];
document.getElementById("conversation_view").innerHTML = ""; document.getElementById("conversation_view").innerHTML = "";
conversation = [];
current_message = -1;
} }
async function perform() { async function perform() {
var client_slot = parseFloat(document.getElementById("client_slot").value); document.getElementById("btn_send").disabled = true;
var prompt = document.getElementById("message").value; var slot_id = parseInt(document.getElementById("slot_id").value);
if (!isNaN(client_slot) && prompt.length > 0) { var temperature = parseFloat(document.getElementById("temperature").value);
var prompt = " " + document.getElementById("message").value;
if (!isNaN(slot_id) && !isNaN(temperature) && prompt.length > 0) {
current_message++; current_message++;
conversation.push({ conversation.push({
user: prompt, user: prompt,
assistant: '' assistant: ''
}); });
updateView(); updateView();
document.getElementById("message").value = "";
await call_llama({ await call_llama({
client_slot, slot_id,
temperature,
prompt: generatePrompt() prompt: generatePrompt()
}); });
} else { } else {
document.getElementById("conversation_view").innerText = "please, insert valid props."; document.getElementById("conversation_view").innerText = "please, insert valid props.";
} }

View file

@ -59,9 +59,14 @@ enum stop_type
enum slot_state enum slot_state
{ {
BUSY,
IDLE, IDLE,
NEXT_TOKEN PROCESSING
};
enum slot_command {
NONE,
LOAD_PROMPT,
RELEASE
}; };
static std::string system_prompt = static std::string system_prompt =
@ -80,15 +85,38 @@ struct llama_client_slot
int32_t n_prompt = 0; int32_t n_prompt = 0;
int32_t n_decoded = 0; int32_t n_decoded = 0;
int32_t i_batch = -1; int32_t i_batch = -1;
bool process_prompt = false;
bool release_slot = false;
bool forced_release = false;
string prompt = ""; string prompt = "";
string sampled_token_str; string sampled_token_str;
string generated_text; string generated_text = "";
llama_token sampled; llama_token sampled;
std::vector<llama_token> tokens_prev; std::vector<llama_token> tokens_prev;
slot_state state = IDLE; slot_state state = IDLE;
slot_command command = NONE;
bool newToken = false;
float temperature = 0.1f;
void start(string prompt_, float temp_) {
prompt = prompt_;
command = LOAD_PROMPT;
temperature = temp_;
newToken = false;
}
bool hasNewToken() {
if(newToken) {
newToken = false;
return true;
}
return false;
}
bool available() {
return state == IDLE && command == NONE;
}
void nofity() {
newToken = !newToken;
}
}; };
struct server_parallel_context { struct server_parallel_context {
@ -131,7 +159,7 @@ struct server_parallel_context {
slot.state = IDLE; slot.state = IDLE;
slot.tokens_prev.resize(std::max(256, params.n_predict)); slot.tokens_prev.resize(std::max(256, params.n_predict));
std::fill(slot.tokens_prev.begin(), slot.tokens_prev.end(), 0); std::fill(slot.tokens_prev.begin(), slot.tokens_prev.end(), 0);
LOG_TEE(" -> client slot: %i\n", slot.id); LOG_TEE(" - slot %i\n", slot.id);
slots.push_back(slot); slots.push_back(slot);
} }
} }
@ -169,16 +197,15 @@ struct server_parallel_context {
return true; return true;
} }
llama_client_slot* loadPrompt(int slot_id, string prompt) { llama_client_slot* loadPrompt(int slot_id, string prompt, float temp_) {
for (llama_client_slot & slot : slots) for (llama_client_slot & slot : slots)
{ {
if ( if (
slot_id == -1 && slot.state == IDLE || slot_id == -1 && slot.available() ||
slot.id == slot_id) slot.id == slot_id)
{ {
slot.prompt = prompt; slot.start(prompt, temp_);
slot.process_prompt = true; LOG_TEE("slot %i is processing\n", slot.id);
LOG_TEE("client %i is workloaded\n", slot.id);
return &slot; // return a pointer to slot (thread safe?) return &slot; // return a pointer to slot (thread safe?)
} }
} }
@ -211,24 +238,26 @@ struct server_parallel_context {
return stop_pos; return stop_pos;
} }
bool updateSlots() { bool updateSlots() {
batch.n_tokens = 0; batch.n_tokens = 0;
// decode any currently ongoing sequences // decode any currently ongoing sequences
for (auto & slot : slots) { for (auto & slot : slots) {
if(slot.release_slot && slot.state == BUSY || slot.forced_release) { if (slot.state == PROCESSING && slot.command == RELEASE)
if(slot.forced_release) { {
llama_kv_cache_seq_rm(ctx, slot.id, n_tokens_system, n_ctx); llama_kv_cache_seq_rm(ctx, slot.id, n_tokens_system, n_ctx);
slot.forced_release = false;
}
LOG_TEE("client %i is released\n", slot.id);
slot.state = IDLE; slot.state = IDLE;
slot.release_slot = false; LOG_TEE("slot %i is released\n", slot.id);
slot.command = NONE;
} }
if (slot.state == IDLE) {
// no decode wait until the token had been send to client
// improves performance and avoid decoherence?
if (slot.state == IDLE || slot.newToken) {
continue; continue;
} }
batch.token [batch.n_tokens] = slot.sampled; batch.token [batch.n_tokens] = slot.sampled;
batch.pos [batch.n_tokens] = n_tokens_system + slot.n_prompt + slot.n_decoded; batch.pos [batch.n_tokens] = n_tokens_system + slot.n_prompt + slot.n_decoded;
batch.seq_id[batch.n_tokens] = slot.id; batch.seq_id[batch.n_tokens] = slot.id;
@ -253,10 +282,11 @@ struct server_parallel_context {
// assign workload to the slots // assign workload to the slots
if (params.cont_batching || batch.n_tokens == 0) { if (params.cont_batching || batch.n_tokens == 0) {
for (llama_client_slot & slot : slots) { for (llama_client_slot & slot : slots) {
if (slot.state == IDLE && slot.process_prompt) { // need process the prompt
slot.state = BUSY; if (slot.state == IDLE && slot.command == LOAD_PROMPT) {
slot.process_prompt = false; slot.state = PROCESSING;
//LOG_TEE("client %i process prompt:\n%s'------------------------------\n", slot.id, slot.prompt.c_str()); slot.command = NONE;
//LOG_TEE("slot %i process prompt:\n%s%s'------------------------------\n", slot.id, system_prompt.c_str(), slot.prompt.c_str());
std::fill(slot.tokens_prev.begin(), slot.tokens_prev.end(), 0); std::fill(slot.tokens_prev.begin(), slot.tokens_prev.end(), 0);
// do not prepend BOS because we have a system prompt! // do not prepend BOS because we have a system prompt!
@ -328,7 +358,6 @@ struct server_parallel_context {
// retry with half the batch size to try to find a free slot in the KV cache // retry with half the batch size to try to find a free slot in the KV cache
n_batch /= 2; n_batch /= 2;
i -= n_batch; i -= n_batch;
continue; continue;
} }
@ -337,6 +366,7 @@ struct server_parallel_context {
continue; continue;
} }
params.temp = slot.temperature;
const llama_token id = llama_sample_token(ctx, NULL, NULL, params, slot.tokens_prev, candidates, slot.i_batch - i); const llama_token id = llama_sample_token(ctx, NULL, NULL, params, slot.tokens_prev, candidates, slot.i_batch - i);
// remember which tokens were sampled - used for repetition penalties during sampling // remember which tokens were sampled - used for repetition penalties during sampling
@ -353,7 +383,8 @@ struct server_parallel_context {
findStoppingStrings(slot.generated_text, token_str.size(), STOP_FULL); findStoppingStrings(slot.generated_text, token_str.size(), STOP_FULL);
slot.sampled_token_str = token_str; slot.sampled_token_str = token_str;
slot.state = NEXT_TOKEN; // notify new token
slot.nofity();
if (slot.n_decoded > 2 && if (slot.n_decoded > 2 &&
(id == llama_token_eos(ctx) || (id == llama_token_eos(ctx) ||
@ -361,11 +392,9 @@ struct server_parallel_context {
slot.n_decoded + slot.n_prompt >= slot.n_decoded + slot.n_prompt >=
params.n_predict) || params.n_predict) ||
stop_pos != std::string::npos)) { stop_pos != std::string::npos)) {
// delete only the generated part of the sequence, i.e. keep the system prompt in the cache //LOG_TEE("slot %i generated text:\n%s'------------------------------\n", slot.id, slot.generated_text.c_str());
llama_kv_cache_seq_rm(ctx, slot.id, n_tokens_system, n_ctx);
//LOG_TEE("client %i generated text:\n%s'------------------------------\n", slot.id, slot.generated_text.c_str());
slot.generated_text.clear(); slot.generated_text.clear();
slot.release_slot = true; slot.command = RELEASE;
} }
slot.i_batch = -1; slot.i_batch = -1;
@ -712,16 +741,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
} }
} }
void processClient(server_parallel_context* ctx)
{
bool running = true;
while (running)
{
running = ctx->updateSlots();
}
}
int main(int argc, char **argv) int main(int argc, char **argv)
{ {
gpt_params params; gpt_params params;
@ -730,6 +749,12 @@ int main(int argc, char **argv)
server_params_parse(argc, argv, sparams, params); server_params_parse(argc, argv, sparams, params);
#ifndef LOG_DISABLE_LOGS
log_set_target(log_filename_generator("server-parallel", "log"));
LOG_TEE("Log start\n");
log_dump_cmdline(argc, argv);
#endif // LOG_DISABLE_LOGS
llama_backend_init(params.numa); llama_backend_init(params.numa);
// load the target model // load the target model
@ -754,28 +779,29 @@ int main(int argc, char **argv)
svr.Post("/completion", [&llama](const Request &req, Response &res) svr.Post("/completion", [&llama](const Request &req, Response &res)
{ {
json data = json::parse(req.body); json data = json::parse(req.body);
int slot_id = data.value("client_slot", -1); int slot_id = data.value("slot_id", -1);
float temperature = data.value("temperature", 0.8f);
string prompt = data.value("prompt", ""); string prompt = data.value("prompt", "");
llama_client_slot* slot_client = llama.loadPrompt(slot_id, prompt); llama_client_slot* slot = llama.loadPrompt(slot_id, prompt, temperature);
// Verify if the slot exist // Verify if the slot exist
if (slot_client) { if (slot) {
res.set_chunked_content_provider("text/event-stream", res.set_chunked_content_provider("text/event-stream",
[slot_client](size_t /*offset*/, DataSink &sink) { [slot](size_t /*offset*/, DataSink &sink) {
if(slot_client->state == IDLE && !slot_client->process_prompt) { // slot has been released if(slot->available()) { // slot has been released
sink.done(); sink.done();
return false; return false;
} }
if(slot_client->state == NEXT_TOKEN) { // new token notification
if(slot->hasNewToken()) { // new token notification
stringstream ss; stringstream ss;
json res_d = {{"token", slot_client->sampled_token_str}}; json res_d = {{"token", slot->sampled_token_str}};
ss << "data: " << res_d.dump() << "\n\n"; ss << "data: " << res_d.dump() << "\n\n";
string result = ss.str(); string result = ss.str();
if(!sink.write(result.c_str(), result.size())) { // user request release if(!sink.write(result.c_str(), result.size())) { // user request release
slot_client->forced_release = true; slot->command = RELEASE;
return false; return false;
} }
slot_client->state = BUSY; // process next token
} }
return true; return true;
}); });
@ -785,7 +811,13 @@ int main(int argc, char **argv)
res.set_content("slot_error", "text/plain"); res.set_content("slot_error", "text/plain");
} }); } });
thread t(processClient, &llama); thread t([&llama]()
{
bool running = true;
while (running)
{
running = llama.updateSlots();
} });
svr.set_read_timeout(sparams.read_timeout); svr.set_read_timeout(sparams.read_timeout);
svr.set_write_timeout(sparams.write_timeout); svr.set_write_timeout(sparams.write_timeout);