token healing : refactor to return struct
This commit is contained in:
parent
db9c018891
commit
414fc13248
3 changed files with 44 additions and 52 deletions
|
@ -49,18 +49,13 @@ static size_t get_max_token_length(const llama_context * ctx_main) {
|
|||
return len;
|
||||
}
|
||||
|
||||
struct token_healing_info {
|
||||
std::string prefix;
|
||||
int n_tokens_removed;
|
||||
};
|
||||
|
||||
token_healing_info llama_token_healing_get_prefix(
|
||||
const llama_context * ctx_main,
|
||||
const llama_token_healing_type th_type,
|
||||
const std::vector<llama_token> & tokens,
|
||||
int max_to_remove) {
|
||||
static llama_token_healing_output llama_token_healing_get_prefix(
|
||||
const llama_context * ctx_main,
|
||||
const llama_token_healing_type th_type,
|
||||
const std::vector<llama_token> & tokens,
|
||||
int max_to_remove) {
|
||||
if (tokens.size() <= 1) {
|
||||
return {"", 0};
|
||||
return {};
|
||||
}
|
||||
|
||||
const int n_ctx = tokens.size();
|
||||
|
@ -122,34 +117,28 @@ token_healing_info llama_token_healing_get_prefix(
|
|||
// Token healing (external)
|
||||
//
|
||||
|
||||
std::string llama_token_healing_rollback(
|
||||
const llama_context * ctx_main,
|
||||
llama_token_healing_type th_type,
|
||||
std::vector<llama_token> & tokens,
|
||||
int max_to_remove,
|
||||
int * n_removed) {
|
||||
if (n_removed != nullptr) {
|
||||
*n_removed = 0;
|
||||
}
|
||||
llama_token_healing_output llama_token_healing_rollback(
|
||||
const llama_context * ctx_main,
|
||||
llama_token_healing_type th_type,
|
||||
std::vector<llama_token> & tokens,
|
||||
int max_to_remove) {
|
||||
// NB. To avoid returning empty `tokens`, at least 1 token will remain in `tokens` after rolling back.
|
||||
// It is the caller's responsibility to add BOS to the start of the prompt if they want to roll back the whole prompt.
|
||||
token_healing_info info = llama_token_healing_get_prefix(ctx_main, th_type, tokens, max_to_remove);
|
||||
llama_token_healing_output out = llama_token_healing_get_prefix(ctx_main, th_type, tokens, max_to_remove);
|
||||
|
||||
// If constrained decoding would give back the original prompt, there is no need to modify the prompt.
|
||||
const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI ||
|
||||
th_type == llama_token_healing_type::DYNAMIC_MULTI;
|
||||
const std::vector<llama_token> candidates = token_healing_get_candidates(ctx_main, info.prefix, is_multi_step);
|
||||
LOG("token_healing: prefix = '%s' (%d tokens)\n", info.prefix.c_str(), info.n_tokens_removed);
|
||||
if (info.n_tokens_removed == 1 && candidates.size() == 1) {
|
||||
const std::vector<llama_token> candidates = token_healing_get_candidates(ctx_main, out.prefix, is_multi_step);
|
||||
LOG("token_healing: prefix = '%s' (%d tokens)\n", out.prefix.c_str(), out.n_tokens_removed);
|
||||
if (out.n_tokens_removed == 1 && candidates.size() == 1) {
|
||||
LOG("token_healing: nothing to heal\n");
|
||||
return "";
|
||||
return {};
|
||||
}
|
||||
// Finalize outputs
|
||||
if (n_removed != nullptr) {
|
||||
*n_removed = info.n_tokens_removed;
|
||||
}
|
||||
tokens.resize(tokens.size() - info.n_tokens_removed);
|
||||
return info.prefix;
|
||||
|
||||
// Finally, trim prompt tokens
|
||||
tokens.resize(tokens.size() - out.n_tokens_removed);
|
||||
return out;
|
||||
}
|
||||
|
||||
void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix) {
|
||||
|
|
|
@ -176,13 +176,17 @@ void llama_sampling_accept(
|
|||
// Token healing
|
||||
//
|
||||
|
||||
// Roll back `tokens` for constrained generation according to the token healing
|
||||
// strategy. Returns the prefix for constrained generation.
|
||||
std::string llama_token_healing_rollback(
|
||||
const llama_context * ctx_main,
|
||||
llama_token_healing_type th_type,
|
||||
std::vector<llama_token> & tokens,
|
||||
int max_to_remove = -1,
|
||||
int * n_removed = nullptr);
|
||||
struct llama_token_healing_output {
|
||||
std::string prefix;
|
||||
int n_tokens_removed;
|
||||
};
|
||||
|
||||
// Roll back `tokens` for constrained generation according to the token healing strategy.
|
||||
// Call `llama_token_healing_set_prefix` with the returned prefix before the first sampling.
|
||||
llama_token_healing_output llama_token_healing_rollback(
|
||||
const llama_context * ctx_main,
|
||||
llama_token_healing_type th_type,
|
||||
std::vector<llama_token> & tokens,
|
||||
int max_to_remove = -1);
|
||||
|
||||
void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix);
|
||||
|
|
|
@ -295,11 +295,10 @@ int main(int argc, char ** argv) {
|
|||
sparams.token_healing_enabled = false;
|
||||
LOG("token healing: disabled due to custom suffix/conversation mode");
|
||||
}
|
||||
std::string token_healing_prefix;
|
||||
int token_healing_n_removed = 0;
|
||||
llama_token_healing_output token_healing_out{};
|
||||
if (!params.interactive_first && sparams.token_healing_enabled) {
|
||||
token_healing_prefix = llama_token_healing_rollback(ctx, sparams.token_healing_type, embd_inp,
|
||||
sparams.token_healing_n_rollback, &token_healing_n_removed);
|
||||
token_healing_out = llama_token_healing_rollback(ctx, sparams.token_healing_type, embd_inp,
|
||||
sparams.token_healing_n_rollback);
|
||||
}
|
||||
|
||||
// Should not run without any tokens
|
||||
|
@ -326,7 +325,7 @@ int main(int argc, char ** argv) {
|
|||
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true, true);
|
||||
LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str());
|
||||
|
||||
original_prompt_len = original_inp.size() - token_healing_n_removed;
|
||||
original_prompt_len = original_inp.size() - token_healing_out.n_tokens_removed;
|
||||
guidance_offset = (int)guidance_inp.size() - original_prompt_len;
|
||||
LOG("original_prompt_len: %s", log_tostr(original_prompt_len));
|
||||
LOG("guidance_offset: %s", log_tostr(guidance_offset));
|
||||
|
@ -548,7 +547,7 @@ int main(int argc, char ** argv) {
|
|||
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
|
||||
exit(1);
|
||||
}
|
||||
llama_token_healing_set_prefix(ctx_sampling, token_healing_prefix);
|
||||
llama_token_healing_set_prefix(ctx_sampling, token_healing_out.prefix);
|
||||
|
||||
if (llama_model_has_encoder(model)) {
|
||||
int enc_input_size = embd_inp.size();
|
||||
|
@ -883,7 +882,8 @@ int main(int argc, char ** argv) {
|
|||
assistant_ss << llama_token_to_piece(ctx, id, false);
|
||||
}
|
||||
|
||||
token_healing_n_removed = 0;
|
||||
token_healing_out = {};
|
||||
|
||||
if (n_past > 0 && is_interacting) {
|
||||
LOG("waiting for user input\n");
|
||||
|
||||
|
@ -962,9 +962,8 @@ int main(int argc, char ** argv) {
|
|||
const int max_to_remove = sparams.token_healing_n_rollback < 0
|
||||
? n_new_tokens
|
||||
: std::min(sparams.token_healing_n_rollback, n_new_tokens);
|
||||
token_healing_prefix = llama_token_healing_rollback(ctx, sparams.token_healing_type, embd_inp,
|
||||
max_to_remove, &token_healing_n_removed);
|
||||
n_bytes_to_skip = token_healing_prefix.size();
|
||||
token_healing_out = llama_token_healing_rollback(ctx, sparams.token_healing_type, embd_inp, max_to_remove);
|
||||
n_bytes_to_skip = token_healing_out.prefix.size();
|
||||
}
|
||||
|
||||
for (size_t i = original_size; i < embd_inp.size(); ++i) {
|
||||
|
@ -976,7 +975,7 @@ int main(int argc, char ** argv) {
|
|||
// reset assistant message
|
||||
assistant_ss.str("");
|
||||
|
||||
n_remain -= line_inp.size() + token_healing_n_removed;
|
||||
n_remain -= line_inp.size() + token_healing_out.n_tokens_removed;
|
||||
LOG("n_remain: %d\n", n_remain);
|
||||
} else {
|
||||
LOG("empty line, passing control back\n");
|
||||
|
@ -988,9 +987,9 @@ int main(int argc, char ** argv) {
|
|||
if (n_past > 0) {
|
||||
if (is_interacting) {
|
||||
llama_sampling_reset(ctx_sampling);
|
||||
if (token_healing_n_removed > 0) {
|
||||
if (token_healing_out.n_tokens_removed > 0) {
|
||||
// Set new prefix after an interaction
|
||||
llama_token_healing_set_prefix(ctx_sampling, token_healing_prefix);
|
||||
llama_token_healing_set_prefix(ctx_sampling, token_healing_out.prefix);
|
||||
}
|
||||
}
|
||||
is_interacting = false;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue