fix cors + regen + cancel funcs

This commit is contained in:
FSSRepo 2023-10-08 22:30:43 -04:00
parent f861ff916d
commit 8a8535bb6d
2 changed files with 122 additions and 43 deletions

View file

@ -11,14 +11,14 @@ const char* index_html_ = R"(
<!DOCTYPE html>
<html>
<head>
<title>llama.cpp - server parallel PoC</title>
<title>llama.cpp - server parallel</title>
<script src="index.js"></script>
</head>
<body>
<div style="width: 90%;margin: auto;">
<h2>Server parallel - PoC</h2>
<h2>Server parallel</h2>
<form id="myForm">
<input type="checkbox" id="system_promt_cb" name="myCheckbox" onchange="toggleSP() ">
<input type="checkbox" id="system_promt_cb" name="myCheckbox" onchange="toggle_system_prompt() ">
<label for="system_promt_cb">Use custom system prompt</label>
<br>
<div id="system_prompt_view" style="display: none;">
@ -27,7 +27,7 @@ const char* index_html_ = R"(
<input type="text" id="user_name" value="" placeholder="Anti prompt" required>
<label for="assistant_name">Assistant name</label>
<input type="text" id="assistant_name" value="" placeholder="Assistant:" required>
<button type="button" id="btn_reset" onclick="clearSP() " >Clear all</button>
<button type="button" id="btn_reset" onclick="clear_sp_props() " >Clear all</button>
</div>
<br>
<label for="slot_id">Slot ID (-1 load in a idle slot)</label>
@ -39,10 +39,9 @@ const char* index_html_ = R"(
<label for="message">Message</label>
<input id="message" style="width: 80%;" required>
<br><br>
<button type="button" id="btn_send" onclick="perform() " >Send</button>
<br>
<br>
<button type="button" id="btn_reset" onclick="resetBtn() " >Reset</button>
<button type="button" id="btn_send" style="margin-right: 1rem;" onclick="perform(false) " >Send</button>
<button type="button" id="btn_cancel" style="margin-right: 1rem;" onclick="cancel() " disabled>Cancel</button>
<button type="button" id="btn_reset" onclick="resetView() " >Reset</button>
</form>
<div id="conversation_view">
</div>
@ -52,8 +51,12 @@ const char* index_html_ = R"(
)";
const char* index_js_ = R"(
let conversation = [];
let current_message = -1;
let conversation = [];
let current_message = -1;
let request_cancel = false;
let canceled = false;
let running = false;
let slot_id = -1;
const questions = [
"Who is Elon Musk?",
@ -71,7 +74,7 @@ const questions = [
let user_name = "";
let assistant_name = "";
function toggleSP() {
function toggle_system_prompt() {
if(document.getElementById("system_promt_cb").checked) {
document.getElementById("system_prompt_view").style.display = "block";
} else {
@ -79,16 +82,15 @@ function toggleSP() {
}
}
function clearSP() {
function clear_sp_props() {
document.getElementById("sp_text").value = "";
document.getElementById("anti_prompt").value = "";
document.getElementById("user_name").value = "";
document.getElementById("assistant_name").value = "";
}
docReady(async () => {
document.getElementById("message").value =
questions[Math.floor(Math.random() * questions.length)];
// to keep the same prompt format in all clients
const response = await fetch("/props");
if (!response.ok) {
@ -128,28 +130,31 @@ function updateView() {
}
async function call_llama(options) {
try {
controller = new AbortController();
const response = await fetch("/completion", {
method: "POST",
body: JSON.stringify(options),
headers: {
Connection: "keep-alive",
"Content-Type": "application/json",
Accept: "text/event-stream",
},
}
});
const reader = response.body.getReader();
let cont = true;
const decoder = new TextDecoder();
let leftover = ""; // Buffer for partially read lines
try {
let cont = true;
while (cont) {
running = true;
while (running) {
// this no disposes the slot
if(request_cancel) {
running = false;
break;
}
const result = await reader.read();
if (result.done) {
document.getElementById("btn_send").disabled = false;
document.getElementById("btn_cancel").disabled = true;
running = false;
break;
}
@ -177,8 +182,9 @@ async function call_llama(options) {
if (match) {
result[match[1]] = match[2];
// since we know this is llama.cpp, let's just decode the json in data
if (result.data) {
if (result.data && !request_cancel) {
result.data = JSON.parse(result.data);
slot_id = result.data.slot_id;
conversation[current_message].assistant += result.data.content;
updateView();
}
@ -211,21 +217,56 @@ function generatePrompt() {
return prompt;
}
function resetBtn() {
async function resetView() {
if(running) {
await sendCancelSignal();
}
document.getElementById("slot_id").value = "-1";
document.getElementById("temperature").value = "0.1";
document.getElementById("message").value =
questions[Math.floor(Math.random() * questions.length)];
document.getElementById("btn_cancel").disabled = true;
document.getElementById("btn_cancel").innerText = "Cancel";
document.getElementById("btn_send").disabled = false;
document.getElementById("conversation_view").innerHTML = "";
conversation = [];
current_message = -1;
canceled = false;
}
async function perform() {
async function sendCancelSignal() {
await fetch(
"/cancel?slot_id=" + slot_id
);
request_cancel = true;
}
async function cancel() {
if(!canceled) {
await sendCancelSignal();
document.getElementById("btn_send").disabled = false;
document.getElementById("btn_cancel").innerText = "Regenerate response";
canceled = true;
} else {
perform(true);
}
}
async function perform(regen) {
if(regen) {
document.getElementById("message").value = conversation.pop().user;
current_message--;
}
request_cancel = false;
var slot_id = parseInt(document.getElementById("slot_id").value);
var temperature = parseFloat(document.getElementById("temperature").value);
var prompt = " " + document.getElementById("message").value;
if (!isNaN(slot_id) && !isNaN(temperature) && prompt.length > 0) {
if (prompt.length > 1 && !isNaN(slot_id) && !isNaN(temperature)) {
if(!regen && canceled) { // use the new message
conversation.pop(); // delete incomplete interaction
current_message--;
}
canceled = false;
let options = {
slot_id,
temperature
@ -243,6 +284,7 @@ async function perform() {
current_message = -1;
document.getElementById("system_promt_cb").checked = false;
document.getElementById("system_promt_cb").dispatchEvent(new Event("change"));
// include system prompt props
options.system_prompt = system_prompt;
options.anti_prompt = anti_prompt;
options.assistant_name = assistant_name_;
@ -257,12 +299,12 @@ async function perform() {
updateView();
document.getElementById("message").value = "";
document.getElementById("btn_send").disabled = true;
document.getElementById("btn_cancel").disabled = false;
document.getElementById("btn_cancel").innerText = "Cancel";
options.prompt = generatePrompt();
await call_llama(options);
} else {
document.getElementById("conversation_view").innerText =
"please, insert valid props.";
alert("please, insert valid props.");
}
}
)";

View file

@ -210,6 +210,16 @@ struct server_parallel_context {
update_system_prompt = false;
}
bool releaseSlot(int id) {
for(llama_client_slot &slot : slots) {
if(slot.id == id) {
slot.release();
return true;
}
}
return false;
}
void notifySystemPromptChanged() {
// release all slots
for (llama_client_slot &slot : slots)
@ -824,9 +834,12 @@ int main(int argc, char **argv)
Server svr;
svr.set_default_headers({{"Server", "llama.cpp"},
{"Access-Control-Allow-Origin", "*"},
{"Access-Control-Allow-Headers", "content-type"}});
svr.Options("/(.*)",
[&](const Request & /*req*/, Response &res) {
res.set_header("Access-Control-Allow-Methods", "*");
res.set_header("Access-Control-Allow-Headers", "content-type");
res.set_header("Access-Control-Allow-Origin", "*");
});
svr.Get("/", [&](const Request & /*req*/, Response &res)
{ res.set_content(index_html_, "text/html"); });
@ -836,14 +849,36 @@ int main(int argc, char **argv)
svr.Get("/props", [&llama](const Request & /*req*/, Response &res)
{
res.set_header("Access-Control-Allow-Origin", "*");
json data = {
{ "user_name", llama.user_name.c_str() },
{ "assistant_name", llama.assistant_name.c_str() }
};
res.set_content(data.dump(), "application/json"); });
svr.Get("/cancel", [&llama](const Request & req/*req*/, Response &res) {
res.set_header("Access-Control-Allow-Origin", "*");
if(req.has_param("slot_id")) {
int slot_id = std::stoi(req.get_param_value("slot_id"));
string result = "done";
if(!llama.releaseSlot(slot_id)) {
result = "wrong slot ID";
}
json data = {
{ "status", result }
};
res.set_content(data.dump(), "application/json");
} else {
json data = {
{ "error", "Missing parameter" }
};
res.set_content(data.dump(), "application/json");
}
});
svr.Post("/completion", [&llama](const Request &req, Response &res)
{
res.set_header("Access-Control-Allow-Origin", "*");
llama_client_slot* slot = llama.requestCompletion(json::parse(req.body));
// Verify if the slot exist
if (slot) {
@ -855,7 +890,9 @@ int main(int argc, char **argv)
}
if(slot->hasNewToken()) { // new token notification
stringstream ss;
json res_d = {{ "content", slot->sampled_tokens.back() }};
json res_d = {
{ "content", slot->sampled_tokens.back() },
{ "slot_id", slot->id }};
slot->sampled_tokens.pop_back();
ss << "data: " << res_d.dump() << "\n\n";
string result = ss.str();