From 236194fc3d4241b8bfc9dc7656f497076612e99a Mon Sep 17 00:00:00 2001 From: Evan Jones Date: Mon, 14 Aug 2023 04:41:50 -0400 Subject: [PATCH] add more comments --- llama.cpp | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/llama.cpp b/llama.cpp index 418105f72..a7e3e03a5 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2083,13 +2083,15 @@ static std::vector llama_tokenize(const llama_vocab & vocab, co // struct llama_partial_utf8 { - uint32_t value; - int n_remain; + uint32_t value; // bit value so far (unshifted) + int n_remain; // num bytes remaining; -1 indicates invalid sequence }; struct llama_grammar { const std::vector> rules; std::vector> stacks; + + // buffer for partially generated UTF-8 sequence from accepted tokens llama_partial_utf8 partial_utf8; }; @@ -2099,8 +2101,8 @@ struct llama_grammar_candidate { llama_partial_utf8 partial_utf8; }; -// Decodes a UTF-8 string which may end in an incomplete sequence. Otherwise, assumes valid UTF-8. -// adds a terminating 0 for use as pointer +// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as +// pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`. std::pair, llama_partial_utf8> decode_utf8( const char * src, llama_partial_utf8 partial_start) { @@ -2114,7 +2116,7 @@ std::pair, llama_partial_utf8> decode_utf8( while (*pos != 0 && n_remain > 0) { uint8_t next_byte = static_cast(*pos); if ((next_byte >> 6) != 2) { - // REVIEW invalid sequence, abort + // invalid sequence, abort code_points.push_back(0); return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, -1 }); } @@ -2134,7 +2136,7 @@ std::pair, llama_partial_utf8> decode_utf8( n_remain = lookup[highbits] - 1; if (n_remain < 0) { - // REVIEW invalid sequence, abort + // invalid sequence, abort code_points.clear(); code_points.push_back(0); return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, n_remain }); @@ -2191,6 +2193,9 @@ static std::pair llama_grammar_match_char( return std::make_pair(found == is_positive_char, pos); } +// returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char +// range at pos (regular or inverse range) +// asserts that pos is pointing to a char range element static bool llama_grammar_match_partial_char( const llama_grammar_element * pos, const llama_partial_utf8 partial_utf8) { @@ -2201,6 +2206,7 @@ static bool llama_grammar_match_partial_char( uint32_t partial_value = partial_utf8.value; int n_remain = partial_utf8.n_remain; + // invalid sequence or 7-bit char split across 2 bytes (overlong) if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) { return false; }