Change eos to eob in llama_beam and llama_beam_view structs.
This commit is contained in:
parent
b619cfc059
commit
5fa1ea2c38
4 changed files with 24 additions and 22 deletions
|
@ -33,7 +33,7 @@ struct ostream_beam_view {
|
||||||
llama_beam_view beam_view;
|
llama_beam_view beam_view;
|
||||||
};
|
};
|
||||||
std::ostream& operator<<(std::ostream& os, const ostream_beam_view & obv) {
|
std::ostream& operator<<(std::ostream& os, const ostream_beam_view & obv) {
|
||||||
os << "p(" << obv.beam_view.p << ") eos(" << std::boolalpha << obv.beam_view.eos << ") tokens(";
|
os << "p(" << obv.beam_view.p << ") eob(" << std::boolalpha << obv.beam_view.eob << ") tokens(";
|
||||||
for (size_t i = 0 ; i < obv.beam_view.n_tokens ; ++i) {
|
for (size_t i = 0 ; i < obv.beam_view.n_tokens ; ++i) {
|
||||||
os << llama_token_to_str(obv.ctx, obv.beam_view.tokens[i]);
|
os << llama_token_to_str(obv.ctx, obv.beam_view.tokens[i]);
|
||||||
}
|
}
|
||||||
|
@ -46,7 +46,9 @@ struct beam_search_callback_data {
|
||||||
std::vector<llama_token> response;
|
std::vector<llama_token> response;
|
||||||
};
|
};
|
||||||
|
|
||||||
bool is_at_eos(const beam_search_callback_data & callback_data, const llama_token * tokens, const size_t n_tokens) {
|
// In this case, end-of-beam (eob) is equivalent to end-of-sentence (eos) but this need not always be the same.
|
||||||
|
// For example, eob can be flagged due to maximum token length, stop words, etc.
|
||||||
|
bool is_at_eob(const beam_search_callback_data & callback_data, const llama_token * tokens, const size_t n_tokens) {
|
||||||
return n_tokens && tokens[n_tokens-1] == llama_token_eos(callback_data.ctx);
|
return n_tokens && tokens[n_tokens-1] == llama_token_eos(callback_data.ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -61,8 +63,8 @@ void beam_search_callback(void * callback_data_ptr, llama_beams_state beams_stat
|
||||||
// Mark beams as EOS as needed.
|
// Mark beams as EOS as needed.
|
||||||
for (size_t i = 0 ; i < beams_state.n_beams ; ++i) {
|
for (size_t i = 0 ; i < beams_state.n_beams ; ++i) {
|
||||||
llama_beam_view& beam_view = beams_state.beam_views[i];
|
llama_beam_view& beam_view = beams_state.beam_views[i];
|
||||||
if (!beam_view.eos && is_at_eos(callback_data, beam_view.tokens, beam_view.n_tokens)) {
|
if (!beam_view.eob && is_at_eob(callback_data, beam_view.tokens, beam_view.n_tokens)) {
|
||||||
beam_view.eos = true;
|
beam_view.eob = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
printf(","); // Show progress
|
printf(","); // Show progress
|
||||||
|
|
|
@ -1209,7 +1209,7 @@ static void log_server_request(const Request &req, const Response &res)
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_at_eos(llama_server_context & server_context, const llama_token * tokens, const size_t n_tokens) {
|
bool is_at_eob(llama_server_context & server_context, const llama_token * tokens, const size_t n_tokens) {
|
||||||
return n_tokens && tokens[n_tokens-1] == llama_token_eos(server_context.ctx);
|
return n_tokens && tokens[n_tokens-1] == llama_token_eos(server_context.ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1223,9 +1223,9 @@ void beam_search_callback(void * callback_data, llama_beams_state beams_state) {
|
||||||
auto & llama = *static_cast<llama_server_context*>(callback_data);
|
auto & llama = *static_cast<llama_server_context*>(callback_data);
|
||||||
// Mark beams as EOS as needed.
|
// Mark beams as EOS as needed.
|
||||||
for (size_t i = 0 ; i < beams_state.n_beams ; ++i) {
|
for (size_t i = 0 ; i < beams_state.n_beams ; ++i) {
|
||||||
llama_beam_view & beam_view = beams_state.beam_views[i];
|
llama_beam_view& beam_view = beams_state.beam_views[i];
|
||||||
if (!beam_view.eos && is_at_eos(llama, beam_view.tokens, beam_view.n_tokens)) {
|
if (!beam_view.eob && is_at_eob(llama, beam_view.tokens, beam_view.n_tokens)) {
|
||||||
beam_view.eos = true;
|
beam_view.eob = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
printf(","); // Show progress
|
printf(","); // Show progress
|
||||||
|
|
26
llama.cpp
26
llama.cpp
|
@ -4333,10 +4333,10 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
|
||||||
struct llama_beam {
|
struct llama_beam {
|
||||||
std::vector<llama_token> tokens;
|
std::vector<llama_token> tokens;
|
||||||
float p; // Cumulative beam probability (renormalized relative to all beams)
|
float p; // Cumulative beam probability (renormalized relative to all beams)
|
||||||
bool eos; // Initialize end-of-sentence to false. Callback sets this to true.
|
bool eob; // Initialize end-of-beam to false. Callback sets this to true.
|
||||||
// Sort beams by probability. In case of ties, prefer beams at eos.
|
// Sort beams by probability. In case of ties, prefer beams at eob.
|
||||||
bool operator<(const llama_beam & rhs) const {
|
bool operator<(const llama_beam & rhs) const {
|
||||||
return std::make_tuple(p, eos) < std::make_tuple(rhs.p, rhs.eos);
|
return std::make_pair(p, eob) < std::make_pair(rhs.p, rhs.eob);
|
||||||
}
|
}
|
||||||
// Shift off first n tokens and discard them.
|
// Shift off first n tokens and discard them.
|
||||||
void shift_tokens(const size_t n) {
|
void shift_tokens(const size_t n) {
|
||||||
|
@ -4345,7 +4345,7 @@ struct llama_beam {
|
||||||
tokens.resize(tokens.size() - n);
|
tokens.resize(tokens.size() - n);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
llama_beam_view view() const { return {tokens.data(), tokens.size(), p, eos}; }
|
llama_beam_view view() const { return {tokens.data(), tokens.size(), p, eob}; }
|
||||||
};
|
};
|
||||||
|
|
||||||
// A struct for calculating logit-related info.
|
// A struct for calculating logit-related info.
|
||||||
|
@ -4435,7 +4435,7 @@ struct llama_beam_search_data {
|
||||||
void fill_next_beams_by_top_probabilities(llama_beam & beam) {
|
void fill_next_beams_by_top_probabilities(llama_beam & beam) {
|
||||||
// Min-heaps use a greater-than comparator.
|
// Min-heaps use a greater-than comparator.
|
||||||
const auto comp = [](const llama_beam & a, const llama_beam & b) { return a.p > b.p; };
|
const auto comp = [](const llama_beam & a, const llama_beam & b) { return a.p > b.p; };
|
||||||
if (beam.eos) {
|
if (beam.eob) {
|
||||||
// beam is at end-of-sentence, so just copy it to next_beams if its probability is high enough.
|
// beam is at end-of-sentence, so just copy it to next_beams if its probability is high enough.
|
||||||
if (next_beams.size() < n_beams) {
|
if (next_beams.size() < n_beams) {
|
||||||
next_beams.push_back(std::move(beam));
|
next_beams.push_back(std::move(beam));
|
||||||
|
@ -4513,16 +4513,16 @@ struct llama_beam_search_data {
|
||||||
|
|
||||||
// Loop:
|
// Loop:
|
||||||
// * while i < n_predict, AND
|
// * while i < n_predict, AND
|
||||||
// * any of the beams have not yet reached end-of-sentence, AND
|
// * any of the beams have not yet reached end-of-beam (eob), AND
|
||||||
// * the highest probability beam(s) (plural in case of ties) are not at end-of-sentence
|
// * the highest probability beam(s) (plural in case of ties) are not at end-of-sentence
|
||||||
// (since all other beam probabilities can only decrease)
|
// (since all other beam probabilities can only decrease)
|
||||||
void loop(const llama_beam_search_callback_fn_t callback, void * const callback_data) {
|
void loop(const llama_beam_search_callback_fn_t callback, void * const callback_data) {
|
||||||
beams.push_back({{}, 1.0f, false}); // Start with one empty beam w/ probability = 1.0 and !eos.
|
beams.push_back({{}, 1.0f, false}); // Start with one empty beam w/ probability = 1.0 and !eob.
|
||||||
const auto not_eos = [](const llama_beam & beam) { return !beam.eos; };
|
const auto not_eob = [](const llama_beam & beam) { return !beam.eob; };
|
||||||
for (int i = 0 ; i < n_predict && std::any_of(beams.begin(),beams.end(),not_eos) &&
|
for (int i = 0 ; i < n_predict && std::any_of(beams.begin(),beams.end(),not_eob) &&
|
||||||
!beams[top_beam_index()].eos ; ++i) {
|
!beams[top_beam_index()].eob ; ++i) {
|
||||||
callback(callback_data, get_beams_state(false)); // Sets common_prefix_length
|
callback(callback_data, get_beams_state(false)); // Sets common_prefix_length
|
||||||
update_beams_from_beam_views(); // Update values (p,eos) that callback may have changed.
|
update_beams_from_beam_views(); // Update values (p,eob) that callback may have changed.
|
||||||
if (common_prefix_length) {
|
if (common_prefix_length) {
|
||||||
llama_eval(ctx, beams[0].tokens.data(), common_prefix_length, n_past, n_threads);
|
llama_eval(ctx, beams[0].tokens.data(), common_prefix_length, n_past, n_threads);
|
||||||
n_past += common_prefix_length;
|
n_past += common_prefix_length;
|
||||||
|
@ -4554,11 +4554,11 @@ struct llama_beam_search_data {
|
||||||
return std::max_element(beams.begin(), beams.end()) - beams.begin();
|
return std::max_element(beams.begin(), beams.end()) - beams.begin();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy (p,eos) for each beam which may have been changed by the callback.
|
// Copy (p,eob) for each beam which may have been changed by the callback.
|
||||||
void update_beams_from_beam_views() {
|
void update_beams_from_beam_views() {
|
||||||
for (size_t i = 0 ; i < beams.size() ; ++i) {
|
for (size_t i = 0 ; i < beams.size() ; ++i) {
|
||||||
beams[i].p = beam_views[i].p;
|
beams[i].p = beam_views[i].p;
|
||||||
beams[i].eos = beam_views[i].eos;
|
beams[i].eob = beam_views[i].eob;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
2
llama.h
2
llama.h
|
@ -473,7 +473,7 @@ extern "C" {
|
||||||
const llama_token * tokens;
|
const llama_token * tokens;
|
||||||
size_t n_tokens;
|
size_t n_tokens;
|
||||||
float p; // Cumulative beam probability (renormalized relative to all beams)
|
float p; // Cumulative beam probability (renormalized relative to all beams)
|
||||||
bool eos; // Callback should set this to true when a beam is at end-of-sentence.
|
bool eob; // Callback should set this to true when a beam is at end-of-beam.
|
||||||
};
|
};
|
||||||
|
|
||||||
// Passed to beam_search_callback function.
|
// Passed to beam_search_callback function.
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue