diff --git a/examples/server-parallel/frontend.h b/examples/server-parallel/frontend.h
index b6909c0ff..8fdc0fb09 100644
--- a/examples/server-parallel/frontend.h
+++ b/examples/server-parallel/frontend.h
@@ -11,14 +11,14 @@ const char* index_html_ = R"(
-
-
Server parallel - PoC
+
Server parallel
@@ -52,8 +51,12 @@ const char* index_html_ = R"(
)";
const char* index_js_ = R"(
- let conversation = [];
- let current_message = -1;
+let conversation = [];
+let current_message = -1;
+let request_cancel = false;
+let canceled = false;
+let running = false;
+let slot_id = -1;
const questions = [
"Who is Elon Musk?",
@@ -71,7 +74,7 @@ const questions = [
let user_name = "";
let assistant_name = "";
-function toggleSP() {
+function toggle_system_prompt() {
if(document.getElementById("system_promt_cb").checked) {
document.getElementById("system_prompt_view").style.display = "block";
} else {
@@ -79,16 +82,15 @@ function toggleSP() {
}
}
-function clearSP() {
+function clear_sp_props() {
document.getElementById("sp_text").value = "";
- document.getElementById("anti_prompt").value = "";
+ document.getElementById("user_name").value = "";
document.getElementById("assistant_name").value = "";
}
docReady(async () => {
document.getElementById("message").value =
questions[Math.floor(Math.random() * questions.length)];
-
// to keep the same prompt format in all clients
const response = await fetch("/props");
if (!response.ok) {
@@ -128,28 +130,31 @@ function updateView() {
}
async function call_llama(options) {
- const response = await fetch("/completion", {
- method: "POST",
- body: JSON.stringify(options),
- headers: {
- Connection: "keep-alive",
- "Content-Type": "application/json",
- Accept: "text/event-stream",
- },
- });
-
- const reader = response.body.getReader();
- let cont = true;
- const decoder = new TextDecoder();
- let leftover = ""; // Buffer for partially read lines
-
try {
- let cont = true;
-
- while (cont) {
+ controller = new AbortController();
+ const response = await fetch("/completion", {
+ method: "POST",
+ body: JSON.stringify(options),
+ headers: {
+ "Content-Type": "application/json",
+ Accept: "text/event-stream",
+ }
+ });
+ const reader = response.body.getReader();
+ const decoder = new TextDecoder();
+ let leftover = ""; // Buffer for partially read lines
+ running = true;
+ while (running) {
+ // this no disposes the slot
+ if(request_cancel) {
+ running = false;
+ break;
+ }
const result = await reader.read();
if (result.done) {
document.getElementById("btn_send").disabled = false;
+ document.getElementById("btn_cancel").disabled = true;
+ running = false;
break;
}
@@ -177,8 +182,9 @@ async function call_llama(options) {
if (match) {
result[match[1]] = match[2];
// since we know this is llama.cpp, let's just decode the json in data
- if (result.data) {
+ if (result.data && !request_cancel) {
result.data = JSON.parse(result.data);
+ slot_id = result.data.slot_id;
conversation[current_message].assistant += result.data.content;
updateView();
}
@@ -211,21 +217,56 @@ function generatePrompt() {
return prompt;
}
-function resetBtn() {
+async function resetView() {
+ if(running) {
+ await sendCancelSignal();
+ }
document.getElementById("slot_id").value = "-1";
document.getElementById("temperature").value = "0.1";
document.getElementById("message").value =
questions[Math.floor(Math.random() * questions.length)];
+ document.getElementById("btn_cancel").disabled = true;
+ document.getElementById("btn_cancel").innerText = "Cancel";
+ document.getElementById("btn_send").disabled = false;
document.getElementById("conversation_view").innerHTML = "";
conversation = [];
current_message = -1;
+ canceled = false;
}
-async function perform() {
+async function sendCancelSignal() {
+ await fetch(
+ "/cancel?slot_id=" + slot_id
+ );
+ request_cancel = true;
+}
+
+async function cancel() {
+ if(!canceled) {
+ await sendCancelSignal();
+ document.getElementById("btn_send").disabled = false;
+ document.getElementById("btn_cancel").innerText = "Regenerate response";
+ canceled = true;
+ } else {
+ perform(true);
+ }
+}
+
+async function perform(regen) {
+ if(regen) {
+ document.getElementById("message").value = conversation.pop().user;
+ current_message--;
+ }
+ request_cancel = false;
var slot_id = parseInt(document.getElementById("slot_id").value);
var temperature = parseFloat(document.getElementById("temperature").value);
var prompt = " " + document.getElementById("message").value;
- if (!isNaN(slot_id) && !isNaN(temperature) && prompt.length > 0) {
+ if (prompt.length > 1 && !isNaN(slot_id) && !isNaN(temperature)) {
+ if(!regen && canceled) { // use the new message
+ conversation.pop(); // delete incomplete interaction
+ current_message--;
+ }
+ canceled = false;
let options = {
slot_id,
temperature
@@ -243,6 +284,7 @@ async function perform() {
current_message = -1;
document.getElementById("system_promt_cb").checked = false;
document.getElementById("system_promt_cb").dispatchEvent(new Event("change"));
+ // include system prompt props
options.system_prompt = system_prompt;
options.anti_prompt = anti_prompt;
options.assistant_name = assistant_name_;
@@ -257,12 +299,12 @@ async function perform() {
updateView();
document.getElementById("message").value = "";
document.getElementById("btn_send").disabled = true;
+ document.getElementById("btn_cancel").disabled = false;
+ document.getElementById("btn_cancel").innerText = "Cancel";
options.prompt = generatePrompt();
await call_llama(options);
} else {
- document.getElementById("conversation_view").innerText =
- "please, insert valid props.";
+ alert("please, insert valid props.");
}
}
-
)";
diff --git a/examples/server-parallel/server.cpp b/examples/server-parallel/server.cpp
index dbc361fd3..431211621 100644
--- a/examples/server-parallel/server.cpp
+++ b/examples/server-parallel/server.cpp
@@ -210,6 +210,16 @@ struct server_parallel_context {
update_system_prompt = false;
}
+ bool releaseSlot(int id) {
+ for(llama_client_slot &slot : slots) {
+ if(slot.id == id) {
+ slot.release();
+ return true;
+ }
+ }
+ return false;
+ }
+
void notifySystemPromptChanged() {
// release all slots
for (llama_client_slot &slot : slots)
@@ -824,9 +834,12 @@ int main(int argc, char **argv)
Server svr;
- svr.set_default_headers({{"Server", "llama.cpp"},
- {"Access-Control-Allow-Origin", "*"},
- {"Access-Control-Allow-Headers", "content-type"}});
+ svr.Options("/(.*)",
+ [&](const Request & /*req*/, Response &res) {
+ res.set_header("Access-Control-Allow-Methods", "*");
+ res.set_header("Access-Control-Allow-Headers", "content-type");
+ res.set_header("Access-Control-Allow-Origin", "*");
+ });
svr.Get("/", [&](const Request & /*req*/, Response &res)
{ res.set_content(index_html_, "text/html"); });
@@ -836,14 +849,36 @@ int main(int argc, char **argv)
svr.Get("/props", [&llama](const Request & /*req*/, Response &res)
{
+ res.set_header("Access-Control-Allow-Origin", "*");
json data = {
{ "user_name", llama.user_name.c_str() },
{ "assistant_name", llama.assistant_name.c_str() }
};
res.set_content(data.dump(), "application/json"); });
+ svr.Get("/cancel", [&llama](const Request & req/*req*/, Response &res) {
+ res.set_header("Access-Control-Allow-Origin", "*");
+ if(req.has_param("slot_id")) {
+ int slot_id = std::stoi(req.get_param_value("slot_id"));
+ string result = "done";
+ if(!llama.releaseSlot(slot_id)) {
+ result = "wrong slot ID";
+ }
+ json data = {
+ { "status", result }
+ };
+ res.set_content(data.dump(), "application/json");
+ } else {
+ json data = {
+ { "error", "Missing parameter" }
+ };
+ res.set_content(data.dump(), "application/json");
+ }
+ });
+
svr.Post("/completion", [&llama](const Request &req, Response &res)
{
+ res.set_header("Access-Control-Allow-Origin", "*");
llama_client_slot* slot = llama.requestCompletion(json::parse(req.body));
// Verify if the slot exist
if (slot) {
@@ -855,7 +890,9 @@ int main(int argc, char **argv)
}
if(slot->hasNewToken()) { // new token notification
stringstream ss;
- json res_d = {{ "content", slot->sampled_tokens.back() }};
+ json res_d = {
+ { "content", slot->sampled_tokens.back() },
+ { "slot_id", slot->id }};
slot->sampled_tokens.pop_back();
ss << "data: " << res_d.dump() << "\n\n";
string result = ss.str();