fixed cancel + removed useless code

This commit is contained in:
FSSRepo 2023-10-09 07:53:00 -04:00
parent c8d7b1b897
commit 59e7c0c51b
2 changed files with 24 additions and 59 deletions

View file

@ -53,10 +53,9 @@ const char* index_html_ = R"(
const char* index_js_ = R"(
let conversation = [];
let current_message = -1;
let request_cancel = false;
let canceled = false;
let running = false;
let slot_id = -1;
var controller;
const questions = [
"Who is Elon Musk?",
@ -92,7 +91,7 @@ docReady(async () => {
document.getElementById("message").value =
questions[Math.floor(Math.random() * questions.length)];
// to keep the same prompt format in all clients
const response = await fetch("/props");
const response = await fetch("http://localhost:8080/props");
if (!response.ok) {
alert(`HTTP error! Status: ${response.status}`);
}
@ -131,29 +130,25 @@ function updateView() {
async function call_llama(options) {
try {
const response = await fetch("/completion", {
controller = new AbortController();
signal = controller.signal;
const response = await fetch("http://localhost:8080/completion", {
method: "POST",
body: JSON.stringify(options),
headers: {
"Content-Type": "application/json",
Accept: "text/event-stream",
}
},
signal: signal
});
const reader = response.body.getReader();
const decoder = new TextDecoder();
let leftover = ""; // Buffer for partially read lines
running = true;
while (running) {
// this no disposes the slot
if(request_cancel) {
running = false;
break;
}
while (current_message >= 0) {
const result = await reader.read();
if (result.done) {
document.getElementById("btn_send").disabled = false;
document.getElementById("btn_cancel").disabled = true;
running = false;
break;
}
@ -181,7 +176,7 @@ async function call_llama(options) {
if (match) {
result[match[1]] = match[2];
// since we know this is llama.cpp, let's just decode the json in data
if (result.data && !request_cancel) {
if (result.data && current_message >= 0) {
result.data = JSON.parse(result.data);
slot_id = result.data.slot_id;
conversation[current_message].assistant += result.data.content;
@ -194,7 +189,6 @@ async function call_llama(options) {
if (e.name !== "AbortError") {
console.error("llama error: ", e);
}
throw e;
}
}
@ -213,12 +207,14 @@ function generatePrompt() {
prompt += assistant_name + conversation[index].assistant;
}
}
console.log(prompt)
return prompt;
}
async function resetView() {
if(running) {
await sendCancelSignal();
if(controller) {
controller.abort();
controller = null;
}
document.getElementById("slot_id").value = "-1";
document.getElementById("temperature").value = "0.1";
@ -233,16 +229,12 @@ async function resetView() {
canceled = false;
}
async function sendCancelSignal() {
await fetch(
"/cancel?slot_id=" + slot_id
);
request_cancel = true;
}
async function cancel() {
if(!canceled) {
await sendCancelSignal();
if(controller) {
controller.abort();
controller = null;
}
document.getElementById("btn_send").disabled = false;
document.getElementById("btn_cancel").innerText = "Regenerate response";
canceled = true;
@ -256,7 +248,6 @@ async function perform(regen) {
document.getElementById("message").value = conversation.pop().user;
current_message--;
}
request_cancel = false;
var slot_id = parseInt(document.getElementById("slot_id").value);
var temperature = parseFloat(document.getElementById("temperature").value);
var prompt = " " + document.getElementById("message").value;

View file

@ -210,16 +210,6 @@ struct server_parallel_context {
update_system_prompt = false;
}
bool releaseSlot(int id) {
for(llama_client_slot &slot : slots) {
if(slot.id == id) {
slot.release();
return true;
}
}
return false;
}
void notifySystemPromptChanged() {
// release all slots
for (llama_client_slot &slot : slots)
@ -357,6 +347,7 @@ struct server_parallel_context {
std::vector<llama_token> tokens_prompt;
tokens_prompt = ::llama_tokenize(ctx, slot.prompt, false);
slot.n_tokens_predicted = 0;
slot.sampled_tokens.clear();
for (size_t i = 0; i < tokens_prompt.size(); ++i) {
batch.token [batch.n_tokens] = tokens_prompt[i];
@ -856,34 +847,13 @@ int main(int argc, char **argv)
};
res.set_content(data.dump(), "application/json"); });
svr.Get("/cancel", [&llama](const Request & req/*req*/, Response &res) {
res.set_header("Access-Control-Allow-Origin", "*");
if(req.has_param("slot_id")) {
int slot_id = std::stoi(req.get_param_value("slot_id"));
string result = "done";
if(!llama.releaseSlot(slot_id)) {
result = "wrong slot ID";
}
json data = {
{ "status", result }
};
res.set_content(data.dump(), "application/json");
} else {
json data = {
{ "error", "Missing parameter" }
};
res.set_content(data.dump(), "application/json");
}
});
svr.Post("/completion", [&llama](const Request &req, Response &res)
{
res.set_header("Access-Control-Allow-Origin", "*");
llama_client_slot* slot = llama.requestCompletion(json::parse(req.body));
// Verify if the slot exist
if (slot) {
res.set_chunked_content_provider("text/event-stream",
[slot](size_t /*offset*/, DataSink &sink) {
auto content_provider = [slot](size_t /*offset*/, DataSink &sink) {
if(slot->available()) { // slot has been released
sink.done();
return false;
@ -902,7 +872,11 @@ int main(int argc, char **argv)
}
}
return true;
});
};
auto on_complete = [slot] (bool) {
slot->release();
};
res.set_chunked_content_provider("text/event-stream", content_provider, on_complete);
} else {
LOG_TEE("slot unavailable\n");
res.status = 404;