Transform Min P into a proper CLI option

This commit is contained in:
kalomaze 2023-10-28 20:49:17 -05:00
parent a9e2b74f1a
commit a235a0d226
5 changed files with 75 additions and 106 deletions

View file

@ -218,6 +218,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break; break;
} }
sparams.top_p = std::stof(argv[i]); sparams.top_p = std::stof(argv[i]);
} else if (arg == "--min-p") { // Adding min_p argument
if (++i >= argc) {
invalid_param = true;
break;
}
sparams.min_p = std::stof(argv[i]); // Parsing and setting the min_p value from command line
} else if (arg == "--temp") { } else if (arg == "--temp") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
@ -679,6 +685,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k); printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k);
printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p); printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p);
printf(" --min-p N min-p sampling (default: %.1f, 0.0 = disabled)\n", (double)sparams.min_p);
printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z); printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z);
printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p); printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p);
printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.penalty_last_n); printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.penalty_last_n);
@ -1275,6 +1282,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "threads: %d # default: %d\n", params.n_threads, std::thread::hardware_concurrency()); fprintf(stream, "threads: %d # default: %d\n", params.n_threads, std::thread::hardware_concurrency());
fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k); fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k);
fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p); fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p);
fprintf(stream, "min_p: %f # default: 0.05\n", sparams.min_p);
fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p); fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p);
fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false"); fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
} }

View file

@ -89,10 +89,10 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
snprintf(result, sizeof(result), snprintf(result, sizeof(result),
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, typical_p = %.3f, temp = %.3f\n" "\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n"
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f", "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present, params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present,
params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.top_k, params.tfs_z, params.top_p, params.min_p, params.typical_p, params.temp,
params.mirostat, params.mirostat_eta, params.mirostat_tau); params.mirostat, params.mirostat_eta, params.mirostat_tau);
return std::string(result); return std::string(result);
@ -110,6 +110,7 @@ llama_token llama_sampling_sample(
const float temp = params.temp; const float temp = params.temp;
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k; const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
const float top_p = params.top_p; const float top_p = params.top_p;
const float min_p = params.min_p;
const float tfs_z = params.tfs_z; const float tfs_z = params.tfs_z;
const float typical_p = params.typical_p; const float typical_p = params.typical_p;
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n; const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
@ -190,6 +191,7 @@ llama_token llama_sampling_sample(
llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep);
llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep);
llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep);
llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep);
llama_sample_temp (ctx_main, &cur_p, temp); llama_sample_temp (ctx_main, &cur_p, temp);
id = llama_sample_token(ctx_main, &cur_p); id = llama_sample_token(ctx_main, &cur_p);

View file

@ -14,6 +14,7 @@ typedef struct llama_sampling_params {
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
int32_t top_k = 40; // <= 0 to use vocab size int32_t top_k = 40; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled float top_p = 0.95f; // 1.0 = disabled
float min_p = 0.05f; // 0.0 = disabled
float tfs_z = 1.00f; // 1.0 = disabled float tfs_z = 1.00f; // 1.0 = disabled
float typical_p = 1.00f; // 1.0 = disabled float typical_p = 1.00f; // 1.0 = disabled
float temp = 0.80f; // 1.0 = disabled float temp = 0.80f; // 1.0 = disabled

159
llama.cpp
View file

@ -7328,112 +7328,8 @@ void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * can
} }
} }
void read_or_write_base_min_p(float* base_min_p) {
// Define the filename
const std::string filename = "SamplerBaseMinP.txt";
// Check if the file exists
std::ifstream infile(filename);
if (!infile.good()) {
// File doesn't exist, create it with default values
std::ofstream outfile(filename);
if (outfile.is_open()) {
// Set the default value for base_min_p
*base_min_p = 0.05f;
outfile << "base_min_p = " << *base_min_p << "\n";
outfile.close();
} else {
// Handle the error during file opening or writing here (optional)
// For example, you might want to set a safe default value or notify the user
}
} else {
// File exists, read the values from it
std::string line;
while (getline(infile, line)) {
std::istringstream iss(line);
std::string key;
float value;
char equals; // Using char to read the '=' character
// Read each line in the format key = value
if (iss >> key >> equals >> value && equals == '=' && key == "base_min_p") {
*base_min_p = value;
}
// Note: If the key doesn't match, or the format is incorrect,
// you might want to handle the error or set a default value
}
infile.close();
}
}
void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
// Variables for the special mode
float base_min_p; // This will hold the base minimum probability value
float multiplied_min_p; // This will hold the adjusted minimum probability threshold
// If p is 1.0, we switch to a different sampling mode.
if (p >= 1.0f) { if (p >= 1.0f) {
printf("returning, top p set to 1.0");
return;
}
// If p is ~0.02, we switch to the Min P sampler
if (p >= 0.01f && p <= 0.03f) {
printf("USING MIN P SAMPLING MODE\n");
// Ensure the probabilities are calculated.
llama_sample_softmax(ctx, candidates);
// Print the top tokens before filtering
printf("Top tokens before filtering:\n");
for (size_t i = 0; i < candidates->size && i < 10; ++i) {
printf("Token %zu: %.6f%%\n", i + 1, candidates->data[i].p * 100); // Multiplying by 100 to convert to percentage
}
base_min_p = 0.05; // For example, 5% as the base minimum probability before reading the text value
read_or_write_base_min_p(&base_min_p);
// Calculate the multiplication factor based on the highest scoring token.
float multiplication_factor = candidates->data[0].p; // Assuming the probabilities are sorted
printf("Highest scoring token probability (multiplication factor): %f\n", multiplication_factor);
// Calculate the dynamic threshold.
multiplied_min_p = base_min_p * multiplication_factor;
printf("Base min_p value: %f\n", base_min_p);
printf("Calculated multiplied_min_p (threshold) value: %f\n", multiplied_min_p);
// Store the tokens that meet the threshold in a new list.
std::vector<llama_token_data> filtered_candidates;
filtered_candidates.reserve(candidates->size); // Reserve to avoid multiple reallocations
// Variable to count how many tokens meet the condition
int count_qualifying_tokens = 0;
for (size_t i = 0; i < candidates->size; ++i) {
// If a token's probability is above the threshold, we keep it.
if (candidates->data[i].p >= multiplied_min_p) {
filtered_candidates.push_back(candidates->data[i]);
++count_qualifying_tokens; // Increase count
}
}
// Debug information about how many tokens were retained
printf("Number of tokens that met the multiplied_min_p condition: %d\n", count_qualifying_tokens);
llama_sample_softmax(ctx, candidates); // re-normalize after pruning
// Print the top tokens after filtering
printf("Tokens after filtering:\n");
for (size_t i = 0; i < filtered_candidates.size() && i < 10; ++i) { // Adjust 10 to however many top tokens you want to display
printf("Token %zu: %.6f%%\n", i + 1, filtered_candidates[i].p * 100); // Multiplying by 100 to convert to percentage
}
// Now we replace the original candidates with the filtered list.
std::copy(filtered_candidates.begin(), filtered_candidates.end(), candidates->data);
candidates->size = filtered_candidates.size();
// Since we're not actually sampling below a certain 'p', we return from the function after this.
return; return;
} }
@ -7464,6 +7360,61 @@ void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * can
} }
} }
void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
float base_min_p = p; // This will hold the base minimum probability value
float multiplied_min_p; // This will hold the adjusted minimum probability threshold
printf("\nUSING MIN P SAMPLING MODE\n\n");
// Ensure the probabilities are calculated.
llama_sample_softmax(ctx, candidates);
// Print the top tokens before filtering
printf("Top tokens before filtering:\n");
for (size_t i = 0; i < candidates->size && i < 10; ++i) {
printf("Token %zu: %.6f%%\n", i + 1, candidates->data[i].p * 100); // Multiplying by 100 to convert to percentage
}
// Calculate the multiplication factor based on the highest scoring token.
float multiplication_factor = candidates->data[0].p; // Assuming the probabilities are sorted
printf("Highest scoring token probability (multiplication factor): %f\n", multiplication_factor);
// Calculate the dynamic threshold.
multiplied_min_p = base_min_p * multiplication_factor;
printf("Base min_p value: %f\n", base_min_p);
printf("Calculated multiplied_min_p (threshold) value: %f\n", multiplied_min_p);
// Store the tokens that meet the threshold in a new list.
std::vector<llama_token_data> filtered_candidates;
filtered_candidates.reserve(candidates->size); // Reserve to avoid multiple reallocations
// Variable to count how many tokens meet the condition
int count_qualifying_tokens = 0;
for (size_t i = 0; i < candidates->size; ++i) {
// If a token's probability is above the threshold, we keep it.
if (candidates->data[i].p >= multiplied_min_p) {
filtered_candidates.push_back(candidates->data[i]);
++count_qualifying_tokens; // Increase count
}
}
// Debug information about how many tokens were retained
printf("Number of tokens that met the multiplied_min_p condition: %d\n", count_qualifying_tokens);
// Print the top tokens after filtering
printf("Tokens after filtering:\n\n");
for (size_t i = 0; i < filtered_candidates.size() && i < 10; ++i) { // Adjust 10 to however many top tokens you want to display
printf("Token %zu: %.6f%%\n", i + 1, filtered_candidates[i].p * 100); // Multiplying by 100 to convert to percentage
}
// Now we replace the original candidates with the filtered list.
std::copy(filtered_candidates.begin(), filtered_candidates.end(), candidates->data);
candidates->size = filtered_candidates.size();
return;
}
void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) { void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) {
if (z >= 1.0f || candidates->size <= 2) { if (z >= 1.0f || candidates->size <= 2) {
return; return;

View file

@ -600,6 +600,13 @@ extern "C" {
float p, float p,
size_t min_keep); size_t min_keep);
/// @details Minimum P sampling by Kalomaze
LLAMA_API void llama_sample_min_p(
struct llama_context * ctx,
llama_token_data_array * candidates,
float p,
size_t min_keep);
/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
LLAMA_API void llama_sample_tail_free( LLAMA_API void llama_sample_tail_free(
struct llama_context * ctx, struct llama_context * ctx,