diff --git a/examples/server-parallel/index.h b/examples/server-parallel/index.h
index f3a160292..4d305fc51 100644
--- a/examples/server-parallel/index.h
+++ b/examples/server-parallel/index.h
@@ -1,22 +1,27 @@
const auto index_html = R"(
-
+
-
Server parallel - Proof of Concept
-
@@ -26,9 +31,20 @@ const auto index_html = R"(
let conversation = [];
let current_message = -1;
const questions = ["Who is Elon Musk?", "Who is Jeff Bezos?", "How to get a job at google?", "What are you?", "When was born Abraham Lincoln?"];
- window.onload = function() {
+
+ docReady(() => {
document.getElementById("message").value = questions[Math.floor(Math.random() * questions.length)];
- };
+ });
+
+ function docReady(fn) {
+ // see if DOM is already available
+ if (document.readyState === "complete" || document.readyState === "interactive") {
+ // call on next available tick
+ setTimeout(fn, 1);
+ } else {
+ document.addEventListener("DOMContentLoaded", fn);
+ }
+}
function updateView() {
let conv_view = document.getElementById("conversation_view");
@@ -64,6 +80,7 @@ const auto index_html = R"(
while (cont) {
const result = await reader.read();
if (result.done) {
+ document.getElementById("btn_send").disabled = false;
break;
}
@@ -108,43 +125,51 @@ const auto index_html = R"(
}
function generatePrompt() {
+ // generate a good prompt to have coherence
let prompt = '';
for(let index in conversation) {
if(index == 0) {
prompt += conversation[index].user + "\n";
} else {
- prompt += "User: " + conversation[index].user + "\n";
+ prompt += "User:" + conversation[index].user + "\n";
}
if(index == current_message) {
prompt += "Assistant:";
} else {
- prompt += "Assistant: " + conversation[index].assistant;
+ prompt += "Assistant:" + conversation[index].assistant;
}
}
return prompt;
}
- function reset() {
- conversation = [];
- document.getElementById("client_slot").value = "-1";
- document.getElementById("message").value = "";
+ function resetBtn() {
+ document.getElementById("slot_id").value = "-1";
+ document.getElementById("temperature").value = "0.1";
+ document.getElementById("message").value = questions[Math.floor(Math.random() * questions.length)];
document.getElementById("conversation_view").innerHTML = "";
+ conversation = [];
+ current_message = -1;
}
async function perform() {
- var client_slot = parseFloat(document.getElementById("client_slot").value);
- var prompt = document.getElementById("message").value;
- if (!isNaN(client_slot) && prompt.length > 0) {
+ document.getElementById("btn_send").disabled = true;
+ var slot_id = parseInt(document.getElementById("slot_id").value);
+ var temperature = parseFloat(document.getElementById("temperature").value);
+ var prompt = " " + document.getElementById("message").value;
+ if (!isNaN(slot_id) && !isNaN(temperature) && prompt.length > 0) {
current_message++;
conversation.push({
user: prompt,
assistant: ''
});
updateView();
+ document.getElementById("message").value = "";
await call_llama({
- client_slot,
+ slot_id,
+ temperature,
prompt: generatePrompt()
});
+
} else {
document.getElementById("conversation_view").innerText = "please, insert valid props.";
}
diff --git a/examples/server-parallel/server.cpp b/examples/server-parallel/server.cpp
index f3453f148..ef7a9a87d 100644
--- a/examples/server-parallel/server.cpp
+++ b/examples/server-parallel/server.cpp
@@ -59,9 +59,14 @@ enum stop_type
enum slot_state
{
- BUSY,
IDLE,
- NEXT_TOKEN
+ PROCESSING
+};
+
+enum slot_command {
+ NONE,
+ LOAD_PROMPT,
+ RELEASE
};
static std::string system_prompt =
@@ -80,15 +85,38 @@ struct llama_client_slot
int32_t n_prompt = 0;
int32_t n_decoded = 0;
int32_t i_batch = -1;
- bool process_prompt = false;
- bool release_slot = false;
- bool forced_release = false;
string prompt = "";
string sampled_token_str;
- string generated_text;
+ string generated_text = "";
llama_token sampled;
std::vector
tokens_prev;
slot_state state = IDLE;
+ slot_command command = NONE;
+ bool newToken = false;
+ float temperature = 0.1f;
+
+ void start(string prompt_, float temp_) {
+ prompt = prompt_;
+ command = LOAD_PROMPT;
+ temperature = temp_;
+ newToken = false;
+ }
+
+ bool hasNewToken() {
+ if(newToken) {
+ newToken = false;
+ return true;
+ }
+ return false;
+ }
+
+ bool available() {
+ return state == IDLE && command == NONE;
+ }
+
+ void nofity() {
+ newToken = !newToken;
+ }
};
struct server_parallel_context {
@@ -131,7 +159,7 @@ struct server_parallel_context {
slot.state = IDLE;
slot.tokens_prev.resize(std::max(256, params.n_predict));
std::fill(slot.tokens_prev.begin(), slot.tokens_prev.end(), 0);
- LOG_TEE(" -> client slot: %i\n", slot.id);
+ LOG_TEE(" - slot %i\n", slot.id);
slots.push_back(slot);
}
}
@@ -169,16 +197,15 @@ struct server_parallel_context {
return true;
}
- llama_client_slot* loadPrompt(int slot_id, string prompt) {
+ llama_client_slot* loadPrompt(int slot_id, string prompt, float temp_) {
for (llama_client_slot & slot : slots)
{
if (
- slot_id == -1 && slot.state == IDLE ||
+ slot_id == -1 && slot.available() ||
slot.id == slot_id)
{
- slot.prompt = prompt;
- slot.process_prompt = true;
- LOG_TEE("client %i is workloaded\n", slot.id);
+ slot.start(prompt, temp_);
+ LOG_TEE("slot %i is processing\n", slot.id);
return &slot; // return a pointer to slot (thread safe?)
}
}
@@ -211,24 +238,26 @@ struct server_parallel_context {
return stop_pos;
}
-
bool updateSlots() {
batch.n_tokens = 0;
// decode any currently ongoing sequences
for (auto & slot : slots) {
- if(slot.release_slot && slot.state == BUSY || slot.forced_release) {
- if(slot.forced_release) {
- llama_kv_cache_seq_rm(ctx, slot.id, n_tokens_system, n_ctx);
- slot.forced_release = false;
- }
- LOG_TEE("client %i is released\n", slot.id);
+ if (slot.state == PROCESSING && slot.command == RELEASE)
+ {
+ llama_kv_cache_seq_rm(ctx, slot.id, n_tokens_system, n_ctx);
slot.state = IDLE;
- slot.release_slot = false;
+ LOG_TEE("slot %i is released\n", slot.id);
+ slot.command = NONE;
}
- if (slot.state == IDLE) {
+
+ // no decode wait until the token had been send to client
+ // improves performance and avoid decoherence?
+
+ if (slot.state == IDLE || slot.newToken) {
continue;
}
+
batch.token [batch.n_tokens] = slot.sampled;
batch.pos [batch.n_tokens] = n_tokens_system + slot.n_prompt + slot.n_decoded;
batch.seq_id[batch.n_tokens] = slot.id;
@@ -253,10 +282,11 @@ struct server_parallel_context {
// assign workload to the slots
if (params.cont_batching || batch.n_tokens == 0) {
for (llama_client_slot & slot : slots) {
- if (slot.state == IDLE && slot.process_prompt) {
- slot.state = BUSY;
- slot.process_prompt = false;
- //LOG_TEE("client %i process prompt:\n%s'------------------------------\n", slot.id, slot.prompt.c_str());
+ // need process the prompt
+ if (slot.state == IDLE && slot.command == LOAD_PROMPT) {
+ slot.state = PROCESSING;
+ slot.command = NONE;
+ //LOG_TEE("slot %i process prompt:\n%s%s'------------------------------\n", slot.id, system_prompt.c_str(), slot.prompt.c_str());
std::fill(slot.tokens_prev.begin(), slot.tokens_prev.end(), 0);
// do not prepend BOS because we have a system prompt!
@@ -328,7 +358,6 @@ struct server_parallel_context {
// retry with half the batch size to try to find a free slot in the KV cache
n_batch /= 2;
i -= n_batch;
-
continue;
}
@@ -337,6 +366,7 @@ struct server_parallel_context {
continue;
}
+ params.temp = slot.temperature;
const llama_token id = llama_sample_token(ctx, NULL, NULL, params, slot.tokens_prev, candidates, slot.i_batch - i);
// remember which tokens were sampled - used for repetition penalties during sampling
@@ -353,7 +383,8 @@ struct server_parallel_context {
findStoppingStrings(slot.generated_text, token_str.size(), STOP_FULL);
slot.sampled_token_str = token_str;
- slot.state = NEXT_TOKEN;
+ // notify new token
+ slot.nofity();
if (slot.n_decoded > 2 &&
(id == llama_token_eos(ctx) ||
@@ -361,11 +392,9 @@ struct server_parallel_context {
slot.n_decoded + slot.n_prompt >=
params.n_predict) ||
stop_pos != std::string::npos)) {
- // delete only the generated part of the sequence, i.e. keep the system prompt in the cache
- llama_kv_cache_seq_rm(ctx, slot.id, n_tokens_system, n_ctx);
- //LOG_TEE("client %i generated text:\n%s'------------------------------\n", slot.id, slot.generated_text.c_str());
+ //LOG_TEE("slot %i generated text:\n%s'------------------------------\n", slot.id, slot.generated_text.c_str());
slot.generated_text.clear();
- slot.release_slot = true;
+ slot.command = RELEASE;
}
slot.i_batch = -1;
@@ -712,16 +741,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
}
}
-
-void processClient(server_parallel_context* ctx)
-{
- bool running = true;
- while (running)
- {
- running = ctx->updateSlots();
- }
-}
-
int main(int argc, char **argv)
{
gpt_params params;
@@ -730,6 +749,12 @@ int main(int argc, char **argv)
server_params_parse(argc, argv, sparams, params);
+#ifndef LOG_DISABLE_LOGS
+ log_set_target(log_filename_generator("server-parallel", "log"));
+ LOG_TEE("Log start\n");
+ log_dump_cmdline(argc, argv);
+#endif // LOG_DISABLE_LOGS
+
llama_backend_init(params.numa);
// load the target model
@@ -754,28 +779,29 @@ int main(int argc, char **argv)
svr.Post("/completion", [&llama](const Request &req, Response &res)
{
json data = json::parse(req.body);
- int slot_id = data.value("client_slot", -1);
+ int slot_id = data.value("slot_id", -1);
+ float temperature = data.value("temperature", 0.8f);
string prompt = data.value("prompt", "");
- llama_client_slot* slot_client = llama.loadPrompt(slot_id, prompt);
+ llama_client_slot* slot = llama.loadPrompt(slot_id, prompt, temperature);
// Verify if the slot exist
- if (slot_client) {
+ if (slot) {
res.set_chunked_content_provider("text/event-stream",
- [slot_client](size_t /*offset*/, DataSink &sink) {
- if(slot_client->state == IDLE && !slot_client->process_prompt) { // slot has been released
+ [slot](size_t /*offset*/, DataSink &sink) {
+ if(slot->available()) { // slot has been released
sink.done();
return false;
}
- if(slot_client->state == NEXT_TOKEN) { // new token notification
+
+ if(slot->hasNewToken()) { // new token notification
stringstream ss;
- json res_d = {{"token", slot_client->sampled_token_str}};
+ json res_d = {{"token", slot->sampled_token_str}};
ss << "data: " << res_d.dump() << "\n\n";
string result = ss.str();
if(!sink.write(result.c_str(), result.size())) { // user request release
- slot_client->forced_release = true;
+ slot->command = RELEASE;
return false;
}
- slot_client->state = BUSY; // process next token
}
return true;
});
@@ -785,7 +811,13 @@ int main(int argc, char **argv)
res.set_content("slot_error", "text/plain");
} });
- thread t(processClient, &llama);
+ thread t([&llama]()
+ {
+ bool running = true;
+ while (running)
+ {
+ running = llama.updateSlots();
+ } });
svr.set_read_timeout(sparams.read_timeout);
svr.set_write_timeout(sparams.write_timeout);