diff --git a/common/sampling.cpp b/common/sampling.cpp index 03c2664bb..7e7bf5ea1 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -97,6 +97,10 @@ std::string llama_token_healing_prepare( return prefix; } +void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix) { + ctx_sampling->token_healing_prefix = prefix; +} + // // Sampling // @@ -132,8 +136,6 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_ grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root")); } - result->token_healing_prefix.clear(); - result->prev.resize(params.n_prev); llama_sampling_set_rng_seed(result, params.seed); @@ -425,8 +427,6 @@ static llama_token_data_array llama_sampling_prepare_impl( llama_token_data_array cur_p = { cur.data(), cur.size(), false }; - // TODO should we skip penalties and grammar while token healing? - // apply penalties const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev; const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n); diff --git a/common/sampling.h b/common/sampling.h index 90198bec9..2aa7bc2bd 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -170,9 +170,13 @@ void llama_sampling_accept( // Token healing // +// Roll back `tokens` for constrained generation according to the token healing +// strategy. Returns the prefix for constrained generation. std::string llama_token_healing_prepare( const llama_context * ctx_main, llama_token_healing_type th_type, std::vector & tokens, int max_to_remove = -1, int * n_removed = nullptr); + +void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index fd26fc380..70834b01a 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -269,9 +269,10 @@ int main(int argc, char ** argv) { LOG("token_healing: disabled due to custom suffix"); } std::string token_healing_prefix; + int token_healing_n_removed = 0; if (!params.interactive_first && sparams.token_healing_enabled) { token_healing_prefix = llama_token_healing_prepare(ctx, sparams.token_healing_type, embd_inp, - sparams.token_healing_n_rollback); + sparams.token_healing_n_rollback, &token_healing_n_removed); } // Should not run without any tokens @@ -293,7 +294,7 @@ int main(int argc, char ** argv) { std::vector original_inp = ::llama_tokenize(ctx, params.prompt, true, true); LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str()); - original_prompt_len = original_inp.size(); + original_prompt_len = original_inp.size() - token_healing_n_removed; guidance_offset = (int)guidance_inp.size() - original_prompt_len; LOG("original_prompt_len: %s", log_tostr(original_prompt_len)); LOG("guidance_offset: %s", log_tostr(guidance_offset)); @@ -531,7 +532,7 @@ int main(int argc, char ** argv) { } struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams); - ctx_sampling->token_healing_prefix = token_healing_prefix; + llama_token_healing_set_prefix(ctx_sampling, token_healing_prefix); while ((n_remain != 0 && !is_antiprompt) || params.interactive) { // predict @@ -834,7 +835,7 @@ int main(int argc, char ** argv) { } } - int token_healing_n_removed = 0; + token_healing_n_removed = 0; if (n_past > 0 && is_interacting) { LOG("waiting for user input\n"); @@ -926,6 +927,7 @@ int main(int argc, char ** argv) { : std::min(sparams.token_healing_n_rollback, n_new_tokens); token_healing_prefix = llama_token_healing_prepare(ctx, sparams.token_healing_type, embd_inp, max_to_remove, &token_healing_n_removed); + n_bytes_to_skip = token_healing_prefix.size(); } for (size_t i = original_size; i < embd_inp.size(); ++i) { @@ -948,8 +950,7 @@ int main(int argc, char ** argv) { llama_sampling_reset(ctx_sampling); if (token_healing_n_removed > 0) { // Set new prefix after an interaction - ctx_sampling->token_healing_prefix = token_healing_prefix; - n_bytes_to_skip = ctx_sampling->token_healing_prefix.size(); + llama_token_healing_set_prefix(ctx_sampling, token_healing_prefix); } } is_interacting = false;