diff --git a/examples/beam_search/beam_search.cpp b/examples/beam_search/beam_search.cpp index 1d0d077d1..1c04fabc2 100644 --- a/examples/beam_search/beam_search.cpp +++ b/examples/beam_search/beam_search.cpp @@ -33,7 +33,7 @@ struct ostream_beam_view { llama_beam_view beam_view; }; std::ostream& operator<<(std::ostream& os, const ostream_beam_view & obv) { - os << "p(" << obv.beam_view.p << ") eos(" << std::boolalpha << obv.beam_view.eos << ") tokens("; + os << "p(" << obv.beam_view.p << ") eob(" << std::boolalpha << obv.beam_view.eob << ") tokens("; for (size_t i = 0 ; i < obv.beam_view.n_tokens ; ++i) { os << llama_token_to_str(obv.ctx, obv.beam_view.tokens[i]); } @@ -46,7 +46,9 @@ struct beam_search_callback_data { std::vector response; }; -bool is_at_eos(const beam_search_callback_data & callback_data, const llama_token * tokens, const size_t n_tokens) { +// In this case, end-of-beam (eob) is equivalent to end-of-sentence (eos) but this need not always be the same. +// For example, eob can be flagged due to maximum token length, stop words, etc. +bool is_at_eob(const beam_search_callback_data & callback_data, const llama_token * tokens, const size_t n_tokens) { return n_tokens && tokens[n_tokens-1] == llama_token_eos(callback_data.ctx); } @@ -61,8 +63,8 @@ void beam_search_callback(void * callback_data_ptr, llama_beams_state beams_stat // Mark beams as EOS as needed. for (size_t i = 0 ; i < beams_state.n_beams ; ++i) { llama_beam_view& beam_view = beams_state.beam_views[i]; - if (!beam_view.eos && is_at_eos(callback_data, beam_view.tokens, beam_view.n_tokens)) { - beam_view.eos = true; + if (!beam_view.eob && is_at_eob(callback_data, beam_view.tokens, beam_view.n_tokens)) { + beam_view.eob = true; } } printf(","); // Show progress diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 94a029bbf..3300553f9 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1209,7 +1209,7 @@ static void log_server_request(const Request &req, const Response &res) }); } -bool is_at_eos(llama_server_context & server_context, const llama_token * tokens, const size_t n_tokens) { +bool is_at_eob(llama_server_context & server_context, const llama_token * tokens, const size_t n_tokens) { return n_tokens && tokens[n_tokens-1] == llama_token_eos(server_context.ctx); } @@ -1223,9 +1223,9 @@ void beam_search_callback(void * callback_data, llama_beams_state beams_state) { auto & llama = *static_cast(callback_data); // Mark beams as EOS as needed. for (size_t i = 0 ; i < beams_state.n_beams ; ++i) { - llama_beam_view & beam_view = beams_state.beam_views[i]; - if (!beam_view.eos && is_at_eos(llama, beam_view.tokens, beam_view.n_tokens)) { - beam_view.eos = true; + llama_beam_view& beam_view = beams_state.beam_views[i]; + if (!beam_view.eob && is_at_eob(llama, beam_view.tokens, beam_view.n_tokens)) { + beam_view.eob = true; } } printf(","); // Show progress diff --git a/llama.cpp b/llama.cpp index 9f23a6a9d..5b8e3bbaa 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4333,10 +4333,10 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar struct llama_beam { std::vector tokens; float p; // Cumulative beam probability (renormalized relative to all beams) - bool eos; // Initialize end-of-sentence to false. Callback sets this to true. - // Sort beams by probability. In case of ties, prefer beams at eos. + bool eob; // Initialize end-of-beam to false. Callback sets this to true. + // Sort beams by probability. In case of ties, prefer beams at eob. bool operator<(const llama_beam & rhs) const { - return std::make_tuple(p, eos) < std::make_tuple(rhs.p, rhs.eos); + return std::make_pair(p, eob) < std::make_pair(rhs.p, rhs.eob); } // Shift off first n tokens and discard them. void shift_tokens(const size_t n) { @@ -4345,7 +4345,7 @@ struct llama_beam { tokens.resize(tokens.size() - n); } } - llama_beam_view view() const { return {tokens.data(), tokens.size(), p, eos}; } + llama_beam_view view() const { return {tokens.data(), tokens.size(), p, eob}; } }; // A struct for calculating logit-related info. @@ -4435,7 +4435,7 @@ struct llama_beam_search_data { void fill_next_beams_by_top_probabilities(llama_beam & beam) { // Min-heaps use a greater-than comparator. const auto comp = [](const llama_beam & a, const llama_beam & b) { return a.p > b.p; }; - if (beam.eos) { + if (beam.eob) { // beam is at end-of-sentence, so just copy it to next_beams if its probability is high enough. if (next_beams.size() < n_beams) { next_beams.push_back(std::move(beam)); @@ -4513,16 +4513,16 @@ struct llama_beam_search_data { // Loop: // * while i < n_predict, AND - // * any of the beams have not yet reached end-of-sentence, AND + // * any of the beams have not yet reached end-of-beam (eob), AND // * the highest probability beam(s) (plural in case of ties) are not at end-of-sentence // (since all other beam probabilities can only decrease) void loop(const llama_beam_search_callback_fn_t callback, void * const callback_data) { - beams.push_back({{}, 1.0f, false}); // Start with one empty beam w/ probability = 1.0 and !eos. - const auto not_eos = [](const llama_beam & beam) { return !beam.eos; }; - for (int i = 0 ; i < n_predict && std::any_of(beams.begin(),beams.end(),not_eos) && - !beams[top_beam_index()].eos ; ++i) { + beams.push_back({{}, 1.0f, false}); // Start with one empty beam w/ probability = 1.0 and !eob. + const auto not_eob = [](const llama_beam & beam) { return !beam.eob; }; + for (int i = 0 ; i < n_predict && std::any_of(beams.begin(),beams.end(),not_eob) && + !beams[top_beam_index()].eob ; ++i) { callback(callback_data, get_beams_state(false)); // Sets common_prefix_length - update_beams_from_beam_views(); // Update values (p,eos) that callback may have changed. + update_beams_from_beam_views(); // Update values (p,eob) that callback may have changed. if (common_prefix_length) { llama_eval(ctx, beams[0].tokens.data(), common_prefix_length, n_past, n_threads); n_past += common_prefix_length; @@ -4554,11 +4554,11 @@ struct llama_beam_search_data { return std::max_element(beams.begin(), beams.end()) - beams.begin(); } - // Copy (p,eos) for each beam which may have been changed by the callback. + // Copy (p,eob) for each beam which may have been changed by the callback. void update_beams_from_beam_views() { for (size_t i = 0 ; i < beams.size() ; ++i) { beams[i].p = beam_views[i].p; - beams[i].eos = beam_views[i].eos; + beams[i].eob = beam_views[i].eob; } } }; diff --git a/llama.h b/llama.h index 47e7a2ebe..cca803181 100644 --- a/llama.h +++ b/llama.h @@ -473,7 +473,7 @@ extern "C" { const llama_token * tokens; size_t n_tokens; float p; // Cumulative beam probability (renormalized relative to all beams) - bool eos; // Callback should set this to true when a beam is at end-of-sentence. + bool eob; // Callback should set this to true when a beam is at end-of-beam. }; // Passed to beam_search_callback function.