token healing : refactor argument parsing
Unify `main` and `server` token healing argument handling.
This commit is contained in:
parent
3ba5c55bc4
commit
ea4abc9d82
5 changed files with 61 additions and 60 deletions
|
@ -1095,21 +1095,8 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||||
}
|
}
|
||||||
if (arg == "-th" || arg == "--token-healing") {
|
if (arg == "-th" || arg == "--token-healing") {
|
||||||
CHECK_ARG
|
CHECK_ARG
|
||||||
sparams.token_healing_enabled = true;
|
|
||||||
auto & th_type = sparams.token_healing_type;
|
|
||||||
auto & th_n_rollback = sparams.token_healing_n_rollback;
|
|
||||||
std::string value(argv[i]);
|
std::string value(argv[i]);
|
||||||
/**/ if (value == "0" ) { sparams.token_healing_enabled = false; }
|
invalid_param = !llama_token_healing_parse_params(value, sparams.token_healing);
|
||||||
else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; }
|
|
||||||
else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; }
|
|
||||||
else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; }
|
|
||||||
else if (value[0] == 'r' ) {
|
|
||||||
th_type = llama_token_healing_type::ROLLBACK_MULTI;
|
|
||||||
th_n_rollback = std::stoi(value.substr(1));
|
|
||||||
if (th_n_rollback <= 0) {
|
|
||||||
sparams.token_healing_enabled = false;
|
|
||||||
}
|
|
||||||
} else { invalid_param = true; }
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (arg == "--override-kv") {
|
if (arg == "--override-kv") {
|
||||||
|
|
|
@ -154,6 +154,25 @@ void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const
|
||||||
ctx_sampling->token_healing_prefix = prefix;
|
ctx_sampling->token_healing_prefix = prefix;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool llama_token_healing_parse_params(const std::string & params, llama_token_healing_params & th_params) {
|
||||||
|
th_params.enabled = true;
|
||||||
|
th_params.n_rollback = -1;
|
||||||
|
/**/ if (params == "0" ) { th_params.enabled = false; }
|
||||||
|
else if (params == "1" ) { th_params.type = llama_token_healing_type::ROLLBACK_LAST; }
|
||||||
|
else if (params == "d1") { th_params.type = llama_token_healing_type::DYNAMIC_ONCE; }
|
||||||
|
else if (params == "d" ) { th_params.type = llama_token_healing_type::DYNAMIC_MULTI; }
|
||||||
|
else if (params[0] == 'r' ) {
|
||||||
|
th_params.type = llama_token_healing_type::ROLLBACK_MULTI;
|
||||||
|
th_params.n_rollback = std::stoi(params.substr(1));
|
||||||
|
if (th_params.n_rollback <= 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// Sampling
|
// Sampling
|
||||||
//
|
//
|
||||||
|
@ -552,11 +571,10 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
||||||
cur.resize(n_vocab);
|
cur.resize(n_vocab);
|
||||||
|
|
||||||
// Constrain tokens based on the remaining token healing prefix (if any)
|
// Constrain tokens based on the remaining token healing prefix (if any)
|
||||||
const auto & th_type = params.token_healing_type;
|
|
||||||
const auto & th_prefix = ctx_sampling->token_healing_prefix;
|
const auto & th_prefix = ctx_sampling->token_healing_prefix;
|
||||||
if (params.token_healing_enabled && !th_prefix.empty()) {
|
if (params.token_healing.enabled && !th_prefix.empty()) {
|
||||||
const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI ||
|
const bool is_multi_step = params.token_healing.type == llama_token_healing_type::ROLLBACK_MULTI ||
|
||||||
th_type == llama_token_healing_type::DYNAMIC_MULTI;
|
params.token_healing.type == llama_token_healing_type::DYNAMIC_MULTI;
|
||||||
std::vector<llama_token> th_candidates = token_healing_get_candidates(ctx_main, th_prefix, is_multi_step);
|
std::vector<llama_token> th_candidates = token_healing_get_candidates(ctx_main, th_prefix, is_multi_step);
|
||||||
|
|
||||||
LOG("token_healing: prefix = '%s'\n", th_prefix.c_str());
|
LOG("token_healing: prefix = '%s'\n", th_prefix.c_str());
|
||||||
|
@ -635,7 +653,7 @@ void llama_sampling_accept(
|
||||||
llama_grammar_accept_token(ctx_sampling->grammar, ctx_main, id);
|
llama_grammar_accept_token(ctx_sampling->grammar, ctx_main, id);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ctx_sampling->params.token_healing_enabled && apply_grammar) {
|
if (ctx_sampling->params.token_healing.enabled && apply_grammar) {
|
||||||
std::string & th_prefix = ctx_sampling->token_healing_prefix;
|
std::string & th_prefix = ctx_sampling->token_healing_prefix;
|
||||||
if (!th_prefix.empty()) {
|
if (!th_prefix.empty()) {
|
||||||
const std::string new_token_piece = llama_token_to_piece(ctx_main, id);
|
const std::string new_token_piece = llama_token_to_piece(ctx_main, id);
|
||||||
|
|
|
@ -26,6 +26,12 @@ enum class llama_token_healing_type : uint8_t {
|
||||||
DYNAMIC_MULTI // dynamic roll back, multiple constrained decoding steps
|
DYNAMIC_MULTI // dynamic roll back, multiple constrained decoding steps
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct llama_token_healing_params {
|
||||||
|
bool enabled = false;
|
||||||
|
llama_token_healing_type type = llama_token_healing_type::DYNAMIC_MULTI;
|
||||||
|
int n_rollback = -1; // number of tokens to roll back
|
||||||
|
};
|
||||||
|
|
||||||
// sampling parameters
|
// sampling parameters
|
||||||
typedef struct llama_sampling_params {
|
typedef struct llama_sampling_params {
|
||||||
int32_t n_prev = 64; // number of previous tokens to remember
|
int32_t n_prev = 64; // number of previous tokens to remember
|
||||||
|
@ -70,9 +76,7 @@ typedef struct llama_sampling_params {
|
||||||
std::vector<llama_token> penalty_prompt_tokens;
|
std::vector<llama_token> penalty_prompt_tokens;
|
||||||
bool use_penalty_prompt_tokens = false;
|
bool use_penalty_prompt_tokens = false;
|
||||||
|
|
||||||
llama_token_healing_type token_healing_type = llama_token_healing_type::ROLLBACK_LAST;
|
llama_token_healing_params token_healing;
|
||||||
bool token_healing_enabled = false;
|
|
||||||
int token_healing_n_rollback = -1; // number of tokens to roll back
|
|
||||||
} llama_sampling_params;
|
} llama_sampling_params;
|
||||||
|
|
||||||
// general sampler context
|
// general sampler context
|
||||||
|
@ -190,3 +194,6 @@ llama_token_healing_output llama_token_healing_rollback(
|
||||||
int max_to_remove = -1);
|
int max_to_remove = -1);
|
||||||
|
|
||||||
void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix);
|
void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix);
|
||||||
|
|
||||||
|
// Helper for parsing token healing params from a string.
|
||||||
|
bool llama_token_healing_parse_params(const std::string & params, llama_token_healing_params & th_params);
|
||||||
|
|
|
@ -291,14 +291,14 @@ int main(int argc, char ** argv) {
|
||||||
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
|
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (sparams.token_healing_enabled && (params.conversation || !params.input_suffix.empty())) {
|
if (sparams.token_healing.enabled && (params.conversation || !params.input_suffix.empty())) {
|
||||||
sparams.token_healing_enabled = false;
|
sparams.token_healing.enabled = false;
|
||||||
LOG("token healing: disabled due to custom suffix/conversation mode");
|
LOG("token healing: disabled due to custom suffix/conversation mode");
|
||||||
}
|
}
|
||||||
llama_token_healing_output token_healing_out{};
|
llama_token_healing_output token_healing_out{};
|
||||||
if (!params.interactive_first && sparams.token_healing_enabled) {
|
if (!params.interactive_first && sparams.token_healing.enabled) {
|
||||||
token_healing_out = llama_token_healing_rollback(ctx, sparams.token_healing_type, embd_inp,
|
token_healing_out = llama_token_healing_rollback(ctx, sparams.token_healing.type, embd_inp,
|
||||||
sparams.token_healing_n_rollback);
|
sparams.token_healing.n_rollback);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should not run without any tokens
|
// Should not run without any tokens
|
||||||
|
@ -956,13 +956,13 @@ int main(int argc, char ** argv) {
|
||||||
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
|
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
|
||||||
embd_inp.insert(embd_inp.end(), line_sfx.begin(), line_sfx.end());
|
embd_inp.insert(embd_inp.end(), line_sfx.begin(), line_sfx.end());
|
||||||
|
|
||||||
if (sparams.token_healing_enabled) {
|
if (sparams.token_healing.enabled) {
|
||||||
// Limit token healing rollback to new tokens only (otherwise would need to shift everything)
|
// Limit token healing rollback to new tokens only (otherwise would need to shift everything)
|
||||||
const int n_new_tokens = embd_inp.size() - original_size;
|
const int n_new_tokens = embd_inp.size() - original_size;
|
||||||
const int max_to_remove = sparams.token_healing_n_rollback < 0
|
const int max_to_remove = sparams.token_healing.n_rollback < 0
|
||||||
? n_new_tokens
|
? n_new_tokens
|
||||||
: std::min(sparams.token_healing_n_rollback, n_new_tokens);
|
: std::min(sparams.token_healing.n_rollback, n_new_tokens);
|
||||||
token_healing_out = llama_token_healing_rollback(ctx, sparams.token_healing_type, embd_inp, max_to_remove);
|
token_healing_out = llama_token_healing_rollback(ctx, sparams.token_healing.type, embd_inp, max_to_remove);
|
||||||
n_bytes_to_skip = token_healing_out.prefix.size();
|
n_bytes_to_skip = token_healing_out.prefix.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1098,31 +1098,20 @@ struct server_context {
|
||||||
|
|
||||||
{
|
{
|
||||||
const auto & token_healing_str = data.find("token_healing");
|
const auto & token_healing_str = data.find("token_healing");
|
||||||
auto & th_enabled = slot.sparams.token_healing_enabled;
|
|
||||||
th_enabled = default_sparams.token_healing_enabled;
|
|
||||||
if (token_healing_str != data.end() && token_healing_str->is_string()) {
|
if (token_healing_str != data.end() && token_healing_str->is_string()) {
|
||||||
const auto value = token_healing_str->get<std::string>();
|
const auto value = token_healing_str->get<std::string>();
|
||||||
auto & th_type = slot.sparams.token_healing_type;
|
if (!llama_token_healing_parse_params(value, slot.sparams.token_healing)) {
|
||||||
auto & th_n_rollback = slot.sparams.token_healing_n_rollback;
|
send_error(task, "\"token_healing\" parse error", ERROR_TYPE_INVALID_REQUEST);
|
||||||
th_enabled = true;
|
return false;
|
||||||
/**/ if (value == "0" ) { th_enabled = false; }
|
}
|
||||||
else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; }
|
|
||||||
else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; }
|
|
||||||
else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; }
|
|
||||||
else if (value[0] == 'r' ) {
|
|
||||||
th_type = llama_token_healing_type::ROLLBACK_MULTI;
|
|
||||||
th_n_rollback = std::stoi(value.substr(1));
|
|
||||||
if (th_n_rollback <= 0) {
|
|
||||||
th_enabled = false;
|
|
||||||
}
|
|
||||||
} else { th_enabled = false; }
|
|
||||||
|
|
||||||
LOG_VERBOSE("token healing", {
|
LOG_VERBOSE("token healing", {
|
||||||
{"id_slot", slot.id},
|
{"id_slot", slot.id},
|
||||||
{"enabled", th_enabled},
|
{"enabled", slot.sparams.token_healing.enabled},
|
||||||
{"type", th_type},
|
{"type", slot.sparams.token_healing.type},
|
||||||
{"n_rollback", th_n_rollback}
|
{"n_rollback", slot.sparams.token_healing.n_rollback}
|
||||||
});
|
});
|
||||||
|
} else {
|
||||||
|
slot.sparams.token_healing = default_sparams.token_healing;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1406,7 +1395,7 @@ struct server_context {
|
||||||
{"min_keep", slot.sparams.min_keep},
|
{"min_keep", slot.sparams.min_keep},
|
||||||
{"grammar", slot.sparams.grammar},
|
{"grammar", slot.sparams.grammar},
|
||||||
{"samplers", samplers_sequence},
|
{"samplers", samplers_sequence},
|
||||||
{"token_healing_enabled", slot.sparams.token_healing_enabled}
|
{"token_healing_enabled", slot.sparams.token_healing.enabled}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2109,10 +2098,10 @@ struct server_context {
|
||||||
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
|
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
|
||||||
suffix_tokens.insert(suffix_tokens.begin(), llama_token_suffix(model));
|
suffix_tokens.insert(suffix_tokens.begin(), llama_token_suffix(model));
|
||||||
|
|
||||||
if (slot.sparams.token_healing_enabled) {
|
if (slot.sparams.token_healing.enabled) {
|
||||||
// For FIM roll back only the prefix part (i.e. cursor location)
|
// For FIM roll back only the prefix part (i.e. cursor location)
|
||||||
token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing_type,
|
token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing.type,
|
||||||
prefix_tokens, slot.sparams.token_healing_n_rollback);
|
prefix_tokens, slot.sparams.token_healing.n_rollback);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens;
|
auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens;
|
||||||
|
@ -2131,9 +2120,9 @@ struct server_context {
|
||||||
} else {
|
} else {
|
||||||
prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
|
prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
|
||||||
|
|
||||||
if (slot.sparams.token_healing_enabled) {
|
if (slot.sparams.token_healing.enabled) {
|
||||||
token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing_type,
|
token_healing_out = llama_token_healing_rollback(ctx, slot.sparams.token_healing.type,
|
||||||
prompt_tokens, slot.sparams.token_healing_n_rollback);
|
prompt_tokens, slot.sparams.token_healing.n_rollback);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2149,7 +2138,7 @@ struct server_context {
|
||||||
{"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
|
{"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
|
||||||
});
|
});
|
||||||
|
|
||||||
if (slot.sparams.token_healing_enabled) {
|
if (slot.sparams.token_healing.enabled) {
|
||||||
slot.n_th_prefix = token_healing_out.prefix.size();
|
slot.n_th_prefix = token_healing_out.prefix.size();
|
||||||
LOG_VERBOSE("token healing prompt", {
|
LOG_VERBOSE("token healing prompt", {
|
||||||
{"id_slot", slot.id},
|
{"id_slot", slot.id},
|
||||||
|
@ -2224,7 +2213,7 @@ struct server_context {
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_sampling_reset(slot.ctx_sampling);
|
llama_sampling_reset(slot.ctx_sampling);
|
||||||
if (slot.sparams.token_healing_enabled) {
|
if (slot.sparams.token_healing.enabled) {
|
||||||
llama_token_healing_set_prefix(slot.ctx_sampling, token_healing_out.prefix);
|
llama_token_healing_set_prefix(slot.ctx_sampling, token_healing_out.prefix);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue