llama : move sampling code into llama-sampling
ggml-ci
This commit is contained in:
parent
081fe431aa
commit
0ddc8e361c
7 changed files with 758 additions and 699 deletions
9
Makefile
9
Makefile
|
@ -876,6 +876,7 @@ OBJ_GGML += \
|
|||
|
||||
OBJ_LLAMA = \
|
||||
src/llama.o \
|
||||
src/llama-sampling.o \
|
||||
src/unicode.o \
|
||||
src/unicode-data.o
|
||||
|
||||
|
@ -1055,6 +1056,7 @@ src/unicode-data.o: \
|
|||
|
||||
src/llama.o: \
|
||||
src/llama.cpp \
|
||||
src/llama-impl.h \
|
||||
src/unicode.h \
|
||||
include/llama.h \
|
||||
ggml/include/ggml-cuda.h \
|
||||
|
@ -1064,6 +1066,13 @@ src/llama.o: \
|
|||
ggml/include/ggml-backend.h
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||
|
||||
src/llama-sampling.o: \
|
||||
src/llama-sampling.cpp \
|
||||
src/llama-sampling.h \
|
||||
src/llama-impl.h \
|
||||
include/llama.h
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||
|
||||
$(LIB_LLAMA): \
|
||||
$(OBJ_LLAMA) \
|
||||
$(LIB_GGML)
|
||||
|
|
|
@ -1084,12 +1084,6 @@ extern "C" {
|
|||
llama_token_data_array * candidates,
|
||||
float temp);
|
||||
|
||||
/// @details Apply constraints from grammar
|
||||
LLAMA_API void llama_sample_grammar(
|
||||
struct llama_context * ctx,
|
||||
llama_token_data_array * candidates,
|
||||
const struct llama_grammar * grammar);
|
||||
|
||||
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
||||
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
||||
|
@ -1127,6 +1121,12 @@ extern "C" {
|
|||
struct llama_context * ctx,
|
||||
llama_token_data_array * candidates);
|
||||
|
||||
/// @details Apply constraints from grammar
|
||||
LLAMA_API void llama_sample_grammar(
|
||||
struct llama_context * ctx,
|
||||
llama_token_data_array * candidates,
|
||||
const struct llama_grammar * grammar);
|
||||
|
||||
/// @details Accepts the sampled token into the grammar
|
||||
LLAMA_API void llama_grammar_accept_token(
|
||||
struct llama_context * ctx,
|
||||
|
|
|
@ -14,6 +14,7 @@ endif()
|
|||
add_library(llama
|
||||
../include/llama.h
|
||||
llama.cpp
|
||||
llama-sampling.cpp
|
||||
unicode.h
|
||||
unicode.cpp
|
||||
unicode-data.cpp
|
||||
|
|
50
src/llama-impl.h
Normal file
50
src/llama-impl.h
Normal file
|
@ -0,0 +1,50 @@
|
|||
#pragma once
|
||||
|
||||
#define LLAMA_API_INTERNAL
|
||||
#include "llama.h"
|
||||
|
||||
#include <array>
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <cstdint>
|
||||
#include <random>
|
||||
|
||||
#ifdef __has_include
|
||||
#if __has_include(<unistd.h>)
|
||||
#include <unistd.h>
|
||||
#if defined(_POSIX_MAPPED_FILES)
|
||||
#include <sys/mman.h>
|
||||
#include <fcntl.h>
|
||||
#endif
|
||||
#if defined(_POSIX_MEMLOCK_RANGE)
|
||||
#include <sys/resource.h>
|
||||
#endif
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// bump if necessary
|
||||
#define LLAMA_MAX_NODES 8192
|
||||
#define LLAMA_MAX_LAYERS 256
|
||||
#define LLAMA_MAX_EXPERTS 160 // DeepSeekV2
|
||||
|
||||
#ifdef __GNUC__
|
||||
#ifdef __MINGW32__
|
||||
#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
|
||||
#else
|
||||
#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
|
||||
#endif
|
||||
#else
|
||||
#define LLAMA_ATTRIBUTE_FORMAT(...)
|
||||
#endif
|
||||
|
||||
//
|
||||
// logging
|
||||
//
|
||||
|
||||
LLAMA_ATTRIBUTE_FORMAT(2, 3)
|
||||
void llama_log_internal (ggml_log_level level, const char * format, ...);
|
||||
void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data);
|
||||
|
||||
#define LLAMA_LOG_INFO(...) llama_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__)
|
||||
#define LLAMA_LOG_WARN(...) llama_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__)
|
||||
#define LLAMA_LOG_ERROR(...) llama_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
|
635
src/llama-sampling.cpp
Normal file
635
src/llama-sampling.cpp
Normal file
|
@ -0,0 +1,635 @@
|
|||
#include "llama-sampling.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include <ctime>
|
||||
#include <cfloat>
|
||||
#include <numeric>
|
||||
#include <unordered_map>
|
||||
|
||||
void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) {
|
||||
if (seed == LLAMA_DEFAULT_SEED) {
|
||||
seed = time(NULL);
|
||||
}
|
||||
|
||||
llama_get_sampling(ctx)->rng.seed(seed);
|
||||
}
|
||||
|
||||
void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) {
|
||||
GGML_ASSERT(candidates->size > 0);
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
// Sort the logits in descending order
|
||||
if (!candidates->sorted) {
|
||||
std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
|
||||
return a.logit > b.logit;
|
||||
});
|
||||
candidates->sorted = true;
|
||||
}
|
||||
|
||||
float max_l = candidates->data[0].logit;
|
||||
float cum_sum = 0.0f;
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
float p = expf(candidates->data[i].logit - max_l);
|
||||
candidates->data[i].p = p;
|
||||
cum_sum += p;
|
||||
}
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
candidates->data[i].p /= cum_sum;
|
||||
}
|
||||
|
||||
if (ctx) {
|
||||
llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
|
||||
// TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
|
||||
// if (k >= (int32_t)candidates->size) {
|
||||
// return;
|
||||
// }
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
if (k <= 0) {
|
||||
k = candidates->size;
|
||||
}
|
||||
|
||||
k = std::max(k, (int) min_keep);
|
||||
k = std::min(k, (int) candidates->size);
|
||||
|
||||
// Sort scores in descending order
|
||||
if (!candidates->sorted) {
|
||||
auto comp = [](const llama_token_data & a, const llama_token_data & b) {
|
||||
return a.logit > b.logit;
|
||||
};
|
||||
if (k <= 128) {
|
||||
std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp);
|
||||
} else {
|
||||
constexpr int nbuckets = 128;
|
||||
constexpr float bucket_low = -10.0f;
|
||||
constexpr float bucket_high = 10.0f;
|
||||
constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low);
|
||||
constexpr float bucker_inter = -bucket_low * bucket_scale;
|
||||
|
||||
std::vector<int> bucket_idx(candidates->size);
|
||||
std::vector<int> histo(nbuckets, 0);
|
||||
|
||||
for (int i = 0; i < (int)candidates->size; ++i) {
|
||||
const float val = candidates->data[i].logit;
|
||||
int ib = int(bucket_scale * val + bucker_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
|
||||
ib = std::max(0, std::min(nbuckets-1, ib));
|
||||
bucket_idx[i] = ib;
|
||||
++histo[ib];
|
||||
}
|
||||
int nhave = 0;
|
||||
int ib = nbuckets - 1;
|
||||
for ( ; ib >= 0; --ib) {
|
||||
nhave += histo[ib];
|
||||
if (nhave >= k) break;
|
||||
}
|
||||
std::vector<llama_token_data> tmp_tokens(nhave);
|
||||
auto ptr = tmp_tokens.data();
|
||||
std::vector<llama_token_data*> bucket_ptrs;
|
||||
bucket_ptrs.reserve(nbuckets - ib);
|
||||
for (int j = nbuckets - 1; j >= ib; --j) {
|
||||
bucket_ptrs.push_back(ptr);
|
||||
ptr += histo[j];
|
||||
}
|
||||
for (int i = 0; i < (int)candidates->size; ++i) {
|
||||
int j = bucket_idx[i];
|
||||
if (j >= ib) {
|
||||
*bucket_ptrs[nbuckets-1-j]++ = candidates->data[i];
|
||||
}
|
||||
}
|
||||
|
||||
ptr = tmp_tokens.data();
|
||||
int ndone = 0;
|
||||
for (int j = nbuckets-1; j > ib; --j) {
|
||||
std::sort(ptr, ptr + histo[j], comp);
|
||||
ptr += histo[j];
|
||||
ndone += histo[j];
|
||||
}
|
||||
std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp);
|
||||
|
||||
std::memcpy(candidates->data, tmp_tokens.data(), k*sizeof(llama_token_data));
|
||||
|
||||
}
|
||||
candidates->sorted = true;
|
||||
}
|
||||
candidates->size = k;
|
||||
|
||||
if (ctx) {
|
||||
llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
|
||||
if (p >= 1.0f) {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_sample_softmax(ctx, candidates);
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
// Compute the cumulative probabilities
|
||||
float cum_sum = 0.0f;
|
||||
size_t last_idx = candidates->size;
|
||||
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
cum_sum += candidates->data[i].p;
|
||||
|
||||
// Check if the running sum is at least p or if we have kept at least min_keep tokens
|
||||
// we set the last index to i+1 to indicate that the current iterate should be included in the set
|
||||
if (cum_sum >= p && i + 1 >= min_keep) {
|
||||
last_idx = i + 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Resize the output vector to keep only the top-p tokens
|
||||
candidates->size = last_idx;
|
||||
|
||||
if (ctx) {
|
||||
llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
|
||||
if (p <= 0.0f || !candidates->size) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
bool min_p_applied = false;
|
||||
|
||||
// if the candidates aren't sorted, try the unsorted implementation first
|
||||
if (!candidates->sorted) {
|
||||
std::vector<llama_token_data> filtered_tokens;
|
||||
|
||||
float max_logit = -FLT_MAX;
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
max_logit = std::max(max_logit, candidates->data[i].logit);
|
||||
}
|
||||
const float min_logit = max_logit + logf(p); // min logit for p_i >= p * p_max
|
||||
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
if (candidates->data[i].logit >= min_logit) {
|
||||
filtered_tokens.push_back(candidates->data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// if we have enough values the operation was a success
|
||||
if (filtered_tokens.size() >= min_keep) {
|
||||
memcpy(candidates->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
|
||||
candidates->size = filtered_tokens.size();
|
||||
min_p_applied = true;
|
||||
}
|
||||
}
|
||||
|
||||
// if the candidates are sorted or the unsorted implementation failed, use this implementation
|
||||
if (!min_p_applied) {
|
||||
// Sort the logits in descending order
|
||||
if (!candidates->sorted) {
|
||||
std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
|
||||
return a.logit > b.logit;
|
||||
});
|
||||
candidates->sorted = true;
|
||||
}
|
||||
|
||||
const float min_logit = candidates->data[0].logit + logf(p); // min logit for p_i >= p * p_max
|
||||
size_t i = 1; // first token always matches
|
||||
|
||||
for (; i < candidates->size; ++i) {
|
||||
if (candidates->data[i].logit < min_logit && i >= min_keep) {
|
||||
break; // prob too small
|
||||
}
|
||||
}
|
||||
|
||||
// Resize the output vector to keep only the matching tokens
|
||||
candidates->size = i;
|
||||
}
|
||||
|
||||
if (ctx) {
|
||||
llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) {
|
||||
if (z >= 1.0f || candidates->size <= 2) {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_sample_softmax(nullptr, candidates);
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
// Compute the first and second derivatives
|
||||
std::vector<float> first_derivatives(candidates->size - 1);
|
||||
std::vector<float> second_derivatives(candidates->size - 2);
|
||||
|
||||
for (size_t i = 0; i < first_derivatives.size(); ++i) {
|
||||
first_derivatives[i] = candidates->data[i].p - candidates->data[i + 1].p;
|
||||
}
|
||||
for (size_t i = 0; i < second_derivatives.size(); ++i) {
|
||||
second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1];
|
||||
}
|
||||
|
||||
// Calculate absolute value of second derivatives
|
||||
for (size_t i = 0; i < second_derivatives.size(); ++i) {
|
||||
second_derivatives[i] = std::abs(second_derivatives[i]);
|
||||
}
|
||||
|
||||
// Normalize the second derivatives
|
||||
{
|
||||
const float second_derivatives_sum = std::accumulate(second_derivatives.begin(), second_derivatives.end(), 0.0f);
|
||||
|
||||
if (second_derivatives_sum > 1e-6f) {
|
||||
for (float & value : second_derivatives) {
|
||||
value /= second_derivatives_sum;
|
||||
}
|
||||
} else {
|
||||
for (float & value : second_derivatives) {
|
||||
value = 1.0f / second_derivatives.size();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float cum_sum = 0.0f;
|
||||
size_t last_idx = candidates->size;
|
||||
for (size_t i = 0; i < second_derivatives.size(); ++i) {
|
||||
cum_sum += second_derivatives[i];
|
||||
|
||||
// Check if the running sum is greater than z or if we have kept at least min_keep tokens
|
||||
if (cum_sum > z && i >= min_keep) {
|
||||
last_idx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Resize the output vector to keep only the tokens above the tail location
|
||||
candidates->size = last_idx;
|
||||
|
||||
if (ctx) {
|
||||
llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
|
||||
// Reference implementation:
|
||||
// https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
|
||||
if (p >= 1.0f) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Compute the softmax of logits and calculate entropy
|
||||
llama_sample_softmax(nullptr, candidates);
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
float entropy = 0.0f;
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
entropy += -candidates->data[i].p * logf(candidates->data[i].p);
|
||||
}
|
||||
|
||||
// Compute the absolute difference between negative log probability and entropy for each candidate
|
||||
std::vector<float> shifted_scores;
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
float shifted_score = fabsf(-logf(candidates->data[i].p) - entropy);
|
||||
shifted_scores.push_back(shifted_score);
|
||||
}
|
||||
|
||||
// Sort tokens based on the shifted_scores and their corresponding indices
|
||||
std::vector<size_t> indices(candidates->size);
|
||||
std::iota(indices.begin(), indices.end(), 0);
|
||||
|
||||
std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) {
|
||||
return shifted_scores[a] < shifted_scores[b];
|
||||
});
|
||||
|
||||
// Compute the cumulative probabilities
|
||||
float cum_sum = 0.0f;
|
||||
size_t last_idx = indices.size();
|
||||
|
||||
for (size_t i = 0; i < indices.size(); ++i) {
|
||||
size_t idx = indices[i];
|
||||
cum_sum += candidates->data[idx].p;
|
||||
|
||||
// Check if the running sum is greater than typical or if we have kept at least min_keep tokens
|
||||
if (cum_sum > p && i >= min_keep - 1) {
|
||||
last_idx = i + 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Resize the output vector to keep only the locally typical tokens
|
||||
std::vector<llama_token_data> new_candidates;
|
||||
for (size_t i = 0; i < last_idx; ++i) {
|
||||
size_t idx = indices[i];
|
||||
new_candidates.push_back(candidates->data[idx]);
|
||||
}
|
||||
|
||||
// Replace the data in candidates with the new_candidates data
|
||||
std::copy(new_candidates.begin(), new_candidates.end(), candidates->data);
|
||||
candidates->size = new_candidates.size();
|
||||
candidates->sorted = false;
|
||||
|
||||
if (ctx) {
|
||||
llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_entropy(struct llama_context * ctx, llama_token_data_array * candidates_p, float min_temp, float max_temp, float exponent_val) {
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
// no need to do anything if there is only one (or zero) candidates
|
||||
if(candidates_p->size <= 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Calculate maximum possible entropy
|
||||
float max_entropy = -logf(1.0f / candidates_p->size);
|
||||
|
||||
llama_sample_softmax(nullptr, candidates_p);
|
||||
|
||||
// Calculate entropy of the softmax probabilities
|
||||
float entropy = 0.0f;
|
||||
for (size_t i = 0; i < candidates_p->size; ++i) {
|
||||
float prob = candidates_p->data[i].p;
|
||||
if (prob > 0.0f) { // Ensure no log(0)
|
||||
entropy -= prob * logf(prob);
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize the entropy (max_entropy cannot be 0 here because we checked candidates_p->size != 1 above)
|
||||
float normalized_entropy = entropy / max_entropy;
|
||||
|
||||
// Map the normalized entropy to the desired temperature range using the power function
|
||||
float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val);
|
||||
|
||||
#ifdef DEBUG
|
||||
LLAMA_LOG_INFO("Your text maxtemp value is: %f\n", max_temp);
|
||||
LLAMA_LOG_INFO("Entropy: %f\n", entropy);
|
||||
LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy);
|
||||
LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy);
|
||||
LLAMA_LOG_INFO("Exponent: %f\n", exponent_val);
|
||||
LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp);
|
||||
#endif
|
||||
|
||||
// Apply the dynamically calculated temperature scaling
|
||||
for (size_t i = 0; i < candidates_p->size; ++i) {
|
||||
candidates_p->data[i].logit /= dyn_temp;
|
||||
}
|
||||
|
||||
// Re-compute softmax probabilities after scaling logits with dynamic temperature
|
||||
double max_l_double = candidates_p->data[0].logit;
|
||||
double cum_sum_double = 0.0;
|
||||
for (size_t i = 0; i < candidates_p->size; ++i) {
|
||||
double p = exp(candidates_p->data[i].logit - max_l_double);
|
||||
candidates_p->data[i].p = p; // Store the scaled probability
|
||||
cum_sum_double += p;
|
||||
}
|
||||
for (size_t i = 0; i < candidates_p->size; ++i) {
|
||||
candidates_p->data[i].p /= cum_sum_double; // Re-normalize the probabilities
|
||||
}
|
||||
|
||||
#ifdef DEBUG
|
||||
// Print the updated top 25 probabilities after temperature scaling
|
||||
LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n");
|
||||
for (size_t i = 0; i < 25 && i < candidates_p->size; ++i) {
|
||||
LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, candidates_p->data[i].p * 100.0f);
|
||||
}
|
||||
#endif
|
||||
|
||||
if (ctx) {
|
||||
llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_temp(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) {
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
for (size_t i = 0; i < candidates_p->size; ++i) {
|
||||
candidates_p->data[i].logit /= temp;
|
||||
}
|
||||
|
||||
if (ctx) {
|
||||
llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_repetition_penalties(
|
||||
struct llama_context * ctx,
|
||||
llama_token_data_array * candidates,
|
||||
const llama_token * last_tokens,
|
||||
size_t penalty_last_n,
|
||||
float penalty_repeat,
|
||||
float penalty_freq,
|
||||
float penalty_present) {
|
||||
if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
// Create a frequency map to count occurrences of each token in last_tokens
|
||||
std::unordered_map<llama_token, int> token_count;
|
||||
for (size_t i = 0; i < penalty_last_n; ++i) {
|
||||
token_count[last_tokens[i]]++;
|
||||
}
|
||||
|
||||
// Apply frequency and presence penalties to the candidates
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
const auto token_iter = token_count.find(candidates->data[i].id);
|
||||
if (token_iter == token_count.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const int count = token_iter->second;
|
||||
|
||||
// The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
|
||||
// This is common fix for this problem, which is to multiply by the penalty instead of dividing.
|
||||
if (candidates->data[i].logit <= 0) {
|
||||
candidates->data[i].logit *= penalty_repeat;
|
||||
} else {
|
||||
candidates->data[i].logit /= penalty_repeat;
|
||||
}
|
||||
|
||||
candidates->data[i].logit -= float(count) * penalty_freq + float(count > 0) * penalty_present;
|
||||
}
|
||||
|
||||
candidates->sorted = false;
|
||||
|
||||
if (ctx) {
|
||||
llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
static void llama_log_softmax(float * array, size_t size) {
|
||||
float max_l = *std::max_element(array, array + size);
|
||||
float sum = 0.f;
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
float p = expf(array[i] - max_l);
|
||||
sum += p;
|
||||
array[i] = p;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
array[i] = logf(array[i] / sum);
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_apply_guidance(
|
||||
struct llama_context * ctx,
|
||||
float * logits,
|
||||
float * logits_guidance,
|
||||
float scale) {
|
||||
GGML_ASSERT(ctx);
|
||||
|
||||
const auto t_start_sample_us = ggml_time_us();
|
||||
const auto n_vocab = llama_get_sampling(ctx)->n_vocab;
|
||||
|
||||
llama_log_softmax(logits, n_vocab);
|
||||
llama_log_softmax(logits_guidance, n_vocab);
|
||||
|
||||
for (int i = 0; i < n_vocab; ++i) {
|
||||
auto & l = logits[i];
|
||||
const auto & g = logits_guidance[i];
|
||||
|
||||
l = scale * (l - g) + g;
|
||||
}
|
||||
|
||||
llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
|
||||
llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
|
||||
GGML_ASSERT(ctx);
|
||||
|
||||
const int32_t n_vocab = float(llama_get_sampling(ctx)->n_vocab);
|
||||
|
||||
int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
llama_sample_softmax(nullptr, candidates);
|
||||
|
||||
// Estimate s_hat using the most probable m tokens
|
||||
float s_hat = 0.0;
|
||||
float sum_ti_bi = 0.0;
|
||||
float sum_ti_sq = 0.0;
|
||||
for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) {
|
||||
float t_i = logf(float(i + 2) / float(i + 1));
|
||||
float b_i = logf(candidates->data[i].p / candidates->data[i + 1].p);
|
||||
sum_ti_bi += t_i * b_i;
|
||||
sum_ti_sq += t_i * t_i;
|
||||
}
|
||||
s_hat = sum_ti_bi / sum_ti_sq;
|
||||
|
||||
// Compute k from the estimated s_hat and target surprise value
|
||||
float epsilon_hat = s_hat - 1;
|
||||
float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat);
|
||||
|
||||
// Sample the next word X using top-k sampling
|
||||
llama_sample_top_k(nullptr, candidates, int(k), 1);
|
||||
llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
llama_token X = llama_sample_token(ctx, candidates);
|
||||
t_start_sample_us = ggml_time_us();
|
||||
|
||||
// Compute error as the difference between observed surprise and target surprise value
|
||||
size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
|
||||
return candidate.id == X;
|
||||
}));
|
||||
float observed_surprise = -log2f(candidates->data[X_idx].p);
|
||||
float e = observed_surprise - tau;
|
||||
|
||||
// Update mu using the learning rate and error
|
||||
*mu = *mu - eta * e;
|
||||
|
||||
llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
return X;
|
||||
}
|
||||
|
||||
llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu) {
|
||||
int64_t t_start_sample_us;
|
||||
t_start_sample_us = ggml_time_us();
|
||||
|
||||
llama_sample_softmax(ctx, candidates);
|
||||
|
||||
// Truncate the words with surprise values greater than mu
|
||||
candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
|
||||
return -log2f(candidate.p) > *mu;
|
||||
}));
|
||||
|
||||
if (candidates->size == 0) {
|
||||
candidates->size = 1;
|
||||
}
|
||||
|
||||
if (ctx) {
|
||||
llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
|
||||
// Normalize the probabilities of the remaining words
|
||||
llama_sample_softmax(ctx, candidates);
|
||||
|
||||
// Sample the next word X from the remaining words
|
||||
llama_token X = llama_sample_token(ctx, candidates);
|
||||
t_start_sample_us = ggml_time_us();
|
||||
|
||||
// Compute error as the difference between observed surprise and target surprise value
|
||||
size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
|
||||
return candidate.id == X;
|
||||
}));
|
||||
float observed_surprise = -log2f(candidates->data[X_idx].p);
|
||||
float e = observed_surprise - tau;
|
||||
|
||||
// Update mu using the learning rate and error
|
||||
*mu = *mu - eta * e;
|
||||
|
||||
if (ctx) {
|
||||
llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
return X;
|
||||
}
|
||||
|
||||
llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates) {
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
// Find max element
|
||||
auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
|
||||
return a.logit < b.logit;
|
||||
});
|
||||
|
||||
llama_token result = max_iter->id;
|
||||
if (ctx) {
|
||||
llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
llama_get_sampling(ctx)->n_sample++;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng) {
|
||||
GGML_ASSERT(ctx);
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
llama_sample_softmax(nullptr, candidates);
|
||||
|
||||
std::vector<float> probs;
|
||||
probs.reserve(candidates->size);
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
probs.push_back(candidates->data[i].p);
|
||||
}
|
||||
|
||||
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
||||
int idx = dist(rng);
|
||||
|
||||
llama_token result = candidates->data[idx].id;
|
||||
|
||||
llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
llama_get_sampling(ctx)->n_sample++;
|
||||
return result;
|
||||
}
|
||||
|
||||
llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) {
|
||||
return llama_sample_token_with_rng(ctx, candidates, llama_get_sampling(ctx)->rng);
|
||||
}
|
||||
|
21
src/llama-sampling.h
Normal file
21
src/llama-sampling.h
Normal file
|
@ -0,0 +1,21 @@
|
|||
#pragma once
|
||||
|
||||
#include "llama-impl.h"
|
||||
|
||||
struct llama_sampling {
|
||||
llama_sampling(int32_t n_vocab) : n_vocab(n_vocab) {}
|
||||
|
||||
std::mt19937 rng;
|
||||
|
||||
int64_t t_sample_us = 0;
|
||||
|
||||
int32_t n_sample = 0;
|
||||
int32_t n_vocab = 0;
|
||||
|
||||
void reset_timings() {
|
||||
t_sample_us = 0;
|
||||
n_sample = 0;
|
||||
}
|
||||
};
|
||||
|
||||
struct llama_sampling * llama_get_sampling(struct llama_context * ctx);
|
729
src/llama.cpp
729
src/llama.cpp
|
@ -1,5 +1,5 @@
|
|||
#define LLAMA_API_INTERNAL
|
||||
#include "llama.h"
|
||||
#include "llama-impl.h"
|
||||
#include "llama-sampling.h"
|
||||
|
||||
#include "unicode.h"
|
||||
|
||||
|
@ -34,19 +34,6 @@
|
|||
// TODO: replace with ggml API call
|
||||
#define QK_K 256
|
||||
|
||||
#ifdef __has_include
|
||||
#if __has_include(<unistd.h>)
|
||||
#include <unistd.h>
|
||||
#if defined(_POSIX_MAPPED_FILES)
|
||||
#include <sys/mman.h>
|
||||
#include <fcntl.h>
|
||||
#endif
|
||||
#if defined(_POSIX_MEMLOCK_RANGE)
|
||||
#include <sys/resource.h>
|
||||
#endif
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if defined(_WIN32)
|
||||
#define WIN32_LEAN_AND_MEAN
|
||||
#ifndef NOMINMAX
|
||||
|
@ -90,7 +77,6 @@
|
|||
#include <mutex>
|
||||
#include <numeric>
|
||||
#include <queue>
|
||||
#include <random>
|
||||
#include <regex>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
|
@ -102,33 +88,6 @@
|
|||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
#endif
|
||||
|
||||
#ifdef __GNUC__
|
||||
#ifdef __MINGW32__
|
||||
#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
|
||||
#else
|
||||
#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
|
||||
#endif
|
||||
#else
|
||||
#define LLAMA_ATTRIBUTE_FORMAT(...)
|
||||
#endif
|
||||
|
||||
// bump if necessary
|
||||
#define LLAMA_MAX_NODES 8192
|
||||
#define LLAMA_MAX_LAYERS 512
|
||||
#define LLAMA_MAX_EXPERTS 160 // DeepSeekV2
|
||||
|
||||
//
|
||||
// logging
|
||||
//
|
||||
|
||||
LLAMA_ATTRIBUTE_FORMAT(2, 3)
|
||||
static void llama_log_internal (ggml_log_level level, const char * format, ...);
|
||||
static void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data);
|
||||
|
||||
#define LLAMA_LOG_INFO(...) llama_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__)
|
||||
#define LLAMA_LOG_WARN(...) llama_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__)
|
||||
#define LLAMA_LOG_ERROR(...) llama_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
|
||||
|
||||
//
|
||||
// helpers
|
||||
//
|
||||
|
@ -2737,7 +2696,8 @@ struct llama_model {
|
|||
};
|
||||
|
||||
struct llama_context {
|
||||
llama_context(const llama_model & model) : model(model), t_start_us(model.t_start_us), t_load_us(model.t_load_us) {}
|
||||
llama_context(const llama_model & model) : model(model), sampling(llama_n_vocab(&model)), t_start_us(model.t_start_us), t_load_us(model.t_load_us) {}
|
||||
|
||||
~llama_context() {
|
||||
ggml_backend_sched_free(sched);
|
||||
|
||||
|
@ -2748,7 +2708,14 @@ struct llama_context {
|
|||
ggml_backend_buffer_free(buf_output);
|
||||
}
|
||||
|
||||
llama_cparams cparams;
|
||||
const struct llama_model & model;
|
||||
|
||||
struct llama_cparams cparams;
|
||||
struct llama_sampling sampling;
|
||||
struct llama_kv_cache kv_self;
|
||||
struct llama_control_vector cvec;
|
||||
|
||||
std::unordered_map<struct llama_lora_adapter *, float> lora_adapters;
|
||||
|
||||
std::vector<ggml_backend_t> backends;
|
||||
#ifdef GGML_USE_METAL
|
||||
|
@ -2759,26 +2726,16 @@ struct llama_context {
|
|||
#endif
|
||||
ggml_backend_t backend_cpu = nullptr;
|
||||
|
||||
|
||||
const llama_model & model;
|
||||
|
||||
// key + value cache for the self attention
|
||||
struct llama_kv_cache kv_self;
|
||||
|
||||
std::mt19937 rng;
|
||||
|
||||
bool has_evaluated_once = false;
|
||||
|
||||
int64_t t_start_us;
|
||||
int64_t t_load_us;
|
||||
int64_t t_sample_us = 0;
|
||||
int64_t t_p_eval_us = 0;
|
||||
int64_t t_eval_us = 0;
|
||||
|
||||
int64_t t_compute_start_us = 0;
|
||||
int64_t n_queued_tokens = 0;
|
||||
|
||||
int32_t n_sample = 0; // number of tokens sampled
|
||||
int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
|
||||
int32_t n_eval = 0; // number of eval calls
|
||||
|
||||
|
@ -2834,12 +2791,6 @@ struct llama_context {
|
|||
struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
|
||||
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
|
||||
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
|
||||
|
||||
// control vectors
|
||||
struct llama_control_vector cvec;
|
||||
|
||||
// lora adapters and scales
|
||||
std::unordered_map<struct llama_lora_adapter *, float> lora_adapters;
|
||||
};
|
||||
|
||||
struct llama_lora_weight {
|
||||
|
@ -17047,469 +16998,7 @@ struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar)
|
|||
return result;
|
||||
}
|
||||
|
||||
//
|
||||
// sampling
|
||||
//
|
||||
|
||||
void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) {
|
||||
if (seed == LLAMA_DEFAULT_SEED) {
|
||||
seed = time(NULL);
|
||||
}
|
||||
ctx->rng.seed(seed);
|
||||
}
|
||||
|
||||
void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) {
|
||||
GGML_ASSERT(candidates->size > 0);
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
// Sort the logits in descending order
|
||||
if (!candidates->sorted) {
|
||||
std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
|
||||
return a.logit > b.logit;
|
||||
});
|
||||
candidates->sorted = true;
|
||||
}
|
||||
|
||||
float max_l = candidates->data[0].logit;
|
||||
float cum_sum = 0.0f;
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
float p = expf(candidates->data[i].logit - max_l);
|
||||
candidates->data[i].p = p;
|
||||
cum_sum += p;
|
||||
}
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
candidates->data[i].p /= cum_sum;
|
||||
}
|
||||
|
||||
if (ctx) {
|
||||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
|
||||
// TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
|
||||
// if (k >= (int32_t)candidates->size) {
|
||||
// return;
|
||||
// }
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
if (k <= 0) {
|
||||
k = candidates->size;
|
||||
}
|
||||
|
||||
k = std::max(k, (int) min_keep);
|
||||
k = std::min(k, (int) candidates->size);
|
||||
|
||||
// Sort scores in descending order
|
||||
if (!candidates->sorted) {
|
||||
auto comp = [](const llama_token_data & a, const llama_token_data & b) {
|
||||
return a.logit > b.logit;
|
||||
};
|
||||
if (k <= 128) {
|
||||
std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp);
|
||||
} else {
|
||||
constexpr int nbuckets = 128;
|
||||
constexpr float bucket_low = -10.0f;
|
||||
constexpr float bucket_high = 10.0f;
|
||||
constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low);
|
||||
constexpr float bucker_inter = -bucket_low * bucket_scale;
|
||||
|
||||
std::vector<int> bucket_idx(candidates->size);
|
||||
std::vector<int> histo(nbuckets, 0);
|
||||
|
||||
for (int i = 0; i < (int)candidates->size; ++i) {
|
||||
const float val = candidates->data[i].logit;
|
||||
int ib = int(bucket_scale * val + bucker_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
|
||||
ib = std::max(0, std::min(nbuckets-1, ib));
|
||||
bucket_idx[i] = ib;
|
||||
++histo[ib];
|
||||
}
|
||||
int nhave = 0;
|
||||
int ib = nbuckets - 1;
|
||||
for ( ; ib >= 0; --ib) {
|
||||
nhave += histo[ib];
|
||||
if (nhave >= k) break;
|
||||
}
|
||||
std::vector<llama_token_data> tmp_tokens(nhave);
|
||||
auto ptr = tmp_tokens.data();
|
||||
std::vector<llama_token_data*> bucket_ptrs;
|
||||
bucket_ptrs.reserve(nbuckets - ib);
|
||||
for (int j = nbuckets - 1; j >= ib; --j) {
|
||||
bucket_ptrs.push_back(ptr);
|
||||
ptr += histo[j];
|
||||
}
|
||||
for (int i = 0; i < (int)candidates->size; ++i) {
|
||||
int j = bucket_idx[i];
|
||||
if (j >= ib) {
|
||||
*bucket_ptrs[nbuckets-1-j]++ = candidates->data[i];
|
||||
}
|
||||
}
|
||||
|
||||
ptr = tmp_tokens.data();
|
||||
int ndone = 0;
|
||||
for (int j = nbuckets-1; j > ib; --j) {
|
||||
std::sort(ptr, ptr + histo[j], comp);
|
||||
ptr += histo[j];
|
||||
ndone += histo[j];
|
||||
}
|
||||
std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp);
|
||||
|
||||
std::memcpy(candidates->data, tmp_tokens.data(), k*sizeof(llama_token_data));
|
||||
|
||||
}
|
||||
candidates->sorted = true;
|
||||
}
|
||||
candidates->size = k;
|
||||
|
||||
if (ctx) {
|
||||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
|
||||
if (p >= 1.0f) {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_sample_softmax(ctx, candidates);
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
// Compute the cumulative probabilities
|
||||
float cum_sum = 0.0f;
|
||||
size_t last_idx = candidates->size;
|
||||
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
cum_sum += candidates->data[i].p;
|
||||
|
||||
// Check if the running sum is at least p or if we have kept at least min_keep tokens
|
||||
// we set the last index to i+1 to indicate that the current iterate should be included in the set
|
||||
if (cum_sum >= p && i + 1 >= min_keep) {
|
||||
last_idx = i + 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Resize the output vector to keep only the top-p tokens
|
||||
candidates->size = last_idx;
|
||||
|
||||
if (ctx) {
|
||||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
|
||||
if (p <= 0.0f || !candidates->size) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
bool min_p_applied = false;
|
||||
|
||||
// if the candidates aren't sorted, try the unsorted implementation first
|
||||
if (!candidates->sorted) {
|
||||
std::vector<llama_token_data> filtered_tokens;
|
||||
|
||||
float max_logit = -FLT_MAX;
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
max_logit = std::max(max_logit, candidates->data[i].logit);
|
||||
}
|
||||
const float min_logit = max_logit + logf(p); // min logit for p_i >= p * p_max
|
||||
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
if (candidates->data[i].logit >= min_logit) {
|
||||
filtered_tokens.push_back(candidates->data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// if we have enough values the operation was a success
|
||||
if (filtered_tokens.size() >= min_keep) {
|
||||
memcpy(candidates->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
|
||||
candidates->size = filtered_tokens.size();
|
||||
min_p_applied = true;
|
||||
}
|
||||
}
|
||||
|
||||
// if the candidates are sorted or the unsorted implementation failed, use this implementation
|
||||
if (!min_p_applied) {
|
||||
// Sort the logits in descending order
|
||||
if (!candidates->sorted) {
|
||||
std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
|
||||
return a.logit > b.logit;
|
||||
});
|
||||
candidates->sorted = true;
|
||||
}
|
||||
|
||||
const float min_logit = candidates->data[0].logit + logf(p); // min logit for p_i >= p * p_max
|
||||
size_t i = 1; // first token always matches
|
||||
|
||||
for (; i < candidates->size; ++i) {
|
||||
if (candidates->data[i].logit < min_logit && i >= min_keep) {
|
||||
break; // prob too small
|
||||
}
|
||||
}
|
||||
|
||||
// Resize the output vector to keep only the matching tokens
|
||||
candidates->size = i;
|
||||
}
|
||||
|
||||
if (ctx) {
|
||||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) {
|
||||
if (z >= 1.0f || candidates->size <= 2) {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_sample_softmax(nullptr, candidates);
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
// Compute the first and second derivatives
|
||||
std::vector<float> first_derivatives(candidates->size - 1);
|
||||
std::vector<float> second_derivatives(candidates->size - 2);
|
||||
|
||||
for (size_t i = 0; i < first_derivatives.size(); ++i) {
|
||||
first_derivatives[i] = candidates->data[i].p - candidates->data[i + 1].p;
|
||||
}
|
||||
for (size_t i = 0; i < second_derivatives.size(); ++i) {
|
||||
second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1];
|
||||
}
|
||||
|
||||
// Calculate absolute value of second derivatives
|
||||
for (size_t i = 0; i < second_derivatives.size(); ++i) {
|
||||
second_derivatives[i] = std::abs(second_derivatives[i]);
|
||||
}
|
||||
|
||||
// Normalize the second derivatives
|
||||
{
|
||||
const float second_derivatives_sum = std::accumulate(second_derivatives.begin(), second_derivatives.end(), 0.0f);
|
||||
|
||||
if (second_derivatives_sum > 1e-6f) {
|
||||
for (float & value : second_derivatives) {
|
||||
value /= second_derivatives_sum;
|
||||
}
|
||||
} else {
|
||||
for (float & value : second_derivatives) {
|
||||
value = 1.0f / second_derivatives.size();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float cum_sum = 0.0f;
|
||||
size_t last_idx = candidates->size;
|
||||
for (size_t i = 0; i < second_derivatives.size(); ++i) {
|
||||
cum_sum += second_derivatives[i];
|
||||
|
||||
// Check if the running sum is greater than z or if we have kept at least min_keep tokens
|
||||
if (cum_sum > z && i >= min_keep) {
|
||||
last_idx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Resize the output vector to keep only the tokens above the tail location
|
||||
candidates->size = last_idx;
|
||||
|
||||
if (ctx) {
|
||||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
|
||||
// Reference implementation:
|
||||
// https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
|
||||
if (p >= 1.0f) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Compute the softmax of logits and calculate entropy
|
||||
llama_sample_softmax(nullptr, candidates);
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
float entropy = 0.0f;
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
entropy += -candidates->data[i].p * logf(candidates->data[i].p);
|
||||
}
|
||||
|
||||
// Compute the absolute difference between negative log probability and entropy for each candidate
|
||||
std::vector<float> shifted_scores;
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
float shifted_score = fabsf(-logf(candidates->data[i].p) - entropy);
|
||||
shifted_scores.push_back(shifted_score);
|
||||
}
|
||||
|
||||
// Sort tokens based on the shifted_scores and their corresponding indices
|
||||
std::vector<size_t> indices(candidates->size);
|
||||
std::iota(indices.begin(), indices.end(), 0);
|
||||
|
||||
std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) {
|
||||
return shifted_scores[a] < shifted_scores[b];
|
||||
});
|
||||
|
||||
// Compute the cumulative probabilities
|
||||
float cum_sum = 0.0f;
|
||||
size_t last_idx = indices.size();
|
||||
|
||||
for (size_t i = 0; i < indices.size(); ++i) {
|
||||
size_t idx = indices[i];
|
||||
cum_sum += candidates->data[idx].p;
|
||||
|
||||
// Check if the running sum is greater than typical or if we have kept at least min_keep tokens
|
||||
if (cum_sum > p && i >= min_keep - 1) {
|
||||
last_idx = i + 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Resize the output vector to keep only the locally typical tokens
|
||||
std::vector<llama_token_data> new_candidates;
|
||||
for (size_t i = 0; i < last_idx; ++i) {
|
||||
size_t idx = indices[i];
|
||||
new_candidates.push_back(candidates->data[idx]);
|
||||
}
|
||||
|
||||
// Replace the data in candidates with the new_candidates data
|
||||
std::copy(new_candidates.begin(), new_candidates.end(), candidates->data);
|
||||
candidates->size = new_candidates.size();
|
||||
candidates->sorted = false;
|
||||
|
||||
if (ctx) {
|
||||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_entropy(struct llama_context * ctx, llama_token_data_array * candidates_p, float min_temp, float max_temp, float exponent_val) {
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
// no need to do anything if there is only one (or zero) candidates
|
||||
if(candidates_p->size <= 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Calculate maximum possible entropy
|
||||
float max_entropy = -logf(1.0f / candidates_p->size);
|
||||
|
||||
llama_sample_softmax(nullptr, candidates_p);
|
||||
|
||||
// Calculate entropy of the softmax probabilities
|
||||
float entropy = 0.0f;
|
||||
for (size_t i = 0; i < candidates_p->size; ++i) {
|
||||
float prob = candidates_p->data[i].p;
|
||||
if (prob > 0.0f) { // Ensure no log(0)
|
||||
entropy -= prob * logf(prob);
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize the entropy (max_entropy cannot be 0 here because we checked candidates_p->size != 1 above)
|
||||
float normalized_entropy = entropy / max_entropy;
|
||||
|
||||
// Map the normalized entropy to the desired temperature range using the power function
|
||||
float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val);
|
||||
|
||||
#ifdef DEBUG
|
||||
LLAMA_LOG_INFO("Your text maxtemp value is: %f\n", max_temp);
|
||||
LLAMA_LOG_INFO("Entropy: %f\n", entropy);
|
||||
LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy);
|
||||
LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy);
|
||||
LLAMA_LOG_INFO("Exponent: %f\n", exponent_val);
|
||||
LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp);
|
||||
#endif
|
||||
|
||||
// Apply the dynamically calculated temperature scaling
|
||||
for (size_t i = 0; i < candidates_p->size; ++i) {
|
||||
candidates_p->data[i].logit /= dyn_temp;
|
||||
}
|
||||
|
||||
// Re-compute softmax probabilities after scaling logits with dynamic temperature
|
||||
double max_l_double = candidates_p->data[0].logit;
|
||||
double cum_sum_double = 0.0;
|
||||
for (size_t i = 0; i < candidates_p->size; ++i) {
|
||||
double p = exp(candidates_p->data[i].logit - max_l_double);
|
||||
candidates_p->data[i].p = p; // Store the scaled probability
|
||||
cum_sum_double += p;
|
||||
}
|
||||
for (size_t i = 0; i < candidates_p->size; ++i) {
|
||||
candidates_p->data[i].p /= cum_sum_double; // Re-normalize the probabilities
|
||||
}
|
||||
|
||||
#ifdef DEBUG
|
||||
// Print the updated top 25 probabilities after temperature scaling
|
||||
LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n");
|
||||
for (size_t i = 0; i < 25 && i < candidates_p->size; ++i) {
|
||||
LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, candidates_p->data[i].p * 100.0f);
|
||||
}
|
||||
#endif
|
||||
|
||||
if (ctx) {
|
||||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_temp(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) {
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
for (size_t i = 0; i < candidates_p->size; ++i) {
|
||||
candidates_p->data[i].logit /= temp;
|
||||
}
|
||||
|
||||
if (ctx) {
|
||||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_repetition_penalties(
|
||||
struct llama_context * ctx,
|
||||
llama_token_data_array * candidates,
|
||||
const llama_token * last_tokens,
|
||||
size_t penalty_last_n,
|
||||
float penalty_repeat,
|
||||
float penalty_freq,
|
||||
float penalty_present) {
|
||||
if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
// Create a frequency map to count occurrences of each token in last_tokens
|
||||
std::unordered_map<llama_token, int> token_count;
|
||||
for (size_t i = 0; i < penalty_last_n; ++i) {
|
||||
token_count[last_tokens[i]]++;
|
||||
}
|
||||
|
||||
// Apply frequency and presence penalties to the candidates
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
const auto token_iter = token_count.find(candidates->data[i].id);
|
||||
if (token_iter == token_count.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const int count = token_iter->second;
|
||||
|
||||
// The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
|
||||
// This is common fix for this problem, which is to multiply by the penalty instead of dividing.
|
||||
if (candidates->data[i].logit <= 0) {
|
||||
candidates->data[i].logit *= penalty_repeat;
|
||||
} else {
|
||||
candidates->data[i].logit /= penalty_repeat;
|
||||
}
|
||||
|
||||
candidates->data[i].logit -= float(count) * penalty_freq + float(count > 0) * penalty_present;
|
||||
}
|
||||
|
||||
candidates->sorted = false;
|
||||
|
||||
if (ctx) {
|
||||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: rename to llama_grammar_...
|
||||
void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) {
|
||||
GGML_ASSERT(ctx);
|
||||
int64_t t_start_sample_us = ggml_time_us();
|
||||
|
@ -17549,7 +17038,8 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
|
|||
candidates->data[reject.index].logit = -INFINITY;
|
||||
}
|
||||
|
||||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
// TODO: change to t_grammar_us
|
||||
ctx->sampling.t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
|
||||
static void llama_log_softmax(float * array, size_t size) {
|
||||
|
@ -17566,158 +17056,6 @@ static void llama_log_softmax(float * array, size_t size) {
|
|||
}
|
||||
}
|
||||
|
||||
void llama_sample_apply_guidance(
|
||||
struct llama_context * ctx,
|
||||
float * logits,
|
||||
float * logits_guidance,
|
||||
float scale) {
|
||||
GGML_ASSERT(ctx);
|
||||
|
||||
const auto t_start_sample_us = ggml_time_us();
|
||||
const auto n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||
|
||||
llama_log_softmax(logits, n_vocab);
|
||||
llama_log_softmax(logits_guidance, n_vocab);
|
||||
|
||||
for (int i = 0; i < n_vocab; ++i) {
|
||||
auto & l = logits[i];
|
||||
const auto & g = logits_guidance[i];
|
||||
|
||||
l = scale * (l - g) + g;
|
||||
}
|
||||
|
||||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
|
||||
llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
|
||||
GGML_ASSERT(ctx);
|
||||
|
||||
auto N = float(llama_n_vocab(llama_get_model(ctx)));
|
||||
int64_t t_start_sample_us;
|
||||
t_start_sample_us = ggml_time_us();
|
||||
|
||||
llama_sample_softmax(nullptr, candidates);
|
||||
|
||||
// Estimate s_hat using the most probable m tokens
|
||||
float s_hat = 0.0;
|
||||
float sum_ti_bi = 0.0;
|
||||
float sum_ti_sq = 0.0;
|
||||
for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) {
|
||||
float t_i = logf(float(i + 2) / float(i + 1));
|
||||
float b_i = logf(candidates->data[i].p / candidates->data[i + 1].p);
|
||||
sum_ti_bi += t_i * b_i;
|
||||
sum_ti_sq += t_i * t_i;
|
||||
}
|
||||
s_hat = sum_ti_bi / sum_ti_sq;
|
||||
|
||||
// Compute k from the estimated s_hat and target surprise value
|
||||
float epsilon_hat = s_hat - 1;
|
||||
float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(N, -epsilon_hat)), 1 / s_hat);
|
||||
|
||||
// Sample the next word X using top-k sampling
|
||||
llama_sample_top_k(nullptr, candidates, int(k), 1);
|
||||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
llama_token X = llama_sample_token(ctx, candidates);
|
||||
t_start_sample_us = ggml_time_us();
|
||||
|
||||
// Compute error as the difference between observed surprise and target surprise value
|
||||
size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
|
||||
return candidate.id == X;
|
||||
}));
|
||||
float observed_surprise = -log2f(candidates->data[X_idx].p);
|
||||
float e = observed_surprise - tau;
|
||||
|
||||
// Update mu using the learning rate and error
|
||||
*mu = *mu - eta * e;
|
||||
|
||||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
return X;
|
||||
}
|
||||
|
||||
llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu) {
|
||||
int64_t t_start_sample_us;
|
||||
t_start_sample_us = ggml_time_us();
|
||||
|
||||
llama_sample_softmax(ctx, candidates);
|
||||
|
||||
// Truncate the words with surprise values greater than mu
|
||||
candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
|
||||
return -log2f(candidate.p) > *mu;
|
||||
}));
|
||||
|
||||
if (candidates->size == 0) {
|
||||
candidates->size = 1;
|
||||
}
|
||||
|
||||
if (ctx) {
|
||||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
|
||||
// Normalize the probabilities of the remaining words
|
||||
llama_sample_softmax(ctx, candidates);
|
||||
|
||||
// Sample the next word X from the remaining words
|
||||
llama_token X = llama_sample_token(ctx, candidates);
|
||||
t_start_sample_us = ggml_time_us();
|
||||
|
||||
// Compute error as the difference between observed surprise and target surprise value
|
||||
size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
|
||||
return candidate.id == X;
|
||||
}));
|
||||
float observed_surprise = -log2f(candidates->data[X_idx].p);
|
||||
float e = observed_surprise - tau;
|
||||
|
||||
// Update mu using the learning rate and error
|
||||
*mu = *mu - eta * e;
|
||||
|
||||
if (ctx) {
|
||||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
return X;
|
||||
}
|
||||
|
||||
llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates) {
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
// Find max element
|
||||
auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
|
||||
return a.logit < b.logit;
|
||||
});
|
||||
|
||||
llama_token result = max_iter->id;
|
||||
if (ctx) {
|
||||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
ctx->n_sample++;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng) {
|
||||
GGML_ASSERT(ctx);
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
llama_sample_softmax(nullptr, candidates);
|
||||
|
||||
std::vector<float> probs;
|
||||
probs.reserve(candidates->size);
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
probs.push_back(candidates->data[i].p);
|
||||
}
|
||||
|
||||
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
||||
int idx = dist(rng);
|
||||
|
||||
llama_token result = candidates->data[idx].id;
|
||||
|
||||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
ctx->n_sample++;
|
||||
return result;
|
||||
}
|
||||
|
||||
llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) {
|
||||
return llama_sample_token_with_rng(ctx, candidates, ctx->rng);
|
||||
}
|
||||
|
||||
void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) {
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
|
@ -17743,7 +17081,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
|
|||
grammar->partial_utf8 = decoded.second;
|
||||
GGML_ASSERT(!grammar->stacks.empty());
|
||||
|
||||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
ctx->sampling.t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
|
||||
//
|
||||
|
@ -19131,8 +18469,8 @@ struct llama_context * llama_new_context_with_model(
|
|||
ctx->abort_callback = params.abort_callback;
|
||||
ctx->abort_callback_data = params.abort_callback_data;
|
||||
|
||||
ctx->rng = std::mt19937(params.seed);
|
||||
ctx->logits_all = params.logits_all;
|
||||
ctx->sampling.rng = std::mt19937(params.seed);
|
||||
ctx->logits_all = params.logits_all;
|
||||
|
||||
uint32_t kv_size = cparams.n_ctx;
|
||||
ggml_type type_k = params.type_k;
|
||||
|
@ -19408,10 +18746,14 @@ void llama_free(struct llama_context * ctx) {
|
|||
delete ctx;
|
||||
}
|
||||
|
||||
const llama_model * llama_get_model(const struct llama_context * ctx) {
|
||||
const struct llama_model * llama_get_model(const struct llama_context * ctx) {
|
||||
return &ctx->model;
|
||||
}
|
||||
|
||||
struct llama_sampling * llama_get_sampling(struct llama_context * ctx) {
|
||||
return &ctx->sampling;
|
||||
}
|
||||
|
||||
uint32_t llama_n_ctx(const struct llama_context * ctx) {
|
||||
return ctx->cparams.n_ctx;
|
||||
}
|
||||
|
@ -20000,7 +19342,7 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data
|
|||
// copy rng
|
||||
{
|
||||
std::ostringstream rng_ss;
|
||||
rng_ss << ctx->rng;
|
||||
rng_ss << ctx->sampling.rng;
|
||||
|
||||
const std::string & rng_str = rng_ss.str();
|
||||
const size_t rng_size = rng_str.size();
|
||||
|
@ -20166,7 +19508,7 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
|
|||
std::string rng_str((const char *)inp, rng_size); inp += rng_size;
|
||||
|
||||
std::istringstream rng_ss(rng_str);
|
||||
rng_ss >> ctx->rng;
|
||||
rng_ss >> ctx->sampling.rng;
|
||||
|
||||
GGML_ASSERT(!rng_ss.fail());
|
||||
}
|
||||
|
@ -21737,11 +21079,11 @@ struct llama_timings llama_get_timings(struct llama_context * ctx) {
|
|||
/*.t_start_ms =*/ 1e-3 * ctx->t_start_us,
|
||||
/*.t_end_ms =*/ 1.00 * ggml_time_ms(),
|
||||
/*.t_load_ms =*/ 1e-3 * ctx->t_load_us,
|
||||
/*.t_sample_ms =*/ 1e-3 * ctx->t_sample_us,
|
||||
/*.t_sample_ms =*/ 1e-3 * ctx->sampling.t_sample_us,
|
||||
/*.t_p_eval_ms =*/ 1e-3 * ctx->t_p_eval_us,
|
||||
/*.t_eval_ms =*/ 1e-3 * ctx->t_eval_us,
|
||||
|
||||
/*.n_sample =*/ std::max(1, ctx->n_sample),
|
||||
/*.n_sample =*/ std::max(1, ctx->sampling.n_sample),
|
||||
/*.n_p_eval =*/ std::max(0, ctx->n_p_eval),
|
||||
/*.n_eval =*/ std::max(1, ctx->n_eval),
|
||||
};
|
||||
|
@ -21764,10 +21106,11 @@ void llama_print_timings(struct llama_context * ctx) {
|
|||
}
|
||||
|
||||
void llama_reset_timings(struct llama_context * ctx) {
|
||||
ctx->t_start_us = ggml_time_us();
|
||||
ctx->t_sample_us = ctx->n_sample = 0;
|
||||
ctx->t_start_us = ggml_time_us();
|
||||
ctx->t_eval_us = ctx->n_eval = 0;
|
||||
ctx->t_p_eval_us = ctx->n_p_eval = 0;
|
||||
|
||||
ctx->sampling.reset_timings();
|
||||
}
|
||||
|
||||
const char * llama_print_system_info(void) {
|
||||
|
@ -21814,20 +21157,20 @@ void llama_dump_timing_info_yaml(FILE * stream, const llama_context * ctx) {
|
|||
fprintf(stream, "mst_p_eval: %.2f # ms / token during prompt processing\n",
|
||||
1.0e-3 * ctx->t_p_eval_us / ctx->n_p_eval);
|
||||
fprintf(stream, "mst_sample: %.2f # ms / token during sampling\n",
|
||||
1.0e-3 * ctx->t_sample_us / ctx->n_sample);
|
||||
1.0e-3 * ctx->sampling.t_sample_us / ctx->sampling.n_sample);
|
||||
fprintf(stream, "n_eval: %d # number of tokens generated (excluding the first one)\n", ctx->n_eval);
|
||||
fprintf(stream, "n_p_eval: %d # number of tokens processed in batches at the beginning\n", ctx->n_p_eval);
|
||||
fprintf(stream, "n_sample: %d # number of sampled tokens\n", ctx->n_sample);
|
||||
fprintf(stream, "n_sample: %d # number of sampled tokens\n", ctx->sampling.n_sample);
|
||||
fprintf(stream, "t_eval_us: %" PRId64 " # total microseconds spent generating tokens\n", ctx->t_eval_us);
|
||||
fprintf(stream, "t_load_us: %" PRId64 " # total microseconds spent loading the model\n", ctx->t_load_us);
|
||||
fprintf(stream, "t_p_eval_us: %" PRId64 " # total microseconds spent prompt processing\n", ctx->t_p_eval_us);
|
||||
fprintf(stream, "t_sample_us: %" PRId64 " # total microseconds spent sampling\n", ctx->t_sample_us);
|
||||
fprintf(stream, "t_sample_us: %" PRId64 " # total microseconds spent sampling\n", ctx->sampling.t_sample_us);
|
||||
fprintf(stream, "ts_eval: %.2f # tokens / second during generation\n",
|
||||
1.0e6 * ctx->n_eval / ctx->t_eval_us);
|
||||
fprintf(stream, "ts_p_eval: %.2f # tokens / second during prompt processing\n",
|
||||
1.0e6 * ctx->n_p_eval / ctx->t_p_eval_us);
|
||||
fprintf(stream, "ts_sample: %.2f # tokens / second during sampling\n",
|
||||
1.0e6 * ctx->n_sample / ctx->t_sample_us);
|
||||
1.0e6 * ctx->sampling.n_sample / ctx->sampling.t_sample_us);
|
||||
}
|
||||
|
||||
// For internal test use
|
||||
|
@ -21866,14 +21209,14 @@ static void llama_log_internal_v(ggml_log_level level, const char * format, va_l
|
|||
va_end(args_copy);
|
||||
}
|
||||
|
||||
static void llama_log_internal(ggml_log_level level, const char * format, ...) {
|
||||
void llama_log_internal(ggml_log_level level, const char * format, ...) {
|
||||
va_list args;
|
||||
va_start(args, format);
|
||||
llama_log_internal_v(level, format, args);
|
||||
va_end(args);
|
||||
}
|
||||
|
||||
static void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data) {
|
||||
void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data) {
|
||||
(void) level;
|
||||
(void) user_data;
|
||||
fputs(text, stderr);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue