add more comments

This commit is contained in:
Evan Jones 2023-08-14 04:41:50 -04:00
parent 961f2ab73f
commit 236194fc3d

View file

@ -2083,13 +2083,15 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co
//
struct llama_partial_utf8 {
uint32_t value;
int n_remain;
uint32_t value; // bit value so far (unshifted)
int n_remain; // num bytes remaining; -1 indicates invalid sequence
};
struct llama_grammar {
const std::vector<std::vector<llama_grammar_element>> rules;
std::vector<std::vector<const llama_grammar_element *>> stacks;
// buffer for partially generated UTF-8 sequence from accepted tokens
llama_partial_utf8 partial_utf8;
};
@ -2099,8 +2101,8 @@ struct llama_grammar_candidate {
llama_partial_utf8 partial_utf8;
};
// Decodes a UTF-8 string which may end in an incomplete sequence. Otherwise, assumes valid UTF-8.
// adds a terminating 0 for use as pointer
// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
// pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`.
std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
const char * src,
llama_partial_utf8 partial_start) {
@ -2114,7 +2116,7 @@ std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
while (*pos != 0 && n_remain > 0) {
uint8_t next_byte = static_cast<uint8_t>(*pos);
if ((next_byte >> 6) != 2) {
// REVIEW invalid sequence, abort
// invalid sequence, abort
code_points.push_back(0);
return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, -1 });
}
@ -2134,7 +2136,7 @@ std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
n_remain = lookup[highbits] - 1;
if (n_remain < 0) {
// REVIEW invalid sequence, abort
// invalid sequence, abort
code_points.clear();
code_points.push_back(0);
return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, n_remain });
@ -2191,6 +2193,9 @@ static std::pair<bool, const llama_grammar_element *> llama_grammar_match_char(
return std::make_pair(found == is_positive_char, pos);
}
// returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char
// range at pos (regular or inverse range)
// asserts that pos is pointing to a char range element
static bool llama_grammar_match_partial_char(
const llama_grammar_element * pos,
const llama_partial_utf8 partial_utf8) {
@ -2201,6 +2206,7 @@ static bool llama_grammar_match_partial_char(
uint32_t partial_value = partial_utf8.value;
int n_remain = partial_utf8.n_remain;
// invalid sequence or 7-bit char split across 2 bytes (overlong)
if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) {
return false;
}