common : fix mirostat state when using multiple sequences (#3543)
* Fix mirostat state when using multiple sequences * Fix mirostat by completely refactoring sampling! * Try to fix zig build. * Export function to fetch/create default sampler states Code formatting cleanups and add some comments Silence a warning about id not being used when logging is disabled * Apply some renaming suggestions. Fix comments that were out of sync with the pull. * Use more consistant naming convention for sampling contexts
This commit is contained in:
parent
8c70a5ff25
commit
70c29da118
14 changed files with 495 additions and 334 deletions
|
@ -200,6 +200,7 @@ struct llama_server_context
|
|||
llama_model *model = nullptr;
|
||||
llama_context *ctx = nullptr;
|
||||
gpt_params params;
|
||||
llama_sampling_context ctx_sampling;
|
||||
int n_ctx;
|
||||
|
||||
grammar_parser::parse_state parsed_grammar;
|
||||
|
@ -254,6 +255,7 @@ struct llama_server_context
|
|||
if (grammar != nullptr) {
|
||||
llama_grammar_free(grammar);
|
||||
grammar = nullptr;
|
||||
ctx_sampling = llama_sampling_context_init(params, NULL);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -329,8 +331,8 @@ struct llama_server_context
|
|||
grammar_parser::print_grammar(stderr, parsed_grammar);
|
||||
|
||||
{
|
||||
auto it = params.logit_bias.find(llama_token_eos(ctx));
|
||||
if (it != params.logit_bias.end() && it->second == -INFINITY) {
|
||||
auto it = params.sampling_params.logit_bias.find(llama_token_eos(ctx));
|
||||
if (it != params.sampling_params.logit_bias.end() && it->second == -INFINITY) {
|
||||
LOG_WARNING("EOS token is disabled, which will cause most grammars to fail", {});
|
||||
}
|
||||
}
|
||||
|
@ -339,6 +341,7 @@ struct llama_server_context
|
|||
grammar = llama_grammar_init(
|
||||
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
|
||||
}
|
||||
ctx_sampling = llama_sampling_context_init(params, grammar);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -550,12 +553,12 @@ struct llama_server_context
|
|||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(llama_n_vocab(model));
|
||||
|
||||
result.tok = llama_sample_token(ctx, NULL, grammar, params, last_n_tokens, candidates);
|
||||
result.tok = llama_sampling_sample(ctx, NULL, ctx_sampling, last_n_tokens, candidates);
|
||||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
|
||||
const int32_t n_probs = params.n_probs;
|
||||
if (params.temp <= 0 && n_probs > 0)
|
||||
const int32_t n_probs = params.sampling_params.n_probs;
|
||||
if (params.sampling_params.temp <= 0 && n_probs > 0)
|
||||
{
|
||||
// For llama_sample_token_greedy we need to sort candidates
|
||||
llama_sample_softmax(ctx, &candidates_p);
|
||||
|
@ -630,7 +633,7 @@ struct llama_server_context
|
|||
const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_piece(ctx, token_with_probs.tok);
|
||||
generated_text += token_text;
|
||||
|
||||
if (params.n_probs > 0)
|
||||
if (params.sampling_params.n_probs > 0)
|
||||
{
|
||||
generated_token_probs.push_back(token_with_probs);
|
||||
}
|
||||
|
@ -1018,34 +1021,35 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
|||
|
||||
static json format_generation_settings(llama_server_context &llama)
|
||||
{
|
||||
const auto eos_bias = llama.params.logit_bias.find(llama_token_eos(llama.ctx));
|
||||
const bool ignore_eos = eos_bias != llama.params.logit_bias.end() &&
|
||||
const auto & sparams = llama.params.sampling_params;
|
||||
const auto eos_bias = sparams.logit_bias.find(llama_token_eos(llama.ctx));
|
||||
const bool ignore_eos = eos_bias != sparams.logit_bias.end() &&
|
||||
eos_bias->second < 0.0f && std::isinf(eos_bias->second);
|
||||
|
||||
return json{
|
||||
{"n_ctx", llama.n_ctx},
|
||||
{"model", llama.params.model_alias},
|
||||
{"seed", llama.params.seed},
|
||||
{"temp", llama.params.temp},
|
||||
{"top_k", llama.params.top_k},
|
||||
{"top_p", llama.params.top_p},
|
||||
{"tfs_z", llama.params.tfs_z},
|
||||
{"typical_p", llama.params.typical_p},
|
||||
{"repeat_last_n", llama.params.repeat_last_n},
|
||||
{"repeat_penalty", llama.params.repeat_penalty},
|
||||
{"presence_penalty", llama.params.presence_penalty},
|
||||
{"frequency_penalty", llama.params.frequency_penalty},
|
||||
{"mirostat", llama.params.mirostat},
|
||||
{"mirostat_tau", llama.params.mirostat_tau},
|
||||
{"mirostat_eta", llama.params.mirostat_eta},
|
||||
{"penalize_nl", llama.params.penalize_nl},
|
||||
{"temp", sparams.temp},
|
||||
{"top_k", sparams.top_k},
|
||||
{"top_p", sparams.top_p},
|
||||
{"tfs_z", sparams.tfs_z},
|
||||
{"typical_p", sparams.typical_p},
|
||||
{"repeat_last_n", sparams.repeat_last_n},
|
||||
{"repeat_penalty", sparams.repeat_penalty},
|
||||
{"presence_penalty", sparams.presence_penalty},
|
||||
{"frequency_penalty", sparams.frequency_penalty},
|
||||
{"mirostat", sparams.mirostat},
|
||||
{"mirostat_tau", sparams.mirostat_tau},
|
||||
{"mirostat_eta", sparams.mirostat_eta},
|
||||
{"penalize_nl", sparams.penalize_nl},
|
||||
{"stop", llama.params.antiprompt},
|
||||
{"n_predict", llama.params.n_predict},
|
||||
{"n_keep", llama.params.n_keep},
|
||||
{"ignore_eos", ignore_eos},
|
||||
{"stream", llama.stream},
|
||||
{"logit_bias", llama.params.logit_bias},
|
||||
{"n_probs", llama.params.n_probs},
|
||||
{"logit_bias", sparams.logit_bias},
|
||||
{"n_probs", sparams.n_probs},
|
||||
{"grammar", llama.params.grammar},
|
||||
};
|
||||
}
|
||||
|
@ -1094,7 +1098,7 @@ static json format_final_response(llama_server_context &llama, const std::string
|
|||
{"timings", format_timings(llama)},
|
||||
};
|
||||
|
||||
if (llama.params.n_probs > 0)
|
||||
if (llama.params.sampling_params.n_probs > 0)
|
||||
{
|
||||
res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs);
|
||||
}
|
||||
|
@ -1110,7 +1114,7 @@ static json format_partial_response(
|
|||
{"stop", false},
|
||||
};
|
||||
|
||||
if (llama.params.n_probs > 0)
|
||||
if (llama.params.sampling_params.n_probs > 0)
|
||||
{
|
||||
res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs);
|
||||
}
|
||||
|
@ -1142,26 +1146,28 @@ static T json_value(const json &body, const std::string &key, const T &default_v
|
|||
static void parse_options_completion(const json &body, llama_server_context &llama)
|
||||
{
|
||||
gpt_params default_params;
|
||||
const auto & default_sparams = default_params.sampling_params;
|
||||
auto & sparams = llama.params.sampling_params;
|
||||
|
||||
llama.stream = json_value(body, "stream", false);
|
||||
llama.params.n_predict = json_value(body, "n_predict", default_params.n_predict);
|
||||
llama.params.top_k = json_value(body, "top_k", default_params.top_k);
|
||||
llama.params.top_p = json_value(body, "top_p", default_params.top_p);
|
||||
llama.params.tfs_z = json_value(body, "tfs_z", default_params.tfs_z);
|
||||
llama.params.typical_p = json_value(body, "typical_p", default_params.typical_p);
|
||||
llama.params.repeat_last_n = json_value(body, "repeat_last_n", default_params.repeat_last_n);
|
||||
llama.params.temp = json_value(body, "temperature", default_params.temp);
|
||||
llama.params.repeat_penalty = json_value(body, "repeat_penalty", default_params.repeat_penalty);
|
||||
llama.params.presence_penalty = json_value(body, "presence_penalty", default_params.presence_penalty);
|
||||
llama.params.frequency_penalty = json_value(body, "frequency_penalty", default_params.frequency_penalty);
|
||||
llama.params.mirostat = json_value(body, "mirostat", default_params.mirostat);
|
||||
llama.params.mirostat_tau = json_value(body, "mirostat_tau", default_params.mirostat_tau);
|
||||
llama.params.mirostat_eta = json_value(body, "mirostat_eta", default_params.mirostat_eta);
|
||||
llama.params.penalize_nl = json_value(body, "penalize_nl", default_params.penalize_nl);
|
||||
sparams.top_k = json_value(body, "top_k", default_sparams.top_k);
|
||||
sparams.top_p = json_value(body, "top_p", default_sparams.top_p);
|
||||
sparams.tfs_z = json_value(body, "tfs_z", default_sparams.tfs_z);
|
||||
sparams.typical_p = json_value(body, "typical_p", default_sparams.typical_p);
|
||||
sparams.repeat_last_n = json_value(body, "repeat_last_n", default_sparams.repeat_last_n);
|
||||
sparams.temp = json_value(body, "temperature", default_sparams.temp);
|
||||
sparams.repeat_penalty = json_value(body, "repeat_penalty", default_sparams.repeat_penalty);
|
||||
sparams.presence_penalty = json_value(body, "presence_penalty", default_sparams.presence_penalty);
|
||||
sparams.frequency_penalty = json_value(body, "frequency_penalty", default_sparams.frequency_penalty);
|
||||
sparams.mirostat = json_value(body, "mirostat", default_sparams.mirostat);
|
||||
sparams.mirostat_tau = json_value(body, "mirostat_tau", default_sparams.mirostat_tau);
|
||||
sparams.mirostat_eta = json_value(body, "mirostat_eta", default_sparams.mirostat_eta);
|
||||
sparams.penalize_nl = json_value(body, "penalize_nl", default_sparams.penalize_nl);
|
||||
llama.params.n_keep = json_value(body, "n_keep", default_params.n_keep);
|
||||
llama.params.seed = json_value(body, "seed", default_params.seed);
|
||||
llama.params.grammar = json_value(body, "grammar", default_params.grammar);
|
||||
llama.params.n_probs = json_value(body, "n_probs", default_params.n_probs);
|
||||
sparams.n_probs = json_value(body, "n_probs", default_sparams.n_probs);
|
||||
|
||||
if (body.count("prompt") != 0)
|
||||
{
|
||||
|
@ -1172,10 +1178,10 @@ static void parse_options_completion(const json &body, llama_server_context &lla
|
|||
llama.prompt = "";
|
||||
}
|
||||
|
||||
llama.params.logit_bias.clear();
|
||||
sparams.logit_bias.clear();
|
||||
if (json_value(body, "ignore_eos", false))
|
||||
{
|
||||
llama.params.logit_bias[llama_token_eos(llama.ctx)] = -INFINITY;
|
||||
sparams.logit_bias[llama_token_eos(llama.ctx)] = -INFINITY;
|
||||
}
|
||||
|
||||
const auto &logit_bias = body.find("logit_bias");
|
||||
|
@ -1191,11 +1197,11 @@ static void parse_options_completion(const json &body, llama_server_context &lla
|
|||
{
|
||||
if (el[1].is_number())
|
||||
{
|
||||
llama.params.logit_bias[tok] = el[1].get<float>();
|
||||
sparams.logit_bias[tok] = el[1].get<float>();
|
||||
}
|
||||
else if (el[1].is_boolean() && !el[1].get<bool>())
|
||||
{
|
||||
llama.params.logit_bias[tok] = -INFINITY;
|
||||
sparams.logit_bias[tok] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1215,6 +1221,8 @@ static void parse_options_completion(const json &body, llama_server_context &lla
|
|||
}
|
||||
}
|
||||
|
||||
llama.ctx_sampling = llama_sampling_context_init(llama.params, llama.grammar);
|
||||
|
||||
LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama));
|
||||
}
|
||||
|
||||
|
@ -1423,7 +1431,7 @@ int main(int argc, char **argv)
|
|||
}
|
||||
|
||||
auto probs = llama.generated_token_probs;
|
||||
if (llama.params.n_probs > 0 && llama.stopped_word) {
|
||||
if (llama.params.sampling_params.n_probs > 0 && llama.stopped_word) {
|
||||
const std::vector<llama_token> stop_word_toks = llama_tokenize(llama.ctx, llama.stopping_word, false);
|
||||
probs = std::vector<completion_token_output>(llama.generated_token_probs.begin(), llama.generated_token_probs.end() - stop_word_toks.size());
|
||||
}
|
||||
|
@ -1475,7 +1483,7 @@ int main(int argc, char **argv)
|
|||
|
||||
std::vector<completion_token_output> probs_output = {};
|
||||
|
||||
if (llama.params.n_probs > 0) {
|
||||
if (llama.params.sampling_params.n_probs > 0) {
|
||||
const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false);
|
||||
size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size());
|
||||
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size());
|
||||
|
@ -1596,7 +1604,7 @@ int main(int argc, char **argv)
|
|||
|
||||
std::vector<completion_token_output> probs_output = {};
|
||||
|
||||
if (llama.params.n_probs > 0) {
|
||||
if (llama.params.sampling_params.n_probs > 0) {
|
||||
const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false);
|
||||
size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size());
|
||||
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size());
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue