llama : add infill sampler (#9896)

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-10-15 16:35:33 +03:00 committed by GitHub
parent 223c25a72f
commit 755a9b2bf0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 300 additions and 29 deletions

View file

@ -91,7 +91,7 @@ enum common_sampler_type {
COMMON_SAMPLER_TYPE_TYPICAL_P = 5,
COMMON_SAMPLER_TYPE_TEMPERATURE = 6,
COMMON_SAMPLER_TYPE_XTC = 7,
COMMON_SAMPLER_TYPE_INFILL = 8,
};
// dimensionality reduction methods, used by cvector-generator
@ -136,7 +136,7 @@ struct common_sampler_params {
COMMON_SAMPLER_TYPE_TOP_P,
COMMON_SAMPLER_TYPE_MIN_P,
COMMON_SAMPLER_TYPE_XTC,
COMMON_SAMPLER_TYPE_TEMPERATURE
COMMON_SAMPLER_TYPE_TEMPERATURE,
};
std::string grammar; // optional BNF-like grammar to constrain sampling

View file

@ -196,6 +196,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
case COMMON_SAMPLER_TYPE_TEMPERATURE:
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
break;
case COMMON_SAMPLER_TYPE_INFILL:
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model));
break;
default:
GGML_ASSERT(false && "unknown sampler type");
}
@ -376,6 +379,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
case COMMON_SAMPLER_TYPE_MIN_P: return 'm';
case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
case COMMON_SAMPLER_TYPE_XTC: return 'x';
case COMMON_SAMPLER_TYPE_INFILL: return 'i';
default : return '?';
}
}
@ -389,6 +393,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
case COMMON_SAMPLER_TYPE_MIN_P: return "min_p";
case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
case COMMON_SAMPLER_TYPE_INFILL: return "infill";
default : return "";
}
}
@ -402,6 +407,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
{ "tfs_z", COMMON_SAMPLER_TYPE_TFS_Z },
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
{ "infill", COMMON_SAMPLER_TYPE_INFILL },
};
// since samplers names are written multiple ways
@ -448,7 +454,8 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC }
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL },
};
std::vector<common_sampler_type> samplers;