sampling : refactor init to use llama_sampling_params (#3696)
* sampling : refactor init to use llama_sampling_params * llama : combine repetition, frequency and presence penalties in 1 call * examples : remove embd-input and gptneox-wip * sampling : rename penalty params + reduce size of "prev" vector * sampling : add llama_sampling_print helper * sampling : hide prev behind API and apply #3661 ggml-ci
This commit is contained in:
parent
8cf19d60dc
commit
d1031cf49c
30 changed files with 365 additions and 4502 deletions
|
@ -195,10 +195,12 @@ struct llama_server_context
|
|||
json prompt;
|
||||
std::vector<llama_token> embd;
|
||||
|
||||
gpt_params params;
|
||||
|
||||
llama_model *model = nullptr;
|
||||
llama_context *ctx = nullptr;
|
||||
gpt_params params;
|
||||
llama_sampling_context *ctx_sampling = nullptr;
|
||||
|
||||
int n_ctx;
|
||||
|
||||
bool truncated = false;
|
||||
|
@ -232,7 +234,7 @@ struct llama_server_context
|
|||
void rewind()
|
||||
{
|
||||
params.antiprompt.clear();
|
||||
params.grammar.clear();
|
||||
params.sparams.grammar.clear();
|
||||
num_prompt_tokens = 0;
|
||||
num_tokens_predicted = 0;
|
||||
generated_text = "";
|
||||
|
@ -246,11 +248,14 @@ struct llama_server_context
|
|||
multibyte_pending = 0;
|
||||
n_remain = 0;
|
||||
n_past = 0;
|
||||
params.sparams.n_prev = n_ctx;
|
||||
}
|
||||
|
||||
void initSampling() {
|
||||
if (ctx_sampling != nullptr) {
|
||||
llama_sampling_free(ctx_sampling);
|
||||
}
|
||||
ctx_sampling = llama_sampling_init(params);
|
||||
ctx_sampling = llama_sampling_init(params.sparams);
|
||||
}
|
||||
|
||||
bool loadModel(const gpt_params ¶ms_)
|
||||
|
@ -311,16 +316,32 @@ struct llama_server_context
|
|||
return prompt_tokens;
|
||||
}
|
||||
|
||||
bool loadGrammar()
|
||||
{
|
||||
ctx_sampling = llama_sampling_init(params);
|
||||
return true;
|
||||
void truncatePrompt(std::vector<llama_token> &prompt_tokens) {
|
||||
const int n_left = n_ctx - params.n_keep;
|
||||
const int n_block_size = n_left / 2;
|
||||
const int erased_blocks = (prompt_tokens.size() - params.n_keep - n_block_size) / n_block_size;
|
||||
|
||||
// Keep n_keep tokens at start of prompt (at most n_ctx - 4)
|
||||
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
|
||||
|
||||
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_block_size, prompt_tokens.end());
|
||||
|
||||
LOG_VERBOSE("input truncated", {
|
||||
{"n_ctx", n_ctx},
|
||||
{"n_keep", params.n_keep},
|
||||
{"n_left", n_left},
|
||||
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
|
||||
{"num_prompt_tokens", new_tokens.size()}
|
||||
});
|
||||
|
||||
truncated = true;
|
||||
prompt_tokens = new_tokens;
|
||||
}
|
||||
|
||||
void loadInfill()
|
||||
{
|
||||
bool suff_rm_leading_spc = true;
|
||||
if (params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) {
|
||||
if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) {
|
||||
params.input_suffix.erase(0, 1);
|
||||
suff_rm_leading_spc = false;
|
||||
}
|
||||
|
@ -336,6 +357,7 @@ struct llama_server_context
|
|||
prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(ctx));
|
||||
prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
|
||||
prefix_tokens.push_back(llama_token_middle(ctx));
|
||||
|
||||
auto prompt_tokens = prefix_tokens;
|
||||
|
||||
num_prompt_tokens = prompt_tokens.size();
|
||||
|
@ -347,31 +369,18 @@ struct llama_server_context
|
|||
params.n_keep = std::min(params.n_ctx - 4, params.n_keep);
|
||||
|
||||
// if input prompt is too big, truncate like normal
|
||||
if (num_prompt_tokens >= (size_t)params.n_ctx)
|
||||
if (num_prompt_tokens >= (size_t) n_ctx)
|
||||
{
|
||||
printf("Input prompt is too big, truncating. Can only take %d tokens but got %zu\n", params.n_ctx, num_prompt_tokens);
|
||||
// todo we probably want to cut from both sides
|
||||
const int n_left = (params.n_ctx - params.n_keep) / 2;
|
||||
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
|
||||
const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
|
||||
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
|
||||
std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), ctx_sampling->prev.begin());
|
||||
truncatePrompt(prompt_tokens);
|
||||
num_prompt_tokens = prompt_tokens.size();
|
||||
|
||||
LOG_VERBOSE("input truncated", {
|
||||
{"n_ctx", params.n_ctx},
|
||||
{"n_keep", params.n_keep},
|
||||
{"n_left", n_left},
|
||||
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
|
||||
});
|
||||
|
||||
truncated = true;
|
||||
prompt_tokens = new_tokens;
|
||||
GGML_ASSERT(num_prompt_tokens < (size_t)n_ctx);
|
||||
}
|
||||
else
|
||||
|
||||
// push the prompt into the sampling context (do not apply grammar)
|
||||
for (auto & token : prompt_tokens)
|
||||
{
|
||||
const size_t ps = num_prompt_tokens;
|
||||
std::fill(ctx_sampling->prev.begin(), ctx_sampling->prev.end() - ps, 0);
|
||||
std::copy(prompt_tokens.begin(), prompt_tokens.end(), ctx_sampling->prev.end() - ps);
|
||||
llama_sampling_accept(ctx_sampling, ctx, token, false);
|
||||
}
|
||||
|
||||
// compare the evaluated prompt with the new prompt
|
||||
|
@ -409,29 +418,18 @@ struct llama_server_context
|
|||
params.n_keep = std::min(n_ctx - 4, params.n_keep);
|
||||
|
||||
// if input prompt is too big, truncate like normal
|
||||
if (num_prompt_tokens >= (size_t)n_ctx)
|
||||
if (num_prompt_tokens >= (size_t) n_ctx)
|
||||
{
|
||||
const int n_left = (n_ctx - params.n_keep) / 2;
|
||||
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
|
||||
const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
|
||||
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
|
||||
std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), ctx_sampling->prev.begin());
|
||||
truncatePrompt(prompt_tokens);
|
||||
num_prompt_tokens = prompt_tokens.size();
|
||||
|
||||
LOG_VERBOSE("input truncated", {
|
||||
{"n_ctx", n_ctx},
|
||||
{"n_keep", params.n_keep},
|
||||
{"n_left", n_left},
|
||||
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
|
||||
});
|
||||
|
||||
truncated = true;
|
||||
prompt_tokens = new_tokens;
|
||||
GGML_ASSERT(num_prompt_tokens < (size_t)n_ctx);
|
||||
}
|
||||
else
|
||||
|
||||
// push the prompt into the sampling context (do not apply grammar)
|
||||
for (auto & token : prompt_tokens)
|
||||
{
|
||||
const size_t ps = num_prompt_tokens;
|
||||
std::fill(ctx_sampling->prev.begin(), ctx_sampling->prev.end() - ps, 0);
|
||||
std::copy(prompt_tokens.begin(), prompt_tokens.end(), ctx_sampling->prev.end() - ps);
|
||||
llama_sampling_accept(ctx_sampling, ctx, token, false);
|
||||
}
|
||||
|
||||
// compare the evaluated prompt with the new prompt
|
||||
|
@ -530,8 +528,8 @@ struct llama_server_context
|
|||
|
||||
llama_token_data_array cur_p = { ctx_sampling->cur.data(), ctx_sampling->cur.size(), false };
|
||||
|
||||
const int32_t n_probs = params.sampling_params.n_probs;
|
||||
if (params.sampling_params.temp <= 0 && n_probs > 0)
|
||||
const int32_t n_probs = params.sparams.n_probs;
|
||||
if (params.sparams.temp <= 0 && n_probs > 0)
|
||||
{
|
||||
// For llama_sample_token_greedy we need to sort candidates
|
||||
llama_sample_softmax(ctx, &cur_p);
|
||||
|
@ -542,7 +540,7 @@ struct llama_server_context
|
|||
result.probs.push_back({cur_p.data[i].id, cur_p.data[i].p});
|
||||
}
|
||||
|
||||
llama_sampling_accept(ctx_sampling, ctx, result.tok);
|
||||
llama_sampling_accept(ctx_sampling, ctx, result.tok, true);
|
||||
|
||||
if (tg) {
|
||||
num_tokens_predicted++;
|
||||
|
@ -606,7 +604,7 @@ struct llama_server_context
|
|||
const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_piece(ctx, token_with_probs.tok);
|
||||
generated_text += token_text;
|
||||
|
||||
if (params.sampling_params.n_probs > 0)
|
||||
if (params.sparams.n_probs > 0)
|
||||
{
|
||||
generated_token_probs.push_back(token_with_probs);
|
||||
}
|
||||
|
@ -1004,36 +1002,36 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
|||
|
||||
static json format_generation_settings(llama_server_context &llama)
|
||||
{
|
||||
const auto & sparams = llama.params.sampling_params;
|
||||
const auto & sparams = llama.params.sparams;
|
||||
const auto eos_bias = sparams.logit_bias.find(llama_token_eos(llama.ctx));
|
||||
const bool ignore_eos = eos_bias != sparams.logit_bias.end() &&
|
||||
eos_bias->second < 0.0f && std::isinf(eos_bias->second);
|
||||
|
||||
return json{
|
||||
{"n_ctx", llama.n_ctx},
|
||||
{"model", llama.params.model_alias},
|
||||
{"seed", llama.params.seed},
|
||||
{"temp", sparams.temp},
|
||||
{"top_k", sparams.top_k},
|
||||
{"top_p", sparams.top_p},
|
||||
{"tfs_z", sparams.tfs_z},
|
||||
{"typical_p", sparams.typical_p},
|
||||
{"repeat_last_n", sparams.repeat_last_n},
|
||||
{"repeat_penalty", sparams.repeat_penalty},
|
||||
{"presence_penalty", sparams.presence_penalty},
|
||||
{"frequency_penalty", sparams.frequency_penalty},
|
||||
{"mirostat", sparams.mirostat},
|
||||
{"mirostat_tau", sparams.mirostat_tau},
|
||||
{"mirostat_eta", sparams.mirostat_eta},
|
||||
{"penalize_nl", sparams.penalize_nl},
|
||||
{"stop", llama.params.antiprompt},
|
||||
{"n_predict", llama.params.n_predict},
|
||||
{"n_keep", llama.params.n_keep},
|
||||
{"ignore_eos", ignore_eos},
|
||||
{"stream", llama.stream},
|
||||
{"logit_bias", sparams.logit_bias},
|
||||
{"n_probs", sparams.n_probs},
|
||||
{"grammar", llama.params.grammar},
|
||||
{"n_ctx", llama.n_ctx},
|
||||
{"model", llama.params.model_alias},
|
||||
{"seed", llama.params.seed},
|
||||
{"temp", sparams.temp},
|
||||
{"top_k", sparams.top_k},
|
||||
{"top_p", sparams.top_p},
|
||||
{"tfs_z", sparams.tfs_z},
|
||||
{"typical_p", sparams.typical_p},
|
||||
{"repeat_last_n", sparams.penalty_last_n},
|
||||
{"repeat_penalty", sparams.penalty_repeat},
|
||||
{"frequency_penalty", sparams.penalty_freq},
|
||||
{"presence_penalty", sparams.penalty_present},
|
||||
{"mirostat", sparams.mirostat},
|
||||
{"mirostat_tau", sparams.mirostat_tau},
|
||||
{"mirostat_eta", sparams.mirostat_eta},
|
||||
{"penalize_nl", sparams.penalize_nl},
|
||||
{"stop", llama.params.antiprompt},
|
||||
{"n_predict", llama.params.n_predict},
|
||||
{"n_keep", llama.params.n_keep},
|
||||
{"ignore_eos", ignore_eos},
|
||||
{"stream", llama.stream},
|
||||
{"logit_bias", sparams.logit_bias},
|
||||
{"n_probs", sparams.n_probs},
|
||||
{"grammar", llama.params.sparams.grammar},
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -1081,7 +1079,7 @@ static json format_final_response(llama_server_context &llama, const std::string
|
|||
{"timings", format_timings(llama)},
|
||||
};
|
||||
|
||||
if (llama.params.sampling_params.n_probs > 0)
|
||||
if (llama.params.sparams.n_probs > 0)
|
||||
{
|
||||
res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs);
|
||||
}
|
||||
|
@ -1097,7 +1095,7 @@ static json format_partial_response(
|
|||
{"stop", false},
|
||||
};
|
||||
|
||||
if (llama.params.sampling_params.n_probs > 0)
|
||||
if (llama.params.sparams.n_probs > 0)
|
||||
{
|
||||
res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs);
|
||||
}
|
||||
|
@ -1129,28 +1127,30 @@ static T json_value(const json &body, const std::string &key, const T &default_v
|
|||
static void parse_options_completion(const json &body, llama_server_context &llama)
|
||||
{
|
||||
gpt_params default_params;
|
||||
const auto & default_sparams = default_params.sampling_params;
|
||||
auto & sparams = llama.params.sampling_params;
|
||||
const auto & default_sparams = default_params.sparams;
|
||||
|
||||
llama.stream = json_value(body, "stream", false);
|
||||
llama.params.n_predict = json_value(body, "n_predict", default_params.n_predict);
|
||||
sparams.top_k = json_value(body, "top_k", default_sparams.top_k);
|
||||
sparams.top_p = json_value(body, "top_p", default_sparams.top_p);
|
||||
sparams.tfs_z = json_value(body, "tfs_z", default_sparams.tfs_z);
|
||||
sparams.typical_p = json_value(body, "typical_p", default_sparams.typical_p);
|
||||
sparams.repeat_last_n = json_value(body, "repeat_last_n", default_sparams.repeat_last_n);
|
||||
sparams.temp = json_value(body, "temperature", default_sparams.temp);
|
||||
sparams.repeat_penalty = json_value(body, "repeat_penalty", default_sparams.repeat_penalty);
|
||||
sparams.presence_penalty = json_value(body, "presence_penalty", default_sparams.presence_penalty);
|
||||
sparams.frequency_penalty = json_value(body, "frequency_penalty", default_sparams.frequency_penalty);
|
||||
sparams.mirostat = json_value(body, "mirostat", default_sparams.mirostat);
|
||||
sparams.mirostat_tau = json_value(body, "mirostat_tau", default_sparams.mirostat_tau);
|
||||
sparams.mirostat_eta = json_value(body, "mirostat_eta", default_sparams.mirostat_eta);
|
||||
sparams.penalize_nl = json_value(body, "penalize_nl", default_sparams.penalize_nl);
|
||||
llama.params.n_keep = json_value(body, "n_keep", default_params.n_keep);
|
||||
llama.params.seed = json_value(body, "seed", default_params.seed);
|
||||
llama.params.grammar = json_value(body, "grammar", default_params.grammar);
|
||||
sparams.n_probs = json_value(body, "n_probs", default_sparams.n_probs);
|
||||
auto & params = llama.params;
|
||||
auto & sparams = llama.params.sparams;
|
||||
|
||||
llama.stream = json_value(body, "stream", false);
|
||||
params.n_predict = json_value(body, "n_predict", default_params.n_predict);
|
||||
sparams.top_k = json_value(body, "top_k", default_sparams.top_k);
|
||||
sparams.top_p = json_value(body, "top_p", default_sparams.top_p);
|
||||
sparams.tfs_z = json_value(body, "tfs_z", default_sparams.tfs_z);
|
||||
sparams.typical_p = json_value(body, "typical_p", default_sparams.typical_p);
|
||||
sparams.temp = json_value(body, "temperature", default_sparams.temp);
|
||||
sparams.penalty_last_n = json_value(body, "repeat_last_n", default_sparams.penalty_last_n);
|
||||
sparams.penalty_repeat = json_value(body, "repeat_penalty", default_sparams.penalty_repeat);
|
||||
sparams.penalty_freq = json_value(body, "frequency_penalty", default_sparams.penalty_freq);
|
||||
sparams.penalty_present = json_value(body, "presence_penalty", default_sparams.penalty_present);
|
||||
sparams.mirostat = json_value(body, "mirostat", default_sparams.mirostat);
|
||||
sparams.mirostat_tau = json_value(body, "mirostat_tau", default_sparams.mirostat_tau);
|
||||
sparams.mirostat_eta = json_value(body, "mirostat_eta", default_sparams.mirostat_eta);
|
||||
sparams.penalize_nl = json_value(body, "penalize_nl", default_sparams.penalize_nl);
|
||||
params.n_keep = json_value(body, "n_keep", default_params.n_keep);
|
||||
params.seed = json_value(body, "seed", default_params.seed);
|
||||
sparams.grammar = json_value(body, "grammar", default_sparams.grammar);
|
||||
sparams.n_probs = json_value(body, "n_probs", default_sparams.n_probs);
|
||||
|
||||
if (body.count("prompt") != 0)
|
||||
{
|
||||
|
@ -1204,8 +1204,6 @@ static void parse_options_completion(const json &body, llama_server_context &lla
|
|||
}
|
||||
}
|
||||
|
||||
llama.ctx_sampling = llama_sampling_init(llama.params);
|
||||
|
||||
LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama));
|
||||
}
|
||||
|
||||
|
@ -1374,15 +1372,9 @@ int main(int argc, char **argv)
|
|||
llama.rewind();
|
||||
|
||||
llama_reset_timings(llama.ctx);
|
||||
|
||||
parse_options_completion(json::parse(req.body), llama);
|
||||
|
||||
if (!llama.loadGrammar())
|
||||
{
|
||||
res.status = 400;
|
||||
return;
|
||||
}
|
||||
|
||||
llama.initSampling();
|
||||
llama.loadPrompt();
|
||||
llama.beginCompletion();
|
||||
|
||||
|
@ -1414,7 +1406,7 @@ int main(int argc, char **argv)
|
|||
}
|
||||
|
||||
auto probs = llama.generated_token_probs;
|
||||
if (llama.params.sampling_params.n_probs > 0 && llama.stopped_word) {
|
||||
if (llama.params.sparams.n_probs > 0 && llama.stopped_word) {
|
||||
const std::vector<llama_token> stop_word_toks = llama_tokenize(llama.ctx, llama.stopping_word, false);
|
||||
probs = std::vector<completion_token_output>(llama.generated_token_probs.begin(), llama.generated_token_probs.end() - stop_word_toks.size());
|
||||
}
|
||||
|
@ -1466,7 +1458,7 @@ int main(int argc, char **argv)
|
|||
|
||||
std::vector<completion_token_output> probs_output = {};
|
||||
|
||||
if (llama.params.sampling_params.n_probs > 0) {
|
||||
if (llama.params.sparams.n_probs > 0) {
|
||||
const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false);
|
||||
size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size());
|
||||
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size());
|
||||
|
@ -1537,14 +1529,9 @@ int main(int argc, char **argv)
|
|||
llama.rewind();
|
||||
|
||||
llama_reset_timings(llama.ctx);
|
||||
|
||||
parse_options_infill(json::parse(req.body), llama);
|
||||
|
||||
if (!llama.loadGrammar())
|
||||
{
|
||||
res.status = 400;
|
||||
return;
|
||||
}
|
||||
llama.initSampling();
|
||||
llama.loadInfill();
|
||||
llama.beginCompletion();
|
||||
const auto chunked_content_provider = [&](size_t, DataSink & sink) {
|
||||
|
@ -1587,7 +1574,7 @@ int main(int argc, char **argv)
|
|||
|
||||
std::vector<completion_token_output> probs_output = {};
|
||||
|
||||
if (llama.params.sampling_params.n_probs > 0) {
|
||||
if (llama.params.sparams.n_probs > 0) {
|
||||
const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false);
|
||||
size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size());
|
||||
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size());
|
||||
|
@ -1694,7 +1681,9 @@ int main(int argc, char **argv)
|
|||
const json body = json::parse(req.body);
|
||||
|
||||
llama.rewind();
|
||||
|
||||
llama_reset_timings(llama.ctx);
|
||||
|
||||
if (body.count("content") != 0)
|
||||
{
|
||||
llama.prompt = body["content"];
|
||||
|
@ -1704,6 +1693,8 @@ int main(int argc, char **argv)
|
|||
llama.prompt = "";
|
||||
}
|
||||
llama.params.n_predict = 0;
|
||||
|
||||
llama.initSampling();
|
||||
llama.loadPrompt();
|
||||
llama.beginCompletion();
|
||||
llama.doCompletion();
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue