style : rearrange code + add comments and TODOs

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-09-07 12:22:27 +03:00
parent 4a4530b7ff
commit 4b27235624
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
4 changed files with 115 additions and 48 deletions

View file

@ -136,17 +136,6 @@ std::string gpt_sampler_params::print() const {
return std::string(result); return std::string(result);
} }
std::string gpt_sampler_print(const struct gpt_sampler * gsmpl) {
std::string result = "\tlogits ";
for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
result += std::string("-> ") + llama_sampler_name(smpl) + " ";
}
return result;
}
struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params) { struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params) {
llama_sampler_chain_params lparams = llama_sampler_chain_default_params(); llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
@ -232,17 +221,6 @@ void gpt_sampler_free(struct gpt_sampler * gsmpl) {
} }
} }
struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl) {
return new gpt_sampler {
/* .params = */ gsmpl->params,
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
/* .chain = */ llama_sampler_clone(gsmpl->chain),
/* .prev = */ gsmpl->prev,
/* .cur = */ gsmpl->cur,
/* .cur_p = */ gsmpl->cur_p,
};
}
void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool accept_grammar) { void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool accept_grammar) {
if (accept_grammar) { if (accept_grammar) {
llama_sampler_accept(gsmpl->grmr, token); llama_sampler_accept(gsmpl->grmr, token);
@ -259,12 +237,15 @@ void gpt_sampler_reset(struct gpt_sampler * gsmpl) {
llama_sampler_reset(gsmpl->chain); llama_sampler_reset(gsmpl->chain);
} }
llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl) { struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl) {
return &gsmpl->cur_p; return new gpt_sampler {
} /* .params = */ gsmpl->params,
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl) { /* .chain = */ llama_sampler_clone(gsmpl->chain),
return gsmpl->prev.rat(0); /* .prev = */ gsmpl->prev,
/* .cur = */ gsmpl->cur,
/* .cur_p = */ gsmpl->cur_p,
};
} }
void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler * gsmpl) { void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler * gsmpl) {
@ -279,12 +260,11 @@ void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler *
} }
llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) { llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
auto & grmr = gsmpl->grmr;
auto & chain = gsmpl->chain;
gsmpl->set_logits(ctx, idx); gsmpl->set_logits(ctx, idx);
auto & cur_p = gsmpl->cur_p; auto & grmr = gsmpl->grmr;
auto & chain = gsmpl->chain;
auto & cur_p = gsmpl->cur_p; // initialized by set_logits
if (grammar_first) { if (grammar_first) {
llama_sampler_apply(grmr, &cur_p); llama_sampler_apply(grmr, &cur_p);
@ -307,24 +287,45 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context
llama_sampler_apply(grmr, &single_token_data_array); llama_sampler_apply(grmr, &single_token_data_array);
// check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
const bool is_valid = single_token_data_array.data[0].logit != -INFINITY; const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
if (is_valid) { if (is_valid) {
return id; return id;
} }
} }
// if the token is not valid, sample again, first apply the grammar samplers and then sample // resampling:
// if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
gsmpl->set_logits(ctx, idx); gsmpl->set_logits(ctx, idx);
llama_sampler_apply(grmr, &cur_p); llama_sampler_apply(grmr, &cur_p);
llama_sampler_apply(chain, &cur_p); llama_sampler_apply(chain, &cur_p);
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration"); GGML_ASSERT(cur_p.selected != -1 && "no selected token during re-sampling - check your sampling configuration");
return cur_p.data[cur_p.selected].id; return cur_p.data[cur_p.selected].id;
} }
// helpers
llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl) {
return &gsmpl->cur_p;
}
llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl) {
return gsmpl->prev.rat(0);
}
std::string gpt_sampler_print(const struct gpt_sampler * gsmpl) {
std::string result = "\tlogits ";
for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
result += std::string("-> ") + llama_sampler_name(smpl) + " ";
}
return result;
}
std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx_main, int n) { std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx_main, int n) {
n = std::min(n, (int) gsmpl->prev.size()); n = std::min(n, (int) gsmpl->prev.size());

View file

@ -61,24 +61,41 @@ struct gpt_sampler_params {
// //
// - grammar support // - grammar support
// - custom sampler logic based on the parameters // - custom sampler logic based on the parameters
// - history of the last accepted tokens
// - performance metrics
//
// This goal is to have a common implementation of the sampling logic shared across the examples.
// For example, depending on the temperature, the sampling chain can be very simple (greedy) or more
// complex (top-k, top-p, etc).
//
// Another example is related to the grammar. In general, the grammar constraints applied on the full
// vocabulary can be very taxing. To improve performance, the grammar can be applied only to the sampled
// token in order to verify if it fits the grammar. And only if the token doesn't fit the grammar, the
// grammar constraints are applied to the full vocabulary and the token is resampled.
//
// The gpt_sampler also maintains a container with the last accepted tokens. In the future, this can
// be moved into the core llama library.
//
// For convenience, the gpt_sampler also maintains a container with the current candidate tokens.
// This can be used to access the probabilities of the rest of the non-sampled tokens.
// //
// TODO: measure grammar performance // TODO: measure grammar performance
// //
struct gpt_sampler; struct gpt_sampler;
// llama_sampler API overloads
struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params); struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params);
void gpt_sampler_free(struct gpt_sampler * gsmpl); void gpt_sampler_free(struct gpt_sampler * gsmpl);
struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl); // if accept_grammar is true, the token is accepted both by the sampling chain and the grammar
void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool accept_grammar);
void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool accept_grammar); void gpt_sampler_reset (struct gpt_sampler * gsmpl);
void gpt_sampler_reset (struct gpt_sampler * gsmpl); struct gpt_sampler * gpt_sampler_clone (struct gpt_sampler * gsmpl);
llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl);
llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl);
// arguments can be nullptr to skip printing
void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler * gsmpl); void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler * gsmpl);
// extended sampling implementation: // extended sampling implementation:
@ -89,12 +106,18 @@ void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler *
// - if not: resample by first applying the grammar constraints and then sampling again (slower path) // - if not: resample by first applying the grammar constraints and then sampling again (slower path)
// //
// if grammar_first is true, the grammar is applied before the samplers (slower) // if grammar_first is true, the grammar is applied before the samplers (slower)
// useful in cases where all the resulting candidates must fit the grammar // useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar
// //
llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false); llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
// helpers // helpers
// access the internal list of current candidate tokens
llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl);
// get the last accepted token
llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl);
// print the sampler chain into a string // print the sampler chain into a string
std::string gpt_sampler_print(const struct gpt_sampler * gsmpl); std::string gpt_sampler_print(const struct gpt_sampler * gsmpl);

View file

@ -206,6 +206,7 @@ extern "C" {
LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs
}; };
// TODO: simplify (https://github.com/ggerganov/llama.cpp/pull/9294#pullrequestreview-2286561979)
typedef struct llama_token_data { typedef struct llama_token_data {
llama_token id; // token id llama_token id; // token id
float logit; // log-odds of the token float logit; // log-odds of the token
@ -216,7 +217,7 @@ extern "C" {
// TODO: consider SoA // TODO: consider SoA
llama_token_data * data; llama_token_data * data;
size_t size; size_t size;
int64_t selected; int64_t selected; // this is the index in the data array (i.e. not the token id)
bool sorted; bool sorted;
} llama_token_data_array; } llama_token_data_array;
@ -979,9 +980,38 @@ extern "C" {
// //
// Sampling API // Sampling API
// //
// In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU). // Sample usage:
//
// // prepare the sampling chain at the start
// auto sparams = llama_sampler_chain_default_params();
//
// llama_sampler * smpl = llama_sampler_chain_init(sparams);
//
// llama_sampler_chain_add(smpl, llama_sampler_init_top_k(50));
// llama_sampler_chain_add(smpl, llama_sampler_init_top_p(0.9, 1));
// llama_sampler_chain_add(smpl, llama_sampler_init_temp (0.8));
// llama_sampler_chain_add(smpl, llama_sampler_init_dist (seed));
//
// ...
//
// // decoding loop:
// while (...) {
// ...
//
// llama_decode(ctx, batch);
//
// // sample from the logits of the last token in the batch
// const llama_token id = llama_sampler_sample(smpl, ctx, -1);
//
// ...
// }
//
// llama_sampler_free(smpl);
//
//
// TODO: In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU).
// TODO: in the future, the entire sampling API that uses llama_model should start using llama_vocab
// //
// TODO: in the future, the entire API that uses llama_model should start using llama_vocab
typedef void * llama_sampler_context_t; typedef void * llama_sampler_context_t;
@ -1003,6 +1033,7 @@ extern "C" {
llama_sampler_context_t ctx; llama_sampler_context_t ctx;
}; };
// mirror of llama_sampler_i:
LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl); LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl);
LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token); LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token);
LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p); LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p);
@ -1011,7 +1042,8 @@ extern "C" {
// important: do not free if the sampler has been added to a llama_sampler_chain (via llama_sampler_chain_add) // important: do not free if the sampler has been added to a llama_sampler_chain (via llama_sampler_chain_add)
LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl); LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl);
// llama_sampler_chain is a type of llama_sampler that can contain multiple llama_samplers // llama_sampler_chain
// a type of llama_sampler that can chain multiple samplers one after another
LLAMA_API struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params); LLAMA_API struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params);
@ -1089,6 +1121,15 @@ extern "C" {
int32_t n_logit_bias, int32_t n_logit_bias,
const llama_logit_bias * logit_bias); const llama_logit_bias * logit_bias);
// Shorthand for:
//
// const auto * logits = llama_get_logits_ith(ctx, idx);
// llama_token_data_array cur_p = { ... init from logits ... };
// llama_sampler_apply(smpl, &cur_p);
// return cur_p.data[cur_p.selected].id;
//
// At this point, this is mostly a convenience function.
//
LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx); LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx);
// TODO: extend in the future // TODO: extend in the future

View file

@ -1,5 +1,7 @@
#pragma once #pragma once
// TODO: rename llama-sampling.h/.cpp to llama-sampler.h/.cpp ?
#include "llama-grammar.h" #include "llama-grammar.h"
#include <unordered_map> #include <unordered_map>