main : small token healing cleanup

This commit is contained in:
mare5x 2024-05-06 21:25:12 +02:00
parent d4cbccb103
commit 7b6fdc2819
3 changed files with 15 additions and 10 deletions

View file

@ -97,6 +97,10 @@ std::string llama_token_healing_prepare(
return prefix;
}
void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix) {
ctx_sampling->token_healing_prefix = prefix;
}
//
// Sampling
//
@ -132,8 +136,6 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
}
result->token_healing_prefix.clear();
result->prev.resize(params.n_prev);
llama_sampling_set_rng_seed(result, params.seed);
@ -425,8 +427,6 @@ static llama_token_data_array llama_sampling_prepare_impl(
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
// TODO should we skip penalties and grammar while token healing?
// apply penalties
const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);

View file

@ -170,9 +170,13 @@ void llama_sampling_accept(
// Token healing
//
// Roll back `tokens` for constrained generation according to the token healing
// strategy. Returns the prefix for constrained generation.
std::string llama_token_healing_prepare(
const llama_context * ctx_main,
llama_token_healing_type th_type,
std::vector<llama_token> & tokens,
int max_to_remove = -1,
int * n_removed = nullptr);
void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix);

View file

@ -269,9 +269,10 @@ int main(int argc, char ** argv) {
LOG("token_healing: disabled due to custom suffix");
}
std::string token_healing_prefix;
int token_healing_n_removed = 0;
if (!params.interactive_first && sparams.token_healing_enabled) {
token_healing_prefix = llama_token_healing_prepare(ctx, sparams.token_healing_type, embd_inp,
sparams.token_healing_n_rollback);
sparams.token_healing_n_rollback, &token_healing_n_removed);
}
// Should not run without any tokens
@ -293,7 +294,7 @@ int main(int argc, char ** argv) {
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true, true);
LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str());
original_prompt_len = original_inp.size();
original_prompt_len = original_inp.size() - token_healing_n_removed;
guidance_offset = (int)guidance_inp.size() - original_prompt_len;
LOG("original_prompt_len: %s", log_tostr(original_prompt_len));
LOG("guidance_offset: %s", log_tostr(guidance_offset));
@ -531,7 +532,7 @@ int main(int argc, char ** argv) {
}
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
ctx_sampling->token_healing_prefix = token_healing_prefix;
llama_token_healing_set_prefix(ctx_sampling, token_healing_prefix);
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
// predict
@ -834,7 +835,7 @@ int main(int argc, char ** argv) {
}
}
int token_healing_n_removed = 0;
token_healing_n_removed = 0;
if (n_past > 0 && is_interacting) {
LOG("waiting for user input\n");
@ -926,6 +927,7 @@ int main(int argc, char ** argv) {
: std::min(sparams.token_healing_n_rollback, n_new_tokens);
token_healing_prefix = llama_token_healing_prepare(ctx, sparams.token_healing_type, embd_inp,
max_to_remove, &token_healing_n_removed);
n_bytes_to_skip = token_healing_prefix.size();
}
for (size_t i = original_size; i < embd_inp.size(); ++i) {
@ -948,8 +950,7 @@ int main(int argc, char ** argv) {
llama_sampling_reset(ctx_sampling);
if (token_healing_n_removed > 0) {
// Set new prefix after an interaction
ctx_sampling->token_healing_prefix = token_healing_prefix;
n_bytes_to_skip = ctx_sampling->token_healing_prefix.size();
llama_token_healing_set_prefix(ctx_sampling, token_healing_prefix);
}
}
is_interacting = false;