Export function to fetch/create default sampler states
Code formatting cleanups and add some comments Silence a warning about id not being used when logging is disabled
This commit is contained in:
parent
52def09a31
commit
01bef02900
2 changed files with 36 additions and 9 deletions
|
@ -9,7 +9,9 @@ llama_sampling_state::~llama_sampling_state() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_sampling_state llama_sampling_state_init(const struct gpt_params & params, llama_grammar * grammar) {
|
llama_sampling_state llama_sampling_state_init(
|
||||||
|
const struct gpt_params & params,
|
||||||
|
llama_grammar * grammar) {
|
||||||
llama_sampling_state result;
|
llama_sampling_state result;
|
||||||
|
|
||||||
result.params = params.sampling_params;
|
result.params = params.sampling_params;
|
||||||
|
@ -17,8 +19,10 @@ llama_sampling_state llama_sampling_state_init(const struct gpt_params & params,
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Creates the state if it doesn't exist, so this always return something.
|
// Note: Creates the state if it doesn't exist, so this always return something.
|
||||||
static llama_sampler_sequence_state & sampling_get_sequence_state(llama_sampling_state & state, const llama_seq_id seq) {
|
llama_sampler_sequence_state & llama_sampling_get_sequence_state(
|
||||||
|
llama_sampling_state & state,
|
||||||
|
const llama_seq_id seq) {
|
||||||
const auto it = state.sequence_states.find(seq);
|
const auto it = state.sequence_states.find(seq);
|
||||||
if (it != state.sequence_states.end()) {
|
if (it != state.sequence_states.end()) {
|
||||||
return it->second;
|
return it->second;
|
||||||
|
@ -30,7 +34,9 @@ static llama_sampler_sequence_state & sampling_get_sequence_state(llama_sampling
|
||||||
return state.sequence_states.insert({seq, new_state}).first->second;
|
return state.sequence_states.insert({seq, new_state}).first->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_sampling_state_reset(llama_sampling_state & state, const llama_seq_id seq) {
|
bool llama_sampling_state_reset(
|
||||||
|
llama_sampling_state & state,
|
||||||
|
const llama_seq_id seq) {
|
||||||
const auto it = state.sequence_states.find(seq);
|
const auto it = state.sequence_states.find(seq);
|
||||||
if (it == state.sequence_states.end()) return false;
|
if (it == state.sequence_states.end()) return false;
|
||||||
if (it->second.grammar != NULL) {
|
if (it->second.grammar != NULL) {
|
||||||
|
@ -109,7 +115,7 @@ llama_token llama_sample_token(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_sampler_sequence_state & seq_state = sampling_get_sequence_state(state, seq);
|
llama_sampler_sequence_state & seq_state = llama_sampling_get_sequence_state(state, seq);
|
||||||
|
|
||||||
if (seq_state.grammar != NULL) {
|
if (seq_state.grammar != NULL) {
|
||||||
llama_sample_grammar(ctx, &cur_p, seq_state.grammar);
|
llama_sample_grammar(ctx, &cur_p, seq_state.grammar);
|
||||||
|
@ -141,6 +147,7 @@ llama_token llama_sample_token(
|
||||||
|
|
||||||
for (int i = 0; i < n_top; i++) {
|
for (int i = 0; i < n_top; i++) {
|
||||||
const llama_token id = cur_p.data[i].id;
|
const llama_token id = cur_p.data[i].id;
|
||||||
|
(void)id; // To avoid a warning that id is unused when logging is disabled.
|
||||||
LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx, id).c_str(), cur_p.data[i].p);
|
LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx, id).c_str(), cur_p.data[i].p);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -150,7 +157,6 @@ llama_token llama_sample_token(
|
||||||
LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx, id).c_str());
|
LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx, id).c_str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// printf("`%d`", candidates_p.size);
|
|
||||||
|
|
||||||
if (seq_state.grammar != NULL) {
|
if (seq_state.grammar != NULL) {
|
||||||
llama_grammar_accept_token(ctx, seq_state.grammar, id);
|
llama_grammar_accept_token(ctx, seq_state.grammar, id);
|
||||||
|
|
|
@ -44,19 +44,40 @@ typedef struct llama_sampler_sequence_state {
|
||||||
typedef struct llama_sampling_state {
|
typedef struct llama_sampling_state {
|
||||||
~llama_sampling_state();
|
~llama_sampling_state();
|
||||||
|
|
||||||
|
// parameters that will be used for sampling and when creating
|
||||||
|
// new llama_sampler_sequence_state instances
|
||||||
llama_sampling_params params;
|
llama_sampling_params params;
|
||||||
|
|
||||||
|
// map of sequence ids to sampler states
|
||||||
std::unordered_map<llama_seq_id, llama_sampler_sequence_state> sequence_states;
|
std::unordered_map<llama_seq_id, llama_sampler_sequence_state> sequence_states;
|
||||||
|
|
||||||
|
// when non-NULL, new instances of llama_sampler_sequence_state
|
||||||
|
// will get a copy of the grammar here
|
||||||
|
// note: only the pointer is stored here, it is not a copy of
|
||||||
|
// the grammar and shouldn't be freed
|
||||||
llama_grammar * grammar;
|
llama_grammar * grammar;
|
||||||
} llama_sampling_state;
|
} llama_sampling_state;
|
||||||
|
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
|
|
||||||
// Create a new sampling state instance.
|
// Create a new sampling state instance.
|
||||||
llama_sampling_state llama_sampling_state_init(const struct gpt_params & params, llama_grammar * grammar);
|
llama_sampling_state llama_sampling_state_init(
|
||||||
|
const struct gpt_params & params,
|
||||||
|
llama_grammar * grammar = NULL);
|
||||||
|
|
||||||
bool llama_sampling_state_reset(llama_sampling_state & state, const llama_seq_id seq = 0);
|
// Fetches the sampler state for the specified sequence id (defaults to 0).
|
||||||
|
// If the state for that sequence id doesn't already exist, it will be created with
|
||||||
|
// default values based on the parameters in the state argument.
|
||||||
|
llama_sampler_sequence_state & llama_sampling_get_sequence_state(
|
||||||
|
llama_sampling_state & state,
|
||||||
|
const llama_seq_id seq = 0);
|
||||||
|
|
||||||
|
// Reset the sampler states for the supplied sequence id (defaults to 0).
|
||||||
|
// This is necessary to reuse a sequence id or free memory used by sequences
|
||||||
|
// that are no longer required.
|
||||||
|
bool llama_sampling_state_reset(
|
||||||
|
llama_sampling_state & state,
|
||||||
|
const llama_seq_id seq = 0);
|
||||||
|
|
||||||
// this is a common sampling function used across the examples for convenience
|
// this is a common sampling function used across the examples for convenience
|
||||||
// it can serve as a starting point for implementing your own sampling function
|
// it can serve as a starting point for implementing your own sampling function
|
||||||
|
@ -73,7 +94,7 @@ bool llama_sampling_state_reset(llama_sampling_state & state, const llama_seq_id
|
||||||
// - grammar: grammar to use for sampling, ignore if NULL
|
// - grammar: grammar to use for sampling, ignore if NULL
|
||||||
// - last_tokens: needed for repetition penalty, ignore if empty
|
// - last_tokens: needed for repetition penalty, ignore if empty
|
||||||
// - idx: sample from llama_get_logits_ith(ctx, idx)
|
// - idx: sample from llama_get_logits_ith(ctx, idx)
|
||||||
// - seq: sequence id to associate sampler state with (currently only used by mirostat)
|
// - seq: sequence id to associate sampler state with
|
||||||
//
|
//
|
||||||
// returns:
|
// returns:
|
||||||
// - token: sampled token
|
// - token: sampled token
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue