Super hacky starting implementation of Min P

This commit is contained in:
kalomaze 2023-10-28 17:23:06 -05:00
parent ff3bad83e2
commit a9e2b74f1a

104
llama.cpp
View file

@ -7328,11 +7328,115 @@ void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * can
} }
} }
void read_or_write_base_min_p(float* base_min_p) {
// Define the filename
const std::string filename = "SamplerBaseMinP.txt";
// Check if the file exists
std::ifstream infile(filename);
if (!infile.good()) {
// File doesn't exist, create it with default values
std::ofstream outfile(filename);
if (outfile.is_open()) {
// Set the default value for base_min_p
*base_min_p = 0.05f;
outfile << "base_min_p = " << *base_min_p << "\n";
outfile.close();
} else {
// Handle the error during file opening or writing here (optional)
// For example, you might want to set a safe default value or notify the user
}
} else {
// File exists, read the values from it
std::string line;
while (getline(infile, line)) {
std::istringstream iss(line);
std::string key;
float value;
char equals; // Using char to read the '=' character
// Read each line in the format key = value
if (iss >> key >> equals >> value && equals == '=' && key == "base_min_p") {
*base_min_p = value;
}
// Note: If the key doesn't match, or the format is incorrect,
// you might want to handle the error or set a default value
}
infile.close();
}
}
void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
// Variables for the special mode
float base_min_p; // This will hold the base minimum probability value
float multiplied_min_p; // This will hold the adjusted minimum probability threshold
// If p is 1.0, we switch to a different sampling mode.
if (p >= 1.0f) { if (p >= 1.0f) {
printf("returning, top p set to 1.0");
return; return;
} }
// If p is ~0.02, we switch to the Min P sampler
if (p >= 0.01f && p <= 0.03f) {
printf("USING MIN P SAMPLING MODE\n");
// Ensure the probabilities are calculated.
llama_sample_softmax(ctx, candidates);
// Print the top tokens before filtering
printf("Top tokens before filtering:\n");
for (size_t i = 0; i < candidates->size && i < 10; ++i) {
printf("Token %zu: %.6f%%\n", i + 1, candidates->data[i].p * 100); // Multiplying by 100 to convert to percentage
}
base_min_p = 0.05; // For example, 5% as the base minimum probability before reading the text value
read_or_write_base_min_p(&base_min_p);
// Calculate the multiplication factor based on the highest scoring token.
float multiplication_factor = candidates->data[0].p; // Assuming the probabilities are sorted
printf("Highest scoring token probability (multiplication factor): %f\n", multiplication_factor);
// Calculate the dynamic threshold.
multiplied_min_p = base_min_p * multiplication_factor;
printf("Base min_p value: %f\n", base_min_p);
printf("Calculated multiplied_min_p (threshold) value: %f\n", multiplied_min_p);
// Store the tokens that meet the threshold in a new list.
std::vector<llama_token_data> filtered_candidates;
filtered_candidates.reserve(candidates->size); // Reserve to avoid multiple reallocations
// Variable to count how many tokens meet the condition
int count_qualifying_tokens = 0;
for (size_t i = 0; i < candidates->size; ++i) {
// If a token's probability is above the threshold, we keep it.
if (candidates->data[i].p >= multiplied_min_p) {
filtered_candidates.push_back(candidates->data[i]);
++count_qualifying_tokens; // Increase count
}
}
// Debug information about how many tokens were retained
printf("Number of tokens that met the multiplied_min_p condition: %d\n", count_qualifying_tokens);
llama_sample_softmax(ctx, candidates); // re-normalize after pruning
// Print the top tokens after filtering
printf("Tokens after filtering:\n");
for (size_t i = 0; i < filtered_candidates.size() && i < 10; ++i) { // Adjust 10 to however many top tokens you want to display
printf("Token %zu: %.6f%%\n", i + 1, filtered_candidates[i].p * 100); // Multiplying by 100 to convert to percentage
}
// Now we replace the original candidates with the filtered list.
std::copy(filtered_candidates.begin(), filtered_candidates.end(), candidates->data);
candidates->size = filtered_candidates.size();
// Since we're not actually sampling below a certain 'p', we return from the function after this.
return;
}
llama_sample_softmax(ctx, candidates); llama_sample_softmax(ctx, candidates);
const int64_t t_start_sample_us = ggml_time_us(); const int64_t t_start_sample_us = ggml_time_us();