token healing : refactor argument parsing

Unify `main` and `server` token healing argument handling.
This commit is contained in:
mare5x 2024-07-01 11:51:39 +02:00
parent 3ba5c55bc4
commit ea4abc9d82
5 changed files with 61 additions and 60 deletions

View file

@ -1095,21 +1095,8 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
}
if (arg == "-th" || arg == "--token-healing") {
CHECK_ARG
sparams.token_healing_enabled = true;
auto & th_type = sparams.token_healing_type;
auto & th_n_rollback = sparams.token_healing_n_rollback;
std::string value(argv[i]);
/**/ if (value == "0" ) { sparams.token_healing_enabled = false; }
else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; }
else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; }
else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; }
else if (value[0] == 'r' ) {
th_type = llama_token_healing_type::ROLLBACK_MULTI;
th_n_rollback = std::stoi(value.substr(1));
if (th_n_rollback <= 0) {
sparams.token_healing_enabled = false;
}
} else { invalid_param = true; }
invalid_param = !llama_token_healing_parse_params(value, sparams.token_healing);
return true;
}
if (arg == "--override-kv") {

View file

@ -154,6 +154,25 @@ void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const
ctx_sampling->token_healing_prefix = prefix;
}
bool llama_token_healing_parse_params(const std::string & params, llama_token_healing_params & th_params) {
th_params.enabled = true;
th_params.n_rollback = -1;
/**/ if (params == "0" ) { th_params.enabled = false; }
else if (params == "1" ) { th_params.type = llama_token_healing_type::ROLLBACK_LAST; }
else if (params == "d1") { th_params.type = llama_token_healing_type::DYNAMIC_ONCE; }
else if (params == "d" ) { th_params.type = llama_token_healing_type::DYNAMIC_MULTI; }
else if (params[0] == 'r' ) {
th_params.type = llama_token_healing_type::ROLLBACK_MULTI;
th_params.n_rollback = std::stoi(params.substr(1));
if (th_params.n_rollback <= 0) {
return false;
}
} else {
return false;
}
return true;
}
//
// Sampling
//
@ -552,11 +571,10 @@ static llama_token_data_array llama_sampling_prepare_impl(
cur.resize(n_vocab);
// Constrain tokens based on the remaining token healing prefix (if any)
const auto & th_type = params.token_healing_type;
const auto & th_prefix = ctx_sampling->token_healing_prefix;
if (params.token_healing_enabled && !th_prefix.empty()) {
const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI ||
th_type == llama_token_healing_type::DYNAMIC_MULTI;
if (params.token_healing.enabled && !th_prefix.empty()) {
const bool is_multi_step = params.token_healing.type == llama_token_healing_type::ROLLBACK_MULTI ||
params.token_healing.type == llama_token_healing_type::DYNAMIC_MULTI;
std::vector<llama_token> th_candidates = token_healing_get_candidates(ctx_main, th_prefix, is_multi_step);
LOG("token_healing: prefix = '%s'\n", th_prefix.c_str());
@ -635,7 +653,7 @@ void llama_sampling_accept(
llama_grammar_accept_token(ctx_sampling->grammar, ctx_main, id);
}
if (ctx_sampling->params.token_healing_enabled && apply_grammar) {
if (ctx_sampling->params.token_healing.enabled && apply_grammar) {
std::string & th_prefix = ctx_sampling->token_healing_prefix;
if (!th_prefix.empty()) {
const std::string new_token_piece = llama_token_to_piece(ctx_main, id);

View file

@ -26,6 +26,12 @@ enum class llama_token_healing_type : uint8_t {
DYNAMIC_MULTI // dynamic roll back, multiple constrained decoding steps
};
struct llama_token_healing_params {
bool enabled = false;
llama_token_healing_type type = llama_token_healing_type::DYNAMIC_MULTI;
int n_rollback = -1; // number of tokens to roll back
};
// sampling parameters
typedef struct llama_sampling_params {
int32_t n_prev = 64; // number of previous tokens to remember
@ -70,9 +76,7 @@ typedef struct llama_sampling_params {
std::vector<llama_token> penalty_prompt_tokens;
bool use_penalty_prompt_tokens = false;
llama_token_healing_type token_healing_type = llama_token_healing_type::ROLLBACK_LAST;
bool token_healing_enabled = false;
int token_healing_n_rollback = -1; // number of tokens to roll back
llama_token_healing_params token_healing;
} llama_sampling_params;
// general sampler context
@ -190,3 +194,6 @@ llama_token_healing_output llama_token_healing_rollback(
int max_to_remove = -1);
void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix);
// Helper for parsing token healing params from a string.
bool llama_token_healing_parse_params(const std::string & params, llama_token_healing_params & th_params);

View file

@ -291,14 +291,14 @@ int main(int argc, char ** argv) {
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
}
if (sparams.token_healing_enabled && (params.conversation || !params.input_suffix.empty())) {
sparams.token_healing_enabled = false;
if (sparams.token_healing.enabled && (params.conversation || !params.input_suffix.empty())) {
sparams.token_healing.enabled = false;
LOG("token healing: disabled due to custom suffix/conversation mode");
}
llama_token_healing_output token_healing_out{};
if (!params.interactive_first && sparams.token_healing_enabled) {
token_healing_out = llama_token_healing_rollback(ctx, sparams.token_healing_type, embd_inp,
sparams.token_healing_n_rollback);
if (!params.interactive_first && sparams.token_healing.enabled) {
token_healing_out = llama_token_healing_rollback(ctx, sparams.token_healing.type, embd_inp,
sparams.token_healing.n_rollback);
}
// Should not run without any tokens
@ -956,13 +956,13 @@ int main(int argc, char ** argv) {
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
embd_inp.insert(embd_inp.end(), line_sfx.begin(), line_sfx.end());
if (sparams.token_healing_enabled) {
if (sparams.token_healing.enabled) {
// Limit token healing rollback to new tokens only (otherwise would need to shift everything)
const int n_new_tokens = embd_inp.size() - original_size;
const int max_to_remove = sparams.token_healing_n_rollback < 0
const int max_to_remove = sparams.token_healing.n_rollback < 0
? n_new_tokens
: std::min(sparams.token_healing_n_rollback, n_new_tokens);
token_healing_out = llama_token_healing_rollback(ctx, sparams.token_healing_type, embd_inp, max_to_remove);
: std::min(sparams.token_healing.n_rollback, n_new_tokens);
token_healing_out = llama_token_healing_rollback(ctx, sparams.token_healing.type, embd_inp, max_to_remove);
n_bytes_to_skip = token_healing_out.prefix.size();
}

View file

@ -1098,31 +1098,20 @@ struct server_context {
{
const auto & token_healing_str = data.find("token_healing");
auto & th_enabled = slot.sparams.token_healing_enabled;
th_enabled = default_sparams.token_healing_enabled;
if (token_healing_str != data.end() && token_healing_str->is_string()) {
const auto value = token_healing_str->get<std::string>();
auto & th_type = slot.sparams.token_healing_type;
auto & th_n_rollback = slot.sparams.token_healing_n_rollback;
th_enabled = true;
/**/ if (value == "0" ) { th_enabled = false; }
else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; }
else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; }
else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; }
else if (value[0] == 'r' ) {
th_type = llama_token_healing_type::ROLLBACK_MULTI;
th_n_rollback = std::stoi(value.substr(1));
if (th_n_rollback <= 0) {
th_enabled = false;
}
} else { th_enabled = false; }
if (!llama_token_healing_parse_params(value, slot.sparams.token_healing)) {
send_error(task, "\"token_healing\" parse error", ERROR_TYPE_INVALID_REQUEST);
return false;
}
LOG_VERBOSE("token healing", {
{"id_slot", slot.id},
{"enabled", th_enabled},
{"type", th_type},
{"n_rollback", th_n_rollback}
{"enabled", slot.sparams.token_healing.enabled},
{"type", slot.sparams.token_healing.type},
{"n_rollback", slot.sparams.token_healing.n_rollback}
});
} else {
slot.sparams.token_healing = default_sparams.token_healing;
}
}
@ -1406,7 +1395,7 @@ struct server_context {
{"min_keep", slot.sparams.min_keep},
{"grammar", slot.sparams.grammar},
{"samplers", samplers_sequence},
{"token_healing_enabled", slot.sparams.token_healing_enabled}
{"token_healing_enabled", slot.sparams.token_healing.enabled}
};
}
@ -2109,10 +2098,10 @@ struct server_context {
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
suffix_tokens.insert(suffix_tokens.begin(), llama_token_suffix(model));
if (slot.sparams.token_healing_enabled) {
if (slot.sparams.token_healing.enabled) {
// For FIM roll back only the prefix part (i.e. cursor location)
token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing_type,
prefix_tokens, slot.sparams.token_healing_n_rollback);
token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing.type,
prefix_tokens, slot.sparams.token_healing.n_rollback);
}
auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens;
@ -2131,9 +2120,9 @@ struct server_context {
} else {
prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
if (slot.sparams.token_healing_enabled) {
token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing_type,
prompt_tokens, slot.sparams.token_healing_n_rollback);
if (slot.sparams.token_healing.enabled) {
token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing.type,
prompt_tokens, slot.sparams.token_healing.n_rollback);
}
}
@ -2149,7 +2138,7 @@ struct server_context {
{"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
});
if (slot.sparams.token_healing_enabled) {
if (slot.sparams.token_healing.enabled) {
slot.n_th_prefix = token_healing_out.prefix.size();
LOG_VERBOSE("token healing prompt", {
{"id_slot", slot.id},
@ -2224,7 +2213,7 @@ struct server_context {
}
llama_sampling_reset(slot.ctx_sampling);
if (slot.sparams.token_healing_enabled) {
if (slot.sparams.token_healing.enabled) {
llama_token_healing_set_prefix(slot.ctx_sampling, token_healing_out.prefix);
}