llama : add llama_beam_search() (#2267)

* Add llama_beam_search().

* Add '// Beam search' heading to llama.{h,cpp} after llama_grammar_accept_token().

* Add space around * pointers and & references.

* Add spaces around comparison and assignment operators.

* Prefer west const.

* Use llama_ prefix for structs in global namespace.

* Delete obsolete comment from an earlier revision.

* Change eos to eob in llama_beam and llama_beam_view structs.
This commit is contained in:
Matt Pulver 2023-08-25 11:18:48 -04:00 committed by GitHub
parent 28b2c996ca
commit c82742ac9c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 563 additions and 13 deletions

37
llama.h
View file

@ -469,6 +469,43 @@ extern "C" {
/// @details Accepts the sampled token into the grammar
LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token);
//
// Beam search
//
struct llama_beam_view {
const llama_token * tokens;
size_t n_tokens;
float p; // Cumulative beam probability (renormalized relative to all beams)
bool eob; // Callback should set this to true when a beam is at end-of-beam.
};
// Passed to beam_search_callback function.
// Whenever 0 < common_prefix_length, this number of tokens should be copied from any of the beams
// (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks.
// These pointers are valid only during the synchronous callback, so should not be saved.
struct llama_beams_state {
llama_beam_view * beam_views;
size_t n_beams; // Number of elements in beam_views[].
size_t common_prefix_length; // Current max length of prefix tokens shared by all beams.
bool last_call; // True iff this is the last callback invocation.
};
// Type of pointer to the beam_search_callback function.
// void* callback_data is any custom data passed to llama_beam_search, that is subsequently
// passed back to beam_search_callback. This avoids having to use global variables in the callback.
typedef void (*llama_beam_search_callback_fn_t)(void * callback_data, llama_beams_state);
/// @details Deterministically returns entire sentence constructed by a beam search.
/// @param ctx Pointer to the llama_context.
/// @param callback Invoked for each iteration of the beam_search loop, passing in beams_state.
/// @param callback_data A pointer that is simply passed back to callback.
/// @param n_beams Number of beams to use.
/// @param n_past Number of tokens already evaluated.
/// @param n_predict Maximum number of tokens to predict. EOS may occur earlier.
/// @param n_threads Number of threads as passed to llama_eval().
LLAMA_API void llama_beam_search(struct llama_context * ctx, llama_beam_search_callback_fn_t callback, void * callback_data, size_t n_beams, int n_past, int n_predict, int n_threads);
// Performance information
LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);
LLAMA_API void llama_print_timings(struct llama_context * ctx);