From a9e2b74f1a0f4ca195fb087dbd2b6377f7f9025e Mon Sep 17 00:00:00 2001 From: kalomaze <66376113+kalomaze@users.noreply.github.com> Date: Sat, 28 Oct 2023 17:23:06 -0500 Subject: [PATCH] Super hacky starting implementation of Min P --- llama.cpp | 104 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 104 insertions(+) diff --git a/llama.cpp b/llama.cpp index 3d431ee7b..47b411bd6 100644 --- a/llama.cpp +++ b/llama.cpp @@ -7328,11 +7328,115 @@ void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * can } } +void read_or_write_base_min_p(float* base_min_p) { + // Define the filename + const std::string filename = "SamplerBaseMinP.txt"; + + // Check if the file exists + std::ifstream infile(filename); + if (!infile.good()) { + // File doesn't exist, create it with default values + std::ofstream outfile(filename); + if (outfile.is_open()) { + // Set the default value for base_min_p + *base_min_p = 0.05f; + outfile << "base_min_p = " << *base_min_p << "\n"; + outfile.close(); + } else { + // Handle the error during file opening or writing here (optional) + // For example, you might want to set a safe default value or notify the user + } + } else { + // File exists, read the values from it + std::string line; + while (getline(infile, line)) { + std::istringstream iss(line); + std::string key; + float value; + char equals; // Using char to read the '=' character + + // Read each line in the format key = value + if (iss >> key >> equals >> value && equals == '=' && key == "base_min_p") { + *base_min_p = value; + } + // Note: If the key doesn't match, or the format is incorrect, + // you might want to handle the error or set a default value + } + infile.close(); + } +} + void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { + // Variables for the special mode + float base_min_p; // This will hold the base minimum probability value + float multiplied_min_p; // This will hold the adjusted minimum probability threshold + + // If p is 1.0, we switch to a different sampling mode. if (p >= 1.0f) { + printf("returning, top p set to 1.0"); return; } + // If p is ~0.02, we switch to the Min P sampler + if (p >= 0.01f && p <= 0.03f) { + printf("USING MIN P SAMPLING MODE\n"); + + // Ensure the probabilities are calculated. + llama_sample_softmax(ctx, candidates); + + // Print the top tokens before filtering + printf("Top tokens before filtering:\n"); + for (size_t i = 0; i < candidates->size && i < 10; ++i) { + printf("Token %zu: %.6f%%\n", i + 1, candidates->data[i].p * 100); // Multiplying by 100 to convert to percentage + } + + base_min_p = 0.05; // For example, 5% as the base minimum probability before reading the text value + + read_or_write_base_min_p(&base_min_p); + + // Calculate the multiplication factor based on the highest scoring token. + float multiplication_factor = candidates->data[0].p; // Assuming the probabilities are sorted + printf("Highest scoring token probability (multiplication factor): %f\n", multiplication_factor); + + // Calculate the dynamic threshold. + multiplied_min_p = base_min_p * multiplication_factor; + printf("Base min_p value: %f\n", base_min_p); + printf("Calculated multiplied_min_p (threshold) value: %f\n", multiplied_min_p); + + // Store the tokens that meet the threshold in a new list. + std::vector filtered_candidates; + filtered_candidates.reserve(candidates->size); // Reserve to avoid multiple reallocations + + // Variable to count how many tokens meet the condition + int count_qualifying_tokens = 0; + + for (size_t i = 0; i < candidates->size; ++i) { + // If a token's probability is above the threshold, we keep it. + if (candidates->data[i].p >= multiplied_min_p) { + filtered_candidates.push_back(candidates->data[i]); + ++count_qualifying_tokens; // Increase count + } + } + + // Debug information about how many tokens were retained + printf("Number of tokens that met the multiplied_min_p condition: %d\n", count_qualifying_tokens); + + llama_sample_softmax(ctx, candidates); // re-normalize after pruning + + // Print the top tokens after filtering + printf("Tokens after filtering:\n"); + for (size_t i = 0; i < filtered_candidates.size() && i < 10; ++i) { // Adjust 10 to however many top tokens you want to display + printf("Token %zu: %.6f%%\n", i + 1, filtered_candidates[i].p * 100); // Multiplying by 100 to convert to percentage + } + + // Now we replace the original candidates with the filtered list. + std::copy(filtered_candidates.begin(), filtered_candidates.end(), candidates->data); + candidates->size = filtered_candidates.size(); + + // Since we're not actually sampling below a certain 'p', we return from the function after this. + return; + } + llama_sample_softmax(ctx, candidates); const int64_t t_start_sample_us = ggml_time_us();