fix cors + regen + cancel funcs
This commit is contained in:
parent
f861ff916d
commit
8a8535bb6d
2 changed files with 122 additions and 43 deletions
|
@ -11,14 +11,14 @@ const char* index_html_ = R"(
|
||||||
<!DOCTYPE html>
|
<!DOCTYPE html>
|
||||||
<html>
|
<html>
|
||||||
<head>
|
<head>
|
||||||
<title>llama.cpp - server parallel PoC</title>
|
<title>llama.cpp - server parallel</title>
|
||||||
<script src="index.js"></script>
|
<script src="index.js"></script>
|
||||||
</head>
|
</head>
|
||||||
<body>
|
<body>
|
||||||
<div style="width: 90%;margin: auto;">
|
<div style="width: 90%;margin: auto;">
|
||||||
<h2>Server parallel - PoC</h2>
|
<h2>Server parallel</h2>
|
||||||
<form id="myForm">
|
<form id="myForm">
|
||||||
<input type="checkbox" id="system_promt_cb" name="myCheckbox" onchange="toggleSP() ">
|
<input type="checkbox" id="system_promt_cb" name="myCheckbox" onchange="toggle_system_prompt() ">
|
||||||
<label for="system_promt_cb">Use custom system prompt</label>
|
<label for="system_promt_cb">Use custom system prompt</label>
|
||||||
<br>
|
<br>
|
||||||
<div id="system_prompt_view" style="display: none;">
|
<div id="system_prompt_view" style="display: none;">
|
||||||
|
@ -27,7 +27,7 @@ const char* index_html_ = R"(
|
||||||
<input type="text" id="user_name" value="" placeholder="Anti prompt" required>
|
<input type="text" id="user_name" value="" placeholder="Anti prompt" required>
|
||||||
<label for="assistant_name">Assistant name</label>
|
<label for="assistant_name">Assistant name</label>
|
||||||
<input type="text" id="assistant_name" value="" placeholder="Assistant:" required>
|
<input type="text" id="assistant_name" value="" placeholder="Assistant:" required>
|
||||||
<button type="button" id="btn_reset" onclick="clearSP() " >Clear all</button>
|
<button type="button" id="btn_reset" onclick="clear_sp_props() " >Clear all</button>
|
||||||
</div>
|
</div>
|
||||||
<br>
|
<br>
|
||||||
<label for="slot_id">Slot ID (-1 load in a idle slot)</label>
|
<label for="slot_id">Slot ID (-1 load in a idle slot)</label>
|
||||||
|
@ -39,10 +39,9 @@ const char* index_html_ = R"(
|
||||||
<label for="message">Message</label>
|
<label for="message">Message</label>
|
||||||
<input id="message" style="width: 80%;" required>
|
<input id="message" style="width: 80%;" required>
|
||||||
<br><br>
|
<br><br>
|
||||||
<button type="button" id="btn_send" onclick="perform() " >Send</button>
|
<button type="button" id="btn_send" style="margin-right: 1rem;" onclick="perform(false) " >Send</button>
|
||||||
<br>
|
<button type="button" id="btn_cancel" style="margin-right: 1rem;" onclick="cancel() " disabled>Cancel</button>
|
||||||
<br>
|
<button type="button" id="btn_reset" onclick="resetView() " >Reset</button>
|
||||||
<button type="button" id="btn_reset" onclick="resetBtn() " >Reset</button>
|
|
||||||
</form>
|
</form>
|
||||||
<div id="conversation_view">
|
<div id="conversation_view">
|
||||||
</div>
|
</div>
|
||||||
|
@ -52,8 +51,12 @@ const char* index_html_ = R"(
|
||||||
)";
|
)";
|
||||||
|
|
||||||
const char* index_js_ = R"(
|
const char* index_js_ = R"(
|
||||||
let conversation = [];
|
let conversation = [];
|
||||||
let current_message = -1;
|
let current_message = -1;
|
||||||
|
let request_cancel = false;
|
||||||
|
let canceled = false;
|
||||||
|
let running = false;
|
||||||
|
let slot_id = -1;
|
||||||
|
|
||||||
const questions = [
|
const questions = [
|
||||||
"Who is Elon Musk?",
|
"Who is Elon Musk?",
|
||||||
|
@ -71,7 +74,7 @@ const questions = [
|
||||||
let user_name = "";
|
let user_name = "";
|
||||||
let assistant_name = "";
|
let assistant_name = "";
|
||||||
|
|
||||||
function toggleSP() {
|
function toggle_system_prompt() {
|
||||||
if(document.getElementById("system_promt_cb").checked) {
|
if(document.getElementById("system_promt_cb").checked) {
|
||||||
document.getElementById("system_prompt_view").style.display = "block";
|
document.getElementById("system_prompt_view").style.display = "block";
|
||||||
} else {
|
} else {
|
||||||
|
@ -79,16 +82,15 @@ function toggleSP() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function clearSP() {
|
function clear_sp_props() {
|
||||||
document.getElementById("sp_text").value = "";
|
document.getElementById("sp_text").value = "";
|
||||||
document.getElementById("anti_prompt").value = "";
|
document.getElementById("user_name").value = "";
|
||||||
document.getElementById("assistant_name").value = "";
|
document.getElementById("assistant_name").value = "";
|
||||||
}
|
}
|
||||||
|
|
||||||
docReady(async () => {
|
docReady(async () => {
|
||||||
document.getElementById("message").value =
|
document.getElementById("message").value =
|
||||||
questions[Math.floor(Math.random() * questions.length)];
|
questions[Math.floor(Math.random() * questions.length)];
|
||||||
|
|
||||||
// to keep the same prompt format in all clients
|
// to keep the same prompt format in all clients
|
||||||
const response = await fetch("/props");
|
const response = await fetch("/props");
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
|
@ -128,28 +130,31 @@ function updateView() {
|
||||||
}
|
}
|
||||||
|
|
||||||
async function call_llama(options) {
|
async function call_llama(options) {
|
||||||
const response = await fetch("/completion", {
|
|
||||||
method: "POST",
|
|
||||||
body: JSON.stringify(options),
|
|
||||||
headers: {
|
|
||||||
Connection: "keep-alive",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
Accept: "text/event-stream",
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
const reader = response.body.getReader();
|
|
||||||
let cont = true;
|
|
||||||
const decoder = new TextDecoder();
|
|
||||||
let leftover = ""; // Buffer for partially read lines
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
let cont = true;
|
controller = new AbortController();
|
||||||
|
const response = await fetch("/completion", {
|
||||||
while (cont) {
|
method: "POST",
|
||||||
|
body: JSON.stringify(options),
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
Accept: "text/event-stream",
|
||||||
|
}
|
||||||
|
});
|
||||||
|
const reader = response.body.getReader();
|
||||||
|
const decoder = new TextDecoder();
|
||||||
|
let leftover = ""; // Buffer for partially read lines
|
||||||
|
running = true;
|
||||||
|
while (running) {
|
||||||
|
// this no disposes the slot
|
||||||
|
if(request_cancel) {
|
||||||
|
running = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
const result = await reader.read();
|
const result = await reader.read();
|
||||||
if (result.done) {
|
if (result.done) {
|
||||||
document.getElementById("btn_send").disabled = false;
|
document.getElementById("btn_send").disabled = false;
|
||||||
|
document.getElementById("btn_cancel").disabled = true;
|
||||||
|
running = false;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -177,8 +182,9 @@ async function call_llama(options) {
|
||||||
if (match) {
|
if (match) {
|
||||||
result[match[1]] = match[2];
|
result[match[1]] = match[2];
|
||||||
// since we know this is llama.cpp, let's just decode the json in data
|
// since we know this is llama.cpp, let's just decode the json in data
|
||||||
if (result.data) {
|
if (result.data && !request_cancel) {
|
||||||
result.data = JSON.parse(result.data);
|
result.data = JSON.parse(result.data);
|
||||||
|
slot_id = result.data.slot_id;
|
||||||
conversation[current_message].assistant += result.data.content;
|
conversation[current_message].assistant += result.data.content;
|
||||||
updateView();
|
updateView();
|
||||||
}
|
}
|
||||||
|
@ -211,21 +217,56 @@ function generatePrompt() {
|
||||||
return prompt;
|
return prompt;
|
||||||
}
|
}
|
||||||
|
|
||||||
function resetBtn() {
|
async function resetView() {
|
||||||
|
if(running) {
|
||||||
|
await sendCancelSignal();
|
||||||
|
}
|
||||||
document.getElementById("slot_id").value = "-1";
|
document.getElementById("slot_id").value = "-1";
|
||||||
document.getElementById("temperature").value = "0.1";
|
document.getElementById("temperature").value = "0.1";
|
||||||
document.getElementById("message").value =
|
document.getElementById("message").value =
|
||||||
questions[Math.floor(Math.random() * questions.length)];
|
questions[Math.floor(Math.random() * questions.length)];
|
||||||
|
document.getElementById("btn_cancel").disabled = true;
|
||||||
|
document.getElementById("btn_cancel").innerText = "Cancel";
|
||||||
|
document.getElementById("btn_send").disabled = false;
|
||||||
document.getElementById("conversation_view").innerHTML = "";
|
document.getElementById("conversation_view").innerHTML = "";
|
||||||
conversation = [];
|
conversation = [];
|
||||||
current_message = -1;
|
current_message = -1;
|
||||||
|
canceled = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
async function perform() {
|
async function sendCancelSignal() {
|
||||||
|
await fetch(
|
||||||
|
"/cancel?slot_id=" + slot_id
|
||||||
|
);
|
||||||
|
request_cancel = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function cancel() {
|
||||||
|
if(!canceled) {
|
||||||
|
await sendCancelSignal();
|
||||||
|
document.getElementById("btn_send").disabled = false;
|
||||||
|
document.getElementById("btn_cancel").innerText = "Regenerate response";
|
||||||
|
canceled = true;
|
||||||
|
} else {
|
||||||
|
perform(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function perform(regen) {
|
||||||
|
if(regen) {
|
||||||
|
document.getElementById("message").value = conversation.pop().user;
|
||||||
|
current_message--;
|
||||||
|
}
|
||||||
|
request_cancel = false;
|
||||||
var slot_id = parseInt(document.getElementById("slot_id").value);
|
var slot_id = parseInt(document.getElementById("slot_id").value);
|
||||||
var temperature = parseFloat(document.getElementById("temperature").value);
|
var temperature = parseFloat(document.getElementById("temperature").value);
|
||||||
var prompt = " " + document.getElementById("message").value;
|
var prompt = " " + document.getElementById("message").value;
|
||||||
if (!isNaN(slot_id) && !isNaN(temperature) && prompt.length > 0) {
|
if (prompt.length > 1 && !isNaN(slot_id) && !isNaN(temperature)) {
|
||||||
|
if(!regen && canceled) { // use the new message
|
||||||
|
conversation.pop(); // delete incomplete interaction
|
||||||
|
current_message--;
|
||||||
|
}
|
||||||
|
canceled = false;
|
||||||
let options = {
|
let options = {
|
||||||
slot_id,
|
slot_id,
|
||||||
temperature
|
temperature
|
||||||
|
@ -243,6 +284,7 @@ async function perform() {
|
||||||
current_message = -1;
|
current_message = -1;
|
||||||
document.getElementById("system_promt_cb").checked = false;
|
document.getElementById("system_promt_cb").checked = false;
|
||||||
document.getElementById("system_promt_cb").dispatchEvent(new Event("change"));
|
document.getElementById("system_promt_cb").dispatchEvent(new Event("change"));
|
||||||
|
// include system prompt props
|
||||||
options.system_prompt = system_prompt;
|
options.system_prompt = system_prompt;
|
||||||
options.anti_prompt = anti_prompt;
|
options.anti_prompt = anti_prompt;
|
||||||
options.assistant_name = assistant_name_;
|
options.assistant_name = assistant_name_;
|
||||||
|
@ -257,12 +299,12 @@ async function perform() {
|
||||||
updateView();
|
updateView();
|
||||||
document.getElementById("message").value = "";
|
document.getElementById("message").value = "";
|
||||||
document.getElementById("btn_send").disabled = true;
|
document.getElementById("btn_send").disabled = true;
|
||||||
|
document.getElementById("btn_cancel").disabled = false;
|
||||||
|
document.getElementById("btn_cancel").innerText = "Cancel";
|
||||||
options.prompt = generatePrompt();
|
options.prompt = generatePrompt();
|
||||||
await call_llama(options);
|
await call_llama(options);
|
||||||
} else {
|
} else {
|
||||||
document.getElementById("conversation_view").innerText =
|
alert("please, insert valid props.");
|
||||||
"please, insert valid props.";
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
)";
|
)";
|
||||||
|
|
|
@ -210,6 +210,16 @@ struct server_parallel_context {
|
||||||
update_system_prompt = false;
|
update_system_prompt = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool releaseSlot(int id) {
|
||||||
|
for(llama_client_slot &slot : slots) {
|
||||||
|
if(slot.id == id) {
|
||||||
|
slot.release();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
void notifySystemPromptChanged() {
|
void notifySystemPromptChanged() {
|
||||||
// release all slots
|
// release all slots
|
||||||
for (llama_client_slot &slot : slots)
|
for (llama_client_slot &slot : slots)
|
||||||
|
@ -824,9 +834,12 @@ int main(int argc, char **argv)
|
||||||
|
|
||||||
Server svr;
|
Server svr;
|
||||||
|
|
||||||
svr.set_default_headers({{"Server", "llama.cpp"},
|
svr.Options("/(.*)",
|
||||||
{"Access-Control-Allow-Origin", "*"},
|
[&](const Request & /*req*/, Response &res) {
|
||||||
{"Access-Control-Allow-Headers", "content-type"}});
|
res.set_header("Access-Control-Allow-Methods", "*");
|
||||||
|
res.set_header("Access-Control-Allow-Headers", "content-type");
|
||||||
|
res.set_header("Access-Control-Allow-Origin", "*");
|
||||||
|
});
|
||||||
|
|
||||||
svr.Get("/", [&](const Request & /*req*/, Response &res)
|
svr.Get("/", [&](const Request & /*req*/, Response &res)
|
||||||
{ res.set_content(index_html_, "text/html"); });
|
{ res.set_content(index_html_, "text/html"); });
|
||||||
|
@ -836,14 +849,36 @@ int main(int argc, char **argv)
|
||||||
|
|
||||||
svr.Get("/props", [&llama](const Request & /*req*/, Response &res)
|
svr.Get("/props", [&llama](const Request & /*req*/, Response &res)
|
||||||
{
|
{
|
||||||
|
res.set_header("Access-Control-Allow-Origin", "*");
|
||||||
json data = {
|
json data = {
|
||||||
{ "user_name", llama.user_name.c_str() },
|
{ "user_name", llama.user_name.c_str() },
|
||||||
{ "assistant_name", llama.assistant_name.c_str() }
|
{ "assistant_name", llama.assistant_name.c_str() }
|
||||||
};
|
};
|
||||||
res.set_content(data.dump(), "application/json"); });
|
res.set_content(data.dump(), "application/json"); });
|
||||||
|
|
||||||
|
svr.Get("/cancel", [&llama](const Request & req/*req*/, Response &res) {
|
||||||
|
res.set_header("Access-Control-Allow-Origin", "*");
|
||||||
|
if(req.has_param("slot_id")) {
|
||||||
|
int slot_id = std::stoi(req.get_param_value("slot_id"));
|
||||||
|
string result = "done";
|
||||||
|
if(!llama.releaseSlot(slot_id)) {
|
||||||
|
result = "wrong slot ID";
|
||||||
|
}
|
||||||
|
json data = {
|
||||||
|
{ "status", result }
|
||||||
|
};
|
||||||
|
res.set_content(data.dump(), "application/json");
|
||||||
|
} else {
|
||||||
|
json data = {
|
||||||
|
{ "error", "Missing parameter" }
|
||||||
|
};
|
||||||
|
res.set_content(data.dump(), "application/json");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
svr.Post("/completion", [&llama](const Request &req, Response &res)
|
svr.Post("/completion", [&llama](const Request &req, Response &res)
|
||||||
{
|
{
|
||||||
|
res.set_header("Access-Control-Allow-Origin", "*");
|
||||||
llama_client_slot* slot = llama.requestCompletion(json::parse(req.body));
|
llama_client_slot* slot = llama.requestCompletion(json::parse(req.body));
|
||||||
// Verify if the slot exist
|
// Verify if the slot exist
|
||||||
if (slot) {
|
if (slot) {
|
||||||
|
@ -855,7 +890,9 @@ int main(int argc, char **argv)
|
||||||
}
|
}
|
||||||
if(slot->hasNewToken()) { // new token notification
|
if(slot->hasNewToken()) { // new token notification
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
json res_d = {{ "content", slot->sampled_tokens.back() }};
|
json res_d = {
|
||||||
|
{ "content", slot->sampled_tokens.back() },
|
||||||
|
{ "slot_id", slot->id }};
|
||||||
slot->sampled_tokens.pop_back();
|
slot->sampled_tokens.pop_back();
|
||||||
ss << "data: " << res_d.dump() << "\n\n";
|
ss << "data: " << res_d.dump() << "\n\n";
|
||||||
string result = ss.str();
|
string result = ss.str();
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue