convert : vocab inheritance instead of duck typing

This commit is contained in:
Jared Van Bortel 2024-03-27 12:30:49 -04:00
parent 72e95e33a9
commit 9803bb7206

View file

@ -18,11 +18,11 @@ import struct
import sys import sys
import time import time
import zipfile import zipfile
from abc import ABCMeta, abstractmethod from abc import ABC, abstractmethod
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import IO, TYPE_CHECKING, Any, Callable, Iterable, Literal, TypeVar from typing import TYPE_CHECKING, Any, Callable, ClassVar, IO, Iterable, Literal, Protocol, TypeVar, runtime_checkable
import numpy as np import numpy as np
from sentencepiece import SentencePieceProcessor from sentencepiece import SentencePieceProcessor
@ -331,7 +331,32 @@ class Params:
# vocab # vocab
# #
class BpeVocab: @runtime_checkable
class BaseVocab(Protocol):
tokenizer_model: ClassVar[str]
name: ClassVar[str]
class NoVocab(BaseVocab):
tokenizer_model = "no_vocab"
name = "no_vocab"
def __repr__(self) -> str:
return "<NoVocab for a model without integrated vocabulary>"
@runtime_checkable
class Vocab(BaseVocab, Protocol):
vocab_size: int
added_tokens_dict: dict[str, int]
added_tokens_list: list[str]
fname_tokenizer: Path
def __init__(self, base_path: Path): ...
def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: ...
class BpeVocab(Vocab):
tokenizer_model = "gpt2" tokenizer_model = "gpt2"
name = "bpe" name = "bpe"
@ -391,7 +416,7 @@ class BpeVocab:
return f"<BpeVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>" return f"<BpeVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
class SentencePieceVocab: class SentencePieceVocab(Vocab):
tokenizer_model = "llama" tokenizer_model = "llama"
name = "spm" name = "spm"
@ -456,7 +481,7 @@ class SentencePieceVocab:
return f"<SentencePieceVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>" return f"<SentencePieceVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
class HfVocab: class HfVocab(Vocab):
tokenizer_model = "llama" tokenizer_model = "llama"
name = "hfft" name = "hfft"
@ -559,17 +584,6 @@ class HfVocab:
return f"<HfVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>" return f"<HfVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
class NoVocab:
tokenizer_model = "no_vocab"
name = "no_vocab"
def __repr__(self) -> str:
return "<NoVocab for a model without integrated vocabulary>"
Vocab: TypeAlias = "BpeVocab | SentencePieceVocab | HfVocab | NoVocab"
# #
# data loading # data loading
# TODO: reuse (probably move to gguf.py?) # TODO: reuse (probably move to gguf.py?)
@ -585,7 +599,7 @@ def permute(weights: NDArray, n_head: int, n_head_kv: int) -> NDArray:
.reshape(weights.shape)) .reshape(weights.shape))
class Tensor(metaclass=ABCMeta): class Tensor(ABC):
data_type: DataType data_type: DataType
@abstractmethod @abstractmethod
@ -686,7 +700,7 @@ class ModelPlus:
model: LazyModel model: LazyModel
paths: list[Path] # Where this was read from. paths: list[Path] # Where this was read from.
format: Literal['ggml', 'torch', 'safetensors', 'none'] format: Literal['ggml', 'torch', 'safetensors', 'none']
vocab: Vocab | None # For GGML models (which have vocab built in), the vocab. vocab: BaseVocab | None # For GGML models (which have vocab built in), the vocab.
def merge_sharded(models: list[LazyModel]) -> LazyModel: def merge_sharded(models: list[LazyModel]) -> LazyModel:
@ -945,13 +959,14 @@ def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], conc
yield result yield result
def check_vocab_size(params: Params, vocab: Vocab, pad_vocab: bool = False) -> None: def check_vocab_size(params: Params, vocab: BaseVocab, pad_vocab: bool = False) -> None:
# Handle special case where the model's vocab size is not set # Handle special case where the model's vocab size is not set
if params.n_vocab == -1: if params.n_vocab == -1:
raise ValueError( raise ValueError(
f"The model's vocab size is set to -1 in params.json. Please update it manually.{f' Maybe {vocab.vocab_size}?' if hasattr(vocab, 'vocab_size') else ''}" f"The model's vocab size is set to -1 in params.json. Please update it manually." +
(f' Maybe {vocab.vocab_size}?' if isinstance(vocab, Vocab) else ''),
) )
if isinstance(vocab, NoVocab): if not isinstance(vocab, Vocab):
return # model has no vocab return # model has no vocab
# Check for a vocab size mismatch # Check for a vocab size mismatch
@ -1031,8 +1046,6 @@ class OutputFile:
self.gguf.add_file_type(params.ftype) self.gguf.add_file_type(params.ftype)
def extract_vocabulary_from_model(self, vocab: Vocab) -> tuple[list[bytes], list[float], list[gguf.TokenType]]: def extract_vocabulary_from_model(self, vocab: Vocab) -> tuple[list[bytes], list[float], list[gguf.TokenType]]:
assert not isinstance(vocab, NoVocab)
tokens = [] tokens = []
scores = [] scores = []
toktypes = [] toktypes = []
@ -1132,7 +1145,7 @@ class OutputFile:
@staticmethod @staticmethod
def write_all( def write_all(
fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, svocab: gguf.SpecialVocab, fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: BaseVocab, svocab: gguf.SpecialVocab,
concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE,
pad_vocab: bool = False, pad_vocab: bool = False,
) -> None: ) -> None:
@ -1142,11 +1155,11 @@ class OutputFile:
# meta data # meta data
of.add_meta_arch(params) of.add_meta_arch(params)
if isinstance(vocab, NoVocab): if isinstance(vocab, Vocab):
of.gguf.add_tokenizer_model(vocab.tokenizer_model)
else:
of.add_meta_vocab(vocab) of.add_meta_vocab(vocab)
of.add_meta_special_vocab(svocab) of.add_meta_special_vocab(svocab)
else: # NoVocab
of.gguf.add_tokenizer_model(vocab.tokenizer_model)
# tensor info # tensor info
for name, lazy_tensor in model.items(): for name, lazy_tensor in model.items():
@ -1317,9 +1330,9 @@ class VocabFactory:
return vtype, path return vtype, path
raise FileNotFoundError(f"Could not find any of {[self._FILES[vt] for vt in vocab_types]}") raise FileNotFoundError(f"Could not find any of {[self._FILES[vt] for vt in vocab_types]}")
def _create_special_vocab(self, vocab: Vocab, model_parent_path: Path) -> gguf.SpecialVocab: def _create_special_vocab(self, vocab: BaseVocab, model_parent_path: Path) -> gguf.SpecialVocab:
load_merges = vocab.name == "bpe" load_merges = vocab.name == "bpe"
n_vocab = vocab.vocab_size if hasattr(vocab, "vocab_size") else None n_vocab = vocab.vocab_size if isinstance(vocab, Vocab) else None
return gguf.SpecialVocab( return gguf.SpecialVocab(
model_parent_path, model_parent_path,
load_merges=load_merges, load_merges=load_merges,
@ -1327,7 +1340,7 @@ class VocabFactory:
n_vocab=n_vocab, n_vocab=n_vocab,
) )
def _create_vocab_by_path(self, vocab_types: list[str]) -> Vocab: def _create_vocab_by_path(self, vocab_types: list[str]) -> BaseVocab:
vocab_type, path = self._select_file(vocab_types) vocab_type, path = self._select_file(vocab_types)
print(f"Loading vocab file {path!r}, type {vocab_type!r}") print(f"Loading vocab file {path!r}, type {vocab_type!r}")
@ -1346,8 +1359,8 @@ class VocabFactory:
) )
raise ValueError(vocab_type) raise ValueError(vocab_type)
def load_vocab(self, vocab_types: list[str], model_parent_path: Path) -> tuple[Vocab, gguf.SpecialVocab]: def load_vocab(self, vocab_types: list[str], model_parent_path: Path) -> tuple[BaseVocab, gguf.SpecialVocab]:
vocab: Vocab vocab: BaseVocab
if len(vocab_types) == 1 and "no_vocab" in vocab_types: if len(vocab_types) == 1 and "no_vocab" in vocab_types:
vocab = NoVocab() vocab = NoVocab()
else: else:
@ -1407,7 +1420,7 @@ def main(args_in: list[str] | None = None) -> None:
args = parser.parse_args(args_in) args = parser.parse_args(args_in)
if args.no_vocab: if args.no_vocab:
if args.vocab_only: if args.vocab_only:
raise ValueError("no need to specify --vocab-only if using --no-vocab") raise ValueError("--vocab-only does not make sense with --no-vocab")
args.vocab_type = "no_vocab" args.vocab_type = "no_vocab"
if args.dump_single: if args.dump_single:
@ -1451,6 +1464,7 @@ def main(args_in: list[str] | None = None) -> None:
vocab, special_vocab = vocab_factory.load_vocab(args.vocab_type.split(","), model_parent_path) vocab, special_vocab = vocab_factory.load_vocab(args.vocab_type.split(","), model_parent_path)
if args.vocab_only: if args.vocab_only:
assert isinstance(vocab, Vocab)
if not args.outfile: if not args.outfile:
raise ValueError("need --outfile if using --vocab-only") raise ValueError("need --outfile if using --vocab-only")
outfile = args.outfile outfile = args.outfile