token healing : refactor argument parsing

Unify `main` and `server` token healing argument handling.
This commit is contained in:
mare5x 2024-07-01 11:51:39 +02:00
parent 3ba5c55bc4
commit ea4abc9d82
5 changed files with 61 additions and 60 deletions

View file

@ -1095,21 +1095,8 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
} }
if (arg == "-th" || arg == "--token-healing") { if (arg == "-th" || arg == "--token-healing") {
CHECK_ARG CHECK_ARG
sparams.token_healing_enabled = true;
auto & th_type = sparams.token_healing_type;
auto & th_n_rollback = sparams.token_healing_n_rollback;
std::string value(argv[i]); std::string value(argv[i]);
/**/ if (value == "0" ) { sparams.token_healing_enabled = false; } invalid_param = !llama_token_healing_parse_params(value, sparams.token_healing);
else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; }
else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; }
else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; }
else if (value[0] == 'r' ) {
th_type = llama_token_healing_type::ROLLBACK_MULTI;
th_n_rollback = std::stoi(value.substr(1));
if (th_n_rollback <= 0) {
sparams.token_healing_enabled = false;
}
} else { invalid_param = true; }
return true; return true;
} }
if (arg == "--override-kv") { if (arg == "--override-kv") {

View file

@ -154,6 +154,25 @@ void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const
ctx_sampling->token_healing_prefix = prefix; ctx_sampling->token_healing_prefix = prefix;
} }
bool llama_token_healing_parse_params(const std::string & params, llama_token_healing_params & th_params) {
th_params.enabled = true;
th_params.n_rollback = -1;
/**/ if (params == "0" ) { th_params.enabled = false; }
else if (params == "1" ) { th_params.type = llama_token_healing_type::ROLLBACK_LAST; }
else if (params == "d1") { th_params.type = llama_token_healing_type::DYNAMIC_ONCE; }
else if (params == "d" ) { th_params.type = llama_token_healing_type::DYNAMIC_MULTI; }
else if (params[0] == 'r' ) {
th_params.type = llama_token_healing_type::ROLLBACK_MULTI;
th_params.n_rollback = std::stoi(params.substr(1));
if (th_params.n_rollback <= 0) {
return false;
}
} else {
return false;
}
return true;
}
// //
// Sampling // Sampling
// //
@ -552,11 +571,10 @@ static llama_token_data_array llama_sampling_prepare_impl(
cur.resize(n_vocab); cur.resize(n_vocab);
// Constrain tokens based on the remaining token healing prefix (if any) // Constrain tokens based on the remaining token healing prefix (if any)
const auto & th_type = params.token_healing_type;
const auto & th_prefix = ctx_sampling->token_healing_prefix; const auto & th_prefix = ctx_sampling->token_healing_prefix;
if (params.token_healing_enabled && !th_prefix.empty()) { if (params.token_healing.enabled && !th_prefix.empty()) {
const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI || const bool is_multi_step = params.token_healing.type == llama_token_healing_type::ROLLBACK_MULTI ||
th_type == llama_token_healing_type::DYNAMIC_MULTI; params.token_healing.type == llama_token_healing_type::DYNAMIC_MULTI;
std::vector<llama_token> th_candidates = token_healing_get_candidates(ctx_main, th_prefix, is_multi_step); std::vector<llama_token> th_candidates = token_healing_get_candidates(ctx_main, th_prefix, is_multi_step);
LOG("token_healing: prefix = '%s'\n", th_prefix.c_str()); LOG("token_healing: prefix = '%s'\n", th_prefix.c_str());
@ -635,7 +653,7 @@ void llama_sampling_accept(
llama_grammar_accept_token(ctx_sampling->grammar, ctx_main, id); llama_grammar_accept_token(ctx_sampling->grammar, ctx_main, id);
} }
if (ctx_sampling->params.token_healing_enabled && apply_grammar) { if (ctx_sampling->params.token_healing.enabled && apply_grammar) {
std::string & th_prefix = ctx_sampling->token_healing_prefix; std::string & th_prefix = ctx_sampling->token_healing_prefix;
if (!th_prefix.empty()) { if (!th_prefix.empty()) {
const std::string new_token_piece = llama_token_to_piece(ctx_main, id); const std::string new_token_piece = llama_token_to_piece(ctx_main, id);

View file

@ -26,6 +26,12 @@ enum class llama_token_healing_type : uint8_t {
DYNAMIC_MULTI // dynamic roll back, multiple constrained decoding steps DYNAMIC_MULTI // dynamic roll back, multiple constrained decoding steps
}; };
struct llama_token_healing_params {
bool enabled = false;
llama_token_healing_type type = llama_token_healing_type::DYNAMIC_MULTI;
int n_rollback = -1; // number of tokens to roll back
};
// sampling parameters // sampling parameters
typedef struct llama_sampling_params { typedef struct llama_sampling_params {
int32_t n_prev = 64; // number of previous tokens to remember int32_t n_prev = 64; // number of previous tokens to remember
@ -70,9 +76,7 @@ typedef struct llama_sampling_params {
std::vector<llama_token> penalty_prompt_tokens; std::vector<llama_token> penalty_prompt_tokens;
bool use_penalty_prompt_tokens = false; bool use_penalty_prompt_tokens = false;
llama_token_healing_type token_healing_type = llama_token_healing_type::ROLLBACK_LAST; llama_token_healing_params token_healing;
bool token_healing_enabled = false;
int token_healing_n_rollback = -1; // number of tokens to roll back
} llama_sampling_params; } llama_sampling_params;
// general sampler context // general sampler context
@ -190,3 +194,6 @@ llama_token_healing_output llama_token_healing_rollback(
int max_to_remove = -1); int max_to_remove = -1);
void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix); void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix);
// Helper for parsing token healing params from a string.
bool llama_token_healing_parse_params(const std::string & params, llama_token_healing_params & th_params);

View file

@ -291,14 +291,14 @@ int main(int argc, char ** argv) {
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str()); LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
} }
if (sparams.token_healing_enabled && (params.conversation || !params.input_suffix.empty())) { if (sparams.token_healing.enabled && (params.conversation || !params.input_suffix.empty())) {
sparams.token_healing_enabled = false; sparams.token_healing.enabled = false;
LOG("token healing: disabled due to custom suffix/conversation mode"); LOG("token healing: disabled due to custom suffix/conversation mode");
} }
llama_token_healing_output token_healing_out{}; llama_token_healing_output token_healing_out{};
if (!params.interactive_first && sparams.token_healing_enabled) { if (!params.interactive_first && sparams.token_healing.enabled) {
token_healing_out = llama_token_healing_rollback(ctx, sparams.token_healing_type, embd_inp, token_healing_out = llama_token_healing_rollback(ctx, sparams.token_healing.type, embd_inp,
sparams.token_healing_n_rollback); sparams.token_healing.n_rollback);
} }
// Should not run without any tokens // Should not run without any tokens
@ -956,13 +956,13 @@ int main(int argc, char ** argv) {
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end()); embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
embd_inp.insert(embd_inp.end(), line_sfx.begin(), line_sfx.end()); embd_inp.insert(embd_inp.end(), line_sfx.begin(), line_sfx.end());
if (sparams.token_healing_enabled) { if (sparams.token_healing.enabled) {
// Limit token healing rollback to new tokens only (otherwise would need to shift everything) // Limit token healing rollback to new tokens only (otherwise would need to shift everything)
const int n_new_tokens = embd_inp.size() - original_size; const int n_new_tokens = embd_inp.size() - original_size;
const int max_to_remove = sparams.token_healing_n_rollback < 0 const int max_to_remove = sparams.token_healing.n_rollback < 0
? n_new_tokens ? n_new_tokens
: std::min(sparams.token_healing_n_rollback, n_new_tokens); : std::min(sparams.token_healing.n_rollback, n_new_tokens);
token_healing_out = llama_token_healing_rollback(ctx, sparams.token_healing_type, embd_inp, max_to_remove); token_healing_out = llama_token_healing_rollback(ctx, sparams.token_healing.type, embd_inp, max_to_remove);
n_bytes_to_skip = token_healing_out.prefix.size(); n_bytes_to_skip = token_healing_out.prefix.size();
} }

View file

@ -1098,31 +1098,20 @@ struct server_context {
{ {
const auto & token_healing_str = data.find("token_healing"); const auto & token_healing_str = data.find("token_healing");
auto & th_enabled = slot.sparams.token_healing_enabled;
th_enabled = default_sparams.token_healing_enabled;
if (token_healing_str != data.end() && token_healing_str->is_string()) { if (token_healing_str != data.end() && token_healing_str->is_string()) {
const auto value = token_healing_str->get<std::string>(); const auto value = token_healing_str->get<std::string>();
auto & th_type = slot.sparams.token_healing_type; if (!llama_token_healing_parse_params(value, slot.sparams.token_healing)) {
auto & th_n_rollback = slot.sparams.token_healing_n_rollback; send_error(task, "\"token_healing\" parse error", ERROR_TYPE_INVALID_REQUEST);
th_enabled = true; return false;
/**/ if (value == "0" ) { th_enabled = false; } }
else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; }
else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; }
else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; }
else if (value[0] == 'r' ) {
th_type = llama_token_healing_type::ROLLBACK_MULTI;
th_n_rollback = std::stoi(value.substr(1));
if (th_n_rollback <= 0) {
th_enabled = false;
}
} else { th_enabled = false; }
LOG_VERBOSE("token healing", { LOG_VERBOSE("token healing", {
{"id_slot", slot.id}, {"id_slot", slot.id},
{"enabled", th_enabled}, {"enabled", slot.sparams.token_healing.enabled},
{"type", th_type}, {"type", slot.sparams.token_healing.type},
{"n_rollback", th_n_rollback} {"n_rollback", slot.sparams.token_healing.n_rollback}
}); });
} else {
slot.sparams.token_healing = default_sparams.token_healing;
} }
} }
@ -1406,7 +1395,7 @@ struct server_context {
{"min_keep", slot.sparams.min_keep}, {"min_keep", slot.sparams.min_keep},
{"grammar", slot.sparams.grammar}, {"grammar", slot.sparams.grammar},
{"samplers", samplers_sequence}, {"samplers", samplers_sequence},
{"token_healing_enabled", slot.sparams.token_healing_enabled} {"token_healing_enabled", slot.sparams.token_healing.enabled}
}; };
} }
@ -2109,10 +2098,10 @@ struct server_context {
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model)); prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
suffix_tokens.insert(suffix_tokens.begin(), llama_token_suffix(model)); suffix_tokens.insert(suffix_tokens.begin(), llama_token_suffix(model));
if (slot.sparams.token_healing_enabled) { if (slot.sparams.token_healing.enabled) {
// For FIM roll back only the prefix part (i.e. cursor location) // For FIM roll back only the prefix part (i.e. cursor location)
token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing_type, token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing.type,
prefix_tokens, slot.sparams.token_healing_n_rollback); prefix_tokens, slot.sparams.token_healing.n_rollback);
} }
auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens; auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens;
@ -2131,9 +2120,9 @@ struct server_context {
} else { } else {
prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
if (slot.sparams.token_healing_enabled) { if (slot.sparams.token_healing.enabled) {
token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing_type, token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing.type,
prompt_tokens, slot.sparams.token_healing_n_rollback); prompt_tokens, slot.sparams.token_healing.n_rollback);
} }
} }
@ -2149,7 +2138,7 @@ struct server_context {
{"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())}, {"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
}); });
if (slot.sparams.token_healing_enabled) { if (slot.sparams.token_healing.enabled) {
slot.n_th_prefix = token_healing_out.prefix.size(); slot.n_th_prefix = token_healing_out.prefix.size();
LOG_VERBOSE("token healing prompt", { LOG_VERBOSE("token healing prompt", {
{"id_slot", slot.id}, {"id_slot", slot.id},
@ -2224,7 +2213,7 @@ struct server_context {
} }
llama_sampling_reset(slot.ctx_sampling); llama_sampling_reset(slot.ctx_sampling);
if (slot.sparams.token_healing_enabled) { if (slot.sparams.token_healing.enabled) {
llama_token_healing_set_prefix(slot.ctx_sampling, token_healing_out.prefix); llama_token_healing_set_prefix(slot.ctx_sampling, token_healing_out.prefix);
} }