From c4269e020002107181ccc2c7837e20abf6191c1d Mon Sep 17 00:00:00 2001 From: Matt Pulver Date: Tue, 18 Jul 2023 14:33:34 -0400 Subject: [PATCH] Add llama_beam_search(). --- common/common.h | 1 + examples/CMakeLists.txt | 1 + examples/beam_search/CMakeLists.txt | 8 + examples/beam_search/beam_search.cpp | 186 ++++++++++++++++++++ examples/server/server.cpp | 91 ++++++++-- llama.cpp | 247 +++++++++++++++++++++++++++ llama.h | 33 ++++ 7 files changed, 554 insertions(+), 13 deletions(-) create mode 100644 examples/beam_search/CMakeLists.txt create mode 100644 examples/beam_search/beam_search.cpp diff --git a/common/common.h b/common/common.h index 17d271e67..ce61265f8 100644 --- a/common/common.h +++ b/common/common.h @@ -28,6 +28,7 @@ struct gpt_params { int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. + int32_t n_beams = 0; // if non-zero then use beam search of given width. float rope_freq_base = 10000.0f; // RoPE base frequency float rope_freq_scale = 1.0f; // RoPE frequency scaling factor diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index d2176c910..94b785224 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -25,6 +25,7 @@ else() add_subdirectory(simple) add_subdirectory(embd-input) add_subdirectory(llama-bench) + add_subdirectory(beam_search) if (LLAMA_METAL) add_subdirectory(metal) endif() diff --git a/examples/beam_search/CMakeLists.txt b/examples/beam_search/CMakeLists.txt new file mode 100644 index 000000000..b29e01092 --- /dev/null +++ b/examples/beam_search/CMakeLists.txt @@ -0,0 +1,8 @@ +set(TARGET beam_search) +add_executable(${TARGET} beam_search.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) +if(TARGET BUILD_INFO) + add_dependencies(${TARGET} BUILD_INFO) +endif() diff --git a/examples/beam_search/beam_search.cpp b/examples/beam_search/beam_search.cpp new file mode 100644 index 000000000..2bc0a378b --- /dev/null +++ b/examples/beam_search/beam_search.cpp @@ -0,0 +1,186 @@ +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif + +#include "common.h" +#include "llama.h" +#include "build-info.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) +#include +#include +#elif defined (_WIN32) +#define WIN32_LEAN_AND_MEAN +#define NOMINMAX +#include +#include +#endif + +// Used for debugging to print out beam tokens. +struct ostream_beam_view { + llama_context* ctx; + llama_beam_view beam_view; +}; +std::ostream& operator<<(std::ostream& os, ostream_beam_view const& obv) { + os << "p(" << obv.beam_view.p << ") eos(" << std::boolalpha << obv.beam_view.eos << ") tokens("; + for (size_t i=0 ; i response; +}; + +bool is_at_eos(beam_search_callback_data const& callback_data, llama_token const* tokens, size_t const n_tokens) { + return n_tokens && tokens[n_tokens-1] == llama_token_eos(callback_data.ctx); +} + +// Function matching type llama_beam_search_callback_fn_t. +// Custom callback example is called each time the beams lengths increase: +// * Show progress by printing ',' following by number of convergent beam tokens if any. +// * When all beams converge to a common prefix, they are made available in beams_state.beams[0]. +// This is also called when the stop condition is met. +// Collect tokens into std::vector response which is pointed to by callback_data. +void beam_search_callback(void* callback_data_ptr, llama_beams_state beams_state) { + auto& callback_data = *static_cast(callback_data_ptr); + // Mark beams as EOS as needed. + for (size_t i=0 ; i 3 ) + { + params.prompt = argv[3]; + } + + if ( params.prompt.empty() ) + { + params.prompt = "### Request:\nHow many countries are there?\n\n### Response:\n"; + } + + //--------------------------------- + // Init LLM : + //--------------------------------- + + llama_backend_init(params.numa); + + llama_model * model; + llama_context * ctx; + + std::tie(model, ctx) = llama_init_from_gpt_params( params ); + + if ( model == NULL ) + { + fprintf( stderr , "%s: error: unable to load model\n" , __func__ ); + return 1; + } + + //--------------------------------- + // Tokenize the prompt : + //--------------------------------- + + std::vector tokens_list = llama_tokenize(ctx, params.prompt, true); + + const size_t max_context_size = llama_n_ctx( ctx ); + const size_t max_tokens_list_size = max_context_size - 4 ; + + if (tokens_list.size() > max_tokens_list_size) + { + fprintf( stderr , "%s: error: prompt too long (%lu tokens, max %lu)\n" , + __func__ , tokens_list.size() , max_tokens_list_size ); + return 1; + } + + fprintf( stderr, "\n\n" ); + + // Print the tokens from the prompt : + + for( auto id : tokens_list ) + { + std::cout << llama_token_to_str(ctx, id); + } + std::cout << std::flush; + + int n_past = llama_get_kv_cache_token_count(ctx); + if (llama_eval(ctx, tokens_list.data(), tokens_list.size(), n_past, params.n_threads)) + { + fprintf(stderr, "%s : failed to eval prompt.\n" , __func__ ); + return 1; + } + n_past += tokens_list.size(); + + beam_search_callback_data callback_data{ctx, {}}; + size_t const beam_width = static_cast(params.n_beams); + int const n_predict = 256; + llama_beam_search(ctx, beam_search_callback, &callback_data, beam_width, n_past, n_predict, params.n_threads); + + std::cout << "\n\n"; + for (llama_token const token_id : callback_data.response) { + std::cout << llama_token_to_str(ctx,token_id); + } + std::cout << std::endl; + + llama_free( ctx ); + llama_free_model( model ); + + llama_backend_free(); + + return 0; +} diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 025b385cc..7985392fe 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1209,6 +1209,63 @@ static void log_server_request(const Request &req, const Response &res) }); } +bool is_at_eos(llama_server_context& server_context, llama_token const* tokens, size_t const n_tokens) { + return n_tokens && tokens[n_tokens-1] == llama_token_eos(server_context.ctx); +} + +// Function matching type llama_beam_search_callback_fn_t. +// Custom callback example is called each time the beams lengths increase: +// * Show progress by printing ',' following by number of convergent beam tokens if any. +// * When all beams converge to a common prefix, they are made available in beams_state.beams[0]. +// This is also called when the stop condition is met. +// Collect tokens into std::vector response which is pointed to by callback_data. +void beam_search_callback(void* callback_data, llama_beams_state beams_state) { + auto& llama = *static_cast(callback_data); + // Mark beams as EOS as needed. + for (size_t i=0 ; igenerated_token_probs.end() - n); + auto const map = [](llama_token tok) { return completion_token_output{{},tok}; }; + std::transform(tokens, tokens + n, llama.generated_token_probs.end() - n, map); + printf("%lu", n); + } + fflush(stdout); +#if 0 // DEBUG: print current beams for this iteration + std::cout << "\n\nCurrent beams:\n"; + for (size_t i=0 ; i < beams_state.n_beams ; ++i) { + std::cout << "beams["<t_sample_us += ggml_time_us() - t_start_sample_us; } +struct llama_beam { + std::vector tokens; + float p; // Cumulative beam probability (renormalized relative to all beams) + bool eos; // Initialize end-of-sentence to false. Callback sets this to true. + // Sort beams by probability. In case of ties, prefer beams at eos. + bool operator<(llama_beam const& rhs) const { + return std::make_tuple(p, eos) < std::make_tuple(rhs.p, rhs.eos); + } + // Shift off first n tokens and discard them. + void shift_tokens(size_t const n) { + if (n) { + std::copy(tokens.begin() + n, tokens.end(), tokens.begin()); + tokens.resize(tokens.size() - n); + } + } + llama_beam_view view() const { return {tokens.data(), tokens.size(), p, eos}; } +}; + +// A struct for calculating logit-related info. +struct logit_info { + float const* const logits; + int const n_vocab; + float const max_l; + float const normalizer; + struct sum_exp { + float max_l; + float operator()(float sum, float l) const { return sum + std::exp(l - max_l); } + }; + logit_info(llama_context* ctx) + : logits(llama_get_logits(ctx)) + , n_vocab(llama_n_vocab(ctx)) + , max_l(*std::max_element(logits, logits + n_vocab)) + , normalizer(1.0f / std::accumulate(logits, logits + n_vocab, 0.0f, sum_exp{max_l})) + { } + llama_token_data get_token_data(llama_token const token_id) const { + constexpr auto p = std::numeric_limits::quiet_NaN(); // never used + return {token_id, logits[token_id], p}; + } + // Return top k token_data by logit. + std::vector top_k(size_t k) { + std::vector min_heap; // min-heap by logit + llama_token const k_min = std::min(static_cast(k), n_vocab); + min_heap.reserve(k_min); + for (llama_token token_id=0 ; token_id b.logit; }; + std::make_heap(min_heap.begin(), min_heap.end(), comp); + for (llama_token token_id=k_min ; token_id beams; + std::vector next_beams; + + // Re-calculated on each loop iteration + size_t common_prefix_length; + + // Used to communicate to/from callback on beams state. + std::vector beam_views; + + beam_search(llama_context * ctx, size_t n_beams, int n_past, int n_predict, int n_threads) + : ctx(ctx) + , n_beams(n_beams) + , n_past(n_past) + , n_predict(n_predict) + , n_threads(n_threads) + , beam_views(n_beams) { + beams.reserve(n_beams); + next_beams.reserve(n_beams); + } + + // Collapse beams to a single beam given by index. + void collapse_beams(size_t const beam_idx) { + if (0u < beam_idx) { + std::swap(beams[0], beams[beam_idx]); + } + beams.resize(1); + } + + // Min-heaps are used to efficiently collect the top-k elements (k=n_beams). + // The repetative patterns below reflect the 2 stages of heaps: + // * Gather elements until the vector is full, then call std::make_heap() on it. + // * If the heap is full and a new element is found that should be included, pop the + // least element to the back(), replace it with the new, then push it into the heap. + void fill_next_beams_by_top_probabilities(llama_beam& beam) { + // Min-heaps use a greater-than comparator. + auto const comp = [](llama_beam const& a, llama_beam const& b) { return a.p > b.p; }; + if (beam.eos) { + // beam is at end-of-sentence, so just copy it to next_beams if its probability is high enough. + if (next_beams.size() < n_beams) { + next_beams.push_back(std::move(beam)); + if (next_beams.size() == n_beams) { + std::make_heap(next_beams.begin(), next_beams.end(), comp); + } + } else if (next_beams.front().p < beam.p) { + std::pop_heap(next_beams.begin(), next_beams.end(), comp); + next_beams.back() = std::move(beam); + std::push_heap(next_beams.begin(), next_beams.end(), comp); + } + } else { + // beam is not at end-of-sentence, so branch with next top_k tokens. + if (!beam.tokens.empty()) { + llama_eval(ctx, beam.tokens.data(), beam.tokens.size(), n_past, n_threads); + } + logit_info logit_info(ctx); + std::vector next_tokens = logit_info.top_k(n_beams); + size_t i=0; + if (next_beams.size() < n_beams) { + for (; next_beams.size() < n_beams ; ++i) { + llama_beam next_beam = beam; + next_beam.tokens.push_back(next_tokens[i].id); + next_beam.p *= logit_info.probability_from_logit(next_tokens[i].logit); + next_beams.push_back(std::move(next_beam)); + } + std::make_heap(next_beams.begin(), next_beams.end(), comp); + } else { + for (; next_beams.front().p == 0.0f ; ++i) { + std::pop_heap(next_beams.begin(), next_beams.end(), comp); + next_beams.back() = beam; + next_beams.back().tokens.push_back(next_tokens[i].id); + next_beams.back().p *= logit_info.probability_from_logit(next_tokens[i].logit); + std::push_heap(next_beams.begin(), next_beams.end(), comp); + } + } + for (; i < n_beams ; ++i) { + float const next_p = beam.p * logit_info.probability_from_logit(next_tokens[i].logit); + if (next_beams.front().p < next_p) { + std::pop_heap(next_beams.begin(), next_beams.end(), comp); + next_beams.back() = beam; + next_beams.back().tokens.push_back(next_tokens[i].id); + next_beams.back().p = next_p; + std::push_heap(next_beams.begin(), next_beams.end(), comp); + } + } + } + } + + // Find common_prefix_length based on beams. + // Requires beams is not empty. + size_t find_common_prefix_length() { + size_t common_prefix_length = beams[0].tokens.size(); + for (size_t i=1 ; i& beams) { + auto const sum_p = [](float sum, llama_beam& beam) { return sum + beam.p; }; + float const inv_sum = 1.0f / std::accumulate(beams.begin(), beams.end(), 0.0f, sum_p); + std::for_each(beams.begin(), beams.end(), [=](llama_beam& beam) { beam.p *= inv_sum; }); + } + + // Assumes beams is non-empty. Uses llama_beam::operator<() for ordering. + size_t top_beam_index() { + return std::max_element(beams.begin(), beams.end()) - beams.begin(); + } + + // Copy (p,eos) for each beam which may have been changed by the callback. + void update_beams_from_beam_views() { + for (size_t i=0 ; it_sample_us += ggml_time_us() - t_start_sample_us; + ctx->n_sample++; +} + // // quantization // diff --git a/llama.h b/llama.h index 2bcf94e0f..e88a45078 100644 --- a/llama.h +++ b/llama.h @@ -465,6 +465,39 @@ extern "C" { /// @details Accepts the sampled token into the grammar LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token); + struct llama_beam_view { + llama_token const* tokens; + size_t n_tokens; + float p; // Cumulative beam probability (renormalized relative to all beams) + bool eos; // Callback should set this to true when a beam is at end-of-sentence. + }; + + // Passed to beam_search_callback function. + // Whenever 0 < common_prefix_length, this number of tokens should be copied from any of the beams + // (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks. + // These pointers are valid only during the synchronous callback, so should not be saved. + struct llama_beams_state { + llama_beam_view* beam_views; + size_t n_beams; // Number of elements in beam_views[]. + size_t common_prefix_length; // Current max length of prefix tokens shared by all beams. + bool last_call; // True iff this is the last callback invocation. + }; + // Type of pointer to the beam_search_callback function. + // void* callback_data is any custom data passed to llama_beam_search, that is subsequently + // passed back to beam_search_callback. This avoids having to use global variables in the callback. + typedef void (*llama_beam_search_callback_fn_t)(void* callback_data, llama_beams_state); + + /// @details Deterministically returns entire sentence constructed by a beam search. + /// @param ctx Pointer to the llama_context. + /// @param callback Invoked for each iteration of the beam_search loop, passing in beams_state. + /// The return beam_search_control can be used to control the beam_search execution. + /// @param callback_data A pointer that is simply passed back to callback. + /// @param n_beams Number of beams to use. + /// @param n_past Number of tokens already evaluated. + /// @param n_predict Maximum number of tokens to predict. EOS may occur earlier. + /// @param n_threads Number of threads as passed to llama_eval(). + LLAMA_API void llama_beam_search(struct llama_context * ctx, llama_beam_search_callback_fn_t callback, void* callback_data, size_t n_beams, int n_past, int n_predict, int n_threads); + // Performance information LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx); LLAMA_API void llama_print_timings(struct llama_context * ctx);