First pass at implementing suggested changes

This commit is contained in:
KerfuffleV2 2023-08-28 08:56:57 -06:00
parent bb6b64d5e5
commit f82aec99a4
7 changed files with 125 additions and 127 deletions

View file

@ -13,8 +13,6 @@ from typing import Any, List
from pathlib import Path from pathlib import Path
from transformers import AutoTokenizer from transformers import AutoTokenizer
from convert import SpecialVocab
def bytes_to_unicode(): def bytes_to_unicode():
# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py # ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
""" """
@ -161,7 +159,7 @@ if Path(dir_model + "/tokenizer.json").is_file():
gguf_writer.add_token_scores(scores) gguf_writer.add_token_scores(scores)
gguf_writer.add_token_types(toktypes) gguf_writer.add_token_types(toktypes)
special_vocab = SpecialVocab(Path(dir_model)) special_vocab = gguf.SpecialVocab(Path(dir_model))
special_vocab.add_to_gguf(gguf_writer) special_vocab.add_to_gguf(gguf_writer)
# TENSORS # TENSORS

View file

@ -13,8 +13,6 @@ from typing import Any, List
from pathlib import Path from pathlib import Path
from transformers import AutoTokenizer from transformers import AutoTokenizer
from convert import SpecialVocab
# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py # ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
@ -153,7 +151,7 @@ if Path(dir_model + "/tokenizer.json").is_file():
gguf_writer.add_token_list(tokens) gguf_writer.add_token_list(tokens)
special_vocab = SpecialVocab(Path(dir_model)) special_vocab = gguf.SpecialVocab(Path(dir_model))
special_vocab.add_to_gguf(gguf_writer) special_vocab.add_to_gguf(gguf_writer)
# TENSORS # TENSORS

View file

@ -15,8 +15,6 @@ from typing import Any, List, TypeAlias
from pathlib import Path from pathlib import Path
from sentencepiece import SentencePieceProcessor from sentencepiece import SentencePieceProcessor
from convert import SpecialVocab
#NDArray = np.ndarray[Any, Any] #NDArray = np.ndarray[Any, Any]
# compatible with python < 3.9 # compatible with python < 3.9
NDArray: 'TypeAlias' = 'np.ndarray[Any, Any]' NDArray: 'TypeAlias' = 'np.ndarray[Any, Any]'
@ -182,7 +180,7 @@ if Path(dir_model + "/tokenizer.model").is_file():
gguf_writer.add_token_scores(scores) gguf_writer.add_token_scores(scores)
gguf_writer.add_token_types(toktypes) gguf_writer.add_token_types(toktypes)
special_vocab = SpecialVocab(Path(dir_model)) special_vocab = gguf.SpecialVocab(Path(dir_model))
special_vocab.add_to_gguf(gguf_writer) special_vocab.add_to_gguf(gguf_writer)
# TENSORS # TENSORS

View file

@ -299,7 +299,7 @@ def handle_metadata(cfg, hp):
raise ValueError('Unable to load metadata') raise ValueError('Unable to load metadata')
vocab = convert.load_vocab(cfg.vocab_dir if cfg.vocab_dir is not None else cfg.model_metadata_dir, cfg.vocabtype) vocab = convert.load_vocab(cfg.vocab_dir if cfg.vocab_dir is not None else cfg.model_metadata_dir, cfg.vocabtype)
# FIXME: Respect cfg.vocab_dir? # FIXME: Respect cfg.vocab_dir?
svocab = convert.SpecialVocab(cfg.model_metadata_dir) svocab = gguf.SpecialVocab(cfg.model_metadata_dir)
convert.check_vocab_size(params, vocab) convert.check_vocab_size(params, vocab)
return (params, vocab, svocab) return (params, vocab, svocab)

View file

@ -13,8 +13,6 @@ from typing import Any, List, Optional, TypeAlias
from pathlib import Path from pathlib import Path
from sentencepiece import SentencePieceProcessor from sentencepiece import SentencePieceProcessor
from convert import SpecialVocab
#NDArray = np.ndarray[Any, Any] #NDArray = np.ndarray[Any, Any]
# compatible with python < 3.9 # compatible with python < 3.9
NDArray: 'TypeAlias' = 'np.ndarray[Any, Any]' NDArray: 'TypeAlias' = 'np.ndarray[Any, Any]'
@ -191,7 +189,7 @@ if Path(dir_model + "/tokenizer.model").is_file():
gguf_writer.add_token_scores(scores) gguf_writer.add_token_scores(scores)
gguf_writer.add_token_types(toktypes) gguf_writer.add_token_types(toktypes)
special_vocab = SpecialVocab(Path(dir_model)) special_vocab = gguf.SpecialVocab(Path(dir_model))
special_vocab.add_to_gguf(gguf_writer) special_vocab.add_to_gguf(gguf_writer)
# TENSORS # TENSORS

View file

@ -418,73 +418,6 @@ class SentencePieceVocab:
Vocab = Union[BpeVocab, SentencePieceVocab] Vocab = Union[BpeVocab, SentencePieceVocab]
class SpecialVocab:
merges: List[str] = []
special_token_types: Tuple[str, ...] = tuple(('bos', 'eos', 'unk', 'sep', 'pad'))
special_token_ids: Dict[str, int] = {}
def __init__(self, path: Path, special_token_types: Optional[Tuple[str, ...]] = None):
self.special_token_ids = {}
if special_token_types is not None:
self.special_token_types = special_token_types
self.load(path)
def load(self, path: Path):
if not self.try_load_from_tokenizer_json(path):
self.try_load_from_config_json(path)
def try_load_from_tokenizer_json(self, path: Path) -> bool:
tokenizer_file = path / 'tokenizer.json'
if not tokenizer_file.is_file():
return False
with open(tokenizer_file, 'r', encoding = 'utf-8') as f:
tokenizer = json.load(f)
merges = tokenizer.get('model', {}).get('merges')
if isinstance(merges, list) and len(merges) > 0 and isinstance(merges[0], str):
self.merges = merges
tokenizer_config_file = path / 'tokenizer_config.json'
added_tokens = tokenizer.get('added_tokens')
if added_tokens is None or not tokenizer_config_file.is_file():
return True
with open(tokenizer_config_file, 'r', encoding = 'utf-8') as f:
tokenizer_config = json.load(f)
for typ in self.special_token_types:
tc_content = (tokenizer_config.get(f'{typ}_token') or {}).get('content')
if not isinstance(tc_content, str):
continue
for maybe_token_id in (atok.get('id') for atok in added_tokens if atok.get('content') == tc_content):
if isinstance(maybe_token_id, int):
self.special_token_ids[typ] = maybe_token_id
break
return True
def try_load_from_config_json(self, path: Path) -> bool:
config_file = path / 'config.json'
if not config_file.is_file():
return False
with open(config_file, 'r', encoding = 'utf-8') as f:
config = json.load(f)
for typ in self.special_token_types:
maybe_token_id = config.get(f'{typ}_token_id')
if isinstance(maybe_token_id, int):
self.special_token_ids[typ] = maybe_token_id
return True
def add_to_gguf(self, gw: gguf.GGUFWriter):
if len(self.merges) > 0:
print(f'SpecialVocab: Adding {len(self.merges)} merge(s).')
gw.add_token_merges(self.merges)
for typ, tokid in self.special_token_ids.items():
handler: Optional[Callable[[int], None]] = getattr(gw, f'add_{typ}_token_id', None)
if handler is None:
print(f'SpecialVocab: WARNING: No handler for special token type {typ} with id {tokid} - skipping')
continue
print(f'SpecialVocab: Setting special token type {typ} to {tokid}')
handler(tokid)
def __repr__(self):
return f'<SpecialVocab with {len(self.merges)} merges and special tokens {self.special_token_ids if self.special_token_ids else "unset"}>'
# #
# data loading # data loading
# TODO: reuse (probably move to gguf.py?) # TODO: reuse (probably move to gguf.py?)
@ -514,7 +447,7 @@ class Tensor(metaclass=ABCMeta):
def to_ggml(self) -> 'GGMLCompatibleTensor': ... def to_ggml(self) -> 'GGMLCompatibleTensor': ...
def bf16_to_fp32(bf16_arr: np.ndarray) -> np.ndarray: def bf16_to_fp32(bf16_arr: np.ndarray) -> NDArray:
assert bf16_arr.dtype == np.uint16, f"Input array should be of dtype uint16, but got {bf16_arr.dtype}" assert bf16_arr.dtype == np.uint16, f"Input array should be of dtype uint16, but got {bf16_arr.dtype}"
fp32_arr = bf16_arr.astype(np.uint32) << 16 fp32_arr = bf16_arr.astype(np.uint32) << 16
return fp32_arr.view(np.float32) return fp32_arr.view(np.float32)
@ -911,7 +844,7 @@ class OutputFile:
self.gguf.add_token_scores(scores) self.gguf.add_token_scores(scores)
self.gguf.add_token_types(toktypes) self.gguf.add_token_types(toktypes)
def add_meta_special_vocab(self, svocab: SpecialVocab) -> None: def add_meta_special_vocab(self, svocab: gguf.SpecialVocab) -> None:
svocab.add_to_gguf(self.gguf) svocab.add_to_gguf(self.gguf)
def add_tensor_info(self, name: str, tensor: LazyTensor) -> None: def add_tensor_info(self, name: str, tensor: LazyTensor) -> None:
@ -932,7 +865,7 @@ class OutputFile:
self.gguf.close() self.gguf.close()
@staticmethod @staticmethod
def write_vocab_only(fname_out: Path, params: Params, vocab: Vocab, svocab: SpecialVocab) -> None: def write_vocab_only(fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab) -> None:
check_vocab_size(params, vocab) check_vocab_size(params, vocab)
of = OutputFile(fname_out) of = OutputFile(fname_out)
@ -960,7 +893,7 @@ class OutputFile:
return dt.quantize(arr) return dt.quantize(arr)
@staticmethod @staticmethod
def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, svocab: SpecialVocab, concurrency: int = DEFAULT_CONCURRENCY) -> None: def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, svocab: gguf.SpecialVocab, concurrency: int = DEFAULT_CONCURRENCY) -> None:
check_vocab_size(params, vocab) check_vocab_size(params, vocab)
of = OutputFile(fname_out) of = OutputFile(fname_out)
@ -1014,7 +947,7 @@ def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyM
def convert_model_names(model: LazyModel, params: Params) -> LazyModel: def convert_model_names(model: LazyModel, params: Params) -> LazyModel:
tmap = gguf.TensorNameMap(ARCH, params.n_layer) tmap = gguf.TensorNameMap(ARCH, params.n_layer)
should_skip: Set[gguf.MODEL_TENSOR] = gguf.MODEL_TENSOR_SKIP.get(ARCH, set()) should_skip: Set[gguf.MODEL_TENSOR] = set(gguf.MODEL_TENSOR_SKIP.get(ARCH, []))
tmp = model tmp = model
@ -1036,7 +969,7 @@ def convert_model_names(model: LazyModel, params: Params) -> LazyModel:
out: LazyModel = {} out: LazyModel = {}
for name, lazy_tensor in model.items(): for name, lazy_tensor in model.items():
tensor_type, name_new = tmap.get_both(name, try_suffixes = (".weight", ".bias")) or (None, None) tensor_type, name_new = tmap.get_type_and_name(name, try_suffixes = (".weight", ".bias")) or (None, None)
if name_new is None: if name_new is None:
raise Exception(f"Unexpected tensor name: {name}") raise Exception(f"Unexpected tensor name: {name}")
@ -1190,7 +1123,6 @@ def main(args_in: Optional[List[str]] = None) -> None:
if not args.vocab_only: if not args.vocab_only:
model_plus = load_some_model(args.model) model_plus = load_some_model(args.model)
else: else:
# You can no longer use guessed parameters for your vocab only model. Does anyone actually care?
model_plus = ModelPlus(model = {}, paths = [args.model / 'dummy'], format = 'none', vocab = None) model_plus = ModelPlus(model = {}, paths = [args.model / 'dummy'], format = 'none', vocab = None)
if args.dump: if args.dump:
@ -1220,7 +1152,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
assert args.outfile, "need --outfile if using --vocab-only" assert args.outfile, "need --outfile if using --vocab-only"
# FIXME: Try to respect vocab_dir somehow? # FIXME: Try to respect vocab_dir somehow?
vocab = load_vocab(args.vocab_dir or args.model, args.vocabtype) vocab = load_vocab(args.vocab_dir or args.model, args.vocabtype)
special_vocab = SpecialVocab(model_plus.paths[0].parent) special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent)
outfile = args.outfile outfile = args.outfile
OutputFile.write_vocab_only(outfile, params, vocab, special_vocab) OutputFile.write_vocab_only(outfile, params, vocab, special_vocab)
print(f"Wrote {outfile}") print(f"Wrote {outfile}")
@ -1232,7 +1164,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent
vocab = load_vocab(vocab_dir, args.vocabtype) vocab = load_vocab(vocab_dir, args.vocabtype)
# FIXME: Try to respect vocab_dir somehow? # FIXME: Try to respect vocab_dir somehow?
special_vocab = SpecialVocab(model_plus.paths[0].parent) special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent)
model = model_plus.model model = model_plus.model
model = convert_model_names(model, params) model = convert_model_names(model, params)

View file

@ -4,9 +4,13 @@ import sys
import struct import struct
import tempfile import tempfile
import numpy as np import numpy as np
import json
import os
from pathlib import Path
import collections.abc as collections_abc
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import Any, BinaryIO, IO, Dict, List, Optional, Sequence, Tuple from typing import Any, BinaryIO, Callable, IO, Dict, List, Optional, Sequence, Tuple, Union
# #
# constants # constants
@ -317,7 +321,7 @@ class TensorNameMap:
key = key.format(bid = bid) key = key.format(bid = bid)
mapping[key] = (tensor, tensor_name) mapping[key] = (tensor, tensor_name)
def get_both(self, key: str, try_suffixes: Sequence[str]) -> Optional[Tuple[MODEL_TENSOR, str]]: def get_type_and_name(self, key: str, try_suffixes: Sequence[str]) -> Optional[Tuple[MODEL_TENSOR, str]]:
result = self.mapping.get(key) result = self.mapping.get(key)
if result is not None: if result is not None:
return result return result
@ -329,11 +333,17 @@ class TensorNameMap:
return None return None
def get_name(self, key: str, try_suffixes: Sequence[str]) -> Optional[str]: def get_name(self, key: str, try_suffixes: Sequence[str]) -> Optional[str]:
result = self.get_both(key, try_suffixes = try_suffixes) result = self.get_type_and_name(key, try_suffixes = try_suffixes)
if result is None: if result is None:
return None return None
return result[1] return result[1]
def get_type(self, key: str, try_suffixes: Sequence[str]) -> Optional[MODEL_TENSOR]:
result = self.get_type_and_name(key, try_suffixes = try_suffixes)
if result is None:
return None
return result[0]
def __getitem__(self, key: str) -> str: def __getitem__(self, key: str) -> str:
try: try:
return self.mapping[key][1] return self.mapping[key][1]
@ -423,9 +433,9 @@ class GGUFWriter:
ti_data_count = 0 ti_data_count = 0
use_temp_file: bool use_temp_file: bool
temp_file: Optional[tempfile.SpooledTemporaryFile[bytes]] = None temp_file: Optional[tempfile.SpooledTemporaryFile[bytes]] = None
tensors: List[Tuple[np.ndarray, int]] tensors: List[Tuple[np.ndarray[Any, Any], int]]
def __init__(self, path: str, arch: str, use_temp_file = True): def __init__(self, path: Union[os.PathLike[str], str], arch: str, use_temp_file = True):
self.fout = open(path, "wb") self.fout = open(path, "wb")
self.arch = arch self.arch = arch
self.add_architecture() self.add_architecture()
@ -501,13 +511,26 @@ class GGUFWriter:
self.add_key(key) self.add_key(key)
self.add_val(val, GGUFValueType.STRING) self.add_val(val, GGUFValueType.STRING)
def add_array(self, key: str, val: list): def add_array(self, key: str, val: Sequence[Any]):
if not isinstance(val, list): if not isinstance(val, collections_abc.Sequence):
raise ValueError("Value must be a list for array type") raise ValueError("Value must be a sequence for array type")
self.add_key(key) self.add_key(key)
self.add_val(val, GGUFValueType.ARRAY) self.add_val(val, GGUFValueType.ARRAY)
_simple_value_packing = {
GGUFValueType.UINT8: "<B",
GGUFValueType.INT8: "<b",
GGUFValueType.UINT16: "<H",
GGUFValueType.INT16: "<h",
GGUFValueType.UINT32: "<I",
GGUFValueType.INT32: "<i",
GGUFValueType.FLOAT32: "<f",
GGUFValueType.UINT64: "<Q",
GGUFValueType.INT64: "<q",
GGUFValueType.FLOAT64: "<d",
GGUFValueType.BOOL: "?" ,
}
def add_val(self, val: Any, vtype: Optional[GGUFValueType] = None, add_vtype: bool = True): def add_val(self, val: Any, vtype: Optional[GGUFValueType] = None, add_vtype: bool = True):
if vtype is None: if vtype is None:
vtype = GGUFValueType.get_type(val) vtype = GGUFValueType.get_type(val)
@ -516,28 +539,9 @@ class GGUFWriter:
self.kv_data += struct.pack("<I", vtype) self.kv_data += struct.pack("<I", vtype)
self.kv_data_count += 1 self.kv_data_count += 1
if vtype == GGUFValueType.UINT8: pack_fmt = self._simple_value_packing.get(vtype)
self.kv_data += struct.pack("<B", val) if pack_fmt is not None:
elif vtype == GGUFValueType.INT8: self.kv_data += struct.pack(pack_fmt, val)
self.kv_data += struct.pack("<b", val)
elif vtype == GGUFValueType.UINT16:
self.kv_data += struct.pack("<H", val)
elif vtype == GGUFValueType.INT16:
self.kv_data += struct.pack("<h", val)
elif vtype == GGUFValueType.UINT32:
self.kv_data += struct.pack("<I", val)
elif vtype == GGUFValueType.INT32:
self.kv_data += struct.pack("<i", val)
elif vtype == GGUFValueType.FLOAT32:
self.kv_data += struct.pack("<f", val)
elif vtype == GGUFValueType.UINT64:
self.kv_data += struct.pack("<Q", val)
elif vtype == GGUFValueType.INT64:
self.kv_data += struct.pack("<q", val)
elif vtype == GGUFValueType.FLOAT64:
self.kv_data += struct.pack("<d", val)
elif vtype == GGUFValueType.BOOL:
self.kv_data += struct.pack("?", val)
elif vtype == GGUFValueType.STRING: elif vtype == GGUFValueType.STRING:
encoded_val = val.encode("utf8") if isinstance(val, str) else val encoded_val = val.encode("utf8") if isinstance(val, str) else val
self.kv_data += struct.pack("<Q", len(encoded_val)) self.kv_data += struct.pack("<Q", len(encoded_val))
@ -556,7 +560,7 @@ class GGUFWriter:
def ggml_pad(x: int, n: int) -> int: def ggml_pad(x: int, n: int) -> int:
return ((x + n - 1) // n) * n return ((x + n - 1) // n) * n
def add_tensor_info(self, name: str, tensor_shape: np.ndarray, tensor_dtype: np.dtype, tensor_nbytes: int, raw_dtype: Optional[GGMLQuantizationType] = None): def add_tensor_info(self, name: str, tensor_shape: Sequence[int], tensor_dtype: Union[np.dtype[np.float16], np.dtype[np.float32]], tensor_nbytes: int, raw_dtype: Optional[GGMLQuantizationType] = None):
assert raw_dtype is not None or tensor_dtype in (np.float32, np.float16), "Only F32 and F16 tensors are supported for now" assert raw_dtype is not None or tensor_dtype in (np.float32, np.float16), "Only F32 and F16 tensors are supported for now"
encoded_name = name.encode("utf8") encoded_name = name.encode("utf8")
@ -575,13 +579,14 @@ class GGUFWriter:
self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment) self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment)
self.ti_data_count += 1 self.ti_data_count += 1
def add_tensor(self, name: str, tensor: np.ndarray, raw_shape: Optional[np.ndarray] = None, raw_dtype: Optional[GGMLQuantizationType] = None): def add_tensor(self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Optional[Sequence[int]] = None, raw_dtype: Optional[GGMLQuantizationType] = None):
if self.use_temp_file and self.temp_file is None: if self.use_temp_file and self.temp_file is None:
fp = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256*1024*1024) fp = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256*1024*1024)
fp.seek(0) fp.seek(0)
self.temp_file = fp self.temp_file = fp
self.add_tensor_info(name, raw_shape if raw_shape is not None else tensor.shape, tensor.dtype, tensor.nbytes, raw_dtype = raw_dtype) shape: Sequence[int] = raw_shape if raw_shape is not None else tensor.shape
self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype = raw_dtype)
pad = GGUFWriter.ggml_pad(tensor.nbytes, self.data_alignment) - tensor.nbytes pad = GGUFWriter.ggml_pad(tensor.nbytes, self.data_alignment) - tensor.nbytes
@ -599,7 +604,7 @@ class GGUFWriter:
if pad != 0: if pad != 0:
fp.write(bytes([0] * pad)) fp.write(bytes([0] * pad))
def write_tensor_data(self, tensor: np.ndarray): def write_tensor_data(self, tensor: np.ndarray[Any, Any]):
self.write_padding(self.fout, self.fout.tell()) self.write_padding(self.fout, self.fout.tell())
tensor.tofile(self.fout) tensor.tofile(self.fout)
self.write_padding(self.fout, tensor.nbytes) self.write_padding(self.fout, tensor.nbytes)
@ -720,16 +725,16 @@ class GGUFWriter:
def add_tokenizer_model(self, model: str): def add_tokenizer_model(self, model: str):
self.add_string(KEY_TOKENIZER_MODEL, model) self.add_string(KEY_TOKENIZER_MODEL, model)
def add_token_list(self, tokens: List): def add_token_list(self, tokens: Union[Sequence[str], Sequence[bytes], Sequence[bytearray]]):
self.add_array(KEY_TOKENIZER_LIST, tokens) self.add_array(KEY_TOKENIZER_LIST, tokens)
def add_token_merges(self, merges: List): def add_token_merges(self, merges: Union[Sequence[str], Sequence[bytes], Sequence[bytearray]]):
self.add_array(KEY_TOKENIZER_MERGES, merges) self.add_array(KEY_TOKENIZER_MERGES, merges)
def add_token_types(self, types: List[int]): def add_token_types(self, types: Union[Sequence[TokenType], Sequence[int]]):
self.add_array(KEY_TOKENIZER_TOKEN_TYPE, types) self.add_array(KEY_TOKENIZER_TOKEN_TYPE, types)
def add_token_scores(self, scores: List[float]): def add_token_scores(self, scores: Sequence[float]):
self.add_array(KEY_TOKENIZER_SCORES, scores) self.add_array(KEY_TOKENIZER_SCORES, scores)
def add_bos_token_id(self, id: int): def add_bos_token_id(self, id: int):
@ -748,6 +753,75 @@ class GGUFWriter:
self.add_uint32(KEY_TOKENIZER_PAD_ID, id) self.add_uint32(KEY_TOKENIZER_PAD_ID, id)
class SpecialVocab:
merges: List[str] = []
special_token_types: Tuple[str, ...] = tuple(('bos', 'eos', 'unk', 'sep', 'pad'))
special_token_ids: Dict[str, int] = {}
def __init__(self, path: Path, special_token_types: Optional[Tuple[str, ...]] = None):
self.special_token_ids = {}
if special_token_types is not None:
self.special_token_types = special_token_types
self.load(path)
def load(self, path: Path):
if not self.try_load_from_tokenizer_json(path):
self.try_load_from_config_json(path)
def try_load_from_tokenizer_json(self, path: Path) -> bool:
tokenizer_file = path / 'tokenizer.json'
if not tokenizer_file.is_file():
return False
with open(tokenizer_file, 'r', encoding = 'utf-8') as f:
tokenizer = json.load(f)
merges = tokenizer.get('model', {}).get('merges')
if isinstance(merges, list) and len(merges) > 0 and isinstance(merges[0], str):
self.merges = merges
tokenizer_config_file = path / 'tokenizer_config.json'
added_tokens = tokenizer.get('added_tokens')
if added_tokens is None or not tokenizer_config_file.is_file():
return True
with open(tokenizer_config_file, 'r', encoding = 'utf-8') as f:
tokenizer_config = json.load(f)
for typ in self.special_token_types:
tc_content = (tokenizer_config.get(f'{typ}_token') or {}).get('content')
if not isinstance(tc_content, str):
continue
for maybe_token_id in (atok.get('id') for atok in added_tokens if atok.get('content') == tc_content):
if isinstance(maybe_token_id, int):
self.special_token_ids[typ] = maybe_token_id
break
return True
def try_load_from_config_json(self, path: Path) -> bool:
config_file = path / 'config.json'
if not config_file.is_file():
return False
with open(config_file, 'r', encoding = 'utf-8') as f:
config = json.load(f)
for typ in self.special_token_types:
maybe_token_id = config.get(f'{typ}_token_id')
if isinstance(maybe_token_id, int):
self.special_token_ids[typ] = maybe_token_id
return True
def add_to_gguf(self, gw: GGUFWriter):
# FIXME: Don't always include merges (possibly also don't even load them).
if len(self.merges) > 0:
print(f'SpecialVocab: Adding {len(self.merges)} merge(s).')
gw.add_token_merges(self.merges)
for typ, tokid in self.special_token_ids.items():
handler: Optional[Callable[[int], None]] = getattr(gw, f'add_{typ}_token_id', None)
if handler is None:
print(f'SpecialVocab: WARNING: No handler for special token type {typ} with id {tokid} - skipping')
continue
print(f'SpecialVocab: Setting special token type {typ} to {tokid}')
handler(tokid)
def __repr__(self):
return f'<SpecialVocab with {len(self.merges)} merges and special tokens {self.special_token_ids if self.special_token_ids else "unset"}>'
# Example usage: # Example usage:
if __name__ == "__main__": if __name__ == "__main__":
# Example usage with a file # Example usage with a file