style : rearrange code + add comments and TODOs
ggml-ci
This commit is contained in:
parent
4a4530b7ff
commit
4b27235624
4 changed files with 115 additions and 48 deletions
|
@ -136,17 +136,6 @@ std::string gpt_sampler_params::print() const {
|
||||||
return std::string(result);
|
return std::string(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string gpt_sampler_print(const struct gpt_sampler * gsmpl) {
|
|
||||||
std::string result = "\tlogits ";
|
|
||||||
|
|
||||||
for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
|
|
||||||
const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
|
|
||||||
result += std::string("-> ") + llama_sampler_name(smpl) + " ";
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params) {
|
struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params) {
|
||||||
llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
|
llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
|
||||||
|
|
||||||
|
@ -232,17 +221,6 @@ void gpt_sampler_free(struct gpt_sampler * gsmpl) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl) {
|
|
||||||
return new gpt_sampler {
|
|
||||||
/* .params = */ gsmpl->params,
|
|
||||||
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
|
|
||||||
/* .chain = */ llama_sampler_clone(gsmpl->chain),
|
|
||||||
/* .prev = */ gsmpl->prev,
|
|
||||||
/* .cur = */ gsmpl->cur,
|
|
||||||
/* .cur_p = */ gsmpl->cur_p,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool accept_grammar) {
|
void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool accept_grammar) {
|
||||||
if (accept_grammar) {
|
if (accept_grammar) {
|
||||||
llama_sampler_accept(gsmpl->grmr, token);
|
llama_sampler_accept(gsmpl->grmr, token);
|
||||||
|
@ -259,12 +237,15 @@ void gpt_sampler_reset(struct gpt_sampler * gsmpl) {
|
||||||
llama_sampler_reset(gsmpl->chain);
|
llama_sampler_reset(gsmpl->chain);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl) {
|
struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl) {
|
||||||
return &gsmpl->cur_p;
|
return new gpt_sampler {
|
||||||
}
|
/* .params = */ gsmpl->params,
|
||||||
|
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
|
||||||
llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl) {
|
/* .chain = */ llama_sampler_clone(gsmpl->chain),
|
||||||
return gsmpl->prev.rat(0);
|
/* .prev = */ gsmpl->prev,
|
||||||
|
/* .cur = */ gsmpl->cur,
|
||||||
|
/* .cur_p = */ gsmpl->cur_p,
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler * gsmpl) {
|
void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler * gsmpl) {
|
||||||
|
@ -279,12 +260,11 @@ void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler *
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
|
llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
|
||||||
auto & grmr = gsmpl->grmr;
|
|
||||||
auto & chain = gsmpl->chain;
|
|
||||||
|
|
||||||
gsmpl->set_logits(ctx, idx);
|
gsmpl->set_logits(ctx, idx);
|
||||||
|
|
||||||
auto & cur_p = gsmpl->cur_p;
|
auto & grmr = gsmpl->grmr;
|
||||||
|
auto & chain = gsmpl->chain;
|
||||||
|
auto & cur_p = gsmpl->cur_p; // initialized by set_logits
|
||||||
|
|
||||||
if (grammar_first) {
|
if (grammar_first) {
|
||||||
llama_sampler_apply(grmr, &cur_p);
|
llama_sampler_apply(grmr, &cur_p);
|
||||||
|
@ -307,24 +287,45 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context
|
||||||
|
|
||||||
llama_sampler_apply(grmr, &single_token_data_array);
|
llama_sampler_apply(grmr, &single_token_data_array);
|
||||||
|
|
||||||
// check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
|
|
||||||
const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
|
const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
|
||||||
if (is_valid) {
|
if (is_valid) {
|
||||||
return id;
|
return id;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// if the token is not valid, sample again, first apply the grammar samplers and then sample
|
// resampling:
|
||||||
|
// if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
|
||||||
gsmpl->set_logits(ctx, idx);
|
gsmpl->set_logits(ctx, idx);
|
||||||
|
|
||||||
llama_sampler_apply(grmr, &cur_p);
|
llama_sampler_apply(grmr, &cur_p);
|
||||||
llama_sampler_apply(chain, &cur_p);
|
llama_sampler_apply(chain, &cur_p);
|
||||||
|
|
||||||
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
|
GGML_ASSERT(cur_p.selected != -1 && "no selected token during re-sampling - check your sampling configuration");
|
||||||
|
|
||||||
return cur_p.data[cur_p.selected].id;
|
return cur_p.data[cur_p.selected].id;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// helpers
|
||||||
|
|
||||||
|
llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl) {
|
||||||
|
return &gsmpl->cur_p;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl) {
|
||||||
|
return gsmpl->prev.rat(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string gpt_sampler_print(const struct gpt_sampler * gsmpl) {
|
||||||
|
std::string result = "\tlogits ";
|
||||||
|
|
||||||
|
for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
|
||||||
|
const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
|
||||||
|
result += std::string("-> ") + llama_sampler_name(smpl) + " ";
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx_main, int n) {
|
std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx_main, int n) {
|
||||||
n = std::min(n, (int) gsmpl->prev.size());
|
n = std::min(n, (int) gsmpl->prev.size());
|
||||||
|
|
||||||
|
|
|
@ -61,24 +61,41 @@ struct gpt_sampler_params {
|
||||||
//
|
//
|
||||||
// - grammar support
|
// - grammar support
|
||||||
// - custom sampler logic based on the parameters
|
// - custom sampler logic based on the parameters
|
||||||
|
// - history of the last accepted tokens
|
||||||
|
// - performance metrics
|
||||||
|
//
|
||||||
|
// This goal is to have a common implementation of the sampling logic shared across the examples.
|
||||||
|
// For example, depending on the temperature, the sampling chain can be very simple (greedy) or more
|
||||||
|
// complex (top-k, top-p, etc).
|
||||||
|
//
|
||||||
|
// Another example is related to the grammar. In general, the grammar constraints applied on the full
|
||||||
|
// vocabulary can be very taxing. To improve performance, the grammar can be applied only to the sampled
|
||||||
|
// token in order to verify if it fits the grammar. And only if the token doesn't fit the grammar, the
|
||||||
|
// grammar constraints are applied to the full vocabulary and the token is resampled.
|
||||||
|
//
|
||||||
|
// The gpt_sampler also maintains a container with the last accepted tokens. In the future, this can
|
||||||
|
// be moved into the core llama library.
|
||||||
|
//
|
||||||
|
// For convenience, the gpt_sampler also maintains a container with the current candidate tokens.
|
||||||
|
// This can be used to access the probabilities of the rest of the non-sampled tokens.
|
||||||
//
|
//
|
||||||
// TODO: measure grammar performance
|
// TODO: measure grammar performance
|
||||||
//
|
//
|
||||||
|
|
||||||
struct gpt_sampler;
|
struct gpt_sampler;
|
||||||
|
|
||||||
|
// llama_sampler API overloads
|
||||||
|
|
||||||
struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params);
|
struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params);
|
||||||
|
|
||||||
void gpt_sampler_free(struct gpt_sampler * gsmpl);
|
void gpt_sampler_free(struct gpt_sampler * gsmpl);
|
||||||
|
|
||||||
struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl);
|
// if accept_grammar is true, the token is accepted both by the sampling chain and the grammar
|
||||||
|
void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool accept_grammar);
|
||||||
void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool accept_grammar);
|
void gpt_sampler_reset (struct gpt_sampler * gsmpl);
|
||||||
void gpt_sampler_reset (struct gpt_sampler * gsmpl);
|
struct gpt_sampler * gpt_sampler_clone (struct gpt_sampler * gsmpl);
|
||||||
|
|
||||||
llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl);
|
|
||||||
|
|
||||||
llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl);
|
|
||||||
|
|
||||||
|
// arguments can be nullptr to skip printing
|
||||||
void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler * gsmpl);
|
void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler * gsmpl);
|
||||||
|
|
||||||
// extended sampling implementation:
|
// extended sampling implementation:
|
||||||
|
@ -89,12 +106,18 @@ void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler *
|
||||||
// - if not: resample by first applying the grammar constraints and then sampling again (slower path)
|
// - if not: resample by first applying the grammar constraints and then sampling again (slower path)
|
||||||
//
|
//
|
||||||
// if grammar_first is true, the grammar is applied before the samplers (slower)
|
// if grammar_first is true, the grammar is applied before the samplers (slower)
|
||||||
// useful in cases where all the resulting candidates must fit the grammar
|
// useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar
|
||||||
//
|
//
|
||||||
llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
|
llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
|
||||||
|
|
||||||
// helpers
|
// helpers
|
||||||
|
|
||||||
|
// access the internal list of current candidate tokens
|
||||||
|
llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl);
|
||||||
|
|
||||||
|
// get the last accepted token
|
||||||
|
llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl);
|
||||||
|
|
||||||
// print the sampler chain into a string
|
// print the sampler chain into a string
|
||||||
std::string gpt_sampler_print(const struct gpt_sampler * gsmpl);
|
std::string gpt_sampler_print(const struct gpt_sampler * gsmpl);
|
||||||
|
|
||||||
|
|
|
@ -206,6 +206,7 @@ extern "C" {
|
||||||
LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs
|
LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// TODO: simplify (https://github.com/ggerganov/llama.cpp/pull/9294#pullrequestreview-2286561979)
|
||||||
typedef struct llama_token_data {
|
typedef struct llama_token_data {
|
||||||
llama_token id; // token id
|
llama_token id; // token id
|
||||||
float logit; // log-odds of the token
|
float logit; // log-odds of the token
|
||||||
|
@ -216,7 +217,7 @@ extern "C" {
|
||||||
// TODO: consider SoA
|
// TODO: consider SoA
|
||||||
llama_token_data * data;
|
llama_token_data * data;
|
||||||
size_t size;
|
size_t size;
|
||||||
int64_t selected;
|
int64_t selected; // this is the index in the data array (i.e. not the token id)
|
||||||
bool sorted;
|
bool sorted;
|
||||||
} llama_token_data_array;
|
} llama_token_data_array;
|
||||||
|
|
||||||
|
@ -979,9 +980,38 @@ extern "C" {
|
||||||
//
|
//
|
||||||
// Sampling API
|
// Sampling API
|
||||||
//
|
//
|
||||||
// In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU).
|
// Sample usage:
|
||||||
|
//
|
||||||
|
// // prepare the sampling chain at the start
|
||||||
|
// auto sparams = llama_sampler_chain_default_params();
|
||||||
|
//
|
||||||
|
// llama_sampler * smpl = llama_sampler_chain_init(sparams);
|
||||||
|
//
|
||||||
|
// llama_sampler_chain_add(smpl, llama_sampler_init_top_k(50));
|
||||||
|
// llama_sampler_chain_add(smpl, llama_sampler_init_top_p(0.9, 1));
|
||||||
|
// llama_sampler_chain_add(smpl, llama_sampler_init_temp (0.8));
|
||||||
|
// llama_sampler_chain_add(smpl, llama_sampler_init_dist (seed));
|
||||||
|
//
|
||||||
|
// ...
|
||||||
|
//
|
||||||
|
// // decoding loop:
|
||||||
|
// while (...) {
|
||||||
|
// ...
|
||||||
|
//
|
||||||
|
// llama_decode(ctx, batch);
|
||||||
|
//
|
||||||
|
// // sample from the logits of the last token in the batch
|
||||||
|
// const llama_token id = llama_sampler_sample(smpl, ctx, -1);
|
||||||
|
//
|
||||||
|
// ...
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// llama_sampler_free(smpl);
|
||||||
|
//
|
||||||
|
//
|
||||||
|
// TODO: In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU).
|
||||||
|
// TODO: in the future, the entire sampling API that uses llama_model should start using llama_vocab
|
||||||
//
|
//
|
||||||
// TODO: in the future, the entire API that uses llama_model should start using llama_vocab
|
|
||||||
|
|
||||||
typedef void * llama_sampler_context_t;
|
typedef void * llama_sampler_context_t;
|
||||||
|
|
||||||
|
@ -1003,6 +1033,7 @@ extern "C" {
|
||||||
llama_sampler_context_t ctx;
|
llama_sampler_context_t ctx;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// mirror of llama_sampler_i:
|
||||||
LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl);
|
LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl);
|
||||||
LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token);
|
LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token);
|
||||||
LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p);
|
LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p);
|
||||||
|
@ -1011,7 +1042,8 @@ extern "C" {
|
||||||
// important: do not free if the sampler has been added to a llama_sampler_chain (via llama_sampler_chain_add)
|
// important: do not free if the sampler has been added to a llama_sampler_chain (via llama_sampler_chain_add)
|
||||||
LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl);
|
LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl);
|
||||||
|
|
||||||
// llama_sampler_chain is a type of llama_sampler that can contain multiple llama_samplers
|
// llama_sampler_chain
|
||||||
|
// a type of llama_sampler that can chain multiple samplers one after another
|
||||||
|
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params);
|
LLAMA_API struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params);
|
||||||
|
|
||||||
|
@ -1089,6 +1121,15 @@ extern "C" {
|
||||||
int32_t n_logit_bias,
|
int32_t n_logit_bias,
|
||||||
const llama_logit_bias * logit_bias);
|
const llama_logit_bias * logit_bias);
|
||||||
|
|
||||||
|
// Shorthand for:
|
||||||
|
//
|
||||||
|
// const auto * logits = llama_get_logits_ith(ctx, idx);
|
||||||
|
// llama_token_data_array cur_p = { ... init from logits ... };
|
||||||
|
// llama_sampler_apply(smpl, &cur_p);
|
||||||
|
// return cur_p.data[cur_p.selected].id;
|
||||||
|
//
|
||||||
|
// At this point, this is mostly a convenience function.
|
||||||
|
//
|
||||||
LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx);
|
LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx);
|
||||||
|
|
||||||
// TODO: extend in the future
|
// TODO: extend in the future
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
// TODO: rename llama-sampling.h/.cpp to llama-sampler.h/.cpp ?
|
||||||
|
|
||||||
#include "llama-grammar.h"
|
#include "llama-grammar.h"
|
||||||
|
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue