main : better token healing support for interactive mode
This commit is contained in:
parent
951b6593b2
commit
7d0cc78bc3
4 changed files with 45 additions and 16 deletions
|
@ -1298,7 +1298,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||||
auto & th_n_rollback = sparams.token_healing_n_rollback;
|
auto & th_n_rollback = sparams.token_healing_n_rollback;
|
||||||
std::string value(argv[i]);
|
std::string value(argv[i]);
|
||||||
/**/ if (value == "0" ) { sparams.token_healing_enabled = false; }
|
/**/ if (value == "0" ) { sparams.token_healing_enabled = false; }
|
||||||
else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; th_n_rollback = 1; }
|
else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; }
|
||||||
else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; }
|
else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; }
|
||||||
else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; }
|
else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; }
|
||||||
else if (value[0] == 'r' ) {
|
else if (value[0] == 'r' ) {
|
||||||
|
|
|
@ -46,20 +46,26 @@ std::string llama_token_healing_prepare(
|
||||||
const llama_context * ctx_main,
|
const llama_context * ctx_main,
|
||||||
llama_token_healing_type th_type,
|
llama_token_healing_type th_type,
|
||||||
std::vector<llama_token> & tokens,
|
std::vector<llama_token> & tokens,
|
||||||
int n_rollback) {
|
int max_to_remove,
|
||||||
|
int * n_removed) {
|
||||||
|
if (n_removed != nullptr) {
|
||||||
|
*n_removed = 0;
|
||||||
|
}
|
||||||
if (tokens.empty()) {
|
if (tokens.empty()) {
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
|
|
||||||
const llama_model * model = llama_get_model(ctx_main);
|
const llama_model * model = llama_get_model(ctx_main);
|
||||||
const bool is_dynamic = th_type == llama_token_healing_type::DYNAMIC_ONCE || th_type == llama_token_healing_type::DYNAMIC_MULTI;
|
const bool is_dynamic = th_type == llama_token_healing_type::DYNAMIC_ONCE || th_type == llama_token_healing_type::DYNAMIC_MULTI;
|
||||||
const int n_ctx = tokens.size();
|
const int n_ctx = tokens.size();
|
||||||
const int max_to_remove = is_dynamic ? n_ctx : std::min(n_rollback, n_ctx);
|
max_to_remove = th_type == llama_token_healing_type::ROLLBACK_LAST ? 1 : max_to_remove;
|
||||||
int n_removed = 0;
|
max_to_remove = max_to_remove < 0 ? n_ctx : std::min(max_to_remove, n_ctx);
|
||||||
|
int removed = 0;
|
||||||
std::string prefix;
|
std::string prefix;
|
||||||
// Roll back tokens a fixed amount or until there does not exist a token that can cover the prompt
|
// Roll back tokens a fixed amount or until there does not exist a token that can cover the prompt
|
||||||
// and stop early if a special token is encountered
|
// and stop early if a special token is encountered
|
||||||
while (n_removed < max_to_remove) {
|
while (removed < max_to_remove) {
|
||||||
const llama_token next_token_id = tokens[n_ctx - n_removed - 1];
|
const llama_token next_token_id = tokens[n_ctx - removed - 1];
|
||||||
if (llama_token_get_type(model, next_token_id) != LLAMA_TOKEN_TYPE_NORMAL) {
|
if (llama_token_get_type(model, next_token_id) != LLAMA_TOKEN_TYPE_NORMAL) {
|
||||||
// Don't roll back e.g. <|endoftext|> (if parse_special=true in llama_tokenize)
|
// Don't roll back e.g. <|endoftext|> (if parse_special=true in llama_tokenize)
|
||||||
break;
|
break;
|
||||||
|
@ -68,23 +74,26 @@ std::string llama_token_healing_prepare(
|
||||||
if (is_dynamic && !token_healing_prefix_exists(ctx_main, new_prefix)) {
|
if (is_dynamic && !token_healing_prefix_exists(ctx_main, new_prefix)) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
n_removed += 1;
|
removed += 1;
|
||||||
prefix = new_prefix;
|
prefix = new_prefix;
|
||||||
}
|
}
|
||||||
|
if (removed == 0) { // E.g. if the last token is a special token
|
||||||
if (n_removed == 0) { // E.g. if the last token is a special token
|
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
// If constrained decoding would give back the original prompt, there is no need to modify the context
|
// If constrained decoding would give back the original prompt, there is no need to modify the context
|
||||||
const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI ||
|
const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI ||
|
||||||
th_type == llama_token_healing_type::DYNAMIC_MULTI;
|
th_type == llama_token_healing_type::DYNAMIC_MULTI;
|
||||||
const std::vector<llama_token> candidates = token_healing_find_prefix(ctx_main, prefix, is_multi_step);
|
const std::vector<llama_token> candidates = token_healing_find_prefix(ctx_main, prefix, is_multi_step);
|
||||||
LOG("token_healing: prefix = '%s' (%d tokens)\n", prefix.c_str(), n_removed);
|
LOG("token_healing: prefix = '%s' (%d tokens)\n", prefix.c_str(), removed);
|
||||||
if (n_removed == 1 && candidates.size() == 1) {
|
if (removed == 1 && candidates.size() == 1) {
|
||||||
LOG("token_healing: nothing to heal\n");
|
LOG("token_healing: nothing to heal\n");
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
tokens.resize(n_ctx - n_removed);
|
// Finalize outputs
|
||||||
|
if (n_removed != nullptr) {
|
||||||
|
*n_removed = removed;
|
||||||
|
}
|
||||||
|
tokens.resize(n_ctx - removed);
|
||||||
return prefix;
|
return prefix;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -72,7 +72,7 @@ typedef struct llama_sampling_params {
|
||||||
|
|
||||||
llama_token_healing_type token_healing_type = llama_token_healing_type::ROLLBACK_LAST;
|
llama_token_healing_type token_healing_type = llama_token_healing_type::ROLLBACK_LAST;
|
||||||
bool token_healing_enabled = false;
|
bool token_healing_enabled = false;
|
||||||
int token_healing_n_rollback = 1; // number of tokens to roll back
|
int token_healing_n_rollback = -1; // number of tokens to roll back
|
||||||
} llama_sampling_params;
|
} llama_sampling_params;
|
||||||
|
|
||||||
// general sampler context
|
// general sampler context
|
||||||
|
@ -174,4 +174,5 @@ std::string llama_token_healing_prepare(
|
||||||
const llama_context * ctx_main,
|
const llama_context * ctx_main,
|
||||||
llama_token_healing_type th_type,
|
llama_token_healing_type th_type,
|
||||||
std::vector<llama_token> & tokens,
|
std::vector<llama_token> & tokens,
|
||||||
int n_rollback = 1);
|
int max_to_remove = -1,
|
||||||
|
int * n_removed = nullptr);
|
||||||
|
|
|
@ -264,8 +264,12 @@ int main(int argc, char ** argv) {
|
||||||
LOG("prompt: \"%s\"\n", log_tostr(params.prompt));
|
LOG("prompt: \"%s\"\n", log_tostr(params.prompt));
|
||||||
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
|
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
|
||||||
|
|
||||||
|
if (sparams.token_healing_enabled && (params.instruct || params.chatml || !params.input_suffix.empty())) {
|
||||||
|
sparams.token_healing_enabled = false;
|
||||||
|
LOG("token_healing: disabled due to custom suffix");
|
||||||
|
}
|
||||||
std::string token_healing_prefix;
|
std::string token_healing_prefix;
|
||||||
if (sparams.token_healing_enabled) {
|
if (!params.interactive_first && sparams.token_healing_enabled) {
|
||||||
token_healing_prefix = llama_token_healing_prepare(ctx, sparams.token_healing_type, embd_inp,
|
token_healing_prefix = llama_token_healing_prepare(ctx, sparams.token_healing_type, embd_inp,
|
||||||
sparams.token_healing_n_rollback);
|
sparams.token_healing_n_rollback);
|
||||||
}
|
}
|
||||||
|
@ -820,6 +824,7 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int token_healing_n_removed = 0;
|
||||||
if (n_past > 0 && is_interacting) {
|
if (n_past > 0 && is_interacting) {
|
||||||
LOG("waiting for user input\n");
|
LOG("waiting for user input\n");
|
||||||
|
|
||||||
|
@ -903,13 +908,23 @@ int main(int argc, char ** argv) {
|
||||||
embd_inp.insert(embd_inp.end(), cml_sfx.begin(), cml_sfx.end());
|
embd_inp.insert(embd_inp.end(), cml_sfx.begin(), cml_sfx.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (sparams.token_healing_enabled) {
|
||||||
|
// Limit token healing rollback to new tokens only (otherwise would need to shift everything)
|
||||||
|
const int n_new_tokens = embd_inp.size() - original_size;
|
||||||
|
const int max_to_remove = sparams.token_healing_n_rollback < 0
|
||||||
|
? n_new_tokens
|
||||||
|
: std::min(sparams.token_healing_n_rollback, n_new_tokens);
|
||||||
|
token_healing_prefix = llama_token_healing_prepare(ctx, sparams.token_healing_type, embd_inp,
|
||||||
|
max_to_remove, &token_healing_n_removed);
|
||||||
|
}
|
||||||
|
|
||||||
for (size_t i = original_size; i < embd_inp.size(); ++i) {
|
for (size_t i = original_size; i < embd_inp.size(); ++i) {
|
||||||
const llama_token token = embd_inp[i];
|
const llama_token token = embd_inp[i];
|
||||||
output_tokens.push_back(token);
|
output_tokens.push_back(token);
|
||||||
output_ss << llama_token_to_piece(ctx, token);
|
output_ss << llama_token_to_piece(ctx, token);
|
||||||
}
|
}
|
||||||
|
|
||||||
n_remain -= line_inp.size();
|
n_remain -= line_inp.size() + token_healing_n_removed;
|
||||||
LOG("n_remain: %d\n", n_remain);
|
LOG("n_remain: %d\n", n_remain);
|
||||||
} else {
|
} else {
|
||||||
LOG("empty line, passing control back\n");
|
LOG("empty line, passing control back\n");
|
||||||
|
@ -921,6 +936,10 @@ int main(int argc, char ** argv) {
|
||||||
if (n_past > 0) {
|
if (n_past > 0) {
|
||||||
if (is_interacting) {
|
if (is_interacting) {
|
||||||
llama_sampling_reset(ctx_sampling);
|
llama_sampling_reset(ctx_sampling);
|
||||||
|
if (token_healing_n_removed > 0) {
|
||||||
|
// Set new prefix after an interaction
|
||||||
|
ctx_sampling->token_healing_prefix = token_healing_prefix;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
is_interacting = false;
|
is_interacting = false;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue