Export function to fetch/create default sampler states

Code formatting cleanups and add some comments

Silence a warning about id not being used when logging is disabled
This commit is contained in:
KerfuffleV2 2023-10-08 11:59:07 -06:00
parent 52def09a31
commit 01bef02900
2 changed files with 36 additions and 9 deletions

View file

@ -9,7 +9,9 @@ llama_sampling_state::~llama_sampling_state() {
}
}
llama_sampling_state llama_sampling_state_init(const struct gpt_params & params, llama_grammar * grammar) {
llama_sampling_state llama_sampling_state_init(
const struct gpt_params & params,
llama_grammar * grammar) {
llama_sampling_state result;
result.params = params.sampling_params;
@ -17,8 +19,10 @@ llama_sampling_state llama_sampling_state_init(const struct gpt_params & params,
return result;
}
// Creates the state if it doesn't exist, so this always return something.
static llama_sampler_sequence_state & sampling_get_sequence_state(llama_sampling_state & state, const llama_seq_id seq) {
// Note: Creates the state if it doesn't exist, so this always return something.
llama_sampler_sequence_state & llama_sampling_get_sequence_state(
llama_sampling_state & state,
const llama_seq_id seq) {
const auto it = state.sequence_states.find(seq);
if (it != state.sequence_states.end()) {
return it->second;
@ -30,7 +34,9 @@ static llama_sampler_sequence_state & sampling_get_sequence_state(llama_sampling
return state.sequence_states.insert({seq, new_state}).first->second;
}
bool llama_sampling_state_reset(llama_sampling_state & state, const llama_seq_id seq) {
bool llama_sampling_state_reset(
llama_sampling_state & state,
const llama_seq_id seq) {
const auto it = state.sequence_states.find(seq);
if (it == state.sequence_states.end()) return false;
if (it->second.grammar != NULL) {
@ -109,7 +115,7 @@ llama_token llama_sample_token(
}
}
llama_sampler_sequence_state & seq_state = sampling_get_sequence_state(state, seq);
llama_sampler_sequence_state & seq_state = llama_sampling_get_sequence_state(state, seq);
if (seq_state.grammar != NULL) {
llama_sample_grammar(ctx, &cur_p, seq_state.grammar);
@ -141,6 +147,7 @@ llama_token llama_sample_token(
for (int i = 0; i < n_top; i++) {
const llama_token id = cur_p.data[i].id;
(void)id; // To avoid a warning that id is unused when logging is disabled.
LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx, id).c_str(), cur_p.data[i].p);
}
}
@ -150,7 +157,6 @@ llama_token llama_sample_token(
LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx, id).c_str());
}
}
// printf("`%d`", candidates_p.size);
if (seq_state.grammar != NULL) {
llama_grammar_accept_token(ctx, seq_state.grammar, id);

View file

@ -44,19 +44,40 @@ typedef struct llama_sampler_sequence_state {
typedef struct llama_sampling_state {
~llama_sampling_state();
// parameters that will be used for sampling and when creating
// new llama_sampler_sequence_state instances
llama_sampling_params params;
// map of sequence ids to sampler states
std::unordered_map<llama_seq_id, llama_sampler_sequence_state> sequence_states;
// when non-NULL, new instances of llama_sampler_sequence_state
// will get a copy of the grammar here
// note: only the pointer is stored here, it is not a copy of
// the grammar and shouldn't be freed
llama_grammar * grammar;
} llama_sampling_state;
#include "common.h"
// Create a new sampling state instance.
llama_sampling_state llama_sampling_state_init(const struct gpt_params & params, llama_grammar * grammar);
llama_sampling_state llama_sampling_state_init(
const struct gpt_params & params,
llama_grammar * grammar = NULL);
bool llama_sampling_state_reset(llama_sampling_state & state, const llama_seq_id seq = 0);
// Fetches the sampler state for the specified sequence id (defaults to 0).
// If the state for that sequence id doesn't already exist, it will be created with
// default values based on the parameters in the state argument.
llama_sampler_sequence_state & llama_sampling_get_sequence_state(
llama_sampling_state & state,
const llama_seq_id seq = 0);
// Reset the sampler states for the supplied sequence id (defaults to 0).
// This is necessary to reuse a sequence id or free memory used by sequences
// that are no longer required.
bool llama_sampling_state_reset(
llama_sampling_state & state,
const llama_seq_id seq = 0);
// this is a common sampling function used across the examples for convenience
// it can serve as a starting point for implementing your own sampling function
@ -73,7 +94,7 @@ bool llama_sampling_state_reset(llama_sampling_state & state, const llama_seq_id
// - grammar: grammar to use for sampling, ignore if NULL
// - last_tokens: needed for repetition penalty, ignore if empty
// - idx: sample from llama_get_logits_ith(ctx, idx)
// - seq: sequence id to associate sampler state with (currently only used by mirostat)
// - seq: sequence id to associate sampler state with
//
// returns:
// - token: sampled token