main : small token healing cleanup
This commit is contained in:
parent
d4cbccb103
commit
7b6fdc2819
3 changed files with 15 additions and 10 deletions
|
@ -97,6 +97,10 @@ std::string llama_token_healing_prepare(
|
||||||
return prefix;
|
return prefix;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix) {
|
||||||
|
ctx_sampling->token_healing_prefix = prefix;
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// Sampling
|
// Sampling
|
||||||
//
|
//
|
||||||
|
@ -132,8 +136,6 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
|
||||||
grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
|
grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
|
||||||
}
|
}
|
||||||
|
|
||||||
result->token_healing_prefix.clear();
|
|
||||||
|
|
||||||
result->prev.resize(params.n_prev);
|
result->prev.resize(params.n_prev);
|
||||||
|
|
||||||
llama_sampling_set_rng_seed(result, params.seed);
|
llama_sampling_set_rng_seed(result, params.seed);
|
||||||
|
@ -425,8 +427,6 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
||||||
|
|
||||||
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
|
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
|
||||||
|
|
||||||
// TODO should we skip penalties and grammar while token healing?
|
|
||||||
|
|
||||||
// apply penalties
|
// apply penalties
|
||||||
const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
|
const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
|
||||||
const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
|
const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
|
||||||
|
|
|
@ -170,9 +170,13 @@ void llama_sampling_accept(
|
||||||
// Token healing
|
// Token healing
|
||||||
//
|
//
|
||||||
|
|
||||||
|
// Roll back `tokens` for constrained generation according to the token healing
|
||||||
|
// strategy. Returns the prefix for constrained generation.
|
||||||
std::string llama_token_healing_prepare(
|
std::string llama_token_healing_prepare(
|
||||||
const llama_context * ctx_main,
|
const llama_context * ctx_main,
|
||||||
llama_token_healing_type th_type,
|
llama_token_healing_type th_type,
|
||||||
std::vector<llama_token> & tokens,
|
std::vector<llama_token> & tokens,
|
||||||
int max_to_remove = -1,
|
int max_to_remove = -1,
|
||||||
int * n_removed = nullptr);
|
int * n_removed = nullptr);
|
||||||
|
|
||||||
|
void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix);
|
||||||
|
|
|
@ -269,9 +269,10 @@ int main(int argc, char ** argv) {
|
||||||
LOG("token_healing: disabled due to custom suffix");
|
LOG("token_healing: disabled due to custom suffix");
|
||||||
}
|
}
|
||||||
std::string token_healing_prefix;
|
std::string token_healing_prefix;
|
||||||
|
int token_healing_n_removed = 0;
|
||||||
if (!params.interactive_first && sparams.token_healing_enabled) {
|
if (!params.interactive_first && sparams.token_healing_enabled) {
|
||||||
token_healing_prefix = llama_token_healing_prepare(ctx, sparams.token_healing_type, embd_inp,
|
token_healing_prefix = llama_token_healing_prepare(ctx, sparams.token_healing_type, embd_inp,
|
||||||
sparams.token_healing_n_rollback);
|
sparams.token_healing_n_rollback, &token_healing_n_removed);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should not run without any tokens
|
// Should not run without any tokens
|
||||||
|
@ -293,7 +294,7 @@ int main(int argc, char ** argv) {
|
||||||
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true, true);
|
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true, true);
|
||||||
LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str());
|
LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str());
|
||||||
|
|
||||||
original_prompt_len = original_inp.size();
|
original_prompt_len = original_inp.size() - token_healing_n_removed;
|
||||||
guidance_offset = (int)guidance_inp.size() - original_prompt_len;
|
guidance_offset = (int)guidance_inp.size() - original_prompt_len;
|
||||||
LOG("original_prompt_len: %s", log_tostr(original_prompt_len));
|
LOG("original_prompt_len: %s", log_tostr(original_prompt_len));
|
||||||
LOG("guidance_offset: %s", log_tostr(guidance_offset));
|
LOG("guidance_offset: %s", log_tostr(guidance_offset));
|
||||||
|
@ -531,7 +532,7 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
|
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
|
||||||
ctx_sampling->token_healing_prefix = token_healing_prefix;
|
llama_token_healing_set_prefix(ctx_sampling, token_healing_prefix);
|
||||||
|
|
||||||
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
|
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
|
||||||
// predict
|
// predict
|
||||||
|
@ -834,7 +835,7 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int token_healing_n_removed = 0;
|
token_healing_n_removed = 0;
|
||||||
if (n_past > 0 && is_interacting) {
|
if (n_past > 0 && is_interacting) {
|
||||||
LOG("waiting for user input\n");
|
LOG("waiting for user input\n");
|
||||||
|
|
||||||
|
@ -926,6 +927,7 @@ int main(int argc, char ** argv) {
|
||||||
: std::min(sparams.token_healing_n_rollback, n_new_tokens);
|
: std::min(sparams.token_healing_n_rollback, n_new_tokens);
|
||||||
token_healing_prefix = llama_token_healing_prepare(ctx, sparams.token_healing_type, embd_inp,
|
token_healing_prefix = llama_token_healing_prepare(ctx, sparams.token_healing_type, embd_inp,
|
||||||
max_to_remove, &token_healing_n_removed);
|
max_to_remove, &token_healing_n_removed);
|
||||||
|
n_bytes_to_skip = token_healing_prefix.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
for (size_t i = original_size; i < embd_inp.size(); ++i) {
|
for (size_t i = original_size; i < embd_inp.size(); ++i) {
|
||||||
|
@ -948,8 +950,7 @@ int main(int argc, char ** argv) {
|
||||||
llama_sampling_reset(ctx_sampling);
|
llama_sampling_reset(ctx_sampling);
|
||||||
if (token_healing_n_removed > 0) {
|
if (token_healing_n_removed > 0) {
|
||||||
// Set new prefix after an interaction
|
// Set new prefix after an interaction
|
||||||
ctx_sampling->token_healing_prefix = token_healing_prefix;
|
llama_token_healing_set_prefix(ctx_sampling, token_healing_prefix);
|
||||||
n_bytes_to_skip = ctx_sampling->token_healing_prefix.size();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
is_interacting = false;
|
is_interacting = false;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue