convert : vocab inheritance instead of duck typing
This commit is contained in:
parent
72e95e33a9
commit
9803bb7206
1 changed files with 47 additions and 33 deletions
80
convert.py
80
convert.py
|
@ -18,11 +18,11 @@ import struct
|
|||
import sys
|
||||
import time
|
||||
import zipfile
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import IO, TYPE_CHECKING, Any, Callable, Iterable, Literal, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, IO, Iterable, Literal, Protocol, TypeVar, runtime_checkable
|
||||
|
||||
import numpy as np
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
|
@ -331,7 +331,32 @@ class Params:
|
|||
# vocab
|
||||
#
|
||||
|
||||
class BpeVocab:
|
||||
@runtime_checkable
|
||||
class BaseVocab(Protocol):
|
||||
tokenizer_model: ClassVar[str]
|
||||
name: ClassVar[str]
|
||||
|
||||
|
||||
class NoVocab(BaseVocab):
|
||||
tokenizer_model = "no_vocab"
|
||||
name = "no_vocab"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "<NoVocab for a model without integrated vocabulary>"
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Vocab(BaseVocab, Protocol):
|
||||
vocab_size: int
|
||||
added_tokens_dict: dict[str, int]
|
||||
added_tokens_list: list[str]
|
||||
fname_tokenizer: Path
|
||||
|
||||
def __init__(self, base_path: Path): ...
|
||||
def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: ...
|
||||
|
||||
|
||||
class BpeVocab(Vocab):
|
||||
tokenizer_model = "gpt2"
|
||||
name = "bpe"
|
||||
|
||||
|
@ -391,7 +416,7 @@ class BpeVocab:
|
|||
return f"<BpeVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
|
||||
|
||||
|
||||
class SentencePieceVocab:
|
||||
class SentencePieceVocab(Vocab):
|
||||
tokenizer_model = "llama"
|
||||
name = "spm"
|
||||
|
||||
|
@ -456,7 +481,7 @@ class SentencePieceVocab:
|
|||
return f"<SentencePieceVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
|
||||
|
||||
|
||||
class HfVocab:
|
||||
class HfVocab(Vocab):
|
||||
tokenizer_model = "llama"
|
||||
name = "hfft"
|
||||
|
||||
|
@ -559,17 +584,6 @@ class HfVocab:
|
|||
return f"<HfVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
|
||||
|
||||
|
||||
class NoVocab:
|
||||
tokenizer_model = "no_vocab"
|
||||
name = "no_vocab"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "<NoVocab for a model without integrated vocabulary>"
|
||||
|
||||
|
||||
Vocab: TypeAlias = "BpeVocab | SentencePieceVocab | HfVocab | NoVocab"
|
||||
|
||||
|
||||
#
|
||||
# data loading
|
||||
# TODO: reuse (probably move to gguf.py?)
|
||||
|
@ -585,7 +599,7 @@ def permute(weights: NDArray, n_head: int, n_head_kv: int) -> NDArray:
|
|||
.reshape(weights.shape))
|
||||
|
||||
|
||||
class Tensor(metaclass=ABCMeta):
|
||||
class Tensor(ABC):
|
||||
data_type: DataType
|
||||
|
||||
@abstractmethod
|
||||
|
@ -686,7 +700,7 @@ class ModelPlus:
|
|||
model: LazyModel
|
||||
paths: list[Path] # Where this was read from.
|
||||
format: Literal['ggml', 'torch', 'safetensors', 'none']
|
||||
vocab: Vocab | None # For GGML models (which have vocab built in), the vocab.
|
||||
vocab: BaseVocab | None # For GGML models (which have vocab built in), the vocab.
|
||||
|
||||
|
||||
def merge_sharded(models: list[LazyModel]) -> LazyModel:
|
||||
|
@ -945,13 +959,14 @@ def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], conc
|
|||
yield result
|
||||
|
||||
|
||||
def check_vocab_size(params: Params, vocab: Vocab, pad_vocab: bool = False) -> None:
|
||||
def check_vocab_size(params: Params, vocab: BaseVocab, pad_vocab: bool = False) -> None:
|
||||
# Handle special case where the model's vocab size is not set
|
||||
if params.n_vocab == -1:
|
||||
raise ValueError(
|
||||
f"The model's vocab size is set to -1 in params.json. Please update it manually.{f' Maybe {vocab.vocab_size}?' if hasattr(vocab, 'vocab_size') else ''}"
|
||||
f"The model's vocab size is set to -1 in params.json. Please update it manually." +
|
||||
(f' Maybe {vocab.vocab_size}?' if isinstance(vocab, Vocab) else ''),
|
||||
)
|
||||
if isinstance(vocab, NoVocab):
|
||||
if not isinstance(vocab, Vocab):
|
||||
return # model has no vocab
|
||||
|
||||
# Check for a vocab size mismatch
|
||||
|
@ -1031,8 +1046,6 @@ class OutputFile:
|
|||
self.gguf.add_file_type(params.ftype)
|
||||
|
||||
def extract_vocabulary_from_model(self, vocab: Vocab) -> tuple[list[bytes], list[float], list[gguf.TokenType]]:
|
||||
assert not isinstance(vocab, NoVocab)
|
||||
|
||||
tokens = []
|
||||
scores = []
|
||||
toktypes = []
|
||||
|
@ -1132,7 +1145,7 @@ class OutputFile:
|
|||
|
||||
@staticmethod
|
||||
def write_all(
|
||||
fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, svocab: gguf.SpecialVocab,
|
||||
fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: BaseVocab, svocab: gguf.SpecialVocab,
|
||||
concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE,
|
||||
pad_vocab: bool = False,
|
||||
) -> None:
|
||||
|
@ -1142,11 +1155,11 @@ class OutputFile:
|
|||
|
||||
# meta data
|
||||
of.add_meta_arch(params)
|
||||
if isinstance(vocab, NoVocab):
|
||||
of.gguf.add_tokenizer_model(vocab.tokenizer_model)
|
||||
else:
|
||||
if isinstance(vocab, Vocab):
|
||||
of.add_meta_vocab(vocab)
|
||||
of.add_meta_special_vocab(svocab)
|
||||
else: # NoVocab
|
||||
of.gguf.add_tokenizer_model(vocab.tokenizer_model)
|
||||
|
||||
# tensor info
|
||||
for name, lazy_tensor in model.items():
|
||||
|
@ -1317,9 +1330,9 @@ class VocabFactory:
|
|||
return vtype, path
|
||||
raise FileNotFoundError(f"Could not find any of {[self._FILES[vt] for vt in vocab_types]}")
|
||||
|
||||
def _create_special_vocab(self, vocab: Vocab, model_parent_path: Path) -> gguf.SpecialVocab:
|
||||
def _create_special_vocab(self, vocab: BaseVocab, model_parent_path: Path) -> gguf.SpecialVocab:
|
||||
load_merges = vocab.name == "bpe"
|
||||
n_vocab = vocab.vocab_size if hasattr(vocab, "vocab_size") else None
|
||||
n_vocab = vocab.vocab_size if isinstance(vocab, Vocab) else None
|
||||
return gguf.SpecialVocab(
|
||||
model_parent_path,
|
||||
load_merges=load_merges,
|
||||
|
@ -1327,7 +1340,7 @@ class VocabFactory:
|
|||
n_vocab=n_vocab,
|
||||
)
|
||||
|
||||
def _create_vocab_by_path(self, vocab_types: list[str]) -> Vocab:
|
||||
def _create_vocab_by_path(self, vocab_types: list[str]) -> BaseVocab:
|
||||
vocab_type, path = self._select_file(vocab_types)
|
||||
print(f"Loading vocab file {path!r}, type {vocab_type!r}")
|
||||
|
||||
|
@ -1346,8 +1359,8 @@ class VocabFactory:
|
|||
)
|
||||
raise ValueError(vocab_type)
|
||||
|
||||
def load_vocab(self, vocab_types: list[str], model_parent_path: Path) -> tuple[Vocab, gguf.SpecialVocab]:
|
||||
vocab: Vocab
|
||||
def load_vocab(self, vocab_types: list[str], model_parent_path: Path) -> tuple[BaseVocab, gguf.SpecialVocab]:
|
||||
vocab: BaseVocab
|
||||
if len(vocab_types) == 1 and "no_vocab" in vocab_types:
|
||||
vocab = NoVocab()
|
||||
else:
|
||||
|
@ -1407,7 +1420,7 @@ def main(args_in: list[str] | None = None) -> None:
|
|||
args = parser.parse_args(args_in)
|
||||
if args.no_vocab:
|
||||
if args.vocab_only:
|
||||
raise ValueError("no need to specify --vocab-only if using --no-vocab")
|
||||
raise ValueError("--vocab-only does not make sense with --no-vocab")
|
||||
args.vocab_type = "no_vocab"
|
||||
|
||||
if args.dump_single:
|
||||
|
@ -1451,6 +1464,7 @@ def main(args_in: list[str] | None = None) -> None:
|
|||
vocab, special_vocab = vocab_factory.load_vocab(args.vocab_type.split(","), model_parent_path)
|
||||
|
||||
if args.vocab_only:
|
||||
assert isinstance(vocab, Vocab)
|
||||
if not args.outfile:
|
||||
raise ValueError("need --outfile if using --vocab-only")
|
||||
outfile = args.outfile
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue