diff --git a/common/common.h b/common/common.h index 65c0ef81a..cc97ec9c2 100644 --- a/common/common.h +++ b/common/common.h @@ -90,7 +90,8 @@ struct gpt_params { float yarn_beta_slow = 1.0f; // YaRN high correction dim int32_t yarn_orig_ctx = 0; // YaRN original context length float defrag_thold = -1.0f; // KV cache defragmentation threshold - + bool ignore_english_tokens = false; // Experimental: Attempt to not sample tokens containing English characters + ggml_backend_sched_eval_callback cb_eval = nullptr; void * cb_eval_user_data = nullptr; diff --git a/common/sampling.cpp b/common/sampling.cpp index 9f332fe57..e04f5254a 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -121,10 +121,12 @@ std::string llama_sampling_print(const llama_sampling_params & params) { snprintf(result, sizeof(result), "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" "\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n" - "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f", + "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f\n" + "\tignore_english_tokens = %s", params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present, params.top_k, params.tfs_z, params.top_p, params.min_p, params.typical_p, params.temp, - params.mirostat, params.mirostat_eta, params.mirostat_tau); + params.mirostat, params.mirostat_eta, params.mirostat_tau, + params.ignore_english_tokens ? "true" : "false"); return std::string(result); } @@ -423,6 +425,13 @@ static llama_token_data_array llama_sampling_prepare_impl( llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar); } + if (params.ignore_english_tokens) { + for (size_t i = 0; i < cur_p.size; ++i) { + if (is_english_token(ctx_main, cur_p.data[i].id)) { + cur_p.data[i].logit = -INFINITY; + } + } + return cur_p; } @@ -457,3 +466,20 @@ void llama_sampling_accept( llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id); } } + +bool is_english_token(const llama_context * ctx, llama_token token) { + const std::string token_str = llama_token_to_piece(ctx, token); + int english_char_count = 0; + bool has_angle_bracket = false; + + for (char c : token_str) { + if (c >= 'a' && c <= 'z') { + english_char_count++; + } + if (c == '<' || c == '>') { + has_angle_bracket = true; + } + } + + return english_char_count >= 2 && !has_angle_bracket; +} \ No newline at end of file diff --git a/common/sampling.h b/common/sampling.h index eeaa53b8b..e2462ea9a 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -41,6 +41,7 @@ typedef struct llama_sampling_params { float mirostat_eta = 0.10f; // learning rate bool penalize_nl = false; // consider newlines as a repeatable token uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context + bool ignore_english_tokens = false; // Ignore tokens with 3+ English characters (except those with angle brackets) std::vector samplers_sequence = { llama_sampler_type::TOP_K,