Move split functionality to new GGUFManager class
This commit is contained in:
parent
72cbd4e014
commit
702a744670
3 changed files with 556 additions and 176 deletions
190
convert.py
190
convert.py
|
@ -24,13 +24,15 @@ from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, IO, Iterable, Literal, Protocol, TypeVar, runtime_checkable
|
from typing import TYPE_CHECKING, Any, Callable, ClassVar, IO, Iterable, Literal, Protocol, TypeVar, runtime_checkable
|
||||||
|
import importlib
|
||||||
|
gguf = importlib.import_module("gguf-py.gguf")
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sentencepiece import SentencePieceProcessor
|
from sentencepiece import SentencePieceProcessor
|
||||||
|
|
||||||
if 'NO_LOCAL_GGUF' not in os.environ:
|
if 'NO_LOCAL_GGUF' not in os.environ:
|
||||||
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
|
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
|
||||||
import gguf
|
# import gguf
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing_extensions import Self, TypeAlias
|
from typing_extensions import Self, TypeAlias
|
||||||
|
@ -47,15 +49,6 @@ DEFAULT_CONCURRENCY = 8
|
||||||
ADDED_TOKENS_FILE = 'added_tokens.json'
|
ADDED_TOKENS_FILE = 'added_tokens.json'
|
||||||
FAST_TOKENIZER_FILE = 'tokenizer.json'
|
FAST_TOKENIZER_FILE = 'tokenizer.json'
|
||||||
|
|
||||||
LLM_KV_SPLIT_NO = "split.no"
|
|
||||||
LLM_KV_SPLIT_COUNT = "split.count"
|
|
||||||
LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count"
|
|
||||||
SHARD_NAME_FORMAT = "{:s}-{:05d}-of-{:05d}.gguf"
|
|
||||||
|
|
||||||
SPLIT_STYLE_NONE = 0
|
|
||||||
SPLIT_STYLE_BY_TENSORS = 1
|
|
||||||
SPLIT_STYLE_BY_SIZE = 2
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# data types
|
# data types
|
||||||
#
|
#
|
||||||
|
@ -1066,8 +1059,8 @@ def check_vocab_size(params: Params, vocab: BaseVocab, pad_vocab: bool = False)
|
||||||
|
|
||||||
|
|
||||||
class OutputFile:
|
class OutputFile:
|
||||||
def __init__(self, fname_out: Path, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE):
|
def __init__(self, fname_out: Path, args: argparse.Namespace, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE):
|
||||||
self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess)
|
self.gguf = gguf.GGUFManager(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], args, endianess=endianess)
|
||||||
|
|
||||||
def add_meta_arch(self, params: Params) -> None:
|
def add_meta_arch(self, params: Params) -> None:
|
||||||
name = "LLaMA"
|
name = "LLaMA"
|
||||||
|
@ -1146,21 +1139,15 @@ class OutputFile:
|
||||||
def add_meta_special_vocab(self, svocab: gguf.SpecialVocab) -> None:
|
def add_meta_special_vocab(self, svocab: gguf.SpecialVocab) -> None:
|
||||||
svocab.add_to_gguf(self.gguf)
|
svocab.add_to_gguf(self.gguf)
|
||||||
|
|
||||||
def add_tensor_info(self, name: str, tensor: LazyTensor) -> None:
|
|
||||||
n_elements = int(np.prod(tensor.shape))
|
|
||||||
raw_dtype = getattr(tensor.data_type, 'ggml_type', None)
|
|
||||||
data_type = getattr(tensor.data_type, 'quantized_type', None) or tensor.data_type.dtype
|
|
||||||
data_nbytes = tensor.data_type.elements_to_bytes(n_elements)
|
|
||||||
self.gguf.add_tensor_info(name, tensor.shape, data_type, data_nbytes, raw_dtype=raw_dtype)
|
|
||||||
|
|
||||||
def write_meta(self) -> None:
|
def write_meta(self) -> None:
|
||||||
self.gguf.write_header_to_file()
|
self.gguf.write_to_file(meta_only=True)
|
||||||
self.gguf.write_kv_data_to_file()
|
|
||||||
|
|
||||||
def write_tensor_info(self) -> None:
|
def write_tensors(self, ftype: GGMLFileType, concurrency: int) -> None:
|
||||||
self.gguf.write_ti_data_to_file()
|
self.gguf.write_to_file(ftype=ftype, concurrency=concurrency, write_tensor_data=OutputFile.write_tensor_data)
|
||||||
|
|
||||||
def write_tensor_data(self, ftype: GGMLFileType, model: LazyModel, concurrency: int) -> None:
|
# really awkward with how this is managed with gguf_manager.py: maybe refactor at some point?
|
||||||
|
@staticmethod
|
||||||
|
def write_tensor_data(ftype: GGMLFileType, model: LazyModel, concurrency: int, writer: gguf.GGUFWriter) -> None:
|
||||||
ndarrays_inner = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency=concurrency)
|
ndarrays_inner = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency=concurrency)
|
||||||
if ftype == GGMLFileType.MostlyQ8_0:
|
if ftype == GGMLFileType.MostlyQ8_0:
|
||||||
ndarrays = bounded_parallel_map(
|
ndarrays = bounded_parallel_map(
|
||||||
|
@ -1178,7 +1165,7 @@ class OutputFile:
|
||||||
print(
|
print(
|
||||||
f"[{i + 1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4} | T+{int(elapsed):4}"
|
f"[{i + 1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4} | T+{int(elapsed):4}"
|
||||||
)
|
)
|
||||||
self.gguf.write_tensor_data(ndarray)
|
writer.write_tensor_data(ndarray)
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
self.gguf.close()
|
self.gguf.close()
|
||||||
|
@ -1217,45 +1204,11 @@ class OutputFile:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def write_all(
|
def write_all(
|
||||||
fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: BaseVocab, svocab: gguf.SpecialVocab,
|
fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: BaseVocab, svocab: gguf.SpecialVocab,
|
||||||
tensors_per_shard: int, tensors_max_size: int, dry_run: bool = False, concurrency: int = DEFAULT_CONCURRENCY,
|
args: argparse.Namespace, concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE
|
||||||
endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, small_first_shard: bool = True,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
check_vocab_size(params, vocab, pad_vocab=pad_vocab)
|
check_vocab_size(params, vocab, pad_vocab=args.pad_vocab)
|
||||||
|
of = OutputFile(fname_out, args, endianess=endianess)
|
||||||
|
|
||||||
total_tensors = len(model)
|
|
||||||
total_size = sum(get_tensor_size(lazy_tensor) for lazy_tensor in model.values())
|
|
||||||
|
|
||||||
if tensors_per_shard:
|
|
||||||
split_style = SPLIT_STYLE_BY_TENSORS
|
|
||||||
elif tensors_max_size:
|
|
||||||
split_style = SPLIT_STYLE_BY_SIZE
|
|
||||||
else:
|
|
||||||
split_style = SPLIT_STYLE_NONE
|
|
||||||
|
|
||||||
if tensors_per_shard and total_tensors < tensors_per_shard:
|
|
||||||
print("Model has fewer tensors than the split threshold, not splitting")
|
|
||||||
split_style = SPLIT_STYLE_NONE
|
|
||||||
|
|
||||||
if tensors_max_size and total_size < tensors_max_size:
|
|
||||||
print("Model has smaller size than the split threshold, not splitting")
|
|
||||||
split_style = SPLIT_STYLE_NONE
|
|
||||||
|
|
||||||
split_strategy = create_split_strategy(split_style, fname_out, model, tensors_per_shard, tensors_max_size, small_first_shard)
|
|
||||||
total_shards = len(split_strategy)
|
|
||||||
|
|
||||||
print("Writing the following files:")
|
|
||||||
for shard_path, shard_tensors in split_strategy:
|
|
||||||
size = format_n_bytes_to_str(sum(get_tensor_size(t[1]) for t in shard_tensors)) if shard_tensors else "negligible - metadata only"
|
|
||||||
print(f" {shard_path}: n_tensors = {len(shard_tensors) if shard_tensors else 0}, total_size = {size}")
|
|
||||||
|
|
||||||
if dry_run:
|
|
||||||
print("Dry run, not writing files")
|
|
||||||
return
|
|
||||||
|
|
||||||
for i, (shard_path, shard_tensors) in enumerate(split_strategy):
|
|
||||||
of = OutputFile(shard_path, endianess=endianess)
|
|
||||||
|
|
||||||
if i == 0:
|
|
||||||
# meta data
|
# meta data
|
||||||
of.add_meta_arch(params)
|
of.add_meta_arch(params)
|
||||||
if isinstance(vocab, Vocab):
|
if isinstance(vocab, Vocab):
|
||||||
|
@ -1264,111 +1217,15 @@ class OutputFile:
|
||||||
else: # NoVocab
|
else: # NoVocab
|
||||||
of.gguf.add_tokenizer_model(vocab.tokenizer_model)
|
of.gguf.add_tokenizer_model(vocab.tokenizer_model)
|
||||||
|
|
||||||
# have the option to write a first shard with only the metadata
|
|
||||||
if split_style != SPLIT_STYLE_NONE:
|
|
||||||
|
|
||||||
of.gguf.add_uint16(LLM_KV_SPLIT_NO, i)
|
|
||||||
of.gguf.add_uint16(LLM_KV_SPLIT_COUNT, total_shards)
|
|
||||||
of.gguf.add_int32(LLM_KV_SPLIT_TENSORS_COUNT, total_tensors)
|
|
||||||
|
|
||||||
if small_first_shard and i == 0:
|
|
||||||
of.write_meta()
|
|
||||||
of.close()
|
|
||||||
continue
|
|
||||||
|
|
||||||
print(f"Writing shard {i + 1}/{total_shards} with {len(shard_tensors)} tensors")
|
|
||||||
|
|
||||||
# tensor info
|
# tensor info
|
||||||
for name, lazy_tensor in shard_tensors:
|
for name, lazy_tensor in model.items():
|
||||||
of.add_tensor_info(name, lazy_tensor)
|
of.gguf.add_tensor_info(name, lazy_tensor)
|
||||||
|
|
||||||
of.write_meta()
|
of.write_tensors(ftype, concurrency)
|
||||||
of.write_tensor_info()
|
|
||||||
of.write_tensor_data(ftype, dict(shard_tensors), concurrency)
|
|
||||||
|
|
||||||
of.close()
|
of.close()
|
||||||
|
|
||||||
|
|
||||||
def split_str_to_n_bytes(split_str: str) -> int:
|
|
||||||
if split_str.endswith("K"):
|
|
||||||
n = int(split_str[:-1]) * 1024
|
|
||||||
elif split_str.endswith("M"):
|
|
||||||
n = int(split_str[:-1]) * 1024 * 1024
|
|
||||||
elif split_str.endswith("G"):
|
|
||||||
n = int(split_str[:-1]) * 1024 * 1024 * 1024
|
|
||||||
elif split_str.isnumeric():
|
|
||||||
n = int(split_str)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid split size: {split_str}, must be a number, optionally followed by K, M, or G")
|
|
||||||
|
|
||||||
if n <= 0:
|
|
||||||
raise ValueError(f"Invalid split size: {split_str}, must be positive")
|
|
||||||
|
|
||||||
return n
|
|
||||||
|
|
||||||
|
|
||||||
def format_n_bytes_to_str(num: int) -> str:
|
|
||||||
num = float(num)
|
|
||||||
for unit in ("", "K", "M", "G"):
|
|
||||||
if abs(num) < 1024.0:
|
|
||||||
return f"{num:3.1f}{unit}"
|
|
||||||
num /= 1024.0
|
|
||||||
return f"{num:.1f}T - over 1TB, --split recommended"
|
|
||||||
|
|
||||||
|
|
||||||
def get_tensor_size(tensor: LazyTensor) -> int:
|
|
||||||
return tensor.data_type.elements_to_bytes(np.prod(tensor.shape))
|
|
||||||
|
|
||||||
|
|
||||||
SplitStrategy: TypeAlias = 'list[tuple[Path, list[tuple[str, LazyTensor]]]]'
|
|
||||||
|
|
||||||
|
|
||||||
def create_split_strategy(split_style: int, fname_out: Path, model: LazyModel, tensors_per_shard: int, tensors_max_size: int, small_first_shard: bool) -> SplitStrategy:
|
|
||||||
if split_style == SPLIT_STYLE_NONE:
|
|
||||||
return [(fname_out, list(model.items()))]
|
|
||||||
|
|
||||||
elif split_style == SPLIT_STYLE_BY_TENSORS:
|
|
||||||
total_shards = math.ceil(len(model) / tensors_per_shard) + small_first_shard
|
|
||||||
shard_files = [fname_out.with_name(SHARD_NAME_FORMAT.format(fname_out.stem, i + 1, total_shards)) for i in range(total_shards)]
|
|
||||||
splits = []
|
|
||||||
|
|
||||||
if small_first_shard:
|
|
||||||
splits.append((shard_files[0], None))
|
|
||||||
|
|
||||||
for i, shard in enumerate(shard_files[small_first_shard:]):
|
|
||||||
start = i * tensors_per_shard
|
|
||||||
stop = min((i + 1) * tensors_per_shard, len(model))
|
|
||||||
splits.append((shard, list(model.items())[start:stop]))
|
|
||||||
|
|
||||||
return splits
|
|
||||||
|
|
||||||
elif split_style == SPLIT_STYLE_BY_SIZE:
|
|
||||||
shards = []
|
|
||||||
|
|
||||||
# we have to determine the shards first to determine how many shards there will be in total - two passes
|
|
||||||
for i, shard in enumerate(list(model.items())):
|
|
||||||
if i == 0:
|
|
||||||
shards.append([shard])
|
|
||||||
continue
|
|
||||||
if get_tensor_size(shard[1]) + sum(get_tensor_size(t[1]) for t in shards[-1]) > tensors_max_size:
|
|
||||||
shards.append([shard])
|
|
||||||
else:
|
|
||||||
shards[-1].append(shard)
|
|
||||||
|
|
||||||
total_shards = len(shards) + small_first_shard
|
|
||||||
shard_offset = 1
|
|
||||||
splits = []
|
|
||||||
|
|
||||||
if small_first_shard:
|
|
||||||
splits.append((fname_out.with_name(SHARD_NAME_FORMAT.format(fname_out.stem, shard_offset, total_shards)), None))
|
|
||||||
shard_offset += 1
|
|
||||||
|
|
||||||
for i, shard in enumerate(shards):
|
|
||||||
splits.append((fname_out.with_name(SHARD_NAME_FORMAT.format(fname_out.stem, i + shard_offset, total_shards)), shard))
|
|
||||||
|
|
||||||
return splits
|
|
||||||
|
|
||||||
|
|
||||||
def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileType:
|
def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileType:
|
||||||
wq_type = model[gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0) + ".weight"].data_type
|
wq_type = model[gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0) + ".weight"].data_type
|
||||||
|
|
||||||
|
@ -1607,8 +1464,8 @@ def main(args_in: list[str] | None = None) -> None:
|
||||||
parser.add_argument("--pad-vocab", action="store_true", help="add pad tokens when model vocab expects more than tokenizer metadata provides")
|
parser.add_argument("--pad-vocab", action="store_true", help="add pad tokens when model vocab expects more than tokenizer metadata provides")
|
||||||
parser.add_argument("--skip-unknown", action="store_true", help="skip unknown tensor names instead of failing")
|
parser.add_argument("--skip-unknown", action="store_true", help="skip unknown tensor names instead of failing")
|
||||||
parser.add_argument("--split", action="store_true", help="split the converted model into multiple files")
|
parser.add_argument("--split", action="store_true", help="split the converted model into multiple files")
|
||||||
parser.add_argument("--split-max-tensors", type=int, help=f"max tensors in each split")
|
parser.add_argument("--split-max-tensors", type=int, help="max tensors in each split")
|
||||||
parser.add_argument("--split-max-size", type=str, help=f"max size per split")
|
parser.add_argument("--split-max-size", type=str, help="max size per split N(M|G)+")
|
||||||
parser.add_argument("--dry-run", action="store_true", help="only print out a split plan and exit, without writing any new files")
|
parser.add_argument("--dry-run", action="store_true", help="only print out a split plan and exit, without writing any new files")
|
||||||
parser.add_argument("--large-first-shard", action="store_true", help="include tensors in the first shard when splitting (default: metadata only)")
|
parser.add_argument("--large-first-shard", action="store_true", help="include tensors in the first shard when splitting (default: metadata only)")
|
||||||
|
|
||||||
|
@ -1628,7 +1485,7 @@ def main(args_in: list[str] | None = None) -> None:
|
||||||
raise ValueError("Can't specify both --split-max-tensors and --split-max-size")
|
raise ValueError("Can't specify both --split-max-tensors and --split-max-size")
|
||||||
|
|
||||||
if args.split_max_size:
|
if args.split_max_size:
|
||||||
args.split_max_size = split_str_to_n_bytes(args.split_max_size)
|
args.split_max_size = gguf.SplitStrategy.split_str_to_n_bytes(args.split_max_size)
|
||||||
|
|
||||||
if not args.vocab_only:
|
if not args.vocab_only:
|
||||||
model_plus = load_some_model(args.model)
|
model_plus = load_some_model(args.model)
|
||||||
|
@ -1693,9 +1550,8 @@ def main(args_in: list[str] | None = None) -> None:
|
||||||
params.ftype = ftype
|
params.ftype = ftype
|
||||||
|
|
||||||
print(f"Writing {outfile}, format {ftype}")
|
print(f"Writing {outfile}, format {ftype}")
|
||||||
OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab, args.split_max_tensors,
|
OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab, args,
|
||||||
args.split_max_size, dry_run=args.dry_run, concurrency=args.concurrency,
|
concurrency=args.concurrency, endianess=endianess)
|
||||||
endianess=endianess, pad_vocab=args.pad_vocab, small_first_shard=not args.large_first_shard)
|
|
||||||
if not args.dry_run:
|
if not args.dry_run:
|
||||||
print(f"Wrote {outfile}")
|
print(f"Wrote {outfile}")
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from .constants import *
|
from .constants import *
|
||||||
from .gguf_reader import *
|
from .gguf_reader import *
|
||||||
from .gguf_writer import *
|
from .gguf_writer import *
|
||||||
|
from .gguf_manager import *
|
||||||
from .tensor_mapping import *
|
from .tensor_mapping import *
|
||||||
from .vocab import *
|
from .vocab import *
|
||||||
|
|
523
gguf-py/gguf/gguf_manager.py
Normal file
523
gguf-py/gguf/gguf_manager.py
Normal file
|
@ -0,0 +1,523 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import struct
|
||||||
|
import tempfile
|
||||||
|
from enum import IntEnum
|
||||||
|
from typing import TYPE_CHECKING, Any, Sequence, Mapping
|
||||||
|
from string import ascii_letters, digits
|
||||||
|
from argparse import Namespace
|
||||||
|
from math import ceil
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
|
from .constants import (
|
||||||
|
GGMLQuantizationType,
|
||||||
|
GGUFEndian,
|
||||||
|
GGUFValueType,
|
||||||
|
Keys,
|
||||||
|
RopeScalingType,
|
||||||
|
PoolingType,
|
||||||
|
TokenType,
|
||||||
|
)
|
||||||
|
from .gguf_writer import GGUFWriter
|
||||||
|
|
||||||
|
|
||||||
|
SHARD_NAME_FORMAT = "{:s}-{:05d}-of-{:05d}.gguf"
|
||||||
|
|
||||||
|
LLM_KV_SPLIT_NO = "split.no"
|
||||||
|
LLM_KV_SPLIT_COUNT = "split.count"
|
||||||
|
LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count"
|
||||||
|
|
||||||
|
SplitTensorsPerFile: TypeAlias = list[tuple[os.PathLike[str], list[tuple[str, Any]], GGUFWriter]] # [(outfile name, [(tensor name, tensor data)] for each tensor in file, filewriter)]
|
||||||
|
KVTempData: TypeAlias = dict[str, tuple[Any, GGUFValueType]] # {key: (value, type)}
|
||||||
|
TensorTempData: TypeAlias = tuple[str, np.ndarray[Any, Any]] # (tensor name, tensor data), aka LazyModel
|
||||||
|
|
||||||
|
|
||||||
|
class SplitStyle(IntEnum):
|
||||||
|
NONE = 0
|
||||||
|
TENSORS = 1
|
||||||
|
SIZE = 2
|
||||||
|
|
||||||
|
|
||||||
|
class SplitStrategy:
|
||||||
|
data: SplitTensorsPerFile
|
||||||
|
|
||||||
|
def __init__(self, split_style: SplitStyle, fname_out: os.PathLike[str], model: list[TensorTempData],
|
||||||
|
args: Namespace, arch: str, use_temp_file: bool = True, endianess: GGUFEndian = GGUFEndian.LITTLE, small_first_shard: bool = True
|
||||||
|
):
|
||||||
|
self.data = []
|
||||||
|
|
||||||
|
if split_style == SplitStyle.NONE:
|
||||||
|
self.append((fname_out, model, GGUFWriter(fname_out, arch, use_temp_file=use_temp_file, endianess=endianess)))
|
||||||
|
|
||||||
|
elif split_style == SplitStyle.TENSORS:
|
||||||
|
total_shards = ceil(len(model) / args.split_max_tensors) + small_first_shard
|
||||||
|
shard_files = [fname_out.with_name(SHARD_NAME_FORMAT.format(fname_out.stem, i + 1, total_shards)) for i in range(total_shards)]
|
||||||
|
|
||||||
|
if small_first_shard:
|
||||||
|
self.append((shard_files[0], None, GGUFWriter(shard_files[0], arch, use_temp_file=use_temp_file, endianess=endianess)))
|
||||||
|
|
||||||
|
for i, shard in enumerate(shard_files[small_first_shard:]):
|
||||||
|
start = i * args.split_max_tensors
|
||||||
|
stop = min((i + 1) * args.split_max_tensors, len(model))
|
||||||
|
self.append((shard, model[start:stop], GGUFWriter(shard, arch, use_temp_file=use_temp_file, endianess=endianess)))
|
||||||
|
|
||||||
|
elif split_style == SplitStyle.SIZE:
|
||||||
|
shards = []
|
||||||
|
|
||||||
|
# we have to determine the shards first to determine how many shards there will be in total - two passes
|
||||||
|
for i, shard in enumerate(model):
|
||||||
|
if i == 0:
|
||||||
|
shards.append([shard])
|
||||||
|
continue
|
||||||
|
if SplitStrategy.get_tensor_size(shard[1]) + sum(SplitStrategy.get_tensor_size(t[1]) for t in shards[-1]) > args.split_max_size:
|
||||||
|
shards.append([shard])
|
||||||
|
else:
|
||||||
|
shards[-1].append(shard)
|
||||||
|
|
||||||
|
total_shards = len(shards) + small_first_shard
|
||||||
|
shard_offset = 1
|
||||||
|
|
||||||
|
if small_first_shard:
|
||||||
|
outname = fname_out.with_name(SHARD_NAME_FORMAT.format(fname_out.stem, shard_offset, total_shards))
|
||||||
|
self.append((outname, None, GGUFWriter(outname, arch, use_temp_file=use_temp_file, endianess=endianess)))
|
||||||
|
shard_offset += 1
|
||||||
|
|
||||||
|
for i, shard in enumerate(shards):
|
||||||
|
outname = fname_out.with_name(SHARD_NAME_FORMAT.format(fname_out.stem, i + shard_offset, total_shards))
|
||||||
|
self.append((outname, shard, GGUFWriter(outname, arch, use_temp_file=use_temp_file, endianess=endianess)))
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
return self.data[index]
|
||||||
|
|
||||||
|
def __setitem__(self, index, value):
|
||||||
|
self.data[index] = value
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
def append(self, value: TensorTempData):
|
||||||
|
self.data.append(value)
|
||||||
|
|
||||||
|
def remove(self, item: TensorTempData):
|
||||||
|
self.data.remove(item)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tensor_size(tensor) -> int:
|
||||||
|
# we don't have the LazyTensor class here from convert.py but we can try
|
||||||
|
try:
|
||||||
|
return tensor.data_type.elements_to_bytes(np.prod(tensor.shape))
|
||||||
|
except AttributeError: # numpy ndarray[Any, Any]
|
||||||
|
return tensor.nbytes
|
||||||
|
except: # this should never happen
|
||||||
|
raise ValueError(f"Invalid tensor type: {type(tensor)}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def split_str_to_n_bytes(split_str: str) -> int:
|
||||||
|
if split_str.endswith("K"):
|
||||||
|
n = int(split_str[:-1]) * 1024
|
||||||
|
elif split_str.endswith("M"):
|
||||||
|
n = int(split_str[:-1]) * 1024 * 1024
|
||||||
|
elif split_str.endswith("G"):
|
||||||
|
n = int(split_str[:-1]) * 1024 * 1024 * 1024
|
||||||
|
elif split_str.isnumeric():
|
||||||
|
n = int(split_str)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid split size: {split_str}, must be a number, optionally followed by K, M, or G")
|
||||||
|
|
||||||
|
if n <= 0:
|
||||||
|
raise ValueError(f"Invalid split size: {split_str}, must be positive")
|
||||||
|
|
||||||
|
return n
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def format_n_bytes_to_str(num: int) -> str:
|
||||||
|
num = float(num)
|
||||||
|
for unit in ("", "K", "M", "G"):
|
||||||
|
if abs(num) < 1024.0:
|
||||||
|
return f"{num:3.1f}{unit}"
|
||||||
|
num /= 1024.0
|
||||||
|
return f"{num:.1f}T - over 1TB, --split recommended"
|
||||||
|
|
||||||
|
|
||||||
|
# ideally this has most of the same signatures as GGUFWriter so it's nearly a drop-in replacement
|
||||||
|
class GGUFManager:
|
||||||
|
kv_data: KVTempData
|
||||||
|
tensors: list[TensorTempData]
|
||||||
|
split_style: SplitStyle
|
||||||
|
split_strategy: SplitStrategy
|
||||||
|
|
||||||
|
def __init__(self, path: os.PathLike[str] | str, arch: str, args: Namespace, use_temp_file: bool = True,
|
||||||
|
endianess: GGUFEndian = GGUFEndian.LITTLE) -> None:
|
||||||
|
self.arch = arch
|
||||||
|
self.path = path
|
||||||
|
self.endianess = endianess
|
||||||
|
self.offset_tensor = 0
|
||||||
|
self.kv_data = {}
|
||||||
|
self.tensors = []
|
||||||
|
self.args = args
|
||||||
|
self.split_style = SplitStyle.NONE if not args.split \
|
||||||
|
else SplitStyle.TENSORS if args.split_max_tensors \
|
||||||
|
else SplitStyle.SIZE
|
||||||
|
self.split_strategy = None
|
||||||
|
self.total_shards = None
|
||||||
|
self.total_tensors = None
|
||||||
|
self.use_temp_file = use_temp_file
|
||||||
|
|
||||||
|
self.add_architecture()
|
||||||
|
|
||||||
|
# have to consolidate because we need to know kv data count and tensor count before we can write the header
|
||||||
|
# and we need to write tensor info before we can write metadata
|
||||||
|
# these all kinda show up around the same places anyway so it's not a huge deal?
|
||||||
|
def write_to_file(self, meta_only: bool = False, ftype: int = 0, concurrency: int = 8, write_tensor_data: function = None) -> None:
|
||||||
|
|
||||||
|
# here is the first place you can assume you have all tensors written and you can establish the size of the file - so logic goes here
|
||||||
|
self.total_tensors = len(self.tensors)
|
||||||
|
total_size = sum(SplitStrategy.get_tensor_size(tensor[1]) for tensor in self.tensors)
|
||||||
|
|
||||||
|
if self.args.split_max_tensors and self.total_tensors < self.args.split_max_tensors:
|
||||||
|
print("Model has fewer tensors than the split threshold, not splitting")
|
||||||
|
self.split_style = SplitStyle.NONE
|
||||||
|
|
||||||
|
if self.args.split_max_size and total_size < self.args.split_max_size:
|
||||||
|
print("Model has smaller size than the split threshold, not splitting")
|
||||||
|
self.split_style = SplitStyle.NONE
|
||||||
|
|
||||||
|
self.split_strategy = SplitStrategy(self.split_style, self.path, self.tensors, self.args, not self.args.large_first_shard)
|
||||||
|
self.total_shards = len(self.split_strategy)
|
||||||
|
|
||||||
|
# only the first shard needs all the KV data
|
||||||
|
for key, (value, etype) in self.kv_data.items():
|
||||||
|
self.split_strategy[0][2].add_key(key)
|
||||||
|
self.split_strategy[0][2].add_val(value, etype)
|
||||||
|
|
||||||
|
if self.split_style != SplitStyle.NONE:
|
||||||
|
for i, (_, _, writer) in enumerate(self.split_strategy):
|
||||||
|
writer.add_uint16(LLM_KV_SPLIT_NO, i)
|
||||||
|
writer.add_uint16(LLM_KV_SPLIT_COUNT, self.total_shards)
|
||||||
|
writer.add_int32(LLM_KV_SPLIT_TENSORS_COUNT, self.total_tensors)
|
||||||
|
|
||||||
|
# metadata/vocab only can write and return here
|
||||||
|
if meta_only:
|
||||||
|
for i, (_, _, writer) in enumerate(self.split_strategy):
|
||||||
|
writer.write_header_to_file()
|
||||||
|
writer.write_kv_data_to_file()
|
||||||
|
return
|
||||||
|
|
||||||
|
# tensor writing code starts here
|
||||||
|
|
||||||
|
print("\nWriting the following files:")
|
||||||
|
for (shard_path, shard_tensors, _) in self.split_strategy:
|
||||||
|
size = SplitStrategy.format_n_bytes_to_str(sum(SplitStrategy.get_tensor_size(t[1]) for t in shard_tensors)) if shard_tensors else "negligible - metadata only"
|
||||||
|
print(f" {shard_path}: n_tensors = {len(shard_tensors) if shard_tensors else 0}, total_size = {size}")
|
||||||
|
|
||||||
|
if self.args.dry_run:
|
||||||
|
print("\nDry run, not writing files")
|
||||||
|
return
|
||||||
|
|
||||||
|
# run add_tensor_info, write data, then write_tensor_data - taken from convert.py
|
||||||
|
running_total = self.total_tensors
|
||||||
|
for i, (_, tensors, writer) in enumerate(self.split_strategy):
|
||||||
|
|
||||||
|
if tensors:
|
||||||
|
for name, tensor in tensors:
|
||||||
|
n_elements = int(np.prod(tensor.shape))
|
||||||
|
raw_dtype = getattr(tensor.data_type, 'ggml_type', None)
|
||||||
|
data_type = getattr(tensor.data_type, 'quantized_type', None) or tensor.data_type.dtype
|
||||||
|
data_nbytes = tensor.data_type.elements_to_bytes(n_elements)
|
||||||
|
writer.add_tensor_info(name, tensor.shape, data_type, data_nbytes, raw_dtype=raw_dtype)
|
||||||
|
|
||||||
|
writer.write_header_to_file()
|
||||||
|
writer.write_kv_data_to_file()
|
||||||
|
writer.write_tensors_to_file()
|
||||||
|
|
||||||
|
if tensors:
|
||||||
|
print(f"\nWriting to shard {i + 1}/{self.total_shards} with {len(tensors)}/{running_total} remaining tensors (of {self.total_tensors} total)")
|
||||||
|
running_total -= len(tensors)
|
||||||
|
|
||||||
|
# convert.py's write_tensor_data is dependent on so many objects in convert.py itself that it's easier to pass the function as a parameter and call it here
|
||||||
|
write_tensor_data(ftype, dict(tensors), concurrency, writer)
|
||||||
|
|
||||||
|
def add_uint8(self, key: str, val: int) -> None:
|
||||||
|
self.kv_data[key] = (val, GGUFValueType.UINT8)
|
||||||
|
|
||||||
|
def add_int8(self, key: str, val: int) -> None:
|
||||||
|
self.kv_data[key] = (val, GGUFValueType.INT8)
|
||||||
|
|
||||||
|
def add_uint16(self, key: str, val: int) -> None:
|
||||||
|
self.kv_data[key] = (val, GGUFValueType.UINT16)
|
||||||
|
|
||||||
|
def add_int16(self, key: str, val: int) -> None:
|
||||||
|
self.kv_data[key] = (val, GGUFValueType.INT16)
|
||||||
|
|
||||||
|
def add_uint32(self, key: str, val: int) -> None:
|
||||||
|
self.kv_data[key] = (val, GGUFValueType.UINT32)
|
||||||
|
|
||||||
|
def add_int32(self, key: str, val: int) -> None:
|
||||||
|
self.kv_data[key] = (val, GGUFValueType.INT32)
|
||||||
|
|
||||||
|
def add_float32(self, key: str, val: float) -> None:
|
||||||
|
self.kv_data[key] = (val, GGUFValueType.FLOAT32)
|
||||||
|
|
||||||
|
def add_uint64(self, key: str, val: int) -> None:
|
||||||
|
self.kv_data[key] = (val, GGUFValueType.UINT64)
|
||||||
|
|
||||||
|
def add_int64(self, key: str, val: int) -> None:
|
||||||
|
self.kv_data[key] = (val, GGUFValueType.INT64)
|
||||||
|
|
||||||
|
def add_float64(self, key: str, val: float) -> None:
|
||||||
|
self.kv_data[key] = (val, GGUFValueType.FLOAT64)
|
||||||
|
|
||||||
|
def add_bool(self, key: str, val: bool) -> None:
|
||||||
|
self.kv_data[key] = (val, GGUFValueType.BOOL)
|
||||||
|
|
||||||
|
def add_string(self, key: str, val: str) -> None:
|
||||||
|
if not val:
|
||||||
|
return
|
||||||
|
self.kv_data[key] = (val, GGUFValueType.STRING)
|
||||||
|
|
||||||
|
def add_array(self, key: str, val: Sequence[Any]) -> None:
|
||||||
|
if not isinstance(val, Sequence):
|
||||||
|
raise ValueError(f'Expected a sequence for {key}, got {type(val)}')
|
||||||
|
self.kv_data[key] = (val, GGUFValueType.ARRAY)
|
||||||
|
|
||||||
|
# this method is exclusive to convert.py - we don't have LazyTensor so Any type is used
|
||||||
|
def add_tensor_info(self, name: str, tensor: Any) -> None:
|
||||||
|
self.tensors.append((name, tensor))
|
||||||
|
|
||||||
|
# these methods are everywhere but convert.py (and convert-lora-to-ggml.py since that doesn't use the class)
|
||||||
|
def add_tensor(
|
||||||
|
self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
|
||||||
|
raw_dtype: GGMLQuantizationType | None = None,
|
||||||
|
) -> None:
|
||||||
|
# TODO WRITE
|
||||||
|
pass
|
||||||
|
|
||||||
|
def write_tensors_to_file(self) -> None:
|
||||||
|
# TODO WRITE
|
||||||
|
pass
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
for _, _, writer in self.split_strategy:
|
||||||
|
writer.close()
|
||||||
|
|
||||||
|
def add_architecture(self) -> None:
|
||||||
|
self.add_string(Keys.General.ARCHITECTURE, self.arch)
|
||||||
|
|
||||||
|
def add_author(self, author: str) -> None:
|
||||||
|
self.add_string(Keys.General.AUTHOR, author)
|
||||||
|
|
||||||
|
def add_version(self, version: str) -> None:
|
||||||
|
self.add_string(Keys.General.VERSION, version)
|
||||||
|
|
||||||
|
def add_tensor_data_layout(self, layout: str) -> None:
|
||||||
|
self.add_string(Keys.LLM.TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
|
||||||
|
|
||||||
|
def add_url(self, url: str) -> None:
|
||||||
|
self.add_string(Keys.General.URL, url)
|
||||||
|
|
||||||
|
def add_description(self, description: str) -> None:
|
||||||
|
self.add_string(Keys.General.DESCRIPTION, description)
|
||||||
|
|
||||||
|
def add_licence(self, licence: str) -> None:
|
||||||
|
self.add_string(Keys.General.LICENSE, licence)
|
||||||
|
|
||||||
|
def add_source_url(self, url: str) -> None:
|
||||||
|
self.add_string(Keys.General.SOURCE_URL, url)
|
||||||
|
|
||||||
|
def add_source_hf_repo(self, repo: str) -> None:
|
||||||
|
self.add_string(Keys.General.SOURCE_HF_REPO, repo)
|
||||||
|
|
||||||
|
def add_file_type(self, ftype: int) -> None:
|
||||||
|
self.add_uint32(Keys.General.FILE_TYPE, ftype)
|
||||||
|
|
||||||
|
def add_name(self, name: str) -> None:
|
||||||
|
self.add_string(Keys.General.NAME, name)
|
||||||
|
|
||||||
|
def add_quantization_version(self, quantization_version: GGMLQuantizationType) -> None:
|
||||||
|
self.add_uint32(Keys.General.QUANTIZATION_VERSION, quantization_version)
|
||||||
|
|
||||||
|
def add_custom_alignment(self, alignment: int) -> None:
|
||||||
|
self.data_alignment = alignment
|
||||||
|
self.add_uint32(Keys.General.ALIGNMENT, alignment)
|
||||||
|
|
||||||
|
def add_vocab_size(self, size: int) -> None:
|
||||||
|
self.add_uint32(Keys.LLM.VOCAB_SIZE.format(arch=self.arch), size)
|
||||||
|
|
||||||
|
def add_context_length(self, length: int) -> None:
|
||||||
|
self.add_uint32(Keys.LLM.CONTEXT_LENGTH.format(arch=self.arch), length)
|
||||||
|
|
||||||
|
def add_embedding_length(self, length: int) -> None:
|
||||||
|
self.add_uint32(Keys.LLM.EMBEDDING_LENGTH.format(arch=self.arch), length)
|
||||||
|
|
||||||
|
def add_block_count(self, length: int) -> None:
|
||||||
|
self.add_uint32(Keys.LLM.BLOCK_COUNT.format(arch=self.arch), length)
|
||||||
|
|
||||||
|
def add_feed_forward_length(self, length: int) -> None:
|
||||||
|
self.add_uint32(Keys.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length)
|
||||||
|
|
||||||
|
def add_parallel_residual(self, use: bool) -> None:
|
||||||
|
self.add_bool(Keys.LLM.USE_PARALLEL_RESIDUAL.format(arch=self.arch), use)
|
||||||
|
|
||||||
|
def add_head_count(self, count: int) -> None:
|
||||||
|
self.add_uint32(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count)
|
||||||
|
|
||||||
|
def add_head_count_kv(self, count: int) -> None:
|
||||||
|
self.add_uint32(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count)
|
||||||
|
|
||||||
|
def add_key_length(self, length: int) -> None:
|
||||||
|
self.add_uint32(Keys.Attention.KEY_LENGTH.format(arch=self.arch), length)
|
||||||
|
|
||||||
|
def add_value_length(self, length: int) -> None:
|
||||||
|
self.add_uint32(Keys.Attention.VALUE_LENGTH.format(arch=self.arch), length)
|
||||||
|
|
||||||
|
def add_max_alibi_bias(self, bias: float) -> None:
|
||||||
|
self.add_float32(Keys.Attention.MAX_ALIBI_BIAS.format(arch=self.arch), bias)
|
||||||
|
|
||||||
|
def add_clamp_kqv(self, value: float) -> None:
|
||||||
|
self.add_float32(Keys.Attention.CLAMP_KQV.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
def add_logit_scale(self, value: float) -> None:
|
||||||
|
self.add_float32(Keys.LLM.LOGIT_SCALE.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
def add_expert_count(self, count: int) -> None:
|
||||||
|
self.add_uint32(Keys.LLM.EXPERT_COUNT.format(arch=self.arch), count)
|
||||||
|
|
||||||
|
def add_expert_used_count(self, count: int) -> None:
|
||||||
|
self.add_uint32(Keys.LLM.EXPERT_USED_COUNT.format(arch=self.arch), count)
|
||||||
|
|
||||||
|
def add_layer_norm_eps(self, value: float) -> None:
|
||||||
|
self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
def add_layer_norm_rms_eps(self, value: float) -> None:
|
||||||
|
self.add_float32(Keys.Attention.LAYERNORM_RMS_EPS.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
def add_causal_attention(self, value: bool) -> None:
|
||||||
|
self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
def add_pooling_type(self, value: PoolingType) -> None:
|
||||||
|
self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)
|
||||||
|
|
||||||
|
def add_rope_dimension_count(self, count: int) -> None:
|
||||||
|
self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count)
|
||||||
|
|
||||||
|
def add_rope_freq_base(self, value: float) -> None:
|
||||||
|
self.add_float32(Keys.Rope.FREQ_BASE.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
def add_rope_scaling_type(self, value: RopeScalingType) -> None:
|
||||||
|
self.add_string(Keys.Rope.SCALING_TYPE.format(arch=self.arch), value.value)
|
||||||
|
|
||||||
|
def add_rope_scaling_factor(self, value: float) -> None:
|
||||||
|
self.add_float32(Keys.Rope.SCALING_FACTOR.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
def add_rope_scaling_orig_ctx_len(self, value: int) -> None:
|
||||||
|
self.add_uint32(Keys.Rope.SCALING_ORIG_CTX_LEN.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
def add_rope_scaling_finetuned(self, value: bool) -> None:
|
||||||
|
self.add_bool(Keys.Rope.SCALING_FINETUNED.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
def add_ssm_conv_kernel(self, value: int) -> None:
|
||||||
|
self.add_uint32(Keys.SSM.CONV_KERNEL.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
def add_ssm_inner_size(self, value: int) -> None:
|
||||||
|
self.add_uint32(Keys.SSM.INNER_SIZE.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
def add_ssm_state_size(self, value: int) -> None:
|
||||||
|
self.add_uint32(Keys.SSM.STATE_SIZE.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
def add_ssm_time_step_rank(self, value: int) -> None:
|
||||||
|
self.add_uint32(Keys.SSM.TIME_STEP_RANK.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
def add_tokenizer_model(self, model: str) -> None:
|
||||||
|
self.add_string(Keys.Tokenizer.MODEL, model)
|
||||||
|
|
||||||
|
def add_token_list(self, tokens: Sequence[str] | Sequence[bytes] | Sequence[bytearray]) -> None:
|
||||||
|
self.add_array(Keys.Tokenizer.LIST, tokens)
|
||||||
|
|
||||||
|
def add_token_merges(self, merges: Sequence[str] | Sequence[bytes] | Sequence[bytearray]) -> None:
|
||||||
|
self.add_array(Keys.Tokenizer.MERGES, merges)
|
||||||
|
|
||||||
|
def add_token_types(self, types: Sequence[TokenType] | Sequence[int]) -> None:
|
||||||
|
self.add_array(Keys.Tokenizer.TOKEN_TYPE, types)
|
||||||
|
|
||||||
|
def add_token_type_count(self, value: int) -> None:
|
||||||
|
self.add_uint32(Keys.Tokenizer.TOKEN_TYPE_COUNT, value)
|
||||||
|
|
||||||
|
def add_token_scores(self, scores: Sequence[float]) -> None:
|
||||||
|
self.add_array(Keys.Tokenizer.SCORES, scores)
|
||||||
|
|
||||||
|
def add_bos_token_id(self, id: int) -> None:
|
||||||
|
self.add_uint32(Keys.Tokenizer.BOS_ID, id)
|
||||||
|
|
||||||
|
def add_eos_token_id(self, id: int) -> None:
|
||||||
|
self.add_uint32(Keys.Tokenizer.EOS_ID, id)
|
||||||
|
|
||||||
|
def add_unk_token_id(self, id: int) -> None:
|
||||||
|
self.add_uint32(Keys.Tokenizer.UNK_ID, id)
|
||||||
|
|
||||||
|
def add_sep_token_id(self, id: int) -> None:
|
||||||
|
self.add_uint32(Keys.Tokenizer.SEP_ID, id)
|
||||||
|
|
||||||
|
def add_pad_token_id(self, id: int) -> None:
|
||||||
|
self.add_uint32(Keys.Tokenizer.PAD_ID, id)
|
||||||
|
|
||||||
|
def add_cls_token_id(self, id: int) -> None:
|
||||||
|
self.add_uint32(Keys.Tokenizer.CLS_ID, id)
|
||||||
|
|
||||||
|
def add_mask_token_id(self, id: int) -> None:
|
||||||
|
self.add_uint32(Keys.Tokenizer.MASK_ID, id)
|
||||||
|
|
||||||
|
def add_add_bos_token(self, value: bool) -> None:
|
||||||
|
self.add_bool(Keys.Tokenizer.ADD_BOS, value)
|
||||||
|
|
||||||
|
def add_add_eos_token(self, value: bool) -> None:
|
||||||
|
self.add_bool(Keys.Tokenizer.ADD_EOS, value)
|
||||||
|
|
||||||
|
def add_add_space_prefix(self, value: bool) -> None:
|
||||||
|
self.add_bool(Keys.Tokenizer.ADD_PREFIX, value)
|
||||||
|
|
||||||
|
def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None:
|
||||||
|
if isinstance(value, list):
|
||||||
|
template_default = None
|
||||||
|
template_names = set()
|
||||||
|
|
||||||
|
for choice in value:
|
||||||
|
name = choice.get('name', '')
|
||||||
|
template = choice.get('template')
|
||||||
|
|
||||||
|
# Allowing non-alphanumerical characters in template name is probably not a good idea, so filter it
|
||||||
|
name = ''.join((c if c in ascii_letters + digits else '_' for c in name))
|
||||||
|
|
||||||
|
if name and template is not None:
|
||||||
|
if name == 'default':
|
||||||
|
template_default = template
|
||||||
|
else:
|
||||||
|
template_names.add(name)
|
||||||
|
self.add_string(Keys.Tokenizer.CHAT_TEMPLATE_N.format(name=name), template)
|
||||||
|
|
||||||
|
if template_names:
|
||||||
|
self.add_array(Keys.Tokenizer.CHAT_TEMPLATES, list(template_names))
|
||||||
|
|
||||||
|
if template_default is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
value = template_default
|
||||||
|
|
||||||
|
self.add_string(Keys.Tokenizer.CHAT_TEMPLATE, value)
|
||||||
|
|
||||||
|
def add_prefix_token_id(self, id: int) -> None:
|
||||||
|
self.add_uint32(Keys.Tokenizer.PREFIX_ID, id)
|
||||||
|
|
||||||
|
def add_suffix_token_id(self, id: int) -> None:
|
||||||
|
self.add_uint32(Keys.Tokenizer.SUFFIX_ID, id)
|
||||||
|
|
||||||
|
def add_middle_token_id(self, id: int) -> None:
|
||||||
|
self.add_uint32(Keys.Tokenizer.MIDDLE_ID, id)
|
||||||
|
|
||||||
|
def add_eot_token_id(self, id: int) -> None:
|
||||||
|
self.add_uint32(Keys.Tokenizer.EOT_ID, id)
|
Loading…
Add table
Add a link
Reference in a new issue