Super hacky starting implementation of Min P
This commit is contained in:
parent
ff3bad83e2
commit
a9e2b74f1a
1 changed files with 104 additions and 0 deletions
104
llama.cpp
104
llama.cpp
|
@ -7328,11 +7328,115 @@ void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * can
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void read_or_write_base_min_p(float* base_min_p) {
|
||||||
|
// Define the filename
|
||||||
|
const std::string filename = "SamplerBaseMinP.txt";
|
||||||
|
|
||||||
|
// Check if the file exists
|
||||||
|
std::ifstream infile(filename);
|
||||||
|
if (!infile.good()) {
|
||||||
|
// File doesn't exist, create it with default values
|
||||||
|
std::ofstream outfile(filename);
|
||||||
|
if (outfile.is_open()) {
|
||||||
|
// Set the default value for base_min_p
|
||||||
|
*base_min_p = 0.05f;
|
||||||
|
outfile << "base_min_p = " << *base_min_p << "\n";
|
||||||
|
outfile.close();
|
||||||
|
} else {
|
||||||
|
// Handle the error during file opening or writing here (optional)
|
||||||
|
// For example, you might want to set a safe default value or notify the user
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// File exists, read the values from it
|
||||||
|
std::string line;
|
||||||
|
while (getline(infile, line)) {
|
||||||
|
std::istringstream iss(line);
|
||||||
|
std::string key;
|
||||||
|
float value;
|
||||||
|
char equals; // Using char to read the '=' character
|
||||||
|
|
||||||
|
// Read each line in the format key = value
|
||||||
|
if (iss >> key >> equals >> value && equals == '=' && key == "base_min_p") {
|
||||||
|
*base_min_p = value;
|
||||||
|
}
|
||||||
|
// Note: If the key doesn't match, or the format is incorrect,
|
||||||
|
// you might want to handle the error or set a default value
|
||||||
|
}
|
||||||
|
infile.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
|
void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
|
||||||
|
// Variables for the special mode
|
||||||
|
float base_min_p; // This will hold the base minimum probability value
|
||||||
|
float multiplied_min_p; // This will hold the adjusted minimum probability threshold
|
||||||
|
|
||||||
|
// If p is 1.0, we switch to a different sampling mode.
|
||||||
if (p >= 1.0f) {
|
if (p >= 1.0f) {
|
||||||
|
printf("returning, top p set to 1.0");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If p is ~0.02, we switch to the Min P sampler
|
||||||
|
if (p >= 0.01f && p <= 0.03f) {
|
||||||
|
printf("USING MIN P SAMPLING MODE\n");
|
||||||
|
|
||||||
|
// Ensure the probabilities are calculated.
|
||||||
|
llama_sample_softmax(ctx, candidates);
|
||||||
|
|
||||||
|
// Print the top tokens before filtering
|
||||||
|
printf("Top tokens before filtering:\n");
|
||||||
|
for (size_t i = 0; i < candidates->size && i < 10; ++i) {
|
||||||
|
printf("Token %zu: %.6f%%\n", i + 1, candidates->data[i].p * 100); // Multiplying by 100 to convert to percentage
|
||||||
|
}
|
||||||
|
|
||||||
|
base_min_p = 0.05; // For example, 5% as the base minimum probability before reading the text value
|
||||||
|
|
||||||
|
read_or_write_base_min_p(&base_min_p);
|
||||||
|
|
||||||
|
// Calculate the multiplication factor based on the highest scoring token.
|
||||||
|
float multiplication_factor = candidates->data[0].p; // Assuming the probabilities are sorted
|
||||||
|
printf("Highest scoring token probability (multiplication factor): %f\n", multiplication_factor);
|
||||||
|
|
||||||
|
// Calculate the dynamic threshold.
|
||||||
|
multiplied_min_p = base_min_p * multiplication_factor;
|
||||||
|
printf("Base min_p value: %f\n", base_min_p);
|
||||||
|
printf("Calculated multiplied_min_p (threshold) value: %f\n", multiplied_min_p);
|
||||||
|
|
||||||
|
// Store the tokens that meet the threshold in a new list.
|
||||||
|
std::vector<llama_token_data> filtered_candidates;
|
||||||
|
filtered_candidates.reserve(candidates->size); // Reserve to avoid multiple reallocations
|
||||||
|
|
||||||
|
// Variable to count how many tokens meet the condition
|
||||||
|
int count_qualifying_tokens = 0;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < candidates->size; ++i) {
|
||||||
|
// If a token's probability is above the threshold, we keep it.
|
||||||
|
if (candidates->data[i].p >= multiplied_min_p) {
|
||||||
|
filtered_candidates.push_back(candidates->data[i]);
|
||||||
|
++count_qualifying_tokens; // Increase count
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Debug information about how many tokens were retained
|
||||||
|
printf("Number of tokens that met the multiplied_min_p condition: %d\n", count_qualifying_tokens);
|
||||||
|
|
||||||
|
llama_sample_softmax(ctx, candidates); // re-normalize after pruning
|
||||||
|
|
||||||
|
// Print the top tokens after filtering
|
||||||
|
printf("Tokens after filtering:\n");
|
||||||
|
for (size_t i = 0; i < filtered_candidates.size() && i < 10; ++i) { // Adjust 10 to however many top tokens you want to display
|
||||||
|
printf("Token %zu: %.6f%%\n", i + 1, filtered_candidates[i].p * 100); // Multiplying by 100 to convert to percentage
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now we replace the original candidates with the filtered list.
|
||||||
|
std::copy(filtered_candidates.begin(), filtered_candidates.end(), candidates->data);
|
||||||
|
candidates->size = filtered_candidates.size();
|
||||||
|
|
||||||
|
// Since we're not actually sampling below a certain 'p', we return from the function after this.
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
llama_sample_softmax(ctx, candidates);
|
llama_sample_softmax(ctx, candidates);
|
||||||
|
|
||||||
const int64_t t_start_sample_us = ggml_time_us();
|
const int64_t t_start_sample_us = ggml_time_us();
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue