cont : llama_sampling_init() use llama_sampling_params

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-08-12 17:14:56 +03:00
parent c5734f1274
commit d352b01af9
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
17 changed files with 103 additions and 45 deletions

View file

@ -6,7 +6,43 @@ struct llama_sampling_context * llama_sampling_init(const struct gpt_sampling_pa
struct llama_sampling_context * result = new llama_sampling_context(); struct llama_sampling_context * result = new llama_sampling_context();
result->params = params; result->params = params;
result->smpl = llama_sampling_init(model, params.grammar.c_str(), "root");
{
auto lp = llama_sampling_default_params();
lp.seed = params.seed;
lp.n_prev = params.n_prev;
lp.n_probs = params.n_probs;
lp.min_keep = params.min_keep;
lp.top_k = params.top_k;
lp.top_p = params.top_p;
lp.min_p = params.min_p;
lp.tfs_z = params.tfs_z;
lp.typical_p = params.typical_p;
lp.temp = params.temp;
lp.dynatemp_range = params.dynatemp_range;
lp.dynatemp_exponent = params.dynatemp_exponent;
lp.penalty_last_n = params.penalty_last_n;
lp.penalty_repeat = params.penalty_repeat;
lp.penalty_freq = params.penalty_freq;
lp.penalty_present = params.penalty_present;
lp.mirostat = params.mirostat;
lp.mirostat_tau = params.mirostat_tau;
lp.mirostat_eta = params.mirostat_eta;
lp.penalize_nl = params.penalize_nl;
lp.ignore_eos = params.ignore_eos;
lp.grammar = params.grammar.c_str();
lp.grammar_root = "root";
lp.cfg_prompt = params.cfg_negative_prompt.c_str();
lp.cfg_scale = params.cfg_scale;
lp.n_logit_bias = params.logit_bias.size();
lp.logit_bias = params.logit_bias.data();
result->smpl = llama_sampling_init(model, lp);
}
result->prev.resize(params.n_prev); result->prev.resize(params.n_prev);

View file

@ -17,6 +17,7 @@ enum class llama_sampler_type : char {
// sampling parameters // sampling parameters
typedef struct gpt_sampling_params { typedef struct gpt_sampling_params {
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
int32_t n_prev = 64; // number of previous tokens to remember int32_t n_prev = 64; // number of previous tokens to remember
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
@ -37,7 +38,6 @@ typedef struct gpt_sampling_params {
float mirostat_eta = 0.10f; // learning rate float mirostat_eta = 0.10f; // learning rate
bool penalize_nl = false; // consider newlines as a repeatable token bool penalize_nl = false; // consider newlines as a repeatable token
bool ignore_eos = false; bool ignore_eos = false;
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
std::vector<llama_sampler_type> samplers_sequence = { std::vector<llama_sampler_type> samplers_sequence = {
llama_sampler_type::TOP_K, llama_sampler_type::TOP_K,

View file

@ -50,7 +50,7 @@ defer {
llama_free(context) llama_free(context)
} }
let smpl = llama_sampling_init(model, nil, nil) let smpl = llama_sampling_init(model, llama_sampling_default_params())
guard smpl != nil else { guard smpl != nil else {
print("Failed to initialize sampling") print("Failed to initialize sampling")
exit(1) exit(1)

View file

@ -64,7 +64,7 @@ int main(int argc, char ** argv) {
ctx_params.n_batch = std::max(n_predict, n_parallel); ctx_params.n_batch = std::max(n_predict, n_parallel);
llama_context * ctx = llama_new_context_with_model(model, ctx_params); llama_context * ctx = llama_new_context_with_model(model, ctx_params);
llama_sampling * smpl = llama_sampling_init(model, nullptr, nullptr); llama_sampling * smpl = llama_sampling_init(model, llama_sampling_default_params());
if (ctx == NULL) { if (ctx == NULL) {
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);

View file

@ -172,7 +172,7 @@ int main(int argc, char * argv[]) {
// create generation context // create generation context
llama_context * ctx = llama_new_context_with_model(model, cparams); llama_context * ctx = llama_new_context_with_model(model, cparams);
llama_sampling * smpl = llama_sampling_init(model, nullptr, nullptr); llama_sampling * smpl = llama_sampling_init(model, llama_sampling_default_params());
// ### Embedding/Representation ### // ### Embedding/Representation ###
// samples taken from: https://github.com/ContextualAI/gritlm#basic // samples taken from: https://github.com/ContextualAI/gritlm#basic

View file

@ -43,7 +43,7 @@ actor LlamaContext {
self.tokens_list = [] self.tokens_list = []
self.batch = llama_batch_init(512, 0, 1) self.batch = llama_batch_init(512, 0, 1)
self.temporary_invalid_cchars = [] self.temporary_invalid_cchars = []
self.sampling = llama_sampling_init(context, nil, nil); self.sampling = llama_sampling_init(context, llama_sampling_default_params())
} }
deinit { deinit {

View file

@ -83,7 +83,7 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
llama_sampling * smpl = llama_sampling_init(model, nullptr, nullptr); llama_sampling * smpl = llama_sampling_init(model, llama_sampling_default_params());
// tokenize the prompt // tokenize the prompt
std::vector<llama_token> tokens_list; std::vector<llama_token> tokens_list;

View file

@ -37,7 +37,7 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
llama_sampling * smpl = llama_sampling_init(model, nullptr, nullptr); llama_sampling * smpl = llama_sampling_init(model, llama_sampling_default_params());
// tokenize prompt // tokenize prompt
auto tokens = llama_tokenize(ctx, params.prompt, true); auto tokens = llama_tokenize(ctx, params.prompt, true);
@ -97,7 +97,7 @@ int main(int argc, char ** argv) {
// make new context // make new context
auto * ctx2 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params)); auto * ctx2 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
llama_sampling * smpl2 = llama_sampling_init(model, nullptr, nullptr); llama_sampling * smpl2 = llama_sampling_init(model, llama_sampling_default_params());
printf("\nsecond run: %s", params.prompt.c_str()); printf("\nsecond run: %s", params.prompt.c_str());
@ -162,7 +162,7 @@ int main(int argc, char ** argv) {
// make new context // make new context
auto * ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params)); auto * ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
llama_sampling * smpl3 = llama_sampling_init(model, nullptr, nullptr); llama_sampling * smpl3 = llama_sampling_init(model, llama_sampling_default_params());
printf("\nsingle seq run: %s", params.prompt.c_str()); printf("\nsingle seq run: %s", params.prompt.c_str());

View file

@ -55,7 +55,7 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
llama_sampling * smpl = llama_sampling_init(model, nullptr, nullptr); llama_sampling * smpl = llama_sampling_init(model, llama_sampling_default_params());
// tokenize the prompt // tokenize the prompt

View file

@ -328,7 +328,8 @@ extern "C" {
enum ggml_type type_k; // data type for K cache [EXPERIMENTAL] enum ggml_type type_k; // data type for K cache [EXPERIMENTAL]
enum ggml_type type_v; // data type for V cache [EXPERIMENTAL] enum ggml_type type_v; // data type for V cache [EXPERIMENTAL]
// Keep the booleans together to avoid misalignment during copy-by-value. // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
// TODO: move at the end of the struct
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead) bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
bool embeddings; // if true, extract embeddings (together with logits) bool embeddings; // if true, extract embeddings (together with logits)
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
@ -382,14 +383,20 @@ extern "C" {
int32_t mirostat; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 int32_t mirostat; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float mirostat_tau; // target entropy float mirostat_tau; // target entropy
float mirostat_eta; // learning rate float mirostat_eta; // learning rate
bool penalize_nl; // consider newlines as a repeatable token
bool ignore_eos; // ignore the end-of-sequence token
// https://github.com/ggerganov/llama.cpp/pull/1773
const char * grammar; const char * grammar;
const char * grammar_root; const char * grammar_root;
const char * cfg_prompt; // string to help guidance in negative direction
float cfg_scale; // how strong is guidance
int32_t n_logit_bias; int32_t n_logit_bias;
const llama_logit_bias * logit_bias; const llama_logit_bias * logit_bias;
// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
bool penalize_nl; // consider newlines as a repeatable token
bool ignore_eos; // ignore the end-of-sequence token
} llama_sampling_params; } llama_sampling_params;
// performance timing information // performance timing information
@ -1006,9 +1013,8 @@ extern "C" {
// Sampling functions // Sampling functions
// //
// TODO: args become llama_sampling_params
// TODO: llama_model should become llama_vocab // TODO: llama_model should become llama_vocab
LLAMA_API struct llama_sampling * llama_sampling_init(const struct llama_model * model, const char * grammar_str, const char * grammar_root); LLAMA_API struct llama_sampling * llama_sampling_init(const struct llama_model * model, struct llama_sampling_params params);
LLAMA_API void llama_sampling_free(struct llama_sampling * smpl); LLAMA_API void llama_sampling_free(struct llama_sampling * smpl);

View file

@ -24,13 +24,7 @@ static void llama_log_softmax(float * array, size_t size) {
} }
} }
llama_sampling::llama_sampling(uint32_t n_vocab) : n_vocab(n_vocab) { llama_sampling::llama_sampling(const struct llama_vocab & vocab) : vocab(vocab) {
}
llama_sampling::llama_sampling(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) : n_vocab(vocab.n_vocab) {
if (grammar_str != nullptr && grammar_str[0] != '\0') {
grammar = llama_grammar_init_impl(vocab, grammar_str, grammar_root);
}
} }
llama_sampling::~llama_sampling() { llama_sampling::~llama_sampling() {
@ -39,8 +33,16 @@ llama_sampling::~llama_sampling() {
} }
} }
struct llama_sampling * llama_sampling_init_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) { struct llama_sampling * llama_sampling_init_impl(const struct llama_vocab & vocab, struct llama_sampling_params params) {
return new llama_sampling(vocab, grammar_str, grammar_root); auto * result = new llama_sampling(vocab);
// TODO: store params
if (params.grammar != nullptr && params.grammar[0] != '\0') {
result->grammar = llama_grammar_init_impl(result->vocab, params.grammar, params.grammar_root);
}
return result;
} }
void llama_sampling_free_impl(struct llama_sampling * sampling) { void llama_sampling_free_impl(struct llama_sampling * sampling) {
@ -48,7 +50,7 @@ void llama_sampling_free_impl(struct llama_sampling * sampling) {
} }
struct llama_sampling * llama_sampling_cp_impl(const struct llama_sampling & smpl) { struct llama_sampling * llama_sampling_cp_impl(const struct llama_sampling & smpl) {
auto * result = new llama_sampling(smpl.n_vocab); auto * result = new llama_sampling(smpl.vocab);
if (smpl.grammar) { if (smpl.grammar) {
result->grammar = llama_grammar_copy_impl(*smpl.grammar); result->grammar = llama_grammar_copy_impl(*smpl.grammar);
@ -493,7 +495,7 @@ void llama_sampling_apply_guidance_impl(
float * logits, float * logits,
float * logits_guidance, float * logits_guidance,
float scale) { float scale) {
const auto n_vocab = smpl.n_vocab; const auto n_vocab = smpl.vocab.n_vocab;
llama_log_softmax(logits, n_vocab); llama_log_softmax(logits, n_vocab);
llama_log_softmax(logits_guidance, n_vocab); llama_log_softmax(logits_guidance, n_vocab);
@ -507,7 +509,7 @@ void llama_sampling_apply_guidance_impl(
} }
llama_token llama_sampling_sample_mirostat_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) { llama_token llama_sampling_sample_mirostat_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
const int32_t n_vocab = float(smpl.n_vocab); const float n_vocab = float(smpl.vocab.n_vocab);
llama_sampling_softmax_impl(smpl, candidates); llama_sampling_softmax_impl(smpl, candidates);

View file

@ -1,17 +1,15 @@
#pragma once #pragma once
#include "llama-impl.h"
#include "llama-grammar.h" #include "llama-grammar.h"
struct llama_vocab; struct llama_vocab;
struct llama_grammar; struct llama_grammar;
struct llama_sampling { struct llama_sampling {
llama_sampling(uint32_t n_vocab); llama_sampling(const struct llama_vocab & vocab);
llama_sampling(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root);
~llama_sampling(); ~llama_sampling();
const uint32_t n_vocab; const llama_vocab & vocab;
std::mt19937 rng; std::mt19937 rng;
@ -26,7 +24,7 @@ struct llama_sampling {
// internal API // internal API
// //
struct llama_sampling * llama_sampling_init_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root); struct llama_sampling * llama_sampling_init_impl(const struct llama_vocab & vocab, struct llama_sampling_params params);
void llama_sampling_free_impl(struct llama_sampling * sampling); void llama_sampling_free_impl(struct llama_sampling * sampling);

View file

@ -93,6 +93,18 @@ int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string
return it->second; return it->second;
} }
llama_vocab llama_vocab::for_tests(uint32_t n_vocab) {
llama_vocab vocab;
vocab.n_vocab = n_vocab;
vocab.token_to_id.reserve(n_vocab);
vocab.id_to_token.reserve(n_vocab);
for (uint32_t i = 0; i < n_vocab; i++) {
vocab.token_to_id[format("token_%u", i)] = i;
vocab.id_to_token.push_back({ format("token_%u", i), 0.0f, LLAMA_TOKEN_ATTR_NORMAL });
}
return vocab;
}
static enum llama_vocab_type llama_vocab_get_type(const llama_vocab & vocab) { static enum llama_vocab_type llama_vocab_get_type(const llama_vocab & vocab) {
return vocab.type; return vocab.type;
} }

View file

@ -62,6 +62,8 @@ struct llama_vocab {
std::vector<char> precompiled_charsmap; std::vector<char> precompiled_charsmap;
int find_bpe_rank(const std::string & token_left, const std::string & token_right) const; int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
static llama_vocab for_tests(uint32_t n_vocab);
}; };
// //

View file

@ -16511,12 +16511,14 @@ struct llama_sampling_params llama_sampling_default_params() {
/*.mirostat =*/ 0, /*.mirostat =*/ 0,
/*.mirostat_tau =*/ 5.00f, /*.mirostat_tau =*/ 5.00f,
/*.mirostat_eta =*/ 0.10f, /*.mirostat_eta =*/ 0.10f,
/*.penalize_nl =*/ false,
/*.ignore_eos =*/ false,
/*.grammar =*/ nullptr, /*.grammar =*/ nullptr,
/*.grammar_root =*/ nullptr, /*.grammar_root =*/ nullptr,
/*.cfg_prompt =*/ nullptr,
/*.cfg_scale =*/ 1.00f,
/*.n_logit_bias =*/ 0, /*.n_logit_bias =*/ 0,
/*.logit_bias =*/ nullptr, /*.logit_bias =*/ nullptr,
/*.penalize_nl =*/ false,
/*.ignore_eos =*/ false,
}; };
return result; return result;
@ -19091,8 +19093,8 @@ int32_t llama_chat_apply_template(
// sampling // sampling
// //
struct llama_sampling * llama_sampling_init(const struct llama_model * model, const char * grammar_str, const char * grammar_root) { struct llama_sampling * llama_sampling_init(const struct llama_model * model, struct llama_sampling_params params) {
return llama_sampling_init_impl(model->vocab, grammar_str, grammar_root); return llama_sampling_init_impl(model->vocab, params);
} }
void llama_sampling_free(struct llama_sampling * smpl) { void llama_sampling_free(struct llama_sampling * smpl) {

View file

@ -2470,7 +2470,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
} }
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
return false;
} }
static void usage(char ** argv) { static void usage(char ** argv) {

View file

@ -1,5 +1,6 @@
#include "ggml.h" #include "ggml.h"
#include "llama.h" #include "llama.h"
#include "llama-vocab.h"
#include "llama-sampling.h" #include "llama-sampling.h"
#ifdef NDEBUG #ifdef NDEBUG
@ -21,7 +22,7 @@ static void dump(const llama_token_data_array * candidates) {
static void test_top_k(const std::vector<float> & probs, const std::vector<float> & expected_probs, int k) { static void test_top_k(const std::vector<float> & probs, const std::vector<float> & expected_probs, int k) {
const size_t n_vocab = probs.size(); const size_t n_vocab = probs.size();
llama_sampling smpl(n_vocab); llama_sampling smpl(llama_vocab::for_tests(n_vocab));
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab); candidates.reserve(n_vocab);
@ -44,7 +45,7 @@ static void test_top_k(const std::vector<float> & probs, const std::vector<float
static void test_top_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) { static void test_top_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
const size_t n_vocab = probs.size(); const size_t n_vocab = probs.size();
llama_sampling smpl(n_vocab); llama_sampling smpl(llama_vocab::for_tests(n_vocab));
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab); candidates.reserve(n_vocab);
@ -67,7 +68,7 @@ static void test_top_p(const std::vector<float> & probs, const std::vector<float
static void test_tfs(const std::vector<float> & probs, const std::vector<float> & expected_probs, float z) { static void test_tfs(const std::vector<float> & probs, const std::vector<float> & expected_probs, float z) {
const size_t n_vocab = probs.size(); const size_t n_vocab = probs.size();
llama_sampling smpl(n_vocab); llama_sampling smpl(llama_vocab::for_tests(n_vocab));
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab); candidates.reserve(n_vocab);
@ -89,7 +90,7 @@ static void test_tfs(const std::vector<float> & probs, const std::vector<float>
static void test_min_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) { static void test_min_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
const size_t n_vocab = probs.size(); const size_t n_vocab = probs.size();
llama_sampling smpl(n_vocab); llama_sampling smpl(llama_vocab::for_tests(n_vocab));
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab); candidates.reserve(n_vocab);
@ -112,7 +113,7 @@ static void test_min_p(const std::vector<float> & probs, const std::vector<float
static void test_typical(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) { static void test_typical(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
const size_t n_vocab = probs.size(); const size_t n_vocab = probs.size();
llama_sampling smpl(n_vocab); llama_sampling smpl(llama_vocab::for_tests(n_vocab));
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab); candidates.reserve(n_vocab);
@ -139,7 +140,7 @@ static void test_repetition_penalties(
GGML_ASSERT(probs.size() == expected_probs.size()); GGML_ASSERT(probs.size() == expected_probs.size());
const size_t n_vocab = probs.size(); const size_t n_vocab = probs.size();
llama_sampling smpl(n_vocab); llama_sampling smpl(llama_vocab::for_tests(n_vocab));
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab); candidates.reserve(n_vocab);
@ -163,7 +164,7 @@ static void test_repetition_penalties(
static void test_sampler_queue(const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p static void test_sampler_queue(const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p
) { ) {
llama_sampling smpl(n_vocab); llama_sampling smpl(llama_vocab::for_tests(n_vocab));
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab); candidates.reserve(n_vocab);