diff --git a/examples/server-parallel/index.h b/examples/server-parallel/index.h index f3a160292..4d305fc51 100644 --- a/examples/server-parallel/index.h +++ b/examples/server-parallel/index.h @@ -1,22 +1,27 @@ const auto index_html = R"( - + llama.cpp - server parallel PoC
-

Server parallel - Proof of Concept

-
- - -

+

Server parallel - PoC

+ + + +
+ + +


- +
+
+
@@ -26,9 +31,20 @@ const auto index_html = R"( let conversation = []; let current_message = -1; const questions = ["Who is Elon Musk?", "Who is Jeff Bezos?", "How to get a job at google?", "What are you?", "When was born Abraham Lincoln?"]; - window.onload = function() { + + docReady(() => { document.getElementById("message").value = questions[Math.floor(Math.random() * questions.length)]; - }; + }); + + function docReady(fn) { + // see if DOM is already available + if (document.readyState === "complete" || document.readyState === "interactive") { + // call on next available tick + setTimeout(fn, 1); + } else { + document.addEventListener("DOMContentLoaded", fn); + } +} function updateView() { let conv_view = document.getElementById("conversation_view"); @@ -64,6 +80,7 @@ const auto index_html = R"( while (cont) { const result = await reader.read(); if (result.done) { + document.getElementById("btn_send").disabled = false; break; } @@ -108,43 +125,51 @@ const auto index_html = R"( } function generatePrompt() { + // generate a good prompt to have coherence let prompt = ''; for(let index in conversation) { if(index == 0) { prompt += conversation[index].user + "\n"; } else { - prompt += "User: " + conversation[index].user + "\n"; + prompt += "User:" + conversation[index].user + "\n"; } if(index == current_message) { prompt += "Assistant:"; } else { - prompt += "Assistant: " + conversation[index].assistant; + prompt += "Assistant:" + conversation[index].assistant; } } return prompt; } - function reset() { - conversation = []; - document.getElementById("client_slot").value = "-1"; - document.getElementById("message").value = ""; + function resetBtn() { + document.getElementById("slot_id").value = "-1"; + document.getElementById("temperature").value = "0.1"; + document.getElementById("message").value = questions[Math.floor(Math.random() * questions.length)]; document.getElementById("conversation_view").innerHTML = ""; + conversation = []; + current_message = -1; } async function perform() { - var client_slot = parseFloat(document.getElementById("client_slot").value); - var prompt = document.getElementById("message").value; - if (!isNaN(client_slot) && prompt.length > 0) { + document.getElementById("btn_send").disabled = true; + var slot_id = parseInt(document.getElementById("slot_id").value); + var temperature = parseFloat(document.getElementById("temperature").value); + var prompt = " " + document.getElementById("message").value; + if (!isNaN(slot_id) && !isNaN(temperature) && prompt.length > 0) { current_message++; conversation.push({ user: prompt, assistant: '' }); updateView(); + document.getElementById("message").value = ""; await call_llama({ - client_slot, + slot_id, + temperature, prompt: generatePrompt() }); + } else { document.getElementById("conversation_view").innerText = "please, insert valid props."; } diff --git a/examples/server-parallel/server.cpp b/examples/server-parallel/server.cpp index f3453f148..ef7a9a87d 100644 --- a/examples/server-parallel/server.cpp +++ b/examples/server-parallel/server.cpp @@ -59,9 +59,14 @@ enum stop_type enum slot_state { - BUSY, IDLE, - NEXT_TOKEN + PROCESSING +}; + +enum slot_command { + NONE, + LOAD_PROMPT, + RELEASE }; static std::string system_prompt = @@ -80,15 +85,38 @@ struct llama_client_slot int32_t n_prompt = 0; int32_t n_decoded = 0; int32_t i_batch = -1; - bool process_prompt = false; - bool release_slot = false; - bool forced_release = false; string prompt = ""; string sampled_token_str; - string generated_text; + string generated_text = ""; llama_token sampled; std::vector tokens_prev; slot_state state = IDLE; + slot_command command = NONE; + bool newToken = false; + float temperature = 0.1f; + + void start(string prompt_, float temp_) { + prompt = prompt_; + command = LOAD_PROMPT; + temperature = temp_; + newToken = false; + } + + bool hasNewToken() { + if(newToken) { + newToken = false; + return true; + } + return false; + } + + bool available() { + return state == IDLE && command == NONE; + } + + void nofity() { + newToken = !newToken; + } }; struct server_parallel_context { @@ -131,7 +159,7 @@ struct server_parallel_context { slot.state = IDLE; slot.tokens_prev.resize(std::max(256, params.n_predict)); std::fill(slot.tokens_prev.begin(), slot.tokens_prev.end(), 0); - LOG_TEE(" -> client slot: %i\n", slot.id); + LOG_TEE(" - slot %i\n", slot.id); slots.push_back(slot); } } @@ -169,16 +197,15 @@ struct server_parallel_context { return true; } - llama_client_slot* loadPrompt(int slot_id, string prompt) { + llama_client_slot* loadPrompt(int slot_id, string prompt, float temp_) { for (llama_client_slot & slot : slots) { if ( - slot_id == -1 && slot.state == IDLE || + slot_id == -1 && slot.available() || slot.id == slot_id) { - slot.prompt = prompt; - slot.process_prompt = true; - LOG_TEE("client %i is workloaded\n", slot.id); + slot.start(prompt, temp_); + LOG_TEE("slot %i is processing\n", slot.id); return &slot; // return a pointer to slot (thread safe?) } } @@ -211,24 +238,26 @@ struct server_parallel_context { return stop_pos; } - bool updateSlots() { batch.n_tokens = 0; // decode any currently ongoing sequences for (auto & slot : slots) { - if(slot.release_slot && slot.state == BUSY || slot.forced_release) { - if(slot.forced_release) { - llama_kv_cache_seq_rm(ctx, slot.id, n_tokens_system, n_ctx); - slot.forced_release = false; - } - LOG_TEE("client %i is released\n", slot.id); + if (slot.state == PROCESSING && slot.command == RELEASE) + { + llama_kv_cache_seq_rm(ctx, slot.id, n_tokens_system, n_ctx); slot.state = IDLE; - slot.release_slot = false; + LOG_TEE("slot %i is released\n", slot.id); + slot.command = NONE; } - if (slot.state == IDLE) { + + // no decode wait until the token had been send to client + // improves performance and avoid decoherence? + + if (slot.state == IDLE || slot.newToken) { continue; } + batch.token [batch.n_tokens] = slot.sampled; batch.pos [batch.n_tokens] = n_tokens_system + slot.n_prompt + slot.n_decoded; batch.seq_id[batch.n_tokens] = slot.id; @@ -253,10 +282,11 @@ struct server_parallel_context { // assign workload to the slots if (params.cont_batching || batch.n_tokens == 0) { for (llama_client_slot & slot : slots) { - if (slot.state == IDLE && slot.process_prompt) { - slot.state = BUSY; - slot.process_prompt = false; - //LOG_TEE("client %i process prompt:\n%s'------------------------------\n", slot.id, slot.prompt.c_str()); + // need process the prompt + if (slot.state == IDLE && slot.command == LOAD_PROMPT) { + slot.state = PROCESSING; + slot.command = NONE; + //LOG_TEE("slot %i process prompt:\n%s%s'------------------------------\n", slot.id, system_prompt.c_str(), slot.prompt.c_str()); std::fill(slot.tokens_prev.begin(), slot.tokens_prev.end(), 0); // do not prepend BOS because we have a system prompt! @@ -328,7 +358,6 @@ struct server_parallel_context { // retry with half the batch size to try to find a free slot in the KV cache n_batch /= 2; i -= n_batch; - continue; } @@ -337,6 +366,7 @@ struct server_parallel_context { continue; } + params.temp = slot.temperature; const llama_token id = llama_sample_token(ctx, NULL, NULL, params, slot.tokens_prev, candidates, slot.i_batch - i); // remember which tokens were sampled - used for repetition penalties during sampling @@ -353,7 +383,8 @@ struct server_parallel_context { findStoppingStrings(slot.generated_text, token_str.size(), STOP_FULL); slot.sampled_token_str = token_str; - slot.state = NEXT_TOKEN; + // notify new token + slot.nofity(); if (slot.n_decoded > 2 && (id == llama_token_eos(ctx) || @@ -361,11 +392,9 @@ struct server_parallel_context { slot.n_decoded + slot.n_prompt >= params.n_predict) || stop_pos != std::string::npos)) { - // delete only the generated part of the sequence, i.e. keep the system prompt in the cache - llama_kv_cache_seq_rm(ctx, slot.id, n_tokens_system, n_ctx); - //LOG_TEE("client %i generated text:\n%s'------------------------------\n", slot.id, slot.generated_text.c_str()); + //LOG_TEE("slot %i generated text:\n%s'------------------------------\n", slot.id, slot.generated_text.c_str()); slot.generated_text.clear(); - slot.release_slot = true; + slot.command = RELEASE; } slot.i_batch = -1; @@ -712,16 +741,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } } - -void processClient(server_parallel_context* ctx) -{ - bool running = true; - while (running) - { - running = ctx->updateSlots(); - } -} - int main(int argc, char **argv) { gpt_params params; @@ -730,6 +749,12 @@ int main(int argc, char **argv) server_params_parse(argc, argv, sparams, params); +#ifndef LOG_DISABLE_LOGS + log_set_target(log_filename_generator("server-parallel", "log")); + LOG_TEE("Log start\n"); + log_dump_cmdline(argc, argv); +#endif // LOG_DISABLE_LOGS + llama_backend_init(params.numa); // load the target model @@ -754,28 +779,29 @@ int main(int argc, char **argv) svr.Post("/completion", [&llama](const Request &req, Response &res) { json data = json::parse(req.body); - int slot_id = data.value("client_slot", -1); + int slot_id = data.value("slot_id", -1); + float temperature = data.value("temperature", 0.8f); string prompt = data.value("prompt", ""); - llama_client_slot* slot_client = llama.loadPrompt(slot_id, prompt); + llama_client_slot* slot = llama.loadPrompt(slot_id, prompt, temperature); // Verify if the slot exist - if (slot_client) { + if (slot) { res.set_chunked_content_provider("text/event-stream", - [slot_client](size_t /*offset*/, DataSink &sink) { - if(slot_client->state == IDLE && !slot_client->process_prompt) { // slot has been released + [slot](size_t /*offset*/, DataSink &sink) { + if(slot->available()) { // slot has been released sink.done(); return false; } - if(slot_client->state == NEXT_TOKEN) { // new token notification + + if(slot->hasNewToken()) { // new token notification stringstream ss; - json res_d = {{"token", slot_client->sampled_token_str}}; + json res_d = {{"token", slot->sampled_token_str}}; ss << "data: " << res_d.dump() << "\n\n"; string result = ss.str(); if(!sink.write(result.c_str(), result.size())) { // user request release - slot_client->forced_release = true; + slot->command = RELEASE; return false; } - slot_client->state = BUSY; // process next token } return true; }); @@ -785,7 +811,13 @@ int main(int argc, char **argv) res.set_content("slot_error", "text/plain"); } }); - thread t(processClient, &llama); + thread t([&llama]() + { + bool running = true; + while (running) + { + running = llama.updateSlots(); + } }); svr.set_read_timeout(sparams.read_timeout); svr.set_write_timeout(sparams.write_timeout);