Add option to ignore tokens with 2+ English characters

This commit is contained in:
GC 2024-07-03 15:54:58 +01:00
parent 916248af1f
commit 3c5acd512a
3 changed files with 31 additions and 3 deletions

View file

@ -90,7 +90,8 @@ struct gpt_params {
float yarn_beta_slow = 1.0f; // YaRN high correction dim float yarn_beta_slow = 1.0f; // YaRN high correction dim
int32_t yarn_orig_ctx = 0; // YaRN original context length int32_t yarn_orig_ctx = 0; // YaRN original context length
float defrag_thold = -1.0f; // KV cache defragmentation threshold float defrag_thold = -1.0f; // KV cache defragmentation threshold
bool ignore_english_tokens = false; // Experimental: Attempt to not sample tokens containing English characters
ggml_backend_sched_eval_callback cb_eval = nullptr; ggml_backend_sched_eval_callback cb_eval = nullptr;
void * cb_eval_user_data = nullptr; void * cb_eval_user_data = nullptr;

View file

@ -121,10 +121,12 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
snprintf(result, sizeof(result), snprintf(result, sizeof(result),
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n" "\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n"
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f", "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f\n"
"\tignore_english_tokens = %s",
params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present, params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present,
params.top_k, params.tfs_z, params.top_p, params.min_p, params.typical_p, params.temp, params.top_k, params.tfs_z, params.top_p, params.min_p, params.typical_p, params.temp,
params.mirostat, params.mirostat_eta, params.mirostat_tau); params.mirostat, params.mirostat_eta, params.mirostat_tau,
params.ignore_english_tokens ? "true" : "false");
return std::string(result); return std::string(result);
} }
@ -423,6 +425,13 @@ static llama_token_data_array llama_sampling_prepare_impl(
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar); llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
} }
if (params.ignore_english_tokens) {
for (size_t i = 0; i < cur_p.size; ++i) {
if (is_english_token(ctx_main, cur_p.data[i].id)) {
cur_p.data[i].logit = -INFINITY;
}
}
return cur_p; return cur_p;
} }
@ -457,3 +466,20 @@ void llama_sampling_accept(
llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id); llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id);
} }
} }
bool is_english_token(const llama_context * ctx, llama_token token) {
const std::string token_str = llama_token_to_piece(ctx, token);
int english_char_count = 0;
bool has_angle_bracket = false;
for (char c : token_str) {
if (c >= 'a' && c <= 'z') {
english_char_count++;
}
if (c == '<' || c == '>') {
has_angle_bracket = true;
}
}
return english_char_count >= 2 && !has_angle_bracket;
}

View file

@ -41,6 +41,7 @@ typedef struct llama_sampling_params {
float mirostat_eta = 0.10f; // learning rate float mirostat_eta = 0.10f; // learning rate
bool penalize_nl = false; // consider newlines as a repeatable token bool penalize_nl = false; // consider newlines as a repeatable token
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
bool ignore_english_tokens = false; // Ignore tokens with 3+ English characters (except those with angle brackets)
std::vector<llama_sampler_type> samplers_sequence = { std::vector<llama_sampler_type> samplers_sequence = {
llama_sampler_type::TOP_K, llama_sampler_type::TOP_K,