Merge pull request #29 from wwoodsTM/test-dry-sampler
Add DRY sampling parameters to gpt_params and server_context
This commit is contained in:
commit
d1676a10f9
2 changed files with 79 additions and 24 deletions
|
@ -555,6 +555,26 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
|||
sparams.penalty_present = std::stof(argv[i]);
|
||||
return true;
|
||||
}
|
||||
if (arg == "--dry-multiplier") {
|
||||
CHECK_ARG
|
||||
sparams.dry_multiplier = std::stof(argv[i]);
|
||||
return true;
|
||||
}
|
||||
if (arg == "--dry-base") {
|
||||
CHECK_ARG
|
||||
sparams.dry_base = std::stof(argv[i]);
|
||||
return true;
|
||||
}
|
||||
if (arg == "--dry-allowed-length") {
|
||||
CHECK_ARG
|
||||
sparams.dry_allowed_length = std::stoi(argv[i]);
|
||||
return true;
|
||||
}
|
||||
if (arg == "--dry-penalty-last-n") {
|
||||
CHECK_ARG
|
||||
sparams.dry_penalty_last_n = std::stoi(argv[i]);
|
||||
return true;
|
||||
}
|
||||
if (arg == "--dynatemp-range") {
|
||||
CHECK_ARG
|
||||
sparams.dynatemp_range = std::stof(argv[i]);
|
||||
|
@ -1471,6 +1491,11 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
|
|||
options.push_back({ "*", " --repeat-penalty N", "penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)sparams.penalty_repeat });
|
||||
options.push_back({ "*", " --presence-penalty N", "repeat alpha presence penalty (default: %.1f, 0.0 = disabled)", (double)sparams.penalty_present });
|
||||
options.push_back({ "*", " --frequency-penalty N", "repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)", (double)sparams.penalty_freq });
|
||||
options.push_back({ "*", " --dry-multiplier N", "DRY sampling multiplier (default: %.1f, 0.0 = disabled)", (double)sparams.dry_multiplier });
|
||||
options.push_back({ "*", " --dry-base N", "DRY sampling base (default: %.1f)", (double)sparams.dry_base });
|
||||
options.push_back({ "*", " --dry-allowed-length N", "DRY sampling allowed length (default: %d)", sparams.dry_allowed_length });
|
||||
options.push_back({ "*", " --dry-penalty-last-n N", "DRY sampling penalty last n tokens (-1 = context size, default: %d)", sparams.dry_penalty_last_n });
|
||||
|
||||
options.push_back({ "*", " --dynatemp-range N", "dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)sparams.dynatemp_range });
|
||||
options.push_back({ "*", " --dynatemp-exp N", "dynamic temperature exponent (default: %.1f)", (double)sparams.dynatemp_exponent });
|
||||
options.push_back({ "*", " --mirostat N", "use Mirostat sampling.\n"
|
||||
|
|
|
@ -898,30 +898,55 @@ struct server_context {
|
|||
slot.oaicompat_model = "";
|
||||
}
|
||||
|
||||
slot.params.stream = json_value(data, "stream", false);
|
||||
slot.params.cache_prompt = json_value(data, "cache_prompt", false);
|
||||
slot.params.n_predict = json_value(data, "n_predict", default_params.n_predict);
|
||||
slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
|
||||
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
|
||||
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
|
||||
slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
|
||||
slot.sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p);
|
||||
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
|
||||
slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
|
||||
slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
|
||||
slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
|
||||
slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat);
|
||||
slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq);
|
||||
slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present);
|
||||
slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
|
||||
slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
|
||||
slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
|
||||
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
|
||||
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
|
||||
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
|
||||
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
|
||||
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
||||
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
|
||||
slot.params.stream = json_value(data, "stream", false);
|
||||
slot.params.cache_prompt = json_value(data, "cache_prompt", false);
|
||||
slot.params.n_predict = json_value(data, "n_predict", default_params.n_predict);
|
||||
slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
|
||||
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
|
||||
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
|
||||
slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
|
||||
slot.sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p);
|
||||
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
|
||||
slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
|
||||
slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
|
||||
slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
|
||||
slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat);
|
||||
slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq);
|
||||
slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present);
|
||||
slot.sparams.dry_multiplier = json_value(data, "dry_multiplier", default_sparams.dry_multiplier);
|
||||
slot.sparams.dry_base = json_value(data, "dry_base", default_sparams.dry_base);
|
||||
slot.sparams.dry_allowed_length = json_value(data, "dry_allowed_length", default_sparams.dry_allowed_length);
|
||||
slot.sparams.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", default_sparams.dry_penalty_last_n);
|
||||
slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
|
||||
slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
|
||||
slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
|
||||
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
|
||||
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
|
||||
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
|
||||
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
|
||||
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
||||
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
|
||||
|
||||
// sequence breakers for DRY
|
||||
{
|
||||
auto dry_seq_breakers = data.find("dry_seq_breakers");
|
||||
if (dry_seq_breakers != data.end()) {
|
||||
try {
|
||||
if (dry_seq_breakers->is_array()) {
|
||||
slot.sparams.dry_seq_breakers = dry_seq_breakers->get<std::vector<std::string>>();
|
||||
} else if (dry_seq_breakers->is_string()) {
|
||||
slot.sparams.dry_seq_breakers = json::parse(dry_seq_breakers->get<std::string>()).get<std::vector<std::string>>();
|
||||
} else {
|
||||
send_error(task, "\"dry_seq_breakers\": Expected an array of strings or a JSON-encoded array of strings.", ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
}
|
||||
} catch (const std::exception & e) {
|
||||
send_error(task, std::string("\"dry_seq_breakers\": ") + e.what(), ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// process "json_schema" and "grammar"
|
||||
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
|
||||
|
@ -1339,6 +1364,11 @@ struct server_context {
|
|||
{"frequency_penalty", slot.sparams.penalty_freq},
|
||||
{"penalty_prompt_tokens", slot.sparams.penalty_prompt_tokens},
|
||||
{"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens},
|
||||
{"dry_multiplier", slot.sparams.dry_multiplier},
|
||||
{"dry_base", slot.sparams.dry_base},
|
||||
{"dry_allowed_length", slot.sparams.dry_allowed_length},
|
||||
{"dry_penalty_last_n", slot.sparams.dry_penalty_last_n},
|
||||
{"dry_seq_breakers", slot.sparams.dry_seq_breakers},
|
||||
{"mirostat", slot.sparams.mirostat},
|
||||
{"mirostat_tau", slot.sparams.mirostat_tau},
|
||||
{"mirostat_eta", slot.sparams.mirostat_eta},
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue