style : rearrange code + add comments and TODOs
ggml-ci
This commit is contained in:
parent
4a4530b7ff
commit
4b27235624
4 changed files with 115 additions and 48 deletions
|
@ -136,17 +136,6 @@ std::string gpt_sampler_params::print() const {
|
|||
return std::string(result);
|
||||
}
|
||||
|
||||
std::string gpt_sampler_print(const struct gpt_sampler * gsmpl) {
|
||||
std::string result = "\tlogits ";
|
||||
|
||||
for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
|
||||
const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
|
||||
result += std::string("-> ") + llama_sampler_name(smpl) + " ";
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params) {
|
||||
llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
|
||||
|
||||
|
@ -232,17 +221,6 @@ void gpt_sampler_free(struct gpt_sampler * gsmpl) {
|
|||
}
|
||||
}
|
||||
|
||||
struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl) {
|
||||
return new gpt_sampler {
|
||||
/* .params = */ gsmpl->params,
|
||||
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
|
||||
/* .chain = */ llama_sampler_clone(gsmpl->chain),
|
||||
/* .prev = */ gsmpl->prev,
|
||||
/* .cur = */ gsmpl->cur,
|
||||
/* .cur_p = */ gsmpl->cur_p,
|
||||
};
|
||||
}
|
||||
|
||||
void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool accept_grammar) {
|
||||
if (accept_grammar) {
|
||||
llama_sampler_accept(gsmpl->grmr, token);
|
||||
|
@ -259,12 +237,15 @@ void gpt_sampler_reset(struct gpt_sampler * gsmpl) {
|
|||
llama_sampler_reset(gsmpl->chain);
|
||||
}
|
||||
|
||||
llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl) {
|
||||
return &gsmpl->cur_p;
|
||||
}
|
||||
|
||||
llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl) {
|
||||
return gsmpl->prev.rat(0);
|
||||
struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl) {
|
||||
return new gpt_sampler {
|
||||
/* .params = */ gsmpl->params,
|
||||
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
|
||||
/* .chain = */ llama_sampler_clone(gsmpl->chain),
|
||||
/* .prev = */ gsmpl->prev,
|
||||
/* .cur = */ gsmpl->cur,
|
||||
/* .cur_p = */ gsmpl->cur_p,
|
||||
};
|
||||
}
|
||||
|
||||
void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler * gsmpl) {
|
||||
|
@ -279,12 +260,11 @@ void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler *
|
|||
}
|
||||
|
||||
llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
|
||||
auto & grmr = gsmpl->grmr;
|
||||
auto & chain = gsmpl->chain;
|
||||
|
||||
gsmpl->set_logits(ctx, idx);
|
||||
|
||||
auto & cur_p = gsmpl->cur_p;
|
||||
auto & grmr = gsmpl->grmr;
|
||||
auto & chain = gsmpl->chain;
|
||||
auto & cur_p = gsmpl->cur_p; // initialized by set_logits
|
||||
|
||||
if (grammar_first) {
|
||||
llama_sampler_apply(grmr, &cur_p);
|
||||
|
@ -307,24 +287,45 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context
|
|||
|
||||
llama_sampler_apply(grmr, &single_token_data_array);
|
||||
|
||||
// check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
|
||||
const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
|
||||
if (is_valid) {
|
||||
return id;
|
||||
}
|
||||
}
|
||||
|
||||
// if the token is not valid, sample again, first apply the grammar samplers and then sample
|
||||
// resampling:
|
||||
// if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
|
||||
gsmpl->set_logits(ctx, idx);
|
||||
|
||||
llama_sampler_apply(grmr, &cur_p);
|
||||
llama_sampler_apply(chain, &cur_p);
|
||||
|
||||
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
|
||||
GGML_ASSERT(cur_p.selected != -1 && "no selected token during re-sampling - check your sampling configuration");
|
||||
|
||||
return cur_p.data[cur_p.selected].id;
|
||||
}
|
||||
|
||||
// helpers
|
||||
|
||||
llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl) {
|
||||
return &gsmpl->cur_p;
|
||||
}
|
||||
|
||||
llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl) {
|
||||
return gsmpl->prev.rat(0);
|
||||
}
|
||||
|
||||
std::string gpt_sampler_print(const struct gpt_sampler * gsmpl) {
|
||||
std::string result = "\tlogits ";
|
||||
|
||||
for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
|
||||
const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
|
||||
result += std::string("-> ") + llama_sampler_name(smpl) + " ";
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx_main, int n) {
|
||||
n = std::min(n, (int) gsmpl->prev.size());
|
||||
|
||||
|
|
|
@ -61,24 +61,41 @@ struct gpt_sampler_params {
|
|||
//
|
||||
// - grammar support
|
||||
// - custom sampler logic based on the parameters
|
||||
// - history of the last accepted tokens
|
||||
// - performance metrics
|
||||
//
|
||||
// This goal is to have a common implementation of the sampling logic shared across the examples.
|
||||
// For example, depending on the temperature, the sampling chain can be very simple (greedy) or more
|
||||
// complex (top-k, top-p, etc).
|
||||
//
|
||||
// Another example is related to the grammar. In general, the grammar constraints applied on the full
|
||||
// vocabulary can be very taxing. To improve performance, the grammar can be applied only to the sampled
|
||||
// token in order to verify if it fits the grammar. And only if the token doesn't fit the grammar, the
|
||||
// grammar constraints are applied to the full vocabulary and the token is resampled.
|
||||
//
|
||||
// The gpt_sampler also maintains a container with the last accepted tokens. In the future, this can
|
||||
// be moved into the core llama library.
|
||||
//
|
||||
// For convenience, the gpt_sampler also maintains a container with the current candidate tokens.
|
||||
// This can be used to access the probabilities of the rest of the non-sampled tokens.
|
||||
//
|
||||
// TODO: measure grammar performance
|
||||
//
|
||||
|
||||
struct gpt_sampler;
|
||||
|
||||
// llama_sampler API overloads
|
||||
|
||||
struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params);
|
||||
|
||||
void gpt_sampler_free(struct gpt_sampler * gsmpl);
|
||||
|
||||
struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl);
|
||||
|
||||
void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool accept_grammar);
|
||||
void gpt_sampler_reset (struct gpt_sampler * gsmpl);
|
||||
|
||||
llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl);
|
||||
|
||||
llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl);
|
||||
// if accept_grammar is true, the token is accepted both by the sampling chain and the grammar
|
||||
void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool accept_grammar);
|
||||
void gpt_sampler_reset (struct gpt_sampler * gsmpl);
|
||||
struct gpt_sampler * gpt_sampler_clone (struct gpt_sampler * gsmpl);
|
||||
|
||||
// arguments can be nullptr to skip printing
|
||||
void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler * gsmpl);
|
||||
|
||||
// extended sampling implementation:
|
||||
|
@ -89,12 +106,18 @@ void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler *
|
|||
// - if not: resample by first applying the grammar constraints and then sampling again (slower path)
|
||||
//
|
||||
// if grammar_first is true, the grammar is applied before the samplers (slower)
|
||||
// useful in cases where all the resulting candidates must fit the grammar
|
||||
// useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar
|
||||
//
|
||||
llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
|
||||
|
||||
// helpers
|
||||
|
||||
// access the internal list of current candidate tokens
|
||||
llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl);
|
||||
|
||||
// get the last accepted token
|
||||
llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl);
|
||||
|
||||
// print the sampler chain into a string
|
||||
std::string gpt_sampler_print(const struct gpt_sampler * gsmpl);
|
||||
|
||||
|
|
|
@ -206,6 +206,7 @@ extern "C" {
|
|||
LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs
|
||||
};
|
||||
|
||||
// TODO: simplify (https://github.com/ggerganov/llama.cpp/pull/9294#pullrequestreview-2286561979)
|
||||
typedef struct llama_token_data {
|
||||
llama_token id; // token id
|
||||
float logit; // log-odds of the token
|
||||
|
@ -216,7 +217,7 @@ extern "C" {
|
|||
// TODO: consider SoA
|
||||
llama_token_data * data;
|
||||
size_t size;
|
||||
int64_t selected;
|
||||
int64_t selected; // this is the index in the data array (i.e. not the token id)
|
||||
bool sorted;
|
||||
} llama_token_data_array;
|
||||
|
||||
|
@ -979,9 +980,38 @@ extern "C" {
|
|||
//
|
||||
// Sampling API
|
||||
//
|
||||
// In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU).
|
||||
// Sample usage:
|
||||
//
|
||||
// // prepare the sampling chain at the start
|
||||
// auto sparams = llama_sampler_chain_default_params();
|
||||
//
|
||||
// llama_sampler * smpl = llama_sampler_chain_init(sparams);
|
||||
//
|
||||
// llama_sampler_chain_add(smpl, llama_sampler_init_top_k(50));
|
||||
// llama_sampler_chain_add(smpl, llama_sampler_init_top_p(0.9, 1));
|
||||
// llama_sampler_chain_add(smpl, llama_sampler_init_temp (0.8));
|
||||
// llama_sampler_chain_add(smpl, llama_sampler_init_dist (seed));
|
||||
//
|
||||
// ...
|
||||
//
|
||||
// // decoding loop:
|
||||
// while (...) {
|
||||
// ...
|
||||
//
|
||||
// llama_decode(ctx, batch);
|
||||
//
|
||||
// // sample from the logits of the last token in the batch
|
||||
// const llama_token id = llama_sampler_sample(smpl, ctx, -1);
|
||||
//
|
||||
// ...
|
||||
// }
|
||||
//
|
||||
// llama_sampler_free(smpl);
|
||||
//
|
||||
//
|
||||
// TODO: In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU).
|
||||
// TODO: in the future, the entire sampling API that uses llama_model should start using llama_vocab
|
||||
//
|
||||
// TODO: in the future, the entire API that uses llama_model should start using llama_vocab
|
||||
|
||||
typedef void * llama_sampler_context_t;
|
||||
|
||||
|
@ -1003,6 +1033,7 @@ extern "C" {
|
|||
llama_sampler_context_t ctx;
|
||||
};
|
||||
|
||||
// mirror of llama_sampler_i:
|
||||
LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl);
|
||||
LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token);
|
||||
LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p);
|
||||
|
@ -1011,7 +1042,8 @@ extern "C" {
|
|||
// important: do not free if the sampler has been added to a llama_sampler_chain (via llama_sampler_chain_add)
|
||||
LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl);
|
||||
|
||||
// llama_sampler_chain is a type of llama_sampler that can contain multiple llama_samplers
|
||||
// llama_sampler_chain
|
||||
// a type of llama_sampler that can chain multiple samplers one after another
|
||||
|
||||
LLAMA_API struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params);
|
||||
|
||||
|
@ -1089,6 +1121,15 @@ extern "C" {
|
|||
int32_t n_logit_bias,
|
||||
const llama_logit_bias * logit_bias);
|
||||
|
||||
// Shorthand for:
|
||||
//
|
||||
// const auto * logits = llama_get_logits_ith(ctx, idx);
|
||||
// llama_token_data_array cur_p = { ... init from logits ... };
|
||||
// llama_sampler_apply(smpl, &cur_p);
|
||||
// return cur_p.data[cur_p.selected].id;
|
||||
//
|
||||
// At this point, this is mostly a convenience function.
|
||||
//
|
||||
LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx);
|
||||
|
||||
// TODO: extend in the future
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
#pragma once
|
||||
|
||||
// TODO: rename llama-sampling.h/.cpp to llama-sampler.h/.cpp ?
|
||||
|
||||
#include "llama-grammar.h"
|
||||
|
||||
#include <unordered_map>
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue