cont : llama_sampling_init() use llama_sampling_params
ggml-ci
This commit is contained in:
parent
c5734f1274
commit
d352b01af9
17 changed files with 103 additions and 45 deletions
|
@ -6,7 +6,43 @@ struct llama_sampling_context * llama_sampling_init(const struct gpt_sampling_pa
|
||||||
struct llama_sampling_context * result = new llama_sampling_context();
|
struct llama_sampling_context * result = new llama_sampling_context();
|
||||||
|
|
||||||
result->params = params;
|
result->params = params;
|
||||||
result->smpl = llama_sampling_init(model, params.grammar.c_str(), "root");
|
|
||||||
|
{
|
||||||
|
auto lp = llama_sampling_default_params();
|
||||||
|
|
||||||
|
lp.seed = params.seed;
|
||||||
|
lp.n_prev = params.n_prev;
|
||||||
|
lp.n_probs = params.n_probs;
|
||||||
|
lp.min_keep = params.min_keep;
|
||||||
|
lp.top_k = params.top_k;
|
||||||
|
lp.top_p = params.top_p;
|
||||||
|
lp.min_p = params.min_p;
|
||||||
|
lp.tfs_z = params.tfs_z;
|
||||||
|
lp.typical_p = params.typical_p;
|
||||||
|
lp.temp = params.temp;
|
||||||
|
lp.dynatemp_range = params.dynatemp_range;
|
||||||
|
lp.dynatemp_exponent = params.dynatemp_exponent;
|
||||||
|
lp.penalty_last_n = params.penalty_last_n;
|
||||||
|
lp.penalty_repeat = params.penalty_repeat;
|
||||||
|
lp.penalty_freq = params.penalty_freq;
|
||||||
|
lp.penalty_present = params.penalty_present;
|
||||||
|
lp.mirostat = params.mirostat;
|
||||||
|
lp.mirostat_tau = params.mirostat_tau;
|
||||||
|
lp.mirostat_eta = params.mirostat_eta;
|
||||||
|
lp.penalize_nl = params.penalize_nl;
|
||||||
|
lp.ignore_eos = params.ignore_eos;
|
||||||
|
|
||||||
|
lp.grammar = params.grammar.c_str();
|
||||||
|
lp.grammar_root = "root";
|
||||||
|
|
||||||
|
lp.cfg_prompt = params.cfg_negative_prompt.c_str();
|
||||||
|
lp.cfg_scale = params.cfg_scale;
|
||||||
|
|
||||||
|
lp.n_logit_bias = params.logit_bias.size();
|
||||||
|
lp.logit_bias = params.logit_bias.data();
|
||||||
|
|
||||||
|
result->smpl = llama_sampling_init(model, lp);
|
||||||
|
}
|
||||||
|
|
||||||
result->prev.resize(params.n_prev);
|
result->prev.resize(params.n_prev);
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@ enum class llama_sampler_type : char {
|
||||||
|
|
||||||
// sampling parameters
|
// sampling parameters
|
||||||
typedef struct gpt_sampling_params {
|
typedef struct gpt_sampling_params {
|
||||||
|
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
|
||||||
int32_t n_prev = 64; // number of previous tokens to remember
|
int32_t n_prev = 64; // number of previous tokens to remember
|
||||||
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
||||||
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
|
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
|
||||||
|
@ -37,7 +38,6 @@ typedef struct gpt_sampling_params {
|
||||||
float mirostat_eta = 0.10f; // learning rate
|
float mirostat_eta = 0.10f; // learning rate
|
||||||
bool penalize_nl = false; // consider newlines as a repeatable token
|
bool penalize_nl = false; // consider newlines as a repeatable token
|
||||||
bool ignore_eos = false;
|
bool ignore_eos = false;
|
||||||
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
|
|
||||||
|
|
||||||
std::vector<llama_sampler_type> samplers_sequence = {
|
std::vector<llama_sampler_type> samplers_sequence = {
|
||||||
llama_sampler_type::TOP_K,
|
llama_sampler_type::TOP_K,
|
||||||
|
|
|
@ -50,7 +50,7 @@ defer {
|
||||||
llama_free(context)
|
llama_free(context)
|
||||||
}
|
}
|
||||||
|
|
||||||
let smpl = llama_sampling_init(model, nil, nil)
|
let smpl = llama_sampling_init(model, llama_sampling_default_params())
|
||||||
guard smpl != nil else {
|
guard smpl != nil else {
|
||||||
print("Failed to initialize sampling")
|
print("Failed to initialize sampling")
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
|
@ -64,7 +64,7 @@ int main(int argc, char ** argv) {
|
||||||
ctx_params.n_batch = std::max(n_predict, n_parallel);
|
ctx_params.n_batch = std::max(n_predict, n_parallel);
|
||||||
|
|
||||||
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
|
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
|
||||||
llama_sampling * smpl = llama_sampling_init(model, nullptr, nullptr);
|
llama_sampling * smpl = llama_sampling_init(model, llama_sampling_default_params());
|
||||||
|
|
||||||
if (ctx == NULL) {
|
if (ctx == NULL) {
|
||||||
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
|
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
|
||||||
|
|
|
@ -172,7 +172,7 @@ int main(int argc, char * argv[]) {
|
||||||
// create generation context
|
// create generation context
|
||||||
llama_context * ctx = llama_new_context_with_model(model, cparams);
|
llama_context * ctx = llama_new_context_with_model(model, cparams);
|
||||||
|
|
||||||
llama_sampling * smpl = llama_sampling_init(model, nullptr, nullptr);
|
llama_sampling * smpl = llama_sampling_init(model, llama_sampling_default_params());
|
||||||
|
|
||||||
// ### Embedding/Representation ###
|
// ### Embedding/Representation ###
|
||||||
// samples taken from: https://github.com/ContextualAI/gritlm#basic
|
// samples taken from: https://github.com/ContextualAI/gritlm#basic
|
||||||
|
|
|
@ -43,7 +43,7 @@ actor LlamaContext {
|
||||||
self.tokens_list = []
|
self.tokens_list = []
|
||||||
self.batch = llama_batch_init(512, 0, 1)
|
self.batch = llama_batch_init(512, 0, 1)
|
||||||
self.temporary_invalid_cchars = []
|
self.temporary_invalid_cchars = []
|
||||||
self.sampling = llama_sampling_init(context, nil, nil);
|
self.sampling = llama_sampling_init(context, llama_sampling_default_params())
|
||||||
}
|
}
|
||||||
|
|
||||||
deinit {
|
deinit {
|
||||||
|
|
|
@ -83,7 +83,7 @@ int main(int argc, char ** argv) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_sampling * smpl = llama_sampling_init(model, nullptr, nullptr);
|
llama_sampling * smpl = llama_sampling_init(model, llama_sampling_default_params());
|
||||||
|
|
||||||
// tokenize the prompt
|
// tokenize the prompt
|
||||||
std::vector<llama_token> tokens_list;
|
std::vector<llama_token> tokens_list;
|
||||||
|
|
|
@ -37,7 +37,7 @@ int main(int argc, char ** argv) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_sampling * smpl = llama_sampling_init(model, nullptr, nullptr);
|
llama_sampling * smpl = llama_sampling_init(model, llama_sampling_default_params());
|
||||||
|
|
||||||
// tokenize prompt
|
// tokenize prompt
|
||||||
auto tokens = llama_tokenize(ctx, params.prompt, true);
|
auto tokens = llama_tokenize(ctx, params.prompt, true);
|
||||||
|
@ -97,7 +97,7 @@ int main(int argc, char ** argv) {
|
||||||
// make new context
|
// make new context
|
||||||
auto * ctx2 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
|
auto * ctx2 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
|
||||||
|
|
||||||
llama_sampling * smpl2 = llama_sampling_init(model, nullptr, nullptr);
|
llama_sampling * smpl2 = llama_sampling_init(model, llama_sampling_default_params());
|
||||||
|
|
||||||
printf("\nsecond run: %s", params.prompt.c_str());
|
printf("\nsecond run: %s", params.prompt.c_str());
|
||||||
|
|
||||||
|
@ -162,7 +162,7 @@ int main(int argc, char ** argv) {
|
||||||
// make new context
|
// make new context
|
||||||
auto * ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
|
auto * ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
|
||||||
|
|
||||||
llama_sampling * smpl3 = llama_sampling_init(model, nullptr, nullptr);
|
llama_sampling * smpl3 = llama_sampling_init(model, llama_sampling_default_params());
|
||||||
|
|
||||||
printf("\nsingle seq run: %s", params.prompt.c_str());
|
printf("\nsingle seq run: %s", params.prompt.c_str());
|
||||||
|
|
||||||
|
|
|
@ -55,7 +55,7 @@ int main(int argc, char ** argv) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_sampling * smpl = llama_sampling_init(model, nullptr, nullptr);
|
llama_sampling * smpl = llama_sampling_init(model, llama_sampling_default_params());
|
||||||
|
|
||||||
// tokenize the prompt
|
// tokenize the prompt
|
||||||
|
|
||||||
|
|
|
@ -328,7 +328,8 @@ extern "C" {
|
||||||
enum ggml_type type_k; // data type for K cache [EXPERIMENTAL]
|
enum ggml_type type_k; // data type for K cache [EXPERIMENTAL]
|
||||||
enum ggml_type type_v; // data type for V cache [EXPERIMENTAL]
|
enum ggml_type type_v; // data type for V cache [EXPERIMENTAL]
|
||||||
|
|
||||||
// Keep the booleans together to avoid misalignment during copy-by-value.
|
// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
|
||||||
|
// TODO: move at the end of the struct
|
||||||
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
|
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
|
||||||
bool embeddings; // if true, extract embeddings (together with logits)
|
bool embeddings; // if true, extract embeddings (together with logits)
|
||||||
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
|
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
|
||||||
|
@ -382,14 +383,20 @@ extern "C" {
|
||||||
int32_t mirostat; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
int32_t mirostat; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
||||||
float mirostat_tau; // target entropy
|
float mirostat_tau; // target entropy
|
||||||
float mirostat_eta; // learning rate
|
float mirostat_eta; // learning rate
|
||||||
bool penalize_nl; // consider newlines as a repeatable token
|
|
||||||
bool ignore_eos; // ignore the end-of-sequence token
|
|
||||||
|
|
||||||
|
// https://github.com/ggerganov/llama.cpp/pull/1773
|
||||||
const char * grammar;
|
const char * grammar;
|
||||||
const char * grammar_root;
|
const char * grammar_root;
|
||||||
|
|
||||||
|
const char * cfg_prompt; // string to help guidance in negative direction
|
||||||
|
float cfg_scale; // how strong is guidance
|
||||||
|
|
||||||
int32_t n_logit_bias;
|
int32_t n_logit_bias;
|
||||||
const llama_logit_bias * logit_bias;
|
const llama_logit_bias * logit_bias;
|
||||||
|
|
||||||
|
// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
|
||||||
|
bool penalize_nl; // consider newlines as a repeatable token
|
||||||
|
bool ignore_eos; // ignore the end-of-sequence token
|
||||||
} llama_sampling_params;
|
} llama_sampling_params;
|
||||||
|
|
||||||
// performance timing information
|
// performance timing information
|
||||||
|
@ -1006,9 +1013,8 @@ extern "C" {
|
||||||
// Sampling functions
|
// Sampling functions
|
||||||
//
|
//
|
||||||
|
|
||||||
// TODO: args become llama_sampling_params
|
|
||||||
// TODO: llama_model should become llama_vocab
|
// TODO: llama_model should become llama_vocab
|
||||||
LLAMA_API struct llama_sampling * llama_sampling_init(const struct llama_model * model, const char * grammar_str, const char * grammar_root);
|
LLAMA_API struct llama_sampling * llama_sampling_init(const struct llama_model * model, struct llama_sampling_params params);
|
||||||
|
|
||||||
LLAMA_API void llama_sampling_free(struct llama_sampling * smpl);
|
LLAMA_API void llama_sampling_free(struct llama_sampling * smpl);
|
||||||
|
|
||||||
|
|
|
@ -24,13 +24,7 @@ static void llama_log_softmax(float * array, size_t size) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_sampling::llama_sampling(uint32_t n_vocab) : n_vocab(n_vocab) {
|
llama_sampling::llama_sampling(const struct llama_vocab & vocab) : vocab(vocab) {
|
||||||
}
|
|
||||||
|
|
||||||
llama_sampling::llama_sampling(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) : n_vocab(vocab.n_vocab) {
|
|
||||||
if (grammar_str != nullptr && grammar_str[0] != '\0') {
|
|
||||||
grammar = llama_grammar_init_impl(vocab, grammar_str, grammar_root);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_sampling::~llama_sampling() {
|
llama_sampling::~llama_sampling() {
|
||||||
|
@ -39,8 +33,16 @@ llama_sampling::~llama_sampling() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct llama_sampling * llama_sampling_init_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) {
|
struct llama_sampling * llama_sampling_init_impl(const struct llama_vocab & vocab, struct llama_sampling_params params) {
|
||||||
return new llama_sampling(vocab, grammar_str, grammar_root);
|
auto * result = new llama_sampling(vocab);
|
||||||
|
|
||||||
|
// TODO: store params
|
||||||
|
|
||||||
|
if (params.grammar != nullptr && params.grammar[0] != '\0') {
|
||||||
|
result->grammar = llama_grammar_init_impl(result->vocab, params.grammar, params.grammar_root);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_sampling_free_impl(struct llama_sampling * sampling) {
|
void llama_sampling_free_impl(struct llama_sampling * sampling) {
|
||||||
|
@ -48,7 +50,7 @@ void llama_sampling_free_impl(struct llama_sampling * sampling) {
|
||||||
}
|
}
|
||||||
|
|
||||||
struct llama_sampling * llama_sampling_cp_impl(const struct llama_sampling & smpl) {
|
struct llama_sampling * llama_sampling_cp_impl(const struct llama_sampling & smpl) {
|
||||||
auto * result = new llama_sampling(smpl.n_vocab);
|
auto * result = new llama_sampling(smpl.vocab);
|
||||||
|
|
||||||
if (smpl.grammar) {
|
if (smpl.grammar) {
|
||||||
result->grammar = llama_grammar_copy_impl(*smpl.grammar);
|
result->grammar = llama_grammar_copy_impl(*smpl.grammar);
|
||||||
|
@ -493,7 +495,7 @@ void llama_sampling_apply_guidance_impl(
|
||||||
float * logits,
|
float * logits,
|
||||||
float * logits_guidance,
|
float * logits_guidance,
|
||||||
float scale) {
|
float scale) {
|
||||||
const auto n_vocab = smpl.n_vocab;
|
const auto n_vocab = smpl.vocab.n_vocab;
|
||||||
|
|
||||||
llama_log_softmax(logits, n_vocab);
|
llama_log_softmax(logits, n_vocab);
|
||||||
llama_log_softmax(logits_guidance, n_vocab);
|
llama_log_softmax(logits_guidance, n_vocab);
|
||||||
|
@ -507,7 +509,7 @@ void llama_sampling_apply_guidance_impl(
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token llama_sampling_sample_mirostat_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
|
llama_token llama_sampling_sample_mirostat_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
|
||||||
const int32_t n_vocab = float(smpl.n_vocab);
|
const float n_vocab = float(smpl.vocab.n_vocab);
|
||||||
|
|
||||||
llama_sampling_softmax_impl(smpl, candidates);
|
llama_sampling_softmax_impl(smpl, candidates);
|
||||||
|
|
||||||
|
|
|
@ -1,17 +1,15 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "llama-impl.h"
|
|
||||||
#include "llama-grammar.h"
|
#include "llama-grammar.h"
|
||||||
|
|
||||||
struct llama_vocab;
|
struct llama_vocab;
|
||||||
struct llama_grammar;
|
struct llama_grammar;
|
||||||
|
|
||||||
struct llama_sampling {
|
struct llama_sampling {
|
||||||
llama_sampling(uint32_t n_vocab);
|
llama_sampling(const struct llama_vocab & vocab);
|
||||||
llama_sampling(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root);
|
|
||||||
~llama_sampling();
|
~llama_sampling();
|
||||||
|
|
||||||
const uint32_t n_vocab;
|
const llama_vocab & vocab;
|
||||||
|
|
||||||
std::mt19937 rng;
|
std::mt19937 rng;
|
||||||
|
|
||||||
|
@ -26,7 +24,7 @@ struct llama_sampling {
|
||||||
// internal API
|
// internal API
|
||||||
//
|
//
|
||||||
|
|
||||||
struct llama_sampling * llama_sampling_init_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root);
|
struct llama_sampling * llama_sampling_init_impl(const struct llama_vocab & vocab, struct llama_sampling_params params);
|
||||||
|
|
||||||
void llama_sampling_free_impl(struct llama_sampling * sampling);
|
void llama_sampling_free_impl(struct llama_sampling * sampling);
|
||||||
|
|
||||||
|
|
|
@ -93,6 +93,18 @@ int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llama_vocab llama_vocab::for_tests(uint32_t n_vocab) {
|
||||||
|
llama_vocab vocab;
|
||||||
|
vocab.n_vocab = n_vocab;
|
||||||
|
vocab.token_to_id.reserve(n_vocab);
|
||||||
|
vocab.id_to_token.reserve(n_vocab);
|
||||||
|
for (uint32_t i = 0; i < n_vocab; i++) {
|
||||||
|
vocab.token_to_id[format("token_%u", i)] = i;
|
||||||
|
vocab.id_to_token.push_back({ format("token_%u", i), 0.0f, LLAMA_TOKEN_ATTR_NORMAL });
|
||||||
|
}
|
||||||
|
return vocab;
|
||||||
|
}
|
||||||
|
|
||||||
static enum llama_vocab_type llama_vocab_get_type(const llama_vocab & vocab) {
|
static enum llama_vocab_type llama_vocab_get_type(const llama_vocab & vocab) {
|
||||||
return vocab.type;
|
return vocab.type;
|
||||||
}
|
}
|
||||||
|
|
|
@ -62,6 +62,8 @@ struct llama_vocab {
|
||||||
std::vector<char> precompiled_charsmap;
|
std::vector<char> precompiled_charsmap;
|
||||||
|
|
||||||
int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
|
int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
|
||||||
|
|
||||||
|
static llama_vocab for_tests(uint32_t n_vocab);
|
||||||
};
|
};
|
||||||
|
|
||||||
//
|
//
|
||||||
|
|
|
@ -16511,12 +16511,14 @@ struct llama_sampling_params llama_sampling_default_params() {
|
||||||
/*.mirostat =*/ 0,
|
/*.mirostat =*/ 0,
|
||||||
/*.mirostat_tau =*/ 5.00f,
|
/*.mirostat_tau =*/ 5.00f,
|
||||||
/*.mirostat_eta =*/ 0.10f,
|
/*.mirostat_eta =*/ 0.10f,
|
||||||
/*.penalize_nl =*/ false,
|
|
||||||
/*.ignore_eos =*/ false,
|
|
||||||
/*.grammar =*/ nullptr,
|
/*.grammar =*/ nullptr,
|
||||||
/*.grammar_root =*/ nullptr,
|
/*.grammar_root =*/ nullptr,
|
||||||
|
/*.cfg_prompt =*/ nullptr,
|
||||||
|
/*.cfg_scale =*/ 1.00f,
|
||||||
/*.n_logit_bias =*/ 0,
|
/*.n_logit_bias =*/ 0,
|
||||||
/*.logit_bias =*/ nullptr,
|
/*.logit_bias =*/ nullptr,
|
||||||
|
/*.penalize_nl =*/ false,
|
||||||
|
/*.ignore_eos =*/ false,
|
||||||
};
|
};
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
|
@ -19091,8 +19093,8 @@ int32_t llama_chat_apply_template(
|
||||||
// sampling
|
// sampling
|
||||||
//
|
//
|
||||||
|
|
||||||
struct llama_sampling * llama_sampling_init(const struct llama_model * model, const char * grammar_str, const char * grammar_root) {
|
struct llama_sampling * llama_sampling_init(const struct llama_model * model, struct llama_sampling_params params) {
|
||||||
return llama_sampling_init_impl(model->vocab, grammar_str, grammar_root);
|
return llama_sampling_init_impl(model->vocab, params);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_sampling_free(struct llama_sampling * smpl) {
|
void llama_sampling_free(struct llama_sampling * smpl) {
|
||||||
|
|
|
@ -2470,7 +2470,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void usage(char ** argv) {
|
static void usage(char ** argv) {
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
#include "llama-vocab.h"
|
||||||
#include "llama-sampling.h"
|
#include "llama-sampling.h"
|
||||||
|
|
||||||
#ifdef NDEBUG
|
#ifdef NDEBUG
|
||||||
|
@ -21,7 +22,7 @@ static void dump(const llama_token_data_array * candidates) {
|
||||||
|
|
||||||
static void test_top_k(const std::vector<float> & probs, const std::vector<float> & expected_probs, int k) {
|
static void test_top_k(const std::vector<float> & probs, const std::vector<float> & expected_probs, int k) {
|
||||||
const size_t n_vocab = probs.size();
|
const size_t n_vocab = probs.size();
|
||||||
llama_sampling smpl(n_vocab);
|
llama_sampling smpl(llama_vocab::for_tests(n_vocab));
|
||||||
|
|
||||||
std::vector<llama_token_data> candidates;
|
std::vector<llama_token_data> candidates;
|
||||||
candidates.reserve(n_vocab);
|
candidates.reserve(n_vocab);
|
||||||
|
@ -44,7 +45,7 @@ static void test_top_k(const std::vector<float> & probs, const std::vector<float
|
||||||
|
|
||||||
static void test_top_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
|
static void test_top_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
|
||||||
const size_t n_vocab = probs.size();
|
const size_t n_vocab = probs.size();
|
||||||
llama_sampling smpl(n_vocab);
|
llama_sampling smpl(llama_vocab::for_tests(n_vocab));
|
||||||
|
|
||||||
std::vector<llama_token_data> candidates;
|
std::vector<llama_token_data> candidates;
|
||||||
candidates.reserve(n_vocab);
|
candidates.reserve(n_vocab);
|
||||||
|
@ -67,7 +68,7 @@ static void test_top_p(const std::vector<float> & probs, const std::vector<float
|
||||||
|
|
||||||
static void test_tfs(const std::vector<float> & probs, const std::vector<float> & expected_probs, float z) {
|
static void test_tfs(const std::vector<float> & probs, const std::vector<float> & expected_probs, float z) {
|
||||||
const size_t n_vocab = probs.size();
|
const size_t n_vocab = probs.size();
|
||||||
llama_sampling smpl(n_vocab);
|
llama_sampling smpl(llama_vocab::for_tests(n_vocab));
|
||||||
|
|
||||||
std::vector<llama_token_data> candidates;
|
std::vector<llama_token_data> candidates;
|
||||||
candidates.reserve(n_vocab);
|
candidates.reserve(n_vocab);
|
||||||
|
@ -89,7 +90,7 @@ static void test_tfs(const std::vector<float> & probs, const std::vector<float>
|
||||||
|
|
||||||
static void test_min_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
|
static void test_min_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
|
||||||
const size_t n_vocab = probs.size();
|
const size_t n_vocab = probs.size();
|
||||||
llama_sampling smpl(n_vocab);
|
llama_sampling smpl(llama_vocab::for_tests(n_vocab));
|
||||||
|
|
||||||
std::vector<llama_token_data> candidates;
|
std::vector<llama_token_data> candidates;
|
||||||
candidates.reserve(n_vocab);
|
candidates.reserve(n_vocab);
|
||||||
|
@ -112,7 +113,7 @@ static void test_min_p(const std::vector<float> & probs, const std::vector<float
|
||||||
|
|
||||||
static void test_typical(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
|
static void test_typical(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
|
||||||
const size_t n_vocab = probs.size();
|
const size_t n_vocab = probs.size();
|
||||||
llama_sampling smpl(n_vocab);
|
llama_sampling smpl(llama_vocab::for_tests(n_vocab));
|
||||||
|
|
||||||
std::vector<llama_token_data> candidates;
|
std::vector<llama_token_data> candidates;
|
||||||
candidates.reserve(n_vocab);
|
candidates.reserve(n_vocab);
|
||||||
|
@ -139,7 +140,7 @@ static void test_repetition_penalties(
|
||||||
GGML_ASSERT(probs.size() == expected_probs.size());
|
GGML_ASSERT(probs.size() == expected_probs.size());
|
||||||
|
|
||||||
const size_t n_vocab = probs.size();
|
const size_t n_vocab = probs.size();
|
||||||
llama_sampling smpl(n_vocab);
|
llama_sampling smpl(llama_vocab::for_tests(n_vocab));
|
||||||
|
|
||||||
std::vector<llama_token_data> candidates;
|
std::vector<llama_token_data> candidates;
|
||||||
candidates.reserve(n_vocab);
|
candidates.reserve(n_vocab);
|
||||||
|
@ -163,7 +164,7 @@ static void test_repetition_penalties(
|
||||||
|
|
||||||
static void test_sampler_queue(const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p
|
static void test_sampler_queue(const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p
|
||||||
) {
|
) {
|
||||||
llama_sampling smpl(n_vocab);
|
llama_sampling smpl(llama_vocab::for_tests(n_vocab));
|
||||||
|
|
||||||
std::vector<llama_token_data> candidates;
|
std::vector<llama_token_data> candidates;
|
||||||
candidates.reserve(n_vocab);
|
candidates.reserve(n_vocab);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue