main : small token healing cleanup
This commit is contained in:
parent
d4cbccb103
commit
7b6fdc2819
3 changed files with 15 additions and 10 deletions
|
@ -97,6 +97,10 @@ std::string llama_token_healing_prepare(
|
|||
return prefix;
|
||||
}
|
||||
|
||||
void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix) {
|
||||
ctx_sampling->token_healing_prefix = prefix;
|
||||
}
|
||||
|
||||
//
|
||||
// Sampling
|
||||
//
|
||||
|
@ -132,8 +136,6 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
|
|||
grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
|
||||
}
|
||||
|
||||
result->token_healing_prefix.clear();
|
||||
|
||||
result->prev.resize(params.n_prev);
|
||||
|
||||
llama_sampling_set_rng_seed(result, params.seed);
|
||||
|
@ -425,8 +427,6 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
|||
|
||||
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
|
||||
|
||||
// TODO should we skip penalties and grammar while token healing?
|
||||
|
||||
// apply penalties
|
||||
const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
|
||||
const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
|
||||
|
|
|
@ -170,9 +170,13 @@ void llama_sampling_accept(
|
|||
// Token healing
|
||||
//
|
||||
|
||||
// Roll back `tokens` for constrained generation according to the token healing
|
||||
// strategy. Returns the prefix for constrained generation.
|
||||
std::string llama_token_healing_prepare(
|
||||
const llama_context * ctx_main,
|
||||
llama_token_healing_type th_type,
|
||||
std::vector<llama_token> & tokens,
|
||||
int max_to_remove = -1,
|
||||
int * n_removed = nullptr);
|
||||
|
||||
void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix);
|
||||
|
|
|
@ -269,9 +269,10 @@ int main(int argc, char ** argv) {
|
|||
LOG("token_healing: disabled due to custom suffix");
|
||||
}
|
||||
std::string token_healing_prefix;
|
||||
int token_healing_n_removed = 0;
|
||||
if (!params.interactive_first && sparams.token_healing_enabled) {
|
||||
token_healing_prefix = llama_token_healing_prepare(ctx, sparams.token_healing_type, embd_inp,
|
||||
sparams.token_healing_n_rollback);
|
||||
sparams.token_healing_n_rollback, &token_healing_n_removed);
|
||||
}
|
||||
|
||||
// Should not run without any tokens
|
||||
|
@ -293,7 +294,7 @@ int main(int argc, char ** argv) {
|
|||
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true, true);
|
||||
LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str());
|
||||
|
||||
original_prompt_len = original_inp.size();
|
||||
original_prompt_len = original_inp.size() - token_healing_n_removed;
|
||||
guidance_offset = (int)guidance_inp.size() - original_prompt_len;
|
||||
LOG("original_prompt_len: %s", log_tostr(original_prompt_len));
|
||||
LOG("guidance_offset: %s", log_tostr(guidance_offset));
|
||||
|
@ -531,7 +532,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
|
||||
ctx_sampling->token_healing_prefix = token_healing_prefix;
|
||||
llama_token_healing_set_prefix(ctx_sampling, token_healing_prefix);
|
||||
|
||||
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
|
||||
// predict
|
||||
|
@ -834,7 +835,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
}
|
||||
|
||||
int token_healing_n_removed = 0;
|
||||
token_healing_n_removed = 0;
|
||||
if (n_past > 0 && is_interacting) {
|
||||
LOG("waiting for user input\n");
|
||||
|
||||
|
@ -926,6 +927,7 @@ int main(int argc, char ** argv) {
|
|||
: std::min(sparams.token_healing_n_rollback, n_new_tokens);
|
||||
token_healing_prefix = llama_token_healing_prepare(ctx, sparams.token_healing_type, embd_inp,
|
||||
max_to_remove, &token_healing_n_removed);
|
||||
n_bytes_to_skip = token_healing_prefix.size();
|
||||
}
|
||||
|
||||
for (size_t i = original_size; i < embd_inp.size(); ++i) {
|
||||
|
@ -948,8 +950,7 @@ int main(int argc, char ** argv) {
|
|||
llama_sampling_reset(ctx_sampling);
|
||||
if (token_healing_n_removed > 0) {
|
||||
// Set new prefix after an interaction
|
||||
ctx_sampling->token_healing_prefix = token_healing_prefix;
|
||||
n_bytes_to_skip = ctx_sampling->token_healing_prefix.size();
|
||||
llama_token_healing_set_prefix(ctx_sampling, token_healing_prefix);
|
||||
}
|
||||
}
|
||||
is_interacting = false;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue