added support custom prompts and more functions

This commit is contained in:
FSSRepo 2023-05-17 00:12:45 -06:00
parent 0cfbd1d7d7
commit da7f370a94
4 changed files with 625 additions and 636 deletions

View file

@ -1,6 +1,6 @@
set(TARGET server) set(TARGET server)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}) include_directories(${CMAKE_CURRENT_SOURCE_DIR})
add_executable(${TARGET} server.cpp json.hpp httplib.h server.h) add_executable(${TARGET} server.cpp json.hpp httplib.h)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11) target_compile_features(${TARGET} PRIVATE cxx_std_11)
if(TARGET BUILD_INFO) if(TARGET BUILD_INFO)

View file

@ -7,8 +7,9 @@ This example allow you to have a llama.cpp http server to interact from a web pa
1. [Quick Start](#quick-start) 1. [Quick Start](#quick-start)
2. [Node JS Test](#node-js-test) 2. [Node JS Test](#node-js-test)
3. [API Endpoints](#api-endpoints) 3. [API Endpoints](#api-endpoints)
4. [Common Options](#common-options) 4. [More examples](#more-examples)
5. [Performance Tuning and Memory Options](#performance-tuning-and-memory-options) 5. [Common Options](#common-options)
6. [Performance Tuning and Memory Options](#performance-tuning-and-memory-options)
## Quick Start ## Quick Start
@ -17,13 +18,13 @@ To get started right away, run the following command, making sure to use the cor
#### Unix-based systems (Linux, macOS, etc.): #### Unix-based systems (Linux, macOS, etc.):
```bash ```bash
./server -m models/7B/ggml-model.bin --keep -1 --ctx_size 2048 ./server -m models/7B/ggml-model.bin --ctx_size 2048
``` ```
#### Windows: #### Windows:
```powershell ```powershell
server.exe -m models\7B\ggml-model.bin --keep -1 --ctx_size 2048 server.exe -m models\7B\ggml-model.bin --ctx_size 2048
``` ```
That will start a server that by default listens on `127.0.0.1:8080`. You can consume the endpoints with Postman or NodeJS with axios library. That will start a server that by default listens on `127.0.0.1:8080`. You can consume the endpoints with Postman or NodeJS with axios library.
@ -42,45 +43,22 @@ npm install axios
Create a index.js file and put inside this: Create a index.js file and put inside this:
```javascript ```javascript
const axios = require('axios'); const axios = require("axios");
async function LLamaTest() { const prompt = `Building a website can be done in 10 simple steps:`;
let result = await axios.post("http://127.0.0.1:8080/setting-context", {
context: [ async function Test() {
{ role: "system", content: "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." }, let result = await axios.post("http://127.0.0.1:8080/completion", {
{ role: "user", content: "Hello, Assistant." }, prompt,
{ role: "assistant", content: "Hello. How may I help you today?" }, batch_size: 128,
{ role: "user", content: "Please tell me the largest city in Europe." }, n_predict: 512,
{ role: "assistant", content: "Sure. The largest city in Europe is Moscow, the capital of Russia." }
],
batch_size: 64,
temperature: 0.2,
top_k: 40,
top_p: 0.9,
n_predict: 2048,
threads: 5
});
result = await axios.post("http://127.0.0.1:8080/set-message", {
message: ' What is linux?'
});
if(result.data.can_inference) {
result = await axios.get("http://127.0.0.1:8080/completion?stream=true", { responseType: 'stream' });
result.data.on('data', (data) => {
let completion = JSON.parse(data.toString());
// token by token completion like Chat GPT
process.stdout.write(completion.content);
}); });
/* // the response is received until completion finish
Wait the entire completion (takes long time for response)
result = await axios.get("http://127.0.0.1:8080/completion");
console.log(result.data.content); console.log(result.data.content);
*/
}
} }
LLamaTest(); Test();
``` ```
And run it: And run it:
@ -93,7 +71,7 @@ node .
You can interact with this API Endpoints. This implementations just support chat style interaction. You can interact with this API Endpoints. This implementations just support chat style interaction.
- `POST hostname:port/setting-context`: Setting up the Llama Context to begin the completions tasks. - **POST** `hostname:port/completion`: Setting up the Llama Context to begin the completions tasks.
Options: Options:
`batch_size`: Set the batch size for prompt processing (default: 512). `batch_size`: Set the batch size for prompt processing (default: 512).
@ -108,38 +86,200 @@ Options:
`threads`: Set the number of threads to use during computation. `threads`: Set the number of threads to use during computation.
`context`: Set a short conversation as context. `n_keep`: Specify the number of tokens from the initial prompt to retain when the model resets its internal context. By default, this value is set to 0 (meaning no tokens are kept). Use `-1` to retain all tokens from the initial prompt.
Insert items to an array of this form: `{ role: "user", content: "Hello, Assistant." }`, where: `as_loop`: It allows receiving each predicted token in real-time instead of waiting for the completion to finish. To enable this, set to `true`.
`role` can be `system`, `assistant` and `user`. `interactive`: It allows interacting with the completion, and the completion stops as soon as it encounters a `stop word`. To enable this, set to `true`.
`content` the message content. `prompt`: Provide a prompt. Internally, the prompt is compared, and it detects if a part has already been evaluated, and the remaining part will be evaluate.
- `POST hostname:port/set-message`: Set the message of the user to Llama. `stop`: Specify the words or characters that indicate a stop. These words will not be included in the completion, so make sure to add them to the prompt for the next iteration.
`message`: Set the message content. `exclude`: Specify the words or characters you do not want to appear in the completion. These words will not be included in the completion, so make sure to add them to the prompt for the next iteration.
- `GET hostname:port/completion`: Receive the response, it can be a stream or wait until finish the completion. - **POST** `hostname:port/embedding`: Generate embedding of a given text
`stream`: Set `true` if you want to receive a stream response. `content`: Set the text to get generate the embedding.
`threads`: Set the number of threads to use during computation.
To use this endpoint, you need to start the server with the `--embedding` option added.
- **POST** `hostname:port/tokenize`: Tokenize a given text
`content`: Set the text to tokenize.
- **GET** `hostname:port/next-token`: Receive the next token predicted, execute this request in a loop. Make sure set `as_loop` as `true` in the completion request.
## More examples
### Interactive mode
This mode allows interacting in a chat-like manner. It is recommended for models designed as assistants such as `Vicuna`, `WizardLM`, `Koala`, among others. Make sure to add the correct stop word for the corresponding model.
The prompt should be generated by you, according to the model's guidelines. You should keep adding the model's completions to the context as well.
This example works well for `Vicuna - version 1`.
```javascript
const axios = require("axios");
let prompt = `A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
### Human: Hello, Assistant.
### Assistant: Hello. How may I help you today?
### Human: Please tell me the largest city in Europe.
### Assistant: Sure. The largest city in Europe is Moscow, the capital of Russia.`;
async function ChatCompletion(answer) {
// the user's next question to the prompt
prompt += `\n### Human: ${answer}\n`
result = await axios.post("http://127.0.0.1:8080/completion", {
prompt,
batch_size: 128,
temperature: 0.2,
top_k: 40,
top_p: 0.9,
n_keep: -1,
n_predict: 2048,
stop: ["\n### Human:"], // when detect this, stop completion
exclude: ["### Assistant:"], // no show in the completion
threads: 8,
as_loop: true, // use this to request the completion token by token
interactive: true, // enable the detection of a stop word
});
// create a loop to receive every token predicted
// note: this operation is blocking, avoid use this in a ui thread
let message = "";
while (true) {
result = await axios.get("http://127.0.0.1:8080/next-token");
process.stdout.write(result.data.content);
message += result.data.content;
// to avoid an infinite loop
if (result.data.stop) {
console.log("Completed");
// make sure to add the completion to the prompt.
prompt += `### Assistant: ${message}`;
break;
}
}
}
// This function should be called every time a question to the model is needed.
async function Test() {
// the server can't inference in paralell
await ChatCompletion("Write a long story about a time magician in a fantasy world");
await ChatCompletion("Summary the story");
}
Test();
```
### Alpaca example
**Temporaly note:** no tested, if you have the model, please test it and report me some issue
```javascript
const axios = require("axios");
let prompt = `Below is an instruction that describes a task. Write a response that appropriately completes the request.
`;
async function DoInstruction(instruction) {
prompt += `\n\n### Instruction:\n\n${instruction}\n\n### Response:\n\n`;
result = await axios.post("http://127.0.0.1:8080/completion", {
prompt,
batch_size: 128,
temperature: 0.2,
top_k: 40,
top_p: 0.9,
n_keep: -1,
n_predict: 2048,
stop: ["### Instruction:\n\n"], // when detect this, stop completion
exclude: [], // no show in the completion
threads: 8,
as_loop: true, // use this to request the completion token by token
interactive: true, // enable the detection of a stop word
});
// create a loop to receive every token predicted
// note: this operation is blocking, avoid use this in a ui thread
let message = "";
while (true) {
result = await axios.get("http://127.0.0.1:8080/next-token");
process.stdout.write(result.data.content);
message += result.data.content;
// to avoid an infinite loop
if (result.data.stop) {
console.log("Completed");
// make sure to add the completion and the user's next question to the prompt.
prompt += message;
break;
}
}
}
// This function should be called every time a instruction to the model is needed.
DoInstruction("Destroy the world");
```
### Embeddings
First, run the server with `--embedding` option:
```bash
server -m models/7B/ggml-model.bin --ctx_size 2048 --embedding
```
Run this code in NodeJS:
```javascript
const axios = require('axios');
async function Test() {
let result = await axios.post("http://127.0.0.1:8080/embedding", {
content: `Hello`,
threads: 5
});
// print the embedding array
console.log(result.data.embedding);
}
Test();
```
### Tokenize
Run this code in NodeJS:
```javascript
const axios = require('axios');
async function Test() {
let result = await axios.post("http://127.0.0.1:8080/tokenize", {
content: `Hello`
});
// print the embedding array
console.log(result.data.tokens);
}
Test();
```
## Common Options ## Common Options
- `-m FNAME, --model FNAME`: Specify the path to the LLaMA model file (e.g., `models/7B/ggml-model.bin`). - `-m FNAME, --model FNAME`: Specify the path to the LLaMA model file (e.g., `models/7B/ggml-model.bin`).
- `-c N, --ctx_size N`: Set the size of the prompt context. The default is 512, but LLaMA models were built with a context of 2048, which will provide better results for longer input/inference. - `-c N, --ctx_size N`: Set the size of the prompt context. The default is 512, but LLaMA models were built with a context of 2048, which will provide better results for longer input/inference.
- `--embedding`: Enable the embedding mode. **Completion function doesn't work in this mode**.
- `--host`: Set the hostname or ip address to listen. Default `127.0.0.1`; - `--host`: Set the hostname or ip address to listen. Default `127.0.0.1`;
- `--port`: Set the port to listen. Default: `8080`. - `--port`: Set the port to listen. Default: `8080`.
### Keep Prompt
The `--keep` option allows users to retain the original prompt when the model runs out of context, ensuring a connection to the initial instruction or conversation topic is maintained.
- `--keep N`: Specify the number of tokens from the initial prompt to retain when the model resets its internal context. By default, this value is set to 0 (meaning no tokens are kept). Use `-1` to retain all tokens from the initial prompt.
By utilizing context management options like `--ctx_size` and `--keep`, you can maintain a more coherent and consistent interaction with the LLaMA models, ensuring that the generated text remains relevant to the original prompt or conversation.
### RNG Seed ### RNG Seed
- `-s SEED, --seed SEED`: Set the random number generator (RNG) seed (default: -1, < 0 = random seed). - `-s SEED, --seed SEED`: Set the random number generator (RNG) seed (default: -1, < 0 = random seed).
@ -150,12 +290,12 @@ The RNG seed is used to initialize the random number generator that influences t
### No Memory Mapping ### No Memory Mapping
- `--no-mmap`: Do not memory-map the model. By default, models are mapped into memory, which allows the system to load only the necessary parts of the model as needed. However, if the model is larger than your total amount of RAM or if your system is low on available memory, using mmap might increase the risk of pageouts, negatively impacting performance. Disabling mmap results in slower load times but may reduce pageouts if you're not using `--mlock`. Note that if the model is larger than the total amount of RAM, turning off mmap would prevent the model from loading at all. - `--no-mmap`: Do not memory-map the model. By default, models are mapped into memory, which allows the system to load only the necessary parts of the model as needed. However, if the model is larger than your total amount of RAM or if your system is low on available memory, using mmap might increase the risk of pageouts, negatively impacting performance.
### Memory Float 32 ### Memory Float 32
- `--memory_f32`: Use 32-bit floats instead of 16-bit floats for memory key+value, allowing higher quality inference at the cost of higher memory usage. - `--memory_f32`: Use 32-bit floats instead of 16-bit floats for memory key+value, allowing higher quality inference at the cost of higher memory usage.
## Limitations: ## Limitations:
* The actual implementation of llama.cpp need a `llama-state` for support multiple contexts and clients.
* The context can't be reset during runtime. - The actual implementation of llama.cpp need a `llama-state` for handle multiple contexts and clients, but this could require more powerful hardware.

View file

@ -1,269 +1,110 @@
#include <server.h> #include <httplib.h>
#include <json.hpp>
#include "common.h"
#include "llama.h"
using namespace httplib; struct server_params
using json = nlohmann::json; {
std::string hostname = "127.0.0.1";
int32_t port = 8080;
};
bool Llama::load_context() { struct llama_server_context
// load the model {
bool context_config = false;
bool has_next_token = false;
bool is_interacting = false;
int32_t tokens_completion = 0;
int32_t n_past = 0;
int32_t n_consumed = 0;
int32_t n_session_consumed = 0;
int32_t n_remain = 0;
std::vector<llama_token> embd;
std::vector<llama_token> last_n_tokens;
std::vector<llama_token> processed_tokens;
std::vector<llama_token> llama_token_newline;
std::vector<llama_token> embd_inp;
std::vector<std::vector<llama_token>> no_show_words;
llama_context *ctx;
gpt_params params;
bool loadModel(gpt_params params_)
{ {
auto lparams = llama_context_default_params(); params = params_;
ctx = llama_init_from_gpt_params(params);
lparams.n_ctx = params.n_ctx;
lparams.n_parts = params.n_parts;
lparams.seed = params.seed;
lparams.f16_kv = params.memory_f16;
lparams.use_mmap = params.use_mmap;
lparams.use_mlock = params.use_mlock;
lparams.n_gpu_layers = params.n_gpu_layers;
ctx = llama_init_from_file(params.model.c_str(), lparams);
if (ctx == NULL) if (ctx == NULL)
{ {
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); fprintf(stderr, "%s: error: unable to load model\n", __func__);
return false; return false;
} }
}
n_ctx = llama_n_ctx(ctx);
// enable interactive mode if reverse prompt or interactive start is specified
if (params.antiprompt.size() != 0 || params.interactive_first)
{
params.interactive = true;
}
// determine newline token // determine newline token
llama_token_newline = ::llama_tokenize(ctx, "\n", false); llama_token_newline = ::llama_tokenize(ctx, "\n", false);
last_n_tokens.resize(params.n_ctx);
last_n_tokens.resize(n_ctx);
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
return true; return true;
} }
bool Llama::prompt_test() { bool loadPrompt() {
embd_inp = ::llama_tokenize(ctx, params.prompt, true); params.prompt.insert(0, 1, ' '); // always add a first space
std::vector<llama_token> prompt_tokens = ::llama_tokenize(ctx, params.prompt, true);
if ((int)embd_inp.size() > n_ctx - 4) // compare the evaluated prompt with the new prompt
int new_prompt_len = 0;
for (int i = 0;i < prompt_tokens.size(); i++) {
if (i < processed_tokens.size() &&
processed_tokens[i] == prompt_tokens[i])
{
continue;
}
else
{
embd_inp.push_back(prompt_tokens[i]);
if(new_prompt_len == 0) {
if(i - 1 < n_past) {
processed_tokens.erase(processed_tokens.begin() + i, processed_tokens.end());
}
// Evaluate the new fragment prompt from the last token processed.
n_past = processed_tokens.size();
}
new_prompt_len ++;
}
}
if(n_past > 0 && params.interactive) {
n_remain -= new_prompt_len;
}
if ((int)embd_inp.size() > params.n_ctx - 4)
{ {
fprintf(stderr, "%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int)embd_inp.size(), n_ctx - 4);
return false; return false;
} }
has_next_token = true;
return true; return true;
} }
void Llama::setting_context() {
user_tag_tokens = ::llama_tokenize(ctx, user_tag, false);
assistant_tag_tokens = ::llama_tokenize(ctx, assistant_tag, false);
n_remain = params.n_predict;
void beginCompletion()
{
if(n_remain == 0) {
// number of tokens to keep when resetting context // number of tokens to keep when resetting context
if (params.n_keep < 0 || params.n_keep > (int)embd_inp.size() || params.instruct) if (params.n_keep < 0 || params.n_keep > (int)embd_inp.size())
{ {
params.n_keep = (int)embd_inp.size(); params.n_keep = (int)embd_inp.size();
} }
}
// print system information n_remain = params.n_predict;
{
fprintf(stderr, "\n");
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
} }
fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n", llama_token nextToken() {
params.repeat_last_n, params.repeat_penalty, params.presence_penalty, params.frequency_penalty, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau); llama_token result = -1;
fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
while(true) {
if (embd.size() > 0) if (embd.size() > 0)
{ {
if (n_past + (int)embd.size() > n_ctx) if (n_past + (int)embd.size() > params.n_ctx)
{ {
// Reset context
const int n_left = n_past - params.n_keep; const int n_left = n_past - params.n_keep;
n_past = params.n_keep; n_past = std::max(1, params.n_keep);
embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left / 2 - embd.size(), last_n_tokens.end() - embd.size()); processed_tokens.erase(processed_tokens.begin() + n_past, processed_tokens.end());
} embd.insert(embd.begin(), last_n_tokens.begin() + params.n_ctx - n_left / 2 - embd.size(), last_n_tokens.end() - embd.size());
for (int i = 0; i < (int)embd.size(); i += params.n_batch)
{
int n_eval = (int)embd.size() - i;
if (n_eval > params.n_batch)
{
n_eval = params.n_batch;
}
if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads))
{
fprintf(stderr, "%s : failed to eval\n", __func__);
return;
}
n_past += n_eval;
}
}
embd.clear();
if ((int)embd_inp.size() <= n_consumed && !is_interacting)
{
// out of user input, sample next token
const float temp = params.temp;
const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
const float top_p = params.top_p;
const float tfs_z = params.tfs_z;
const float typical_p = params.typical_p;
const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
const float repeat_penalty = params.repeat_penalty;
const float alpha_presence = params.presence_penalty;
const float alpha_frequency = params.frequency_penalty;
const int mirostat = params.mirostat;
const float mirostat_tau = params.mirostat_tau;
const float mirostat_eta = params.mirostat_eta;
const bool penalize_nl = params.penalize_nl;
llama_token id = 0;
{
auto logits = llama_get_logits(ctx);
auto n_vocab = llama_n_vocab(ctx);
// Apply params.logit_bias map
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++)
{
logits[it->first] += it->second;
}
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++)
{
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
}
llama_token_data_array candidates_p = {candidates.data(), candidates.size(), false};
// Apply penalties
float nl_logit = logits[llama_token_nl()];
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
llama_sample_repetition_penalty(ctx, &candidates_p,
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
last_n_repeat, repeat_penalty);
llama_sample_frequency_and_presence_penalties(ctx, &candidates_p,
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
last_n_repeat, alpha_frequency, alpha_presence);
if (!penalize_nl)
{
logits[llama_token_nl()] = nl_logit;
}
if (temp <= 0)
{
// Greedy sampling
id = llama_sample_token_greedy(ctx, &candidates_p);
}
else
{
if (mirostat == 1)
{
static float mirostat_mu = 2.0f * mirostat_tau;
const int mirostat_m = 100;
llama_sample_temperature(ctx, &candidates_p, temp);
id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
}
else if (mirostat == 2)
{
static float mirostat_mu = 2.0f * mirostat_tau;
llama_sample_temperature(ctx, &candidates_p, temp);
id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
}
else
{
// Temperature sampling
llama_sample_top_k(ctx, &candidates_p, top_k, 1);
llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1);
llama_sample_typical(ctx, &candidates_p, typical_p, 1);
llama_sample_top_p(ctx, &candidates_p, top_p, 1);
llama_sample_temperature(ctx, &candidates_p, temp);
id = llama_sample_token(ctx, &candidates_p);
}
}
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(id);
}
// replace end of text token with newline token when in interactive mode
if (id == llama_token_eos() && params.interactive && !params.instruct)
{
id = llama_token_newline.front();
if (params.antiprompt.size() != 0)
{
// tokenize and inject first reverse prompt
const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false);
embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
}
}
// add it to the context
embd.push_back(id);
// decrement remaining sampling budget
--n_remain;
}
else
{
// some user input remains from prompt or interaction, forward it to processing
while ((int)embd_inp.size() > n_consumed)
{
embd.push_back(embd_inp[n_consumed]);
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(embd_inp[n_consumed]);
++n_consumed;
if ((int)embd.size() >= params.n_batch)
{
break;
}
}
}
if (params.interactive && (int)embd_inp.size() <= n_consumed) {
// check for reverse prompt
if (params.antiprompt.size())
{
std::string last_output;
for (auto id : last_n_tokens)
{
last_output += llama_token_to_str(ctx, id);
}
is_antiprompt = false;
// Check if each of the reverse prompts appears at the end of the output.
for (std::string &antiprompt : params.antiprompt)
{
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos)
{
is_interacting = true;
is_antiprompt = true;
context_config = true;
return;
}
}
}
}
}
}
int Llama::set_message(std::string msg) {
if (msg.length() > 1)
{
auto line_inp = ::llama_tokenize(ctx, msg, false);
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
n_remain -= (int)line_inp.size();
is_antiprompt = false;
return (int)line_inp.size();
} else {
return 0;
}
}
llama_token Llama::nextToken() {
llama_token result = -1;
if (embd.size() > 0) {
if (n_past + (int)embd.size() > n_ctx)
{
const int n_left = n_past - params.n_keep;
n_past = params.n_keep;
embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left / 2 - embd.size(), last_n_tokens.end() - embd.size());
} }
for (int i = 0; i < (int)embd.size(); i += params.n_batch) for (int i = 0; i < (int)embd.size(); i += params.n_batch)
{ {
@ -275,6 +116,7 @@ llama_token Llama::nextToken() {
if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads)) if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads))
{ {
fprintf(stderr, "%s : failed to eval\n", __func__); fprintf(stderr, "%s : failed to eval\n", __func__);
has_next_token = false;
return result; return result;
} }
n_past += n_eval; n_past += n_eval;
@ -289,7 +131,7 @@ llama_token Llama::nextToken() {
const float top_p = params.top_p; const float top_p = params.top_p;
const float tfs_z = params.tfs_z; const float tfs_z = params.tfs_z;
const float typical_p = params.typical_p; const float typical_p = params.typical_p;
const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n; const int32_t repeat_last_n = params.repeat_last_n < 0 ? params.n_ctx : params.repeat_last_n;
const float repeat_penalty = params.repeat_penalty; const float repeat_penalty = params.repeat_penalty;
const float alpha_presence = params.presence_penalty; const float alpha_presence = params.presence_penalty;
const float alpha_frequency = params.frequency_penalty; const float alpha_frequency = params.frequency_penalty;
@ -319,7 +161,7 @@ llama_token Llama::nextToken() {
// Apply penalties // Apply penalties
float nl_logit = logits[llama_token_nl()]; float nl_logit = logits[llama_token_nl()];
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx); auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), params.n_ctx);
llama_sample_repetition_penalty(ctx, &candidates_p, llama_sample_repetition_penalty(ctx, &candidates_p,
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
last_n_repeat, repeat_penalty); last_n_repeat, repeat_penalty);
@ -363,10 +205,11 @@ llama_token Llama::nextToken() {
} }
last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(id); last_n_tokens.push_back(id);
processed_tokens.push_back(id);
} }
// replace end of text token with newline token when in interactive mode // replace end of text token with newline token when in interactive mode
if (id == llama_token_eos() && params.interactive && !params.instruct) if (id == llama_token_eos() && params.interactive)
{ {
id = llama_token_newline.front(); id = llama_token_newline.front();
if (params.antiprompt.size() != 0) if (params.antiprompt.size() != 0)
@ -379,7 +222,6 @@ llama_token Llama::nextToken() {
// add it to the context // add it to the context
embd.push_back(id); embd.push_back(id);
for (auto id : embd) for (auto id : embd)
{ {
result = id; result = id;
@ -396,6 +238,7 @@ llama_token Llama::nextToken() {
embd.push_back(embd_inp[n_consumed]); embd.push_back(embd_inp[n_consumed]);
last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(embd_inp[n_consumed]); last_n_tokens.push_back(embd_inp[n_consumed]);
processed_tokens.push_back(embd_inp[n_consumed]);
++n_consumed; ++n_consumed;
if ((int)embd.size() >= params.n_batch) if ((int)embd.size() >= params.n_batch)
{ {
@ -403,7 +246,8 @@ llama_token Llama::nextToken() {
} }
} }
} }
if (params.interactive && (int)embd_inp.size() <= n_consumed) { if (params.interactive && (int)embd_inp.size() <= n_consumed)
{
// check for reverse prompt // check for reverse prompt
if (params.antiprompt.size()) if (params.antiprompt.size())
{ {
@ -412,14 +256,14 @@ llama_token Llama::nextToken() {
{ {
last_output += llama_token_to_str(ctx, id); last_output += llama_token_to_str(ctx, id);
} }
is_antiprompt = false; has_next_token = true;
// Check if each of the reverse prompts appears at the end of the output. // Check if each of the reverse prompts appears at the end of the output.
for (std::string &antiprompt : params.antiprompt) for (std::string &antiprompt : params.antiprompt)
{ {
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos)
{ {
is_interacting = true; is_interacting = true;
is_antiprompt = true; has_next_token = false;
return result; return result;
} }
} }
@ -429,145 +273,162 @@ llama_token Llama::nextToken() {
is_interacting = false; is_interacting = false;
} }
} }
if (!embd.empty() && embd.back() == llama_token_eos()) {
has_next_token = false;
}
if (params.interactive && n_remain <= 0 && params.n_predict != -1) if (params.interactive && n_remain <= 0 && params.n_predict != -1)
{ {
n_remain = params.n_predict; n_remain = params.n_predict;
is_interacting = true; is_interacting = true;
} }
has_next_token = n_remain != 0;
return result; return result;
} }
std::string Llama::inference() { std::string inference()
llama_token tkn = nextToken(); {
if(tkn == -1) { llama_token token = nextToken();
if (token == -1) {
return ""; return "";
} }
std::vector<llama_token> tokens_completion; std::vector<llama_token> tokens_completion;
tokens_completion.push_back(tkn); tokens_completion.push_back(token);
// Avoid add the no show words to the response
// Avoid add the user or assistant tag to the response for (std::vector<llama_token> word_tokens : no_show_words)
{
int match_token = 1; int match_token = 1;
if(tokens_completion[0] == user_tag_tokens[0]) { if (tokens_completion[0] == word_tokens[0])
while(true) { {
if(match_token == user_tag_tokens.size()) { // all user tag tokens matched, return empty inference
return "";
}
tkn = nextToken();
tokens_completion.push_back(tkn);
if(tkn == user_tag_tokens[match_token]) { // the token follow the sequence
match_token++;
} else if(match_token < user_tag_tokens.size()) { // no complete all user tag
break;
}
}
}
if(tokens_completion[0] == assistant_tag_tokens[0]) {
bool execute_matching = true; bool execute_matching = true;
if(tokens_completion.size() > 1) { // if user tag had been tested if (tokens_completion.size() > 1) { // if previus tokens had been tested
for(int i = 1;i < assistant_tag_tokens.size(); i++) { for (int i = 1; i < word_tokens.size(); i++)
if(i >= tokens_completion.size()) { {
if (i >= tokens_completion.size()) {
match_token = i; match_token = i;
break; break;
} }
if(tokens_completion[i] == assistant_tag_tokens[i]) { if (tokens_completion[i] == word_tokens[i])
{
continue; continue;
} else { }
else
{
execute_matching = false; execute_matching = false;
break; break;
} }
} }
} }
while(execute_matching) { while (execute_matching) {
if(match_token == assistant_tag_tokens.size()) { // all assistant tag tokens matched, return empty inference if (match_token == word_tokens.size()) {
return ""; return "";
} }
tkn = nextToken(); token = nextToken();
tokens_completion.push_back(tkn); tokens_completion.push_back(token);
if(tkn == assistant_tag_tokens[match_token]) { // the token follow the sequence if (token == word_tokens[match_token])
{ // the token follow the sequence
match_token++; match_token++;
} else if(match_token < assistant_tag_tokens.size()) { // no complete all user tag }
else if (match_token < word_tokens.size())
{ // no complete all user tag
break; break;
} }
} }
} }
}
std::string result = ""; std::string result = "";
for(llama_token token : tokens_completion) { for (llama_token tkn : tokens_completion)
result += llama_token_to_str(ctx, token); {
result += llama_token_to_str(ctx, tkn);
} }
return result; return result;
} }
void Llama::release() { std::vector<float> embedding(std::string content, int threads) {
// TODO: Clean the context content.insert(0, 1, ' ');
// llama_free(ctx); std::vector<llama_token> tokens = ::llama_tokenize(ctx, content, true);
} if (tokens.size() > 0)
{
if (llama_eval(ctx, tokens.data(), tokens.size(), 0, threads))
{
fprintf(stderr, "%s : failed to eval\n", __func__);
std::vector<float> embeddings_;
return embeddings_;
}
}
const int n_embd = llama_n_embd(ctx);
const auto embeddings = llama_get_embeddings(ctx);
std::vector<float> embeddings_(embeddings, embeddings + n_embd);
return embeddings_;
}
};
void server_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { using namespace httplib;
using json = nlohmann::json;
void server_print_usage(int /*argc*/, char **argv, const gpt_params &params)
{
fprintf(stderr, "usage: %s [options]\n", argv[0]); fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n"); fprintf(stderr, "\n");
fprintf(stderr, "options:\n"); fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help show this help message and exit\n"); fprintf(stderr, " -h, --help show this help message and exit\n");
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n"); fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n");
fprintf(stderr, " --memory_f32 use f32 instead of f16 for memory key+value\n"); fprintf(stderr, " --memory_f32 use f32 instead of f16 for memory key+value\n");
fprintf(stderr, " --embedding enable embedding mode\n");
fprintf(stderr, " --keep number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep); fprintf(stderr, " --keep number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
if (llama_mlock_supported()) { if (llama_mlock_supported())
{
fprintf(stderr, " --mlock force system to keep model in RAM rather than swapping or compressing\n"); fprintf(stderr, " --mlock force system to keep model in RAM rather than swapping or compressing\n");
} }
if (llama_mmap_supported()) { if (llama_mmap_supported())
{
fprintf(stderr, " --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)\n"); fprintf(stderr, " --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)\n");
} }
fprintf(stderr, " -ngl N, --n-gpu-layers N\n"); fprintf(stderr, " -ngl N, --n-gpu-layers N\n");
fprintf(stderr, " number of layers to store in VRAM\n"); fprintf(stderr, " number of layers to store in VRAM\n");
fprintf(stderr, " -m FNAME, --model FNAME\n"); fprintf(stderr, " -m FNAME, --model FNAME\n");
fprintf(stderr, " model path (default: %s)\n", params.model.c_str()); fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
fprintf(stderr, " -host ip address to listen (default 0.0.0.0)\n"); fprintf(stderr, " -host ip address to listen (default 127.0.0.1)\n");
fprintf(stderr, " -port PORT port to listen (default 8080)\n"); fprintf(stderr, " -port PORT port to listen (default 8080)\n");
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
int main(int argc, char ** argv) { bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_params &params)
{
// own arguments required by this example
gpt_params default_params; gpt_params default_params;
gpt_params params;
params.model = "ggml-model.bin";
std::string hostname = "0.0.0.0";
int port = 8080;
std::string arg; std::string arg;
bool invalid_param = false; bool invalid_param = false;
for (int i = 1; i < argc; i++) for (int i = 1; i < argc; i++)
{ {
arg = argv[i]; arg = argv[i];
if (arg == "--port") { if (arg == "--port")
if (++i >= argc) { {
invalid_param = true;
break;
}
port = std::stoi(argv[i]);
} else if (arg == "--host") {
if (++i >= argc) if (++i >= argc)
{ {
invalid_param = true; invalid_param = true;
break; break;
} }
hostname = argv[i]; sparams.port = std::stoi(argv[i]);
} else if (arg == "--keep") { }
if (++i >= argc) { else if (arg == "--host")
{
if (++i >= argc)
{
invalid_param = true; invalid_param = true;
break; break;
} }
params.n_keep = std::stoi(argv[i]); sparams.hostname = argv[i];
} }
else if (arg == "-s" || arg == "--seed") { else if (arg == "-s" || arg == "--seed")
{
#if defined(GGML_USE_CUBLAS) #if defined(GGML_USE_CUBLAS)
fprintf(stderr, "WARNING: when using cuBLAS generation results are NOT guaranteed to be reproducible.\n"); fprintf(stderr, "WARNING: when using cuBLAS generation results are NOT guaranteed to be reproducible.\n");
#endif #endif
if (++i >= argc) { if (++i >= argc)
{
invalid_param = true; invalid_param = true;
break; break;
} }
@ -582,6 +443,10 @@ int main(int argc, char ** argv) {
} }
params.model = argv[i]; params.model = argv[i];
} }
else if (arg == "--embedding")
{
params.embedding = true;
}
else if (arg == "-h" || arg == "--help") else if (arg == "-h" || arg == "--help")
{ {
server_print_usage(argc, argv, default_params); server_print_usage(argc, argv, default_params);
@ -599,8 +464,11 @@ int main(int argc, char ** argv) {
else if (arg == "--memory_f32") else if (arg == "--memory_f32")
{ {
params.memory_f16 = false; params.memory_f16 = false;
} else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") { }
if (++i >= argc) { else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers")
{
if (++i >= argc)
{
invalid_param = true; invalid_param = true;
break; break;
} }
@ -620,6 +488,23 @@ int main(int argc, char ** argv) {
server_print_usage(argc, argv, default_params); server_print_usage(argc, argv, default_params);
exit(1); exit(1);
} }
return true;
}
int main(int argc, char **argv)
{
// own arguments required by this example
gpt_params params;
server_params sparams;
// struct that contains llama context and inference
llama_server_context llama;
params.model = "ggml-model.bin";
if (server_params_parse(argc, argv, sparams, params) == false)
{
return 1;
}
if (params.seed <= 0) if (params.seed <= 0)
{ {
@ -628,92 +513,96 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s: seed = %d\n", __func__, params.seed); fprintf(stderr, "%s: seed = %d\n", __func__, params.seed);
Llama* llama = new Llama(params); // load the model
if(!llama->load_context()) { if (!llama.loadModel(params))
{
return 1; return 1;
} }
Server svr; Server svr;
svr.Get("/", [](const Request &req, Response &res) svr.Get("/", [](const Request &req, Response &res)
{ { res.set_content("<h1>llama.cpp server works</h1>", "text/html"); });
res.set_content("<h1>llama.cpp server works</h1>", "text/html");
} svr.Post("/completion", [&llama](const Request &req, Response &res)
); {
if(llama.params.embedding) {
json data = {
{"status", "error"},
{"reason", "To use completion function disable embedding mode"}};
res.set_content(data.dump(), "application/json");
res.status = 400;
return;
}
svr.Post("/setting-context", [&llama](const Request &req, Response &res) {
if(!llama->context_config) {
json body = json::parse(req.body); json body = json::parse(req.body);
/* llama.params.antiprompt.clear();
Seed whould be passed by the request, but seem llama.no_show_words.clear();
the current implementation need it in the load file bool as_loop = false;
*/
if (!body["threads"].is_null()) if (!body["threads"].is_null())
{ {
llama->params.n_threads = body["threads"].get<int>(); llama.params.n_threads = body["threads"].get<int>();
} }
if (!body["n_predict"].is_null()) if (!body["n_predict"].is_null())
{ {
llama->params.n_predict = body["n_predict"].get<int>(); llama.params.n_predict = body["n_predict"].get<int>();
} }
if (!body["top_k"].is_null()) if (!body["top_k"].is_null())
{ {
llama->params.top_k = body["top_k"].get<int>(); llama.params.top_k = body["top_k"].get<int>();
} }
if (!body["top_p"].is_null()) if (!body["top_p"].is_null())
{ {
llama->params.top_p = (float)body["top_p"].get<float>(); llama.params.top_p = body["top_p"].get<float>();
} }
if (!body["temperature"].is_null()) if (!body["temperature"].is_null())
{ {
llama->params.temp = (float)body["temperature"].get<float>(); llama.params.temp = body["temperature"].get<float>();
} }
if (!body["batch_size"].is_null()) if (!body["batch_size"].is_null())
{ {
llama->params.n_batch = body["batch_size"].get<int>(); llama.params.n_batch = body["batch_size"].get<int>();
} }
if (!body["tags"].is_null()) if (!body["n_keep"].is_null())
{ {
json tags = body["tags"].get<json>(); llama.params.n_keep = body["n_keep"].get<int>();
llama->user_tag = tags["user"].get<std::string>();
llama->assistant_tag = tags["assistant"].get<std::string>();
} }
if (!body["context"].is_null()) if (!body["as_loop"].is_null())
{ {
llama->params.prompt = ""; as_loop = body["as_loop"].get<bool>();
std::vector<json> context_messages = body["context"].get<std::vector<json>>();
for (json ctx_msg : context_messages)
{
auto role = ctx_msg["role"].get<std::string>();
if (role == "system")
{
llama->params.prompt = ctx_msg["content"].get<std::string>() + "\n\n";
} }
else if (role == "user") if (!body["interactive"].is_null())
{ {
llama->params.prompt += llama->user_tag + " " + ctx_msg["content"].get<std::string>() + "\n"; llama.params.interactive = body["interactive"].get<bool>();
} }
else if (role == "assistant") if (!body["prompt"].is_null())
{ {
llama->params.prompt += llama->assistant_tag + " " + ctx_msg["content"].get<std::string>() + "\n"; llama.params.prompt = body["prompt"].get<std::string>();
}
}
llama->params.prompt += llama->user_tag;
}
else if (!body["prompt"].is_null())
{
llama->params.prompt = body["prompt"].get<std::string>();
} }
else else
{ {
json data = { json data = {
{"status", "error"}, {"status", "error"},
{"reason", "You need to pass the context or prompt"}}; {"reason", "You need to pass the prompt"}};
res.set_content(data.dump(), "application/json"); res.set_content(data.dump(), "application/json");
res.status = 400; res.status = 400;
return; return;
} }
if(!llama->prompt_test()) if (!body["stop"].is_null()) {
std::vector<std::string> stop_words = body["stop"].get<std::vector<std::string>>();
for (std::string stop_word : stop_words) {
llama.params.antiprompt.push_back(stop_word);
llama.no_show_words.push_back(::llama_tokenize(llama.ctx, stop_word, false));
}
}
if (!body["exclude"].is_null()) {
std::vector<std::string> no_show_words = body["exclude"].get<std::vector<std::string>>();
for (std::string no_show : no_show_words) {
llama.no_show_words.push_back(::llama_tokenize(llama.ctx, no_show, false));
}
}
if (!llama.loadPrompt())
{ {
json data = { json data = {
{"status", "error"}, {"status", "error"},
@ -722,73 +611,83 @@ int main(int argc, char ** argv) {
res.status = 400; res.status = 400;
return; return;
} }
// Default configs for interactive with Vicuna model llama.beginCompletion();
llama->params.interactive = true; llama.tokens_completion = 0;
llama->params.antiprompt.push_back(llama->user_tag); if(as_loop) {
llama->params.repeat_last_n = 64;
llama->params.repeat_penalty = 1.1f;
llama->setting_context();
}
json data = { json data = {
{ "status", "done" }}; {"status", "done" } };
res.set_content(data.dump(), "application/json"); return res.set_content(data.dump(), "application/json");
});
svr.Post("/set-message", [&llama](const Request &req, Response &res) {
bool result = false;
if (llama->context_config)
{
json body = json::parse(req.body);
result = llama->set_message(body["message"].get<std::string>() + "\n");
}
json data = {
{"can_inference", result }};
res.set_content(data.dump(), "application/json");
});
svr.Get("/completion", [&llama](const Request &req, Response &res)
{
bool stream = false;
if (req.has_param("stream")) {
stream = req.get_param_value("stream") == "true";
}
if(stream) {
// Stream token by token like Chat GPT
res.set_content_provider(
"application/json",
[&llama](size_t offset, DataSink &sink)
{
llama->tokens_completion = 0;
while(!llama->is_antiprompt) {
std::string result = llama->inference();
json data = {
{"content", result },
{"tokens_consumed", 1},
{"stop", llama->is_antiprompt }};
std::string json_data = data.dump();
sink.write(json_data.c_str(), json_data.length());
}
sink.done(); // No more data
return true; // return 'false' if you want to cancel the process.
});
} else { } else {
// Send all completion when finish // Send all completion when finish
std::string completion = ""; std::string completion = "";
llama->tokens_completion = 0; while (llama.has_next_token)
while (!llama->is_antiprompt)
{ {
completion += llama->inference(); completion += llama.inference();
} }
json data = { json data = {
{ "content", completion.c_str() }, {"content", completion.c_str()},
{ "total_tokens", llama->tokens_completion } {"total_tokens", llama.tokens_completion}};
}; return res.set_content(data.dump(), "application/json");
res.set_content(data.dump(), "application/json"); } });
svr.Post("/tokenize", [&llama](const Request &req, Response &res)
{
json body = json::parse(req.body);
json data = {
{"tokens", ::llama_tokenize(llama.ctx, body["content"].get<std::string>(), false) } };
return res.set_content(data.dump(), "application/json");
});
svr.Post("/embedding", [&llama](const Request &req, Response &res)
{
if(!llama.params.embedding) {
std::vector<float> empty;
json data = {
{"embedding", empty}};
fprintf(stderr, "[llama-server] : You need enable embedding mode adding: --embedding option\n");
return res.set_content(data.dump(), "application/json");
}
json body = json::parse(req.body);
std::string content = body["content"].get<std::string>();
int threads = body["threads"].get<int>();
json data = {
{"embedding", llama.embedding(content, threads) } };
return res.set_content(data.dump(), "application/json");
});
svr.Get("/next-token", [&llama](const Request &req, Response &res)
{
if(llama.params.embedding) {
res.set_content("{}", "application/json");
return;
}
std::string result = "";
if (req.has_param("stop")) {
llama.has_next_token = false;
llama.is_interacting = true;
} else {
result = llama.inference();
}
try {
json data = {
{"content", result },
{"stop", !llama.has_next_token }};
return res.set_content(data.dump(), "application/json");
} catch (json::exception e) {
// Some tokens have bad UTF-8 strings, the json parser is very sensitive
json data = {
{"content", "" },
{"stop", !llama.has_next_token }};
return res.set_content(data.dump(), "application/json");
} }
}); });
printf("llama.cpp HTTP Server Listening at http://%s:%i", hostname.c_str(), port); fprintf(stderr, "%s: http server Listening at http://%s:%i\n", __func__, sparams.hostname.c_str(), sparams.port);
if(params.embedding) {
fprintf(stderr, "NOTE: Mode embedding enabled. Completion function doesn't work in this mode.\n");
}
// change hostname and port // change hostname and port
svr.listen(hostname, port); svr.listen(sparams.hostname, sparams.port);
} }

View file

@ -1,50 +0,0 @@
#include <httplib.h>
#include <json.hpp>
#include <cstring>
#include "common.h"
#include "llama.h"
/*
This isn't the best way to do this.
Missing:
- Clean context (insert new prompt for change the behavior,
this implies clean kv cache and emb_inp in runtime)
- Release context (free memory) after shutdown the server
*/
class Llama{
public:
Llama(gpt_params params_) : params(params_){};
bool load_context();
bool prompt_test();
void setting_context();
int set_message(std::string msg);
void release();
llama_token nextToken();
std::string inference();
bool context_config = false;
bool is_antiprompt = false;
int tokens_completion = 0;
gpt_params params;
std::string user_tag = "### Human:", assistant_tag = "### Assistant:";
private:
llama_context *ctx;
int n_ctx;
int n_past = 0;
int n_consumed = 0;
int n_session_consumed = 0;
int n_remain = 0;
std::vector<llama_token> embd;
std::vector<llama_token> last_n_tokens;
bool is_interacting = false;
std::vector<int> llama_token_newline;
std::vector<int> embd_inp;
// to ignore this in the completion
std::vector<int> user_tag_tokens;
std::vector<int> assistant_tag_tokens;
};