token healing : refactor argument parsing
Unify `main` and `server` token healing argument handling.
This commit is contained in:
parent
3ba5c55bc4
commit
ea4abc9d82
5 changed files with 61 additions and 60 deletions
|
@ -1095,21 +1095,8 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
|||
}
|
||||
if (arg == "-th" || arg == "--token-healing") {
|
||||
CHECK_ARG
|
||||
sparams.token_healing_enabled = true;
|
||||
auto & th_type = sparams.token_healing_type;
|
||||
auto & th_n_rollback = sparams.token_healing_n_rollback;
|
||||
std::string value(argv[i]);
|
||||
/**/ if (value == "0" ) { sparams.token_healing_enabled = false; }
|
||||
else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; }
|
||||
else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; }
|
||||
else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; }
|
||||
else if (value[0] == 'r' ) {
|
||||
th_type = llama_token_healing_type::ROLLBACK_MULTI;
|
||||
th_n_rollback = std::stoi(value.substr(1));
|
||||
if (th_n_rollback <= 0) {
|
||||
sparams.token_healing_enabled = false;
|
||||
}
|
||||
} else { invalid_param = true; }
|
||||
invalid_param = !llama_token_healing_parse_params(value, sparams.token_healing);
|
||||
return true;
|
||||
}
|
||||
if (arg == "--override-kv") {
|
||||
|
|
|
@ -154,6 +154,25 @@ void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const
|
|||
ctx_sampling->token_healing_prefix = prefix;
|
||||
}
|
||||
|
||||
bool llama_token_healing_parse_params(const std::string & params, llama_token_healing_params & th_params) {
|
||||
th_params.enabled = true;
|
||||
th_params.n_rollback = -1;
|
||||
/**/ if (params == "0" ) { th_params.enabled = false; }
|
||||
else if (params == "1" ) { th_params.type = llama_token_healing_type::ROLLBACK_LAST; }
|
||||
else if (params == "d1") { th_params.type = llama_token_healing_type::DYNAMIC_ONCE; }
|
||||
else if (params == "d" ) { th_params.type = llama_token_healing_type::DYNAMIC_MULTI; }
|
||||
else if (params[0] == 'r' ) {
|
||||
th_params.type = llama_token_healing_type::ROLLBACK_MULTI;
|
||||
th_params.n_rollback = std::stoi(params.substr(1));
|
||||
if (th_params.n_rollback <= 0) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
//
|
||||
// Sampling
|
||||
//
|
||||
|
@ -552,11 +571,10 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
|||
cur.resize(n_vocab);
|
||||
|
||||
// Constrain tokens based on the remaining token healing prefix (if any)
|
||||
const auto & th_type = params.token_healing_type;
|
||||
const auto & th_prefix = ctx_sampling->token_healing_prefix;
|
||||
if (params.token_healing_enabled && !th_prefix.empty()) {
|
||||
const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI ||
|
||||
th_type == llama_token_healing_type::DYNAMIC_MULTI;
|
||||
if (params.token_healing.enabled && !th_prefix.empty()) {
|
||||
const bool is_multi_step = params.token_healing.type == llama_token_healing_type::ROLLBACK_MULTI ||
|
||||
params.token_healing.type == llama_token_healing_type::DYNAMIC_MULTI;
|
||||
std::vector<llama_token> th_candidates = token_healing_get_candidates(ctx_main, th_prefix, is_multi_step);
|
||||
|
||||
LOG("token_healing: prefix = '%s'\n", th_prefix.c_str());
|
||||
|
@ -635,7 +653,7 @@ void llama_sampling_accept(
|
|||
llama_grammar_accept_token(ctx_sampling->grammar, ctx_main, id);
|
||||
}
|
||||
|
||||
if (ctx_sampling->params.token_healing_enabled && apply_grammar) {
|
||||
if (ctx_sampling->params.token_healing.enabled && apply_grammar) {
|
||||
std::string & th_prefix = ctx_sampling->token_healing_prefix;
|
||||
if (!th_prefix.empty()) {
|
||||
const std::string new_token_piece = llama_token_to_piece(ctx_main, id);
|
||||
|
|
|
@ -26,6 +26,12 @@ enum class llama_token_healing_type : uint8_t {
|
|||
DYNAMIC_MULTI // dynamic roll back, multiple constrained decoding steps
|
||||
};
|
||||
|
||||
struct llama_token_healing_params {
|
||||
bool enabled = false;
|
||||
llama_token_healing_type type = llama_token_healing_type::DYNAMIC_MULTI;
|
||||
int n_rollback = -1; // number of tokens to roll back
|
||||
};
|
||||
|
||||
// sampling parameters
|
||||
typedef struct llama_sampling_params {
|
||||
int32_t n_prev = 64; // number of previous tokens to remember
|
||||
|
@ -70,9 +76,7 @@ typedef struct llama_sampling_params {
|
|||
std::vector<llama_token> penalty_prompt_tokens;
|
||||
bool use_penalty_prompt_tokens = false;
|
||||
|
||||
llama_token_healing_type token_healing_type = llama_token_healing_type::ROLLBACK_LAST;
|
||||
bool token_healing_enabled = false;
|
||||
int token_healing_n_rollback = -1; // number of tokens to roll back
|
||||
llama_token_healing_params token_healing;
|
||||
} llama_sampling_params;
|
||||
|
||||
// general sampler context
|
||||
|
@ -190,3 +194,6 @@ llama_token_healing_output llama_token_healing_rollback(
|
|||
int max_to_remove = -1);
|
||||
|
||||
void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix);
|
||||
|
||||
// Helper for parsing token healing params from a string.
|
||||
bool llama_token_healing_parse_params(const std::string & params, llama_token_healing_params & th_params);
|
||||
|
|
|
@ -291,14 +291,14 @@ int main(int argc, char ** argv) {
|
|||
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
|
||||
}
|
||||
|
||||
if (sparams.token_healing_enabled && (params.conversation || !params.input_suffix.empty())) {
|
||||
sparams.token_healing_enabled = false;
|
||||
if (sparams.token_healing.enabled && (params.conversation || !params.input_suffix.empty())) {
|
||||
sparams.token_healing.enabled = false;
|
||||
LOG("token healing: disabled due to custom suffix/conversation mode");
|
||||
}
|
||||
llama_token_healing_output token_healing_out{};
|
||||
if (!params.interactive_first && sparams.token_healing_enabled) {
|
||||
token_healing_out = llama_token_healing_rollback(ctx, sparams.token_healing_type, embd_inp,
|
||||
sparams.token_healing_n_rollback);
|
||||
if (!params.interactive_first && sparams.token_healing.enabled) {
|
||||
token_healing_out = llama_token_healing_rollback(ctx, sparams.token_healing.type, embd_inp,
|
||||
sparams.token_healing.n_rollback);
|
||||
}
|
||||
|
||||
// Should not run without any tokens
|
||||
|
@ -956,13 +956,13 @@ int main(int argc, char ** argv) {
|
|||
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
|
||||
embd_inp.insert(embd_inp.end(), line_sfx.begin(), line_sfx.end());
|
||||
|
||||
if (sparams.token_healing_enabled) {
|
||||
if (sparams.token_healing.enabled) {
|
||||
// Limit token healing rollback to new tokens only (otherwise would need to shift everything)
|
||||
const int n_new_tokens = embd_inp.size() - original_size;
|
||||
const int max_to_remove = sparams.token_healing_n_rollback < 0
|
||||
const int max_to_remove = sparams.token_healing.n_rollback < 0
|
||||
? n_new_tokens
|
||||
: std::min(sparams.token_healing_n_rollback, n_new_tokens);
|
||||
token_healing_out = llama_token_healing_rollback(ctx, sparams.token_healing_type, embd_inp, max_to_remove);
|
||||
: std::min(sparams.token_healing.n_rollback, n_new_tokens);
|
||||
token_healing_out = llama_token_healing_rollback(ctx, sparams.token_healing.type, embd_inp, max_to_remove);
|
||||
n_bytes_to_skip = token_healing_out.prefix.size();
|
||||
}
|
||||
|
||||
|
|
|
@ -1098,31 +1098,20 @@ struct server_context {
|
|||
|
||||
{
|
||||
const auto & token_healing_str = data.find("token_healing");
|
||||
auto & th_enabled = slot.sparams.token_healing_enabled;
|
||||
th_enabled = default_sparams.token_healing_enabled;
|
||||
if (token_healing_str != data.end() && token_healing_str->is_string()) {
|
||||
const auto value = token_healing_str->get<std::string>();
|
||||
auto & th_type = slot.sparams.token_healing_type;
|
||||
auto & th_n_rollback = slot.sparams.token_healing_n_rollback;
|
||||
th_enabled = true;
|
||||
/**/ if (value == "0" ) { th_enabled = false; }
|
||||
else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; }
|
||||
else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; }
|
||||
else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; }
|
||||
else if (value[0] == 'r' ) {
|
||||
th_type = llama_token_healing_type::ROLLBACK_MULTI;
|
||||
th_n_rollback = std::stoi(value.substr(1));
|
||||
if (th_n_rollback <= 0) {
|
||||
th_enabled = false;
|
||||
}
|
||||
} else { th_enabled = false; }
|
||||
|
||||
if (!llama_token_healing_parse_params(value, slot.sparams.token_healing)) {
|
||||
send_error(task, "\"token_healing\" parse error", ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
}
|
||||
LOG_VERBOSE("token healing", {
|
||||
{"id_slot", slot.id},
|
||||
{"enabled", th_enabled},
|
||||
{"type", th_type},
|
||||
{"n_rollback", th_n_rollback}
|
||||
{"enabled", slot.sparams.token_healing.enabled},
|
||||
{"type", slot.sparams.token_healing.type},
|
||||
{"n_rollback", slot.sparams.token_healing.n_rollback}
|
||||
});
|
||||
} else {
|
||||
slot.sparams.token_healing = default_sparams.token_healing;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1406,7 +1395,7 @@ struct server_context {
|
|||
{"min_keep", slot.sparams.min_keep},
|
||||
{"grammar", slot.sparams.grammar},
|
||||
{"samplers", samplers_sequence},
|
||||
{"token_healing_enabled", slot.sparams.token_healing_enabled}
|
||||
{"token_healing_enabled", slot.sparams.token_healing.enabled}
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -2109,10 +2098,10 @@ struct server_context {
|
|||
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
|
||||
suffix_tokens.insert(suffix_tokens.begin(), llama_token_suffix(model));
|
||||
|
||||
if (slot.sparams.token_healing_enabled) {
|
||||
if (slot.sparams.token_healing.enabled) {
|
||||
// For FIM roll back only the prefix part (i.e. cursor location)
|
||||
token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing_type,
|
||||
prefix_tokens, slot.sparams.token_healing_n_rollback);
|
||||
token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing.type,
|
||||
prefix_tokens, slot.sparams.token_healing.n_rollback);
|
||||
}
|
||||
|
||||
auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens;
|
||||
|
@ -2131,9 +2120,9 @@ struct server_context {
|
|||
} else {
|
||||
prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
|
||||
|
||||
if (slot.sparams.token_healing_enabled) {
|
||||
token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing_type,
|
||||
prompt_tokens, slot.sparams.token_healing_n_rollback);
|
||||
if (slot.sparams.token_healing.enabled) {
|
||||
token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing.type,
|
||||
prompt_tokens, slot.sparams.token_healing.n_rollback);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2149,7 +2138,7 @@ struct server_context {
|
|||
{"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
|
||||
});
|
||||
|
||||
if (slot.sparams.token_healing_enabled) {
|
||||
if (slot.sparams.token_healing.enabled) {
|
||||
slot.n_th_prefix = token_healing_out.prefix.size();
|
||||
LOG_VERBOSE("token healing prompt", {
|
||||
{"id_slot", slot.id},
|
||||
|
@ -2224,7 +2213,7 @@ struct server_context {
|
|||
}
|
||||
|
||||
llama_sampling_reset(slot.ctx_sampling);
|
||||
if (slot.sparams.token_healing_enabled) {
|
||||
if (slot.sparams.token_healing.enabled) {
|
||||
llama_token_healing_set_prefix(slot.ctx_sampling, token_healing_out.prefix);
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue