sampling : refactor init to use llama_sampling_params

This commit is contained in:
Georgi Gerganov 2023-10-20 14:58:20 +03:00
parent 8cf19d60dc
commit cd1e937821
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
12 changed files with 110 additions and 142 deletions

View file

@ -39,8 +39,8 @@ static gpt_params * g_params;
static std::vector<llama_token> * g_input_tokens;
static std::ostringstream * g_output_ss;
static std::vector<llama_token> * g_output_tokens;
static bool is_interacting = false;
static bool is_interacting = false;
static void write_logfile(
const llama_context * ctx, const gpt_params & params, const llama_model * model,
@ -104,7 +104,7 @@ static void sigint_handler(int signo) {
int main(int argc, char ** argv) {
gpt_params params;
llama_sampling_params & sparams = params.sampling_params;
llama_sampling_params & sparams = params.sparams;
g_params = &params;
if (!gpt_params_parse(argc, argv, params)) {
@ -363,31 +363,6 @@ int main(int argc, char ** argv) {
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
LOG_TEE("\n\n");
struct llama_grammar * grammar = NULL;
grammar_parser::parse_state parsed_grammar;
if (!params.grammar.empty()) {
parsed_grammar = grammar_parser::parse(params.grammar.c_str());
// will be empty (default) if there are parse errors
if (parsed_grammar.rules.empty()) {
return 1;
}
LOG_TEE("%s: grammar:\n", __func__);
grammar_parser::print_grammar(stderr, parsed_grammar);
LOG_TEE("\n");
{
auto it = sparams.logit_bias.find(llama_token_eos(ctx));
if (it != sparams.logit_bias.end() && it->second == -INFINITY) {
LOG_TEE("%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__);
}
}
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
grammar = llama_grammar_init(
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
}
LOG_TEE("\n##### Infill mode #####\n\n");
if (params.infill) {
printf("\n************\n");
@ -430,7 +405,7 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd;
std::vector<llama_token> embd_guidance;
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params);
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
while (n_remain != 0 || params.interactive) {
// predict
@ -740,15 +715,7 @@ int main(int argc, char ** argv) {
if (n_past > 0) {
if (is_interacting) {
// reset grammar state if we're restarting generation
if (grammar != NULL) {
llama_grammar_free(grammar);
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
grammar = llama_grammar_init(
grammar_rules.data(), grammar_rules.size(),
parsed_grammar.symbol_ids.at("root"));
}
llama_sampling_reset(ctx_sampling);
}
is_interacting = false;
}
@ -778,9 +745,7 @@ int main(int argc, char ** argv) {
llama_free(ctx);
llama_free_model(model);
if (grammar != NULL) {
llama_grammar_free(grammar);
}
llama_sampling_free(ctx_sampling);
llama_backend_free();
#ifndef LOG_DISABLE_LOGS