Improve support for special tokens

This commit is contained in:
Igor Pissolati 2023-06-18 20:11:01 -03:00
parent 93356bdb7a
commit 61a98bc30a
4 changed files with 297 additions and 36 deletions

View file

@ -142,6 +142,7 @@ def find_n_mult(n_ff: int, n_embd: int) -> int:
@dataclass @dataclass
class Params: class Params:
n_vocab: int n_vocab: int
n_vocab_sp:int
n_embd: int n_embd: int
n_mult: int n_mult: int
n_head: int n_head: int
@ -169,6 +170,7 @@ class Params:
return Params( return Params(
n_vocab = n_vocab, n_vocab = n_vocab,
n_vocab_sp= n_vocab,
n_embd = n_embd, n_embd = n_embd,
n_mult = 256, n_mult = 256,
n_head = n_head, n_head = n_head,
@ -191,6 +193,7 @@ class Params:
return Params( return Params(
n_vocab = n_vocab, n_vocab = n_vocab,
n_vocab_sp= n_vocab,
n_embd = n_embd, n_embd = n_embd,
n_mult = n_mult, n_mult = n_mult,
n_head = n_head, n_head = n_head,
@ -215,6 +218,7 @@ class Params:
return Params( return Params(
n_vocab = n_vocab, n_vocab = n_vocab,
n_vocab_sp= n_vocab
n_embd = n_embd, n_embd = n_embd,
n_mult = n_mult, n_mult = n_mult,
n_head = n_head, n_head = n_head,
@ -239,7 +243,7 @@ class Params:
class SentencePieceVocab: class SentencePieceVocab:
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path], vocabtype: Optional[str]) -> None: def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path], fname_special_tokens: Optional[Path], vocabtype: Optional[str]) -> None:
self.vocabtype = vocabtype self.vocabtype = vocabtype
if self.vocabtype == "bpe": if self.vocabtype == "bpe":
self.sentencepiece_tokenizer = json.loads(open(str(fname_tokenizer)).read()) self.sentencepiece_tokenizer = json.loads(open(str(fname_tokenizer)).read())
@ -264,35 +268,46 @@ class SentencePieceVocab:
self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_list) self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_list)
self.fname_tokenizer = fname_tokenizer self.fname_tokenizer = fname_tokenizer
self.fname_added_tokens = fname_added_tokens self.fname_added_tokens = fname_added_tokens
special_tokens: Dict[str, Dict[str, Any]]
if fname_special_tokens is not None:
special_tokens = json.load(open(fname_special_tokens))
else:
special_tokens = {}
token_name_to_id = {"unk_token": self.sentencepiece_tokenizer.unk_id(), "bos_token": self.sentencepiece_tokenizer.bos_id(), "eos_token": self.sentencepiece_tokenizer.eos_id(), "pad_token": self.sentencepiece_tokenizer.pad_id()}
self.special_tokens_map = {token_name_to_id[token_name]: info["content"] if isinstance(info, dict) else info for token_name, info in special_tokens.items() if token_name in token_name_to_id and token_name_to_id[token_name] != -1}
self.vocab_special_size: int = len(self.added_tokens_list) + len(self.special_tokens_map)
def sentencepiece_tokens(self) -> Iterable[Tuple[bytes, float]]: def sentencepiece_tokens(self) -> Iterable[Tuple[bytes, float]]:
tokenizer = self.sentencepiece_tokenizer tokenizer = self.sentencepiece_tokenizer
if self.vocabtype == "bpe": if self.vocabtype == "bpe":
from transformers.models.gpt2 import tokenization_gpt2 from transformers.models.gpt2 import tokenization_gpt2
byte_encoder = tokenization_gpt2.bytes_to_unicode() byte_encoder = tokenization_gpt2.bytes_to_unicode()
byte_decoder = {v: k for k, v in byte_encoder.items()} byte_decoder = {v: k for k, v in byte_encoder.items()}
for i, item in enumerate(tokenizer): for i, item in enumerate(tokenizer):
text: bytes text: bytes
text = b''.join([x.to_bytes(1, byteorder='big') for x in [byte_decoder[y] for y in item]]) text = b''.join([x.to_bytes(1, byteorder='big') for x in [byte_decoder[y] for y in item]])
score: float = -i score: float = -i
yield text, score yield text, score
else: else:
for i in range(tokenizer.vocab_size()): special_tokens = [tokenizer.bos_id(), tokenizer.eos_id(), tokenizer.pad_id()]
text: bytes for i in range(tokenizer.vocab_size()):
if tokenizer.is_unknown(i): text: bytes
text = " \u2047 ".encode("utf-8") if tokenizer.is_unknown(i):
elif tokenizer.is_control(i): text = self.special_tokens_map.get(i, " \u2047 ").encode("utf-8")
text = b"" elif i in special_tokens:
elif tokenizer.is_byte(i): text = self.special_tokens_map.get(i, "").encode("utf-8")
piece = tokenizer.id_to_piece(i) elif tokenizer.is_control(i):
if len(piece) != 6: text = b""
raise Exception(f"Invalid token: {piece}") elif tokenizer.is_byte(i):
byte_value = int(piece[3:-1], 16) piece = tokenizer.id_to_piece(i)
text = struct.pack("B", byte_value) if len(piece) != 6:
else: raise Exception(f"Invalid token: {piece}")
text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8") byte_value = int(piece[3:-1], 16)
score: float = tokenizer.get_score(i) text = struct.pack("B", byte_value)
yield text, score else:
text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8")
score: float = tokenizer.get_score(i)
yield text, score
def added_tokens(self) -> Iterable[Tuple[bytes, float]]: def added_tokens(self) -> Iterable[Tuple[bytes, float]]:
for text in self.added_tokens_list: for text in self.added_tokens_list:
@ -303,6 +318,12 @@ class SentencePieceVocab:
yield from self.sentencepiece_tokens() yield from self.sentencepiece_tokens()
yield from self.added_tokens() yield from self.added_tokens()
def all_special_tokens(self) -> Iterable[int]:
for token_id in self.special_tokens_map.keys():
yield token_id
for i in range(len(self.added_tokens_list)):
yield self.vocab_size_base + i
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<SentencePieceVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>" return f"<SentencePieceVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
@ -310,11 +331,16 @@ class SentencePieceVocab:
class GGMLVocab: class GGMLVocab:
def __init__(self, tokens: List[Tuple[bytes, float]]): def __init__(self, tokens: List[Tuple[bytes, float]]):
self.tokens = tokens self.tokens = tokens
self.special_tokens = []
self.vocab_size = len(tokens) self.vocab_size = len(tokens)
self.vocab_special_size = 0
def all_tokens(self) -> Iterable[Tuple[bytes, float]]: def all_tokens(self) -> Iterable[Tuple[bytes, float]]:
return self.tokens return self.tokens
def all_special_tokens(self) -> Iterable[int]:
return self.special_tokens
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<GGMLVocab with {self.vocab_size} tokens>" return f"<GGMLVocab with {self.vocab_size} tokens>"
@ -1066,8 +1092,9 @@ class OutputFile:
def write_file_header(self, params: Params, file_type: GGMLFileType) -> None: def write_file_header(self, params: Params, file_type: GGMLFileType) -> None:
self.fout.write(b"ggjt"[::-1]) # magic self.fout.write(b"ggjt"[::-1]) # magic
values = [ values = [
1, # file version 4, # file version
params.n_vocab, params.n_vocab,
params.n_vocab_sp,
params.n_embd, params.n_embd,
params.n_mult, params.n_mult,
params.n_head, params.n_head,
@ -1089,11 +1116,14 @@ class OutputFile:
self.fout.write(struct.pack("i", len(text))) self.fout.write(struct.pack("i", len(text)))
self.fout.write(text) self.fout.write(text)
self.fout.write(struct.pack("f", score)) self.fout.write(struct.pack("f", score))
for token_id in vocab.all_special_tokens():
self.fout.write(struct.pack("i", token_id))
@staticmethod @staticmethod
def write_vocab_only(fname_out: Path, vocab: Vocab) -> None: def write_vocab_only(fname_out: Path, vocab: Vocab) -> None:
of = OutputFile(fname_out) of = OutputFile(fname_out)
params = Params(n_vocab=vocab.vocab_size, n_embd=0, n_mult=0, n_head=1, n_layer=0) params = Params(n_vocab=vocab.vocab_size, n_vocab_sp=vocab.vocab_special_size, n_embd=0, n_mult=0,
n_head=1, n_layer=0)
of = OutputFile(fname_out) of = OutputFile(fname_out)
of.write_file_header(params, file_type=GGMLFileType.AllF32) of.write_file_header(params, file_type=GGMLFileType.AllF32)
of.write_vocab(vocab) of.write_vocab(vocab)
@ -1249,8 +1279,10 @@ def load_vocab(path: Path, vocabtype: Optional[str]) -> SentencePieceVocab:
f"Could not find tokenizer.model in {path} or its parent; " f"Could not find tokenizer.model in {path} or its parent; "
"if it's in another directory, pass the directory as --vocab-dir") "if it's in another directory, pass the directory as --vocab-dir")
added_tokens_path = path.parent / "added_tokens.json" added_tokens_path = path.parent / "added_tokens.json"
special_tokens_path = path.parent / "special_tokens_map.json"
tokenizer_config_path = path.parent / "tokenizer_config.json"
print(f"Loading vocab file {path}") print(f"Loading vocab file {path}")
return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None, return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None, special_tokens_path if special_tokens_path.exists() else tokenizer_config_path if tokenizer_config_path.exists() else None,
vocabtype) vocabtype)
@ -1313,6 +1345,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent
vocab = load_vocab(vocab_dir, args.vocabtype) vocab = load_vocab(vocab_dir, args.vocabtype)
params = Params.load(model_plus) params = Params.load(model_plus)
params.n_vocab_sp = vocab.vocab_special_size
model = model_plus.model model = model_plus.model
model = do_necessary_conversions(model, params) model = do_necessary_conversions(model, params)
output_type = pick_output_type(model, args.outtype) output_type = pick_output_type(model, args.outtype)

View file

@ -14,6 +14,8 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include <map>
#include <unordered_map>
#include <stdexcept> #include <stdexcept>
#ifdef __has_include #ifdef __has_include
@ -541,4 +543,166 @@ struct llama_ctx_buffer {
typedef llama_buffer llama_ctx_buffer; typedef llama_buffer llama_ctx_buffer;
#endif #endif
struct llama_trie_node {
llama_trie_node(): is_terminator(false) {}
std::unordered_map<char, llama_trie_node*> children;
bool is_terminator;
};
// Trie in C++. Creates a Trie out of a list of words. The trie is used to split on multiple delimiters in one pass
// Ported from: https://github.com/huggingface/transformers/blob/ee88ae59940fd4b2c8fc119373143d7a1175c651/src/transformers/tokenization_utils.py#L52
struct llama_trie {
public:
llama_trie(): root_(new llama_trie_node()) {}
void add(const std::string & word) {
if (word.empty()) {
return;
}
llama_trie_node *ref = root_;
for (char c : word) {
if (ref->children.find(c) == ref->children.end()) {
ref->children[c] = new llama_trie_node();
}
ref = ref->children[c];
}
ref->is_terminator = true;
}
// Will look for the words added to the trie within `text`. Output is the boundaries of the words found.
// Note that this trie will match the longest possible word first!
std::vector<int> split(const std::string & text) const {
std::map<int, llama_trie_node*> states;
std::vector<int> offsets{0};
int skip = 0;
for (int current = 0; current < text.size(); current++) {
char current_char = text[current];
if (skip > 0 && current < skip) {
// Prevents the lookahead for matching twice
// like extra_id_100 and id_100
continue;
}
// Whenever we found a match, we need to drop everything
// this is a greedy algorithm, it will match on the first found token
bool reset = false;
// In this case, we already have partial matches (But unfinished)
for (auto state = states.begin(); state != states.end(); ) {
int start = state->first;
llama_trie_node *trie_pointer = state->second;
if (trie_pointer->is_terminator) {
// This is a final match, we need to reset and
// store the results in `offsets`.
// Lookahead to match longest first
// Important in case of extra_id_1 vs extra_id_100
// Here we are also actively looking for other earlier partial
// matches
// "[CLS]", "L", we need to match CLS even if L is special
int end = 0;
for (const auto & look : states) {
int lookstart = look.first;
llama_trie_node *looktrie_pointer = look.second;
int lookahead_index = 0;
if (lookstart > start) {
// This partial match is later, we can stop looking
break;
}
if (lookstart < start) {
// This partial match is earlier, the trie pointer
// was already updated, so index is + 1
lookahead_index = current + 1;
end = current + 1;
} else {
// Here lookstart == start and
// looktrie_pointer == trie_pointer
// It wasn't updated yet so indices are current ones
lookahead_index = current;
end = current;
}
char next_char = lookahead_index < text.size() ? text[lookahead_index] : '\0';
if (looktrie_pointer->is_terminator) {
start = lookstart;
end = lookahead_index;
skip = lookahead_index;
}
auto looktrie_pointer_it = looktrie_pointer->children.find(next_char);
while (looktrie_pointer_it != looktrie_pointer->children.end()) {
looktrie_pointer = looktrie_pointer_it->second;
lookahead_index++;
if (looktrie_pointer->is_terminator) {
start = lookstart;
end = lookahead_index;
skip = lookahead_index;
}
if (lookahead_index == text.size()) {
// End of string
break;
}
next_char = text[lookahead_index];
looktrie_pointer_it = looktrie_pointer->children.find(next_char);
}
}
offsets.push_back(start);
offsets.push_back(end);
reset = true;
break;
}
auto trie_pointer_it = trie_pointer->children.find(current_char);
if (trie_pointer_it != trie_pointer->children.end()) {
// The current character being looked at has a match within the trie
// update the pointer (it will be stored back into states later).
trie_pointer = trie_pointer_it->second;
states[start] = trie_pointer;
++state;
} else {
// The new character has not match in the trie, we need
// to stop keeping track of this partial match.
state = states.erase(state);
}
}
if (reset) {
// Clear the full start (we found a real match)
states.clear();
}
// If this character is a starting character within the trie
// start keeping track of this partial match.
auto children_it = root_->children.find(current_char);
if (current >= skip && children_it != root_->children.end()) {
states[current] = children_it->second;
}
}
// We have a cut at the end with states.
for (const auto & state : states) {
int start = state.first;
llama_trie_node *trie_pointer = state.second;
if (trie_pointer->is_terminator) {
// This is a final match, we need to reset and
// store the results in `offsets`.
int end = text.size();
offsets.push_back(start);
offsets.push_back(end);
break;
}
}
offsets.push_back(text.size());
return offsets;
}
private:
llama_trie_node *root_;
};
#endif #endif

View file

@ -181,6 +181,7 @@ static const std::map<e_model, size_t> & VRAM_REQ_SCRATCH_PER_CONTEXT()
// default hparams (LLaMA 7B) // default hparams (LLaMA 7B)
struct llama_hparams { struct llama_hparams {
uint32_t n_vocab = 32000; uint32_t n_vocab = 32000;
uint32_t n_vocab_sp = 0;
uint32_t n_ctx = 512; // this is provided as user input? uint32_t n_ctx = 512; // this is provided as user input?
uint32_t n_embd = 4096; uint32_t n_embd = 4096;
uint32_t n_mult = 256; uint32_t n_mult = 256;
@ -277,6 +278,11 @@ struct llama_vocab {
std::unordered_map<token, id> token_to_id; std::unordered_map<token, id> token_to_id;
std::vector<token_score> id_to_token; std::vector<token_score> id_to_token;
llama_trie special_token_trie;
std::unordered_map<token, id> special_token_to_id;
std::vector<id> special_tokens;
size_t max_special_token_length;
}; };
struct llama_model { struct llama_model {
@ -494,6 +500,7 @@ enum llama_file_version {
LLAMA_FILE_VERSION_GGJT_V1, // added padding LLAMA_FILE_VERSION_GGJT_V1, // added padding
LLAMA_FILE_VERSION_GGJT_V2, // changed quantization format LLAMA_FILE_VERSION_GGJT_V2, // changed quantization format
LLAMA_FILE_VERSION_GGJT_V3, // changed Q4 and Q8 quantization format LLAMA_FILE_VERSION_GGJT_V3, // changed Q4 and Q8 quantization format
LLAMA_FILE_VERSION_GGJT_V4, // improved support for added/special tokens
}; };
struct llama_file_loader { struct llama_file_loader {
@ -531,6 +538,7 @@ struct llama_file_loader {
case 1: file_version = LLAMA_FILE_VERSION_GGJT_V1; return; case 1: file_version = LLAMA_FILE_VERSION_GGJT_V1; return;
case 2: file_version = LLAMA_FILE_VERSION_GGJT_V2; return; case 2: file_version = LLAMA_FILE_VERSION_GGJT_V2; return;
case 3: file_version = LLAMA_FILE_VERSION_GGJT_V3; return; case 3: file_version = LLAMA_FILE_VERSION_GGJT_V3; return;
case 4: file_version = LLAMA_FILE_VERSION_GGJT_V4; return;
} }
} }
@ -539,6 +547,7 @@ struct llama_file_loader {
} }
void read_hparams() { void read_hparams() {
hparams.n_vocab = file.read_u32(); hparams.n_vocab = file.read_u32();
hparams.n_vocab_sp = file_version >= LLAMA_FILE_VERSION_GGJT_V4 ? file.read_u32() : 0;
hparams.n_embd = file.read_u32(); hparams.n_embd = file.read_u32();
hparams.n_mult = file.read_u32(); hparams.n_mult = file.read_u32();
hparams.n_head = file.read_u32(); hparams.n_head = file.read_u32();
@ -566,6 +575,21 @@ struct llama_file_loader {
tok_score.tok = std::move(word); tok_score.tok = std::move(word);
tok_score.score = score; tok_score.score = score;
} }
vocab.special_token_to_id.reserve(hparams.n_vocab_sp);
for (uint32_t i = 0; i < hparams.n_vocab_sp; i++) {
uint32_t token_id = file.read_u32();
const auto & token = vocab.id_to_token[token_id].tok;
vocab.special_token_trie.add(token);
vocab.special_tokens.push_back(token_id);
vocab.special_token_to_id[token] = token_id;
if (vocab.max_special_token_length < token.size()) {
vocab.max_special_token_length = token.size();
}
}
} }
void read_tensor_metadata(llama_load_tensors_map & tensors_map) { void read_tensor_metadata(llama_load_tensors_map & tensors_map) {
while (file.tell() < file.size) { while (file.tell() < file.size) {
@ -631,6 +655,7 @@ struct llama_file_saver {
void write_hparams(enum llama_ftype new_ftype) { void write_hparams(enum llama_ftype new_ftype) {
const llama_hparams & hparams = any_file_loader->hparams; const llama_hparams & hparams = any_file_loader->hparams;
file.write_u32(hparams.n_vocab); file.write_u32(hparams.n_vocab);
file.write_u32(hparams.n_vocab_sp);
file.write_u32(hparams.n_embd); file.write_u32(hparams.n_embd);
file.write_u32(hparams.n_mult); file.write_u32(hparams.n_mult);
file.write_u32(hparams.n_head); file.write_u32(hparams.n_head);
@ -649,6 +674,10 @@ struct llama_file_saver {
file.write_raw(token_score.tok.data(), token_score.tok.size()); file.write_raw(token_score.tok.data(), token_score.tok.size());
file.write_raw(&token_score.score, sizeof(token_score.score)); file.write_raw(&token_score.score, sizeof(token_score.score));
} }
uint32_t n_vocab_sp = any_file_loader->hparams.n_vocab_sp;
for (uint32_t i = 0; i < n_vocab; i++) {
file.write_u32(any_file_loader->vocab.special_tokens[i]);
}
} }
void write_tensor(llama_load_tensor & tensor, enum ggml_type new_type, const void * new_data, size_t new_size) { void write_tensor(llama_load_tensor & tensor, enum ggml_type new_type, const void * new_data, size_t new_size) {
switch (new_type) { switch (new_type) {
@ -975,7 +1004,8 @@ static const char *llama_file_version_name(llama_file_version version) {
case LLAMA_FILE_VERSION_GGMF_V1: return "ggmf v1 (old version with no mmap support)"; case LLAMA_FILE_VERSION_GGMF_V1: return "ggmf v1 (old version with no mmap support)";
case LLAMA_FILE_VERSION_GGJT_V1: return "ggjt v1 (pre #1405)"; case LLAMA_FILE_VERSION_GGJT_V1: return "ggjt v1 (pre #1405)";
case LLAMA_FILE_VERSION_GGJT_V2: return "ggjt v2 (pre #1508)"; case LLAMA_FILE_VERSION_GGJT_V2: return "ggjt v2 (pre #1508)";
case LLAMA_FILE_VERSION_GGJT_V3: return "ggjt v3 (latest)"; case LLAMA_FILE_VERSION_GGJT_V3: return "ggjt v3 (pre #1931)";
case LLAMA_FILE_VERSION_GGJT_V4: return "ggjt v4 (latest)";
} }
return "unknown"; return "unknown";
@ -1960,18 +1990,20 @@ struct llama_sp_bigram {
struct llama_tokenizer { struct llama_tokenizer {
llama_tokenizer(const llama_vocab & vocab): vocab_(vocab) {} llama_tokenizer(const llama_vocab & vocab): vocab_(vocab) {}
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) { void tokenize(const char * text, size_t len, std::vector<llama_vocab::id> & output) {
symbols_.clear();
// split string into utf8 chars // split string into utf8 chars
int index = 0; int index = 0;
size_t offs = 0; size_t offs = 0;
while (offs < text.size()) { while (offs < len) {
llama_sp_symbol sym; llama_sp_symbol sym;
size_t char_len = std::min(text.size() - offs, utf8_len(text[offs])); size_t char_len = std::min(len - offs, utf8_len(text[offs]));
sym.text = text.c_str() + offs; sym.text = text + offs;
sym.n = char_len; sym.n = char_len;
offs += char_len; offs += char_len;
sym.prev = index - 1; sym.prev = index - 1;
sym.next = offs == text.size() ? -1 : index + 1; sym.next = offs == len ? -1 : index + 1;
index++; index++;
symbols_.emplace_back(sym); symbols_.emplace_back(sym);
} }
@ -2074,7 +2106,33 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co
output.push_back(llama_token_bos()); output.push_back(llama_token_bos());
} }
tokenizer.tokenize(text, output); if (vocab.special_token_to_id.empty()) {
tokenizer.tokenize(text.c_str(), text.size(), output);
return output;
}
auto offsets = vocab.special_token_trie.split(text);
int start = 0;
for (int end : offsets) {
if (start >= end) {
continue;
}
size_t part_length = end - start;
//printf("\"%.*s\"\n", (int) part_length, text.c_str() + start);
if (vocab.max_special_token_length < part_length) {
tokenizer.tokenize(text.c_str() + start, part_length, output);
} else {
auto token_it = vocab.special_token_to_id.find(std::string(text.c_str() + start, part_length));
if (token_it != vocab.special_token_to_id.end()) {
output.push_back(token_it->second);
} else {
tokenizer.tokenize(text.c_str() + start, part_length, output);
}
}
start = end;
}
return output; return output;
} }
@ -4212,6 +4270,10 @@ llama_token llama_token_nl() {
return 13; return 13;
} }
bool llama_is_special_token(const struct llama_context *ctx, llama_token token) {
return std::find(ctx->vocab.special_tokens.begin(), ctx->vocab.special_tokens.end(), token) != ctx->vocab.special_tokens.end();
}
struct llama_timings llama_get_timings(struct llama_context * ctx) { struct llama_timings llama_get_timings(struct llama_context * ctx) {
struct llama_timings result = { struct llama_timings result = {
/*.t_start_ms =*/ 1e-3 * ctx->t_start_us, /*.t_start_ms =*/ 1e-3 * ctx->t_start_us,

View file

@ -40,7 +40,7 @@
#define LLAMA_FILE_MAGIC_GGML 0x67676d6cu // 'ggml' #define LLAMA_FILE_MAGIC_GGML 0x67676d6cu // 'ggml'
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn' #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
#define LLAMA_FILE_VERSION 3 #define LLAMA_FILE_VERSION 4
#define LLAMA_FILE_MAGIC LLAMA_FILE_MAGIC_GGJT #define LLAMA_FILE_MAGIC LLAMA_FILE_MAGIC_GGJT
#define LLAMA_FILE_MAGIC_UNVERSIONED LLAMA_FILE_MAGIC_GGML #define LLAMA_FILE_MAGIC_UNVERSIONED LLAMA_FILE_MAGIC_GGML
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
@ -373,6 +373,8 @@ extern "C" {
LLAMA_API llama_token llama_token_eos(); // end-of-sentence LLAMA_API llama_token llama_token_eos(); // end-of-sentence
LLAMA_API llama_token llama_token_nl(); // next-line LLAMA_API llama_token llama_token_nl(); // next-line
LLAMA_API bool llama_is_special_token(const struct llama_context * ctx, llama_token token);
// Grammar // Grammar
// //
LLAMA_API struct llama_grammar * llama_grammar_init( LLAMA_API struct llama_grammar * llama_grammar_init(