Add option to ignore tokens with 2+ English characters
This commit is contained in:
parent
916248af1f
commit
3c5acd512a
3 changed files with 31 additions and 3 deletions
|
@ -90,7 +90,8 @@ struct gpt_params {
|
||||||
float yarn_beta_slow = 1.0f; // YaRN high correction dim
|
float yarn_beta_slow = 1.0f; // YaRN high correction dim
|
||||||
int32_t yarn_orig_ctx = 0; // YaRN original context length
|
int32_t yarn_orig_ctx = 0; // YaRN original context length
|
||||||
float defrag_thold = -1.0f; // KV cache defragmentation threshold
|
float defrag_thold = -1.0f; // KV cache defragmentation threshold
|
||||||
|
bool ignore_english_tokens = false; // Experimental: Attempt to not sample tokens containing English characters
|
||||||
|
|
||||||
ggml_backend_sched_eval_callback cb_eval = nullptr;
|
ggml_backend_sched_eval_callback cb_eval = nullptr;
|
||||||
void * cb_eval_user_data = nullptr;
|
void * cb_eval_user_data = nullptr;
|
||||||
|
|
||||||
|
|
|
@ -121,10 +121,12 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
|
||||||
snprintf(result, sizeof(result),
|
snprintf(result, sizeof(result),
|
||||||
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
|
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
|
||||||
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n"
|
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n"
|
||||||
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
|
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f\n"
|
||||||
|
"\tignore_english_tokens = %s",
|
||||||
params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present,
|
params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present,
|
||||||
params.top_k, params.tfs_z, params.top_p, params.min_p, params.typical_p, params.temp,
|
params.top_k, params.tfs_z, params.top_p, params.min_p, params.typical_p, params.temp,
|
||||||
params.mirostat, params.mirostat_eta, params.mirostat_tau);
|
params.mirostat, params.mirostat_eta, params.mirostat_tau,
|
||||||
|
params.ignore_english_tokens ? "true" : "false");
|
||||||
|
|
||||||
return std::string(result);
|
return std::string(result);
|
||||||
}
|
}
|
||||||
|
@ -423,6 +425,13 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
||||||
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
|
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (params.ignore_english_tokens) {
|
||||||
|
for (size_t i = 0; i < cur_p.size; ++i) {
|
||||||
|
if (is_english_token(ctx_main, cur_p.data[i].id)) {
|
||||||
|
cur_p.data[i].logit = -INFINITY;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return cur_p;
|
return cur_p;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -457,3 +466,20 @@ void llama_sampling_accept(
|
||||||
llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id);
|
llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool is_english_token(const llama_context * ctx, llama_token token) {
|
||||||
|
const std::string token_str = llama_token_to_piece(ctx, token);
|
||||||
|
int english_char_count = 0;
|
||||||
|
bool has_angle_bracket = false;
|
||||||
|
|
||||||
|
for (char c : token_str) {
|
||||||
|
if (c >= 'a' && c <= 'z') {
|
||||||
|
english_char_count++;
|
||||||
|
}
|
||||||
|
if (c == '<' || c == '>') {
|
||||||
|
has_angle_bracket = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return english_char_count >= 2 && !has_angle_bracket;
|
||||||
|
}
|
|
@ -41,6 +41,7 @@ typedef struct llama_sampling_params {
|
||||||
float mirostat_eta = 0.10f; // learning rate
|
float mirostat_eta = 0.10f; // learning rate
|
||||||
bool penalize_nl = false; // consider newlines as a repeatable token
|
bool penalize_nl = false; // consider newlines as a repeatable token
|
||||||
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
|
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
|
||||||
|
bool ignore_english_tokens = false; // Ignore tokens with 3+ English characters (except those with angle brackets)
|
||||||
|
|
||||||
std::vector<llama_sampler_type> samplers_sequence = {
|
std::vector<llama_sampler_type> samplers_sequence = {
|
||||||
llama_sampler_type::TOP_K,
|
llama_sampler_type::TOP_K,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue