Merge pull request #21 from SlyEcho/server_refactor

Server refactor
This commit is contained in:
Randall Fitzgerald 2023-06-12 16:16:20 -04:00 committed by GitHub
commit 50e7c5434f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 257 additions and 204 deletions

View file

@ -4,7 +4,7 @@ BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot
ifdef LLAMA_BUILD_SERVER
BUILD_TARGETS += server
LLAMA_SERVER_VERBOSE ?= 1
server: CXXFLAGS += -DSERVER_VERBOSE=$(LLAMA_SERVER_VERBOSE)
server: private CXXFLAGS += -DSERVER_VERBOSE=$(LLAMA_SERVER_VERBOSE)
endif
default: $(BUILD_TARGETS)

View file

@ -3,13 +3,6 @@ option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON)
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
add_executable(${TARGET} server.cpp json.hpp httplib.h)
target_compile_definitions(${TARGET} PRIVATE
# single thread
CPPHTTPLIB_THREAD_POOL_COUNT=1
# crash the server in debug mode, otherwise send an http 500 error
$<$<CONFIG:Debug>:
CPPHTTPLIB_NO_EXCEPTIONS=1
>
SERVER_VERBOSE=$<BOOL:${LLAMA_SERVER_VERBOSE}>
)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})

View file

@ -167,8 +167,16 @@ node .
### Interactive mode
Check the sample in [chat.mjs](chat.mjs).
Run with node:
Run with NodeJS version 16 or later:
```sh
node chat.mjs
```
Another sample in [chat.sh](chat.sh).
Requires [bash](https://www.gnu.org/software/bash/), [curl](https://curl.se) and [jq](https://jqlang.github.io/jq/).
Run with bash:
```sh
bash chat.sh
```

View file

@ -1,61 +1,89 @@
import * as readline from 'node:readline/promises';
import { stdin as input, stdout as output } from 'node:process';
import * as readline from 'node:readline'
import { stdin, stdout } from 'node:process'
const API_URL = 'http://127.0.0.1:8080'
const chat = [
{ human: "Hello, Assistant.",
assistant: "Hello. How may I help you today?" },
{ human: "Please tell me the largest city in Europe.",
assistant: "Sure. The largest city in Europe is Moscow, the capital of Russia." },
{
human: "Hello, Assistant.",
assistant: "Hello. How may I help you today?"
},
{
human: "Please tell me the largest city in Europe.",
assistant: "Sure. The largest city in Europe is Moscow, the capital of Russia."
},
]
const instruction = `A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.`
function format_prompt(question) {
return "A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.\n"
+ chat.map(m => `### Human: ${m.human}\n### Assistant: ${m.assistant}`).join("\n")
+ `\n### Human: ${question}\n### Assistant:`
return `${instruction}\n${
chat.map(m =>`### Human: ${m.human}\n### Assistant: ${m.assistant}`).join("\n")
}\n### Human: ${question}\n### Assistant:`
}
async function ChatCompletion(question) {
const result = await fetch("http://127.0.0.1:8080/completion", {
method: 'POST',
body: JSON.stringify({
prompt: format_prompt(question),
temperature: 0.2,
top_k: 40,
top_p: 0.9,
n_keep: 29,
n_predict: 256,
stop: ["\n### Human:"], // stop completion after generating this
stream: true,
async function tokenize(content) {
const result = await fetch(`${API_URL}/tokenize`, {
method: 'POST',
body: JSON.stringify({ content })
})
})
if (!result.ok) {
return;
}
let answer = ''
for await (var chunk of result.body) {
const t = Buffer.from(chunk).toString('utf8')
if (t.startsWith('data: ')) {
const message = JSON.parse(t.substring(6))
answer += message.content
process.stdout.write(message.content)
if (message.stop) break;
if (!result.ok) {
return []
}
}
process.stdout.write('\n')
chat.push({ human: question, assistant: answer })
return await result.json().tokens
}
const rl = readline.createInterface({ input, output });
const n_keep = await tokenize(instruction).length
async function chat_completion(question) {
const result = await fetch(`${API_URL}/completion`, {
method: 'POST',
body: JSON.stringify({
prompt: format_prompt(question),
temperature: 0.2,
top_k: 40,
top_p: 0.9,
n_keep: n_keep,
n_predict: 256,
stop: ["\n### Human:"], // stop completion after generating this
stream: true,
})
})
if (!result.ok) {
return
}
let answer = ''
for await (var chunk of result.body) {
const t = Buffer.from(chunk).toString('utf8')
if (t.startsWith('data: ')) {
const message = JSON.parse(t.substring(6))
answer += message.content
process.stdout.write(message.content)
if (message.stop) {
if (message.truncated) {
chat.shift()
}
break
}
}
}
process.stdout.write('\n')
chat.push({ human: question, assistant: answer.trimStart() })
}
const rl = readline.createInterface({ input: stdin, output: stdout });
const readlineQuestion = (rl, query, options) => new Promise((resolve, reject) => {
rl.question(query, options, resolve)
});
while(true) {
const question = await rl.question('> ')
await ChatCompletion(question);
const question = await readlineQuestion(rl, '> ')
await chat_completion(question)
}

65
examples/server/chat.sh Normal file
View file

@ -0,0 +1,65 @@
#!/bin/bash
API_URL="http://127.0.0.1:8080"
CHAT=(
"Hello, Assistant."
"Hello. How may I help you today?"
"Please tell me the largest city in Europe."
"Sure. The largest city in Europe is Moscow, the capital of Russia."
)
INSTRUCTION="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
format_prompt() {
echo -n "${INSTRUCTION}"
printf "\n### Human: %s\n### Assistant: %s" "${CHAT[@]}" "$1"
}
tokenize() {
curl \
--silent \
--request POST \
--url "${API_URL}/tokenize" \
--data-raw "$(jq -ns --arg content "$1" '{content:$content}')" \
| jq '.tokens[]'
}
N_KEEP=$(tokenize "${INSTRUCTION}" | wc -l)
chat_completion() {
DATA="$(format_prompt "$1" | jq -Rs --argjson n_keep $N_KEEP '{
prompt: .,
temperature: 0.2,
top_k: 40,
top_p: 0.9,
n_keep: $n_keep,
n_predict: 256,
stop: ["\n### Human:"],
stream: true
}')"
ANSWER=''
curl \
--silent \
--no-buffer \
--request POST \
--url "${API_URL}/completion" \
--data-raw "${DATA}" | while IFS= read -r LINE; do
if [[ $LINE = data:* ]]; then
CONTENT="$(echo "${LINE:5}" | jq -r '.content')"
printf "%s" "${CONTENT}"
ANSWER+="${CONTENT}"
fi
done
printf "\n"
CHAT+=("$1" "${ANSWER:1}")
}
while true; do
read -e -p "> " QUESTION
chat_completion "${QUESTION}"
done

View file

@ -2,6 +2,13 @@
#include "llama.h"
#include "build-info.h"
// single thread
#define CPPHTTPLIB_THREAD_POOL_COUNT 1
#ifndef NDEBUG
// crash the server in debug mode, otherwise send an http 500 error
#define CPPHTTPLIB_NO_EXCEPTIONS 1
#endif
#include "httplib.h"
#include "json.hpp"
@ -105,6 +112,10 @@ struct llama_server_context {
llama_context * ctx = nullptr;
gpt_params params;
bool truncated = false;
bool stopped_eos = false;
bool stopped_word = false;
bool stopped_limit = false;
std::string stopping_word;
int json_indent = -1;
@ -122,6 +133,10 @@ struct llama_server_context {
num_tokens_predicted = 0;
generated_text = "";
generated_text.reserve(params.n_ctx);
truncated = false;
stopped_eos = false;
stopped_word = false;
stopped_limit = false;
stopping_word = "";
multibyte_pending = 0;
@ -166,6 +181,7 @@ struct llama_server_context {
{ "new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend()) },
});
truncated = true;
prompt_tokens = new_tokens;
} else {
const size_t ps = prompt_tokens.size();
@ -207,14 +223,13 @@ struct llama_server_context {
new_tokens.insert(new_tokens.end(), embd.end() - n_left, embd.end());
embd = new_tokens;
n_past = params.n_keep;
if (server_verbose) {
LOG_VERBOSE("input truncated", {
{ "n_ctx", params.n_ctx },
{ "n_keep", params.n_keep },
{ "n_left", n_left },
{ "new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend()) },
});
}
truncated = true;
LOG_VERBOSE("input truncated", {
{ "n_ctx", params.n_ctx },
{ "n_keep", params.n_keep },
{ "n_left", n_left },
{ "new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend()) },
});
}
while (n_past < embd.size()) {
@ -314,8 +329,9 @@ struct llama_server_context {
--n_remain;
if (!embd.empty() && embd.back() == llama_token_eos()) {
stopping_word = llama_token_to_str(ctx, embd.back());
//stopping_word = llama_token_to_str(ctx, embd.back());
has_next_token = false;
stopped_eos = true;
LOG_VERBOSE("eos token found", {});
return result;
}
@ -341,6 +357,7 @@ struct llama_server_context {
(stop_pos == std::string::npos || pos < stop_pos)) {
if (type == STOP_FULL) {
stopping_word = word;
stopped_word = true;
has_next_token = false;
}
stop_pos = pos;
@ -378,17 +395,22 @@ struct llama_server_context {
n_remain++;
}
if (server_verbose) {
LOG_VERBOSE("next token", {
{ "token", token },
{ "token_text", llama_token_to_str(ctx, token) },
{ "has_next_token", has_next_token },
{ "n_remain", n_remain },
{ "num_tokens_predicted", num_tokens_predicted },
{ "stopping_word", stopping_word },
});
if (!has_next_token && n_remain == 0) {
stopped_limit = true;
}
LOG_VERBOSE("next token", {
{ "token", token },
{ "token_text", llama_token_to_str(ctx, token) },
{ "has_next_token", has_next_token },
{ "n_remain", n_remain },
{ "num_tokens_predicted", num_tokens_predicted },
{ "stopped_eos", stopped_eos },
{ "stopped_word", stopped_word },
{ "stopped_limit", stopped_limit },
{ "stopping_word", stopping_word },
});
return token_text;
}
};
@ -578,7 +600,7 @@ void server_params_parse(int argc, char ** argv, server_params & sparams,
}
}
json format_generation_settings(llama_server_context & llama) {
static json format_generation_settings(llama_server_context & llama) {
const auto eos_bias = llama.params.logit_bias.find(llama_token_eos());
const bool ignore_eos = eos_bias != llama.params.logit_bias.end() &&
eos_bias->second < 0.0f && std::isinf(eos_bias->second);
@ -607,98 +629,62 @@ json format_generation_settings(llama_server_context & llama) {
};
}
bool parse_options_completion(json body, llama_server_context & llama, Response & res) {
static json format_final_response(llama_server_context & llama, const std::string & content) {
return json {
{ "content", content },
{ "stop", true },
{ "model", llama.params.model_alias },
{ "tokens_predicted", llama.num_tokens_predicted },
{ "generation_settings", format_generation_settings(llama) },
{ "prompt", llama.params.prompt },
{ "truncated", llama.truncated },
{ "stopped_eos", llama.stopped_eos },
{ "stopped_word", llama.stopped_word },
{ "stopped_limit", llama.stopped_limit },
{ "stopping_word", llama.stopping_word },
};
}
static json format_partial_response(const std::string & content) {
return json {
{ "content", content },
{ "stop", false },
};
}
static json format_tokenizer_response(const std::vector<llama_token> & tokens) {
return json {
{ "tokens", tokens }
};
}
bool parse_options_completion(json body, llama_server_context & llama) {
gpt_params default_params;
if (!body["stream"].is_null()) {
llama.stream = body["stream"].get<bool>();
} else {
llama.stream = false;
}
if (!body["n_predict"].is_null()) {
llama.params.n_predict = body["n_predict"].get<int32_t>();
} else {
llama.params.n_predict = default_params.n_predict;
}
if (!body["top_k"].is_null()) {
llama.params.top_k = body["top_k"].get<int32_t>();
} else {
llama.params.top_k = default_params.top_k;
}
if (!body["top_p"].is_null()) {
llama.params.top_p = body["top_p"].get<float>();
} else {
llama.params.top_p = default_params.top_p;
}
if (!body["tfs_z"].is_null()) {
llama.params.tfs_z = body["tfs_z"].get<float>();
} else {
llama.params.tfs_z = default_params.tfs_z;
}
if (!body["typical_p"].is_null()) {
llama.params.typical_p = body["typical_p"].get<float>();
} else {
llama.params.typical_p = default_params.typical_p;
}
if (!body["repeat_last_n"].is_null()) {
llama.params.repeat_last_n = body["repeat_last_n"].get<int32_t>();
} else {
llama.params.repeat_last_n = default_params.repeat_last_n;
}
if (!body["temperature"].is_null()) {
llama.params.temp = body["temperature"].get<float>();
} else {
llama.params.temp = default_params.temp;
}
if (!body["repeat_penalty"].is_null()) {
llama.params.repeat_penalty = body["repeat_penalty"].get<float>();
} else {
llama.params.repeat_penalty = default_params.repeat_penalty;
}
if (!body["presence_penalty"].is_null()) {
llama.params.presence_penalty = body["presence_penalty"].get<float>();
} else {
llama.params.presence_penalty = default_params.presence_penalty;
}
if (!body["frequency_penalty"].is_null()) {
llama.params.frequency_penalty = body["frequency_penalty"].get<float>();
} else {
llama.params.frequency_penalty = default_params.frequency_penalty;
}
if (!body["mirostat"].is_null()) {
llama.params.mirostat = body["mirostat"].get<int>();
} else {
llama.params.mirostat = default_params.mirostat;
}
if (!body["mirostat_tau"].is_null()) {
llama.params.mirostat_tau = body["mirostat_tau"].get<float>();
} else {
llama.params.mirostat_tau = default_params.mirostat_tau;
}
if (!body["mirostat_eta"].is_null()) {
llama.params.mirostat_eta = body["mirostat_eta"].get<float>();
} else {
llama.params.mirostat_eta = default_params.mirostat_eta;
}
if (!body["penalize_nl"].is_null()) {
llama.params.penalize_nl = body["penalize_nl"].get<bool>();
} else {
llama.params.penalize_nl = default_params.penalize_nl;
}
if (!body["n_keep"].is_null()) {
llama.params.n_keep = body["n_keep"].get<int32_t>();
} else {
llama.params.n_keep = default_params.n_keep;
}
if (!body["seed"].is_null()) {
llama.params.seed = body["seed"].get<int32_t>();
} else {
llama.params.seed = time(NULL);
}
llama.stream = body.value("stream", false);
llama.params.n_predict = body.value("n_predict", default_params.n_predict);
llama.params.top_k = body.value("top_k", default_params.top_k);
llama.params.top_p = body.value("top_p", default_params.top_p);
llama.params.tfs_z = body.value("tfs_z", default_params.tfs_z);
llama.params.typical_p = body.value("typical_p", default_params.typical_p);
llama.params.repeat_last_n = body.value("repeat_last_n", default_params.repeat_last_n);
llama.params.temp = body.value("temperature", default_params.temp);
llama.params.repeat_penalty = body.value("repeat_penalty", default_params.repeat_penalty);
llama.params.presence_penalty = body.value("presence_penalty", default_params.presence_penalty);
llama.params.frequency_penalty = body.value("frequency_penalty", default_params.frequency_penalty);
llama.params.mirostat = body.value("mirostat", default_params.mirostat);
llama.params.mirostat_tau = body.value("mirostat_tau", default_params.mirostat_tau);
llama.params.mirostat_eta = body.value("mirostat_eta", default_params.mirostat_eta);
llama.params.penalize_nl = body.value("penalize_nl", default_params.penalize_nl);
llama.params.n_keep = body.value("n_keep", default_params.n_keep);
llama.params.seed = body.value("seed", default_params.seed);
llama.params.prompt = body.value("prompt", default_params.prompt);
llama.params.logit_bias.clear();
if (!body["ignore_eos"].is_null() && body["ignore_eos"].get<bool>()) {
if (body.value("ignore_eos", false)) {
llama.params.logit_bias[llama_token_eos()] = -INFINITY;
}
if (body["logit_bias"].is_array()) {
int n_vocab = llama_n_vocab(llama.ctx);
for (const auto & el : body["logit_bias"]) {
@ -715,15 +701,6 @@ bool parse_options_completion(json body, llama_server_context & llama, Response
}
}
if (!body["prompt"].is_null()) {
llama.params.prompt = body["prompt"].get<std::string>();
} else {
json data = { {"status", "error"}, {"reason", "You need to provide a prompt"} };
res.set_content(data.dump(llama.json_indent), "application/json");
res.status = 400;
return false;
}
llama.params.antiprompt.clear();
if (!body["stop"].is_null()) {
const auto stop = body["stop"].get<std::vector<std::string>>();
@ -737,6 +714,17 @@ bool parse_options_completion(json body, llama_server_context & llama, Response
return true;
}
static void log_server_request(const Request & req, const Response & res) {
LOG_INFO("request", {
{ "remote_addr", req.remote_addr },
{ "remote_port", req.remote_port },
{ "status", res.status },
{ "path", req.path },
{ "request", req.body },
{ "response", res.body },
});
}
int main(int argc, char ** argv) {
// own arguments required by this example
gpt_params params;
@ -788,7 +776,7 @@ int main(int argc, char ** argv) {
llama.rewind();
llama_reset_timings(llama.ctx);
if (!parse_options_completion(json::parse(req.body), llama, res)) {
if (!parse_options_completion(json::parse(req.body), llama)) {
return;
}
@ -813,15 +801,7 @@ int main(int argc, char ** argv) {
llama.generated_text.end());
}
json data {
{ "content", llama.generated_text },
{ "stop", true },
{ "model", llama.params.model_alias },
{ "tokens_predicted", llama.num_tokens_predicted },
{ "generation_settings", format_generation_settings(llama) },
{ "prompt", llama.params.prompt },
{ "stopping_word", llama.stopping_word },
};
json data = format_final_response(llama, llama.generated_text);
llama_print_timings(llama.ctx);
@ -859,22 +839,10 @@ int main(int argc, char ** argv) {
json data;
if (llama.has_next_token) {
data = {
{ "content", to_send },
{ "stop", false },
};
data = format_partial_response(to_send);
} else {
// Generation is done, send extra information.
data = {
{ "content", to_send },
{ "stop", true },
{ "model", llama.params.model_alias },
{ "tokens_predicted", llama.num_tokens_predicted },
{ "generation_settings", format_generation_settings(llama) },
{ "prompt", llama.params.prompt },
{ "stopping_word", llama.stopping_word },
{ "generated_text", llama.generated_text },
};
data = format_final_response(llama, to_send);
}
std::string str =
@ -910,20 +878,11 @@ int main(int argc, char ** argv) {
json body = json::parse(req.body);
std::string content = body["content"].get<std::string>();
std::vector<llama_token> tokens = ::llama_tokenize(llama.ctx, content, false);
json data {{ "tokens", tokens }};
json data = format_tokenizer_response(tokens);
return res.set_content(data.dump(llama.json_indent), "application/json");
});
svr.set_logger([](const Request & req, const Response & res) {
LOG_INFO("request", {
{ "remote_addr", req.remote_addr },
{ "remote_port", req.remote_port },
{ "status", res.status },
{ "path", req.path },
{ "request", req.body },
{ "response", res.body },
});
});
svr.set_logger(log_server_request);
svr.set_exception_handler([](const Request &, Response & res, std::exception_ptr ep) {
const auto * fmt = "500 Internal Server Error\n%s";