This commit is contained in:
Philipp Emanuel Weidmann 2024-03-15 17:11:14 -05:00 committed by GitHub
commit 7dbc4c9c9b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 145 additions and 37 deletions

View file

@ -376,6 +376,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break;
}
sparams.min_p = std::stof(argv[i]);
} else if (arg == "--p-step") {
if (++i >= argc) {
invalid_param = true;
break;
}
sparams.p_step = std::stof(argv[i]);
} else if (arg == "--temp") {
if (++i >= argc) {
invalid_param = true;
@ -1020,6 +1026,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k);
printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p);
printf(" --min-p N min-p sampling (default: %.1f, 0.0 = disabled)\n", (double)sparams.min_p);
printf(" --p-step N p-step sampling (default: %.1f, 0.0 = disabled)\n", (double)sparams.p_step);
printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z);
printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p);
printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.penalty_last_n);
@ -1199,6 +1206,7 @@ std::vector<llama_sampler_type> sampler_types_from_names(const std::vector<std::
{"top_p", llama_sampler_type::TOP_P},
{"typical_p", llama_sampler_type::TYPICAL_P},
{"min_p", llama_sampler_type::MIN_P},
{"p_step", llama_sampler_type::P_STEP},
{"tfs_z", llama_sampler_type::TFS_Z},
{"temperature", llama_sampler_type::TEMPERATURE}
};
@ -1212,6 +1220,7 @@ std::vector<llama_sampler_type> sampler_types_from_names(const std::vector<std::
{"typical-p", llama_sampler_type::TYPICAL_P},
{"typical", llama_sampler_type::TYPICAL_P},
{"min-p", llama_sampler_type::MIN_P},
{"p-step", llama_sampler_type::P_STEP},
{"tfs-z", llama_sampler_type::TFS_Z},
{"tfs", llama_sampler_type::TFS_Z},
{"temp", llama_sampler_type::TEMPERATURE}
@ -1247,6 +1256,7 @@ std::vector<llama_sampler_type> sampler_types_from_chars(const std::string & nam
{'p', llama_sampler_type::TOP_P},
{'y', llama_sampler_type::TYPICAL_P},
{'m', llama_sampler_type::MIN_P},
{'s', llama_sampler_type::P_STEP},
{'f', llama_sampler_type::TFS_Z},
{'t', llama_sampler_type::TEMPERATURE}
};
@ -1269,6 +1279,7 @@ std::string sampler_type_to_name_string(llama_sampler_type sampler_type) {
case llama_sampler_type::TYPICAL_P: return "typical_p";
case llama_sampler_type::TOP_P: return "top_p";
case llama_sampler_type::MIN_P: return "min_p";
case llama_sampler_type::P_STEP: return "p_step";
case llama_sampler_type::TEMPERATURE: return "temperature";
default : return "";
}
@ -1841,6 +1852,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k);
fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p);
fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p);
fprintf(stream, "p_step: %f # default: 0.0\n", sparams.p_step);
fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p);
fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false");

View file

@ -98,10 +98,10 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
snprintf(result, sizeof(result),
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n"
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, p_step = %.3f, typical_p = %.3f, temp = %.3f\n"
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present,
params.top_k, params.tfs_z, params.top_p, params.min_p, params.typical_p, params.temp,
params.top_k, params.tfs_z, params.top_p, params.min_p, params.p_step, params.typical_p, params.temp,
params.mirostat, params.mirostat_eta, params.mirostat_tau);
return std::string(result);
@ -135,6 +135,7 @@ static void sampler_queue(
const int32_t top_k = params.top_k;
const float top_p = params.top_p;
const float min_p = params.min_p;
const float p_step = params.p_step;
const float tfs_z = params.tfs_z;
const float typical_p = params.typical_p;
const std::vector<llama_sampler_type> & samplers_sequence = params.samplers_sequence;
@ -146,6 +147,7 @@ static void sampler_queue(
case llama_sampler_type::TYPICAL_P: llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break;
case llama_sampler_type::TOP_P : llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break;
case llama_sampler_type::MIN_P : llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break;
case llama_sampler_type::P_STEP : llama_sample_p_step (ctx_main, &cur_p, p_step, min_keep); break;
case llama_sampler_type::TEMPERATURE:
if (dynatemp_range > 0) {
float dynatemp_min = std::max(0.0f, temp - dynatemp_range);

View file

@ -13,6 +13,7 @@ enum class llama_sampler_type : char {
TOP_K = 'k',
TOP_P = 'p',
MIN_P = 'm',
P_STEP = 's',
TFS_Z = 'f',
TYPICAL_P = 'y',
TEMPERATURE = 't'
@ -26,6 +27,7 @@ typedef struct llama_sampling_params {
int32_t top_k = 40; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled
float min_p = 0.05f; // 0.0 = disabled
float p_step = 0.00f; // 0.0 = disabled
float tfs_z = 1.00f; // 1.0 = disabled
float typical_p = 1.00f; // 1.0 = disabled
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
@ -46,6 +48,7 @@ typedef struct llama_sampling_params {
llama_sampler_type::TYPICAL_P,
llama_sampler_type::TOP_P,
llama_sampler_type::MIN_P,
llama_sampler_type::P_STEP,
llama_sampler_type::TEMPERATURE
};

View file

@ -11098,6 +11098,34 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can
}
}
void llama_sample_p_step(struct llama_context * ctx, llama_token_data_array * candidates, float step, size_t min_keep) {
if (step <= 0.0f || candidates->size <= 1) {
return;
}
llama_sample_softmax(nullptr, candidates);
const int64_t t_start_sample_us = ggml_time_us();
bool step_found = false;
for (size_t i = 1; i < candidates->size; ++i) {
if (!step_found && candidates->data[i].p < step * candidates->data[i - 1].p) {
step_found = true;
}
if (step_found && i >= min_keep) {
// Resize the output vector to keep only the tokens before the step
candidates->size = i;
break;
}
}
if (ctx) {
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
}
}
void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) {
if (z >= 1.0f || candidates->size <= 2) {
return;

View file

@ -840,6 +840,13 @@ extern "C" {
float p,
size_t min_keep);
/// @details P-Step sampling as described in [THIS PR]
LLAMA_API void llama_sample_p_step(
struct llama_context * ctx,
llama_token_data_array * candidates,
float step,
size_t min_keep);
/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
LLAMA_API void llama_sample_tail_free(
struct llama_context * ctx,

View file

@ -101,6 +101,27 @@ static void test_min_p(const std::vector<float> & probs, const std::vector<float
}
}
static void test_p_step(const std::vector<float> & probs, const std::vector<float> & expected_probs, float step) {
const size_t n_vocab = probs.size();
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
const float logit = logf(probs[token_id]);
candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
}
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
DUMP(&candidates_p);
llama_sample_p_step(nullptr, &candidates_p, step, 1);
DUMP(&candidates_p);
llama_sample_softmax(nullptr, &candidates_p);
GGML_ASSERT(candidates_p.size == expected_probs.size());
for (size_t i = 0; i < candidates_p.size; i++) {
GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
}
}
static void test_typical(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
const size_t n_vocab = probs.size();
std::vector<llama_token_data> candidates;
@ -149,7 +170,7 @@ static void test_repetition_penalties(
}
static void test_sampler_queue(
const size_t n_vocab, const std::string samplers_sequence, const int top_k, const float top_p, const float min_p
const size_t n_vocab, const std::string samplers_sequence, const int top_k, const float top_p, const float min_p, const float p_step
) {
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
@ -164,14 +185,15 @@ static void test_sampler_queue(
const llama_token max_token_id = n_vocab-1;
for (auto s : samplers_sequence) {
switch (s){
case 'k': llama_sample_top_k (nullptr, &candidates_p, top_k, 1); break;
case 'f': GGML_ASSERT(false && "tail_free test not implemented"); break;
case 'y': GGML_ASSERT(false && "typical test not implemented"); break;
case 'p': llama_sample_top_p (nullptr, &candidates_p, top_p, 1); break;
case 'm': llama_sample_min_p (nullptr, &candidates_p, min_p, 1); break;
case 't': GGML_ASSERT(false && "temperature test not implemented"); break;
default : GGML_ASSERT(false && "Unknown sampler"); break;
switch (s) {
case 'k': llama_sample_top_k (nullptr, &candidates_p, top_k, 1); break;
case 'f': GGML_ASSERT(false && "tail_free test not implemented"); break;
case 'y': GGML_ASSERT(false && "typical test not implemented"); break;
case 'p': llama_sample_top_p (nullptr, &candidates_p, top_p, 1); break;
case 'm': llama_sample_min_p (nullptr, &candidates_p, min_p, 1); break;
case 's': llama_sample_p_step (nullptr, &candidates_p, p_step, 1); break;
case 't': GGML_ASSERT(false && "temperature test not implemented"); break;
default : GGML_ASSERT(false && "Unknown sampler"); break;
}
llama_sample_softmax(nullptr, &candidates_p); // make sure tokens are sorted for tests
@ -218,6 +240,18 @@ static void test_sampler_queue(
min_token_id = std::max(min_token_id, (llama_token)(n_vocab - size));
min_token_id = std::min(min_token_id, (llama_token)(n_vocab - 1));
GGML_ASSERT(size == expected_size);
GGML_ASSERT(candidates_p.data[0].id == max_token_id);
GGML_ASSERT(candidates_p.data[expected_size-1].id == min_token_id);
} else if (s == 's') {
min_token_id = n_vocab;
int expected_size = 0;
do { // do-while because always at least one token is sampled
min_token_id--;
expected_size++;
} while (candidates_p.data[expected_size].p >= p_step * candidates_p.data[expected_size - 1].p);
GGML_ASSERT(size == expected_size);
GGML_ASSERT(candidates_p.data[0].id == max_token_id);
GGML_ASSERT(candidates_p.data[expected_size-1].id == min_token_id);
@ -226,8 +260,8 @@ static void test_sampler_queue(
}
}
printf("Sampler queue %3s OK with n_vocab=%05ld top_k=%05d top_p=%f min_p=%f\n",
samplers_sequence.c_str(), n_vocab, top_k, top_p, min_p);
printf("Sampler queue %3s OK with n_vocab=%05ld top_k=%05d top_p=%f min_p=%f p_step=%f\n",
samplers_sequence.c_str(), n_vocab, top_k, top_p, min_p, p_step);
}
int main(void) {
@ -252,6 +286,17 @@ int main(void) {
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 0.76f);
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 1.00f);
test_p_step({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/1.0f, 0.3f/1.0f, 0.2f/1.0f, 0.1f/1.0f}, 0.0f);
test_p_step({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/1.0f, 0.3f/1.0f, 0.2f/1.0f, 0.1f/1.0f}, 0.5f);
test_p_step({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.9f, 0.3f/0.9f, 0.2f/0.9f}, 0.6f);
test_p_step({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.7f, 0.3f/0.7f}, 0.7f);
test_p_step({0.2f, 0.2f, 0.3f, 0.4f}, {0.4f/0.7f, 0.3f/0.7f}, 0.7f);
test_p_step({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.7f, 0.3f/0.7f}, 0.74f);
// Disabled because of floating point nonsense: 0.3f < 0.75f * 0.4f is true!
//test_p_step({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.7f, 0.3f/0.7f}, 0.75f);
test_p_step({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 0.76f);
test_p_step({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 1.00f);
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f}, 0.25f);
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.75f);
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.99f);
@ -267,33 +312,44 @@ int main(void) {
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f);
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f);
test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f);
test_sampler_queue(10000, "k", 1, 1.0f, 1.0f);
test_sampler_queue(10000, "p", 10000, 1.0f, 1.0f);
test_sampler_queue(10000, "p", 10000, 0.0f, 1.0f);
test_sampler_queue(10000, "m", 10000, 1.0f, 1.0f);
test_sampler_queue(10000, "m", 10000, 1.0f, 1e-12);
test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f, 1.0f);
test_sampler_queue(10000, "k", 1, 1.0f, 1.0f, 1.0f);
test_sampler_queue(10000, "p", 10000, 1.0f, 1.0f, 1.0f);
test_sampler_queue(10000, "p", 10000, 0.0f, 1.0f, 1.0f);
test_sampler_queue(10000, "m", 10000, 1.0f, 1.0f, 1.0f);
test_sampler_queue(10000, "m", 10000, 1.0f, 1e-12, 1.0f);
test_sampler_queue(10000, "s", 10000, 1.0f, 1.0f, 1.0f);
test_sampler_queue(10000, "s", 10000, 1.0f, 1.0f, 1e-12);
test_sampler_queue(10000, "k", 100, 1.0000f, 1.0f);
test_sampler_queue(10000, "p", 10000, 0.0002f, 1.0f);
test_sampler_queue(10000, "p", 10000, 0.8000f, 1.0f);
test_sampler_queue(10000, "m", 10000, 1.0000f, 9997.9f/9999.0f);
test_sampler_queue(10000, "m", 10000, 1.0000f, 0.1f);
test_sampler_queue(10000, "k", 100, 1.0000f, 1.0f, 1.0f);
test_sampler_queue(10000, "p", 10000, 0.0002f, 1.0f, 1.0f);
test_sampler_queue(10000, "p", 10000, 0.8000f, 1.0f, 1.0f);
test_sampler_queue(10000, "m", 10000, 1.0000f, 9997.9f/9999.0f, 1.0f);
test_sampler_queue(10000, "m", 10000, 1.0000f, 0.1f, 1.0f);
test_sampler_queue(10000, "s", 10000, 1.0000f, 1.0f, 9997.9f/9999.0f);
test_sampler_queue(10000, "s", 10000, 1.0000f, 1.0f, 0.1f);
test_sampler_queue(10000, "kp", 100, 0.8f, 0.1f);
test_sampler_queue(10000, "km", 100, 0.8f, 0.1f);
test_sampler_queue(10000, "pk", 100, 0.8f, 0.1f);
test_sampler_queue(10000, "pm", 100, 0.8f, 0.1f);
test_sampler_queue(10000, "mk", 100, 0.8f, 0.1f);
test_sampler_queue(10000, "mp", 100, 0.8f, 9997.9f/9999.0f);
test_sampler_queue(10000, "mp", 100, 0.8f, 0.1f);
test_sampler_queue(10000, "kp", 100, 0.8f, 0.1f, 1.0f);
test_sampler_queue(10000, "km", 100, 0.8f, 0.1f, 1.0f);
test_sampler_queue(10000, "pk", 100, 0.8f, 0.1f, 1.0f);
test_sampler_queue(10000, "pm", 100, 0.8f, 0.1f, 1.0f);
test_sampler_queue(10000, "mk", 100, 0.8f, 0.1f, 1.0f);
test_sampler_queue(10000, "mp", 100, 0.8f, 9997.9f/9999.0f, 1.0f);
test_sampler_queue(10000, "mp", 100, 0.8f, 0.1f, 1.0f);
test_sampler_queue(10000, "ks", 100, 0.8f, 1.0f, 0.1f);
test_sampler_queue(10000, "sk", 100, 0.8f, 1.0f, 0.1f);
test_sampler_queue(10000, "sp", 100, 0.8f, 1.0f, 9997.9f/9999.0f);
test_sampler_queue(10000, "sp", 100, 0.8f, 1.0f, 0.1f);
test_sampler_queue(10000, "kpm", 100, 0.8f, 0.1f);
test_sampler_queue(10000, "kmp", 100, 0.8f, 0.1f);
test_sampler_queue(10000, "pkm", 100, 0.8f, 0.1f);
test_sampler_queue(10000, "pmk", 100, 0.8f, 0.1f);
test_sampler_queue(10000, "mkp", 100, 0.8f, 0.1f);
test_sampler_queue(10000, "mpk", 100, 0.8f, 0.1f);
test_sampler_queue(10000, "kpm", 100, 0.8f, 0.1f, 1.0f);
test_sampler_queue(10000, "kmp", 100, 0.8f, 0.1f, 1.0f);
test_sampler_queue(10000, "pkm", 100, 0.8f, 0.1f, 1.0f);
test_sampler_queue(10000, "pmk", 100, 0.8f, 0.1f, 1.0f);
test_sampler_queue(10000, "mkp", 100, 0.8f, 0.1f, 1.0f);
test_sampler_queue(10000, "mpk", 100, 0.8f, 0.1f, 1.0f);
test_sampler_queue(10000, "ksp", 100, 0.8f, 1.0f, 0.1f);
test_sampler_queue(10000, "skp", 100, 0.8f, 1.0f, 0.1f);
test_sampler_queue(10000, "spk", 100, 0.8f, 1.0f, 0.1f);
printf("OK\n");