diff --git a/examples/server-parallel/frontend.h b/examples/server-parallel/frontend.h index 23d439677..abe9eb701 100644 --- a/examples/server-parallel/frontend.h +++ b/examples/server-parallel/frontend.h @@ -53,10 +53,9 @@ const char* index_html_ = R"( const char* index_js_ = R"( let conversation = []; let current_message = -1; -let request_cancel = false; let canceled = false; -let running = false; let slot_id = -1; +var controller; const questions = [ "Who is Elon Musk?", @@ -92,7 +91,7 @@ docReady(async () => { document.getElementById("message").value = questions[Math.floor(Math.random() * questions.length)]; // to keep the same prompt format in all clients - const response = await fetch("/props"); + const response = await fetch("http://localhost:8080/props"); if (!response.ok) { alert(`HTTP error! Status: ${response.status}`); } @@ -131,29 +130,25 @@ function updateView() { async function call_llama(options) { try { - const response = await fetch("/completion", { + controller = new AbortController(); + signal = controller.signal; + const response = await fetch("http://localhost:8080/completion", { method: "POST", body: JSON.stringify(options), headers: { "Content-Type": "application/json", Accept: "text/event-stream", - } + }, + signal: signal }); const reader = response.body.getReader(); const decoder = new TextDecoder(); let leftover = ""; // Buffer for partially read lines - running = true; - while (running) { - // this no disposes the slot - if(request_cancel) { - running = false; - break; - } + while (current_message >= 0) { const result = await reader.read(); if (result.done) { document.getElementById("btn_send").disabled = false; document.getElementById("btn_cancel").disabled = true; - running = false; break; } @@ -181,7 +176,7 @@ async function call_llama(options) { if (match) { result[match[1]] = match[2]; // since we know this is llama.cpp, let's just decode the json in data - if (result.data && !request_cancel) { + if (result.data && current_message >= 0) { result.data = JSON.parse(result.data); slot_id = result.data.slot_id; conversation[current_message].assistant += result.data.content; @@ -194,7 +189,6 @@ async function call_llama(options) { if (e.name !== "AbortError") { console.error("llama error: ", e); } - throw e; } } @@ -213,12 +207,14 @@ function generatePrompt() { prompt += assistant_name + conversation[index].assistant; } } + console.log(prompt) return prompt; } async function resetView() { - if(running) { - await sendCancelSignal(); + if(controller) { + controller.abort(); + controller = null; } document.getElementById("slot_id").value = "-1"; document.getElementById("temperature").value = "0.1"; @@ -233,16 +229,12 @@ async function resetView() { canceled = false; } -async function sendCancelSignal() { - await fetch( - "/cancel?slot_id=" + slot_id - ); - request_cancel = true; -} - async function cancel() { if(!canceled) { - await sendCancelSignal(); + if(controller) { + controller.abort(); + controller = null; + } document.getElementById("btn_send").disabled = false; document.getElementById("btn_cancel").innerText = "Regenerate response"; canceled = true; @@ -256,7 +248,6 @@ async function perform(regen) { document.getElementById("message").value = conversation.pop().user; current_message--; } - request_cancel = false; var slot_id = parseInt(document.getElementById("slot_id").value); var temperature = parseFloat(document.getElementById("temperature").value); var prompt = " " + document.getElementById("message").value; diff --git a/examples/server-parallel/server.cpp b/examples/server-parallel/server.cpp index 431211621..13dd1fcc0 100644 --- a/examples/server-parallel/server.cpp +++ b/examples/server-parallel/server.cpp @@ -210,16 +210,6 @@ struct server_parallel_context { update_system_prompt = false; } - bool releaseSlot(int id) { - for(llama_client_slot &slot : slots) { - if(slot.id == id) { - slot.release(); - return true; - } - } - return false; - } - void notifySystemPromptChanged() { // release all slots for (llama_client_slot &slot : slots) @@ -357,6 +347,7 @@ struct server_parallel_context { std::vector tokens_prompt; tokens_prompt = ::llama_tokenize(ctx, slot.prompt, false); slot.n_tokens_predicted = 0; + slot.sampled_tokens.clear(); for (size_t i = 0; i < tokens_prompt.size(); ++i) { batch.token [batch.n_tokens] = tokens_prompt[i]; @@ -856,34 +847,13 @@ int main(int argc, char **argv) }; res.set_content(data.dump(), "application/json"); }); - svr.Get("/cancel", [&llama](const Request & req/*req*/, Response &res) { - res.set_header("Access-Control-Allow-Origin", "*"); - if(req.has_param("slot_id")) { - int slot_id = std::stoi(req.get_param_value("slot_id")); - string result = "done"; - if(!llama.releaseSlot(slot_id)) { - result = "wrong slot ID"; - } - json data = { - { "status", result } - }; - res.set_content(data.dump(), "application/json"); - } else { - json data = { - { "error", "Missing parameter" } - }; - res.set_content(data.dump(), "application/json"); - } - }); - svr.Post("/completion", [&llama](const Request &req, Response &res) { res.set_header("Access-Control-Allow-Origin", "*"); llama_client_slot* slot = llama.requestCompletion(json::parse(req.body)); // Verify if the slot exist if (slot) { - res.set_chunked_content_provider("text/event-stream", - [slot](size_t /*offset*/, DataSink &sink) { + auto content_provider = [slot](size_t /*offset*/, DataSink &sink) { if(slot->available()) { // slot has been released sink.done(); return false; @@ -902,7 +872,11 @@ int main(int argc, char **argv) } } return true; - }); + }; + auto on_complete = [slot] (bool) { + slot->release(); + }; + res.set_chunked_content_provider("text/event-stream", content_provider, on_complete); } else { LOG_TEE("slot unavailable\n"); res.status = 404;