add change system prompt on runtime, improve README

This commit is contained in:
FSSRepo 2023-10-05 14:36:58 -04:00
parent 9e6e714dc5
commit dc102b4493
4 changed files with 446 additions and 283 deletions

View file

@ -1,6 +1,6 @@
# llama.cpp/example/server-parallel # llama.cpp/example/server-parallel
This example demonstrates a PoC HTTP API server that handles simulataneus requests. This example demonstrates a PoC HTTP API server that handles simulataneus requests. Long prompts are not supported.
## Quick Start ## Quick Start
@ -9,15 +9,65 @@ To get started right away, run the following command, making sure to use the cor
### Unix-based systems (Linux, macOS, etc.): ### Unix-based systems (Linux, macOS, etc.):
```bash ```bash
./server-parallel -m models/7B/ggml-model.gguf --ctx_size 2048 -t 4 -ngl 33 --batch-size 512 --parallel 3 -n 512 --cont-batching --reverse-prompt "User:" ./server-parallel -m models/7B/ggml-model.gguf --ctx_size 2048 -t 4 -ngl 33 --batch-size 512 --parallel 3 -n 512 --cont-batching
``` ```
### Windows: ### Windows:
```powershell ```powershell
server-parallel.exe -m models\7B\ggml-model.gguf --ctx_size 2048 -t 4 -ngl 33 --batch-size 512 --parallel 3 -n 512 --cont-batching --reverse-prompt "User:" server-parallel.exe -m models\7B\ggml-model.gguf --ctx_size 2048 -t 4 -ngl 33 --batch-size 512 --parallel 3 -n 512 --cont-batching
``` ```
The above command will start a server that by default listens on `127.0.0.1:8080`. The above command will start a server that by default listens on `127.0.0.1:8080`.
You can consume the endpoints with Postman or NodeJS with axios library. You can visit the web front end at the same url.
# This example is a Proof of Concept, have bugs and unexpected behaivors ## API Endpoints
- **GET** `/props`: Return the user and assistant name for generate the prompt.
*Response:*
```json
{
"user_name": "User:",
"assistant_name": "Assistant:"
}
```
- **POST** `/completion`: Given a prompt, it returns the predicted completion, just streaming mode.
*Options:*
`temperature`: Adjust the randomness of the generated text (default: 0.1).
`prompt`: Provide a prompt as a string, It should be a coherent continuation of the system prompt.
`system_prompt`: Provide a system prompt as a string.
`anti_prompt`: Provide the name of the user coherent with the system prompt.
`assistant_name`: Provide the name of the assistant coherent with the system prompt.
*Example request:*
```json
{
// this changes the system prompt on runtime
"system_prompt": "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
Human: Hello
Assistant: Hi, how may I help you?
Human:",
"anti_prompt": "Human:",
"assistant_name": "Assistant:",
// required options
"prompt": "When is the day of independency of US?",
"temperature": 0.2
}
```
*Response:*
```json
{
"content": "<token_str>"
}
```
# This example is a Proof of Concept, have some bugs and unexpected behaivors, this not supports long prompts.

View file

@ -0,0 +1,263 @@
const char* system_prompt_default =
R"(Transcript of a never ending dialog, where the User interacts with an Assistant.
The Assistant is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.
User: Recommend a nice restaurant in the area.
Assistant: I recommend the restaurant "The Golden Duck". It is a 5 star restaurant with a great view of the city. The food is delicious and the service is excellent. The prices are reasonable and the portions are generous. The restaurant is located at 123 Main Street, New York, NY 10001. The phone number is (212) 555-1234. The hours are Monday through Friday from 11:00 am to 10:00 pm. The restaurant is closed on Saturdays and Sundays.
User: Who is Richard Feynman?
Assistant: Richard Feynman was an American physicist who is best known for his work in quantum mechanics and particle physics. He was awarded the Nobel Prize in Physics in 1965 for his contributions to the development of quantum electrodynamics. He was a popular lecturer and author, and he wrote several books, including "Surely You're Joking, Mr. Feynman!" and "What Do You Care What Other People Think?".
User:)";
const char* index_html_ = R"(
<!DOCTYPE html>
<html>
<head>
<title>llama.cpp - server parallel PoC</title>
<script src="index.js"></script>
</head>
<body>
<div style="width: 90%;margin: auto;">
<h2>Server parallel - PoC</h2>
<form id="myForm">
<input type="checkbox" id="system_promt_cb" name="myCheckbox" onchange="toggleSP() ">
<label for="system_promt_cb">Use custom system prompt</label>
<br>
<div id="system_prompt_view" style="display: none;">
<textarea id="sp_text" name="systemPrompt" style="width: 100%;height: 4rem;" placeholder="System Prompt"></textarea>
<label for="user_name">User name</label>
<input type="text" id="user_name" value="" placeholder="Anti prompt" required>
<label for="assistant_name">Assistant name</label>
<input type="text" id="assistant_name" value="" placeholder="Assistant:" required>
<button type="button" id="btn_reset" onclick="clearSP() " >Clear all</button>
</div>
<br>
<label for="slot_id">Slot ID (-1 load in a idle slot)</label>
<input type="number" id="slot_id" value="-1" required>
<br>
<label for="temperature">Temperature</label>
<input type="number" id="temperature" value="0.1" required>
<br>
<label for="message">Message</label>
<input id="message" style="width: 80%;" required>
<br><br>
<button type="button" id="btn_send" onclick="perform() " >Send</button>
<br>
<br>
<button type="button" id="btn_reset" onclick="resetBtn() " >Reset</button>
</form>
<div id="conversation_view">
</div>
</div>
</body>
</html>
)";
const char* index_js_ = R"(
let conversation = [];
let current_message = -1;
const questions = [
"Who is Elon Musk?",
"Who is Jeff Bezos?",
"How to get a job at google?",
"What are you?",
"When was born Abraham Lincoln?",
];
let user_name = "";
let assistant_name = "";
function toggleSP() {
if(document.getElementById("system_promt_cb").checked) {
document.getElementById("system_prompt_view").style.display = "block";
} else {
document.getElementById("system_prompt_view").style.display = "none";
}
}
function clearSP() {
document.getElementById("sp_text").value = "";
document.getElementById("anti_prompt").value = "";
document.getElementById("assistant_name").value = "";
}
docReady(async () => {
document.getElementById("message").value =
questions[Math.floor(Math.random() * questions.length)];
// to keep the same prompt format in all clients
const response = await fetch("/props");
if (!response.ok) {
alert(`HTTP error! Status: ${response.status}`);
}
const data = await response.json();
user_name = data.user_name;
assistant_name = data.assistant_name;
});
function docReady(fn) {
// see if DOM is already available
if (
document.readyState === "complete" ||
document.readyState === "interactive"
) {
// call on next available tick
setTimeout(fn, 1);
} else {
document.addEventListener("DOMContentLoaded", fn);
}
}
function updateView() {
let conv_view = document.getElementById("conversation_view");
// build view
conv_view.innerHTML = "";
for (let index in conversation) {
conversation[index].assistant = conversation[index].assistant.replace(
user_name,
""
);
conv_view.innerHTML += `
<p><span style="font-weight: bold">User:</span> ${conversation[index].user}<p>
<p style="white-space: pre-line;"><span style="font-weight: bold">Assistant:</span> ${conversation[index].assistant}<p>`;
}
}
async function call_llama(options) {
const response = await fetch("/completion", {
method: "POST",
body: JSON.stringify(options),
headers: {
Connection: "keep-alive",
"Content-Type": "application/json",
Accept: "text/event-stream",
},
});
const reader = response.body.getReader();
let cont = true;
const decoder = new TextDecoder();
let leftover = ""; // Buffer for partially read lines
try {
let cont = true;
while (cont) {
const result = await reader.read();
if (result.done) {
document.getElementById("btn_send").disabled = false;
break;
}
// Add any leftover data to the current chunk of data
const text = leftover + decoder.decode(result.value);
// Check if the last character is a line break
const endsWithLineBreak = text.endsWith("\n");
// Split the text into lines
let lines = text.split("\n");
// If the text doesn't end with a line break, then the last line is incomplete
// Store it in leftover to be added to the next chunk of data
if (!endsWithLineBreak) {
leftover = lines.pop();
} else {
leftover = ""; // Reset leftover if we have a line break at the end
}
// Parse all sse events and add them to result
const regex = /^(\S+):\s(.*)$/gm;
for (const line of lines) {
const match = regex.exec(line);
if (match) {
result[match[1]] = match[2];
// since we know this is llama.cpp, let's just decode the json in data
if (result.data) {
result.data = JSON.parse(result.data);
conversation[current_message].assistant += result.data.content;
updateView();
}
}
}
}
} catch (e) {
if (e.name !== "AbortError") {
console.error("llama error: ", e);
}
throw e;
}
}
function generatePrompt() {
// generate a good prompt to have coherence
let prompt = "";
for (let index in conversation) {
if (index == 0) {
prompt += conversation[index].user + "\n";
} else {
prompt += user_name + conversation[index].user + "\n";
}
if (index == current_message) {
prompt += assistant_name;
} else {
prompt += assistant_name + conversation[index].assistant;
}
}
return prompt;
}
function resetBtn() {
document.getElementById("slot_id").value = "-1";
document.getElementById("temperature").value = "0.1";
document.getElementById("message").value =
questions[Math.floor(Math.random() * questions.length)];
document.getElementById("conversation_view").innerHTML = "";
conversation = [];
current_message = -1;
}
async function perform() {
var slot_id = parseInt(document.getElementById("slot_id").value);
var temperature = parseFloat(document.getElementById("temperature").value);
var prompt = " " + document.getElementById("message").value;
if (!isNaN(slot_id) && !isNaN(temperature) && prompt.length > 0) {
let options = {
slot_id,
temperature
};
if(document.getElementById("system_promt_cb").checked) {
let system_prompt = document.getElementById("sp_text").value;
let anti_prompt = document.getElementById("user_name").value;
let assistant_name_ = document.getElementById("assistant_name").value;
if(!system_prompt || !anti_prompt || !assistant_name_) {
document.getElementById("conversation_view").innerText =
"please, insert valid props.";
return;
}
conversation = [];
current_message = -1;
document.getElementById("system_promt_cb").checked = false;
document.getElementById("system_promt_cb").dispatchEvent(new Event("change"));
options.system_prompt = system_prompt;
options.anti_prompt = anti_prompt;
options.assistant_name = assistant_name_;
user_name = anti_prompt;
assistant_name = assistant_name_;
}
current_message++;
conversation.push({
user: prompt,
assistant: "",
});
updateView();
document.getElementById("message").value = "";
document.getElementById("btn_send").disabled = true;
options.prompt = generatePrompt();
await call_llama(options);
} else {
document.getElementById("conversation_view").innerText =
"please, insert valid props.";
}
}
)";

View file

@ -1,180 +0,0 @@
const auto index_html = R"(
<!DOCTYPE html>
<html>
<head>
<title>llama.cpp - server parallel PoC</title>
</head>
<body>
<div style="width: 90%;margin: auto;">
<h2>Server parallel - PoC</h2>
<form id="myForm" >
<label for="slot_id">Slot ID (-1 load in a idle slot)</label>
<input type="number" id="slot_id" value="-1" required>
<br>
<label for="temperature">Temperature</label>
<input type="number" id="temperature" value="0.1" required>
<br>
<label for="message">Message</label>
<input id="message" style="width: 80%;" required>
<br><br>
<button type="button" id="btn_send" onclick="perform() " >Send</button>
<br>
<br>
<button type="button" id="btn_reset" onclick="resetBtn() " >Reset</button>
</form>
<div id="conversation_view">
</div>
</div>
<script>
let conversation = [];
let current_message = -1;
const questions = ["Who is Elon Musk?", "Who is Jeff Bezos?", "How to get a job at google?", "What are you?", "When was born Abraham Lincoln?"];
docReady(() => {
document.getElementById("message").value = questions[Math.floor(Math.random() * questions.length)];
});
function docReady(fn) {
// see if DOM is already available
if (document.readyState === "complete" || document.readyState === "interactive") {
// call on next available tick
setTimeout(fn, 1);
} else {
document.addEventListener("DOMContentLoaded", fn);
}
}
function updateView() {
let conv_view = document.getElementById("conversation_view");
// build view
conv_view.innerHTML = "";
for(let index in conversation) {
conversation[index].assistant = conversation[index].assistant.replace("User:", "");
conv_view.innerHTML += `
<p><span style="font-weight: bold">User:</span> ${conversation[index].user}<p>
<p style="white-space: pre-line;"><span style="font-weight: bold">Assistant:</span> ${conversation[index].assistant}<p>`;
}
}
async function call_llama(options) {
const response = await fetch("/completion", {
method: 'POST',
body: JSON.stringify(options),
headers: {
'Connection': 'keep-alive',
'Content-Type': 'application/json',
'Accept': 'text/event-stream'
}
});
const reader = response.body.getReader();
let cont = true;
const decoder = new TextDecoder();
let leftover = ""; // Buffer for partially read lines
try {
let cont = true;
while (cont) {
const result = await reader.read();
if (result.done) {
document.getElementById("btn_send").disabled = false;
break;
}
// Add any leftover data to the current chunk of data
const text = leftover + decoder.decode(result.value);
// Check if the last character is a line break
const endsWithLineBreak = text.endsWith('\n');
// Split the text into lines
let lines = text.split('\n');
// If the text doesn't end with a line break, then the last line is incomplete
// Store it in leftover to be added to the next chunk of data
if (!endsWithLineBreak) {
leftover = lines.pop();
} else {
leftover = ""; // Reset leftover if we have a line break at the end
}
// Parse all sse events and add them to result
const regex = /^(\S+):\s(.*)$/gm;
for (const line of lines) {
const match = regex.exec(line);
if (match) {
result[match[1]] = match[2]
// since we know this is llama.cpp, let's just decode the json in data
if (result.data) {
result.data = JSON.parse(result.data);
conversation[current_message].assistant += result.data.token;
updateView();
}
}
}
}
} catch (e) {
if (e.name !== 'AbortError') {
console.error("llama error: ", e);
}
throw e;
}
}
function generatePrompt() {
// generate a good prompt to have coherence
let prompt = '';
for(let index in conversation) {
if(index == 0) {
prompt += conversation[index].user + "\n";
} else {
prompt += "User:" + conversation[index].user + "\n";
}
if(index == current_message) {
prompt += "Assistant:";
} else {
prompt += "Assistant:" + conversation[index].assistant;
}
}
return prompt;
}
function resetBtn() {
document.getElementById("slot_id").value = "-1";
document.getElementById("temperature").value = "0.1";
document.getElementById("message").value = questions[Math.floor(Math.random() * questions.length)];
document.getElementById("conversation_view").innerHTML = "";
conversation = [];
current_message = -1;
}
async function perform() {
document.getElementById("btn_send").disabled = true;
var slot_id = parseInt(document.getElementById("slot_id").value);
var temperature = parseFloat(document.getElementById("temperature").value);
var prompt = " " + document.getElementById("message").value;
if (!isNaN(slot_id) && !isNaN(temperature) && prompt.length > 0) {
current_message++;
conversation.push({
user: prompt,
assistant: ''
});
updateView();
document.getElementById("message").value = "";
await call_llama({
slot_id,
temperature,
prompt: generatePrompt()
});
} else {
document.getElementById("conversation_view").innerText = "please, insert valid props.";
}
}
</script>
</body>
</html>
)";

View file

@ -5,7 +5,7 @@
#include <sstream> #include <sstream>
#include <thread> #include <thread>
#include <vector> #include <vector>
#include "index.h" #include "frontend.h"
#include "common.h" #include "common.h"
#include "llama.h" #include "llama.h"
@ -69,15 +69,6 @@ enum slot_command {
RELEASE RELEASE
}; };
static std::string system_prompt =
R"(Transcript of a never ending dialog, where the User interacts with an Assistant.
The Assistant is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.
User: Recommend a nice restaurant in the area.
Assistant: I recommend the restaurant "The Golden Duck". It is a 5 star restaurant with a great view of the city. The food is delicious and the service is excellent. The prices are reasonable and the portions are generous. The restaurant is located at 123 Main Street, New York, NY 10001. The phone number is (212) 555-1234. The hours are Monday through Friday from 11:00 am to 10:00 pm. The restaurant is closed on Saturdays and Sundays.
User: Who is Richard Feynman?
Assistant: Richard Feynman was an American physicist who is best known for his work in quantum mechanics and particle physics. He was awarded the Nobel Prize in Physics in 1965 for his contributions to the development of quantum electrodynamics. He was a popular lecturer and author, and he wrote several books, including "Surely You're Joking, Mr. Feynman!" and "What Do You Care What Other People Think?".
User:)";
struct llama_client_slot struct llama_client_slot
{ {
@ -117,11 +108,23 @@ struct llama_client_slot
void nofity() { void nofity() {
newToken = !newToken; newToken = !newToken;
} }
void release() {
if(state == PROCESSING) {
command = RELEASE;
}
}
}; };
struct server_parallel_context { struct server_parallel_context {
// example props // example props
vector<llama_client_slot> slots; vector<llama_client_slot> slots;
std::string system_prompt = "";
bool update_system_prompt = true;
// broadcast to all clients to keep the same prompt format
std::string user_name = ""; // this should be the anti prompt
std::string assistant_name = ""; // this is for generate the prompt
// llama native props // llama native props
gpt_params params; gpt_params params;
@ -131,9 +134,8 @@ struct server_parallel_context {
int n_vocab; int n_vocab;
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
std::vector<llama_token> tokens_system; std::vector<llama_token> tokens_system;
int32_t n_tokens_system; int32_t n_tokens_system = 0;
llama_batch batch; llama_batch batch;
bool request_clean_kv = true;
bool loadModel(gpt_params params_) { bool loadModel(gpt_params params_) {
params = params_; params = params_;
@ -149,7 +151,8 @@ struct server_parallel_context {
return true; return true;
} }
void initializeSlots() { void initialize() {
// create slots
LOG_TEE("Available slots:\n"); LOG_TEE("Available slots:\n");
for (int i = 0; i < params.n_parallel; i++) for (int i = 0; i < params.n_parallel; i++)
{ {
@ -162,49 +165,87 @@ struct server_parallel_context {
LOG_TEE(" - slot %i\n", slot.id); LOG_TEE(" - slot %i\n", slot.id);
slots.push_back(slot); slots.push_back(slot);
} }
}
bool loadSystemPrompt() {
tokens_system = ::llama_tokenize(ctx, system_prompt, true);
n_tokens_system = tokens_system.size();
batch = llama_batch_init(params.n_ctx, 0); batch = llama_batch_init(params.n_ctx, 0);
{ // always assign a default system prompt
LOG_TEE("Evaluating the system prompt ...\n"); system_prompt = system_prompt_default;
user_name = "User:";
batch.n_tokens = n_tokens_system; assistant_name = "Assistant:";
params.antiprompt.push_back(user_name);
for (int32_t i = 0; i < batch.n_tokens; ++i)
{
batch.token[i] = tokens_system[i];
batch.pos[i] = i;
batch.seq_id[i] = 0;
batch.logits[i] = false;
}
if (llama_decode(ctx, batch) != 0)
{
LOG_TEE("%s: llama_decode() failed\n", __func__);
return false;
}
// assign the system KV cache to all parallel sequences
for (int32_t i = 1; i < params.n_parallel; ++i)
{
llama_kv_cache_seq_cp(ctx, 0, i, 0, n_tokens_system);
}
}
return true;
} }
llama_client_slot* loadPrompt(int slot_id, string prompt, float temp_) { void updateSystemPrompt() {
tokens_system = ::llama_tokenize(ctx, system_prompt, true);
n_tokens_system = tokens_system.size();
batch.n_tokens = n_tokens_system;
// clear the entire KV cache
for (int i = 0; i < params.n_parallel; ++i)
{
llama_kv_cache_seq_rm(ctx, i, 0, -1);
}
for (int32_t i = 0; i < batch.n_tokens; ++i)
{
batch.token[i] = tokens_system[i];
batch.pos[i] = i;
batch.seq_id[i] = 0;
batch.logits[i] = false;
}
if (llama_decode(ctx, batch) != 0)
{
LOG_TEE("%s: llama_decode() failed\n", __func__);
return;
}
// assign the system KV cache to all parallel sequences
for (int32_t i = 1; i < params.n_parallel; ++i)
{
llama_kv_cache_seq_cp(ctx, 0, i, 0, n_tokens_system);
}
LOG_TEE("system prompt updated\n");
update_system_prompt = false;
}
void notifySystemPromptChanged() {
// release all slots
for (llama_client_slot &slot : slots)
{
slot.release();
}
waitAllAreIdle();
// wait until system prompt load
update_system_prompt = true;
while(update_system_prompt) {
this_thread::sleep_for(chrono::milliseconds(5));
}
// system prompt loaded, continue
}
llama_client_slot* requestCompletion(json data) {
if(data.contains("system_prompt") &&
data.contains("anti_prompt") &&
data.contains("assistant_name")) {
system_prompt = data.value("system_prompt", "");
user_name = data.value("anti_prompt", "");
assistant_name = data.value("assistant_name", "");
params.antiprompt.clear();
params.antiprompt.push_back(user_name);
notifySystemPromptChanged();
}
int slot_id = data.value("slot_id", -1);
float temperature = data.value("temperature", 0.1f);
string prompt = data.value("prompt", "");
for (llama_client_slot & slot : slots) for (llama_client_slot & slot : slots)
{ {
if ( if (
slot_id == -1 && slot.available() || slot_id == -1 && slot.available() ||
slot.id == slot_id) slot.id == slot_id)
{ {
slot.start(prompt, temp_); slot.start(prompt, temperature);
LOG_TEE("slot %i is processing\n", slot.id); LOG_TEE("slot %i is processing\n", slot.id);
return &slot; // return a pointer to slot (thread safe?) return &slot; // return a pointer to slot (thread safe?)
} }
@ -238,17 +279,38 @@ struct server_parallel_context {
return stop_pos; return stop_pos;
} }
void waitAllAreIdle() {
bool wait = true;
while(wait) {
wait = false;
for (auto &slot : slots)
{
if (!slot.available())
{
wait = true;
break;
}
}
}
}
bool updateSlots() { bool updateSlots() {
// update the system prompt wait until all slots are idle state
if(update_system_prompt) {
updateSystemPrompt();
}
batch.n_tokens = 0; batch.n_tokens = 0;
// decode any currently ongoing sequences // decode any currently ongoing sequences
for (auto & slot : slots) { for (auto & slot : slots) {
if (slot.state == PROCESSING && slot.command == RELEASE) if (slot.state == PROCESSING && slot.command == RELEASE)
{ {
LOG_TEE("slot %i released\n", slot.id);
llama_kv_cache_seq_rm(ctx, slot.id, n_tokens_system, n_ctx); llama_kv_cache_seq_rm(ctx, slot.id, n_tokens_system, n_ctx);
slot.state = IDLE; slot.state = IDLE;
LOG_TEE("slot %i is released\n", slot.id);
slot.command = NONE; slot.command = NONE;
continue;
} }
// no decode wait until the token had been send to client // no decode wait until the token had been send to client
@ -269,16 +331,6 @@ struct server_parallel_context {
batch.n_tokens += 1; batch.n_tokens += 1;
} }
if (batch.n_tokens == 0 && request_clean_kv) {
// all sequences have ended - clear the entire KV cache
for (int i = 0; i < params.n_parallel; ++i) {
llama_kv_cache_seq_rm(ctx, i, n_tokens_system, -1);
}
request_clean_kv = false;
LOG_TEE("%s: clearing the KV cache\n", __func__);
}
// assign workload to the slots // assign workload to the slots
if (params.cont_batching || batch.n_tokens == 0) { if (params.cont_batching || batch.n_tokens == 0) {
for (llama_client_slot & slot : slots) { for (llama_client_slot & slot : slots) {
@ -394,13 +446,12 @@ struct server_parallel_context {
stop_pos != std::string::npos)) { stop_pos != std::string::npos)) {
//LOG_TEE("slot %i generated text:\n%s'------------------------------\n", slot.id, slot.generated_text.c_str()); //LOG_TEE("slot %i generated text:\n%s'------------------------------\n", slot.id, slot.generated_text.c_str());
slot.generated_text.clear(); slot.generated_text.clear();
slot.command = RELEASE; slot.release();
} }
slot.i_batch = -1; slot.i_batch = -1;
} }
return true;
} }
return true;
} }
}; };
@ -666,7 +717,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
invalid_param = true; invalid_param = true;
break; break;
} }
params.lora_adapter.push_back({lora_adapter, std::stof(argv[i])}); params.lora_adapter.push_back(make_tuple(lora_adapter, std::stof(argv[i])));
params.use_mmap = false; params.use_mmap = false;
} }
else if (arg == "--lora-base") else if (arg == "--lora-base")
@ -703,27 +754,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
break; break;
} }
params.n_predict = std::stoi(argv[i]); params.n_predict = std::stoi(argv[i]);
} else if (arg == "-f" || arg == "--file") {
if (++i >= argc) {
invalid_param = true;
break;
}
std::ifstream file(argv[i]);
if (!file) {
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
invalid_param = true;
break;
}
std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(system_prompt));
if (system_prompt.back() == '\n') {
system_prompt.pop_back();
}
} else if (arg == "-r" || arg == "--reverse-prompt") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.antiprompt.push_back(argv[i]);
} }
else else
{ {
@ -765,25 +795,26 @@ int main(int argc, char **argv)
return 1; return 1;
} }
// create slots llama.initialize();
llama.initializeSlots();
// process system prompt
llama.loadSystemPrompt();
Server svr; Server svr;
svr.Get("/", [&](const Request & /*req*/, Response &res) svr.Get("/", [&](const Request & /*req*/, Response &res)
{ res.set_content(index_html, "text/html"); }); { res.set_content(index_html_, "text/html"); });
svr.Get("/index.js", [&](const Request & /*req*/, Response &res)
{ res.set_content(index_js_, "text/html"); });
svr.Get("/props", [&llama](const Request & /*req*/, Response &res)
{
json data = {
{ "user_name", llama.user_name.c_str() },
{ "assistant_name", llama.assistant_name.c_str() }
};
res.set_content(data.dump(), "application/json"); });
svr.Post("/completion", [&llama](const Request &req, Response &res) svr.Post("/completion", [&llama](const Request &req, Response &res)
{ {
json data = json::parse(req.body); llama_client_slot* slot = llama.requestCompletion(json::parse(req.body));
int slot_id = data.value("slot_id", -1);
float temperature = data.value("temperature", 0.8f);
string prompt = data.value("prompt", "");
llama_client_slot* slot = llama.loadPrompt(slot_id, prompt, temperature);
// Verify if the slot exist // Verify if the slot exist
if (slot) { if (slot) {
res.set_chunked_content_provider("text/event-stream", res.set_chunked_content_provider("text/event-stream",
@ -792,14 +823,13 @@ int main(int argc, char **argv)
sink.done(); sink.done();
return false; return false;
} }
if(slot->hasNewToken()) { // new token notification if(slot->hasNewToken()) { // new token notification
stringstream ss; stringstream ss;
json res_d = {{"token", slot->sampled_token_str}}; json res_d = {{ "content", slot->sampled_token_str }};
ss << "data: " << res_d.dump() << "\n\n"; ss << "data: " << res_d.dump() << "\n\n";
string result = ss.str(); string result = ss.str();
if(!sink.write(result.c_str(), result.size())) { // user request release if(!sink.write(result.c_str(), result.size())) {
slot->command = RELEASE; slot->release();
return false; return false;
} }
} }