convert : various script cleanups/fixes + merges and special token handling (#2842)
* convert: Fix permute calls and method/func definitions * Cleanups for gguf-py * Minor types cleanups. * Initial implementation of handling merges and special tokens * convert: Handle special tokens and merges in vocab only mode convert: Vocab only mode no longer requires loading model tensors * gguf: Refactor tensor name mapping * convert: Fix type hint for special_token_types in SpecialVocab * Use common special vocab handling in various conversion scripts * First pass at implementing suggested changes * Second pass * gguf: SpecialVocab: Fix issue with special token content not in a dict gguf: SpecialVocab: Allow skipping handling of merges * convert-falcon-hf-to-gguf: Support --vocab-only option, bail out if no tokenizer.json * convert-gptneox-hf-to-gguf and convert: Only handle merges for BPE tokenizer * gguf: SpecialVocab: Actually set load_merges in object * Uniform args parsing and vocab only mode for convert examples * convert.py: Set gpt2 as tokenizer model when using BPE * Squish last type warning in gguf.py - yay!
This commit is contained in:
parent
ad9ddcff6e
commit
dc07dc492e
10 changed files with 728 additions and 748 deletions
142
convert.py
142
convert.py
|
@ -25,7 +25,7 @@ import numpy as np
|
|||
from abc import ABCMeta, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable, List, Literal, Optional, Sequence, Set, Tuple, TypeVar, Union)
|
||||
from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable, List, Literal, Optional, Sequence, Set, Tuple, Type, TypeVar, Union)
|
||||
from sentencepiece import SentencePieceProcessor # type: ignore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -299,8 +299,10 @@ class Params:
|
|||
params = Params.loadHFTransformerJson(model_plus.model, hf_config_path)
|
||||
elif orig_config_path.exists():
|
||||
params = Params.loadOriginalParamsJson(model_plus.model, orig_config_path)
|
||||
else:
|
||||
elif model_plus.format != 'none':
|
||||
params = Params.guessed(model_plus.model)
|
||||
else:
|
||||
raise ValueError('Cannot guess params when model format is none')
|
||||
|
||||
params.path_model = model_plus.paths[0].parent
|
||||
|
||||
|
@ -353,7 +355,7 @@ class BpeVocab:
|
|||
yield from self.added_tokens()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"BpeVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
|
||||
return f"<BpeVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
|
||||
|
||||
|
||||
class SentencePieceVocab:
|
||||
|
@ -416,7 +418,6 @@ class SentencePieceVocab:
|
|||
|
||||
Vocab = Union[BpeVocab, SentencePieceVocab]
|
||||
|
||||
|
||||
#
|
||||
# data loading
|
||||
# TODO: reuse (probably move to gguf.py?)
|
||||
|
@ -439,14 +440,14 @@ class Tensor(metaclass=ABCMeta):
|
|||
@abstractmethod
|
||||
def permute(self, n_head: int, n_head_kv: int) -> 'Tensor': ...
|
||||
@abstractmethod
|
||||
def permute_part(self, n_part: int, n_head: int) -> 'UnquantizedTensor': ...
|
||||
def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> 'UnquantizedTensor': ...
|
||||
@abstractmethod
|
||||
def part(self, n_part: int) -> 'UnquantizedTensor': ...
|
||||
@abstractmethod
|
||||
def to_ggml(self) -> 'GGMLCompatibleTensor': ...
|
||||
|
||||
|
||||
def bf16_to_fp32(bf16_arr: np.ndarray) -> np.ndarray:
|
||||
def bf16_to_fp32(bf16_arr: np.ndarray[Any, np.dtype[np.uint16]]) -> NDArray:
|
||||
assert bf16_arr.dtype == np.uint16, f"Input array should be of dtype uint16, but got {bf16_arr.dtype}"
|
||||
fp32_arr = bf16_arr.astype(np.uint32) << 16
|
||||
return fp32_arr.view(np.float32)
|
||||
|
@ -467,9 +468,9 @@ class UnquantizedTensor(Tensor):
|
|||
def to_ggml(self) -> 'UnquantizedTensor':
|
||||
return self
|
||||
|
||||
def permute_part(self, n_part: int, n_head: int) -> 'UnquantizedTensor':
|
||||
def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> 'UnquantizedTensor':
|
||||
r = self.ndarray.shape[0] // 3
|
||||
return UnquantizedTensor(permute(self.ndarray[r * n_part : r * n_part + r, ...], n_head, n_head))
|
||||
return UnquantizedTensor(permute(self.ndarray[r * n_part : r * n_part + r, ...], n_head, n_head_kv))
|
||||
|
||||
def part(self, n_part: int) -> 'UnquantizedTensor':
|
||||
r = self.ndarray.shape[0] // 3
|
||||
|
@ -531,7 +532,7 @@ LazyModel = Dict[str, LazyTensor]
|
|||
class ModelPlus:
|
||||
model: LazyModel
|
||||
paths: List[Path] # Where this was read from.
|
||||
format: Literal['ggml', 'torch', 'safetensors']
|
||||
format: Literal['ggml', 'torch', 'safetensors', 'none']
|
||||
vocab: Optional[Vocab] # For GGML models (which have vocab built in), the vocab.
|
||||
|
||||
|
||||
|
@ -597,12 +598,12 @@ def permute_lazy(lazy_tensor: LazyTensor, n_head: int, n_head_kv: int) -> LazyTe
|
|||
return lazy_tensor.load().permute(n_head, n_head_kv)
|
||||
return LazyTensor(load, lazy_tensor.shape, lazy_tensor.data_type, f'permute({n_head}, {n_head_kv}) ' + lazy_tensor.description)
|
||||
|
||||
def permute_part_lazy(lazy_tensor: LazyTensor, n_part: int, n_head: int) -> LazyTensor:
|
||||
def permute_part_lazy(lazy_tensor: LazyTensor, n_part: int, n_head: int, n_head_kv: int) -> LazyTensor:
|
||||
def load() -> Tensor:
|
||||
return lazy_tensor.load().permute_part(n_part, n_head)
|
||||
return lazy_tensor.load().permute_part(n_part, n_head, n_head_kv)
|
||||
s = lazy_tensor.shape.copy()
|
||||
s[0] = s[0] // 3
|
||||
return LazyTensor(load, s, lazy_tensor.data_type, f'permute({n_head}) ' + lazy_tensor.description)
|
||||
return LazyTensor(load, s, lazy_tensor.data_type, f'permute({n_head}, {n_head_kv}) ' + lazy_tensor.description)
|
||||
|
||||
def part_lazy(lazy_tensor: LazyTensor, n_part: int) -> LazyTensor:
|
||||
def load() -> Tensor:
|
||||
|
@ -657,7 +658,7 @@ class LazyUnpickler(pickle.Unpickler):
|
|||
description = f'storage data_type={data_type} path-in-zip={filename} path={self.zip_file.filename}'
|
||||
return LazyStorage(load=load, kind=pid[1], description=description)
|
||||
|
||||
# @staticmethod
|
||||
@staticmethod
|
||||
def lazy_rebuild_tensor_v2(storage: Any, storage_offset: Any, size: Any, stride: Any,
|
||||
# pyright: ignore[reportSelfClsParameterName]
|
||||
requires_grad: Any, backward_hooks: Any, metadata: Any = None) -> LazyTensor:
|
||||
|
@ -669,13 +670,15 @@ class LazyUnpickler(pickle.Unpickler):
|
|||
description = f'pickled storage_offset={storage_offset} in {storage.description}'
|
||||
return LazyTensor(load, list(size), storage.kind.data_type, description)
|
||||
|
||||
# @staticmethod
|
||||
@staticmethod
|
||||
def rebuild_from_type_v2(func, new_type, args, state):
|
||||
return func(*args)
|
||||
|
||||
CLASSES: Dict[Any, Any] = {
|
||||
('torch._tensor', '_rebuild_from_type_v2'): rebuild_from_type_v2,
|
||||
('torch._utils', '_rebuild_tensor_v2'): lazy_rebuild_tensor_v2,
|
||||
CLASSES: Dict[Tuple[str, str], Any] = {
|
||||
# getattr used here as a workaround for mypy not being smart enough to detrmine
|
||||
# the staticmethods have a __func__ attribute.
|
||||
('torch._tensor', '_rebuild_from_type_v2'): getattr(rebuild_from_type_v2, '__func__'),
|
||||
('torch._utils', '_rebuild_tensor_v2'): getattr(lazy_rebuild_tensor_v2, '__func__'),
|
||||
('torch', 'BFloat16Storage'): LazyStorageKind(DT_BF16),
|
||||
('torch', 'HalfStorage'): LazyStorageKind(DT_F16),
|
||||
('torch', 'FloatStorage'): LazyStorageKind(DT_F32),
|
||||
|
@ -751,7 +754,7 @@ def lazy_load_file(path: Path) -> ModelPlus:
|
|||
In = TypeVar('In')
|
||||
Out = TypeVar('Out')
|
||||
|
||||
def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], concurrency: int, max_workers: Optional[int] = None, factory: Callable = ThreadPoolExecutor) -> Iterable[Out]:
|
||||
def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], concurrency: int, max_workers: Optional[int] = None, use_processpool_executor: bool = False) -> Iterable[Out]:
|
||||
'''Parallel map, but with backpressure. If the caller doesn't call `next`
|
||||
fast enough, this will stop calling `func` at some point rather than
|
||||
letting results pile up in memory. Specifically, there is a max of one
|
||||
|
@ -760,7 +763,12 @@ def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], conc
|
|||
yield from map(func, iterable)
|
||||
# Not reached.
|
||||
iterable = iter(iterable)
|
||||
with factory(max_workers = max_workers) as executor:
|
||||
executor_class: Union[Type[ThreadPoolExecutor], Type[ProcessPoolExecutor]]
|
||||
if use_processpool_executor:
|
||||
executor_class = ProcessPoolExecutor
|
||||
else:
|
||||
executor_class = ThreadPoolExecutor
|
||||
with executor_class(max_workers = max_workers) as executor:
|
||||
futures: List[concurrent.futures.Future[Out]] = []
|
||||
done = False
|
||||
for _ in range(concurrency):
|
||||
|
@ -838,11 +846,19 @@ class OutputFile:
|
|||
scores.append(score)
|
||||
toktypes.append(toktype)
|
||||
|
||||
self.gguf.add_tokenizer_model("llama")
|
||||
if isinstance(vocab, SentencePieceVocab):
|
||||
self.gguf.add_tokenizer_model("llama")
|
||||
elif isinstance(vocab, BpeVocab):
|
||||
self.gguf.add_tokenizer_model("gpt2")
|
||||
else:
|
||||
raise ValueError(f'Unknown vocab type: Not BpeVocab or SentencePieceVocab')
|
||||
self.gguf.add_token_list(tokens)
|
||||
self.gguf.add_token_scores(scores)
|
||||
self.gguf.add_token_types(toktypes)
|
||||
|
||||
def add_meta_special_vocab(self, svocab: gguf.SpecialVocab) -> None:
|
||||
svocab.add_to_gguf(self.gguf)
|
||||
|
||||
def add_tensor_info(self, name: str, tensor: LazyTensor) -> None:
|
||||
n_elements = int(np.prod(tensor.shape))
|
||||
raw_dtype = getattr(tensor.data_type, 'ggml_type', None)
|
||||
|
@ -861,7 +877,7 @@ class OutputFile:
|
|||
self.gguf.close()
|
||||
|
||||
@staticmethod
|
||||
def write_vocab_only(fname_out: Path, params: Params, vocab: Vocab) -> None:
|
||||
def write_vocab_only(fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab) -> None:
|
||||
check_vocab_size(params, vocab)
|
||||
|
||||
of = OutputFile(fname_out)
|
||||
|
@ -869,6 +885,8 @@ class OutputFile:
|
|||
# meta data
|
||||
of.add_meta_arch(params)
|
||||
of.add_meta_vocab(vocab)
|
||||
of.add_meta_special_vocab(svocab)
|
||||
|
||||
of.write_meta()
|
||||
|
||||
of.close()
|
||||
|
@ -887,7 +905,7 @@ class OutputFile:
|
|||
return dt.quantize(arr)
|
||||
|
||||
@staticmethod
|
||||
def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, concurrency: int = DEFAULT_CONCURRENCY) -> None:
|
||||
def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, svocab: gguf.SpecialVocab, concurrency: int = DEFAULT_CONCURRENCY) -> None:
|
||||
check_vocab_size(params, vocab)
|
||||
|
||||
of = OutputFile(fname_out)
|
||||
|
@ -895,6 +913,7 @@ class OutputFile:
|
|||
# meta data
|
||||
of.add_meta_arch(params)
|
||||
of.add_meta_vocab(vocab)
|
||||
of.add_meta_special_vocab(svocab)
|
||||
|
||||
# tensor info
|
||||
for name, lazy_tensor in model.items():
|
||||
|
@ -906,7 +925,7 @@ class OutputFile:
|
|||
# tensor data
|
||||
ndarrays_inner = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency = concurrency)
|
||||
if ftype == GGMLFileType.MostlyQ8_0:
|
||||
ndarrays = bounded_parallel_map(OutputFile.maybe_do_quantize, ndarrays_inner, concurrency = concurrency, max_workers = concurrency, factory = ProcessPoolExecutor)
|
||||
ndarrays = bounded_parallel_map(OutputFile.maybe_do_quantize, ndarrays_inner, concurrency = concurrency, max_workers = concurrency, use_processpool_executor = True)
|
||||
else:
|
||||
ndarrays = map(OutputFile.maybe_do_quantize, ndarrays_inner)
|
||||
|
||||
|
@ -939,7 +958,8 @@ def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyM
|
|||
for (name, tensor) in model.items()}
|
||||
|
||||
def convert_model_names(model: LazyModel, params: Params) -> LazyModel:
|
||||
tmap = gguf.get_tensor_name_map(ARCH, params.n_layer)
|
||||
tmap = gguf.TensorNameMap(ARCH, params.n_layer)
|
||||
should_skip: Set[gguf.MODEL_TENSOR] = set(gguf.MODEL_TENSOR_SKIP.get(ARCH, []))
|
||||
|
||||
tmp = model
|
||||
|
||||
|
@ -952,8 +972,8 @@ def convert_model_names(model: LazyModel, params: Params) -> LazyModel:
|
|||
#tmp[f"model.layers.{i}.self_attn.v_proj.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"]
|
||||
elif f"model.layers.{i}.self_attn.W_pack.weight" in model:
|
||||
print(f"Unpacking and permuting layer {i}")
|
||||
tmp[f"model.layers.{i}.self_attn.q_proj.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 0, params.n_head)
|
||||
tmp[f"model.layers.{i}.self_attn.k_proj.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 1, params.n_head)
|
||||
tmp[f"model.layers.{i}.self_attn.q_proj.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 0, params.n_head, params.n_head)
|
||||
tmp[f"model.layers.{i}.self_attn.k_proj.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 1, params.n_head, params.n_head_kv)
|
||||
tmp[f"model.layers.{i}.self_attn.v_proj.weight"] = part_lazy (model[f"model.layers.{i}.self_attn.W_pack.weight"], 2)
|
||||
del tmp[f"model.layers.{i}.self_attn.W_pack.weight"]
|
||||
else:
|
||||
|
@ -961,23 +981,16 @@ def convert_model_names(model: LazyModel, params: Params) -> LazyModel:
|
|||
|
||||
out: LazyModel = {}
|
||||
for name, lazy_tensor in model.items():
|
||||
name_new = name
|
||||
|
||||
if name in tmap:
|
||||
name_new = tmap[name]
|
||||
elif name.endswith(".weight") and name[:-7] in tmap:
|
||||
name_new = tmap[name[:-7]] + ".weight"
|
||||
elif name.endswith(".bias") and name[:-5] in tmap:
|
||||
name_new = tmap[name[:-5]] + ".bias"
|
||||
else:
|
||||
tensor_type, name_new = tmap.get_type_and_name(name, try_suffixes = (".weight", ".bias")) or (None, None)
|
||||
if name_new is None:
|
||||
raise Exception(f"Unexpected tensor name: {name}")
|
||||
|
||||
if gguf.should_skip_tensor_TMP(ARCH, params.n_layer, name_new):
|
||||
if tensor_type in should_skip:
|
||||
print(f"skipping tensor {name_new}")
|
||||
continue
|
||||
else:
|
||||
print(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type.name:6s} | {lazy_tensor.shape}")
|
||||
out[name_new] = lazy_tensor
|
||||
|
||||
print(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type.name:6s} | {lazy_tensor.shape}")
|
||||
out[name_new] = lazy_tensor
|
||||
|
||||
return out
|
||||
|
||||
|
@ -1117,8 +1130,16 @@ def main(args_in: Optional[List[str]] = None) -> None:
|
|||
if args.dump_single:
|
||||
model_plus = lazy_load_file(args.model)
|
||||
do_dump_model(model_plus)
|
||||
return
|
||||
|
||||
model_plus = load_some_model(args.model)
|
||||
if not args.vocab_only:
|
||||
model_plus = load_some_model(args.model)
|
||||
else:
|
||||
model_plus = ModelPlus(model = {}, paths = [args.model / 'dummy'], format = 'none', vocab = None)
|
||||
|
||||
if args.dump:
|
||||
do_dump_model(model_plus)
|
||||
return
|
||||
|
||||
params = Params.load(model_plus)
|
||||
if params.n_ctx == -1:
|
||||
|
@ -1140,33 +1161,34 @@ def main(args_in: Optional[List[str]] = None) -> None:
|
|||
|
||||
vocab: Vocab
|
||||
if args.vocab_only:
|
||||
vocab = load_vocab(args.vocab_dir or args.model, args.vocabtype)
|
||||
assert args.outfile, "need --outfile if using --vocab-only"
|
||||
# FIXME: Try to respect vocab_dir somehow?
|
||||
vocab = load_vocab(args.vocab_dir or args.model, args.vocabtype)
|
||||
special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent, load_merges = args.vocabtype == 'bpe')
|
||||
outfile = args.outfile
|
||||
OutputFile.write_vocab_only(outfile, params, vocab)
|
||||
OutputFile.write_vocab_only(outfile, params, vocab, special_vocab)
|
||||
print(f"Wrote {outfile}")
|
||||
return
|
||||
|
||||
if model_plus.vocab is not None and args.vocab_dir is None:
|
||||
vocab = model_plus.vocab
|
||||
else:
|
||||
if args.dump:
|
||||
do_dump_model(model_plus)
|
||||
return
|
||||
vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent
|
||||
vocab = load_vocab(vocab_dir, args.vocabtype)
|
||||
# FIXME: Try to respect vocab_dir somehow?
|
||||
special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent, load_merges = args.vocabtype == 'bpe')
|
||||
|
||||
if model_plus.vocab is not None and args.vocab_dir is None:
|
||||
vocab = model_plus.vocab
|
||||
else:
|
||||
vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent
|
||||
vocab = load_vocab(vocab_dir, args.vocabtype)
|
||||
model = model_plus.model
|
||||
model = convert_model_names(model, params)
|
||||
ftype = pick_output_type(model, args.outtype)
|
||||
model = convert_to_output_type(model, ftype)
|
||||
outfile = args.outfile or default_outfile(model_plus.paths, ftype)
|
||||
|
||||
model = model_plus.model
|
||||
model = convert_model_names(model, params)
|
||||
ftype = pick_output_type(model, args.outtype)
|
||||
model = convert_to_output_type(model, ftype)
|
||||
outfile = args.outfile or default_outfile(model_plus.paths, ftype)
|
||||
params.ftype = ftype
|
||||
print(f"Writing {outfile}, format {ftype}")
|
||||
|
||||
params.ftype = ftype
|
||||
print(f"Writing {outfile}, format {ftype}")
|
||||
|
||||
OutputFile.write_all(outfile, ftype, params, model, vocab, concurrency = args.concurrency)
|
||||
print(f"Wrote {outfile}")
|
||||
OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab, concurrency = args.concurrency)
|
||||
print(f"Wrote {outfile}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue