only sample full tokens (no peeking or truncation)

This commit is contained in:
Evan Jones 2023-07-19 23:28:57 -04:00
parent 8d37755bdc
commit c047e8aec2
3 changed files with 119 additions and 87 deletions

View file

@ -613,7 +613,7 @@ int main(int argc, char ** argv) {
// printf("`%d`", candidates_p.size); // printf("`%d`", candidates_p.size);
if (grammar != NULL) { if (grammar != NULL) {
id = llama_grammar_accept_token(ctx, grammar, id); llama_grammar_accept_token(ctx, grammar, id);
} }
last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.erase(last_n_tokens.begin());

180
llama.cpp
View file

@ -1900,20 +1900,32 @@ struct llama_grammar {
std::vector<std::vector<const llama_grammar_element *>> stacks; std::vector<std::vector<const llama_grammar_element *>> stacks;
}; };
struct llama_grammar_candidate {
size_t index;
const uint32_t * code_points;
};
// NOTE: assumes valid utf8 (but checks for overrun) // NOTE: assumes valid utf8 (but checks for overrun)
std::pair<uint32_t, const char *> decode_utf8(const char * src) { // adds a terminating 0 for use as pointer
std::vector<uint32_t> decode_utf8(const char * src) {
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
uint8_t first_byte = static_cast<uint8_t>(*src); const char * pos = src;
std::vector<uint32_t> code_points;
while (*pos != 0) {
uint8_t first_byte = static_cast<uint8_t>(*pos);
uint8_t highbits = first_byte >> 4; uint8_t highbits = first_byte >> 4;
int len = lookup[highbits]; int len = lookup[highbits];
uint8_t mask = (1 << (8 - len)) - 1; uint8_t mask = (1 << (8 - len)) - 1;
uint32_t value = first_byte & mask; uint32_t value = first_byte & mask;
const char * end = src + len; // may overrun! const char * end = pos + len; // may overrun!
const char * pos = src + 1; // may overrun! ++pos;
for ( ; pos < end && *pos; pos++) { for ( ; pos < end && *pos != 0; ++pos) {
value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F); value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
} }
return std::make_pair(value, pos); code_points.push_back(value);
}
code_points.push_back(0);
return code_points;
} }
// returns true iff pos points to the end of one of the definitions of a rule // returns true iff pos points to the end of one of the definitions of a rule
@ -2037,23 +2049,72 @@ static std::vector<std::vector<const llama_grammar_element *>> llama_grammar_acc
return new_stacks; return new_stacks;
} }
// returns `true` if one of the pushdown stacks can accept the given char. static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates(
static bool llama_grammar_peek( const std::vector<std::vector<llama_grammar_element>> & rules,
const std::vector<std::vector<const llama_grammar_element *>> & stacks, const std::vector<std::vector<const llama_grammar_element *>> & stacks,
const uint32_t chr) { const std::vector<llama_grammar_candidate> & candidates);
static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
const std::vector<std::vector<llama_grammar_element>> & rules,
const std::vector<const llama_grammar_element *> & stack,
const std::vector<llama_grammar_candidate> & candidates) {
std::vector<llama_grammar_candidate> rejects;
for (const auto & stack : stacks) {
if (stack.empty()) { if (stack.empty()) {
if (!chr) { // accept nothing; EOS is handled elsewhere
return true; rejects.insert(rejects.end(), candidates.begin(), candidates.end());
return rejects;
} }
} else if (llama_grammar_match_char(stack.back(), chr).first) {
return true; const llama_grammar_element * stack_pos = stack.back();
std::vector<llama_grammar_candidate> next_candidates;
for (auto tok : candidates) {
if (llama_grammar_match_char(stack_pos, tok.code_points[0]).first) {
if (tok.code_points[1] != 0) {
next_candidates.push_back({ tok.index, tok.code_points + 1 });
}
} else {
rejects.push_back(tok);
} }
} }
return false;
auto stack_pos_after = llama_grammar_match_char(stack_pos, 0).second;
// update top of stack to next element, if any
std::vector<const llama_grammar_element *> stack_after(stack.begin(), stack.end() - 1);
if (!llama_grammar_is_end_of_sequence(stack_pos_after)) {
stack_after.push_back(stack_pos_after);
}
std::vector<std::vector<const llama_grammar_element *>> next_stacks;
llama_grammar_advance_stack(rules, stack_after, next_stacks);
auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates);
for (auto tok : next_rejects) {
rejects.push_back({ tok.index, tok.code_points - 1 });
}
return rejects;
} }
static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates(
const std::vector<std::vector<llama_grammar_element>> & rules,
const std::vector<std::vector<const llama_grammar_element *>> & stacks,
const std::vector<llama_grammar_candidate> & candidates) {
LLAMA_ASSERT(!stacks.empty()); // REVIEW
if (candidates.empty()) {
return std::vector<llama_grammar_candidate>();
}
auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates);
for (size_t i = 1, size = stacks.size(); i < size; ++i) {
rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
}
return rejects;
}
// //
// grammar - external // grammar - external
@ -2383,32 +2444,37 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
assert(ctx); assert(ctx);
const int64_t t_start_sample_us = ggml_time_us(); const int64_t t_start_sample_us = ggml_time_us();
const llama_token eos = llama_token_eos(); const llama_token eos = llama_token_eos();
// since many llama tokens are prefixed with a single space, special case a lookahead on ' '
const auto stacks_after_space = llama_grammar_accept(grammar->rules, grammar->stacks, U' '); bool allow_eos = false;
for (const auto & stack : grammar->stacks) {
if (stack.empty()) {
allow_eos = true;
break;
}
}
std::vector<std::vector<uint32_t>> decoded_candidates;
std::vector<llama_grammar_candidate> grammar_candidates;
for (size_t i = 0; i < candidates->size; ++i) { for (size_t i = 0; i < candidates->size; ++i) {
const llama_token id = candidates->data[i].id; const llama_token id = candidates->data[i].id;
const char * str = llama_token_to_str(ctx, id); const char * str = llama_token_to_str(ctx, id);
// prune tokens based on first char only - in `llama_grammar_accept_token` we will find the
// full matching prefix of the selected token
bool valid = false;
if (id == eos) { if (id == eos) {
valid = llama_grammar_peek(grammar->stacks, 0); if (!allow_eos) {
} else {
const auto decoded = decode_utf8(str);
const uint32_t chr = decoded.first;
if (chr == U' ') {
const char * next = decoded.second;
valid = llama_grammar_peek(stacks_after_space, decode_utf8(next).first);
} else if (chr != 0) {
valid = llama_grammar_peek(grammar->stacks, chr);
}
}
if (!valid) {
candidates->data[i].logit = -INFINITY; candidates->data[i].logit = -INFINITY;
} }
} else if (*str == 0) {
candidates->data[i].logit = -INFINITY;
} else {
decoded_candidates.push_back(decode_utf8(str));
grammar_candidates.push_back({ i, decoded_candidates.back().data() });
}
}
auto rejects =
llama_grammar_reject_candidates(grammar->rules, grammar->stacks, grammar_candidates);
for (auto reject : rejects) {
candidates->data[reject.index].logit = -INFINITY;
} }
ctx->t_sample_us += ggml_time_us() - t_start_sample_us; ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
@ -2599,61 +2665,27 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra
return result; return result;
} }
llama_token llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) { void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) {
const int64_t t_start_sample_us = ggml_time_us(); const int64_t t_start_sample_us = ggml_time_us();
if (token == llama_token_eos()) { if (token == llama_token_eos()) {
for (const auto & stack : grammar->stacks) { for (const auto & stack : grammar->stacks) {
if (stack.empty()) { if (stack.empty()) {
return token; return;
} }
} }
LLAMA_ASSERT(false); LLAMA_ASSERT(false);
} }
const char * str = llama_token_to_str(ctx, token); const char * str = llama_token_to_str(ctx, token);
const char * suffix = str; // Note terminating 0 in decoded string
auto code_points = decode_utf8(str);
// Find prefix of selected token that matches grammar, expecting at least 1 char for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
auto decoded = decode_utf8(suffix); grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
auto new_stacks = llama_grammar_accept(grammar->rules, grammar->stacks, decoded.first);
LLAMA_ASSERT(!new_stacks.empty());
if (*suffix) {
suffix = decoded.second;
for ( ; *suffix; suffix = decoded.second) {
decoded = decode_utf8(suffix);
new_stacks = llama_grammar_accept(grammar->rules, new_stacks, decoded.first);
if (new_stacks.empty() ) {
break;
} }
}
}
// if full token is matched, accept new stacks
if (!(*suffix)) {
grammar->stacks = new_stacks;
return token;
}
// otherwise, tokenize the string prefix that did match
llama_token tokens[32]; // TODO - determine actual max token size
const std::string prefix_str(str, suffix - str);
int n_tokens = llama_tokenize(ctx, prefix_str.c_str(), tokens, 32, false);
if (n_tokens < 1) {
return token; // REVIEW
}
// accept the first token of the matching prefix into the grammar
llama_token first_prefix_token = tokens[0];
const char * first_prefix_str = llama_token_to_str(ctx, first_prefix_token);
for ( ; *first_prefix_str; first_prefix_str = decoded.second) {
decoded = decode_utf8(first_prefix_str);
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, decoded.first);
LLAMA_ASSERT(!grammar->stacks.empty()); LLAMA_ASSERT(!grammar->stacks.empty());
}
ctx->t_sample_us += ggml_time_us() - t_start_sample_us; ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
return first_prefix_token;
} }
// //

View file

@ -404,8 +404,8 @@ extern "C" {
/// @details Randomly selects a token from the candidates based on their probabilities. /// @details Randomly selects a token from the candidates based on their probabilities.
LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates); LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates);
/// @details Accepts the sampled token into the grammar, possibly transforming to a new token /// @details Accepts the sampled token into the grammar
LLAMA_API llama_token llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token); LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token);
// Performance information // Performance information
LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx); LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);