models without vocabulary, convert.py part
This commit is contained in:
parent
cc3fe18b43
commit
4f4258fbde
1 changed files with 113 additions and 54 deletions
167
convert.py
167
convert.py
|
@ -385,6 +385,12 @@ class BpeVocab:
|
||||||
yield from self.bpe_tokens()
|
yield from self.bpe_tokens()
|
||||||
yield from self.added_tokens()
|
yield from self.added_tokens()
|
||||||
|
|
||||||
|
def get_tokenizer_model(self) -> str:
|
||||||
|
return "gpt2"
|
||||||
|
|
||||||
|
def get_name(self) -> str:
|
||||||
|
return "bpe"
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"<BpeVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
|
return f"<BpeVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
|
||||||
|
|
||||||
|
@ -448,6 +454,12 @@ class SentencePieceVocab:
|
||||||
yield from self.sentencepiece_tokens()
|
yield from self.sentencepiece_tokens()
|
||||||
yield from self.added_tokens()
|
yield from self.added_tokens()
|
||||||
|
|
||||||
|
def get_tokenizer_model(self) -> str:
|
||||||
|
return "llama"
|
||||||
|
|
||||||
|
def get_name(self) -> str:
|
||||||
|
return "spm"
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"<SentencePieceVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
|
return f"<SentencePieceVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
|
||||||
|
|
||||||
|
@ -549,11 +561,28 @@ class HfVocab:
|
||||||
yield from self.hf_tokens()
|
yield from self.hf_tokens()
|
||||||
yield from self.added_tokens()
|
yield from self.added_tokens()
|
||||||
|
|
||||||
|
def get_tokenizer_model(self) -> str:
|
||||||
|
return "llama"
|
||||||
|
|
||||||
|
def get_name(self) -> str:
|
||||||
|
return "hfft"
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"<HfVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
|
return f"<HfVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
|
||||||
|
|
||||||
|
|
||||||
Vocab: TypeAlias = "BpeVocab | SentencePieceVocab | HfVocab"
|
class NoVocab:
|
||||||
|
def get_tokenizer_model(self) -> str:
|
||||||
|
return "no_vocab"
|
||||||
|
|
||||||
|
def get_name(self) -> str:
|
||||||
|
return "no_vocab"
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return "<NoVocab for a model without integrated vocabulary>"
|
||||||
|
|
||||||
|
|
||||||
|
Vocab: TypeAlias = "BpeVocab | SentencePieceVocab | HfVocab | NoVocab"
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
|
@ -931,13 +960,18 @@ def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], conc
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
|
|
||||||
def check_vocab_size(params: Params, vocab: Vocab, pad_vocab: bool = False) -> None:
|
def check_vocab_size(params: Params, vocab: Vocab) -> None:
|
||||||
# Handle special case where the model's vocab size is not set
|
# Handle special case where the model's vocab size is not set
|
||||||
if params.n_vocab == -1:
|
if params.n_vocab == -1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The model's vocab size is set to -1 in params.json. Please update it manually. Maybe {vocab.vocab_size}?"
|
f"The model's vocab size is set to -1 in params.json. Please update it manually.{f' Maybe {vocab.vocab_size}?' if hasattr(vocab, 'vocab_size') else ''}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_vocab(params: Params, vocab: Vocab, pad_vocab: bool = False) -> None:
|
||||||
|
assert not isinstance(vocab, NoVocab)
|
||||||
|
check_vocab_size(params, vocab)
|
||||||
|
|
||||||
# Check for a vocab size mismatch
|
# Check for a vocab size mismatch
|
||||||
if params.n_vocab == vocab.vocab_size:
|
if params.n_vocab == vocab.vocab_size:
|
||||||
print("Ignoring added_tokens.json since model matches vocab size without it.")
|
print("Ignoring added_tokens.json since model matches vocab size without it.")
|
||||||
|
@ -1014,20 +1048,6 @@ class OutputFile:
|
||||||
if params.ftype is not None:
|
if params.ftype is not None:
|
||||||
self.gguf.add_file_type(params.ftype)
|
self.gguf.add_file_type(params.ftype)
|
||||||
|
|
||||||
def handle_tokenizer_model(self, vocab: Vocab) -> str:
|
|
||||||
# Map the vocab types to the supported tokenizer models
|
|
||||||
tokenizer_model = {
|
|
||||||
SentencePieceVocab: "llama",
|
|
||||||
HfVocab: "llama",
|
|
||||||
BpeVocab: "gpt2",
|
|
||||||
}.get(type(vocab))
|
|
||||||
|
|
||||||
# Block if vocab type is not predefined
|
|
||||||
if tokenizer_model is None:
|
|
||||||
raise ValueError("Unknown vocab type: Not supported")
|
|
||||||
|
|
||||||
return tokenizer_model
|
|
||||||
|
|
||||||
def extract_vocabulary_from_model(self, vocab: Vocab) -> tuple[list[bytes], list[float], list[gguf.TokenType]]:
|
def extract_vocabulary_from_model(self, vocab: Vocab) -> tuple[list[bytes], list[float], list[gguf.TokenType]]:
|
||||||
tokens = []
|
tokens = []
|
||||||
scores = []
|
scores = []
|
||||||
|
@ -1044,11 +1064,8 @@ class OutputFile:
|
||||||
return tokens, scores, toktypes
|
return tokens, scores, toktypes
|
||||||
|
|
||||||
def add_meta_vocab(self, vocab: Vocab) -> None:
|
def add_meta_vocab(self, vocab: Vocab) -> None:
|
||||||
# Handle the tokenizer model
|
|
||||||
tokenizer_model = self.handle_tokenizer_model(vocab)
|
|
||||||
|
|
||||||
# Ensure that tokenizer_model is added to the GGUF model
|
# Ensure that tokenizer_model is added to the GGUF model
|
||||||
self.gguf.add_tokenizer_model(tokenizer_model)
|
self.gguf.add_tokenizer_model(vocab.get_tokenizer_model())
|
||||||
|
|
||||||
# Extract model vocabulary for model conversion
|
# Extract model vocabulary for model conversion
|
||||||
tokens, scores, toktypes = self.extract_vocabulary_from_model(vocab)
|
tokens, scores, toktypes = self.extract_vocabulary_from_model(vocab)
|
||||||
|
@ -1075,6 +1092,26 @@ class OutputFile:
|
||||||
def write_tensor_info(self) -> None:
|
def write_tensor_info(self) -> None:
|
||||||
self.gguf.write_ti_data_to_file()
|
self.gguf.write_ti_data_to_file()
|
||||||
|
|
||||||
|
def write_tensor_data(self, ftype: GGMLFileType, model: LazyModel, concurrency: int) -> None:
|
||||||
|
ndarrays_inner = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency=concurrency)
|
||||||
|
if ftype == GGMLFileType.MostlyQ8_0:
|
||||||
|
ndarrays = bounded_parallel_map(
|
||||||
|
OutputFile.maybe_do_quantize, ndarrays_inner, concurrency=concurrency, max_workers=concurrency,
|
||||||
|
use_processpool_executor=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ndarrays = map(OutputFile.maybe_do_quantize, ndarrays_inner)
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)):
|
||||||
|
elapsed = time.time() - start
|
||||||
|
size = ' x '.join(f"{dim:6d}" for dim in lazy_tensor.shape)
|
||||||
|
padi = len(str(len(model)))
|
||||||
|
print(
|
||||||
|
f"[{i + 1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4} | T+{int(elapsed):4}"
|
||||||
|
)
|
||||||
|
self.gguf.write_tensor_data(ndarray)
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
self.gguf.close()
|
self.gguf.close()
|
||||||
|
|
||||||
|
@ -1083,7 +1120,7 @@ class OutputFile:
|
||||||
fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab,
|
fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab,
|
||||||
endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False,
|
endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
check_vocab_size(params, vocab, pad_vocab = pad_vocab)
|
prepare_vocab(params, vocab, pad_vocab=pad_vocab)
|
||||||
|
|
||||||
of = OutputFile(fname_out, endianess=endianess)
|
of = OutputFile(fname_out, endianess=endianess)
|
||||||
|
|
||||||
|
@ -1115,7 +1152,7 @@ class OutputFile:
|
||||||
concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE,
|
concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE,
|
||||||
pad_vocab: bool = False,
|
pad_vocab: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
check_vocab_size(params, vocab, pad_vocab=pad_vocab)
|
prepare_vocab(params, vocab, pad_vocab=pad_vocab)
|
||||||
|
|
||||||
of = OutputFile(fname_out, endianess=endianess)
|
of = OutputFile(fname_out, endianess=endianess)
|
||||||
|
|
||||||
|
@ -1132,24 +1169,33 @@ class OutputFile:
|
||||||
of.write_tensor_info()
|
of.write_tensor_info()
|
||||||
|
|
||||||
# tensor data
|
# tensor data
|
||||||
ndarrays_inner = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency = concurrency)
|
of.write_tensor_data(ftype, model, concurrency)
|
||||||
if ftype == GGMLFileType.MostlyQ8_0:
|
|
||||||
ndarrays = bounded_parallel_map(
|
|
||||||
OutputFile.maybe_do_quantize, ndarrays_inner, concurrency=concurrency, max_workers=concurrency,
|
|
||||||
use_processpool_executor=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
ndarrays = map(OutputFile.maybe_do_quantize, ndarrays_inner)
|
|
||||||
|
|
||||||
start = time.time()
|
of.close()
|
||||||
for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)):
|
|
||||||
elapsed = time.time() - start
|
@staticmethod
|
||||||
size = ' x '.join(f"{dim:6d}" for dim in lazy_tensor.shape)
|
def write_without_vocab(
|
||||||
padi = len(str(len(model)))
|
fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab,
|
||||||
print(
|
concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE,
|
||||||
f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4} | T+{int(elapsed):4}"
|
) -> None:
|
||||||
)
|
assert isinstance(vocab, NoVocab)
|
||||||
of.gguf.write_tensor_data(ndarray)
|
check_vocab_size(params, vocab)
|
||||||
|
|
||||||
|
of = OutputFile(fname_out, endianess=endianess)
|
||||||
|
|
||||||
|
# meta data
|
||||||
|
of.add_meta_arch(params)
|
||||||
|
of.gguf.add_tokenizer_model(vocab.get_tokenizer_model())
|
||||||
|
|
||||||
|
# tensor info
|
||||||
|
for name, lazy_tensor in model.items():
|
||||||
|
of.add_tensor_info(name, lazy_tensor)
|
||||||
|
|
||||||
|
of.write_meta()
|
||||||
|
of.write_tensor_info()
|
||||||
|
|
||||||
|
# tensor data
|
||||||
|
of.write_tensor_data(ftype, model, concurrency)
|
||||||
|
|
||||||
of.close()
|
of.close()
|
||||||
|
|
||||||
|
@ -1310,8 +1356,8 @@ class VocabFactory:
|
||||||
return vtype, path
|
return vtype, path
|
||||||
raise FileNotFoundError(f"Could not find any of {[self._FILES[vt] for vt in vocab_types]}")
|
raise FileNotFoundError(f"Could not find any of {[self._FILES[vt] for vt in vocab_types]}")
|
||||||
|
|
||||||
def _create_special_vocab(self, vocab: Vocab, vocabtype: str, model_parent_path: Path) -> gguf.SpecialVocab:
|
def _create_special_vocab(self, vocab: Vocab, model_parent_path: Path) -> gguf.SpecialVocab:
|
||||||
load_merges = vocabtype == "bpe"
|
load_merges = vocab.get_name() == "bpe"
|
||||||
n_vocab = vocab.vocab_size if hasattr(vocab, "vocab_size") else None
|
n_vocab = vocab.vocab_size if hasattr(vocab, "vocab_size") else None
|
||||||
return gguf.SpecialVocab(
|
return gguf.SpecialVocab(
|
||||||
model_parent_path,
|
model_parent_path,
|
||||||
|
@ -1320,30 +1366,34 @@ class VocabFactory:
|
||||||
n_vocab=n_vocab,
|
n_vocab=n_vocab,
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_vocab(self, vocab_types: list[str], model_parent_path: Path) -> tuple[Vocab, gguf.SpecialVocab]:
|
def _create_vocab_by_path(self, vocab_types: list[str]) -> Vocab:
|
||||||
vocab_type, path = self._select_file(vocab_types)
|
vocab_type, path = self._select_file(vocab_types)
|
||||||
print(f"Loading vocab file {path!r}, type {vocab_type!r}")
|
print(f"Loading vocab file {path!r}, type {vocab_type!r}")
|
||||||
|
|
||||||
added_tokens_path = path.parent / "added_tokens.json"
|
added_tokens_path = path.parent / "added_tokens.json"
|
||||||
vocab: Vocab
|
|
||||||
if vocab_type == "bpe":
|
if vocab_type == "bpe":
|
||||||
vocab = BpeVocab(
|
return BpeVocab(
|
||||||
path, added_tokens_path if added_tokens_path.exists() else None
|
path, added_tokens_path if added_tokens_path.exists() else None
|
||||||
)
|
)
|
||||||
elif vocab_type == "spm":
|
if vocab_type == "spm":
|
||||||
vocab = SentencePieceVocab(
|
return SentencePieceVocab(
|
||||||
path, added_tokens_path if added_tokens_path.exists() else None
|
path, added_tokens_path if added_tokens_path.exists() else None
|
||||||
)
|
)
|
||||||
elif vocab_type == "hfft":
|
if vocab_type == "hfft":
|
||||||
vocab = HfVocab(
|
return HfVocab(
|
||||||
path.parent, added_tokens_path if added_tokens_path.exists() else None
|
path.parent, added_tokens_path if added_tokens_path.exists() else None
|
||||||
)
|
)
|
||||||
|
raise ValueError(vocab_type)
|
||||||
|
|
||||||
|
def load_vocab(self, vocab_types: list[str], model_parent_path: Path) -> tuple[Vocab, gguf.SpecialVocab]:
|
||||||
|
vocab: Vocab
|
||||||
|
if len(vocab_types) == 1 and "no_vocab" in vocab_types:
|
||||||
|
vocab = NoVocab()
|
||||||
else:
|
else:
|
||||||
raise ValueError(vocab_type)
|
vocab = self._create_vocab_by_path(vocab_types)
|
||||||
# FIXME: Respect --vocab-dir?
|
# FIXME: Respect --vocab-dir?
|
||||||
special_vocab = self._create_special_vocab(
|
special_vocab = self._create_special_vocab(
|
||||||
vocab,
|
vocab,
|
||||||
vocab_type,
|
|
||||||
model_parent_path,
|
model_parent_path,
|
||||||
)
|
)
|
||||||
return vocab, special_vocab
|
return vocab, special_vocab
|
||||||
|
@ -1381,6 +1431,7 @@ def main(args_in: list[str] | None = None) -> None:
|
||||||
parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model")
|
parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model")
|
||||||
parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file")
|
parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file")
|
||||||
parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab")
|
parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab")
|
||||||
|
parser.add_argument("--no-vocab", action="store_true", help="store model without the vocab")
|
||||||
parser.add_argument("--outtype", choices=output_choices, help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)")
|
parser.add_argument("--outtype", choices=output_choices, help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)")
|
||||||
parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file")
|
parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file")
|
||||||
parser.add_argument("--vocab-type", help="vocab types to try in order, choose from 'spm', 'bpe', 'hfft' (default: spm,hfft)", default="spm,hfft")
|
parser.add_argument("--vocab-type", help="vocab types to try in order, choose from 'spm', 'bpe', 'hfft' (default: spm,hfft)", default="spm,hfft")
|
||||||
|
@ -1393,6 +1444,10 @@ def main(args_in: list[str] | None = None) -> None:
|
||||||
parser.add_argument("--skip-unknown", action="store_true", help="skip unknown tensor names instead of failing")
|
parser.add_argument("--skip-unknown", action="store_true", help="skip unknown tensor names instead of failing")
|
||||||
|
|
||||||
args = parser.parse_args(args_in)
|
args = parser.parse_args(args_in)
|
||||||
|
if args.no_vocab:
|
||||||
|
if args.vocab_only:
|
||||||
|
raise ValueError("no need to specify --vocab-only if using --no-vocab")
|
||||||
|
args.vocab_type = "no_vocab"
|
||||||
|
|
||||||
if args.dump_single:
|
if args.dump_single:
|
||||||
model_plus = lazy_load_file(args.model)
|
model_plus = lazy_load_file(args.model)
|
||||||
|
@ -1443,7 +1498,7 @@ def main(args_in: list[str] | None = None) -> None:
|
||||||
print(f"Wrote {outfile}")
|
print(f"Wrote {outfile}")
|
||||||
return
|
return
|
||||||
|
|
||||||
if model_plus.vocab is not None and args.vocab_dir is None:
|
if model_plus.vocab is not None and args.vocab_dir is None and not args.no_vocab:
|
||||||
vocab = model_plus.vocab
|
vocab = model_plus.vocab
|
||||||
|
|
||||||
print(f"Vocab info: {vocab}")
|
print(f"Vocab info: {vocab}")
|
||||||
|
@ -1458,8 +1513,12 @@ def main(args_in: list[str] | None = None) -> None:
|
||||||
params.ftype = ftype
|
params.ftype = ftype
|
||||||
print(f"Writing {outfile}, format {ftype}")
|
print(f"Writing {outfile}, format {ftype}")
|
||||||
|
|
||||||
OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab,
|
if not args.no_vocab:
|
||||||
concurrency=args.concurrency, endianess=endianess, pad_vocab=args.pad_vocab)
|
OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab,
|
||||||
|
concurrency=args.concurrency, endianess=endianess, pad_vocab=args.pad_vocab)
|
||||||
|
else:
|
||||||
|
OutputFile.write_without_vocab(outfile, ftype, params, model, vocab,
|
||||||
|
concurrency=args.concurrency, endianess=endianess)
|
||||||
print(f"Wrote {outfile}")
|
print(f"Wrote {outfile}")
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue