sampling : refactor init to use llama_sampling_params

This commit is contained in:
Georgi Gerganov 2023-10-20 14:58:20 +03:00
parent 8cf19d60dc
commit cd1e937821
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
12 changed files with 110 additions and 142 deletions

View file

@ -232,7 +232,7 @@ struct llama_server_context
void rewind()
{
params.antiprompt.clear();
params.grammar.clear();
params.sparams.grammar.clear();
num_prompt_tokens = 0;
num_tokens_predicted = 0;
generated_text = "";
@ -250,7 +250,7 @@ struct llama_server_context
if (ctx_sampling != nullptr) {
llama_sampling_free(ctx_sampling);
}
ctx_sampling = llama_sampling_init(params);
ctx_sampling = llama_sampling_init(params.sparams);
}
bool loadModel(const gpt_params &params_)
@ -313,7 +313,7 @@ struct llama_server_context
bool loadGrammar()
{
ctx_sampling = llama_sampling_init(params);
ctx_sampling = llama_sampling_init(params.sparams);
return true;
}
@ -530,8 +530,8 @@ struct llama_server_context
llama_token_data_array cur_p = { ctx_sampling->cur.data(), ctx_sampling->cur.size(), false };
const int32_t n_probs = params.sampling_params.n_probs;
if (params.sampling_params.temp <= 0 && n_probs > 0)
const int32_t n_probs = params.sparams.n_probs;
if (params.sparams.temp <= 0 && n_probs > 0)
{
// For llama_sample_token_greedy we need to sort candidates
llama_sample_softmax(ctx, &cur_p);
@ -606,7 +606,7 @@ struct llama_server_context
const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_piece(ctx, token_with_probs.tok);
generated_text += token_text;
if (params.sampling_params.n_probs > 0)
if (params.sparams.n_probs > 0)
{
generated_token_probs.push_back(token_with_probs);
}
@ -1004,7 +1004,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
static json format_generation_settings(llama_server_context &llama)
{
const auto & sparams = llama.params.sampling_params;
const auto & sparams = llama.params.sparams;
const auto eos_bias = sparams.logit_bias.find(llama_token_eos(llama.ctx));
const bool ignore_eos = eos_bias != sparams.logit_bias.end() &&
eos_bias->second < 0.0f && std::isinf(eos_bias->second);
@ -1033,7 +1033,7 @@ static json format_generation_settings(llama_server_context &llama)
{"stream", llama.stream},
{"logit_bias", sparams.logit_bias},
{"n_probs", sparams.n_probs},
{"grammar", llama.params.grammar},
{"grammar", llama.params.sparams.grammar},
};
}
@ -1081,7 +1081,7 @@ static json format_final_response(llama_server_context &llama, const std::string
{"timings", format_timings(llama)},
};
if (llama.params.sampling_params.n_probs > 0)
if (llama.params.sparams.n_probs > 0)
{
res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs);
}
@ -1097,7 +1097,7 @@ static json format_partial_response(
{"stop", false},
};
if (llama.params.sampling_params.n_probs > 0)
if (llama.params.sparams.n_probs > 0)
{
res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs);
}
@ -1129,11 +1129,13 @@ static T json_value(const json &body, const std::string &key, const T &default_v
static void parse_options_completion(const json &body, llama_server_context &llama)
{
gpt_params default_params;
const auto & default_sparams = default_params.sampling_params;
auto & sparams = llama.params.sampling_params;
const auto & default_sparams = default_params.sparams;
auto & params = llama.params;
auto & sparams = llama.params.sparams;
llama.stream = json_value(body, "stream", false);
llama.params.n_predict = json_value(body, "n_predict", default_params.n_predict);
params.n_predict = json_value(body, "n_predict", default_params.n_predict);
sparams.top_k = json_value(body, "top_k", default_sparams.top_k);
sparams.top_p = json_value(body, "top_p", default_sparams.top_p);
sparams.tfs_z = json_value(body, "tfs_z", default_sparams.tfs_z);
@ -1147,9 +1149,9 @@ static void parse_options_completion(const json &body, llama_server_context &lla
sparams.mirostat_tau = json_value(body, "mirostat_tau", default_sparams.mirostat_tau);
sparams.mirostat_eta = json_value(body, "mirostat_eta", default_sparams.mirostat_eta);
sparams.penalize_nl = json_value(body, "penalize_nl", default_sparams.penalize_nl);
llama.params.n_keep = json_value(body, "n_keep", default_params.n_keep);
llama.params.seed = json_value(body, "seed", default_params.seed);
llama.params.grammar = json_value(body, "grammar", default_params.grammar);
params.n_keep = json_value(body, "n_keep", default_params.n_keep);
params.seed = json_value(body, "seed", default_params.seed);
sparams.grammar = json_value(body, "grammar", default_sparams.grammar);
sparams.n_probs = json_value(body, "n_probs", default_sparams.n_probs);
if (body.count("prompt") != 0)
@ -1204,7 +1206,7 @@ static void parse_options_completion(const json &body, llama_server_context &lla
}
}
llama.ctx_sampling = llama_sampling_init(llama.params);
llama.ctx_sampling = llama_sampling_init(llama.params.sparams);
LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama));
}
@ -1414,7 +1416,7 @@ int main(int argc, char **argv)
}
auto probs = llama.generated_token_probs;
if (llama.params.sampling_params.n_probs > 0 && llama.stopped_word) {
if (llama.params.sparams.n_probs > 0 && llama.stopped_word) {
const std::vector<llama_token> stop_word_toks = llama_tokenize(llama.ctx, llama.stopping_word, false);
probs = std::vector<completion_token_output>(llama.generated_token_probs.begin(), llama.generated_token_probs.end() - stop_word_toks.size());
}
@ -1466,7 +1468,7 @@ int main(int argc, char **argv)
std::vector<completion_token_output> probs_output = {};
if (llama.params.sampling_params.n_probs > 0) {
if (llama.params.sparams.n_probs > 0) {
const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false);
size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size());
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size());
@ -1587,7 +1589,7 @@ int main(int argc, char **argv)
std::vector<completion_token_output> probs_output = {};
if (llama.params.sampling_params.n_probs > 0) {
if (llama.params.sparams.n_probs > 0) {
const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false);
size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size());
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size());