fix cors + regen + cancel funcs

This commit is contained in:
FSSRepo 2023-10-08 22:30:43 -04:00
parent f861ff916d
commit 8a8535bb6d
2 changed files with 122 additions and 43 deletions

View file

@ -11,14 +11,14 @@ const char* index_html_ = R"(
<!DOCTYPE html> <!DOCTYPE html>
<html> <html>
<head> <head>
<title>llama.cpp - server parallel PoC</title> <title>llama.cpp - server parallel</title>
<script src="index.js"></script> <script src="index.js"></script>
</head> </head>
<body> <body>
<div style="width: 90%;margin: auto;"> <div style="width: 90%;margin: auto;">
<h2>Server parallel - PoC</h2> <h2>Server parallel</h2>
<form id="myForm"> <form id="myForm">
<input type="checkbox" id="system_promt_cb" name="myCheckbox" onchange="toggleSP() "> <input type="checkbox" id="system_promt_cb" name="myCheckbox" onchange="toggle_system_prompt() ">
<label for="system_promt_cb">Use custom system prompt</label> <label for="system_promt_cb">Use custom system prompt</label>
<br> <br>
<div id="system_prompt_view" style="display: none;"> <div id="system_prompt_view" style="display: none;">
@ -27,7 +27,7 @@ const char* index_html_ = R"(
<input type="text" id="user_name" value="" placeholder="Anti prompt" required> <input type="text" id="user_name" value="" placeholder="Anti prompt" required>
<label for="assistant_name">Assistant name</label> <label for="assistant_name">Assistant name</label>
<input type="text" id="assistant_name" value="" placeholder="Assistant:" required> <input type="text" id="assistant_name" value="" placeholder="Assistant:" required>
<button type="button" id="btn_reset" onclick="clearSP() " >Clear all</button> <button type="button" id="btn_reset" onclick="clear_sp_props() " >Clear all</button>
</div> </div>
<br> <br>
<label for="slot_id">Slot ID (-1 load in a idle slot)</label> <label for="slot_id">Slot ID (-1 load in a idle slot)</label>
@ -39,10 +39,9 @@ const char* index_html_ = R"(
<label for="message">Message</label> <label for="message">Message</label>
<input id="message" style="width: 80%;" required> <input id="message" style="width: 80%;" required>
<br><br> <br><br>
<button type="button" id="btn_send" onclick="perform() " >Send</button> <button type="button" id="btn_send" style="margin-right: 1rem;" onclick="perform(false) " >Send</button>
<br> <button type="button" id="btn_cancel" style="margin-right: 1rem;" onclick="cancel() " disabled>Cancel</button>
<br> <button type="button" id="btn_reset" onclick="resetView() " >Reset</button>
<button type="button" id="btn_reset" onclick="resetBtn() " >Reset</button>
</form> </form>
<div id="conversation_view"> <div id="conversation_view">
</div> </div>
@ -52,8 +51,12 @@ const char* index_html_ = R"(
)"; )";
const char* index_js_ = R"( const char* index_js_ = R"(
let conversation = []; let conversation = [];
let current_message = -1; let current_message = -1;
let request_cancel = false;
let canceled = false;
let running = false;
let slot_id = -1;
const questions = [ const questions = [
"Who is Elon Musk?", "Who is Elon Musk?",
@ -71,7 +74,7 @@ const questions = [
let user_name = ""; let user_name = "";
let assistant_name = ""; let assistant_name = "";
function toggleSP() { function toggle_system_prompt() {
if(document.getElementById("system_promt_cb").checked) { if(document.getElementById("system_promt_cb").checked) {
document.getElementById("system_prompt_view").style.display = "block"; document.getElementById("system_prompt_view").style.display = "block";
} else { } else {
@ -79,16 +82,15 @@ function toggleSP() {
} }
} }
function clearSP() { function clear_sp_props() {
document.getElementById("sp_text").value = ""; document.getElementById("sp_text").value = "";
document.getElementById("anti_prompt").value = ""; document.getElementById("user_name").value = "";
document.getElementById("assistant_name").value = ""; document.getElementById("assistant_name").value = "";
} }
docReady(async () => { docReady(async () => {
document.getElementById("message").value = document.getElementById("message").value =
questions[Math.floor(Math.random() * questions.length)]; questions[Math.floor(Math.random() * questions.length)];
// to keep the same prompt format in all clients // to keep the same prompt format in all clients
const response = await fetch("/props"); const response = await fetch("/props");
if (!response.ok) { if (!response.ok) {
@ -128,28 +130,31 @@ function updateView() {
} }
async function call_llama(options) { async function call_llama(options) {
try {
controller = new AbortController();
const response = await fetch("/completion", { const response = await fetch("/completion", {
method: "POST", method: "POST",
body: JSON.stringify(options), body: JSON.stringify(options),
headers: { headers: {
Connection: "keep-alive",
"Content-Type": "application/json", "Content-Type": "application/json",
Accept: "text/event-stream", Accept: "text/event-stream",
}, }
}); });
const reader = response.body.getReader(); const reader = response.body.getReader();
let cont = true;
const decoder = new TextDecoder(); const decoder = new TextDecoder();
let leftover = ""; // Buffer for partially read lines let leftover = ""; // Buffer for partially read lines
running = true;
try { while (running) {
let cont = true; // this no disposes the slot
if(request_cancel) {
while (cont) { running = false;
break;
}
const result = await reader.read(); const result = await reader.read();
if (result.done) { if (result.done) {
document.getElementById("btn_send").disabled = false; document.getElementById("btn_send").disabled = false;
document.getElementById("btn_cancel").disabled = true;
running = false;
break; break;
} }
@ -177,8 +182,9 @@ async function call_llama(options) {
if (match) { if (match) {
result[match[1]] = match[2]; result[match[1]] = match[2];
// since we know this is llama.cpp, let's just decode the json in data // since we know this is llama.cpp, let's just decode the json in data
if (result.data) { if (result.data && !request_cancel) {
result.data = JSON.parse(result.data); result.data = JSON.parse(result.data);
slot_id = result.data.slot_id;
conversation[current_message].assistant += result.data.content; conversation[current_message].assistant += result.data.content;
updateView(); updateView();
} }
@ -211,21 +217,56 @@ function generatePrompt() {
return prompt; return prompt;
} }
function resetBtn() { async function resetView() {
if(running) {
await sendCancelSignal();
}
document.getElementById("slot_id").value = "-1"; document.getElementById("slot_id").value = "-1";
document.getElementById("temperature").value = "0.1"; document.getElementById("temperature").value = "0.1";
document.getElementById("message").value = document.getElementById("message").value =
questions[Math.floor(Math.random() * questions.length)]; questions[Math.floor(Math.random() * questions.length)];
document.getElementById("btn_cancel").disabled = true;
document.getElementById("btn_cancel").innerText = "Cancel";
document.getElementById("btn_send").disabled = false;
document.getElementById("conversation_view").innerHTML = ""; document.getElementById("conversation_view").innerHTML = "";
conversation = []; conversation = [];
current_message = -1; current_message = -1;
canceled = false;
} }
async function perform() { async function sendCancelSignal() {
await fetch(
"/cancel?slot_id=" + slot_id
);
request_cancel = true;
}
async function cancel() {
if(!canceled) {
await sendCancelSignal();
document.getElementById("btn_send").disabled = false;
document.getElementById("btn_cancel").innerText = "Regenerate response";
canceled = true;
} else {
perform(true);
}
}
async function perform(regen) {
if(regen) {
document.getElementById("message").value = conversation.pop().user;
current_message--;
}
request_cancel = false;
var slot_id = parseInt(document.getElementById("slot_id").value); var slot_id = parseInt(document.getElementById("slot_id").value);
var temperature = parseFloat(document.getElementById("temperature").value); var temperature = parseFloat(document.getElementById("temperature").value);
var prompt = " " + document.getElementById("message").value; var prompt = " " + document.getElementById("message").value;
if (!isNaN(slot_id) && !isNaN(temperature) && prompt.length > 0) { if (prompt.length > 1 && !isNaN(slot_id) && !isNaN(temperature)) {
if(!regen && canceled) { // use the new message
conversation.pop(); // delete incomplete interaction
current_message--;
}
canceled = false;
let options = { let options = {
slot_id, slot_id,
temperature temperature
@ -243,6 +284,7 @@ async function perform() {
current_message = -1; current_message = -1;
document.getElementById("system_promt_cb").checked = false; document.getElementById("system_promt_cb").checked = false;
document.getElementById("system_promt_cb").dispatchEvent(new Event("change")); document.getElementById("system_promt_cb").dispatchEvent(new Event("change"));
// include system prompt props
options.system_prompt = system_prompt; options.system_prompt = system_prompt;
options.anti_prompt = anti_prompt; options.anti_prompt = anti_prompt;
options.assistant_name = assistant_name_; options.assistant_name = assistant_name_;
@ -257,12 +299,12 @@ async function perform() {
updateView(); updateView();
document.getElementById("message").value = ""; document.getElementById("message").value = "";
document.getElementById("btn_send").disabled = true; document.getElementById("btn_send").disabled = true;
document.getElementById("btn_cancel").disabled = false;
document.getElementById("btn_cancel").innerText = "Cancel";
options.prompt = generatePrompt(); options.prompt = generatePrompt();
await call_llama(options); await call_llama(options);
} else { } else {
document.getElementById("conversation_view").innerText = alert("please, insert valid props.");
"please, insert valid props.";
} }
} }
)"; )";

View file

@ -210,6 +210,16 @@ struct server_parallel_context {
update_system_prompt = false; update_system_prompt = false;
} }
bool releaseSlot(int id) {
for(llama_client_slot &slot : slots) {
if(slot.id == id) {
slot.release();
return true;
}
}
return false;
}
void notifySystemPromptChanged() { void notifySystemPromptChanged() {
// release all slots // release all slots
for (llama_client_slot &slot : slots) for (llama_client_slot &slot : slots)
@ -824,9 +834,12 @@ int main(int argc, char **argv)
Server svr; Server svr;
svr.set_default_headers({{"Server", "llama.cpp"}, svr.Options("/(.*)",
{"Access-Control-Allow-Origin", "*"}, [&](const Request & /*req*/, Response &res) {
{"Access-Control-Allow-Headers", "content-type"}}); res.set_header("Access-Control-Allow-Methods", "*");
res.set_header("Access-Control-Allow-Headers", "content-type");
res.set_header("Access-Control-Allow-Origin", "*");
});
svr.Get("/", [&](const Request & /*req*/, Response &res) svr.Get("/", [&](const Request & /*req*/, Response &res)
{ res.set_content(index_html_, "text/html"); }); { res.set_content(index_html_, "text/html"); });
@ -836,14 +849,36 @@ int main(int argc, char **argv)
svr.Get("/props", [&llama](const Request & /*req*/, Response &res) svr.Get("/props", [&llama](const Request & /*req*/, Response &res)
{ {
res.set_header("Access-Control-Allow-Origin", "*");
json data = { json data = {
{ "user_name", llama.user_name.c_str() }, { "user_name", llama.user_name.c_str() },
{ "assistant_name", llama.assistant_name.c_str() } { "assistant_name", llama.assistant_name.c_str() }
}; };
res.set_content(data.dump(), "application/json"); }); res.set_content(data.dump(), "application/json"); });
svr.Get("/cancel", [&llama](const Request & req/*req*/, Response &res) {
res.set_header("Access-Control-Allow-Origin", "*");
if(req.has_param("slot_id")) {
int slot_id = std::stoi(req.get_param_value("slot_id"));
string result = "done";
if(!llama.releaseSlot(slot_id)) {
result = "wrong slot ID";
}
json data = {
{ "status", result }
};
res.set_content(data.dump(), "application/json");
} else {
json data = {
{ "error", "Missing parameter" }
};
res.set_content(data.dump(), "application/json");
}
});
svr.Post("/completion", [&llama](const Request &req, Response &res) svr.Post("/completion", [&llama](const Request &req, Response &res)
{ {
res.set_header("Access-Control-Allow-Origin", "*");
llama_client_slot* slot = llama.requestCompletion(json::parse(req.body)); llama_client_slot* slot = llama.requestCompletion(json::parse(req.body));
// Verify if the slot exist // Verify if the slot exist
if (slot) { if (slot) {
@ -855,7 +890,9 @@ int main(int argc, char **argv)
} }
if(slot->hasNewToken()) { // new token notification if(slot->hasNewToken()) { // new token notification
stringstream ss; stringstream ss;
json res_d = {{ "content", slot->sampled_tokens.back() }}; json res_d = {
{ "content", slot->sampled_tokens.back() },
{ "slot_id", slot->id }};
slot->sampled_tokens.pop_back(); slot->sampled_tokens.pop_back();
ss << "data: " << res_d.dump() << "\n\n"; ss << "data: " << res_d.dump() << "\n\n";
string result = ss.str(); string result = ss.str();