diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 035afbe28..ed86b5022 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -613,7 +613,7 @@ int main(int argc, char ** argv) { // printf("`%d`", candidates_p.size); if (grammar != NULL) { - id = llama_grammar_accept_token(ctx, grammar, id); + llama_grammar_accept_token(ctx, grammar, id); } last_n_tokens.erase(last_n_tokens.begin()); diff --git a/llama.cpp b/llama.cpp index 0f6eddb50..d8c67a662 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1900,20 +1900,32 @@ struct llama_grammar { std::vector> stacks; }; +struct llama_grammar_candidate { + size_t index; + const uint32_t * code_points; +}; + // NOTE: assumes valid utf8 (but checks for overrun) -std::pair decode_utf8(const char * src) { - static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; - uint8_t first_byte = static_cast(*src); - uint8_t highbits = first_byte >> 4; - int len = lookup[highbits]; - uint8_t mask = (1 << (8 - len)) - 1; - uint32_t value = first_byte & mask; - const char * end = src + len; // may overrun! - const char * pos = src + 1; // may overrun! - for ( ; pos < end && *pos; pos++) { - value = (value << 6) + (static_cast(*pos) & 0x3F); +// adds a terminating 0 for use as pointer +std::vector decode_utf8(const char * src) { + static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; + const char * pos = src; + std::vector code_points; + while (*pos != 0) { + uint8_t first_byte = static_cast(*pos); + uint8_t highbits = first_byte >> 4; + int len = lookup[highbits]; + uint8_t mask = (1 << (8 - len)) - 1; + uint32_t value = first_byte & mask; + const char * end = pos + len; // may overrun! + ++pos; + for ( ; pos < end && *pos != 0; ++pos) { + value = (value << 6) + (static_cast(*pos) & 0x3F); + } + code_points.push_back(value); } - return std::make_pair(value, pos); + code_points.push_back(0); + return code_points; } // returns true iff pos points to the end of one of the definitions of a rule @@ -2037,23 +2049,72 @@ static std::vector> llama_grammar_acc return new_stacks; } -// returns `true` if one of the pushdown stacks can accept the given char. -static bool llama_grammar_peek( +static std::vector llama_grammar_reject_candidates( + const std::vector> & rules, const std::vector> & stacks, - const uint32_t chr) { + const std::vector & candidates); - for (const auto & stack : stacks) { - if (stack.empty()) { - if (!chr) { - return true; +static std::vector llama_grammar_reject_candidates_for_stack( + const std::vector> & rules, + const std::vector & stack, + const std::vector & candidates) { + + std::vector rejects; + + if (stack.empty()) { + // accept nothing; EOS is handled elsewhere + rejects.insert(rejects.end(), candidates.begin(), candidates.end()); + return rejects; + } + + const llama_grammar_element * stack_pos = stack.back(); + + std::vector next_candidates; + for (auto tok : candidates) { + if (llama_grammar_match_char(stack_pos, tok.code_points[0]).first) { + if (tok.code_points[1] != 0) { + next_candidates.push_back({ tok.index, tok.code_points + 1 }); } - } else if (llama_grammar_match_char(stack.back(), chr).first) { - return true; + } else { + rejects.push_back(tok); } } - return false; + + auto stack_pos_after = llama_grammar_match_char(stack_pos, 0).second; + + // update top of stack to next element, if any + std::vector stack_after(stack.begin(), stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(stack_pos_after)) { + stack_after.push_back(stack_pos_after); + } + std::vector> next_stacks; + llama_grammar_advance_stack(rules, stack_after, next_stacks); + + auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates); + for (auto tok : next_rejects) { + rejects.push_back({ tok.index, tok.code_points - 1 }); + } + + return rejects; } +static std::vector llama_grammar_reject_candidates( + const std::vector> & rules, + const std::vector> & stacks, + const std::vector & candidates) { + LLAMA_ASSERT(!stacks.empty()); // REVIEW + + if (candidates.empty()) { + return std::vector(); + } + + auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates); + + for (size_t i = 1, size = stacks.size(); i < size; ++i) { + rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects); + } + return rejects; +} // // grammar - external @@ -2383,34 +2444,39 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c assert(ctx); const int64_t t_start_sample_us = ggml_time_us(); const llama_token eos = llama_token_eos(); - // since many llama tokens are prefixed with a single space, special case a lookahead on ' ' - const auto stacks_after_space = llama_grammar_accept(grammar->rules, grammar->stacks, U' '); + + bool allow_eos = false; + for (const auto & stack : grammar->stacks) { + if (stack.empty()) { + allow_eos = true; + break; + } + } + + std::vector> decoded_candidates; + std::vector grammar_candidates; for (size_t i = 0; i < candidates->size; ++i) { - const llama_token id = candidates->data[i].id; - const char * str = llama_token_to_str(ctx, id); - - // prune tokens based on first char only - in `llama_grammar_accept_token` we will find the - // full matching prefix of the selected token - bool valid = false; + const llama_token id = candidates->data[i].id; + const char * str = llama_token_to_str(ctx, id); if (id == eos) { - valid = llama_grammar_peek(grammar->stacks, 0); - } else { - const auto decoded = decode_utf8(str); - const uint32_t chr = decoded.first; - if (chr == U' ') { - const char * next = decoded.second; - valid = llama_grammar_peek(stacks_after_space, decode_utf8(next).first); - } else if (chr != 0) { - valid = llama_grammar_peek(grammar->stacks, chr); + if (!allow_eos) { + candidates->data[i].logit = -INFINITY; } - } - - if (!valid) { + } else if (*str == 0) { candidates->data[i].logit = -INFINITY; + } else { + decoded_candidates.push_back(decode_utf8(str)); + grammar_candidates.push_back({ i, decoded_candidates.back().data() }); } } + auto rejects = + llama_grammar_reject_candidates(grammar->rules, grammar->stacks, grammar_candidates); + for (auto reject : rejects) { + candidates->data[reject.index].logit = -INFINITY; + } + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; } @@ -2599,61 +2665,27 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra return result; } -llama_token llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) { +void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) { const int64_t t_start_sample_us = ggml_time_us(); if (token == llama_token_eos()) { for (const auto & stack : grammar->stacks) { if (stack.empty()) { - return token; + return; } } LLAMA_ASSERT(false); } - const char * str = llama_token_to_str(ctx, token); - const char * suffix = str; - - // Find prefix of selected token that matches grammar, expecting at least 1 char - auto decoded = decode_utf8(suffix); - auto new_stacks = llama_grammar_accept(grammar->rules, grammar->stacks, decoded.first); - LLAMA_ASSERT(!new_stacks.empty()); - if (*suffix) { - suffix = decoded.second; - for ( ; *suffix; suffix = decoded.second) { - decoded = decode_utf8(suffix); - new_stacks = llama_grammar_accept(grammar->rules, new_stacks, decoded.first); - if (new_stacks.empty() ) { - break; - } - } - } - - // if full token is matched, accept new stacks - if (!(*suffix)) { - grammar->stacks = new_stacks; - return token; - } - - // otherwise, tokenize the string prefix that did match - llama_token tokens[32]; // TODO - determine actual max token size - const std::string prefix_str(str, suffix - str); - int n_tokens = llama_tokenize(ctx, prefix_str.c_str(), tokens, 32, false); - if (n_tokens < 1) { - return token; // REVIEW - } - - // accept the first token of the matching prefix into the grammar - llama_token first_prefix_token = tokens[0]; - const char * first_prefix_str = llama_token_to_str(ctx, first_prefix_token); - for ( ; *first_prefix_str; first_prefix_str = decoded.second) { - decoded = decode_utf8(first_prefix_str); - grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, decoded.first); - LLAMA_ASSERT(!grammar->stacks.empty()); + const char * str = llama_token_to_str(ctx, token); + // Note terminating 0 in decoded string + auto code_points = decode_utf8(str); + for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { + grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it); } + LLAMA_ASSERT(!grammar->stacks.empty()); ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - return first_prefix_token; } // diff --git a/llama.h b/llama.h index c9b221038..25cd001ad 100644 --- a/llama.h +++ b/llama.h @@ -404,8 +404,8 @@ extern "C" { /// @details Randomly selects a token from the candidates based on their probabilities. LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates); - /// @details Accepts the sampled token into the grammar, possibly transforming to a new token - LLAMA_API llama_token llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token); + /// @details Accepts the sampled token into the grammar + LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token); // Performance information LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);