llama : move sampling rngs from common to llama

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-07-23 10:23:44 +03:00
parent 938943cdbf
commit f866cb9342
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
22 changed files with 342 additions and 344 deletions

View file

@ -1,11 +1,13 @@
#define LLAMA_API_INTERNAL
#include "sampling.h" #include "sampling.h"
#include <random> #include <random>
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) { struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, struct llama_context * ctx, llama_seq_id seq_id) {
struct llama_sampling_context * result = new llama_sampling_context(); struct llama_sampling_context * result = new llama_sampling_context();
result->params = params; result->params = params;
result->seq_id = seq_id;
result->ctx = ctx;
result->grammar = nullptr; result->grammar = nullptr;
// if there is a grammar, parse it // if there is a grammar, parse it
@ -81,7 +83,7 @@ void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t s
if (seed == LLAMA_DEFAULT_SEED) { if (seed == LLAMA_DEFAULT_SEED) {
seed = std::random_device{}(); seed = std::random_device{}();
} }
ctx->rng.seed(seed); llama_set_rng_seed_seq(ctx->ctx, seed, ctx->seq_id);
} }
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) { void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) {
@ -271,10 +273,10 @@ static llama_token llama_sampling_sample_impl(
bool is_resampling) { bool is_resampling) {
const llama_sampling_params & params = ctx_sampling->params; const llama_sampling_params & params = ctx_sampling->params;
const float temp = params.temp; const float temp = params.temp;
const int mirostat = params.mirostat; const int mirostat = params.mirostat;
const float mirostat_tau = params.mirostat_tau; const float mirostat_tau = params.mirostat_tau;
const float mirostat_eta = params.mirostat_eta; const float mirostat_eta = params.mirostat_eta;
std::vector<float> original_logits; std::vector<float> original_logits;
auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits); auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits);
@ -304,7 +306,7 @@ static llama_token llama_sampling_sample_impl(
sampler_queue(ctx_main, params, cur_p, min_keep); sampler_queue(ctx_main, params, cur_p, min_keep);
id = llama_sample_token_with_rng(ctx_main, &cur_p, ctx_sampling->rng); id = llama_sample_token_seq(ctx_main, &cur_p, ctx_sampling->seq_id);
//{ //{
// const int n_top = 10; // const int n_top = 10;

View file

@ -70,9 +70,12 @@ struct llama_sampling_context {
// parameters that will be used for sampling // parameters that will be used for sampling
llama_sampling_params params; llama_sampling_params params;
llama_seq_id seq_id;
// mirostat sampler state // mirostat sampler state
float mirostat_mu; float mirostat_mu;
llama_context * ctx; // TMP
llama_grammar * grammar; llama_grammar * grammar;
// internal // internal
@ -81,15 +84,14 @@ struct llama_sampling_context {
// TODO: replace with ring-buffer // TODO: replace with ring-buffer
std::vector<llama_token> prev; std::vector<llama_token> prev;
std::vector<llama_token_data> cur; std::vector<llama_token_data> cur;
size_t n_valid; // Number of correct top tokens with correct probabilities.
std::mt19937 rng; size_t n_valid; // Number of correct top tokens with correct probabilities.
}; };
#include "common.h" #include "common.h"
// Create a new sampling context instance. // Create a new sampling context instance.
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params); struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, struct llama_context * ctx, llama_seq_id seq_id);
void llama_sampling_free(struct llama_sampling_context * ctx); void llama_sampling_free(struct llama_sampling_context * ctx);

View file

@ -1,8 +1,7 @@
#define LLAMA_API_INTERNAL
#include "grammar-parser.h" #include "grammar-parser.h"
#include "ggml.h" #include "ggml.h"
#include "llama.h" #include "llama.h"
#include "llama-impl.h"
#include "unicode.h" #include "unicode.h"
#include <cstdio> #include <cstdio>

View file

@ -346,7 +346,7 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd; std::vector<llama_token> embd;
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams); struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams, ctx, 0);
while (n_remain != 0 || params.interactive) { while (n_remain != 0 || params.interactive) {
// predict // predict

View file

@ -191,7 +191,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
LOG_TEE("\n"); LOG_TEE("\n");
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams); struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams, ctx_llava->ctx_llama, 0);
if (!ctx_sampling) { if (!ctx_sampling) {
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__); fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
exit(1); exit(1);

View file

@ -118,7 +118,7 @@ int main(int argc, char ** argv) {
llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1); llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1);
// target model sampling context // target model sampling context
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams); struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams, ctx, 0);
// verification n-grams // verification n-grams
std::vector<ngram_data> ngrams_cur(G); std::vector<ngram_data> ngrams_cur(G);

View file

@ -106,7 +106,7 @@ int main(int argc, char ** argv){
bool has_eos = false; bool has_eos = false;
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams); struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams, ctx, 0);
std::vector<llama_token> draft; std::vector<llama_token> draft;

View file

@ -527,7 +527,7 @@ int main(int argc, char ** argv) {
antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true)); antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true));
} }
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams); struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams, ctx, 0);
if (!ctx_sampling) { if (!ctx_sampling) {
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__); fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
exit(1); exit(1);

View file

@ -161,7 +161,7 @@ int main(int argc, char ** argv) {
for (size_t i = 0; i < clients.size(); ++i) { for (size_t i = 0; i < clients.size(); ++i) {
auto & client = clients[i]; auto & client = clients[i];
client.id = i; client.id = i;
client.ctx_sampling = llama_sampling_init(params.sparams); client.ctx_sampling = llama_sampling_init(params.sparams, ctx, i);
} }
std::vector<llama_token> tokens_system; std::vector<llama_token> tokens_system;

View file

@ -1,7 +1,7 @@
#define LLAMA_API_INTERNAL
#include "common.h" #include "common.h"
#include "ggml.h" #include "ggml.h"
#include "llama.h" #include "llama.h"
#include "llama-impl.h"
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>

View file

@ -1090,7 +1090,7 @@ struct server_context {
if (slot.ctx_sampling != nullptr) { if (slot.ctx_sampling != nullptr) {
llama_sampling_free(slot.ctx_sampling); llama_sampling_free(slot.ctx_sampling);
} }
slot.ctx_sampling = llama_sampling_init(slot.sparams); slot.ctx_sampling = llama_sampling_init(slot.sparams, ctx, slot.id);
if (slot.ctx_sampling == nullptr) { if (slot.ctx_sampling == nullptr) {
// for now, the only error that may happen here is invalid grammar // for now, the only error that may happen here is invalid grammar
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);

View file

@ -175,7 +175,7 @@ int main(int argc, char ** argv) {
bool has_eos = false; bool has_eos = false;
// target model sampling context // target model sampling context
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams); struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams, ctx_tgt, 0);
// draft sequence data // draft sequence data
std::vector<seq_draft> drafts(n_seq_dft); std::vector<seq_draft> drafts(n_seq_dft);
@ -186,7 +186,7 @@ int main(int argc, char ** argv) {
} }
for (int s = 0; s < n_seq_dft; ++s) { for (int s = 0; s < n_seq_dft; ++s) {
drafts[s].ctx_sampling = llama_sampling_init(params.sparams); drafts[s].ctx_sampling = llama_sampling_init(params.sparams, ctx_dft, s);
} }
llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1); llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1);

View file

@ -40,7 +40,7 @@
#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq' #define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
#define LLAMA_SESSION_VERSION 7 #define LLAMA_SESSION_VERSION 8
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ #define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
#define LLAMA_STATE_SEQ_VERSION 1 #define LLAMA_STATE_SEQ_VERSION 1
@ -1031,6 +1031,9 @@ extern "C" {
// Sets the current rng seed. // Sets the current rng seed.
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed); LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed);
LLAMA_API DEPRECATED(void llama_set_rng_seed_seq(struct llama_context * ctx, uint32_t seed, llama_seq_id),
"temporary API, until llama_sampling_context is implemented, do not use");
/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
/// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
LLAMA_API void llama_sample_repetition_penalties( LLAMA_API void llama_sample_repetition_penalties(
@ -1137,11 +1140,18 @@ extern "C" {
struct llama_context * ctx, struct llama_context * ctx,
llama_token_data_array * candidates); llama_token_data_array * candidates);
/// @details Randomly selects a token from the candidates based on their probabilities using the RNG of ctx. /// @details Randomly selects a token from the candidates based on their probabilities using RNG[0] of ctx.
LLAMA_API llama_token llama_sample_token( LLAMA_API llama_token llama_sample_token(
struct llama_context * ctx, struct llama_context * ctx,
llama_token_data_array * candidates); llama_token_data_array * candidates);
/// @details Same as llama_sample_token, but uses a seqeuence-specific RNG[seq_id].
LLAMA_API DEPRECATED(llama_token llama_sample_token_seq(
struct llama_context * ctx,
llama_token_data_array * candidates,
llama_seq_id seq_id),
"temporary API, until llama_sampling_context is implemented, do not use");
// //
// Model split // Model split
// //
@ -1175,59 +1185,4 @@ extern "C" {
} }
#endif #endif
// Internal API to be implemented by llama.cpp and used by tests/benchmarks only
#ifdef LLAMA_API_INTERNAL
#include <random>
#include <string>
#include <vector>
struct ggml_tensor;
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
struct llama_context * ctx
);
struct llama_partial_utf8 {
uint32_t value; // bit value so far (unshifted)
int n_remain; // num bytes remaining; -1 indicates invalid sequence
};
struct llama_grammar_candidate {
size_t index;
const uint32_t * code_points;
llama_partial_utf8 partial_utf8;
};
using llama_grammar_rule = std::vector< llama_grammar_element>;
using llama_grammar_stack = std::vector<const llama_grammar_element *>;
using llama_grammar_rules = std::vector<llama_grammar_rule>;
using llama_grammar_stacks = std::vector<llama_grammar_stack>;
using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
void llama_grammar_accept(
const llama_grammar_rules & rules,
const llama_grammar_stacks & stacks,
const uint32_t chr,
llama_grammar_stacks & new_stacks);
std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
const llama_grammar_rules & rules,
const llama_grammar_stack & stack,
const llama_grammar_candidates & candidates);
std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
const std::string & src,
llama_partial_utf8 partial_start);
// Randomly selects a token from the candidates based on their probabilities using given std::mt19937.
// This is a temporary workaround in order to fix race conditions when sampling with multiple sequences.
llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng);
#endif // LLAMA_API_INTERNAL
#endif // LLAMA_H #endif // LLAMA_H

View file

@ -445,15 +445,15 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) {
delete grammar; delete grammar;
} }
struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar) { struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar & grammar) {
llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8 }; llama_grammar * result = new llama_grammar{ grammar.rules, grammar.stacks, grammar.partial_utf8 };
// redirect elements in stacks to point to new rules // redirect elements in stacks to point to new rules
for (size_t is = 0; is < result->stacks.size(); is++) { for (size_t is = 0; is < result->stacks.size(); is++) {
for (size_t ie = 0; ie < result->stacks[is].size(); ie++) { for (size_t ie = 0; ie < result->stacks[is].size(); ie++) {
for (size_t ir0 = 0; ir0 < grammar->rules.size(); ir0++) { for (size_t ir0 = 0; ir0 < grammar.rules.size(); ir0++) {
for (size_t ir1 = 0; ir1 < grammar->rules[ir0].size(); ir1++) { for (size_t ir1 = 0; ir1 < grammar.rules[ir0].size(); ir1++) {
if (grammar->stacks[is][ie] == &grammar->rules[ir0][ir1]) { if (grammar.stacks[is][ie] == &grammar.rules[ir0][ir1]) {
result->stacks[is][ie] = &result->rules[ir0][ir1]; result->stacks[is][ie] = &result->rules[ir0][ir1];
} }
} }
@ -464,14 +464,9 @@ struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * gram
return result; return result;
} }
void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token_data_array * candidates) { void llama_grammar_sample_impl(const struct llama_grammar & grammar, const struct llama_vocab & vocab, llama_token_data_array * candidates) {
GGML_ASSERT(grammar);
GGML_ASSERT(vocab);
int64_t t_start_sample_us = ggml_time_us();
bool allow_eog = false; bool allow_eog = false;
for (const auto & stack : grammar->stacks) { for (const auto & stack : grammar.stacks) {
if (stack.empty()) { if (stack.empty()) {
allow_eog = true; allow_eog = true;
break; break;
@ -486,33 +481,29 @@ void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struc
for (size_t i = 0; i < candidates->size; ++i) { for (size_t i = 0; i < candidates->size; ++i) {
const llama_token id = candidates->data[i].id; const llama_token id = candidates->data[i].id;
const std::string & piece = vocab->cache_token_to_piece.at(id); const std::string & piece = vocab.cache_token_to_piece.at(id);
if (llama_token_is_eog_impl(*vocab, id)) { if (llama_token_is_eog_impl(vocab, id)) {
if (!allow_eog) { if (!allow_eog) {
candidates->data[i].logit = -INFINITY; candidates->data[i].logit = -INFINITY;
} }
} else if (piece.empty() || piece[0] == 0) { } else if (piece.empty() || piece[0] == 0) {
candidates->data[i].logit = -INFINITY; candidates->data[i].logit = -INFINITY;
} else { } else {
candidates_decoded.push_back(decode_utf8(piece, grammar->partial_utf8)); candidates_decoded.push_back(decode_utf8(piece, grammar.partial_utf8));
candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second }); candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
} }
} }
const auto rejects = llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar); const auto rejects = llama_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar);
for (const auto & reject : rejects) { for (const auto & reject : rejects) {
candidates->data[reject.index].logit = -INFINITY; candidates->data[reject.index].logit = -INFINITY;
} }
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
} }
void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token token) { void llama_grammar_accept_token_impl(struct llama_grammar & grammar, const struct llama_vocab & vocab, llama_token token) {
const int64_t t_start_sample_us = ggml_time_us(); if (llama_token_is_eog_impl(vocab, token)) {
for (const auto & stack : grammar.stacks) {
if (llama_token_is_eog_impl(*vocab, token)) {
for (const auto & stack : grammar->stacks) {
if (stack.empty()) { if (stack.empty()) {
return; return;
} }
@ -520,20 +511,18 @@ void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struc
GGML_ASSERT(false); GGML_ASSERT(false);
} }
const std::string & piece = vocab->cache_token_to_piece.at(token); const std::string & piece = vocab.cache_token_to_piece.at(token);
// Note terminating 0 in decoded string // Note terminating 0 in decoded string
const auto decoded = decode_utf8(piece, grammar->partial_utf8); const auto decoded = decode_utf8(piece, grammar.partial_utf8);
const auto & code_points = decoded.first; const auto & code_points = decoded.first;
llama_grammar_stacks tmp_new_stacks; llama_grammar_stacks tmp_new_stacks;
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
llama_grammar_accept(grammar->rules, grammar->stacks, *it, tmp_new_stacks); llama_grammar_accept(grammar.rules, grammar.stacks, *it, tmp_new_stacks);
grammar->stacks = tmp_new_stacks; grammar.stacks = tmp_new_stacks;
} }
grammar->partial_utf8 = decoded.second; grammar.partial_utf8 = decoded.second;
GGML_ASSERT(!grammar->stacks.empty()); GGML_ASSERT(!grammar.stacks.empty());
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
} }

View file

@ -26,16 +26,14 @@ struct llama_grammar * llama_grammar_init_impl(
void llama_grammar_free_impl(struct llama_grammar * grammar); void llama_grammar_free_impl(struct llama_grammar * grammar);
struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar); struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar & grammar);
void llama_grammar_sample_impl( void llama_grammar_sample_impl(
const struct llama_grammar * grammar, const struct llama_grammar & grammar,
const struct llama_vocab * vocab, const struct llama_vocab & vocab,
const struct llama_sampling * smpl,
llama_token_data_array * candidates); llama_token_data_array * candidates);
void llama_grammar_accept_token_impl( void llama_grammar_accept_token_impl(
struct llama_grammar * grammar, struct llama_grammar & grammar,
const struct llama_vocab * vocab, const struct llama_vocab & vocab,
const struct llama_sampling * smpl,
llama_token token); llama_token token);

View file

@ -1,8 +1,11 @@
#pragma once #pragma once
#define LLAMA_API_INTERNAL
#include "llama.h" #include "llama.h"
#include <random>
#include <string>
#include <vector>
#ifdef __GNUC__ #ifdef __GNUC__
#ifdef __MINGW32__ #ifdef __MINGW32__
#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) #define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
@ -24,3 +27,43 @@ void llama_log_callback_default(ggml_log_level level, const char * text, void *
#define LLAMA_LOG_INFO(...) llama_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__) #define LLAMA_LOG_INFO(...) llama_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__)
#define LLAMA_LOG_WARN(...) llama_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__) #define LLAMA_LOG_WARN(...) llama_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__)
#define LLAMA_LOG_ERROR(...) llama_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__) #define LLAMA_LOG_ERROR(...) llama_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
struct llama_context * ctx
);
struct llama_partial_utf8 {
uint32_t value; // bit value so far (unshifted)
int n_remain; // num bytes remaining; -1 indicates invalid sequence
};
struct llama_grammar_candidate {
size_t index;
const uint32_t * code_points;
llama_partial_utf8 partial_utf8;
};
using llama_grammar_rule = std::vector< llama_grammar_element>;
using llama_grammar_stack = std::vector<const llama_grammar_element *>;
using llama_grammar_rules = std::vector<llama_grammar_rule>;
using llama_grammar_stacks = std::vector<llama_grammar_stack>;
using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
void llama_grammar_accept(
const llama_grammar_rules & rules,
const llama_grammar_stacks & stacks,
const uint32_t chr,
llama_grammar_stacks & new_stacks);
std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
const llama_grammar_rules & rules,
const llama_grammar_stack & stack,
const llama_grammar_candidates & candidates);
std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
const std::string & src,
llama_partial_utf8 partial_start);

View file

@ -21,19 +21,17 @@ static void llama_log_softmax(float * array, size_t size) {
} }
} }
void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed) { void llama_set_rng_seed_impl(struct llama_sampling & smpl, uint32_t seed) {
if (seed == LLAMA_DEFAULT_SEED) { if (seed == LLAMA_DEFAULT_SEED) {
seed = time(NULL); seed = time(NULL);
} }
smpl->rng.seed(seed); smpl.rng.seed(seed);
} }
void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) { void llama_sample_softmax_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates) {
GGML_ASSERT(candidates->size > 0); GGML_ASSERT(candidates->size > 0);
const int64_t t_start_sample_us = ggml_time_us();
// Sort the logits in descending order // Sort the logits in descending order
if (!candidates->sorted) { if (!candidates->sorted) {
std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
@ -44,28 +42,24 @@ void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_ar
float max_l = candidates->data[0].logit; float max_l = candidates->data[0].logit;
float cum_sum = 0.0f; float cum_sum = 0.0f;
for (size_t i = 0; i < candidates->size; ++i) { for (size_t i = 0; i < candidates->size; ++i) {
float p = expf(candidates->data[i].logit - max_l); float p = expf(candidates->data[i].logit - max_l);
candidates->data[i].p = p; candidates->data[i].p = p;
cum_sum += p; cum_sum += p;
} }
for (size_t i = 0; i < candidates->size; ++i) { for (size_t i = 0; i < candidates->size; ++i) {
candidates->data[i].p /= cum_sum; candidates->data[i].p /= cum_sum;
} }
if (smpl) {
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
}
} }
void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep) { void llama_sample_top_k_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
// TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast // TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
// if (k >= (int32_t)candidates->size) { // if (k >= (int32_t)candidates->size) {
// return; // return;
// } // }
const int64_t t_start_sample_us = ggml_time_us();
if (k <= 0) { if (k <= 0) {
k = candidates->size; k = candidates->size;
} }
@ -133,21 +127,15 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra
candidates->sorted = true; candidates->sorted = true;
} }
candidates->size = k; candidates->size = k;
if (smpl) {
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
}
} }
void llama_sample_top_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) { void llama_sample_top_p_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
if (p >= 1.0f) { if (p >= 1.0f) {
return; return;
} }
llama_sample_softmax_impl(smpl, candidates); llama_sample_softmax_impl(smpl, candidates);
const int64_t t_start_sample_us = ggml_time_us();
// Compute the cumulative probabilities // Compute the cumulative probabilities
float cum_sum = 0.0f; float cum_sum = 0.0f;
size_t last_idx = candidates->size; size_t last_idx = candidates->size;
@ -165,19 +153,13 @@ void llama_sample_top_p_impl(struct llama_sampling * smpl, llama_token_data_arra
// Resize the output vector to keep only the top-p tokens // Resize the output vector to keep only the top-p tokens
candidates->size = last_idx; candidates->size = last_idx;
if (smpl) {
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
}
} }
void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) { void llama_sample_min_p_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates, float p, size_t min_keep) {
if (p <= 0.0f || !candidates->size) { if (p <= 0.0f || !candidates->size) {
return; return;
} }
const int64_t t_start_sample_us = ggml_time_us();
bool min_p_applied = false; bool min_p_applied = false;
// if the candidates aren't sorted, try the unsorted implementation first // if the candidates aren't sorted, try the unsorted implementation first
@ -226,19 +208,14 @@ void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_arra
// Resize the output vector to keep only the matching tokens // Resize the output vector to keep only the matching tokens
candidates->size = i; candidates->size = i;
} }
if (smpl) {
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
}
} }
void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) { void llama_sample_tail_free_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float z, size_t min_keep) {
if (z >= 1.0f || candidates->size <= 2) { if (z >= 1.0f || candidates->size <= 2) {
return; return;
} }
llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates); llama_sample_softmax_impl(smpl, candidates);
const int64_t t_start_sample_us = ggml_time_us();
// Compute the first and second derivatives // Compute the first and second derivatives
std::vector<float> first_derivatives(candidates->size - 1); std::vector<float> first_derivatives(candidates->size - 1);
@ -285,13 +262,9 @@ void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_
// Resize the output vector to keep only the tokens above the tail location // Resize the output vector to keep only the tokens above the tail location
candidates->size = last_idx; candidates->size = last_idx;
if (smpl) {
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
}
} }
void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) { void llama_sample_typical_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
// Reference implementation: // Reference implementation:
// https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
if (p >= 1.0f) { if (p >= 1.0f) {
@ -299,9 +272,7 @@ void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_ar
} }
// Compute the softmax of logits and calculate entropy // Compute the softmax of logits and calculate entropy
llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates); llama_sample_softmax_impl(smpl, candidates);
const int64_t t_start_sample_us = ggml_time_us();
float entropy = 0.0f; float entropy = 0.0f;
for (size_t i = 0; i < candidates->size; ++i) { for (size_t i = 0; i < candidates->size; ++i) {
@ -349,15 +320,9 @@ void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_ar
std::copy(new_candidates.begin(), new_candidates.end(), candidates->data); std::copy(new_candidates.begin(), new_candidates.end(), candidates->data);
candidates->size = new_candidates.size(); candidates->size = new_candidates.size();
candidates->sorted = false; candidates->sorted = false;
if (smpl) {
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
}
} }
void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) { void llama_sample_entropy_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) {
const int64_t t_start_sample_us = ggml_time_us();
// no need to do anything if there is only one (or zero) candidates // no need to do anything if there is only one (or zero) candidates
if(candidates->size <= 1) { if(candidates->size <= 1) {
return; return;
@ -366,7 +331,7 @@ void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_ar
// Calculate maximum possible entropy // Calculate maximum possible entropy
float max_entropy = -logf(1.0f / candidates->size); float max_entropy = -logf(1.0f / candidates->size);
llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates); llama_sample_softmax_impl(smpl, candidates);
// Calculate entropy of the softmax probabilities // Calculate entropy of the softmax probabilities
float entropy = 0.0f; float entropy = 0.0f;
@ -416,38 +381,26 @@ void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_ar
LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, candidates->data[i].p * 100.0f); LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, candidates->data[i].p * 100.0f);
} }
#endif #endif
if (smpl) {
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
}
} }
void llama_sample_temp_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float temp) { void llama_sample_temp_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates, float temp) {
const int64_t t_start_sample_us = ggml_time_us();
for (size_t i = 0; i < candidates->size; ++i) { for (size_t i = 0; i < candidates->size; ++i) {
candidates->data[i].logit /= temp; candidates->data[i].logit /= temp;
} }
if (smpl) {
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
}
} }
void llama_sample_repetition_penalties_impl( void llama_sample_repetition_penalties_impl(
struct llama_sampling * smpl, struct llama_sampling & /*smpl*/,
llama_token_data_array * candidates, llama_token_data_array * candidates,
const llama_token * last_tokens, const llama_token * last_tokens,
size_t penalty_last_n, size_t penalty_last_n,
float penalty_repeat, float penalty_repeat,
float penalty_freq, float penalty_freq,
float penalty_present) { float penalty_present) {
if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) { if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) {
return; return;
} }
const int64_t t_start_sample_us = ggml_time_us();
// Create a frequency map to count occurrences of each token in last_tokens // Create a frequency map to count occurrences of each token in last_tokens
std::unordered_map<llama_token, int> token_count; std::unordered_map<llama_token, int> token_count;
for (size_t i = 0; i < penalty_last_n; ++i) { for (size_t i = 0; i < penalty_last_n; ++i) {
@ -475,21 +428,14 @@ void llama_sample_repetition_penalties_impl(
} }
candidates->sorted = false; candidates->sorted = false;
if (smpl) {
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
}
} }
void llama_sample_apply_guidance_impl( void llama_sample_apply_guidance_impl(
struct llama_sampling * smpl, struct llama_sampling & smpl,
float * logits, float * logits,
float * logits_guidance, float * logits_guidance,
float scale) { float scale) {
GGML_ASSERT(smpl); const auto n_vocab = smpl.n_vocab;
const auto t_start_sample_us = ggml_time_us();
const auto n_vocab = smpl->n_vocab;
llama_log_softmax(logits, n_vocab); llama_log_softmax(logits, n_vocab);
llama_log_softmax(logits_guidance, n_vocab); llama_log_softmax(logits_guidance, n_vocab);
@ -500,18 +446,12 @@ void llama_sample_apply_guidance_impl(
l = scale * (l - g) + g; l = scale * (l - g) + g;
} }
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
} }
llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) { llama_token llama_sample_token_mirostat_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
GGML_ASSERT(smpl); const int32_t n_vocab = float(smpl.n_vocab);
const int32_t n_vocab = float(smpl->n_vocab); llama_sample_softmax_impl(smpl, candidates);
int64_t t_start_sample_us = ggml_time_us();
llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
// Estimate s_hat using the most probable m tokens // Estimate s_hat using the most probable m tokens
float s_hat = 0.0; float s_hat = 0.0;
@ -530,10 +470,8 @@ llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama
float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat); float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat);
// Sample the next word X using top-k sampling // Sample the next word X using top-k sampling
llama_sample_top_k_impl((struct llama_sampling *) nullptr, candidates, int(k), 1); llama_sample_top_k_impl(smpl, candidates, int(k), 1);
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
llama_token X = llama_sample_token_impl(smpl, candidates); llama_token X = llama_sample_token_impl(smpl, candidates);
t_start_sample_us = ggml_time_us();
// Compute error as the difference between observed surprise and target surprise value // Compute error as the difference between observed surprise and target surprise value
size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
@ -545,14 +483,10 @@ llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama
// Update mu using the learning rate and error // Update mu using the learning rate and error
*mu = *mu - eta * e; *mu = *mu - eta * e;
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
return X; return X;
} }
llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) { llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) {
int64_t t_start_sample_us;
t_start_sample_us = ggml_time_us();
llama_sample_softmax_impl(smpl, candidates); llama_sample_softmax_impl(smpl, candidates);
// Truncate the words with surprise values greater than mu // Truncate the words with surprise values greater than mu
@ -564,16 +498,11 @@ llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, ll
candidates->size = 1; candidates->size = 1;
} }
if (smpl) {
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
}
// Normalize the probabilities of the remaining words // Normalize the probabilities of the remaining words
llama_sample_softmax_impl(smpl, candidates); llama_sample_softmax_impl(smpl, candidates);
// Sample the next word X from the remaining words // Sample the next word X from the remaining words
llama_token X = llama_sample_token_impl(smpl, candidates); llama_token X = llama_sample_token_impl(smpl, candidates);
t_start_sample_us = ggml_time_us();
// Compute error as the difference between observed surprise and target surprise value // Compute error as the difference between observed surprise and target surprise value
size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
@ -585,33 +514,22 @@ llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, ll
// Update mu using the learning rate and error // Update mu using the learning rate and error
*mu = *mu - eta * e; *mu = *mu - eta * e;
if (smpl) {
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
}
return X; return X;
} }
llama_token llama_sample_token_greedy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) { llama_token llama_sample_token_greedy_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates) {
const int64_t t_start_sample_us = ggml_time_us();
// Find max element // Find max element
auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
return a.logit < b.logit; return a.logit < b.logit;
}); });
llama_token result = max_iter->id; llama_token result = max_iter->id;
if (smpl) {
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
smpl->n_sample++;
}
return result; return result;
} }
llama_token llama_sample_token_with_rng_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng) { llama_token llama_sample_token_with_rng_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, std::mt19937 & rng) {
GGML_ASSERT(smpl); llama_sample_softmax_impl(smpl, candidates);
const int64_t t_start_sample_us = ggml_time_us();
llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
std::vector<float> probs; std::vector<float> probs;
probs.reserve(candidates->size); probs.reserve(candidates->size);
@ -624,12 +542,9 @@ llama_token llama_sample_token_with_rng_impl(struct llama_sampling * smpl, llama
llama_token result = candidates->data[idx].id; llama_token result = candidates->data[idx].id;
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
smpl->n_sample++;
return result; return result;
} }
llama_token llama_sample_token_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) { llama_token llama_sample_token_impl(struct llama_sampling & smpl, llama_token_data_array * candidates) {
return llama_sample_token_with_rng_impl(smpl, candidates, smpl->rng); return llama_sample_token_with_rng_impl(smpl, candidates, smpl.rng);
} }

View file

@ -3,38 +3,30 @@
#include "llama-impl.h" #include "llama-impl.h"
struct llama_sampling { struct llama_sampling {
llama_sampling(int32_t n_vocab) : n_vocab(n_vocab) {} llama_sampling(uint32_t seed, int32_t n_vocab) : rng(seed), n_vocab(n_vocab) {}
std::mt19937 rng; std::mt19937 rng;
int32_t n_vocab = 0; const int32_t n_vocab;
mutable int64_t t_sample_us = 0;
mutable int32_t n_sample = 0;
void reset_timings() const {
t_sample_us = 0;
n_sample = 0;
}
}; };
// //
// internal API // internal API
// //
void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed); void llama_set_rng_seed_impl(struct llama_sampling & smpl, uint32_t seed);
void llama_sample_softmax_impl (struct llama_sampling * smpl, llama_token_data_array * candidates); void llama_sample_softmax_impl (struct llama_sampling & smpl, llama_token_data_array * candidates);
void llama_sample_top_k_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep); void llama_sample_top_k_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep);
void llama_sample_top_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep); void llama_sample_top_p_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep);
void llama_sample_min_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep); void llama_sample_min_p_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep);
void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep); void llama_sample_tail_free_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float z, size_t min_keep);
void llama_sample_typical_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep); void llama_sample_typical_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep);
void llama_sample_entropy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val); void llama_sample_entropy_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val);
void llama_sample_temp_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float temp); void llama_sample_temp_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float temp);
void llama_sample_repetition_penalties_impl( void llama_sample_repetition_penalties_impl(
struct llama_sampling * smpl, struct llama_sampling & smpl,
llama_token_data_array * candidates, llama_token_data_array * candidates,
const llama_token * last_tokens, const llama_token * last_tokens,
size_t penalty_last_n, size_t penalty_last_n,
@ -43,14 +35,14 @@ void llama_sample_repetition_penalties_impl(
float penalty_present); float penalty_present);
void llama_sample_apply_guidance_impl( void llama_sample_apply_guidance_impl(
struct llama_sampling * smpl, struct llama_sampling & smpl,
float * logits, float * logits,
float * logits_guidance, float * logits_guidance,
float scale); float scale);
llama_token llama_sample_token_mirostat_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu); llama_token llama_sample_token_mirostat_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu);
llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu); llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float tau, float eta, float * mu);
llama_token llama_sample_token_greedy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates); llama_token llama_sample_token_greedy_impl (struct llama_sampling & smpl, llama_token_data_array * candidates);
llama_token llama_sample_token_with_rng_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng); llama_token llama_sample_token_with_rng_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, std::mt19937 & rng);
llama_token llama_sample_token_impl (struct llama_sampling * smpl, llama_token_data_array * candidates); llama_token llama_sample_token_impl (struct llama_sampling & smpl, llama_token_data_array * candidates);

View file

@ -163,6 +163,19 @@ static void zeros(std::ofstream & file, size_t n) {
} }
} }
struct time_meas {
time_meas(int64_t & t_acc) : t_start_us(ggml_time_us()), t_acc(t_acc) {}
~time_meas() {
t_acc += ggml_time_us() - t_start_us;
}
const int64_t t_start_us;
int64_t & t_acc;
};
LLAMA_ATTRIBUTE_FORMAT(1, 2) LLAMA_ATTRIBUTE_FORMAT(1, 2)
static std::string format(const char * fmt, ...) { static std::string format(const char * fmt, ...) {
va_list ap; va_list ap;
@ -2656,7 +2669,6 @@ struct llama_model {
struct llama_context { struct llama_context {
llama_context(const llama_model & model) llama_context(const llama_model & model)
: model(model) : model(model)
, sampling(llama_n_vocab(&model))
, grammar() , grammar()
, t_start_us(model.t_start_us) , t_start_us(model.t_start_us)
, t_load_us(model.t_load_us) {} , t_load_us(model.t_load_us) {}
@ -2674,11 +2686,12 @@ struct llama_context {
const struct llama_model & model; const struct llama_model & model;
struct llama_cparams cparams; struct llama_cparams cparams;
struct llama_sampling sampling;
struct llama_grammar grammar; struct llama_grammar grammar;
struct llama_kv_cache kv_self; struct llama_kv_cache kv_self;
struct llama_control_vector cvec; struct llama_control_vector cvec;
std::vector<struct llama_sampling> sampling; // sampling context for each sequence
std::unordered_map<struct llama_lora_adapter *, float> lora_adapters; std::unordered_map<struct llama_lora_adapter *, float> lora_adapters;
std::vector<ggml_backend_t> backends; std::vector<ggml_backend_t> backends;
@ -2692,16 +2705,18 @@ struct llama_context {
bool has_evaluated_once = false; bool has_evaluated_once = false;
int64_t t_start_us; mutable int64_t t_start_us;
int64_t t_load_us; mutable int64_t t_load_us;
int64_t t_p_eval_us = 0; mutable int64_t t_sample_us = 0;
int64_t t_eval_us = 0; mutable int64_t t_p_eval_us = 0;
mutable int64_t t_eval_us = 0;
int64_t t_compute_start_us = 0; mutable int64_t t_compute_start_us = 0;
int64_t n_queued_tokens = 0; mutable int64_t n_queued_tokens = 0;
int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1) mutable int32_t n_sample = 0;
int32_t n_eval = 0; // number of eval calls mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
mutable int32_t n_eval = 0; // number of eval calls
// host buffer for the model output (logits and embeddings) // host buffer for the model output (logits and embeddings)
ggml_backend_buffer_t buf_output = nullptr; ggml_backend_buffer_t buf_output = nullptr;
@ -16527,8 +16542,12 @@ struct llama_context * llama_new_context_with_model(
ctx->abort_callback = params.abort_callback; ctx->abort_callback = params.abort_callback;
ctx->abort_callback_data = params.abort_callback_data; ctx->abort_callback_data = params.abort_callback_data;
ctx->sampling.rng = std::mt19937(params.seed); ctx->sampling.reserve(cparams.n_seq_max);
ctx->logits_all = params.logits_all; for (uint32_t i = 0; i < cparams.n_seq_max; ++i) {
ctx->sampling.emplace_back(params.seed, llama_n_vocab(model));
}
ctx->logits_all = params.logits_all;
uint32_t kv_size = cparams.n_ctx; uint32_t kv_size = cparams.n_ctx;
ggml_type type_k = params.type_k; ggml_type type_k = params.type_k;
@ -17303,6 +17322,7 @@ size_t llama_state_get_size(const struct llama_context * ctx) {
// we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state. // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
// for reference, std::mt19937(1337) serializes to 6701 bytes. // for reference, std::mt19937(1337) serializes to 6701 bytes.
const size_t s_n_rng = sizeof(uint32_t);
const size_t s_rng_size = sizeof(size_t); const size_t s_rng_size = sizeof(size_t);
const size_t s_rng = LLAMA_MAX_RNG_STATE; const size_t s_rng = LLAMA_MAX_RNG_STATE;
const size_t s_n_outputs = sizeof(size_t); const size_t s_n_outputs = sizeof(size_t);
@ -17322,8 +17342,8 @@ size_t llama_state_get_size(const struct llama_context * ctx) {
const size_t s_kv_cells = ctx->kv_self.size * s_kv_cell; const size_t s_kv_cells = ctx->kv_self.size * s_kv_cell;
const size_t s_total = ( const size_t s_total = (
+ s_rng_size + s_n_rng
+ s_rng + cparams.n_seq_max*(s_rng_size + s_rng)
+ s_n_outputs + s_n_outputs
+ s_output_pos + s_output_pos
+ s_logits_size + s_logits_size
@ -17340,7 +17360,7 @@ size_t llama_state_get_size(const struct llama_context * ctx) {
); );
// on session change it is very likely that the state size has changed - so we need to update this function // on session change it is very likely that the state size has changed - so we need to update this function
static_assert(LLAMA_SESSION_VERSION == 7, "So you just bumped the session version - good. But did you remember to update llama_state_get_size?"); static_assert(LLAMA_SESSION_VERSION == 8, "So you just bumped the session version - good. But did you remember to update llama_state_get_size?");
return s_total; return s_total;
} }
@ -17401,18 +17421,24 @@ struct llama_data_file_context : llama_data_context {
static void llama_state_get_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) { static void llama_state_get_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) {
llama_synchronize(ctx); llama_synchronize(ctx);
// copy rng // copy rngs
{ {
std::ostringstream rng_ss; const uint32_t n_rng = ctx->sampling.size();
rng_ss << ctx->sampling.rng;
const std::string & rng_str = rng_ss.str(); data_ctx->write(&n_rng, sizeof(n_rng));
const size_t rng_size = rng_str.size();
GGML_ASSERT(rng_size <= LLAMA_MAX_RNG_STATE); for (const auto & smpl : ctx->sampling) {
std::ostringstream rng_ss;
rng_ss << smpl.rng;
data_ctx->write(&rng_size, sizeof(rng_size)); const std::string & rng_str = rng_ss.str();
data_ctx->write(rng_str.data(), rng_size); const size_t rng_size = rng_str.size();
GGML_ASSERT(rng_size <= LLAMA_MAX_RNG_STATE);
data_ctx->write(&rng_size, sizeof(rng_size));
data_ctx->write(rng_str.data(), rng_size);
}
} }
// copy outputs // copy outputs
@ -17560,19 +17586,26 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
const uint8_t * inp = src; const uint8_t * inp = src;
// set rng // set rngs
{ {
size_t rng_size; uint32_t n_rng;
memcpy(&rng_size, inp, sizeof(rng_size)); inp += sizeof(rng_size); memcpy(&n_rng, inp, sizeof(n_rng)); inp += sizeof(n_rng);
GGML_ASSERT(rng_size <= LLAMA_MAX_RNG_STATE); GGML_ASSERT(n_rng == ctx->cparams.n_seq_max);
std::string rng_str((const char *)inp, rng_size); inp += rng_size; for (auto & smpl : ctx->sampling) {
size_t rng_size;
memcpy(&rng_size, inp, sizeof(rng_size)); inp += sizeof(rng_size);
std::istringstream rng_ss(rng_str); GGML_ASSERT(rng_size <= LLAMA_MAX_RNG_STATE);
rng_ss >> ctx->sampling.rng;
GGML_ASSERT(!rng_ss.fail()); std::string rng_str((const char *)inp, rng_size); inp += rng_size;
std::istringstream rng_ss(rng_str);
rng_ss >> smpl.rng;
GGML_ASSERT(!rng_ss.fail());
}
} }
// set output ids // set output ids
@ -18930,18 +18963,24 @@ struct llama_grammar * llama_grammar_init(
} }
void llama_grammar_free(struct llama_grammar * grammar) { void llama_grammar_free(struct llama_grammar * grammar) {
if (grammar == nullptr) {
return;
}
llama_grammar_free_impl(grammar); llama_grammar_free_impl(grammar);
} }
struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) { struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) {
return llama_grammar_copy_impl(grammar); return llama_grammar_copy_impl(*grammar);
} }
void llama_grammar_sample( void llama_grammar_sample(
const struct llama_grammar * grammar, const struct llama_grammar * grammar,
const struct llama_context * ctx, const struct llama_context * ctx,
llama_token_data_array * candidates) { llama_token_data_array * candidates) {
llama_grammar_sample_impl(grammar, &ctx->model.vocab, &ctx->sampling, candidates); time_meas tm(ctx->t_sample_us); // TODO: measure grammar time separately from sampling
llama_grammar_sample_impl(*grammar, ctx->model.vocab, candidates);
} }
void llama_sample_grammar( void llama_sample_grammar(
@ -18955,7 +18994,9 @@ void llama_grammar_accept_token(
struct llama_grammar * grammar, struct llama_grammar * grammar,
struct llama_context * ctx, struct llama_context * ctx,
llama_token token) { llama_token token) {
llama_grammar_accept_token_impl(grammar, &ctx->model.vocab, &ctx->sampling, token); time_meas tm(ctx->t_sample_us); // TODO: measure grammar time separately from sampling
llama_grammar_accept_token_impl(*grammar, ctx->model.vocab, token);
} }
// //
@ -18963,39 +19004,59 @@ void llama_grammar_accept_token(
// //
void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) { void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) {
llama_set_rng_seed_impl(&ctx->sampling, seed); llama_set_rng_seed_impl(ctx->sampling[0], seed);
}
void llama_set_rng_seed_seq(struct llama_context * ctx, uint32_t seed, llama_seq_id seq_id) {
llama_set_rng_seed_impl(ctx->sampling[seq_id], seed);
} }
void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) { void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) {
llama_sample_softmax_impl(ctx ? &ctx->sampling : nullptr, candidates); time_meas tm(ctx->t_sample_us);
llama_sample_softmax_impl(ctx->sampling[0], candidates);
} }
void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int32_t k, size_t min_keep) { void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
llama_sample_top_k_impl(ctx ? &ctx->sampling : nullptr, candidates, k, min_keep); time_meas tm(ctx->t_sample_us);
llama_sample_top_k_impl(ctx->sampling[0], candidates, k, min_keep);
} }
void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
llama_sample_top_p_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep); time_meas tm(ctx->t_sample_us);
llama_sample_top_p_impl(ctx->sampling[0], candidates, p, min_keep);
} }
void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
llama_sample_min_p_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep); time_meas tm(ctx->t_sample_us);
llama_sample_min_p_impl(ctx->sampling[0], candidates, p, min_keep);
} }
void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) { void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) {
llama_sample_tail_free_impl(ctx ? &ctx->sampling : nullptr, candidates, z, min_keep); time_meas tm(ctx->t_sample_us);
llama_sample_tail_free_impl(ctx->sampling[0], candidates, z, min_keep);
} }
void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
llama_sample_typical_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep); time_meas tm(ctx->t_sample_us);
llama_sample_typical_impl(ctx->sampling[0], candidates, p, min_keep);
} }
void llama_sample_entropy(struct llama_context * ctx, llama_token_data_array * candidates_p, float min_temp, float max_temp, float exponent_val) { void llama_sample_entropy(struct llama_context * ctx, llama_token_data_array * candidates_p, float min_temp, float max_temp, float exponent_val) {
llama_sample_entropy_impl(ctx ? &ctx->sampling : nullptr, candidates_p, min_temp, max_temp, exponent_val); time_meas tm(ctx->t_sample_us);
llama_sample_entropy_impl(ctx->sampling[0], candidates_p, min_temp, max_temp, exponent_val);
} }
void llama_sample_temp(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) { void llama_sample_temp(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) {
llama_sample_temp_impl(ctx ? &ctx->sampling : nullptr, candidates_p, temp); time_meas tm(ctx->t_sample_us);
llama_sample_temp_impl(ctx->sampling[0], candidates_p, temp);
} }
void llama_sample_repetition_penalties( void llama_sample_repetition_penalties(
@ -19006,7 +19067,9 @@ void llama_sample_repetition_penalties(
float penalty_repeat, float penalty_repeat,
float penalty_freq, float penalty_freq,
float penalty_present) { float penalty_present) {
llama_sample_repetition_penalties_impl(ctx ? &ctx->sampling : nullptr, candidates, last_tokens, penalty_last_n, penalty_repeat, penalty_freq, penalty_present); time_meas tm(ctx->t_sample_us);
llama_sample_repetition_penalties_impl(ctx->sampling[0], candidates, last_tokens, penalty_last_n, penalty_repeat, penalty_freq, penalty_present);
} }
void llama_sample_apply_guidance( void llama_sample_apply_guidance(
@ -19014,27 +19077,55 @@ void llama_sample_apply_guidance(
float * logits, float * logits,
float * logits_guidance, float * logits_guidance,
float scale) { float scale) {
llama_sample_apply_guidance_impl(&ctx->sampling, logits, logits_guidance, scale); time_meas tm(ctx->t_sample_us);
llama_sample_apply_guidance_impl(ctx->sampling[0], logits, logits_guidance, scale);
} }
llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) { llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
return llama_sample_token_mirostat_impl(&ctx->sampling, candidates, tau, eta, m, mu); time_meas tm(ctx->t_sample_us);
auto res = llama_sample_token_mirostat_impl(ctx->sampling[0], candidates, tau, eta, m, mu);
ctx->n_sample++;
return res;
} }
llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu) { llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu) {
return llama_sample_token_mirostat_v2_impl(ctx ? &ctx->sampling : nullptr, candidates, tau, eta, mu); time_meas tm(ctx->t_sample_us);
auto res = llama_sample_token_mirostat_v2_impl(ctx->sampling[0], candidates, tau, eta, mu);
ctx->n_sample++;
return res;
} }
llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates) { llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates) {
return llama_sample_token_greedy_impl(ctx ? &ctx->sampling : nullptr, candidates); time_meas tm(ctx->t_sample_us);
}
llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng) { auto res = llama_sample_token_greedy_impl(ctx->sampling[0], candidates);
return llama_sample_token_with_rng_impl(&ctx->sampling, candidates, rng);
ctx->n_sample++;
return res;
} }
llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) { llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) {
return llama_sample_token_with_rng_impl(&ctx->sampling, candidates, ctx->sampling.rng); return llama_sample_token_seq(ctx, candidates, 0);
}
llama_token llama_sample_token_seq(struct llama_context * ctx, llama_token_data_array * candidates, llama_seq_id seq_id) {
GGML_ASSERT(seq_id >= 0 && seq_id < (int32_t) ctx->cparams.n_seq_max);
time_meas tm(ctx->t_sample_us);
auto res = llama_sample_token_impl(ctx->sampling[seq_id], candidates);
ctx->n_sample++;
return res;
} }
int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) { int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) {
@ -19066,11 +19157,11 @@ struct llama_timings llama_get_timings(struct llama_context * ctx) {
/*.t_start_ms =*/ 1e-3 * ctx->t_start_us, /*.t_start_ms =*/ 1e-3 * ctx->t_start_us,
/*.t_end_ms =*/ 1.00 * ggml_time_ms(), /*.t_end_ms =*/ 1.00 * ggml_time_ms(),
/*.t_load_ms =*/ 1e-3 * ctx->t_load_us, /*.t_load_ms =*/ 1e-3 * ctx->t_load_us,
/*.t_sample_ms =*/ 1e-3 * ctx->sampling.t_sample_us, /*.t_sample_ms =*/ 1e-3 * ctx->t_sample_us,
/*.t_p_eval_ms =*/ 1e-3 * ctx->t_p_eval_us, /*.t_p_eval_ms =*/ 1e-3 * ctx->t_p_eval_us,
/*.t_eval_ms =*/ 1e-3 * ctx->t_eval_us, /*.t_eval_ms =*/ 1e-3 * ctx->t_eval_us,
/*.n_sample =*/ std::max(1, ctx->sampling.n_sample), /*.n_sample =*/ std::max(1, ctx->n_sample),
/*.n_p_eval =*/ std::max(0, ctx->n_p_eval), /*.n_p_eval =*/ std::max(0, ctx->n_p_eval),
/*.n_eval =*/ std::max(1, ctx->n_eval), /*.n_eval =*/ std::max(1, ctx->n_eval),
}; };
@ -19095,9 +19186,8 @@ void llama_print_timings(struct llama_context * ctx) {
void llama_reset_timings(struct llama_context * ctx) { void llama_reset_timings(struct llama_context * ctx) {
ctx->t_start_us = ggml_time_us(); ctx->t_start_us = ggml_time_us();
ctx->t_eval_us = ctx->n_eval = 0; ctx->t_eval_us = ctx->n_eval = 0;
ctx->t_sample_us = ctx->n_sample = 0;
ctx->t_p_eval_us = ctx->n_p_eval = 0; ctx->t_p_eval_us = ctx->n_p_eval = 0;
ctx->sampling.reset_timings();
} }
const char * llama_print_system_info(void) { const char * llama_print_system_info(void) {
@ -19144,20 +19234,20 @@ void llama_dump_timing_info_yaml(FILE * stream, const llama_context * ctx) {
fprintf(stream, "mst_p_eval: %.2f # ms / token during prompt processing\n", fprintf(stream, "mst_p_eval: %.2f # ms / token during prompt processing\n",
1.0e-3 * ctx->t_p_eval_us / ctx->n_p_eval); 1.0e-3 * ctx->t_p_eval_us / ctx->n_p_eval);
fprintf(stream, "mst_sample: %.2f # ms / token during sampling\n", fprintf(stream, "mst_sample: %.2f # ms / token during sampling\n",
1.0e-3 * ctx->sampling.t_sample_us / ctx->sampling.n_sample); 1.0e-3 * ctx->t_sample_us / ctx->n_sample);
fprintf(stream, "n_eval: %d # number of tokens generated (excluding the first one)\n", ctx->n_eval); fprintf(stream, "n_eval: %d # number of tokens generated (excluding the first one)\n", ctx->n_eval);
fprintf(stream, "n_p_eval: %d # number of tokens processed in batches at the beginning\n", ctx->n_p_eval); fprintf(stream, "n_p_eval: %d # number of tokens processed in batches at the beginning\n", ctx->n_p_eval);
fprintf(stream, "n_sample: %d # number of sampled tokens\n", ctx->sampling.n_sample); fprintf(stream, "n_sample: %d # number of sampled tokens\n", ctx->n_sample);
fprintf(stream, "t_eval_us: %" PRId64 " # total microseconds spent generating tokens\n", ctx->t_eval_us); fprintf(stream, "t_eval_us: %" PRId64 " # total microseconds spent generating tokens\n", ctx->t_eval_us);
fprintf(stream, "t_load_us: %" PRId64 " # total microseconds spent loading the model\n", ctx->t_load_us); fprintf(stream, "t_load_us: %" PRId64 " # total microseconds spent loading the model\n", ctx->t_load_us);
fprintf(stream, "t_p_eval_us: %" PRId64 " # total microseconds spent prompt processing\n", ctx->t_p_eval_us); fprintf(stream, "t_p_eval_us: %" PRId64 " # total microseconds spent prompt processing\n", ctx->t_p_eval_us);
fprintf(stream, "t_sample_us: %" PRId64 " # total microseconds spent sampling\n", ctx->sampling.t_sample_us); fprintf(stream, "t_sample_us: %" PRId64 " # total microseconds spent sampling\n", ctx->t_sample_us);
fprintf(stream, "ts_eval: %.2f # tokens / second during generation\n", fprintf(stream, "ts_eval: %.2f # tokens / second during generation\n",
1.0e6 * ctx->n_eval / ctx->t_eval_us); 1.0e6 * ctx->n_eval / ctx->t_eval_us);
fprintf(stream, "ts_p_eval: %.2f # tokens / second during prompt processing\n", fprintf(stream, "ts_p_eval: %.2f # tokens / second during prompt processing\n",
1.0e6 * ctx->n_p_eval / ctx->t_p_eval_us); 1.0e6 * ctx->n_p_eval / ctx->t_p_eval_us);
fprintf(stream, "ts_sample: %.2f # tokens / second during sampling\n", fprintf(stream, "ts_sample: %.2f # tokens / second during sampling\n",
1.0e6 * ctx->sampling.n_sample / ctx->sampling.t_sample_us); 1.0e6 * ctx->n_sample / ctx->t_sample_us);
} }
// For internal test use // For internal test use

View file

@ -2,13 +2,13 @@
#undef NDEBUG #undef NDEBUG
#endif #endif
#define LLAMA_API_INTERNAL
#include "ggml.h" #include "ggml.h"
#include "llama.h" #include "llama.h"
#include "llama-impl.h"
#include "unicode.h"
#include "grammar-parser.h" #include "grammar-parser.h"
#include "json-schema-to-grammar.h" #include "json-schema-to-grammar.h"
#include "unicode.h"
#include <cassert> #include <cassert>
#include <string> #include <string>
#include <vector> #include <vector>

View file

@ -2,8 +2,8 @@
#undef NDEBUG #undef NDEBUG
#endif #endif
#define LLAMA_API_INTERNAL
#include "llama.h" #include "llama.h"
#include "llama-impl.h"
#include "grammar-parser.h" #include "grammar-parser.h"
#include <cassert> #include <cassert>

View file

@ -1,5 +1,5 @@
#include "ggml.h" #include "ggml.h"
#include "llama.h" #include "llama-sampling.h"
#ifdef NDEBUG #ifdef NDEBUG
#undef NDEBUG #undef NDEBUG
@ -20,6 +20,8 @@ static void dump(const llama_token_data_array * candidates) {
static void test_top_k(const std::vector<float> & probs, const std::vector<float> & expected_probs, int k) { static void test_top_k(const std::vector<float> & probs, const std::vector<float> & expected_probs, int k) {
const size_t n_vocab = probs.size(); const size_t n_vocab = probs.size();
llama_sampling smpl(1234, n_vocab);
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab); candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
@ -28,9 +30,9 @@ static void test_top_k(const std::vector<float> & probs, const std::vector<float
} }
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
llama_sample_softmax(nullptr, &candidates_p); llama_sample_softmax_impl(smpl, &candidates_p);
DUMP(&candidates_p); DUMP(&candidates_p);
llama_sample_top_k(nullptr, &candidates_p, k, 1); llama_sample_top_k_impl(smpl, &candidates_p, k, 1);
DUMP(&candidates_p); DUMP(&candidates_p);
GGML_ASSERT(candidates_p.size == expected_probs.size()); GGML_ASSERT(candidates_p.size == expected_probs.size());
@ -41,6 +43,8 @@ static void test_top_k(const std::vector<float> & probs, const std::vector<float
static void test_top_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) { static void test_top_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
const size_t n_vocab = probs.size(); const size_t n_vocab = probs.size();
llama_sampling smpl(1234, n_vocab);
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab); candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
@ -49,9 +53,9 @@ static void test_top_p(const std::vector<float> & probs, const std::vector<float
} }
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
llama_sample_softmax(nullptr, &candidates_p); llama_sample_softmax_impl(smpl, &candidates_p);
DUMP(&candidates_p); DUMP(&candidates_p);
llama_sample_top_p(nullptr, &candidates_p, p, 1); llama_sample_top_p_impl(smpl, &candidates_p, p, 1);
DUMP(&candidates_p); DUMP(&candidates_p);
GGML_ASSERT(candidates_p.size == expected_probs.size()); GGML_ASSERT(candidates_p.size == expected_probs.size());
@ -62,6 +66,8 @@ static void test_top_p(const std::vector<float> & probs, const std::vector<float
static void test_tfs(const std::vector<float> & probs, const std::vector<float> & expected_probs, float z) { static void test_tfs(const std::vector<float> & probs, const std::vector<float> & expected_probs, float z) {
const size_t n_vocab = probs.size(); const size_t n_vocab = probs.size();
llama_sampling smpl(1234, n_vocab);
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab); candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
@ -71,7 +77,7 @@ static void test_tfs(const std::vector<float> & probs, const std::vector<float>
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
DUMP(&candidates_p); DUMP(&candidates_p);
llama_sample_tail_free(nullptr, &candidates_p, z, 1); llama_sample_tail_free_impl(smpl, &candidates_p, z, 1);
DUMP(&candidates_p); DUMP(&candidates_p);
GGML_ASSERT(candidates_p.size == expected_probs.size()); GGML_ASSERT(candidates_p.size == expected_probs.size());
@ -82,6 +88,8 @@ static void test_tfs(const std::vector<float> & probs, const std::vector<float>
static void test_min_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) { static void test_min_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
const size_t n_vocab = probs.size(); const size_t n_vocab = probs.size();
llama_sampling smpl(1234, n_vocab);
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab); candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
@ -91,9 +99,9 @@ static void test_min_p(const std::vector<float> & probs, const std::vector<float
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
DUMP(&candidates_p); DUMP(&candidates_p);
llama_sample_min_p(nullptr, &candidates_p, p, 1); llama_sample_min_p_impl(smpl, &candidates_p, p, 1);
DUMP(&candidates_p); DUMP(&candidates_p);
llama_sample_softmax(nullptr, &candidates_p); llama_sample_softmax_impl(smpl, &candidates_p);
GGML_ASSERT(candidates_p.size == expected_probs.size()); GGML_ASSERT(candidates_p.size == expected_probs.size());
for (size_t i = 0; i < candidates_p.size; i++) { for (size_t i = 0; i < candidates_p.size; i++) {
@ -103,6 +111,8 @@ static void test_min_p(const std::vector<float> & probs, const std::vector<float
static void test_typical(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) { static void test_typical(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
const size_t n_vocab = probs.size(); const size_t n_vocab = probs.size();
llama_sampling smpl(1234, n_vocab);
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab); candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
@ -112,7 +122,7 @@ static void test_typical(const std::vector<float> & probs, const std::vector<flo
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
DUMP(&candidates_p); DUMP(&candidates_p);
llama_sample_typical(nullptr, &candidates_p, p, 1); llama_sample_typical_impl(smpl, &candidates_p, p, 1);
DUMP(&candidates_p); DUMP(&candidates_p);
GGML_ASSERT(candidates_p.size == expected_probs.size()); GGML_ASSERT(candidates_p.size == expected_probs.size());
@ -128,6 +138,8 @@ static void test_repetition_penalties(
GGML_ASSERT(probs.size() == expected_probs.size()); GGML_ASSERT(probs.size() == expected_probs.size());
const size_t n_vocab = probs.size(); const size_t n_vocab = probs.size();
llama_sampling smpl(1234, n_vocab);
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab); candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
@ -136,10 +148,10 @@ static void test_repetition_penalties(
} }
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
llama_sample_softmax(nullptr, &candidates_p); llama_sample_softmax_impl(smpl, &candidates_p);
DUMP(&candidates_p); DUMP(&candidates_p);
llama_sample_repetition_penalties(nullptr, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence); llama_sample_repetition_penalties_impl(smpl, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence);
llama_sample_softmax(nullptr, &candidates_p); llama_sample_softmax_impl(smpl, &candidates_p);
DUMP(&candidates_p); DUMP(&candidates_p);
GGML_ASSERT(candidates_p.size == expected_probs.size()); GGML_ASSERT(candidates_p.size == expected_probs.size());
@ -148,9 +160,10 @@ static void test_repetition_penalties(
} }
} }
static void test_sampler_queue( static void test_sampler_queue(const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p
const size_t n_vocab, const std::string samplers_sequence, const int top_k, const float top_p, const float min_p
) { ) {
llama_sampling smpl(1234, n_vocab);
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab); candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
@ -165,16 +178,16 @@ static void test_sampler_queue(
for (auto s : samplers_sequence) { for (auto s : samplers_sequence) {
switch (s){ switch (s){
case 'k': llama_sample_top_k (nullptr, &candidates_p, top_k, 1); break; case 'k': llama_sample_top_k_impl(smpl, &candidates_p, top_k, 1); break;
case 'f': GGML_ASSERT(false && "tail_free test not implemented"); break; case 'f': GGML_ASSERT(false && "tail_free test not implemented"); break;
case 'y': GGML_ASSERT(false && "typical test not implemented"); break; case 'y': GGML_ASSERT(false && "typical test not implemented"); break;
case 'p': llama_sample_top_p (nullptr, &candidates_p, top_p, 1); break; case 'p': llama_sample_top_p_impl(smpl, &candidates_p, top_p, 1); break;
case 'm': llama_sample_min_p (nullptr, &candidates_p, min_p, 1); break; case 'm': llama_sample_min_p_impl(smpl, &candidates_p, min_p, 1); break;
case 't': GGML_ASSERT(false && "temperature test not implemented"); break; case 't': GGML_ASSERT(false && "temperature test not implemented"); break;
default : GGML_ASSERT(false && "Unknown sampler"); break; default : GGML_ASSERT(false && "Unknown sampler"); break;
} }
llama_sample_softmax(nullptr, &candidates_p); // make sure tokens are sorted for tests llama_sample_softmax_impl(smpl, &candidates_p); // make sure tokens are sorted for tests
const int size = candidates_p.size; const int size = candidates_p.size;