fix multiple clients

This commit is contained in:
FSSRepo 2023-10-17 17:54:56 -04:00
parent d2b1fac6c7
commit c02c52efb5
3 changed files with 2257 additions and 2245 deletions

File diff suppressed because it is too large Load diff

View file

@ -195,6 +195,7 @@
import { llama } from '/completion.js'; import { llama } from '/completion.js';
import { SchemaConverter } from '/json-schema-to-grammar.mjs'; import { SchemaConverter } from '/json-schema-to-grammar.mjs';
let selected_image = false; let selected_image = false;
var slot_id = -1;
const session = signal({ const session = signal({
prompt: "This is a conversation between User and Llama, a friendly chatbot. Llama is helpful, kind, honest, good at writing, and never fails to answer any requests immediately and with precision.", prompt: "This is a conversation between User and Llama, a friendly chatbot. Llama is helpful, kind, honest, good at writing, and never fails to answer any requests immediately and with precision.",
@ -222,7 +223,6 @@
mirostat_eta: 0.1, // learning rate mirostat_eta: 0.1, // learning rate
grammar: '', grammar: '',
n_probs: 0, // no completion_probabilities, n_probs: 0, // no completion_probabilities,
slot_id: -1,
image_data: [], image_data: [],
cache_prompt: true cache_prompt: true
}) })
@ -389,7 +389,6 @@
throw new Error("already running"); throw new Error("already running");
} }
controller.value = new AbortController(); controller.value = new AbortController();
let slot_id = -1;
for await (const chunk of llama(prompt, llamaParams, {controller: controller.value})) { for await (const chunk of llama(prompt, llamaParams, {controller: controller.value})) {
const data = chunk.data; const data = chunk.data;
@ -401,7 +400,6 @@
currentMessages.pop(); currentMessages.pop();
} }
transcriptUpdate([...history, [char, currentMessages]]) transcriptUpdate([...history, [char, currentMessages]])
params.value = {...params.value, slot_id}
console.log("Completion finished: '", currentMessages.map(msg => msg.content).join(''), "', summary: ", data); console.log("Completion finished: '", currentMessages.map(msg => msg.content).join(''), "', summary: ", data);
} else { } else {
currentMessages.push(data); currentMessages.push(data);
@ -450,6 +448,7 @@
} }
await runLlama(prompt, { await runLlama(prompt, {
...params.value, ...params.value,
slot_id: slot_id,
stop: ["</s>", template("{{char}}:"), template("{{user}}:")], stop: ["</s>", template("{{char}}:"), template("{{user}}:")],
}, "{{char}}"); }, "{{char}}");
} }

View file

@ -125,6 +125,7 @@ enum slot_command {
struct slot_params { struct slot_params {
bool stream = true; bool stream = true;
uint32_t seed = -1; // RNG seed uint32_t seed = -1; // RNG seed
int n_keep = 0; // RNG seed
int32_t n_predict = -1; // new tokens to predict int32_t n_predict = -1; // new tokens to predict
std::string grammar = ""; // optional BNF-like grammar to constrain sampling std::string grammar = ""; // optional BNF-like grammar to constrain sampling
bool cache_prompt = false; // remember a the prompt to avoid reprocessing all prompt bool cache_prompt = false; // remember a the prompt to avoid reprocessing all prompt
@ -452,6 +453,7 @@ struct llama_server_context
gpt_params params; gpt_params params;
int n_ctx; int n_ctx;
int n_vocab; int n_vocab;
int max_ctx_per_slot = -1;
bool clean_kv_cache = true; bool clean_kv_cache = true;
~llama_server_context() ~llama_server_context()
@ -514,16 +516,23 @@ struct llama_server_context
void initialize() { void initialize() {
// create slots // create slots
LOG_TEE("Available slots:\n");
all_slots_are_idle = true; all_slots_are_idle = true;
if(max_ctx_per_slot == -1) {
max_ctx_per_slot = n_ctx / params.n_parallel; // split context
}
if(max_ctx_per_slot * params.n_parallel > n_ctx) {
printf("Error: The max context per slot is more greater than model context size");
return;
}
LOG_TEE("Available slots:\n");
for (int i = 0; i < params.n_parallel; i++) for (int i = 0; i < params.n_parallel; i++)
{ {
llama_client_slot slot; llama_client_slot slot;
slot.id = i; slot.id = i;
slot.last_n_tokens.resize(n_ctx); // a slot can fill context size slot.last_n_tokens.resize(max_ctx_per_slot);
std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end(), 0); std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end(), 0);
slot.reset(); slot.reset();
LOG_TEE(" -> Slot %i\n", slot.id); LOG_TEE(" -> Slot %i - max context: %i\n", slot.id, max_ctx_per_slot);
slots.push_back(slot); slots.push_back(slot);
} }
batch = llama_batch_init(n_ctx, 0); batch = llama_batch_init(n_ctx, 0);
@ -914,18 +923,17 @@ struct llama_server_context
} }
// context shift takes effect only when there is a single slot // context shift takes effect only when there is a single slot
if(params.n_parallel == 1) { for(llama_client_slot &slot : slots) {
llama_client_slot &slot = slots[0]; if (slot.isProcessing() && slot.cache_tokens.size() >= (size_t)max_ctx_per_slot)
if (slot.isProcessing() && slot.cache_tokens.size() >= (size_t)n_ctx)
{ {
// Shift context // Shift context
const int n_left = slot.n_past - params.n_keep - 1; const int n_left = slot.n_past - slot.params.n_keep - 1;
const int n_discard = n_left / 2; const int n_discard = n_left / 2;
llama_kv_cache_seq_rm (ctx, slot.id, params.n_keep + 1 , params.n_keep + n_discard + 1); llama_kv_cache_seq_rm (ctx, slot.id, slot.params.n_keep + 1 , slot.params.n_keep + n_discard + 1);
llama_kv_cache_seq_shift(ctx, slot.id, params.n_keep + 1 + n_discard, slot.n_past, -n_discard); llama_kv_cache_seq_shift(ctx, slot.id, slot.params.n_keep + 1 + n_discard, slot.n_past, -n_discard);
for (size_t i = params.n_keep + 1 + n_discard; i < slot.cache_tokens.size(); i++) for (size_t i = slot.params.n_keep + 1 + n_discard; i < slot.cache_tokens.size(); i++)
{ {
slot.cache_tokens[i - n_discard] = slot.cache_tokens[i]; slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
} }
@ -1022,16 +1030,16 @@ struct llama_server_context
slot.n_past = 0; slot.n_past = 0;
slot.num_prompt_tokens_processed = slot.num_prompt_tokens; slot.num_prompt_tokens_processed = slot.num_prompt_tokens;
} else { } else {
if (params.n_keep < 0 && params.n_parallel == 1) if (slot.params.n_keep < 0)
{ {
params.n_keep = (int)slot.num_prompt_tokens; slot.params.n_keep = (int)slot.num_prompt_tokens;
} }
params.n_keep = std::min(params.n_ctx - 4, params.n_keep); slot.params.n_keep = std::min(max_ctx_per_slot - 4, slot.params.n_keep);
//if input prompt is too big, truncate like normal //if input prompt is too big, truncate like normal
if (slot.num_prompt_tokens >= (size_t)n_ctx) if (slot.num_prompt_tokens >= (size_t)max_ctx_per_slot)
{ {
const int n_left = n_ctx - params.n_keep; const int n_left = max_ctx_per_slot - slot.params.n_keep;
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + slot.params.n_keep);
// Use half the left-over space in the context for the prompt // Use half the left-over space in the context for the prompt
new_tokens.insert(new_tokens.end(), prompt_tokens.end() - n_left / 2, prompt_tokens.end()); new_tokens.insert(new_tokens.end(), prompt_tokens.end() - n_left / 2, prompt_tokens.end());
LOG_VERBOSE("input truncated", { LOG_VERBOSE("input truncated", {
@ -1331,6 +1339,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
break; break;
} }
params.n_ctx = std::stoi(argv[i]); params.n_ctx = std::stoi(argv[i]);
}
else if (arg == "-cps" || arg == "--ctx-per-slot" || arg == "--ctx_per_slot")
{
if (++i >= argc)
{
invalid_param = true;
break;
}
llama.max_ctx_per_slot = std::stoi(argv[i]);
} }
else if (arg == "--rope-freq-base") else if (arg == "--rope-freq-base")
{ {
@ -1717,7 +1734,7 @@ static void parse_options_completion(const json &body, llama_client_slot* slot,
slot->sparams.mirostat_tau = json_value(body, "mirostat_tau", default_sparams.mirostat_tau); slot->sparams.mirostat_tau = json_value(body, "mirostat_tau", default_sparams.mirostat_tau);
slot->sparams.mirostat_eta = json_value(body, "mirostat_eta", default_sparams.mirostat_eta); slot->sparams.mirostat_eta = json_value(body, "mirostat_eta", default_sparams.mirostat_eta);
slot->sparams.penalize_nl = json_value(body, "penalize_nl", default_sparams.penalize_nl); slot->sparams.penalize_nl = json_value(body, "penalize_nl", default_sparams.penalize_nl);
llama.params.n_keep = json_value(body, "n_keep", 0); slot->params.n_keep = json_value(body, "n_keep", slot->params.n_keep);
slot->params.seed = json_value(body, "seed", default_params.seed); slot->params.seed = json_value(body, "seed", default_params.seed);
slot->params.grammar = json_value(body, "grammar", default_params.grammar); slot->params.grammar = json_value(body, "grammar", default_params.grammar);
slot->sparams.n_probs = json_value(body, "n_probs", default_sparams.n_probs); slot->sparams.n_probs = json_value(body, "n_probs", default_sparams.n_probs);