Add '// Beam search' heading to llama.{h,cpp} after llama_grammar_accept_token().

This commit is contained in:
Matt Pulver 2023-08-25 09:18:24 -04:00
parent c4269e0200
commit abe0829984
2 changed files with 9 additions and 0 deletions

View file

@ -4326,6 +4326,10 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
ctx->t_sample_us += ggml_time_us() - t_start_sample_us; ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
} }
//
// Beam search
//
struct llama_beam { struct llama_beam {
std::vector<llama_token> tokens; std::vector<llama_token> tokens;
float p; // Cumulative beam probability (renormalized relative to all beams) float p; // Cumulative beam probability (renormalized relative to all beams)

View file

@ -465,6 +465,10 @@ extern "C" {
/// @details Accepts the sampled token into the grammar /// @details Accepts the sampled token into the grammar
LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token); LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token);
//
// Beam search
//
struct llama_beam_view { struct llama_beam_view {
llama_token const* tokens; llama_token const* tokens;
size_t n_tokens; size_t n_tokens;
@ -482,6 +486,7 @@ extern "C" {
size_t common_prefix_length; // Current max length of prefix tokens shared by all beams. size_t common_prefix_length; // Current max length of prefix tokens shared by all beams.
bool last_call; // True iff this is the last callback invocation. bool last_call; // True iff this is the last callback invocation.
}; };
// Type of pointer to the beam_search_callback function. // Type of pointer to the beam_search_callback function.
// void* callback_data is any custom data passed to llama_beam_search, that is subsequently // void* callback_data is any custom data passed to llama_beam_search, that is subsequently
// passed back to beam_search_callback. This avoids having to use global variables in the callback. // passed back to beam_search_callback. This avoids having to use global variables in the callback.