fix cors + regen + cancel funcs
This commit is contained in:
parent
f861ff916d
commit
8a8535bb6d
2 changed files with 122 additions and 43 deletions
|
@ -11,14 +11,14 @@ const char* index_html_ = R"(
|
|||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>llama.cpp - server parallel PoC</title>
|
||||
<title>llama.cpp - server parallel</title>
|
||||
<script src="index.js"></script>
|
||||
</head>
|
||||
<body>
|
||||
<div style="width: 90%;margin: auto;">
|
||||
<h2>Server parallel - PoC</h2>
|
||||
<h2>Server parallel</h2>
|
||||
<form id="myForm">
|
||||
<input type="checkbox" id="system_promt_cb" name="myCheckbox" onchange="toggleSP() ">
|
||||
<input type="checkbox" id="system_promt_cb" name="myCheckbox" onchange="toggle_system_prompt() ">
|
||||
<label for="system_promt_cb">Use custom system prompt</label>
|
||||
<br>
|
||||
<div id="system_prompt_view" style="display: none;">
|
||||
|
@ -27,7 +27,7 @@ const char* index_html_ = R"(
|
|||
<input type="text" id="user_name" value="" placeholder="Anti prompt" required>
|
||||
<label for="assistant_name">Assistant name</label>
|
||||
<input type="text" id="assistant_name" value="" placeholder="Assistant:" required>
|
||||
<button type="button" id="btn_reset" onclick="clearSP() " >Clear all</button>
|
||||
<button type="button" id="btn_reset" onclick="clear_sp_props() " >Clear all</button>
|
||||
</div>
|
||||
<br>
|
||||
<label for="slot_id">Slot ID (-1 load in a idle slot)</label>
|
||||
|
@ -39,10 +39,9 @@ const char* index_html_ = R"(
|
|||
<label for="message">Message</label>
|
||||
<input id="message" style="width: 80%;" required>
|
||||
<br><br>
|
||||
<button type="button" id="btn_send" onclick="perform() " >Send</button>
|
||||
<br>
|
||||
<br>
|
||||
<button type="button" id="btn_reset" onclick="resetBtn() " >Reset</button>
|
||||
<button type="button" id="btn_send" style="margin-right: 1rem;" onclick="perform(false) " >Send</button>
|
||||
<button type="button" id="btn_cancel" style="margin-right: 1rem;" onclick="cancel() " disabled>Cancel</button>
|
||||
<button type="button" id="btn_reset" onclick="resetView() " >Reset</button>
|
||||
</form>
|
||||
<div id="conversation_view">
|
||||
</div>
|
||||
|
@ -52,8 +51,12 @@ const char* index_html_ = R"(
|
|||
)";
|
||||
|
||||
const char* index_js_ = R"(
|
||||
let conversation = [];
|
||||
let current_message = -1;
|
||||
let conversation = [];
|
||||
let current_message = -1;
|
||||
let request_cancel = false;
|
||||
let canceled = false;
|
||||
let running = false;
|
||||
let slot_id = -1;
|
||||
|
||||
const questions = [
|
||||
"Who is Elon Musk?",
|
||||
|
@ -71,7 +74,7 @@ const questions = [
|
|||
let user_name = "";
|
||||
let assistant_name = "";
|
||||
|
||||
function toggleSP() {
|
||||
function toggle_system_prompt() {
|
||||
if(document.getElementById("system_promt_cb").checked) {
|
||||
document.getElementById("system_prompt_view").style.display = "block";
|
||||
} else {
|
||||
|
@ -79,16 +82,15 @@ function toggleSP() {
|
|||
}
|
||||
}
|
||||
|
||||
function clearSP() {
|
||||
function clear_sp_props() {
|
||||
document.getElementById("sp_text").value = "";
|
||||
document.getElementById("anti_prompt").value = "";
|
||||
document.getElementById("user_name").value = "";
|
||||
document.getElementById("assistant_name").value = "";
|
||||
}
|
||||
|
||||
docReady(async () => {
|
||||
document.getElementById("message").value =
|
||||
questions[Math.floor(Math.random() * questions.length)];
|
||||
|
||||
// to keep the same prompt format in all clients
|
||||
const response = await fetch("/props");
|
||||
if (!response.ok) {
|
||||
|
@ -128,28 +130,31 @@ function updateView() {
|
|||
}
|
||||
|
||||
async function call_llama(options) {
|
||||
const response = await fetch("/completion", {
|
||||
method: "POST",
|
||||
body: JSON.stringify(options),
|
||||
headers: {
|
||||
Connection: "keep-alive",
|
||||
"Content-Type": "application/json",
|
||||
Accept: "text/event-stream",
|
||||
},
|
||||
});
|
||||
|
||||
const reader = response.body.getReader();
|
||||
let cont = true;
|
||||
const decoder = new TextDecoder();
|
||||
let leftover = ""; // Buffer for partially read lines
|
||||
|
||||
try {
|
||||
let cont = true;
|
||||
|
||||
while (cont) {
|
||||
controller = new AbortController();
|
||||
const response = await fetch("/completion", {
|
||||
method: "POST",
|
||||
body: JSON.stringify(options),
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Accept: "text/event-stream",
|
||||
}
|
||||
});
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let leftover = ""; // Buffer for partially read lines
|
||||
running = true;
|
||||
while (running) {
|
||||
// this no disposes the slot
|
||||
if(request_cancel) {
|
||||
running = false;
|
||||
break;
|
||||
}
|
||||
const result = await reader.read();
|
||||
if (result.done) {
|
||||
document.getElementById("btn_send").disabled = false;
|
||||
document.getElementById("btn_cancel").disabled = true;
|
||||
running = false;
|
||||
break;
|
||||
}
|
||||
|
||||
|
@ -177,8 +182,9 @@ async function call_llama(options) {
|
|||
if (match) {
|
||||
result[match[1]] = match[2];
|
||||
// since we know this is llama.cpp, let's just decode the json in data
|
||||
if (result.data) {
|
||||
if (result.data && !request_cancel) {
|
||||
result.data = JSON.parse(result.data);
|
||||
slot_id = result.data.slot_id;
|
||||
conversation[current_message].assistant += result.data.content;
|
||||
updateView();
|
||||
}
|
||||
|
@ -211,21 +217,56 @@ function generatePrompt() {
|
|||
return prompt;
|
||||
}
|
||||
|
||||
function resetBtn() {
|
||||
async function resetView() {
|
||||
if(running) {
|
||||
await sendCancelSignal();
|
||||
}
|
||||
document.getElementById("slot_id").value = "-1";
|
||||
document.getElementById("temperature").value = "0.1";
|
||||
document.getElementById("message").value =
|
||||
questions[Math.floor(Math.random() * questions.length)];
|
||||
document.getElementById("btn_cancel").disabled = true;
|
||||
document.getElementById("btn_cancel").innerText = "Cancel";
|
||||
document.getElementById("btn_send").disabled = false;
|
||||
document.getElementById("conversation_view").innerHTML = "";
|
||||
conversation = [];
|
||||
current_message = -1;
|
||||
canceled = false;
|
||||
}
|
||||
|
||||
async function perform() {
|
||||
async function sendCancelSignal() {
|
||||
await fetch(
|
||||
"/cancel?slot_id=" + slot_id
|
||||
);
|
||||
request_cancel = true;
|
||||
}
|
||||
|
||||
async function cancel() {
|
||||
if(!canceled) {
|
||||
await sendCancelSignal();
|
||||
document.getElementById("btn_send").disabled = false;
|
||||
document.getElementById("btn_cancel").innerText = "Regenerate response";
|
||||
canceled = true;
|
||||
} else {
|
||||
perform(true);
|
||||
}
|
||||
}
|
||||
|
||||
async function perform(regen) {
|
||||
if(regen) {
|
||||
document.getElementById("message").value = conversation.pop().user;
|
||||
current_message--;
|
||||
}
|
||||
request_cancel = false;
|
||||
var slot_id = parseInt(document.getElementById("slot_id").value);
|
||||
var temperature = parseFloat(document.getElementById("temperature").value);
|
||||
var prompt = " " + document.getElementById("message").value;
|
||||
if (!isNaN(slot_id) && !isNaN(temperature) && prompt.length > 0) {
|
||||
if (prompt.length > 1 && !isNaN(slot_id) && !isNaN(temperature)) {
|
||||
if(!regen && canceled) { // use the new message
|
||||
conversation.pop(); // delete incomplete interaction
|
||||
current_message--;
|
||||
}
|
||||
canceled = false;
|
||||
let options = {
|
||||
slot_id,
|
||||
temperature
|
||||
|
@ -243,6 +284,7 @@ async function perform() {
|
|||
current_message = -1;
|
||||
document.getElementById("system_promt_cb").checked = false;
|
||||
document.getElementById("system_promt_cb").dispatchEvent(new Event("change"));
|
||||
// include system prompt props
|
||||
options.system_prompt = system_prompt;
|
||||
options.anti_prompt = anti_prompt;
|
||||
options.assistant_name = assistant_name_;
|
||||
|
@ -257,12 +299,12 @@ async function perform() {
|
|||
updateView();
|
||||
document.getElementById("message").value = "";
|
||||
document.getElementById("btn_send").disabled = true;
|
||||
document.getElementById("btn_cancel").disabled = false;
|
||||
document.getElementById("btn_cancel").innerText = "Cancel";
|
||||
options.prompt = generatePrompt();
|
||||
await call_llama(options);
|
||||
} else {
|
||||
document.getElementById("conversation_view").innerText =
|
||||
"please, insert valid props.";
|
||||
alert("please, insert valid props.");
|
||||
}
|
||||
}
|
||||
|
||||
)";
|
||||
|
|
|
@ -210,6 +210,16 @@ struct server_parallel_context {
|
|||
update_system_prompt = false;
|
||||
}
|
||||
|
||||
bool releaseSlot(int id) {
|
||||
for(llama_client_slot &slot : slots) {
|
||||
if(slot.id == id) {
|
||||
slot.release();
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void notifySystemPromptChanged() {
|
||||
// release all slots
|
||||
for (llama_client_slot &slot : slots)
|
||||
|
@ -824,9 +834,12 @@ int main(int argc, char **argv)
|
|||
|
||||
Server svr;
|
||||
|
||||
svr.set_default_headers({{"Server", "llama.cpp"},
|
||||
{"Access-Control-Allow-Origin", "*"},
|
||||
{"Access-Control-Allow-Headers", "content-type"}});
|
||||
svr.Options("/(.*)",
|
||||
[&](const Request & /*req*/, Response &res) {
|
||||
res.set_header("Access-Control-Allow-Methods", "*");
|
||||
res.set_header("Access-Control-Allow-Headers", "content-type");
|
||||
res.set_header("Access-Control-Allow-Origin", "*");
|
||||
});
|
||||
|
||||
svr.Get("/", [&](const Request & /*req*/, Response &res)
|
||||
{ res.set_content(index_html_, "text/html"); });
|
||||
|
@ -836,14 +849,36 @@ int main(int argc, char **argv)
|
|||
|
||||
svr.Get("/props", [&llama](const Request & /*req*/, Response &res)
|
||||
{
|
||||
res.set_header("Access-Control-Allow-Origin", "*");
|
||||
json data = {
|
||||
{ "user_name", llama.user_name.c_str() },
|
||||
{ "assistant_name", llama.assistant_name.c_str() }
|
||||
};
|
||||
res.set_content(data.dump(), "application/json"); });
|
||||
|
||||
svr.Get("/cancel", [&llama](const Request & req/*req*/, Response &res) {
|
||||
res.set_header("Access-Control-Allow-Origin", "*");
|
||||
if(req.has_param("slot_id")) {
|
||||
int slot_id = std::stoi(req.get_param_value("slot_id"));
|
||||
string result = "done";
|
||||
if(!llama.releaseSlot(slot_id)) {
|
||||
result = "wrong slot ID";
|
||||
}
|
||||
json data = {
|
||||
{ "status", result }
|
||||
};
|
||||
res.set_content(data.dump(), "application/json");
|
||||
} else {
|
||||
json data = {
|
||||
{ "error", "Missing parameter" }
|
||||
};
|
||||
res.set_content(data.dump(), "application/json");
|
||||
}
|
||||
});
|
||||
|
||||
svr.Post("/completion", [&llama](const Request &req, Response &res)
|
||||
{
|
||||
res.set_header("Access-Control-Allow-Origin", "*");
|
||||
llama_client_slot* slot = llama.requestCompletion(json::parse(req.body));
|
||||
// Verify if the slot exist
|
||||
if (slot) {
|
||||
|
@ -855,7 +890,9 @@ int main(int argc, char **argv)
|
|||
}
|
||||
if(slot->hasNewToken()) { // new token notification
|
||||
stringstream ss;
|
||||
json res_d = {{ "content", slot->sampled_tokens.back() }};
|
||||
json res_d = {
|
||||
{ "content", slot->sampled_tokens.back() },
|
||||
{ "slot_id", slot->id }};
|
||||
slot->sampled_tokens.pop_back();
|
||||
ss << "data: " << res_d.dump() << "\n\n";
|
||||
string result = ss.str();
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue