Super hacky starting implementation of Min P
This commit is contained in:
parent
ff3bad83e2
commit
a9e2b74f1a
1 changed files with 104 additions and 0 deletions
104
llama.cpp
104
llama.cpp
|
@ -7328,8 +7328,112 @@ void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * can
|
|||
}
|
||||
}
|
||||
|
||||
void read_or_write_base_min_p(float* base_min_p) {
|
||||
// Define the filename
|
||||
const std::string filename = "SamplerBaseMinP.txt";
|
||||
|
||||
// Check if the file exists
|
||||
std::ifstream infile(filename);
|
||||
if (!infile.good()) {
|
||||
// File doesn't exist, create it with default values
|
||||
std::ofstream outfile(filename);
|
||||
if (outfile.is_open()) {
|
||||
// Set the default value for base_min_p
|
||||
*base_min_p = 0.05f;
|
||||
outfile << "base_min_p = " << *base_min_p << "\n";
|
||||
outfile.close();
|
||||
} else {
|
||||
// Handle the error during file opening or writing here (optional)
|
||||
// For example, you might want to set a safe default value or notify the user
|
||||
}
|
||||
} else {
|
||||
// File exists, read the values from it
|
||||
std::string line;
|
||||
while (getline(infile, line)) {
|
||||
std::istringstream iss(line);
|
||||
std::string key;
|
||||
float value;
|
||||
char equals; // Using char to read the '=' character
|
||||
|
||||
// Read each line in the format key = value
|
||||
if (iss >> key >> equals >> value && equals == '=' && key == "base_min_p") {
|
||||
*base_min_p = value;
|
||||
}
|
||||
// Note: If the key doesn't match, or the format is incorrect,
|
||||
// you might want to handle the error or set a default value
|
||||
}
|
||||
infile.close();
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
|
||||
// Variables for the special mode
|
||||
float base_min_p; // This will hold the base minimum probability value
|
||||
float multiplied_min_p; // This will hold the adjusted minimum probability threshold
|
||||
|
||||
// If p is 1.0, we switch to a different sampling mode.
|
||||
if (p >= 1.0f) {
|
||||
printf("returning, top p set to 1.0");
|
||||
return;
|
||||
}
|
||||
|
||||
// If p is ~0.02, we switch to the Min P sampler
|
||||
if (p >= 0.01f && p <= 0.03f) {
|
||||
printf("USING MIN P SAMPLING MODE\n");
|
||||
|
||||
// Ensure the probabilities are calculated.
|
||||
llama_sample_softmax(ctx, candidates);
|
||||
|
||||
// Print the top tokens before filtering
|
||||
printf("Top tokens before filtering:\n");
|
||||
for (size_t i = 0; i < candidates->size && i < 10; ++i) {
|
||||
printf("Token %zu: %.6f%%\n", i + 1, candidates->data[i].p * 100); // Multiplying by 100 to convert to percentage
|
||||
}
|
||||
|
||||
base_min_p = 0.05; // For example, 5% as the base minimum probability before reading the text value
|
||||
|
||||
read_or_write_base_min_p(&base_min_p);
|
||||
|
||||
// Calculate the multiplication factor based on the highest scoring token.
|
||||
float multiplication_factor = candidates->data[0].p; // Assuming the probabilities are sorted
|
||||
printf("Highest scoring token probability (multiplication factor): %f\n", multiplication_factor);
|
||||
|
||||
// Calculate the dynamic threshold.
|
||||
multiplied_min_p = base_min_p * multiplication_factor;
|
||||
printf("Base min_p value: %f\n", base_min_p);
|
||||
printf("Calculated multiplied_min_p (threshold) value: %f\n", multiplied_min_p);
|
||||
|
||||
// Store the tokens that meet the threshold in a new list.
|
||||
std::vector<llama_token_data> filtered_candidates;
|
||||
filtered_candidates.reserve(candidates->size); // Reserve to avoid multiple reallocations
|
||||
|
||||
// Variable to count how many tokens meet the condition
|
||||
int count_qualifying_tokens = 0;
|
||||
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
// If a token's probability is above the threshold, we keep it.
|
||||
if (candidates->data[i].p >= multiplied_min_p) {
|
||||
filtered_candidates.push_back(candidates->data[i]);
|
||||
++count_qualifying_tokens; // Increase count
|
||||
}
|
||||
}
|
||||
|
||||
// Debug information about how many tokens were retained
|
||||
printf("Number of tokens that met the multiplied_min_p condition: %d\n", count_qualifying_tokens);
|
||||
|
||||
llama_sample_softmax(ctx, candidates); // re-normalize after pruning
|
||||
|
||||
// Print the top tokens after filtering
|
||||
printf("Tokens after filtering:\n");
|
||||
for (size_t i = 0; i < filtered_candidates.size() && i < 10; ++i) { // Adjust 10 to however many top tokens you want to display
|
||||
printf("Token %zu: %.6f%%\n", i + 1, filtered_candidates[i].p * 100); // Multiplying by 100 to convert to percentage
|
||||
}
|
||||
|
||||
// Now we replace the original candidates with the filtered list.
|
||||
std::copy(filtered_candidates.begin(), filtered_candidates.end(), candidates->data);
|
||||
candidates->size = filtered_candidates.size();
|
||||
|
||||
// Since we're not actually sampling below a certain 'p', we return from the function after this.
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue