llama : add DRY sampler (#9702)
* sampling : add DRY sampler (post-refactor) * DRY: Trying to fix coauthors, removed unneeded line * DRY: Fixed redundant code * DRY: Fixed crash issue due to DRY being in chain but uninitialized --------- Co-authored-by: l3utterfly <gc.pthzfoldr@gmail.com> Co-authored-by: pi6am <34464159+pi6am@users.noreply.github.com>
This commit is contained in:
parent
d80fb71f8b
commit
ff252ea48e
17 changed files with 713 additions and 63 deletions
|
@ -130,9 +130,11 @@ std::string common_sampler_params::print() const {
|
|||
|
||||
snprintf(result, sizeof(result),
|
||||
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
|
||||
"\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
|
||||
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n"
|
||||
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
|
||||
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
|
||||
dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
|
||||
top_k, tfs_z, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp,
|
||||
mirostat, mirostat_eta, mirostat_tau);
|
||||
|
||||
|
@ -174,6 +176,17 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||
if (params.mirostat == 0) {
|
||||
for (const auto & cnstr : params.samplers) {
|
||||
switch (cnstr) {
|
||||
case COMMON_SAMPLER_TYPE_DRY:
|
||||
{
|
||||
std::vector<const char*> c_breakers;
|
||||
c_breakers.reserve(params.dry_sequence_breakers.size());
|
||||
for (const auto& str : params.dry_sequence_breakers) {
|
||||
c_breakers.push_back(str.c_str());
|
||||
}
|
||||
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
|
||||
}
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
|
||||
break;
|
||||
|
@ -358,6 +371,7 @@ std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx_
|
|||
|
||||
char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
|
||||
switch (cnstr) {
|
||||
case COMMON_SAMPLER_TYPE_DRY: return 'd';
|
||||
case COMMON_SAMPLER_TYPE_TOP_K: return 'k';
|
||||
case COMMON_SAMPLER_TYPE_TFS_Z: return 'f';
|
||||
case COMMON_SAMPLER_TYPE_TYPICAL_P: return 'y';
|
||||
|
@ -372,6 +386,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
|
|||
|
||||
std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
|
||||
switch (cnstr) {
|
||||
case COMMON_SAMPLER_TYPE_DRY: return "dry";
|
||||
case COMMON_SAMPLER_TYPE_TOP_K: return "top_k";
|
||||
case COMMON_SAMPLER_TYPE_TFS_Z: return "tfs_z";
|
||||
case COMMON_SAMPLER_TYPE_TYPICAL_P: return "typ_p";
|
||||
|
@ -386,6 +401,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
|
|||
|
||||
std::vector<common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
|
||||
std::unordered_map<std::string, common_sampler_type> sampler_canonical_name_map {
|
||||
{ "dry", COMMON_SAMPLER_TYPE_DRY },
|
||||
{ "top_k", COMMON_SAMPLER_TYPE_TOP_K },
|
||||
{ "top_p", COMMON_SAMPLER_TYPE_TOP_P },
|
||||
{ "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
||||
|
@ -434,6 +450,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
|
|||
|
||||
std::vector<common_sampler_type> common_sampler_types_from_chars(const std::string & chars) {
|
||||
std::unordered_map<char, common_sampler_type> sampler_name_map = {
|
||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_DRY), COMMON_SAMPLER_TYPE_DRY },
|
||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K), COMMON_SAMPLER_TYPE_TOP_K },
|
||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TFS_Z), COMMON_SAMPLER_TYPE_TFS_Z },
|
||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P },
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue