some var names, state fixes + improvement performance

This commit is contained in:
FSSRepo 2023-10-04 14:49:35 -04:00
parent 598e74c157
commit cbd632a658
2 changed files with 127 additions and 70 deletions

View file

@ -1,22 +1,27 @@
const auto index_html = R"(
<!DOCTYPE html>
<!DOCTYPE html>
<html>
<head>
<title>llama.cpp - server parallel PoC</title>
</head>
<body>
<div style="width: 90%;margin: auto;">
<h2>Server parallel - Proof of Concept</h2>
<form id="myForm">
<label for="client_slot">Client Slot (-1 load in a idle client)</label>
<input type="number" id="client_slot" value="-1" required>
<br><br>
<h2>Server parallel - PoC</h2>
<form id="myForm" >
<label for="slot_id">Slot ID (-1 load in a idle slot)</label>
<input type="number" id="slot_id" value="-1" required>
<br>
<label for="temperature">Temperature</label>
<input type="number" id="temperature" value="0.1" required>
<br>
<label for="message">Message</label>
<input id="message" style="width: 80%;" required>
<br><br>
<button type="button" id="btn_send" onclick="perform() " >Send</button>
<button type="button" onclick="reset() " >Reset</button>
<br>
<br>
<button type="button" id="btn_reset" onclick="resetBtn() " >Reset</button>
</form>
<div id="conversation_view">
</div>
@ -26,9 +31,20 @@ const auto index_html = R"(
let conversation = [];
let current_message = -1;
const questions = ["Who is Elon Musk?", "Who is Jeff Bezos?", "How to get a job at google?", "What are you?", "When was born Abraham Lincoln?"];
window.onload = function() {
docReady(() => {
document.getElementById("message").value = questions[Math.floor(Math.random() * questions.length)];
};
});
function docReady(fn) {
// see if DOM is already available
if (document.readyState === "complete" || document.readyState === "interactive") {
// call on next available tick
setTimeout(fn, 1);
} else {
document.addEventListener("DOMContentLoaded", fn);
}
}
function updateView() {
let conv_view = document.getElementById("conversation_view");
@ -64,6 +80,7 @@ const auto index_html = R"(
while (cont) {
const result = await reader.read();
if (result.done) {
document.getElementById("btn_send").disabled = false;
break;
}
@ -108,43 +125,51 @@ const auto index_html = R"(
}
function generatePrompt() {
// generate a good prompt to have coherence
let prompt = '';
for(let index in conversation) {
if(index == 0) {
prompt += conversation[index].user + "\n";
} else {
prompt += "User: " + conversation[index].user + "\n";
prompt += "User:" + conversation[index].user + "\n";
}
if(index == current_message) {
prompt += "Assistant:";
} else {
prompt += "Assistant: " + conversation[index].assistant;
prompt += "Assistant:" + conversation[index].assistant;
}
}
return prompt;
}
function reset() {
conversation = [];
document.getElementById("client_slot").value = "-1";
document.getElementById("message").value = "";
function resetBtn() {
document.getElementById("slot_id").value = "-1";
document.getElementById("temperature").value = "0.1";
document.getElementById("message").value = questions[Math.floor(Math.random() * questions.length)];
document.getElementById("conversation_view").innerHTML = "";
conversation = [];
current_message = -1;
}
async function perform() {
var client_slot = parseFloat(document.getElementById("client_slot").value);
var prompt = document.getElementById("message").value;
if (!isNaN(client_slot) && prompt.length > 0) {
document.getElementById("btn_send").disabled = true;
var slot_id = parseInt(document.getElementById("slot_id").value);
var temperature = parseFloat(document.getElementById("temperature").value);
var prompt = " " + document.getElementById("message").value;
if (!isNaN(slot_id) && !isNaN(temperature) && prompt.length > 0) {
current_message++;
conversation.push({
user: prompt,
assistant: ''
});
updateView();
document.getElementById("message").value = "";
await call_llama({
client_slot,
slot_id,
temperature,
prompt: generatePrompt()
});
} else {
document.getElementById("conversation_view").innerText = "please, insert valid props.";
}

View file

@ -59,9 +59,14 @@ enum stop_type
enum slot_state
{
BUSY,
IDLE,
NEXT_TOKEN
PROCESSING
};
enum slot_command {
NONE,
LOAD_PROMPT,
RELEASE
};
static std::string system_prompt =
@ -80,15 +85,38 @@ struct llama_client_slot
int32_t n_prompt = 0;
int32_t n_decoded = 0;
int32_t i_batch = -1;
bool process_prompt = false;
bool release_slot = false;
bool forced_release = false;
string prompt = "";
string sampled_token_str;
string generated_text;
string generated_text = "";
llama_token sampled;
std::vector<llama_token> tokens_prev;
slot_state state = IDLE;
slot_command command = NONE;
bool newToken = false;
float temperature = 0.1f;
void start(string prompt_, float temp_) {
prompt = prompt_;
command = LOAD_PROMPT;
temperature = temp_;
newToken = false;
}
bool hasNewToken() {
if(newToken) {
newToken = false;
return true;
}
return false;
}
bool available() {
return state == IDLE && command == NONE;
}
void nofity() {
newToken = !newToken;
}
};
struct server_parallel_context {
@ -131,7 +159,7 @@ struct server_parallel_context {
slot.state = IDLE;
slot.tokens_prev.resize(std::max(256, params.n_predict));
std::fill(slot.tokens_prev.begin(), slot.tokens_prev.end(), 0);
LOG_TEE(" -> client slot: %i\n", slot.id);
LOG_TEE(" - slot %i\n", slot.id);
slots.push_back(slot);
}
}
@ -169,16 +197,15 @@ struct server_parallel_context {
return true;
}
llama_client_slot* loadPrompt(int slot_id, string prompt) {
llama_client_slot* loadPrompt(int slot_id, string prompt, float temp_) {
for (llama_client_slot & slot : slots)
{
if (
slot_id == -1 && slot.state == IDLE ||
slot_id == -1 && slot.available() ||
slot.id == slot_id)
{
slot.prompt = prompt;
slot.process_prompt = true;
LOG_TEE("client %i is workloaded\n", slot.id);
slot.start(prompt, temp_);
LOG_TEE("slot %i is processing\n", slot.id);
return &slot; // return a pointer to slot (thread safe?)
}
}
@ -211,24 +238,26 @@ struct server_parallel_context {
return stop_pos;
}
bool updateSlots() {
batch.n_tokens = 0;
// decode any currently ongoing sequences
for (auto & slot : slots) {
if(slot.release_slot && slot.state == BUSY || slot.forced_release) {
if(slot.forced_release) {
llama_kv_cache_seq_rm(ctx, slot.id, n_tokens_system, n_ctx);
slot.forced_release = false;
}
LOG_TEE("client %i is released\n", slot.id);
if (slot.state == PROCESSING && slot.command == RELEASE)
{
llama_kv_cache_seq_rm(ctx, slot.id, n_tokens_system, n_ctx);
slot.state = IDLE;
slot.release_slot = false;
LOG_TEE("slot %i is released\n", slot.id);
slot.command = NONE;
}
if (slot.state == IDLE) {
// no decode wait until the token had been send to client
// improves performance and avoid decoherence?
if (slot.state == IDLE || slot.newToken) {
continue;
}
batch.token [batch.n_tokens] = slot.sampled;
batch.pos [batch.n_tokens] = n_tokens_system + slot.n_prompt + slot.n_decoded;
batch.seq_id[batch.n_tokens] = slot.id;
@ -253,10 +282,11 @@ struct server_parallel_context {
// assign workload to the slots
if (params.cont_batching || batch.n_tokens == 0) {
for (llama_client_slot & slot : slots) {
if (slot.state == IDLE && slot.process_prompt) {
slot.state = BUSY;
slot.process_prompt = false;
//LOG_TEE("client %i process prompt:\n%s'------------------------------\n", slot.id, slot.prompt.c_str());
// need process the prompt
if (slot.state == IDLE && slot.command == LOAD_PROMPT) {
slot.state = PROCESSING;
slot.command = NONE;
//LOG_TEE("slot %i process prompt:\n%s%s'------------------------------\n", slot.id, system_prompt.c_str(), slot.prompt.c_str());
std::fill(slot.tokens_prev.begin(), slot.tokens_prev.end(), 0);
// do not prepend BOS because we have a system prompt!
@ -328,7 +358,6 @@ struct server_parallel_context {
// retry with half the batch size to try to find a free slot in the KV cache
n_batch /= 2;
i -= n_batch;
continue;
}
@ -337,6 +366,7 @@ struct server_parallel_context {
continue;
}
params.temp = slot.temperature;
const llama_token id = llama_sample_token(ctx, NULL, NULL, params, slot.tokens_prev, candidates, slot.i_batch - i);
// remember which tokens were sampled - used for repetition penalties during sampling
@ -353,7 +383,8 @@ struct server_parallel_context {
findStoppingStrings(slot.generated_text, token_str.size(), STOP_FULL);
slot.sampled_token_str = token_str;
slot.state = NEXT_TOKEN;
// notify new token
slot.nofity();
if (slot.n_decoded > 2 &&
(id == llama_token_eos(ctx) ||
@ -361,11 +392,9 @@ struct server_parallel_context {
slot.n_decoded + slot.n_prompt >=
params.n_predict) ||
stop_pos != std::string::npos)) {
// delete only the generated part of the sequence, i.e. keep the system prompt in the cache
llama_kv_cache_seq_rm(ctx, slot.id, n_tokens_system, n_ctx);
//LOG_TEE("client %i generated text:\n%s'------------------------------\n", slot.id, slot.generated_text.c_str());
//LOG_TEE("slot %i generated text:\n%s'------------------------------\n", slot.id, slot.generated_text.c_str());
slot.generated_text.clear();
slot.release_slot = true;
slot.command = RELEASE;
}
slot.i_batch = -1;
@ -712,16 +741,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
}
}
void processClient(server_parallel_context* ctx)
{
bool running = true;
while (running)
{
running = ctx->updateSlots();
}
}
int main(int argc, char **argv)
{
gpt_params params;
@ -730,6 +749,12 @@ int main(int argc, char **argv)
server_params_parse(argc, argv, sparams, params);
#ifndef LOG_DISABLE_LOGS
log_set_target(log_filename_generator("server-parallel", "log"));
LOG_TEE("Log start\n");
log_dump_cmdline(argc, argv);
#endif // LOG_DISABLE_LOGS
llama_backend_init(params.numa);
// load the target model
@ -754,28 +779,29 @@ int main(int argc, char **argv)
svr.Post("/completion", [&llama](const Request &req, Response &res)
{
json data = json::parse(req.body);
int slot_id = data.value("client_slot", -1);
int slot_id = data.value("slot_id", -1);
float temperature = data.value("temperature", 0.8f);
string prompt = data.value("prompt", "");
llama_client_slot* slot_client = llama.loadPrompt(slot_id, prompt);
llama_client_slot* slot = llama.loadPrompt(slot_id, prompt, temperature);
// Verify if the slot exist
if (slot_client) {
if (slot) {
res.set_chunked_content_provider("text/event-stream",
[slot_client](size_t /*offset*/, DataSink &sink) {
if(slot_client->state == IDLE && !slot_client->process_prompt) { // slot has been released
[slot](size_t /*offset*/, DataSink &sink) {
if(slot->available()) { // slot has been released
sink.done();
return false;
}
if(slot_client->state == NEXT_TOKEN) { // new token notification
if(slot->hasNewToken()) { // new token notification
stringstream ss;
json res_d = {{"token", slot_client->sampled_token_str}};
json res_d = {{"token", slot->sampled_token_str}};
ss << "data: " << res_d.dump() << "\n\n";
string result = ss.str();
if(!sink.write(result.c_str(), result.size())) { // user request release
slot_client->forced_release = true;
slot->command = RELEASE;
return false;
}
slot_client->state = BUSY; // process next token
}
return true;
});
@ -785,7 +811,13 @@ int main(int argc, char **argv)
res.set_content("slot_error", "text/plain");
} });
thread t(processClient, &llama);
thread t([&llama]()
{
bool running = true;
while (running)
{
running = llama.updateSlots();
} });
svr.set_read_timeout(sparams.read_timeout);
svr.set_write_timeout(sparams.write_timeout);