sampling : add XTC sampler (#9742)
* Initial XTC commit Adds XTC sampler, not activated by default, but recommended settings by default. * Cleanup * Simplified chances calculation To be more inline with the original implementation, chance is calculated once at the beginning. * First fixes by comments Still need to look into sorting * Fixed trailing backspaces * Fixed RNG to be reproduceable Thanks to @slaren for directions * Fixed forgotten header * Moved `min_keep` Moved from conditions to a simple check at the end. * Fixed broken randomization Thanks to @slaren for explanation * Swapped sorting for a custom algorithm Shifts tokens to remove the penalized ones, then puts the penalized at the back. Should make `min_keep` still viable. * Algorithm rework 1. Scan token from top till the first non-penalizable 2. Remove the last captured token (the least probable above threshold) 3. Shift all tokens to override the remaining penalizable 4. Penalize and put them at the the bottom. * Added XTC to `test-sampling` * Simplified algorithm and more tests * Updated info in common and args * Merged back lost commits in common and arg * Update dump info in common * Fixed incorrect min_keep check * Added XTC to README * Renamed parameters, fixed info and defaults * probability is at 0 by default, but XTC is included in sampling queue * threshold higher than 0.5 switches XTC off * Initial server support * Added XTC to server UIs * Fixed labels in old server UI * Made algorithm safer and more readable * Removed xtc_threshold_max * Fixed arg after update * Quick fixes by comments * Simplified algorithm since threshold_max is removed * Renamed random distribution * Fixed tests and outdated README * Small fixes
This commit is contained in:
parent
dcdd535302
commit
fbc98b748e
11 changed files with 195 additions and 10 deletions
|
@ -947,6 +947,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
params.sparams.tfs_z = std::stof(value);
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--xtc-probability"}, "N",
|
||||
string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sparams.xtc_probability),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sparams.xtc_probability = std::stof(value);
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--xtc-threshold"}, "N",
|
||||
string_format("xtc threshold (default: %.1f, 1.0 = disabled)", (double)params.sparams.xtc_threshold),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sparams.xtc_threshold = std::stof(value);
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--typical"}, "N",
|
||||
string_format("locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)params.sparams.typ_p),
|
||||
|
|
|
@ -2104,6 +2104,8 @@ void yaml_dump_non_result_info(FILE * stream, const common_params & params, cons
|
|||
fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k);
|
||||
fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p);
|
||||
fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p);
|
||||
fprintf(stream, "xtc_probability: %f # default: 0.0\n", sparams.xtc_probability);
|
||||
fprintf(stream, "xtc_threshold: %f # default: 0.1\n", sparams.xtc_threshold);
|
||||
fprintf(stream, "typ_p: %f # default: 1.0\n", sparams.typ_p);
|
||||
fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
|
||||
fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false");
|
||||
|
|
|
@ -90,6 +90,8 @@ enum common_sampler_type {
|
|||
COMMON_SAMPLER_TYPE_TFS_Z = 4,
|
||||
COMMON_SAMPLER_TYPE_TYPICAL_P = 5,
|
||||
COMMON_SAMPLER_TYPE_TEMPERATURE = 6,
|
||||
COMMON_SAMPLER_TYPE_XTC = 7,
|
||||
|
||||
};
|
||||
|
||||
// dimensionality reduction methods, used by cvector-generator
|
||||
|
@ -108,6 +110,8 @@ struct common_sampler_params {
|
|||
int32_t top_k = 40; // <= 0 to use vocab size
|
||||
float top_p = 0.95f; // 1.0 = disabled
|
||||
float min_p = 0.05f; // 0.0 = disabled
|
||||
float xtc_probability = 0.00f; // 0.0 = disabled
|
||||
float xtc_threshold = 0.10f; // > 0.5 disables XTC
|
||||
float tfs_z = 1.00f; // 1.0 = disabled
|
||||
float typ_p = 1.00f; // typical_p, 1.0 = disabled
|
||||
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
|
||||
|
@ -124,12 +128,14 @@ struct common_sampler_params {
|
|||
bool ignore_eos = false;
|
||||
bool no_perf = false; // disable performance metrics
|
||||
|
||||
|
||||
std::vector<enum common_sampler_type> samplers = {
|
||||
COMMON_SAMPLER_TYPE_TOP_K,
|
||||
COMMON_SAMPLER_TYPE_TFS_Z,
|
||||
COMMON_SAMPLER_TYPE_TYPICAL_P,
|
||||
COMMON_SAMPLER_TYPE_TOP_P,
|
||||
COMMON_SAMPLER_TYPE_MIN_P,
|
||||
COMMON_SAMPLER_TYPE_XTC,
|
||||
COMMON_SAMPLER_TYPE_TEMPERATURE
|
||||
};
|
||||
|
||||
|
|
|
@ -130,10 +130,10 @@ std::string common_sampler_params::print() const {
|
|||
|
||||
snprintf(result, sizeof(result),
|
||||
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
|
||||
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n"
|
||||
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n"
|
||||
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
|
||||
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
|
||||
top_k, tfs_z, top_p, min_p, typ_p, temp,
|
||||
top_k, tfs_z, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp,
|
||||
mirostat, mirostat_eta, mirostat_tau);
|
||||
|
||||
return std::string(result);
|
||||
|
@ -184,6 +184,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||
case COMMON_SAMPLER_TYPE_MIN_P:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_XTC:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TFS_Z:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free(params.tfs_z, params.min_keep));
|
||||
break;
|
||||
|
@ -372,6 +375,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
|
|||
case COMMON_SAMPLER_TYPE_TOP_P: return 'p';
|
||||
case COMMON_SAMPLER_TYPE_MIN_P: return 'm';
|
||||
case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
|
||||
case COMMON_SAMPLER_TYPE_XTC: return 'x';
|
||||
default : return '?';
|
||||
}
|
||||
}
|
||||
|
@ -384,6 +388,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
|
|||
case COMMON_SAMPLER_TYPE_TOP_P: return "top_p";
|
||||
case COMMON_SAMPLER_TYPE_MIN_P: return "min_p";
|
||||
case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
|
||||
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
|
||||
default : return "";
|
||||
}
|
||||
}
|
||||
|
@ -396,6 +401,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
|
|||
{ "min_p", COMMON_SAMPLER_TYPE_MIN_P },
|
||||
{ "tfs_z", COMMON_SAMPLER_TYPE_TFS_Z },
|
||||
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
|
||||
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
|
||||
};
|
||||
|
||||
// since samplers names are written multiple ways
|
||||
|
@ -441,7 +447,8 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
|
|||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P },
|
||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P },
|
||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P },
|
||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE }
|
||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
|
||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC }
|
||||
};
|
||||
|
||||
std::vector<common_sampler_type> samplers;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue