From 8a8535bb6dad549f5fec276e768a3e91c5d14a57 Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Sun, 8 Oct 2023 22:30:43 -0400 Subject: [PATCH] fix cors + regen + cancel funcs --- examples/server-parallel/frontend.h | 120 +++++++++++++++++++--------- examples/server-parallel/server.cpp | 45 ++++++++++- 2 files changed, 122 insertions(+), 43 deletions(-) diff --git a/examples/server-parallel/frontend.h b/examples/server-parallel/frontend.h index b6909c0ff..8fdc0fb09 100644 --- a/examples/server-parallel/frontend.h +++ b/examples/server-parallel/frontend.h @@ -11,14 +11,14 @@ const char* index_html_ = R"( - llama.cpp - server parallel PoC + llama.cpp - server parallel
-

Server parallel - PoC

+

Server parallel

- +

@@ -39,10 +39,9 @@ const char* index_html_ = R"(

- -
-
- + + +
@@ -52,8 +51,12 @@ const char* index_html_ = R"( )"; const char* index_js_ = R"( - let conversation = []; - let current_message = -1; +let conversation = []; +let current_message = -1; +let request_cancel = false; +let canceled = false; +let running = false; +let slot_id = -1; const questions = [ "Who is Elon Musk?", @@ -71,7 +74,7 @@ const questions = [ let user_name = ""; let assistant_name = ""; -function toggleSP() { +function toggle_system_prompt() { if(document.getElementById("system_promt_cb").checked) { document.getElementById("system_prompt_view").style.display = "block"; } else { @@ -79,16 +82,15 @@ function toggleSP() { } } -function clearSP() { +function clear_sp_props() { document.getElementById("sp_text").value = ""; - document.getElementById("anti_prompt").value = ""; + document.getElementById("user_name").value = ""; document.getElementById("assistant_name").value = ""; } docReady(async () => { document.getElementById("message").value = questions[Math.floor(Math.random() * questions.length)]; - // to keep the same prompt format in all clients const response = await fetch("/props"); if (!response.ok) { @@ -128,28 +130,31 @@ function updateView() { } async function call_llama(options) { - const response = await fetch("/completion", { - method: "POST", - body: JSON.stringify(options), - headers: { - Connection: "keep-alive", - "Content-Type": "application/json", - Accept: "text/event-stream", - }, - }); - - const reader = response.body.getReader(); - let cont = true; - const decoder = new TextDecoder(); - let leftover = ""; // Buffer for partially read lines - try { - let cont = true; - - while (cont) { + controller = new AbortController(); + const response = await fetch("/completion", { + method: "POST", + body: JSON.stringify(options), + headers: { + "Content-Type": "application/json", + Accept: "text/event-stream", + } + }); + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let leftover = ""; // Buffer for partially read lines + running = true; + while (running) { + // this no disposes the slot + if(request_cancel) { + running = false; + break; + } const result = await reader.read(); if (result.done) { document.getElementById("btn_send").disabled = false; + document.getElementById("btn_cancel").disabled = true; + running = false; break; } @@ -177,8 +182,9 @@ async function call_llama(options) { if (match) { result[match[1]] = match[2]; // since we know this is llama.cpp, let's just decode the json in data - if (result.data) { + if (result.data && !request_cancel) { result.data = JSON.parse(result.data); + slot_id = result.data.slot_id; conversation[current_message].assistant += result.data.content; updateView(); } @@ -211,21 +217,56 @@ function generatePrompt() { return prompt; } -function resetBtn() { +async function resetView() { + if(running) { + await sendCancelSignal(); + } document.getElementById("slot_id").value = "-1"; document.getElementById("temperature").value = "0.1"; document.getElementById("message").value = questions[Math.floor(Math.random() * questions.length)]; + document.getElementById("btn_cancel").disabled = true; + document.getElementById("btn_cancel").innerText = "Cancel"; + document.getElementById("btn_send").disabled = false; document.getElementById("conversation_view").innerHTML = ""; conversation = []; current_message = -1; + canceled = false; } -async function perform() { +async function sendCancelSignal() { + await fetch( + "/cancel?slot_id=" + slot_id + ); + request_cancel = true; +} + +async function cancel() { + if(!canceled) { + await sendCancelSignal(); + document.getElementById("btn_send").disabled = false; + document.getElementById("btn_cancel").innerText = "Regenerate response"; + canceled = true; + } else { + perform(true); + } +} + +async function perform(regen) { + if(regen) { + document.getElementById("message").value = conversation.pop().user; + current_message--; + } + request_cancel = false; var slot_id = parseInt(document.getElementById("slot_id").value); var temperature = parseFloat(document.getElementById("temperature").value); var prompt = " " + document.getElementById("message").value; - if (!isNaN(slot_id) && !isNaN(temperature) && prompt.length > 0) { + if (prompt.length > 1 && !isNaN(slot_id) && !isNaN(temperature)) { + if(!regen && canceled) { // use the new message + conversation.pop(); // delete incomplete interaction + current_message--; + } + canceled = false; let options = { slot_id, temperature @@ -243,6 +284,7 @@ async function perform() { current_message = -1; document.getElementById("system_promt_cb").checked = false; document.getElementById("system_promt_cb").dispatchEvent(new Event("change")); + // include system prompt props options.system_prompt = system_prompt; options.anti_prompt = anti_prompt; options.assistant_name = assistant_name_; @@ -257,12 +299,12 @@ async function perform() { updateView(); document.getElementById("message").value = ""; document.getElementById("btn_send").disabled = true; + document.getElementById("btn_cancel").disabled = false; + document.getElementById("btn_cancel").innerText = "Cancel"; options.prompt = generatePrompt(); await call_llama(options); } else { - document.getElementById("conversation_view").innerText = - "please, insert valid props."; + alert("please, insert valid props."); } } - )"; diff --git a/examples/server-parallel/server.cpp b/examples/server-parallel/server.cpp index dbc361fd3..431211621 100644 --- a/examples/server-parallel/server.cpp +++ b/examples/server-parallel/server.cpp @@ -210,6 +210,16 @@ struct server_parallel_context { update_system_prompt = false; } + bool releaseSlot(int id) { + for(llama_client_slot &slot : slots) { + if(slot.id == id) { + slot.release(); + return true; + } + } + return false; + } + void notifySystemPromptChanged() { // release all slots for (llama_client_slot &slot : slots) @@ -824,9 +834,12 @@ int main(int argc, char **argv) Server svr; - svr.set_default_headers({{"Server", "llama.cpp"}, - {"Access-Control-Allow-Origin", "*"}, - {"Access-Control-Allow-Headers", "content-type"}}); + svr.Options("/(.*)", + [&](const Request & /*req*/, Response &res) { + res.set_header("Access-Control-Allow-Methods", "*"); + res.set_header("Access-Control-Allow-Headers", "content-type"); + res.set_header("Access-Control-Allow-Origin", "*"); + }); svr.Get("/", [&](const Request & /*req*/, Response &res) { res.set_content(index_html_, "text/html"); }); @@ -836,14 +849,36 @@ int main(int argc, char **argv) svr.Get("/props", [&llama](const Request & /*req*/, Response &res) { + res.set_header("Access-Control-Allow-Origin", "*"); json data = { { "user_name", llama.user_name.c_str() }, { "assistant_name", llama.assistant_name.c_str() } }; res.set_content(data.dump(), "application/json"); }); + svr.Get("/cancel", [&llama](const Request & req/*req*/, Response &res) { + res.set_header("Access-Control-Allow-Origin", "*"); + if(req.has_param("slot_id")) { + int slot_id = std::stoi(req.get_param_value("slot_id")); + string result = "done"; + if(!llama.releaseSlot(slot_id)) { + result = "wrong slot ID"; + } + json data = { + { "status", result } + }; + res.set_content(data.dump(), "application/json"); + } else { + json data = { + { "error", "Missing parameter" } + }; + res.set_content(data.dump(), "application/json"); + } + }); + svr.Post("/completion", [&llama](const Request &req, Response &res) { + res.set_header("Access-Control-Allow-Origin", "*"); llama_client_slot* slot = llama.requestCompletion(json::parse(req.body)); // Verify if the slot exist if (slot) { @@ -855,7 +890,9 @@ int main(int argc, char **argv) } if(slot->hasNewToken()) { // new token notification stringstream ss; - json res_d = {{ "content", slot->sampled_tokens.back() }}; + json res_d = { + { "content", slot->sampled_tokens.back() }, + { "slot_id", slot->id }}; slot->sampled_tokens.pop_back(); ss << "data: " << res_d.dump() << "\n\n"; string result = ss.str();