llama : move sampling rngs from common to llama
ggml-ci
This commit is contained in:
parent
938943cdbf
commit
f866cb9342
22 changed files with 342 additions and 344 deletions
|
@ -1,11 +1,13 @@
|
|||
#define LLAMA_API_INTERNAL
|
||||
#include "sampling.h"
|
||||
|
||||
#include <random>
|
||||
|
||||
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
|
||||
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, struct llama_context * ctx, llama_seq_id seq_id) {
|
||||
struct llama_sampling_context * result = new llama_sampling_context();
|
||||
|
||||
result->params = params;
|
||||
result->seq_id = seq_id;
|
||||
result->ctx = ctx;
|
||||
result->grammar = nullptr;
|
||||
|
||||
// if there is a grammar, parse it
|
||||
|
@ -81,7 +83,7 @@ void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t s
|
|||
if (seed == LLAMA_DEFAULT_SEED) {
|
||||
seed = std::random_device{}();
|
||||
}
|
||||
ctx->rng.seed(seed);
|
||||
llama_set_rng_seed_seq(ctx->ctx, seed, ctx->seq_id);
|
||||
}
|
||||
|
||||
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) {
|
||||
|
@ -271,10 +273,10 @@ static llama_token llama_sampling_sample_impl(
|
|||
bool is_resampling) {
|
||||
const llama_sampling_params & params = ctx_sampling->params;
|
||||
|
||||
const float temp = params.temp;
|
||||
const int mirostat = params.mirostat;
|
||||
const float mirostat_tau = params.mirostat_tau;
|
||||
const float mirostat_eta = params.mirostat_eta;
|
||||
const float temp = params.temp;
|
||||
const int mirostat = params.mirostat;
|
||||
const float mirostat_tau = params.mirostat_tau;
|
||||
const float mirostat_eta = params.mirostat_eta;
|
||||
|
||||
std::vector<float> original_logits;
|
||||
auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits);
|
||||
|
@ -304,7 +306,7 @@ static llama_token llama_sampling_sample_impl(
|
|||
|
||||
sampler_queue(ctx_main, params, cur_p, min_keep);
|
||||
|
||||
id = llama_sample_token_with_rng(ctx_main, &cur_p, ctx_sampling->rng);
|
||||
id = llama_sample_token_seq(ctx_main, &cur_p, ctx_sampling->seq_id);
|
||||
|
||||
//{
|
||||
// const int n_top = 10;
|
||||
|
|
|
@ -70,9 +70,12 @@ struct llama_sampling_context {
|
|||
// parameters that will be used for sampling
|
||||
llama_sampling_params params;
|
||||
|
||||
llama_seq_id seq_id;
|
||||
|
||||
// mirostat sampler state
|
||||
float mirostat_mu;
|
||||
|
||||
llama_context * ctx; // TMP
|
||||
llama_grammar * grammar;
|
||||
|
||||
// internal
|
||||
|
@ -81,15 +84,14 @@ struct llama_sampling_context {
|
|||
// TODO: replace with ring-buffer
|
||||
std::vector<llama_token> prev;
|
||||
std::vector<llama_token_data> cur;
|
||||
size_t n_valid; // Number of correct top tokens with correct probabilities.
|
||||
|
||||
std::mt19937 rng;
|
||||
size_t n_valid; // Number of correct top tokens with correct probabilities.
|
||||
};
|
||||
|
||||
#include "common.h"
|
||||
|
||||
// Create a new sampling context instance.
|
||||
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params);
|
||||
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, struct llama_context * ctx, llama_seq_id seq_id);
|
||||
|
||||
void llama_sampling_free(struct llama_sampling_context * ctx);
|
||||
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
#define LLAMA_API_INTERNAL
|
||||
|
||||
#include "grammar-parser.h"
|
||||
#include "ggml.h"
|
||||
#include "llama.h"
|
||||
#include "llama-impl.h"
|
||||
#include "unicode.h"
|
||||
|
||||
#include <cstdio>
|
||||
|
|
|
@ -346,7 +346,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
std::vector<llama_token> embd;
|
||||
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams, ctx, 0);
|
||||
|
||||
while (n_remain != 0 || params.interactive) {
|
||||
// predict
|
||||
|
|
|
@ -191,7 +191,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
|
|||
|
||||
LOG_TEE("\n");
|
||||
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams);
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams, ctx_llava->ctx_llama, 0);
|
||||
if (!ctx_sampling) {
|
||||
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
|
||||
exit(1);
|
||||
|
|
|
@ -118,7 +118,7 @@ int main(int argc, char ** argv) {
|
|||
llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1);
|
||||
|
||||
// target model sampling context
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams, ctx, 0);
|
||||
|
||||
// verification n-grams
|
||||
std::vector<ngram_data> ngrams_cur(G);
|
||||
|
|
|
@ -106,7 +106,7 @@ int main(int argc, char ** argv){
|
|||
|
||||
bool has_eos = false;
|
||||
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams, ctx, 0);
|
||||
|
||||
std::vector<llama_token> draft;
|
||||
|
||||
|
|
|
@ -527,7 +527,7 @@ int main(int argc, char ** argv) {
|
|||
antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true));
|
||||
}
|
||||
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams, ctx, 0);
|
||||
if (!ctx_sampling) {
|
||||
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
|
||||
exit(1);
|
||||
|
|
|
@ -161,7 +161,7 @@ int main(int argc, char ** argv) {
|
|||
for (size_t i = 0; i < clients.size(); ++i) {
|
||||
auto & client = clients[i];
|
||||
client.id = i;
|
||||
client.ctx_sampling = llama_sampling_init(params.sparams);
|
||||
client.ctx_sampling = llama_sampling_init(params.sparams, ctx, i);
|
||||
}
|
||||
|
||||
std::vector<llama_token> tokens_system;
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#define LLAMA_API_INTERNAL
|
||||
#include "common.h"
|
||||
#include "ggml.h"
|
||||
#include "llama.h"
|
||||
#include "llama-impl.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
|
|
|
@ -1090,7 +1090,7 @@ struct server_context {
|
|||
if (slot.ctx_sampling != nullptr) {
|
||||
llama_sampling_free(slot.ctx_sampling);
|
||||
}
|
||||
slot.ctx_sampling = llama_sampling_init(slot.sparams);
|
||||
slot.ctx_sampling = llama_sampling_init(slot.sparams, ctx, slot.id);
|
||||
if (slot.ctx_sampling == nullptr) {
|
||||
// for now, the only error that may happen here is invalid grammar
|
||||
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
|
||||
|
|
|
@ -175,7 +175,7 @@ int main(int argc, char ** argv) {
|
|||
bool has_eos = false;
|
||||
|
||||
// target model sampling context
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams, ctx_tgt, 0);
|
||||
|
||||
// draft sequence data
|
||||
std::vector<seq_draft> drafts(n_seq_dft);
|
||||
|
@ -186,7 +186,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
for (int s = 0; s < n_seq_dft; ++s) {
|
||||
drafts[s].ctx_sampling = llama_sampling_init(params.sparams);
|
||||
drafts[s].ctx_sampling = llama_sampling_init(params.sparams, ctx_dft, s);
|
||||
}
|
||||
|
||||
llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1);
|
||||
|
|
|
@ -40,7 +40,7 @@
|
|||
#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
|
||||
|
||||
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
||||
#define LLAMA_SESSION_VERSION 7
|
||||
#define LLAMA_SESSION_VERSION 8
|
||||
|
||||
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
|
||||
#define LLAMA_STATE_SEQ_VERSION 1
|
||||
|
@ -1031,6 +1031,9 @@ extern "C" {
|
|||
// Sets the current rng seed.
|
||||
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed);
|
||||
|
||||
LLAMA_API DEPRECATED(void llama_set_rng_seed_seq(struct llama_context * ctx, uint32_t seed, llama_seq_id),
|
||||
"temporary API, until llama_sampling_context is implemented, do not use");
|
||||
|
||||
/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
|
||||
/// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
|
||||
LLAMA_API void llama_sample_repetition_penalties(
|
||||
|
@ -1137,11 +1140,18 @@ extern "C" {
|
|||
struct llama_context * ctx,
|
||||
llama_token_data_array * candidates);
|
||||
|
||||
/// @details Randomly selects a token from the candidates based on their probabilities using the RNG of ctx.
|
||||
/// @details Randomly selects a token from the candidates based on their probabilities using RNG[0] of ctx.
|
||||
LLAMA_API llama_token llama_sample_token(
|
||||
struct llama_context * ctx,
|
||||
llama_token_data_array * candidates);
|
||||
|
||||
/// @details Same as llama_sample_token, but uses a seqeuence-specific RNG[seq_id].
|
||||
LLAMA_API DEPRECATED(llama_token llama_sample_token_seq(
|
||||
struct llama_context * ctx,
|
||||
llama_token_data_array * candidates,
|
||||
llama_seq_id seq_id),
|
||||
"temporary API, until llama_sampling_context is implemented, do not use");
|
||||
|
||||
//
|
||||
// Model split
|
||||
//
|
||||
|
@ -1175,59 +1185,4 @@ extern "C" {
|
|||
}
|
||||
#endif
|
||||
|
||||
// Internal API to be implemented by llama.cpp and used by tests/benchmarks only
|
||||
#ifdef LLAMA_API_INTERNAL
|
||||
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
struct ggml_tensor;
|
||||
|
||||
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
|
||||
struct llama_context * ctx
|
||||
);
|
||||
|
||||
struct llama_partial_utf8 {
|
||||
uint32_t value; // bit value so far (unshifted)
|
||||
int n_remain; // num bytes remaining; -1 indicates invalid sequence
|
||||
};
|
||||
|
||||
struct llama_grammar_candidate {
|
||||
size_t index;
|
||||
const uint32_t * code_points;
|
||||
llama_partial_utf8 partial_utf8;
|
||||
};
|
||||
|
||||
using llama_grammar_rule = std::vector< llama_grammar_element>;
|
||||
using llama_grammar_stack = std::vector<const llama_grammar_element *>;
|
||||
|
||||
using llama_grammar_rules = std::vector<llama_grammar_rule>;
|
||||
using llama_grammar_stacks = std::vector<llama_grammar_stack>;
|
||||
using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
|
||||
|
||||
const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
|
||||
llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
|
||||
|
||||
void llama_grammar_accept(
|
||||
const llama_grammar_rules & rules,
|
||||
const llama_grammar_stacks & stacks,
|
||||
const uint32_t chr,
|
||||
llama_grammar_stacks & new_stacks);
|
||||
|
||||
std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
|
||||
const llama_grammar_rules & rules,
|
||||
const llama_grammar_stack & stack,
|
||||
const llama_grammar_candidates & candidates);
|
||||
|
||||
std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
|
||||
const std::string & src,
|
||||
llama_partial_utf8 partial_start);
|
||||
|
||||
// Randomly selects a token from the candidates based on their probabilities using given std::mt19937.
|
||||
// This is a temporary workaround in order to fix race conditions when sampling with multiple sequences.
|
||||
llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng);
|
||||
|
||||
#endif // LLAMA_API_INTERNAL
|
||||
|
||||
#endif // LLAMA_H
|
||||
|
|
|
@ -445,15 +445,15 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) {
|
|||
delete grammar;
|
||||
}
|
||||
|
||||
struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar) {
|
||||
llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8 };
|
||||
struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar & grammar) {
|
||||
llama_grammar * result = new llama_grammar{ grammar.rules, grammar.stacks, grammar.partial_utf8 };
|
||||
|
||||
// redirect elements in stacks to point to new rules
|
||||
for (size_t is = 0; is < result->stacks.size(); is++) {
|
||||
for (size_t ie = 0; ie < result->stacks[is].size(); ie++) {
|
||||
for (size_t ir0 = 0; ir0 < grammar->rules.size(); ir0++) {
|
||||
for (size_t ir1 = 0; ir1 < grammar->rules[ir0].size(); ir1++) {
|
||||
if (grammar->stacks[is][ie] == &grammar->rules[ir0][ir1]) {
|
||||
for (size_t ir0 = 0; ir0 < grammar.rules.size(); ir0++) {
|
||||
for (size_t ir1 = 0; ir1 < grammar.rules[ir0].size(); ir1++) {
|
||||
if (grammar.stacks[is][ie] == &grammar.rules[ir0][ir1]) {
|
||||
result->stacks[is][ie] = &result->rules[ir0][ir1];
|
||||
}
|
||||
}
|
||||
|
@ -464,14 +464,9 @@ struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * gram
|
|||
return result;
|
||||
}
|
||||
|
||||
void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
||||
GGML_ASSERT(grammar);
|
||||
GGML_ASSERT(vocab);
|
||||
|
||||
int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
void llama_grammar_sample_impl(const struct llama_grammar & grammar, const struct llama_vocab & vocab, llama_token_data_array * candidates) {
|
||||
bool allow_eog = false;
|
||||
for (const auto & stack : grammar->stacks) {
|
||||
for (const auto & stack : grammar.stacks) {
|
||||
if (stack.empty()) {
|
||||
allow_eog = true;
|
||||
break;
|
||||
|
@ -486,33 +481,29 @@ void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struc
|
|||
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
const llama_token id = candidates->data[i].id;
|
||||
const std::string & piece = vocab->cache_token_to_piece.at(id);
|
||||
const std::string & piece = vocab.cache_token_to_piece.at(id);
|
||||
|
||||
if (llama_token_is_eog_impl(*vocab, id)) {
|
||||
if (llama_token_is_eog_impl(vocab, id)) {
|
||||
if (!allow_eog) {
|
||||
candidates->data[i].logit = -INFINITY;
|
||||
}
|
||||
} else if (piece.empty() || piece[0] == 0) {
|
||||
candidates->data[i].logit = -INFINITY;
|
||||
} else {
|
||||
candidates_decoded.push_back(decode_utf8(piece, grammar->partial_utf8));
|
||||
candidates_decoded.push_back(decode_utf8(piece, grammar.partial_utf8));
|
||||
candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
|
||||
}
|
||||
}
|
||||
|
||||
const auto rejects = llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar);
|
||||
const auto rejects = llama_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar);
|
||||
for (const auto & reject : rejects) {
|
||||
candidates->data[reject.index].logit = -INFINITY;
|
||||
}
|
||||
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
|
||||
void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token token) {
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
if (llama_token_is_eog_impl(*vocab, token)) {
|
||||
for (const auto & stack : grammar->stacks) {
|
||||
void llama_grammar_accept_token_impl(struct llama_grammar & grammar, const struct llama_vocab & vocab, llama_token token) {
|
||||
if (llama_token_is_eog_impl(vocab, token)) {
|
||||
for (const auto & stack : grammar.stacks) {
|
||||
if (stack.empty()) {
|
||||
return;
|
||||
}
|
||||
|
@ -520,20 +511,18 @@ void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struc
|
|||
GGML_ASSERT(false);
|
||||
}
|
||||
|
||||
const std::string & piece = vocab->cache_token_to_piece.at(token);
|
||||
const std::string & piece = vocab.cache_token_to_piece.at(token);
|
||||
|
||||
// Note terminating 0 in decoded string
|
||||
const auto decoded = decode_utf8(piece, grammar->partial_utf8);
|
||||
const auto decoded = decode_utf8(piece, grammar.partial_utf8);
|
||||
const auto & code_points = decoded.first;
|
||||
|
||||
llama_grammar_stacks tmp_new_stacks;
|
||||
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
||||
llama_grammar_accept(grammar->rules, grammar->stacks, *it, tmp_new_stacks);
|
||||
grammar->stacks = tmp_new_stacks;
|
||||
llama_grammar_accept(grammar.rules, grammar.stacks, *it, tmp_new_stacks);
|
||||
grammar.stacks = tmp_new_stacks;
|
||||
}
|
||||
|
||||
grammar->partial_utf8 = decoded.second;
|
||||
GGML_ASSERT(!grammar->stacks.empty());
|
||||
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
grammar.partial_utf8 = decoded.second;
|
||||
GGML_ASSERT(!grammar.stacks.empty());
|
||||
}
|
||||
|
|
|
@ -26,16 +26,14 @@ struct llama_grammar * llama_grammar_init_impl(
|
|||
|
||||
void llama_grammar_free_impl(struct llama_grammar * grammar);
|
||||
|
||||
struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar);
|
||||
struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar & grammar);
|
||||
|
||||
void llama_grammar_sample_impl(
|
||||
const struct llama_grammar * grammar,
|
||||
const struct llama_vocab * vocab,
|
||||
const struct llama_sampling * smpl,
|
||||
const struct llama_grammar & grammar,
|
||||
const struct llama_vocab & vocab,
|
||||
llama_token_data_array * candidates);
|
||||
|
||||
void llama_grammar_accept_token_impl(
|
||||
struct llama_grammar * grammar,
|
||||
const struct llama_vocab * vocab,
|
||||
const struct llama_sampling * smpl,
|
||||
struct llama_grammar & grammar,
|
||||
const struct llama_vocab & vocab,
|
||||
llama_token token);
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
#pragma once
|
||||
|
||||
#define LLAMA_API_INTERNAL
|
||||
#include "llama.h"
|
||||
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#ifdef __GNUC__
|
||||
#ifdef __MINGW32__
|
||||
#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
|
||||
|
@ -24,3 +27,43 @@ void llama_log_callback_default(ggml_log_level level, const char * text, void *
|
|||
#define LLAMA_LOG_INFO(...) llama_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__)
|
||||
#define LLAMA_LOG_WARN(...) llama_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__)
|
||||
#define LLAMA_LOG_ERROR(...) llama_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
|
||||
|
||||
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
|
||||
struct llama_context * ctx
|
||||
);
|
||||
|
||||
struct llama_partial_utf8 {
|
||||
uint32_t value; // bit value so far (unshifted)
|
||||
int n_remain; // num bytes remaining; -1 indicates invalid sequence
|
||||
};
|
||||
|
||||
struct llama_grammar_candidate {
|
||||
size_t index;
|
||||
const uint32_t * code_points;
|
||||
llama_partial_utf8 partial_utf8;
|
||||
};
|
||||
|
||||
using llama_grammar_rule = std::vector< llama_grammar_element>;
|
||||
using llama_grammar_stack = std::vector<const llama_grammar_element *>;
|
||||
|
||||
using llama_grammar_rules = std::vector<llama_grammar_rule>;
|
||||
using llama_grammar_stacks = std::vector<llama_grammar_stack>;
|
||||
using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
|
||||
|
||||
const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
|
||||
llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
|
||||
|
||||
void llama_grammar_accept(
|
||||
const llama_grammar_rules & rules,
|
||||
const llama_grammar_stacks & stacks,
|
||||
const uint32_t chr,
|
||||
llama_grammar_stacks & new_stacks);
|
||||
|
||||
std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
|
||||
const llama_grammar_rules & rules,
|
||||
const llama_grammar_stack & stack,
|
||||
const llama_grammar_candidates & candidates);
|
||||
|
||||
std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
|
||||
const std::string & src,
|
||||
llama_partial_utf8 partial_start);
|
||||
|
|
|
@ -21,19 +21,17 @@ static void llama_log_softmax(float * array, size_t size) {
|
|||
}
|
||||
}
|
||||
|
||||
void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed) {
|
||||
void llama_set_rng_seed_impl(struct llama_sampling & smpl, uint32_t seed) {
|
||||
if (seed == LLAMA_DEFAULT_SEED) {
|
||||
seed = time(NULL);
|
||||
}
|
||||
|
||||
smpl->rng.seed(seed);
|
||||
smpl.rng.seed(seed);
|
||||
}
|
||||
|
||||
void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
||||
void llama_sample_softmax_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates) {
|
||||
GGML_ASSERT(candidates->size > 0);
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
// Sort the logits in descending order
|
||||
if (!candidates->sorted) {
|
||||
std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
|
||||
|
@ -44,28 +42,24 @@ void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_ar
|
|||
|
||||
float max_l = candidates->data[0].logit;
|
||||
float cum_sum = 0.0f;
|
||||
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
float p = expf(candidates->data[i].logit - max_l);
|
||||
candidates->data[i].p = p;
|
||||
cum_sum += p;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
candidates->data[i].p /= cum_sum;
|
||||
}
|
||||
|
||||
if (smpl) {
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
|
||||
void llama_sample_top_k_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
|
||||
// TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
|
||||
// if (k >= (int32_t)candidates->size) {
|
||||
// return;
|
||||
// }
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
if (k <= 0) {
|
||||
k = candidates->size;
|
||||
}
|
||||
|
@ -133,21 +127,15 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra
|
|||
candidates->sorted = true;
|
||||
}
|
||||
candidates->size = k;
|
||||
|
||||
if (smpl) {
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_top_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
|
||||
void llama_sample_top_p_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
|
||||
if (p >= 1.0f) {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_sample_softmax_impl(smpl, candidates);
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
// Compute the cumulative probabilities
|
||||
float cum_sum = 0.0f;
|
||||
size_t last_idx = candidates->size;
|
||||
|
@ -165,19 +153,13 @@ void llama_sample_top_p_impl(struct llama_sampling * smpl, llama_token_data_arra
|
|||
|
||||
// Resize the output vector to keep only the top-p tokens
|
||||
candidates->size = last_idx;
|
||||
|
||||
if (smpl) {
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
|
||||
void llama_sample_min_p_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates, float p, size_t min_keep) {
|
||||
if (p <= 0.0f || !candidates->size) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
bool min_p_applied = false;
|
||||
|
||||
// if the candidates aren't sorted, try the unsorted implementation first
|
||||
|
@ -226,19 +208,14 @@ void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_arra
|
|||
// Resize the output vector to keep only the matching tokens
|
||||
candidates->size = i;
|
||||
}
|
||||
|
||||
if (smpl) {
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) {
|
||||
void llama_sample_tail_free_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float z, size_t min_keep) {
|
||||
if (z >= 1.0f || candidates->size <= 2) {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
llama_sample_softmax_impl(smpl, candidates);
|
||||
|
||||
// Compute the first and second derivatives
|
||||
std::vector<float> first_derivatives(candidates->size - 1);
|
||||
|
@ -285,13 +262,9 @@ void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_
|
|||
|
||||
// Resize the output vector to keep only the tokens above the tail location
|
||||
candidates->size = last_idx;
|
||||
|
||||
if (smpl) {
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
|
||||
void llama_sample_typical_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
|
||||
// Reference implementation:
|
||||
// https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
|
||||
if (p >= 1.0f) {
|
||||
|
@ -299,9 +272,7 @@ void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_ar
|
|||
}
|
||||
|
||||
// Compute the softmax of logits and calculate entropy
|
||||
llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
llama_sample_softmax_impl(smpl, candidates);
|
||||
|
||||
float entropy = 0.0f;
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
|
@ -349,15 +320,9 @@ void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_ar
|
|||
std::copy(new_candidates.begin(), new_candidates.end(), candidates->data);
|
||||
candidates->size = new_candidates.size();
|
||||
candidates->sorted = false;
|
||||
|
||||
if (smpl) {
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) {
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
void llama_sample_entropy_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) {
|
||||
// no need to do anything if there is only one (or zero) candidates
|
||||
if(candidates->size <= 1) {
|
||||
return;
|
||||
|
@ -366,7 +331,7 @@ void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_ar
|
|||
// Calculate maximum possible entropy
|
||||
float max_entropy = -logf(1.0f / candidates->size);
|
||||
|
||||
llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
|
||||
llama_sample_softmax_impl(smpl, candidates);
|
||||
|
||||
// Calculate entropy of the softmax probabilities
|
||||
float entropy = 0.0f;
|
||||
|
@ -416,38 +381,26 @@ void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_ar
|
|||
LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, candidates->data[i].p * 100.0f);
|
||||
}
|
||||
#endif
|
||||
|
||||
if (smpl) {
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_temp_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float temp) {
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
void llama_sample_temp_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates, float temp) {
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
candidates->data[i].logit /= temp;
|
||||
}
|
||||
|
||||
if (smpl) {
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_repetition_penalties_impl(
|
||||
struct llama_sampling * smpl,
|
||||
struct llama_sampling & /*smpl*/,
|
||||
llama_token_data_array * candidates,
|
||||
const llama_token * last_tokens,
|
||||
size_t penalty_last_n,
|
||||
float penalty_repeat,
|
||||
float penalty_freq,
|
||||
float penalty_present) {
|
||||
float penalty_repeat,
|
||||
float penalty_freq,
|
||||
float penalty_present) {
|
||||
if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
// Create a frequency map to count occurrences of each token in last_tokens
|
||||
std::unordered_map<llama_token, int> token_count;
|
||||
for (size_t i = 0; i < penalty_last_n; ++i) {
|
||||
|
@ -475,21 +428,14 @@ void llama_sample_repetition_penalties_impl(
|
|||
}
|
||||
|
||||
candidates->sorted = false;
|
||||
|
||||
if (smpl) {
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_apply_guidance_impl(
|
||||
struct llama_sampling * smpl,
|
||||
struct llama_sampling & smpl,
|
||||
float * logits,
|
||||
float * logits_guidance,
|
||||
float scale) {
|
||||
GGML_ASSERT(smpl);
|
||||
|
||||
const auto t_start_sample_us = ggml_time_us();
|
||||
const auto n_vocab = smpl->n_vocab;
|
||||
const auto n_vocab = smpl.n_vocab;
|
||||
|
||||
llama_log_softmax(logits, n_vocab);
|
||||
llama_log_softmax(logits_guidance, n_vocab);
|
||||
|
@ -500,18 +446,12 @@ void llama_sample_apply_guidance_impl(
|
|||
|
||||
l = scale * (l - g) + g;
|
||||
}
|
||||
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
|
||||
llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
|
||||
GGML_ASSERT(smpl);
|
||||
llama_token llama_sample_token_mirostat_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
|
||||
const int32_t n_vocab = float(smpl.n_vocab);
|
||||
|
||||
const int32_t n_vocab = float(smpl->n_vocab);
|
||||
|
||||
int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
|
||||
llama_sample_softmax_impl(smpl, candidates);
|
||||
|
||||
// Estimate s_hat using the most probable m tokens
|
||||
float s_hat = 0.0;
|
||||
|
@ -530,10 +470,8 @@ llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama
|
|||
float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat);
|
||||
|
||||
// Sample the next word X using top-k sampling
|
||||
llama_sample_top_k_impl((struct llama_sampling *) nullptr, candidates, int(k), 1);
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
llama_sample_top_k_impl(smpl, candidates, int(k), 1);
|
||||
llama_token X = llama_sample_token_impl(smpl, candidates);
|
||||
t_start_sample_us = ggml_time_us();
|
||||
|
||||
// Compute error as the difference between observed surprise and target surprise value
|
||||
size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
|
||||
|
@ -545,14 +483,10 @@ llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama
|
|||
// Update mu using the learning rate and error
|
||||
*mu = *mu - eta * e;
|
||||
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
return X;
|
||||
}
|
||||
|
||||
llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) {
|
||||
int64_t t_start_sample_us;
|
||||
t_start_sample_us = ggml_time_us();
|
||||
|
||||
llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) {
|
||||
llama_sample_softmax_impl(smpl, candidates);
|
||||
|
||||
// Truncate the words with surprise values greater than mu
|
||||
|
@ -564,16 +498,11 @@ llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, ll
|
|||
candidates->size = 1;
|
||||
}
|
||||
|
||||
if (smpl) {
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
|
||||
// Normalize the probabilities of the remaining words
|
||||
llama_sample_softmax_impl(smpl, candidates);
|
||||
|
||||
// Sample the next word X from the remaining words
|
||||
llama_token X = llama_sample_token_impl(smpl, candidates);
|
||||
t_start_sample_us = ggml_time_us();
|
||||
|
||||
// Compute error as the difference between observed surprise and target surprise value
|
||||
size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
|
||||
|
@ -585,33 +514,22 @@ llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, ll
|
|||
// Update mu using the learning rate and error
|
||||
*mu = *mu - eta * e;
|
||||
|
||||
if (smpl) {
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
return X;
|
||||
}
|
||||
|
||||
llama_token llama_sample_token_greedy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
llama_token llama_sample_token_greedy_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates) {
|
||||
// Find max element
|
||||
auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
|
||||
return a.logit < b.logit;
|
||||
});
|
||||
|
||||
llama_token result = max_iter->id;
|
||||
if (smpl) {
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
smpl->n_sample++;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
llama_token llama_sample_token_with_rng_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng) {
|
||||
GGML_ASSERT(smpl);
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
|
||||
llama_token llama_sample_token_with_rng_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, std::mt19937 & rng) {
|
||||
llama_sample_softmax_impl(smpl, candidates);
|
||||
|
||||
std::vector<float> probs;
|
||||
probs.reserve(candidates->size);
|
||||
|
@ -624,12 +542,9 @@ llama_token llama_sample_token_with_rng_impl(struct llama_sampling * smpl, llama
|
|||
|
||||
llama_token result = candidates->data[idx].id;
|
||||
|
||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
smpl->n_sample++;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
llama_token llama_sample_token_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
||||
return llama_sample_token_with_rng_impl(smpl, candidates, smpl->rng);
|
||||
llama_token llama_sample_token_impl(struct llama_sampling & smpl, llama_token_data_array * candidates) {
|
||||
return llama_sample_token_with_rng_impl(smpl, candidates, smpl.rng);
|
||||
}
|
||||
|
|
|
@ -3,38 +3,30 @@
|
|||
#include "llama-impl.h"
|
||||
|
||||
struct llama_sampling {
|
||||
llama_sampling(int32_t n_vocab) : n_vocab(n_vocab) {}
|
||||
llama_sampling(uint32_t seed, int32_t n_vocab) : rng(seed), n_vocab(n_vocab) {}
|
||||
|
||||
std::mt19937 rng;
|
||||
|
||||
int32_t n_vocab = 0;
|
||||
|
||||
mutable int64_t t_sample_us = 0;
|
||||
mutable int32_t n_sample = 0;
|
||||
|
||||
void reset_timings() const {
|
||||
t_sample_us = 0;
|
||||
n_sample = 0;
|
||||
}
|
||||
const int32_t n_vocab;
|
||||
};
|
||||
|
||||
//
|
||||
// internal API
|
||||
//
|
||||
|
||||
void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed);
|
||||
void llama_set_rng_seed_impl(struct llama_sampling & smpl, uint32_t seed);
|
||||
|
||||
void llama_sample_softmax_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
|
||||
void llama_sample_top_k_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep);
|
||||
void llama_sample_top_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
|
||||
void llama_sample_min_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
|
||||
void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep);
|
||||
void llama_sample_typical_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
|
||||
void llama_sample_entropy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val);
|
||||
void llama_sample_temp_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float temp);
|
||||
void llama_sample_softmax_impl (struct llama_sampling & smpl, llama_token_data_array * candidates);
|
||||
void llama_sample_top_k_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep);
|
||||
void llama_sample_top_p_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep);
|
||||
void llama_sample_min_p_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep);
|
||||
void llama_sample_tail_free_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float z, size_t min_keep);
|
||||
void llama_sample_typical_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep);
|
||||
void llama_sample_entropy_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val);
|
||||
void llama_sample_temp_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float temp);
|
||||
|
||||
void llama_sample_repetition_penalties_impl(
|
||||
struct llama_sampling * smpl,
|
||||
struct llama_sampling & smpl,
|
||||
llama_token_data_array * candidates,
|
||||
const llama_token * last_tokens,
|
||||
size_t penalty_last_n,
|
||||
|
@ -43,14 +35,14 @@ void llama_sample_repetition_penalties_impl(
|
|||
float penalty_present);
|
||||
|
||||
void llama_sample_apply_guidance_impl(
|
||||
struct llama_sampling * smpl,
|
||||
struct llama_sampling & smpl,
|
||||
float * logits,
|
||||
float * logits_guidance,
|
||||
float scale);
|
||||
|
||||
llama_token llama_sample_token_mirostat_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu);
|
||||
llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu);
|
||||
llama_token llama_sample_token_greedy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
|
||||
llama_token llama_sample_token_with_rng_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng);
|
||||
llama_token llama_sample_token_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
|
||||
llama_token llama_sample_token_mirostat_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu);
|
||||
llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float tau, float eta, float * mu);
|
||||
llama_token llama_sample_token_greedy_impl (struct llama_sampling & smpl, llama_token_data_array * candidates);
|
||||
llama_token llama_sample_token_with_rng_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, std::mt19937 & rng);
|
||||
llama_token llama_sample_token_impl (struct llama_sampling & smpl, llama_token_data_array * candidates);
|
||||
|
||||
|
|
210
src/llama.cpp
210
src/llama.cpp
|
@ -163,6 +163,19 @@ static void zeros(std::ofstream & file, size_t n) {
|
|||
}
|
||||
}
|
||||
|
||||
struct time_meas {
|
||||
time_meas(int64_t & t_acc) : t_start_us(ggml_time_us()), t_acc(t_acc) {}
|
||||
|
||||
~time_meas() {
|
||||
t_acc += ggml_time_us() - t_start_us;
|
||||
}
|
||||
|
||||
const int64_t t_start_us;
|
||||
|
||||
int64_t & t_acc;
|
||||
};
|
||||
|
||||
|
||||
LLAMA_ATTRIBUTE_FORMAT(1, 2)
|
||||
static std::string format(const char * fmt, ...) {
|
||||
va_list ap;
|
||||
|
@ -2656,7 +2669,6 @@ struct llama_model {
|
|||
struct llama_context {
|
||||
llama_context(const llama_model & model)
|
||||
: model(model)
|
||||
, sampling(llama_n_vocab(&model))
|
||||
, grammar()
|
||||
, t_start_us(model.t_start_us)
|
||||
, t_load_us(model.t_load_us) {}
|
||||
|
@ -2674,11 +2686,12 @@ struct llama_context {
|
|||
const struct llama_model & model;
|
||||
|
||||
struct llama_cparams cparams;
|
||||
struct llama_sampling sampling;
|
||||
struct llama_grammar grammar;
|
||||
struct llama_kv_cache kv_self;
|
||||
struct llama_control_vector cvec;
|
||||
|
||||
std::vector<struct llama_sampling> sampling; // sampling context for each sequence
|
||||
|
||||
std::unordered_map<struct llama_lora_adapter *, float> lora_adapters;
|
||||
|
||||
std::vector<ggml_backend_t> backends;
|
||||
|
@ -2692,16 +2705,18 @@ struct llama_context {
|
|||
|
||||
bool has_evaluated_once = false;
|
||||
|
||||
int64_t t_start_us;
|
||||
int64_t t_load_us;
|
||||
int64_t t_p_eval_us = 0;
|
||||
int64_t t_eval_us = 0;
|
||||
mutable int64_t t_start_us;
|
||||
mutable int64_t t_load_us;
|
||||
mutable int64_t t_sample_us = 0;
|
||||
mutable int64_t t_p_eval_us = 0;
|
||||
mutable int64_t t_eval_us = 0;
|
||||
|
||||
int64_t t_compute_start_us = 0;
|
||||
int64_t n_queued_tokens = 0;
|
||||
mutable int64_t t_compute_start_us = 0;
|
||||
mutable int64_t n_queued_tokens = 0;
|
||||
|
||||
int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
|
||||
int32_t n_eval = 0; // number of eval calls
|
||||
mutable int32_t n_sample = 0;
|
||||
mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
|
||||
mutable int32_t n_eval = 0; // number of eval calls
|
||||
|
||||
// host buffer for the model output (logits and embeddings)
|
||||
ggml_backend_buffer_t buf_output = nullptr;
|
||||
|
@ -16527,8 +16542,12 @@ struct llama_context * llama_new_context_with_model(
|
|||
ctx->abort_callback = params.abort_callback;
|
||||
ctx->abort_callback_data = params.abort_callback_data;
|
||||
|
||||
ctx->sampling.rng = std::mt19937(params.seed);
|
||||
ctx->logits_all = params.logits_all;
|
||||
ctx->sampling.reserve(cparams.n_seq_max);
|
||||
for (uint32_t i = 0; i < cparams.n_seq_max; ++i) {
|
||||
ctx->sampling.emplace_back(params.seed, llama_n_vocab(model));
|
||||
}
|
||||
|
||||
ctx->logits_all = params.logits_all;
|
||||
|
||||
uint32_t kv_size = cparams.n_ctx;
|
||||
ggml_type type_k = params.type_k;
|
||||
|
@ -17303,6 +17322,7 @@ size_t llama_state_get_size(const struct llama_context * ctx) {
|
|||
|
||||
// we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
|
||||
// for reference, std::mt19937(1337) serializes to 6701 bytes.
|
||||
const size_t s_n_rng = sizeof(uint32_t);
|
||||
const size_t s_rng_size = sizeof(size_t);
|
||||
const size_t s_rng = LLAMA_MAX_RNG_STATE;
|
||||
const size_t s_n_outputs = sizeof(size_t);
|
||||
|
@ -17322,8 +17342,8 @@ size_t llama_state_get_size(const struct llama_context * ctx) {
|
|||
const size_t s_kv_cells = ctx->kv_self.size * s_kv_cell;
|
||||
|
||||
const size_t s_total = (
|
||||
+ s_rng_size
|
||||
+ s_rng
|
||||
+ s_n_rng
|
||||
+ cparams.n_seq_max*(s_rng_size + s_rng)
|
||||
+ s_n_outputs
|
||||
+ s_output_pos
|
||||
+ s_logits_size
|
||||
|
@ -17340,7 +17360,7 @@ size_t llama_state_get_size(const struct llama_context * ctx) {
|
|||
);
|
||||
|
||||
// on session change it is very likely that the state size has changed - so we need to update this function
|
||||
static_assert(LLAMA_SESSION_VERSION == 7, "So you just bumped the session version - good. But did you remember to update llama_state_get_size?");
|
||||
static_assert(LLAMA_SESSION_VERSION == 8, "So you just bumped the session version - good. But did you remember to update llama_state_get_size?");
|
||||
|
||||
return s_total;
|
||||
}
|
||||
|
@ -17401,18 +17421,24 @@ struct llama_data_file_context : llama_data_context {
|
|||
static void llama_state_get_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) {
|
||||
llama_synchronize(ctx);
|
||||
|
||||
// copy rng
|
||||
// copy rngs
|
||||
{
|
||||
std::ostringstream rng_ss;
|
||||
rng_ss << ctx->sampling.rng;
|
||||
const uint32_t n_rng = ctx->sampling.size();
|
||||
|
||||
const std::string & rng_str = rng_ss.str();
|
||||
const size_t rng_size = rng_str.size();
|
||||
data_ctx->write(&n_rng, sizeof(n_rng));
|
||||
|
||||
GGML_ASSERT(rng_size <= LLAMA_MAX_RNG_STATE);
|
||||
for (const auto & smpl : ctx->sampling) {
|
||||
std::ostringstream rng_ss;
|
||||
rng_ss << smpl.rng;
|
||||
|
||||
data_ctx->write(&rng_size, sizeof(rng_size));
|
||||
data_ctx->write(rng_str.data(), rng_size);
|
||||
const std::string & rng_str = rng_ss.str();
|
||||
const size_t rng_size = rng_str.size();
|
||||
|
||||
GGML_ASSERT(rng_size <= LLAMA_MAX_RNG_STATE);
|
||||
|
||||
data_ctx->write(&rng_size, sizeof(rng_size));
|
||||
data_ctx->write(rng_str.data(), rng_size);
|
||||
}
|
||||
}
|
||||
|
||||
// copy outputs
|
||||
|
@ -17560,19 +17586,26 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
|
|||
|
||||
const uint8_t * inp = src;
|
||||
|
||||
// set rng
|
||||
// set rngs
|
||||
{
|
||||
size_t rng_size;
|
||||
memcpy(&rng_size, inp, sizeof(rng_size)); inp += sizeof(rng_size);
|
||||
uint32_t n_rng;
|
||||
memcpy(&n_rng, inp, sizeof(n_rng)); inp += sizeof(n_rng);
|
||||
|
||||
GGML_ASSERT(rng_size <= LLAMA_MAX_RNG_STATE);
|
||||
GGML_ASSERT(n_rng == ctx->cparams.n_seq_max);
|
||||
|
||||
std::string rng_str((const char *)inp, rng_size); inp += rng_size;
|
||||
for (auto & smpl : ctx->sampling) {
|
||||
size_t rng_size;
|
||||
memcpy(&rng_size, inp, sizeof(rng_size)); inp += sizeof(rng_size);
|
||||
|
||||
std::istringstream rng_ss(rng_str);
|
||||
rng_ss >> ctx->sampling.rng;
|
||||
GGML_ASSERT(rng_size <= LLAMA_MAX_RNG_STATE);
|
||||
|
||||
GGML_ASSERT(!rng_ss.fail());
|
||||
std::string rng_str((const char *)inp, rng_size); inp += rng_size;
|
||||
|
||||
std::istringstream rng_ss(rng_str);
|
||||
rng_ss >> smpl.rng;
|
||||
|
||||
GGML_ASSERT(!rng_ss.fail());
|
||||
}
|
||||
}
|
||||
|
||||
// set output ids
|
||||
|
@ -18930,18 +18963,24 @@ struct llama_grammar * llama_grammar_init(
|
|||
}
|
||||
|
||||
void llama_grammar_free(struct llama_grammar * grammar) {
|
||||
if (grammar == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_grammar_free_impl(grammar);
|
||||
}
|
||||
|
||||
struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) {
|
||||
return llama_grammar_copy_impl(grammar);
|
||||
return llama_grammar_copy_impl(*grammar);
|
||||
}
|
||||
|
||||
void llama_grammar_sample(
|
||||
const struct llama_grammar * grammar,
|
||||
const struct llama_context * ctx,
|
||||
llama_token_data_array * candidates) {
|
||||
llama_grammar_sample_impl(grammar, &ctx->model.vocab, &ctx->sampling, candidates);
|
||||
time_meas tm(ctx->t_sample_us); // TODO: measure grammar time separately from sampling
|
||||
|
||||
llama_grammar_sample_impl(*grammar, ctx->model.vocab, candidates);
|
||||
}
|
||||
|
||||
void llama_sample_grammar(
|
||||
|
@ -18955,7 +18994,9 @@ void llama_grammar_accept_token(
|
|||
struct llama_grammar * grammar,
|
||||
struct llama_context * ctx,
|
||||
llama_token token) {
|
||||
llama_grammar_accept_token_impl(grammar, &ctx->model.vocab, &ctx->sampling, token);
|
||||
time_meas tm(ctx->t_sample_us); // TODO: measure grammar time separately from sampling
|
||||
|
||||
llama_grammar_accept_token_impl(*grammar, ctx->model.vocab, token);
|
||||
}
|
||||
|
||||
//
|
||||
|
@ -18963,39 +19004,59 @@ void llama_grammar_accept_token(
|
|||
//
|
||||
|
||||
void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) {
|
||||
llama_set_rng_seed_impl(&ctx->sampling, seed);
|
||||
llama_set_rng_seed_impl(ctx->sampling[0], seed);
|
||||
}
|
||||
|
||||
void llama_set_rng_seed_seq(struct llama_context * ctx, uint32_t seed, llama_seq_id seq_id) {
|
||||
llama_set_rng_seed_impl(ctx->sampling[seq_id], seed);
|
||||
}
|
||||
|
||||
void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) {
|
||||
llama_sample_softmax_impl(ctx ? &ctx->sampling : nullptr, candidates);
|
||||
time_meas tm(ctx->t_sample_us);
|
||||
|
||||
llama_sample_softmax_impl(ctx->sampling[0], candidates);
|
||||
}
|
||||
|
||||
void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
|
||||
llama_sample_top_k_impl(ctx ? &ctx->sampling : nullptr, candidates, k, min_keep);
|
||||
time_meas tm(ctx->t_sample_us);
|
||||
|
||||
llama_sample_top_k_impl(ctx->sampling[0], candidates, k, min_keep);
|
||||
}
|
||||
|
||||
void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
|
||||
llama_sample_top_p_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep);
|
||||
time_meas tm(ctx->t_sample_us);
|
||||
|
||||
llama_sample_top_p_impl(ctx->sampling[0], candidates, p, min_keep);
|
||||
}
|
||||
|
||||
void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
|
||||
llama_sample_min_p_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep);
|
||||
time_meas tm(ctx->t_sample_us);
|
||||
|
||||
llama_sample_min_p_impl(ctx->sampling[0], candidates, p, min_keep);
|
||||
}
|
||||
|
||||
void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) {
|
||||
llama_sample_tail_free_impl(ctx ? &ctx->sampling : nullptr, candidates, z, min_keep);
|
||||
time_meas tm(ctx->t_sample_us);
|
||||
|
||||
llama_sample_tail_free_impl(ctx->sampling[0], candidates, z, min_keep);
|
||||
}
|
||||
|
||||
void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
|
||||
llama_sample_typical_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep);
|
||||
time_meas tm(ctx->t_sample_us);
|
||||
|
||||
llama_sample_typical_impl(ctx->sampling[0], candidates, p, min_keep);
|
||||
}
|
||||
|
||||
void llama_sample_entropy(struct llama_context * ctx, llama_token_data_array * candidates_p, float min_temp, float max_temp, float exponent_val) {
|
||||
llama_sample_entropy_impl(ctx ? &ctx->sampling : nullptr, candidates_p, min_temp, max_temp, exponent_val);
|
||||
time_meas tm(ctx->t_sample_us);
|
||||
|
||||
llama_sample_entropy_impl(ctx->sampling[0], candidates_p, min_temp, max_temp, exponent_val);
|
||||
}
|
||||
|
||||
void llama_sample_temp(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) {
|
||||
llama_sample_temp_impl(ctx ? &ctx->sampling : nullptr, candidates_p, temp);
|
||||
time_meas tm(ctx->t_sample_us);
|
||||
|
||||
llama_sample_temp_impl(ctx->sampling[0], candidates_p, temp);
|
||||
}
|
||||
|
||||
void llama_sample_repetition_penalties(
|
||||
|
@ -19006,7 +19067,9 @@ void llama_sample_repetition_penalties(
|
|||
float penalty_repeat,
|
||||
float penalty_freq,
|
||||
float penalty_present) {
|
||||
llama_sample_repetition_penalties_impl(ctx ? &ctx->sampling : nullptr, candidates, last_tokens, penalty_last_n, penalty_repeat, penalty_freq, penalty_present);
|
||||
time_meas tm(ctx->t_sample_us);
|
||||
|
||||
llama_sample_repetition_penalties_impl(ctx->sampling[0], candidates, last_tokens, penalty_last_n, penalty_repeat, penalty_freq, penalty_present);
|
||||
}
|
||||
|
||||
void llama_sample_apply_guidance(
|
||||
|
@ -19014,27 +19077,55 @@ void llama_sample_apply_guidance(
|
|||
float * logits,
|
||||
float * logits_guidance,
|
||||
float scale) {
|
||||
llama_sample_apply_guidance_impl(&ctx->sampling, logits, logits_guidance, scale);
|
||||
time_meas tm(ctx->t_sample_us);
|
||||
|
||||
llama_sample_apply_guidance_impl(ctx->sampling[0], logits, logits_guidance, scale);
|
||||
}
|
||||
|
||||
llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
|
||||
return llama_sample_token_mirostat_impl(&ctx->sampling, candidates, tau, eta, m, mu);
|
||||
time_meas tm(ctx->t_sample_us);
|
||||
|
||||
auto res = llama_sample_token_mirostat_impl(ctx->sampling[0], candidates, tau, eta, m, mu);
|
||||
|
||||
ctx->n_sample++;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu) {
|
||||
return llama_sample_token_mirostat_v2_impl(ctx ? &ctx->sampling : nullptr, candidates, tau, eta, mu);
|
||||
time_meas tm(ctx->t_sample_us);
|
||||
|
||||
auto res = llama_sample_token_mirostat_v2_impl(ctx->sampling[0], candidates, tau, eta, mu);
|
||||
|
||||
ctx->n_sample++;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates) {
|
||||
return llama_sample_token_greedy_impl(ctx ? &ctx->sampling : nullptr, candidates);
|
||||
}
|
||||
time_meas tm(ctx->t_sample_us);
|
||||
|
||||
llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng) {
|
||||
return llama_sample_token_with_rng_impl(&ctx->sampling, candidates, rng);
|
||||
auto res = llama_sample_token_greedy_impl(ctx->sampling[0], candidates);
|
||||
|
||||
ctx->n_sample++;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) {
|
||||
return llama_sample_token_with_rng_impl(&ctx->sampling, candidates, ctx->sampling.rng);
|
||||
return llama_sample_token_seq(ctx, candidates, 0);
|
||||
}
|
||||
|
||||
llama_token llama_sample_token_seq(struct llama_context * ctx, llama_token_data_array * candidates, llama_seq_id seq_id) {
|
||||
GGML_ASSERT(seq_id >= 0 && seq_id < (int32_t) ctx->cparams.n_seq_max);
|
||||
|
||||
time_meas tm(ctx->t_sample_us);
|
||||
|
||||
auto res = llama_sample_token_impl(ctx->sampling[seq_id], candidates);
|
||||
|
||||
ctx->n_sample++;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) {
|
||||
|
@ -19066,11 +19157,11 @@ struct llama_timings llama_get_timings(struct llama_context * ctx) {
|
|||
/*.t_start_ms =*/ 1e-3 * ctx->t_start_us,
|
||||
/*.t_end_ms =*/ 1.00 * ggml_time_ms(),
|
||||
/*.t_load_ms =*/ 1e-3 * ctx->t_load_us,
|
||||
/*.t_sample_ms =*/ 1e-3 * ctx->sampling.t_sample_us,
|
||||
/*.t_sample_ms =*/ 1e-3 * ctx->t_sample_us,
|
||||
/*.t_p_eval_ms =*/ 1e-3 * ctx->t_p_eval_us,
|
||||
/*.t_eval_ms =*/ 1e-3 * ctx->t_eval_us,
|
||||
|
||||
/*.n_sample =*/ std::max(1, ctx->sampling.n_sample),
|
||||
/*.n_sample =*/ std::max(1, ctx->n_sample),
|
||||
/*.n_p_eval =*/ std::max(0, ctx->n_p_eval),
|
||||
/*.n_eval =*/ std::max(1, ctx->n_eval),
|
||||
};
|
||||
|
@ -19095,9 +19186,8 @@ void llama_print_timings(struct llama_context * ctx) {
|
|||
void llama_reset_timings(struct llama_context * ctx) {
|
||||
ctx->t_start_us = ggml_time_us();
|
||||
ctx->t_eval_us = ctx->n_eval = 0;
|
||||
ctx->t_sample_us = ctx->n_sample = 0;
|
||||
ctx->t_p_eval_us = ctx->n_p_eval = 0;
|
||||
|
||||
ctx->sampling.reset_timings();
|
||||
}
|
||||
|
||||
const char * llama_print_system_info(void) {
|
||||
|
@ -19144,20 +19234,20 @@ void llama_dump_timing_info_yaml(FILE * stream, const llama_context * ctx) {
|
|||
fprintf(stream, "mst_p_eval: %.2f # ms / token during prompt processing\n",
|
||||
1.0e-3 * ctx->t_p_eval_us / ctx->n_p_eval);
|
||||
fprintf(stream, "mst_sample: %.2f # ms / token during sampling\n",
|
||||
1.0e-3 * ctx->sampling.t_sample_us / ctx->sampling.n_sample);
|
||||
1.0e-3 * ctx->t_sample_us / ctx->n_sample);
|
||||
fprintf(stream, "n_eval: %d # number of tokens generated (excluding the first one)\n", ctx->n_eval);
|
||||
fprintf(stream, "n_p_eval: %d # number of tokens processed in batches at the beginning\n", ctx->n_p_eval);
|
||||
fprintf(stream, "n_sample: %d # number of sampled tokens\n", ctx->sampling.n_sample);
|
||||
fprintf(stream, "n_sample: %d # number of sampled tokens\n", ctx->n_sample);
|
||||
fprintf(stream, "t_eval_us: %" PRId64 " # total microseconds spent generating tokens\n", ctx->t_eval_us);
|
||||
fprintf(stream, "t_load_us: %" PRId64 " # total microseconds spent loading the model\n", ctx->t_load_us);
|
||||
fprintf(stream, "t_p_eval_us: %" PRId64 " # total microseconds spent prompt processing\n", ctx->t_p_eval_us);
|
||||
fprintf(stream, "t_sample_us: %" PRId64 " # total microseconds spent sampling\n", ctx->sampling.t_sample_us);
|
||||
fprintf(stream, "t_sample_us: %" PRId64 " # total microseconds spent sampling\n", ctx->t_sample_us);
|
||||
fprintf(stream, "ts_eval: %.2f # tokens / second during generation\n",
|
||||
1.0e6 * ctx->n_eval / ctx->t_eval_us);
|
||||
fprintf(stream, "ts_p_eval: %.2f # tokens / second during prompt processing\n",
|
||||
1.0e6 * ctx->n_p_eval / ctx->t_p_eval_us);
|
||||
fprintf(stream, "ts_sample: %.2f # tokens / second during sampling\n",
|
||||
1.0e6 * ctx->sampling.n_sample / ctx->sampling.t_sample_us);
|
||||
1.0e6 * ctx->n_sample / ctx->t_sample_us);
|
||||
}
|
||||
|
||||
// For internal test use
|
||||
|
|
|
@ -2,13 +2,13 @@
|
|||
#undef NDEBUG
|
||||
#endif
|
||||
|
||||
#define LLAMA_API_INTERNAL
|
||||
|
||||
#include "ggml.h"
|
||||
#include "llama.h"
|
||||
#include "llama-impl.h"
|
||||
#include "unicode.h"
|
||||
#include "grammar-parser.h"
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "unicode.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
|
|
@ -2,8 +2,8 @@
|
|||
#undef NDEBUG
|
||||
#endif
|
||||
|
||||
#define LLAMA_API_INTERNAL
|
||||
#include "llama.h"
|
||||
#include "llama-impl.h"
|
||||
#include "grammar-parser.h"
|
||||
|
||||
#include <cassert>
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
#include "ggml.h"
|
||||
#include "llama.h"
|
||||
#include "llama-sampling.h"
|
||||
|
||||
#ifdef NDEBUG
|
||||
#undef NDEBUG
|
||||
|
@ -20,6 +20,8 @@ static void dump(const llama_token_data_array * candidates) {
|
|||
|
||||
static void test_top_k(const std::vector<float> & probs, const std::vector<float> & expected_probs, int k) {
|
||||
const size_t n_vocab = probs.size();
|
||||
llama_sampling smpl(1234, n_vocab);
|
||||
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
||||
|
@ -28,9 +30,9 @@ static void test_top_k(const std::vector<float> & probs, const std::vector<float
|
|||
}
|
||||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
llama_sample_softmax(nullptr, &candidates_p);
|
||||
llama_sample_softmax_impl(smpl, &candidates_p);
|
||||
DUMP(&candidates_p);
|
||||
llama_sample_top_k(nullptr, &candidates_p, k, 1);
|
||||
llama_sample_top_k_impl(smpl, &candidates_p, k, 1);
|
||||
DUMP(&candidates_p);
|
||||
|
||||
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
||||
|
@ -41,6 +43,8 @@ static void test_top_k(const std::vector<float> & probs, const std::vector<float
|
|||
|
||||
static void test_top_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
|
||||
const size_t n_vocab = probs.size();
|
||||
llama_sampling smpl(1234, n_vocab);
|
||||
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
||||
|
@ -49,9 +53,9 @@ static void test_top_p(const std::vector<float> & probs, const std::vector<float
|
|||
}
|
||||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
llama_sample_softmax(nullptr, &candidates_p);
|
||||
llama_sample_softmax_impl(smpl, &candidates_p);
|
||||
DUMP(&candidates_p);
|
||||
llama_sample_top_p(nullptr, &candidates_p, p, 1);
|
||||
llama_sample_top_p_impl(smpl, &candidates_p, p, 1);
|
||||
DUMP(&candidates_p);
|
||||
|
||||
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
||||
|
@ -62,6 +66,8 @@ static void test_top_p(const std::vector<float> & probs, const std::vector<float
|
|||
|
||||
static void test_tfs(const std::vector<float> & probs, const std::vector<float> & expected_probs, float z) {
|
||||
const size_t n_vocab = probs.size();
|
||||
llama_sampling smpl(1234, n_vocab);
|
||||
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
||||
|
@ -71,7 +77,7 @@ static void test_tfs(const std::vector<float> & probs, const std::vector<float>
|
|||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
DUMP(&candidates_p);
|
||||
llama_sample_tail_free(nullptr, &candidates_p, z, 1);
|
||||
llama_sample_tail_free_impl(smpl, &candidates_p, z, 1);
|
||||
DUMP(&candidates_p);
|
||||
|
||||
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
||||
|
@ -82,6 +88,8 @@ static void test_tfs(const std::vector<float> & probs, const std::vector<float>
|
|||
|
||||
static void test_min_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
|
||||
const size_t n_vocab = probs.size();
|
||||
llama_sampling smpl(1234, n_vocab);
|
||||
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
||||
|
@ -91,9 +99,9 @@ static void test_min_p(const std::vector<float> & probs, const std::vector<float
|
|||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
DUMP(&candidates_p);
|
||||
llama_sample_min_p(nullptr, &candidates_p, p, 1);
|
||||
llama_sample_min_p_impl(smpl, &candidates_p, p, 1);
|
||||
DUMP(&candidates_p);
|
||||
llama_sample_softmax(nullptr, &candidates_p);
|
||||
llama_sample_softmax_impl(smpl, &candidates_p);
|
||||
|
||||
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
||||
for (size_t i = 0; i < candidates_p.size; i++) {
|
||||
|
@ -103,6 +111,8 @@ static void test_min_p(const std::vector<float> & probs, const std::vector<float
|
|||
|
||||
static void test_typical(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
|
||||
const size_t n_vocab = probs.size();
|
||||
llama_sampling smpl(1234, n_vocab);
|
||||
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
||||
|
@ -112,7 +122,7 @@ static void test_typical(const std::vector<float> & probs, const std::vector<flo
|
|||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
DUMP(&candidates_p);
|
||||
llama_sample_typical(nullptr, &candidates_p, p, 1);
|
||||
llama_sample_typical_impl(smpl, &candidates_p, p, 1);
|
||||
DUMP(&candidates_p);
|
||||
|
||||
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
||||
|
@ -128,6 +138,8 @@ static void test_repetition_penalties(
|
|||
GGML_ASSERT(probs.size() == expected_probs.size());
|
||||
|
||||
const size_t n_vocab = probs.size();
|
||||
llama_sampling smpl(1234, n_vocab);
|
||||
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
||||
|
@ -136,10 +148,10 @@ static void test_repetition_penalties(
|
|||
}
|
||||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
llama_sample_softmax(nullptr, &candidates_p);
|
||||
llama_sample_softmax_impl(smpl, &candidates_p);
|
||||
DUMP(&candidates_p);
|
||||
llama_sample_repetition_penalties(nullptr, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence);
|
||||
llama_sample_softmax(nullptr, &candidates_p);
|
||||
llama_sample_repetition_penalties_impl(smpl, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence);
|
||||
llama_sample_softmax_impl(smpl, &candidates_p);
|
||||
DUMP(&candidates_p);
|
||||
|
||||
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
||||
|
@ -148,9 +160,10 @@ static void test_repetition_penalties(
|
|||
}
|
||||
}
|
||||
|
||||
static void test_sampler_queue(
|
||||
const size_t n_vocab, const std::string samplers_sequence, const int top_k, const float top_p, const float min_p
|
||||
static void test_sampler_queue(const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p
|
||||
) {
|
||||
llama_sampling smpl(1234, n_vocab);
|
||||
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
||||
|
@ -165,16 +178,16 @@ static void test_sampler_queue(
|
|||
|
||||
for (auto s : samplers_sequence) {
|
||||
switch (s){
|
||||
case 'k': llama_sample_top_k (nullptr, &candidates_p, top_k, 1); break;
|
||||
case 'k': llama_sample_top_k_impl(smpl, &candidates_p, top_k, 1); break;
|
||||
case 'f': GGML_ASSERT(false && "tail_free test not implemented"); break;
|
||||
case 'y': GGML_ASSERT(false && "typical test not implemented"); break;
|
||||
case 'p': llama_sample_top_p (nullptr, &candidates_p, top_p, 1); break;
|
||||
case 'm': llama_sample_min_p (nullptr, &candidates_p, min_p, 1); break;
|
||||
case 'p': llama_sample_top_p_impl(smpl, &candidates_p, top_p, 1); break;
|
||||
case 'm': llama_sample_min_p_impl(smpl, &candidates_p, min_p, 1); break;
|
||||
case 't': GGML_ASSERT(false && "temperature test not implemented"); break;
|
||||
default : GGML_ASSERT(false && "Unknown sampler"); break;
|
||||
}
|
||||
|
||||
llama_sample_softmax(nullptr, &candidates_p); // make sure tokens are sorted for tests
|
||||
llama_sample_softmax_impl(smpl, &candidates_p); // make sure tokens are sorted for tests
|
||||
|
||||
const int size = candidates_p.size;
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue