only sample full tokens (no peeking or truncation)
This commit is contained in:
parent
8d37755bdc
commit
c047e8aec2
3 changed files with 119 additions and 87 deletions
|
@ -613,7 +613,7 @@ int main(int argc, char ** argv) {
|
||||||
// printf("`%d`", candidates_p.size);
|
// printf("`%d`", candidates_p.size);
|
||||||
|
|
||||||
if (grammar != NULL) {
|
if (grammar != NULL) {
|
||||||
id = llama_grammar_accept_token(ctx, grammar, id);
|
llama_grammar_accept_token(ctx, grammar, id);
|
||||||
}
|
}
|
||||||
|
|
||||||
last_n_tokens.erase(last_n_tokens.begin());
|
last_n_tokens.erase(last_n_tokens.begin());
|
||||||
|
|
186
llama.cpp
186
llama.cpp
|
@ -1900,20 +1900,32 @@ struct llama_grammar {
|
||||||
std::vector<std::vector<const llama_grammar_element *>> stacks;
|
std::vector<std::vector<const llama_grammar_element *>> stacks;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct llama_grammar_candidate {
|
||||||
|
size_t index;
|
||||||
|
const uint32_t * code_points;
|
||||||
|
};
|
||||||
|
|
||||||
// NOTE: assumes valid utf8 (but checks for overrun)
|
// NOTE: assumes valid utf8 (but checks for overrun)
|
||||||
std::pair<uint32_t, const char *> decode_utf8(const char * src) {
|
// adds a terminating 0 for use as pointer
|
||||||
|
std::vector<uint32_t> decode_utf8(const char * src) {
|
||||||
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
||||||
uint8_t first_byte = static_cast<uint8_t>(*src);
|
const char * pos = src;
|
||||||
|
std::vector<uint32_t> code_points;
|
||||||
|
while (*pos != 0) {
|
||||||
|
uint8_t first_byte = static_cast<uint8_t>(*pos);
|
||||||
uint8_t highbits = first_byte >> 4;
|
uint8_t highbits = first_byte >> 4;
|
||||||
int len = lookup[highbits];
|
int len = lookup[highbits];
|
||||||
uint8_t mask = (1 << (8 - len)) - 1;
|
uint8_t mask = (1 << (8 - len)) - 1;
|
||||||
uint32_t value = first_byte & mask;
|
uint32_t value = first_byte & mask;
|
||||||
const char * end = src + len; // may overrun!
|
const char * end = pos + len; // may overrun!
|
||||||
const char * pos = src + 1; // may overrun!
|
++pos;
|
||||||
for ( ; pos < end && *pos; pos++) {
|
for ( ; pos < end && *pos != 0; ++pos) {
|
||||||
value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
|
value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
|
||||||
}
|
}
|
||||||
return std::make_pair(value, pos);
|
code_points.push_back(value);
|
||||||
|
}
|
||||||
|
code_points.push_back(0);
|
||||||
|
return code_points;
|
||||||
}
|
}
|
||||||
|
|
||||||
// returns true iff pos points to the end of one of the definitions of a rule
|
// returns true iff pos points to the end of one of the definitions of a rule
|
||||||
|
@ -2037,23 +2049,72 @@ static std::vector<std::vector<const llama_grammar_element *>> llama_grammar_acc
|
||||||
return new_stacks;
|
return new_stacks;
|
||||||
}
|
}
|
||||||
|
|
||||||
// returns `true` if one of the pushdown stacks can accept the given char.
|
static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates(
|
||||||
static bool llama_grammar_peek(
|
const std::vector<std::vector<llama_grammar_element>> & rules,
|
||||||
const std::vector<std::vector<const llama_grammar_element *>> & stacks,
|
const std::vector<std::vector<const llama_grammar_element *>> & stacks,
|
||||||
const uint32_t chr) {
|
const std::vector<llama_grammar_candidate> & candidates);
|
||||||
|
|
||||||
|
static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
|
||||||
|
const std::vector<std::vector<llama_grammar_element>> & rules,
|
||||||
|
const std::vector<const llama_grammar_element *> & stack,
|
||||||
|
const std::vector<llama_grammar_candidate> & candidates) {
|
||||||
|
|
||||||
|
std::vector<llama_grammar_candidate> rejects;
|
||||||
|
|
||||||
for (const auto & stack : stacks) {
|
|
||||||
if (stack.empty()) {
|
if (stack.empty()) {
|
||||||
if (!chr) {
|
// accept nothing; EOS is handled elsewhere
|
||||||
return true;
|
rejects.insert(rejects.end(), candidates.begin(), candidates.end());
|
||||||
}
|
return rejects;
|
||||||
} else if (llama_grammar_match_char(stack.back(), chr).first) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const llama_grammar_element * stack_pos = stack.back();
|
||||||
|
|
||||||
|
std::vector<llama_grammar_candidate> next_candidates;
|
||||||
|
for (auto tok : candidates) {
|
||||||
|
if (llama_grammar_match_char(stack_pos, tok.code_points[0]).first) {
|
||||||
|
if (tok.code_points[1] != 0) {
|
||||||
|
next_candidates.push_back({ tok.index, tok.code_points + 1 });
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
rejects.push_back(tok);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto stack_pos_after = llama_grammar_match_char(stack_pos, 0).second;
|
||||||
|
|
||||||
|
// update top of stack to next element, if any
|
||||||
|
std::vector<const llama_grammar_element *> stack_after(stack.begin(), stack.end() - 1);
|
||||||
|
if (!llama_grammar_is_end_of_sequence(stack_pos_after)) {
|
||||||
|
stack_after.push_back(stack_pos_after);
|
||||||
|
}
|
||||||
|
std::vector<std::vector<const llama_grammar_element *>> next_stacks;
|
||||||
|
llama_grammar_advance_stack(rules, stack_after, next_stacks);
|
||||||
|
|
||||||
|
auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates);
|
||||||
|
for (auto tok : next_rejects) {
|
||||||
|
rejects.push_back({ tok.index, tok.code_points - 1 });
|
||||||
|
}
|
||||||
|
|
||||||
|
return rejects;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates(
|
||||||
|
const std::vector<std::vector<llama_grammar_element>> & rules,
|
||||||
|
const std::vector<std::vector<const llama_grammar_element *>> & stacks,
|
||||||
|
const std::vector<llama_grammar_candidate> & candidates) {
|
||||||
|
LLAMA_ASSERT(!stacks.empty()); // REVIEW
|
||||||
|
|
||||||
|
if (candidates.empty()) {
|
||||||
|
return std::vector<llama_grammar_candidate>();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates);
|
||||||
|
|
||||||
|
for (size_t i = 1, size = stacks.size(); i < size; ++i) {
|
||||||
|
rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
|
||||||
|
}
|
||||||
|
return rejects;
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// grammar - external
|
// grammar - external
|
||||||
|
@ -2383,32 +2444,37 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
|
||||||
assert(ctx);
|
assert(ctx);
|
||||||
const int64_t t_start_sample_us = ggml_time_us();
|
const int64_t t_start_sample_us = ggml_time_us();
|
||||||
const llama_token eos = llama_token_eos();
|
const llama_token eos = llama_token_eos();
|
||||||
// since many llama tokens are prefixed with a single space, special case a lookahead on ' '
|
|
||||||
const auto stacks_after_space = llama_grammar_accept(grammar->rules, grammar->stacks, U' ');
|
bool allow_eos = false;
|
||||||
|
for (const auto & stack : grammar->stacks) {
|
||||||
|
if (stack.empty()) {
|
||||||
|
allow_eos = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::vector<uint32_t>> decoded_candidates;
|
||||||
|
std::vector<llama_grammar_candidate> grammar_candidates;
|
||||||
|
|
||||||
for (size_t i = 0; i < candidates->size; ++i) {
|
for (size_t i = 0; i < candidates->size; ++i) {
|
||||||
const llama_token id = candidates->data[i].id;
|
const llama_token id = candidates->data[i].id;
|
||||||
const char * str = llama_token_to_str(ctx, id);
|
const char * str = llama_token_to_str(ctx, id);
|
||||||
|
|
||||||
// prune tokens based on first char only - in `llama_grammar_accept_token` we will find the
|
|
||||||
// full matching prefix of the selected token
|
|
||||||
bool valid = false;
|
|
||||||
if (id == eos) {
|
if (id == eos) {
|
||||||
valid = llama_grammar_peek(grammar->stacks, 0);
|
if (!allow_eos) {
|
||||||
} else {
|
|
||||||
const auto decoded = decode_utf8(str);
|
|
||||||
const uint32_t chr = decoded.first;
|
|
||||||
if (chr == U' ') {
|
|
||||||
const char * next = decoded.second;
|
|
||||||
valid = llama_grammar_peek(stacks_after_space, decode_utf8(next).first);
|
|
||||||
} else if (chr != 0) {
|
|
||||||
valid = llama_grammar_peek(grammar->stacks, chr);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!valid) {
|
|
||||||
candidates->data[i].logit = -INFINITY;
|
candidates->data[i].logit = -INFINITY;
|
||||||
}
|
}
|
||||||
|
} else if (*str == 0) {
|
||||||
|
candidates->data[i].logit = -INFINITY;
|
||||||
|
} else {
|
||||||
|
decoded_candidates.push_back(decode_utf8(str));
|
||||||
|
grammar_candidates.push_back({ i, decoded_candidates.back().data() });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto rejects =
|
||||||
|
llama_grammar_reject_candidates(grammar->rules, grammar->stacks, grammar_candidates);
|
||||||
|
for (auto reject : rejects) {
|
||||||
|
candidates->data[reject.index].logit = -INFINITY;
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||||
|
@ -2599,61 +2665,27 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) {
|
void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) {
|
||||||
const int64_t t_start_sample_us = ggml_time_us();
|
const int64_t t_start_sample_us = ggml_time_us();
|
||||||
|
|
||||||
if (token == llama_token_eos()) {
|
if (token == llama_token_eos()) {
|
||||||
for (const auto & stack : grammar->stacks) {
|
for (const auto & stack : grammar->stacks) {
|
||||||
if (stack.empty()) {
|
if (stack.empty()) {
|
||||||
return token;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
LLAMA_ASSERT(false);
|
LLAMA_ASSERT(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
const char * str = llama_token_to_str(ctx, token);
|
const char * str = llama_token_to_str(ctx, token);
|
||||||
const char * suffix = str;
|
// Note terminating 0 in decoded string
|
||||||
|
auto code_points = decode_utf8(str);
|
||||||
// Find prefix of selected token that matches grammar, expecting at least 1 char
|
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
||||||
auto decoded = decode_utf8(suffix);
|
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
|
||||||
auto new_stacks = llama_grammar_accept(grammar->rules, grammar->stacks, decoded.first);
|
|
||||||
LLAMA_ASSERT(!new_stacks.empty());
|
|
||||||
if (*suffix) {
|
|
||||||
suffix = decoded.second;
|
|
||||||
for ( ; *suffix; suffix = decoded.second) {
|
|
||||||
decoded = decode_utf8(suffix);
|
|
||||||
new_stacks = llama_grammar_accept(grammar->rules, new_stacks, decoded.first);
|
|
||||||
if (new_stacks.empty() ) {
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// if full token is matched, accept new stacks
|
|
||||||
if (!(*suffix)) {
|
|
||||||
grammar->stacks = new_stacks;
|
|
||||||
return token;
|
|
||||||
}
|
|
||||||
|
|
||||||
// otherwise, tokenize the string prefix that did match
|
|
||||||
llama_token tokens[32]; // TODO - determine actual max token size
|
|
||||||
const std::string prefix_str(str, suffix - str);
|
|
||||||
int n_tokens = llama_tokenize(ctx, prefix_str.c_str(), tokens, 32, false);
|
|
||||||
if (n_tokens < 1) {
|
|
||||||
return token; // REVIEW
|
|
||||||
}
|
|
||||||
|
|
||||||
// accept the first token of the matching prefix into the grammar
|
|
||||||
llama_token first_prefix_token = tokens[0];
|
|
||||||
const char * first_prefix_str = llama_token_to_str(ctx, first_prefix_token);
|
|
||||||
for ( ; *first_prefix_str; first_prefix_str = decoded.second) {
|
|
||||||
decoded = decode_utf8(first_prefix_str);
|
|
||||||
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, decoded.first);
|
|
||||||
LLAMA_ASSERT(!grammar->stacks.empty());
|
LLAMA_ASSERT(!grammar->stacks.empty());
|
||||||
}
|
|
||||||
|
|
||||||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||||
return first_prefix_token;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
|
|
4
llama.h
4
llama.h
|
@ -404,8 +404,8 @@ extern "C" {
|
||||||
/// @details Randomly selects a token from the candidates based on their probabilities.
|
/// @details Randomly selects a token from the candidates based on their probabilities.
|
||||||
LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates);
|
LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates);
|
||||||
|
|
||||||
/// @details Accepts the sampled token into the grammar, possibly transforming to a new token
|
/// @details Accepts the sampled token into the grammar
|
||||||
LLAMA_API llama_token llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token);
|
LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token);
|
||||||
|
|
||||||
// Performance information
|
// Performance information
|
||||||
LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);
|
LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue