This commit is contained in:
FSSRepo 2023-10-12 12:55:08 -04:00
commit b716eeb72a
37 changed files with 13747 additions and 2584 deletions

View file

@ -380,6 +380,7 @@ struct llama_server_context
std::vector<llama_token_data> candidates;
bool all_slots_are_idle = false;
gpt_params params;
llama_sampling_context ctx_sampling;
int n_ctx;
int n_vocab;
bool clean_kv_cache = true;
@ -402,11 +403,29 @@ struct llama_server_context
llama_free_model(model);
model = nullptr;
}
}
for(auto &slot : slots) {
if(slot.grammar) {
llama_grammar_free(slot.grammar);
}
void rewind()
{
params.antiprompt.clear();
params.grammar.clear();
num_prompt_tokens = 0;
num_tokens_predicted = 0;
generated_text = "";
generated_text.reserve(n_ctx);
generated_token_probs.clear();
truncated = false;
stopped_eos = false;
stopped_word = false;
stopped_limit = false;
stopping_word = "";
multibyte_pending = 0;
n_remain = 0;
n_past = 0;
if (grammar != nullptr) {
llama_grammar_free(grammar);
grammar = nullptr;
}
}
@ -491,59 +510,28 @@ struct llama_server_context
return prompt_tokens;
}
void processPrompt() {
//params.n_keep = std::min(n_ctx - 4, params.n_keep);
// if input prompt is too big, truncate like normal
// if (num_prompt_tokens >= (size_t)n_ctx)
// {
// const int n_left = (n_ctx - params.n_keep) / 2;
// std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
// const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
// new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
// std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), last_n_tokens.begin());
// LOG_VERBOSE("input truncated", {
// {"n_ctx", n_ctx},
// {"n_keep", params.n_keep},
// {"n_left", n_left},
// {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
// });
// truncated = true;
// prompt_tokens = new_tokens;
// }
// else
// {
// const size_t ps = num_prompt_tokens;
// std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0);
// std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps);
// }
// compare the evaluated prompt with the new prompt
}
llama_client_slot* getSlot(int id) {
for (llama_client_slot & slot : slots)
{
if ((id == -1 && slot.available()) || slot.id == id)
{
return &slot;
bool loadGrammar()
{
if (!params.grammar.empty()) {
parsed_grammar = grammar_parser::parse(params.grammar.c_str());
// will be empty (default) if there are parse errors
if (parsed_grammar.rules.empty()) {
LOG_ERROR("grammar parse error", {{"grammar", params.grammar}});
return false;
}
}
return nullptr;
}
grammar_parser::print_grammar(stderr, parsed_grammar);
bool launchSlot(llama_client_slot* &slot) {
if(!slot->loadGrammar()) {
return false;
{
auto it = params.logit_bias.find(llama_token_eos(ctx));
if (it != params.logit_bias.end() && it->second == -INFINITY) {
LOG_WARNING("EOS token is disabled, which will cause most grammars to fail", {});
}
}
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
grammar = llama_grammar_init(
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
}
all_slots_are_idle = false;
slot->command = LOAD_PROMPT;
LOG_TEE("slot %i is processing\n", slot->id);
return true;
}
@ -604,15 +592,15 @@ struct llama_server_context
// std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps);
// }
// // compare the evaluated prompt with the new prompt
// n_past = common_part(embd, prompt_tokens);
// embd = prompt_tokens;
// if (n_past == num_prompt_tokens)
// {
// // we have to evaluate at least 1 token to generate logits.
// printf("we have to evaluate at least 1 token to generate logits\n");
// n_past--;
// }
// compare the evaluated prompt with the new prompt
n_past = common_part(embd, prompt_tokens);
embd = prompt_tokens;
if (n_past == num_prompt_tokens)
{
// we have to evaluate at least 1 token to generate logits.
printf("we have to evaluate at least 1 token to generate logits\n");
n_past--;
}
// LOG_VERBOSE("prompt ingested", {
// {"n_past", n_past},
@ -629,77 +617,168 @@ struct llama_server_context
{
llama_kv_cache_seq_rm(ctx, i, 0, -1);
}
clean_kv_cache = false;
params.n_keep = std::min(n_ctx - 4, params.n_keep);
// if input prompt is too big, truncate like normal
if (num_prompt_tokens >= (size_t)n_ctx)
{
const int n_left = (n_ctx - params.n_keep) / 2;
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), last_n_tokens.begin());
LOG_VERBOSE("input truncated", {
{"n_ctx", n_ctx},
{"n_keep", params.n_keep},
{"n_left", n_left},
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
});
truncated = true;
prompt_tokens = new_tokens;
}
else
{
const size_t ps = num_prompt_tokens;
std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0);
std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps);
}
// compare the evaluated prompt with the new prompt
n_past = common_part(embd, prompt_tokens);
// since #3228 we now have to manually manage the KV cache
llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
embd = prompt_tokens;
if (n_past == num_prompt_tokens)
{
// we have to evaluate at least 1 token to generate logits.
n_past--;
}
LOG_VERBOSE("prompt ingested", {
{"n_past", n_past},
{"cached", tokens_to_str(ctx, embd.cbegin(), embd.cbegin() + n_past)},
{"to_eval", tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend())},
});
has_next_token = true;
}
void updateSystemPrompt() {
tokens_system = ::llama_tokenize(ctx, system_prompt, true);
n_tokens_system = tokens_system.size();
batch.n_tokens = n_tokens_system;
cleanKVCache();
for (int32_t i = 0; i < batch.n_tokens; ++i)
{
batch.token[i] = tokens_system[i];
batch.pos[i] = i;
batch.seq_id[i] = 0;
batch.logits[i] = false;
}
if (llama_decode(ctx, batch) != 0)
{
LOG_TEE("%s: llama_decode() failed\n", __func__);
return;
}
// assign the system KV cache to all parallel sequences
for (int32_t i = 1; i < params.n_parallel; ++i)
{
llama_kv_cache_seq_cp(ctx, 0, i, 0, n_tokens_system);
}
LOG_TEE("system prompt updated\n");
update_system_prompt = false;
void beginCompletion()
{
// number of tokens to keep when resetting context
n_remain = params.n_predict;
llama_set_rng_seed(ctx, params.seed);
}
void notifySystemPromptChanged() {
// release all slots
for (llama_client_slot &slot : slots)
completion_token_output nextToken()
{
completion_token_output result;
result.tok = -1;
if (embd.size() >= (size_t)n_ctx)
{
slot.release();
}
waitAllAreIdle();
all_slots_are_idle = true;
// wait until system prompt load
update_system_prompt = true;
while(update_system_prompt) {
std::this_thread::sleep_for(std::chrono::milliseconds(5));
}
// system prompt loaded, continue
}
// Shift context
void processSystemPromptData(json sys_props) {
system_prompt = sys_props.value("system_prompt", "");
user_name = sys_props.value("anti_prompt", "");
assistant_name = sys_props.value("assistant_name", "");
notifySystemPromptChanged();
}
const int n_left = n_past - params.n_keep - 1;
const int n_discard = n_left/2;
void waitAllAreIdle() {
bool wait = true;
while(wait) {
wait = false;
for (auto &slot : slots)
llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
for (size_t i = params.n_keep + 1 + n_discard; i < embd.size(); i++)
{
if (!slot.available())
{
wait = true;
break;
}
embd[i - n_discard] = embd[i];
}
embd.resize(embd.size() - n_discard);
n_past -= n_discard;
truncated = true;
LOG_VERBOSE("input truncated", {
{"n_ctx", n_ctx},
{"n_keep", params.n_keep},
{"n_left", n_left},
});
}
bool tg = true;
while (n_past < embd.size())
{
int n_eval = (int)embd.size() - n_past;
tg = n_eval == 1;
if (n_eval > params.n_batch)
{
n_eval = params.n_batch;
}
if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval, n_past, 0)))
{
LOG_ERROR("failed to eval", {
{"n_eval", n_eval},
{"n_past", n_past},
{"embd", tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend())},
});
has_next_token = false;
return result;
}
n_past += n_eval;
}
if (params.n_predict == 0)
{
has_next_token = false;
result.tok = llama_token_eos(ctx);
return result;
}
{
// out of user input, sample next token
std::vector<llama_token_data> candidates;
candidates.reserve(llama_n_vocab(model));
result.tok = llama_sample_token(ctx, NULL, grammar, params, last_n_tokens, candidates);
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
const int32_t n_probs = params.n_probs;
if (params.temp <= 0 && n_probs > 0)
{
// For llama_sample_token_greedy we need to sort candidates
llama_sample_softmax(ctx, &candidates_p);
}
for (size_t i = 0; i < std::min(candidates_p.size, (size_t)n_probs); ++i)
{
result.probs.push_back({candidates_p.data[i].id, candidates_p.data[i].p});
}
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(result.tok);
if (tg) {
num_tokens_predicted++;
}
}
// add it to the context
embd.push_back(result.tok);
// decrement remaining sampling budget
--n_remain;
if (!embd.empty() && embd.back() == llama_token_eos(ctx))
{
// stopping_word = llama_token_to_piece(ctx, embd.back());
has_next_token = false;
stopped_eos = true;
LOG_VERBOSE("eos token found", {});
return result;
}
has_next_token = params.n_predict == -1 || n_remain != 0;
return result;
}
size_t findStoppingStrings(const size_t last_token_size,
@ -754,7 +833,7 @@ struct llama_server_context
params.n_predict) ||
stop_pos != std::string::npos));
if (slot.params.n_probs > 0)
if (params.n_probs > 0)
{
slot.generated_token_probs.push_back(result);
}
@ -1013,15 +1092,16 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
printf("usage: %s [options]\n", argv0);
printf("\n");
printf("options:\n");
printf(" -h, --help show this help message and exit\n");
printf(" -v, --verbose verbose output (default: %s)\n", server_verbose ? "enabled" : "disabled");
printf(" -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
printf(" -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
printf(" --rope-freq-base N RoPE base frequency (default: loaded from model)\n");
printf(" --rope-freq-scale N RoPE frequency scaling factor (default: loaded from model)\n");
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
printf(" not recommended: doubles context memory required and no measurable increase in quality\n");
printf(" -h, --help show this help message and exit\n");
printf(" -v, --verbose verbose output (default: %s)\n", server_verbose ? "enabled" : "disabled");
printf(" -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
printf(" -tb N, --threads-batch N number of threads to use during batch and prompt processing (default: same as --threads)\n");
printf(" -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
printf(" --rope-freq-base N RoPE base frequency (default: loaded from model)\n");
printf(" --rope-freq-scale N RoPE frequency scaling factor (default: loaded from model)\n");
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
printf(" not recommended: doubles context memory required and no measurable increase in quality\n");
if (llama_mlock_supported())
{
printf(" --mlock force system to keep model in RAM rather than swapping or compressing\n");
@ -1166,6 +1246,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
}
params.n_threads = std::stoi(argv[i]);
}
else if (arg == "--threads-batch" || arg == "-tb")
{
if (++i >= argc)
{
invalid_param = true;
break;
}
params.n_threads_batch = std::stoi(argv[i]);
}
else if (arg == "-b" || arg == "--batch-size")
{
if (++i >= argc)
@ -1343,35 +1432,35 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
static json format_generation_settings(llama_server_context &llama, llama_client_slot* &slot)
{
const auto eos_bias = slot->params.logit_bias.find(llama_token_eos(llama.ctx));
const bool ignore_eos = eos_bias != slot->params.logit_bias.end() &&
const auto eos_bias = llama.params.logit_bias.find(llama_token_eos(llama.ctx));
const bool ignore_eos = eos_bias != llama.params.logit_bias.end() &&
eos_bias->second < 0.0f && std::isinf(eos_bias->second);
return json{
{"n_ctx", llama.n_ctx},
{"model", llama.params.model_alias},
{"seed", slot->params.seed},
{"temp", slot->params.temp},
{"top_k", slot->params.top_k},
{"top_p", slot->params.top_p},
{"tfs_z", slot->params.tfs_z},
{"typical_p", slot->params.typical_p},
{"repeat_last_n", slot->params.repeat_last_n},
{"repeat_penalty", slot->params.repeat_penalty},
{"presence_penalty",slot->params.presence_penalty},
{"frequency_penalty", slot->params.frequency_penalty},
{"mirostat", slot->params.mirostat},
{"mirostat_tau", slot->params.mirostat_tau},
{"mirostat_eta", slot->params.mirostat_eta},
{"penalize_nl", slot->params.penalize_nl},
{"stop", slot->params.antiprompt},
{"n_predict", slot->params.n_predict},
// {"n_keep", slot.params.n_keep},
{"seed", llama.params.seed},
{"temp", llama.params.temp},
{"top_k", llama.params.top_k},
{"top_p", llama.params.top_p},
{"tfs_z", llama.params.tfs_z},
{"typical_p", llama.params.typical_p},
{"repeat_last_n", llama.params.repeat_last_n},
{"repeat_penalty", llama.params.repeat_penalty},
{"presence_penalty", llama.params.presence_penalty},
{"frequency_penalty", llama.params.frequency_penalty},
{"mirostat", llama.params.mirostat},
{"mirostat_tau", llama.params.mirostat_tau},
{"mirostat_eta", llama.params.mirostat_eta},
{"penalize_nl", llama.params.penalize_nl},
{"stop", llama.params.antiprompt},
{"n_predict", llama.params.n_predict},
{"n_keep", llama.params.n_keep},
{"ignore_eos", ignore_eos},
{"stream", slot->params.stream},
{"logit_bias", slot->params.logit_bias},
{"n_probs", slot->params.n_probs},
{"grammar", slot->params.grammar},
{"stream", llama.stream},
{"logit_bias", llama.params.logit_bias},
{"n_probs", llama.params.n_probs},
{"grammar", llama.params.grammar},
};
}
@ -1419,7 +1508,7 @@ static json format_final_response(llama_server_context &llama, llama_client_slot
// {"timings", format_timings(llama)},
};
if (slot->params.n_probs > 0)
if (llama.params.n_probs > 0)
{
res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs);
}
@ -1436,7 +1525,7 @@ static json format_partial_response(
{ "slot_id", slot->id }
};
if (slot->params.n_probs > 0)
if (llama.params.n_probs > 0)
{
res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs);
}
@ -1467,27 +1556,27 @@ static T json_value(const json &body, const std::string &key, const T &default_v
static void parse_options_completion(const json &body, llama_client_slot* &slot, llama_server_context &llama)
{
slot_params default_params;
gpt_params default_params;
slot->params.stream = json_value(body, "stream", false);
slot->params.n_predict = json_value(body, "n_predict", default_params.n_predict);
slot->params.top_k = json_value(body, "top_k", default_params.top_k);
slot->params.top_p = json_value(body, "top_p", default_params.top_p);
slot->params.tfs_z = json_value(body, "tfs_z", default_params.tfs_z);
slot->params.typical_p = json_value(body, "typical_p", default_params.typical_p);
slot->params.repeat_last_n = json_value(body, "repeat_last_n", default_params.repeat_last_n);
slot->params.temp = json_value(body, "temperature", default_params.temp);
slot->params.repeat_penalty = json_value(body, "repeat_penalty", default_params.repeat_penalty);
slot->params.presence_penalty = json_value(body, "presence_penalty", default_params.presence_penalty);
slot->params.frequency_penalty = json_value(body, "frequency_penalty", default_params.frequency_penalty);
slot->params.mirostat = json_value(body, "mirostat", default_params.mirostat);
slot->params.mirostat_tau = json_value(body, "mirostat_tau", default_params.mirostat_tau);
slot->params.mirostat_eta = json_value(body, "mirostat_eta", default_params.mirostat_eta);
slot->params.penalize_nl = json_value(body, "penalize_nl", default_params.penalize_nl);
//llama.params.n_keep = json_value(body, "n_keep", default_params.n_keep);
slot->params.seed = json_value(body, "seed", default_params.seed);
slot->params.grammar = json_value(body, "grammar", default_params.grammar);
slot->params.n_probs = json_value(body, "n_probs", default_params.n_probs);
llama.stream = json_value(body, "stream", false);
llama.params.n_predict = json_value(body, "n_predict", default_params.n_predict);
llama.params.top_k = json_value(body, "top_k", default_params.top_k);
llama.params.top_p = json_value(body, "top_p", default_params.top_p);
llama.params.tfs_z = json_value(body, "tfs_z", default_params.tfs_z);
llama.params.typical_p = json_value(body, "typical_p", default_params.typical_p);
llama.params.repeat_last_n = json_value(body, "repeat_last_n", default_params.repeat_last_n);
llama.params.temp = json_value(body, "temperature", default_params.temp);
llama.params.repeat_penalty = json_value(body, "repeat_penalty", default_params.repeat_penalty);
llama.params.presence_penalty = json_value(body, "presence_penalty", default_params.presence_penalty);
llama.params.frequency_penalty = json_value(body, "frequency_penalty", default_params.frequency_penalty);
llama.params.mirostat = json_value(body, "mirostat", default_params.mirostat);
llama.params.mirostat_tau = json_value(body, "mirostat_tau", default_params.mirostat_tau);
llama.params.mirostat_eta = json_value(body, "mirostat_eta", default_params.mirostat_eta);
llama.params.penalize_nl = json_value(body, "penalize_nl", default_params.penalize_nl);
llama.params.n_keep = json_value(body, "n_keep", default_params.n_keep);
llama.params.seed = json_value(body, "seed", default_params.seed);
llama.params.grammar = json_value(body, "grammar", default_params.grammar);
llama.params.n_probs = json_value(body, "n_probs", default_params.n_probs);
if (body.count("prompt") != 0)
{
@ -1498,10 +1587,10 @@ static void parse_options_completion(const json &body, llama_client_slot* &slot,
slot->prompt = "";
}
slot->params.logit_bias.clear();
llama.params.logit_bias.clear();
if (json_value(body, "ignore_eos", false))
{
slot->params.logit_bias[llama_token_eos(llama.ctx)] = -INFINITY;
llama.params.logit_bias[llama_token_eos(llama.ctx)] = -INFINITY;
}
const auto &logit_bias = body.find("logit_bias");
@ -1517,11 +1606,11 @@ static void parse_options_completion(const json &body, llama_client_slot* &slot,
{
if (el[1].is_number())
{
slot->params.logit_bias[tok] = el[1].get<float>();
llama.params.logit_bias[tok] = el[1].get<float>();
}
else if (el[1].is_boolean() && !el[1].get<bool>())
{
slot->params.logit_bias[tok] = -INFINITY;
llama.params.logit_bias[tok] = -INFINITY;
}
}
}
@ -1541,6 +1630,8 @@ static void parse_options_completion(const json &body, llama_client_slot* &slot,
}
}
llama.ctx_sampling = llama_sampling_context_init(llama.params, llama.grammar);
LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama, slot));
}
@ -1774,11 +1865,11 @@ int main(int argc, char **argv)
// }
// }
// auto probs = llama.generated_token_probs;
// if (llama.params.n_probs > 0 && llama.stopped_word) {
// const std::vector<llama_token> stop_word_toks = llama_tokenize(llama.ctx, llama.stopping_word, false);
// probs = std::vector<completion_token_output>(llama.generated_token_probs.begin(), llama.generated_token_probs.end() - stop_word_toks.size());
// }
auto probs = llama.generated_token_probs;
if (llama.params.n_probs > 0 && llama.stopped_word) {
const std::vector<llama_token> stop_word_toks = llama_tokenize(llama.ctx, llama.stopping_word, false);
probs = std::vector<completion_token_output>(llama.generated_token_probs.begin(), llama.generated_token_probs.end() - stop_word_toks.size());
}
// const json data = format_final_response(llama, llama.generated_text, probs);
@ -1796,32 +1887,70 @@ int main(int argc, char **argv)
// const completion_token_output token = slot->next();
// std::string token_str = llama_token_to_piece(llama.ctx, token.tok);
// std::vector<completion_token_output> probs_output = {};
size_t pos = std::min(sent_count, llama.generated_text.size());
// const json data = format_partial_response(llama, slot, token_str, probs_output);
// const std::string str =
// "data: " +
// data.dump(-1, ' ', false, json::error_handler_t::replace) +
// "\n\n";
const std::string str_test = llama.generated_text.substr(pos);
bool is_stop_full = false;
size_t stop_pos =
llama.findStoppingStrings(str_test, token_text.size(), STOP_FULL);
if (stop_pos != std::string::npos) {
is_stop_full = true;
llama.generated_text.erase(
llama.generated_text.begin() + pos + stop_pos,
llama.generated_text.end());
pos = std::min(sent_count, llama.generated_text.size());
} else {
is_stop_full = false;
stop_pos = llama.findStoppingStrings(str_test, token_text.size(),
STOP_PARTIAL);
}
// LOG_VERBOSE("data stream", {
// { "to_send", str }
// });
// if(!sink.write(str.c_str(), str.size())) {
// slot->release();
// return false;
// }
} else {
std::this_thread::sleep_for(std::chrono::milliseconds(5));
if (
stop_pos == std::string::npos ||
// Send rest of the text if we are at the end of the generation
(!llama.has_next_token && !is_stop_full && stop_pos > 0)
) {
const std::string to_send = llama.generated_text.substr(pos, std::string::npos);
sent_count += to_send.size();
std::vector<completion_token_output> probs_output = {};
if (llama.params.n_probs > 0) {
const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false);
size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size());
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size());
if (probs_pos < probs_stop_pos) {
probs_output = std::vector<completion_token_output>(llama.generated_token_probs.begin() + probs_pos, llama.generated_token_probs.begin() + probs_stop_pos);
}
sent_token_probs_index = probs_stop_pos;
}
const json data = format_partial_response(llama, to_send, probs_output);
const std::string str =
"data: " +
data.dump(-1, ' ', false, json::error_handler_t::replace) +
"\n\n";
LOG_VERBOSE("data stream", {
{ "to_send", str }
});
if (!sink.write(str.data(), str.size())) {
LOG_VERBOSE("stream closed", {});
llama_print_timings(llama.ctx);
return false;
}
}
// const json data = format_final_response(
// llama, slot,
// "",
// std::vector<completion_token_output>(
// slot->generated_token_probs.begin(),
// slot->generated_token_probs.begin() + sent_token_probs_index)
// );
if (!llama.has_next_token) {
// Generation is done, send extra information.
const json data = format_final_response(
llama,
"",
std::vector<completion_token_output>(llama.generated_token_probs.begin(), llama.generated_token_probs.begin() + sent_token_probs_index)
);
// const std::string str =
// "data: " +
@ -1907,15 +2036,15 @@ int main(int argc, char **argv)
// std::vector<completion_token_output> probs_output = {};
// if (llama.params.n_probs > 0) {
// const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false);
// size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size());
// size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size());
// if (probs_pos < probs_stop_pos) {
// probs_output = std::vector<completion_token_output>(llama.generated_token_probs.begin() + probs_pos, llama.generated_token_probs.begin() + probs_stop_pos);
// }
// sent_token_probs_index = probs_stop_pos;
// }
if (llama.params.n_probs > 0) {
const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false);
size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size());
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size());
if (probs_pos < probs_stop_pos) {
probs_output = std::vector<completion_token_output>(llama.generated_token_probs.begin() + probs_pos, llama.generated_token_probs.begin() + probs_stop_pos);
}
sent_token_probs_index = probs_stop_pos;
}
// const json data = format_partial_response(llama, to_send, probs_output);