Fix eager tensor memory leak and remove convert.py changes
Removed a memory leak caused by unexpected reference retention to eager tensors. Also removed GGUFManager functionality in convert.py in favor of specializing for convert-hf-to-gguf.py.
This commit is contained in:
parent
2dd784108b
commit
3ff27efa89
4 changed files with 87 additions and 120 deletions
|
@ -2570,7 +2570,7 @@ def main() -> None:
|
||||||
if args.split_max_tensors and args.split_max_size:
|
if args.split_max_tensors and args.split_max_size:
|
||||||
raise ValueError("Can't specify both --split-max-tensors and --split-max-size")
|
raise ValueError("Can't specify both --split-max-tensors and --split-max-size")
|
||||||
|
|
||||||
split_arguments = gguf.SplitArguments(args) if args.split else gguf.SplitArguments()
|
split_arguments = gguf.SplitArguments(args=args) if args.split else gguf.SplitArguments()
|
||||||
|
|
||||||
ftype_map = {
|
ftype_map = {
|
||||||
"f32": gguf.LlamaFileType.ALL_F32,
|
"f32": gguf.LlamaFileType.ALL_F32,
|
||||||
|
|
70
convert.py
70
convert.py
|
@ -24,17 +24,14 @@ from abc import ABC, abstractmethod
|
||||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, IO, Iterable, Literal, Protocol, TypeVar, runtime_checkable
|
from typing import TYPE_CHECKING, Any, Callable, ClassVar, IO, Iterable, Literal, Protocol, TypeVar, runtime_checkable, Optional
|
||||||
# TEMPORARY IMPORT - TODO REMOVE
|
|
||||||
import importlib
|
|
||||||
gguf = importlib.import_module("gguf-py.gguf")
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sentencepiece import SentencePieceProcessor
|
from sentencepiece import SentencePieceProcessor
|
||||||
|
|
||||||
if 'NO_LOCAL_GGUF' not in os.environ:
|
if 'NO_LOCAL_GGUF' not in os.environ:
|
||||||
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
|
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
|
||||||
# import gguf
|
import gguf
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing_extensions import Self, TypeAlias
|
from typing_extensions import Self, TypeAlias
|
||||||
|
@ -1103,8 +1100,8 @@ def check_vocab_size(params: Params, vocab: BaseVocab, pad_vocab: bool = False)
|
||||||
|
|
||||||
|
|
||||||
class OutputFile:
|
class OutputFile:
|
||||||
def __init__(self, fname_out: Path, split_arguments: gguf.SplitArguments, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE):
|
def __init__(self, fname_out: Path, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE):
|
||||||
self.gguf = gguf.GGUFManager(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], split_arguments, endianess=endianess)
|
self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess)
|
||||||
|
|
||||||
def add_meta_model(self, params: Params, metadata: Metadata) -> None:
|
def add_meta_model(self, params: Params, metadata: Metadata) -> None:
|
||||||
# Metadata About The Model And Its Provenence
|
# Metadata About The Model And Its Provenence
|
||||||
|
@ -1204,15 +1201,21 @@ class OutputFile:
|
||||||
def add_meta_special_vocab(self, svocab: gguf.SpecialVocab) -> None:
|
def add_meta_special_vocab(self, svocab: gguf.SpecialVocab) -> None:
|
||||||
svocab.add_to_gguf(self.gguf)
|
svocab.add_to_gguf(self.gguf)
|
||||||
|
|
||||||
|
def add_tensor_info(self, name: str, tensor: LazyTensor) -> None:
|
||||||
|
n_elements = int(np.prod(tensor.shape))
|
||||||
|
raw_dtype = getattr(tensor.data_type, 'ggml_type', None)
|
||||||
|
data_type = getattr(tensor.data_type, 'quantized_type', None) or tensor.data_type.dtype
|
||||||
|
data_nbytes = tensor.data_type.elements_to_bytes(n_elements)
|
||||||
|
self.gguf.add_tensor_info(name, tensor.shape, data_type, data_nbytes, raw_dtype=raw_dtype)
|
||||||
|
|
||||||
def write_meta(self) -> None:
|
def write_meta(self) -> None:
|
||||||
self.gguf.write_to_file(meta_only=True)
|
self.gguf.write_header_to_file()
|
||||||
|
self.gguf.write_kv_data_to_file()
|
||||||
|
|
||||||
def write_tensors(self, ftype: GGMLFileType, concurrency: int) -> None:
|
def write_tensor_info(self) -> None:
|
||||||
self.gguf.write_to_file(ftype=ftype, concurrency=concurrency, write_tensor_data=OutputFile.write_tensor_data)
|
self.gguf.write_ti_data_to_file()
|
||||||
|
|
||||||
# really awkward with how this is managed with gguf_manager.py: maybe refactor at some point?
|
def write_tensor_data(self, ftype: GGMLFileType, model: LazyModel, concurrency: int) -> None:
|
||||||
@staticmethod
|
|
||||||
def write_tensor_data(ftype: GGMLFileType, model: LazyModel, concurrency: int, writer: gguf.GGUFWriter) -> None:
|
|
||||||
ndarrays_inner = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency=concurrency)
|
ndarrays_inner = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency=concurrency)
|
||||||
if ftype == GGMLFileType.MostlyQ8_0:
|
if ftype == GGMLFileType.MostlyQ8_0:
|
||||||
ndarrays = bounded_parallel_map(
|
ndarrays = bounded_parallel_map(
|
||||||
|
@ -1230,7 +1233,7 @@ class OutputFile:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[{i + 1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4} | T+{int(elapsed):4}"
|
f"[{i + 1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4} | T+{int(elapsed):4}"
|
||||||
)
|
)
|
||||||
writer.write_tensor_data(ndarray)
|
self.gguf.write_tensor_data(ndarray)
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
self.gguf.close()
|
self.gguf.close()
|
||||||
|
@ -1242,7 +1245,7 @@ class OutputFile:
|
||||||
) -> None:
|
) -> None:
|
||||||
check_vocab_size(params, vocab, pad_vocab=pad_vocab)
|
check_vocab_size(params, vocab, pad_vocab=pad_vocab)
|
||||||
|
|
||||||
of = OutputFile(fname_out, gguf.SplitArguments(), endianess=endianess)
|
of = OutputFile(fname_out, endianess=endianess)
|
||||||
|
|
||||||
# meta data
|
# meta data
|
||||||
of.add_meta_model(params, metadata)
|
of.add_meta_model(params, metadata)
|
||||||
|
@ -1270,11 +1273,13 @@ class OutputFile:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def write_all(
|
def write_all(
|
||||||
fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: BaseVocab, svocab: gguf.SpecialVocab,
|
fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: BaseVocab, svocab: gguf.SpecialVocab,
|
||||||
split_arguments: gguf.SplitArguments, concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE,
|
concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE,
|
||||||
pad_vocab: bool = False, metadata: Metadata = None,
|
pad_vocab: bool = False,
|
||||||
|
metadata: Metadata = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
check_vocab_size(params, vocab, pad_vocab=pad_vocab)
|
check_vocab_size(params, vocab, pad_vocab=pad_vocab)
|
||||||
of = OutputFile(fname_out, split_arguments, endianess=endianess)
|
|
||||||
|
of = OutputFile(fname_out, endianess=endianess)
|
||||||
|
|
||||||
# meta data
|
# meta data
|
||||||
of.add_meta_model(params, metadata)
|
of.add_meta_model(params, metadata)
|
||||||
|
@ -1287,9 +1292,13 @@ class OutputFile:
|
||||||
|
|
||||||
# tensor info
|
# tensor info
|
||||||
for name, lazy_tensor in model.items():
|
for name, lazy_tensor in model.items():
|
||||||
of.gguf.add_tensor_info(name, lazy_tensor)
|
of.add_tensor_info(name, lazy_tensor)
|
||||||
|
|
||||||
of.write_tensors(ftype, concurrency)
|
of.write_meta()
|
||||||
|
of.write_tensor_info()
|
||||||
|
|
||||||
|
# tensor data
|
||||||
|
of.write_tensor_data(ftype, model, concurrency)
|
||||||
|
|
||||||
of.close()
|
of.close()
|
||||||
|
|
||||||
|
@ -1364,7 +1373,7 @@ def convert_model_names(model: LazyModel, params: Params, skip_unknown: bool) ->
|
||||||
experts.append(model[f"model.layers.{i_l}.block_sparse_moe.experts.{e}.w{w}.weight"])
|
experts.append(model[f"model.layers.{i_l}.block_sparse_moe.experts.{e}.w{w}.weight"])
|
||||||
del tmp[f"model.layers.{i_l}.block_sparse_moe.experts.{e}.w{w}.weight"]
|
del tmp[f"model.layers.{i_l}.block_sparse_moe.experts.{e}.w{w}.weight"]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Expert tensor not found: layers.{i_l}.feed_forward.experts.{e}.w{w}.model_classweight")
|
raise ValueError(f"Expert tensor not found: layers.{i_l}.feed_forward.experts.{e}.w{w}.weight")
|
||||||
tmp[f"layers.{i_l}.feed_forward.experts.w{w}.weight"] = pack_experts_lazy(experts)
|
tmp[f"layers.{i_l}.feed_forward.experts.w{w}.weight"] = pack_experts_lazy(experts)
|
||||||
|
|
||||||
# HF models permut or pack some of the tensors, so we need to undo that
|
# HF models permut or pack some of the tensors, so we need to undo that
|
||||||
|
@ -1584,11 +1593,6 @@ def main(args_in: list[str] | None = None) -> None:
|
||||||
parser.add_argument("--big-endian", action="store_true", help="model is executed on big endian machine")
|
parser.add_argument("--big-endian", action="store_true", help="model is executed on big endian machine")
|
||||||
parser.add_argument("--pad-vocab", action="store_true", help="add pad tokens when model vocab expects more than tokenizer metadata provides")
|
parser.add_argument("--pad-vocab", action="store_true", help="add pad tokens when model vocab expects more than tokenizer metadata provides")
|
||||||
parser.add_argument("--skip-unknown", action="store_true", help="skip unknown tensor names instead of failing")
|
parser.add_argument("--skip-unknown", action="store_true", help="skip unknown tensor names instead of failing")
|
||||||
parser.add_argument("--split", action="store_true", help="split the converted model into multiple files")
|
|
||||||
parser.add_argument("--split-max-tensors", type=int, help="max tensors in each split")
|
|
||||||
parser.add_argument("--split-max-size", type=str, help="max size per split N(M|G)")
|
|
||||||
parser.add_argument("--dry-run", action="store_true", help="only print out a split plan and exit, without writing any new files")
|
|
||||||
parser.add_argument("--large-first-shard", action="store_true", help="include tensors in the first shard when splitting (default: metadata only)")
|
|
||||||
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
|
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
|
||||||
parser.add_argument("--metadata", type=Path, help="Specify the path for a metadata file")
|
parser.add_argument("--metadata", type=Path, help="Specify the path for a metadata file")
|
||||||
parser.add_argument("--get-outfile", action="store_true", help="get calculated default outfile name")
|
parser.add_argument("--get-outfile", action="store_true", help="get calculated default outfile name")
|
||||||
|
@ -1622,14 +1626,6 @@ def main(args_in: list[str] | None = None) -> None:
|
||||||
do_dump_model(model_plus)
|
do_dump_model(model_plus)
|
||||||
return
|
return
|
||||||
|
|
||||||
if args.split and not (args.split_max_tensors or args.split_max_size):
|
|
||||||
raise ValueError("Need to specify one of --split-max-tensors or --split-max-size when splitting")
|
|
||||||
|
|
||||||
if args.split_max_tensors and args.split_max_size:
|
|
||||||
raise ValueError("Can't specify both --split-max-tensors and --split-max-size")
|
|
||||||
|
|
||||||
split_arguments = gguf.SplitArguments(args) if args.split else gguf.SplitArguments()
|
|
||||||
|
|
||||||
if not args.vocab_only:
|
if not args.vocab_only:
|
||||||
model_plus = load_some_model(args.model)
|
model_plus = load_some_model(args.model)
|
||||||
else:
|
else:
|
||||||
|
@ -1707,13 +1703,11 @@ def main(args_in: list[str] | None = None) -> None:
|
||||||
outfile = args.outfile or default_outfile(model_plus.paths, ftype, params, model_params_count, metadata)
|
outfile = args.outfile or default_outfile(model_plus.paths, ftype, params, model_params_count, metadata)
|
||||||
|
|
||||||
params.ftype = ftype
|
params.ftype = ftype
|
||||||
|
|
||||||
logger.info(f"Writing {outfile}, format {ftype}")
|
logger.info(f"Writing {outfile}, format {ftype}")
|
||||||
|
|
||||||
OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab, split_arguments,
|
OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab,
|
||||||
concurrency=args.concurrency, endianess=endianess, pad_vocab=args.pad_vocab, metadata=metadata)
|
concurrency=args.concurrency, endianess=endianess, pad_vocab=args.pad_vocab, metadata=metadata)
|
||||||
if not args.dry_run:
|
logger.info(f"Wrote {outfile}")
|
||||||
logger.info(f"Wrote {outfile}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Any, Sequence, Mapping
|
||||||
from string import ascii_letters, digits
|
from string import ascii_letters, digits
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from math import ceil
|
from math import ceil
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
@ -34,7 +35,7 @@ LLM_KV_SPLIT_NO = "split.no"
|
||||||
LLM_KV_SPLIT_COUNT = "split.count"
|
LLM_KV_SPLIT_COUNT = "split.count"
|
||||||
LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count"
|
LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count"
|
||||||
|
|
||||||
SplitTensorsPerFile: TypeAlias = list[tuple[os.PathLike[str], list[tuple[str, Any]], GGUFWriter]] # [(outfile name, [(tensor name, tensor data)] for each tensor in file, filewriter)]
|
SplitTensorsPerFile: TypeAlias = deque[tuple[os.PathLike[str], deque[tuple[str, Any]], GGUFWriter]] # [(outfile name, [(tensor name, tensor data)] for each tensor in file, filewriter)]
|
||||||
KVTempData: TypeAlias = dict[str, tuple[Any, GGUFValueType]] # {key: (value, type)}
|
KVTempData: TypeAlias = dict[str, tuple[Any, GGUFValueType]] # {key: (value, type)}
|
||||||
TensorTempData: TypeAlias = tuple[str, np.ndarray[Any, Any]] # (tensor name, tensor data), aka LazyModel
|
TensorTempData: TypeAlias = tuple[str, np.ndarray[Any, Any]] # (tensor name, tensor data), aka LazyModel
|
||||||
|
|
||||||
|
@ -53,23 +54,23 @@ class SplitArguments:
|
||||||
split_max_size: int
|
split_max_size: int
|
||||||
split_style: SplitStyle
|
split_style: SplitStyle
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self, args: Namespace = None) -> None:
|
||||||
self.split = False
|
if args is None:
|
||||||
self.dry_run = False
|
self.split = False
|
||||||
self.small_first_shard = False
|
self.dry_run = False
|
||||||
self.split_max_tensors = 0
|
self.small_first_shard = False
|
||||||
self.split_max_size = 0
|
self.split_max_tensors = 0
|
||||||
self.split_style = SplitStyle.NONE
|
self.split_max_size = 0
|
||||||
|
self.split_style = SplitStyle.NONE
|
||||||
def __init__(self, args: Namespace) -> None:
|
else:
|
||||||
self.split = args.split
|
self.split = args.split
|
||||||
self.split_max_tensors = args.split_max_tensors
|
self.split_max_tensors = args.split_max_tensors
|
||||||
self.split_max_size = SplitStrategy.split_str_to_n_bytes(args.split_max_size) if args.split_max_size else None
|
self.split_max_size = SplitStrategy.split_str_to_n_bytes(args.split_max_size) if args.split_max_size else None
|
||||||
self.dry_run = args.dry_run
|
self.dry_run = args.dry_run
|
||||||
self.small_first_shard = not args.large_first_shard
|
self.small_first_shard = not args.large_first_shard
|
||||||
self.split_style = SplitStyle.NONE if not self.split \
|
self.split_style = SplitStyle.NONE if not self.split \
|
||||||
else SplitStyle.TENSORS if self.split_max_tensors \
|
else SplitStyle.TENSORS if self.split_max_tensors \
|
||||||
else SplitStyle.SIZE
|
else SplitStyle.SIZE
|
||||||
|
|
||||||
|
|
||||||
class SplitStrategy:
|
class SplitStrategy:
|
||||||
|
@ -78,7 +79,7 @@ class SplitStrategy:
|
||||||
def __init__(self, fname_out: os.PathLike[str], model: list[TensorTempData], arch: str,
|
def __init__(self, fname_out: os.PathLike[str], model: list[TensorTempData], arch: str,
|
||||||
split_arguments: SplitArguments, use_temp_file: bool = True, endianess: GGUFEndian = GGUFEndian.LITTLE,
|
split_arguments: SplitArguments, use_temp_file: bool = True, endianess: GGUFEndian = GGUFEndian.LITTLE,
|
||||||
):
|
):
|
||||||
self.data = []
|
self.data = deque()
|
||||||
|
|
||||||
if split_arguments.split_style == SplitStyle.NONE:
|
if split_arguments.split_style == SplitStyle.NONE:
|
||||||
self.append((fname_out, model, GGUFWriter(fname_out, arch, use_temp_file=use_temp_file, endianess=endianess)))
|
self.append((fname_out, model, GGUFWriter(fname_out, arch, use_temp_file=use_temp_file, endianess=endianess)))
|
||||||
|
@ -96,7 +97,7 @@ class SplitStrategy:
|
||||||
self.append((shard, model[start:stop], GGUFWriter(shard, arch, use_temp_file=use_temp_file, endianess=endianess)))
|
self.append((shard, model[start:stop], GGUFWriter(shard, arch, use_temp_file=use_temp_file, endianess=endianess)))
|
||||||
|
|
||||||
elif split_arguments.split_style == SplitStyle.SIZE:
|
elif split_arguments.split_style == SplitStyle.SIZE:
|
||||||
shards = []
|
shards = deque()
|
||||||
|
|
||||||
# we have to determine the shards first to determine how many shards there will be in total - two passes
|
# we have to determine the shards first to determine how many shards there will be in total - two passes
|
||||||
for i, shard in enumerate(model):
|
for i, shard in enumerate(model):
|
||||||
|
@ -118,13 +119,7 @@ class SplitStrategy:
|
||||||
|
|
||||||
for i, shard in enumerate(shards):
|
for i, shard in enumerate(shards):
|
||||||
outname = fname_out.with_name(SHARD_NAME_FORMAT.format(fname_out.stem, i + shard_offset, total_shards))
|
outname = fname_out.with_name(SHARD_NAME_FORMAT.format(fname_out.stem, i + shard_offset, total_shards))
|
||||||
self.append((outname, shard, GGUFWriter(outname, arch, use_temp_file=use_temp_file, endianess=endianess)))
|
self.append((outname, deque(shard), GGUFWriter(outname, arch, use_temp_file=use_temp_file, endianess=endianess)))
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
return self.data[index]
|
|
||||||
|
|
||||||
def __setitem__(self, index, value):
|
|
||||||
self.data[index] = value
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.data)
|
return len(self.data)
|
||||||
|
@ -176,7 +171,7 @@ class SplitStrategy:
|
||||||
# ideally this has most of the same signatures as GGUFWriter so it's nearly a drop-in replacement
|
# ideally this has most of the same signatures as GGUFWriter so it's nearly a drop-in replacement
|
||||||
class GGUFManager:
|
class GGUFManager:
|
||||||
kv_data: KVTempData
|
kv_data: KVTempData
|
||||||
tensors: list[TensorTempData]
|
tensors: deque[TensorTempData]
|
||||||
split_arguments: SplitArguments
|
split_arguments: SplitArguments
|
||||||
split_strategy: SplitStrategy
|
split_strategy: SplitStrategy
|
||||||
|
|
||||||
|
@ -188,7 +183,7 @@ class GGUFManager:
|
||||||
self.endianess = endianess
|
self.endianess = endianess
|
||||||
self.offset_tensor = 0
|
self.offset_tensor = 0
|
||||||
self.kv_data = {}
|
self.kv_data = {}
|
||||||
self.tensors = []
|
self.tensors = deque()
|
||||||
self.split_strategy = None
|
self.split_strategy = None
|
||||||
self.total_shards = None
|
self.total_shards = None
|
||||||
self.total_tensors = None
|
self.total_tensors = None
|
||||||
|
@ -200,9 +195,7 @@ class GGUFManager:
|
||||||
# have to consolidate because we need to know kv data count and tensor count before we can write the header
|
# have to consolidate because we need to know kv data count and tensor count before we can write the header
|
||||||
# and we need to write tensor info before we can write metadata
|
# and we need to write tensor info before we can write metadata
|
||||||
# these all kinda show up around the same places anyway so it's not a huge deal?
|
# these all kinda show up around the same places anyway so it's not a huge deal?
|
||||||
def write_to_file(self, meta_only: bool = False, ftype: int = 0, concurrency: int = 8,
|
def write_to_file(self, meta_only: bool = False) -> None:
|
||||||
write_tensor_data: function = None
|
|
||||||
) -> None:
|
|
||||||
|
|
||||||
# here is the first place you can assume you have all tensors written and you can establish the size of the file - so logic goes here
|
# here is the first place you can assume you have all tensors written and you can establish the size of the file - so logic goes here
|
||||||
self.total_tensors = len(self.tensors)
|
self.total_tensors = len(self.tensors)
|
||||||
|
@ -218,22 +211,23 @@ class GGUFManager:
|
||||||
|
|
||||||
self.split_strategy = SplitStrategy(self.path, self.tensors, self.arch, self.split_arguments,
|
self.split_strategy = SplitStrategy(self.path, self.tensors, self.arch, self.split_arguments,
|
||||||
use_temp_file=self.use_temp_file, endianess=self.endianess)
|
use_temp_file=self.use_temp_file, endianess=self.endianess)
|
||||||
|
del self.tensors
|
||||||
self.total_shards = len(self.split_strategy)
|
self.total_shards = len(self.split_strategy)
|
||||||
|
|
||||||
# only the first shard needs all the KV data
|
# only the first shard needs all the KV data
|
||||||
for key, (value, etype) in self.kv_data.items():
|
for key, (value, etype) in self.kv_data.items():
|
||||||
self.split_strategy[0][2].add_key(key)
|
self.split_strategy.data[0][2].add_key(key)
|
||||||
self.split_strategy[0][2].add_val(value, etype)
|
self.split_strategy.data[0][2].add_val(value, etype)
|
||||||
|
|
||||||
if self.split_arguments.split_style != SplitStyle.NONE:
|
if self.split_arguments.split_style != SplitStyle.NONE:
|
||||||
for i, (_, _, writer) in enumerate(self.split_strategy):
|
for i, (_, _, writer) in enumerate(self.split_strategy.data):
|
||||||
writer.add_uint16(LLM_KV_SPLIT_NO, i)
|
writer.add_uint16(LLM_KV_SPLIT_NO, i)
|
||||||
writer.add_uint16(LLM_KV_SPLIT_COUNT, self.total_shards)
|
writer.add_uint16(LLM_KV_SPLIT_COUNT, self.total_shards)
|
||||||
writer.add_int32(LLM_KV_SPLIT_TENSORS_COUNT, self.total_tensors)
|
writer.add_int32(LLM_KV_SPLIT_TENSORS_COUNT, self.total_tensors)
|
||||||
|
|
||||||
# metadata/vocab only can write and return here
|
# metadata/vocab only can write and return here
|
||||||
if meta_only:
|
if meta_only:
|
||||||
for i, (_, _, writer) in enumerate(self.split_strategy):
|
for i, (_, _, writer) in enumerate(self.split_strategy.data):
|
||||||
writer.write_header_to_file()
|
writer.write_header_to_file()
|
||||||
writer.write_kv_data_to_file()
|
writer.write_kv_data_to_file()
|
||||||
return
|
return
|
||||||
|
@ -241,57 +235,44 @@ class GGUFManager:
|
||||||
# tensor writing code starts here
|
# tensor writing code starts here
|
||||||
|
|
||||||
print("\nWriting the following files:")
|
print("\nWriting the following files:")
|
||||||
for (shard_path, shard_tensors, _) in self.split_strategy:
|
for (shard_path, shard_tensors, _) in self.split_strategy.data:
|
||||||
size = SplitStrategy.format_n_bytes_to_str(sum(SplitStrategy.get_tensor_size(t[1]) for t in shard_tensors)) if shard_tensors else "negligible - metadata only"
|
size = SplitStrategy.format_n_bytes_to_str(sum(SplitStrategy.get_tensor_size(t[1]) for t in shard_tensors)) if shard_tensors else "negligible - metadata only"
|
||||||
print(f" {shard_path}: n_tensors = {len(shard_tensors) if shard_tensors else 0}, total_size = {size}")
|
print(f" {shard_path}: n_tensors = {len(shard_tensors) if shard_tensors else 0}, total_size = {size}")
|
||||||
|
|
||||||
if self.split_arguments.dry_run:
|
if self.split_arguments.dry_run:
|
||||||
print("\nDry run, not writing files")
|
print("\nDry run, not writing files")
|
||||||
# instantiating GGUFWriters creates files
|
# instantiating GGUFWriters creates files
|
||||||
for name, _, _ in self.split_strategy:
|
for name, _, _ in self.split_strategy.data:
|
||||||
os.remove(name)
|
os.remove(name)
|
||||||
return
|
return
|
||||||
|
|
||||||
# run add_tensor_info, write data, then write_tensor_data - taken from convert.py
|
# run add_tensor_info, write data, then write_tensor_data - taken from convert.py
|
||||||
running_total = self.total_tensors
|
running_total = self.total_tensors
|
||||||
start = time.time()
|
ct = 0
|
||||||
for i, (_, tensors, writer) in enumerate(self.split_strategy):
|
while True:
|
||||||
|
try:
|
||||||
|
(_, tensors, writer) = self.split_strategy.data.popleft()
|
||||||
|
except IndexError:
|
||||||
|
break
|
||||||
|
|
||||||
|
shard_num_tensors = len(tensors) if tensors else 0
|
||||||
|
|
||||||
if tensors:
|
if tensors:
|
||||||
print(f"\nWriting to shard {i + 1}/{self.total_shards} with {len(tensors)}/{running_total} remaining tensors (of {self.total_tensors} total)")
|
while True:
|
||||||
for j, (name, tensor) in enumerate(tensors):
|
try:
|
||||||
n_elements = int(np.prod(tensor.shape))
|
(name, tensor) = tensors.popleft()
|
||||||
# logic from convert.py
|
except IndexError:
|
||||||
if getattr(tensor, 'data_type', None):
|
break
|
||||||
raw_dtype = getattr(tensor.data_type, 'ggml_type', None)
|
writer.add_tensor(name, tensor)
|
||||||
data_type = getattr(tensor.data_type, 'quantized_type', None) or tensor.data_type.dtype
|
|
||||||
data_nbytes = tensor.data_type.elements_to_bytes(n_elements)
|
|
||||||
writer.add_tensor_info(name, tensor.shape, data_type, data_nbytes, raw_dtype=raw_dtype)
|
|
||||||
# logic from convert-hf-to-gguf.py
|
|
||||||
else:
|
|
||||||
# stolen from write_tensor_data because that doesn't get called with this logic
|
|
||||||
elapsed = time.time() - start
|
|
||||||
size = ' x '.join(f"{dim:6d}" for dim in tensor.shape)
|
|
||||||
padi = len(str(self.total_tensors))
|
|
||||||
dtype = str(tensor.dtype)
|
|
||||||
print(
|
|
||||||
f"[{j + 1:{padi}d}/{len(tensors)}] Writing tensor {name:38s} | size {size:16} | type {dtype:8} | T+{int(elapsed):4}"
|
|
||||||
)
|
|
||||||
writer.add_tensor(name, tensor)
|
|
||||||
print(f"Writing to shard {i + 1}/{self.total_shards} with {len(tensors)}/{running_total} remaining tensors (of {self.total_tensors} total)")
|
|
||||||
|
|
||||||
|
print(f"Writing to shard {ct + 1}/{self.total_shards} with {shard_num_tensors}/{running_total} remaining tensors (of {self.total_tensors} total)")
|
||||||
|
running_total -= shard_num_tensors
|
||||||
|
|
||||||
writer.write_header_to_file()
|
writer.write_header_to_file()
|
||||||
writer.write_kv_data_to_file()
|
writer.write_kv_data_to_file()
|
||||||
writer.write_tensors_to_file()
|
writer.write_tensors_to_file(progress=True)
|
||||||
|
ct = ct + 1
|
||||||
if tensors:
|
del tensors
|
||||||
# TODO this shows up AFTER writing which we don't really want - move it
|
|
||||||
running_total -= len(tensors)
|
|
||||||
|
|
||||||
if write_tensor_data:
|
|
||||||
# convert.py's write_tensor_data is dependent on so many objects in convert.py itself that it's easier to pass the function as a parameter and call it here
|
|
||||||
write_tensor_data(ftype, dict(tensors), concurrency, writer)
|
|
||||||
|
|
||||||
def add_uint8(self, key: str, val: int) -> None:
|
def add_uint8(self, key: str, val: int) -> None:
|
||||||
self.kv_data[key] = (val, GGUFValueType.UINT8)
|
self.kv_data[key] = (val, GGUFValueType.UINT8)
|
||||||
|
@ -336,11 +317,6 @@ class GGUFManager:
|
||||||
raise ValueError(f'Expected a sequence for {key}, got {type(val)}')
|
raise ValueError(f'Expected a sequence for {key}, got {type(val)}')
|
||||||
self.kv_data[key] = (val, GGUFValueType.ARRAY)
|
self.kv_data[key] = (val, GGUFValueType.ARRAY)
|
||||||
|
|
||||||
# this method is exclusive to convert.py - we don't have LazyTensor so Any type is used
|
|
||||||
def add_tensor_info(self, name: str, tensor: Any) -> None:
|
|
||||||
self.tensors.append((name, tensor))
|
|
||||||
|
|
||||||
# these methods are everywhere but convert.py (and convert-lora-to-ggml.py since that doesn't use the class)
|
|
||||||
def add_tensor(
|
def add_tensor(
|
||||||
self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
|
self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
|
||||||
raw_dtype: GGMLQuantizationType | None = None,
|
raw_dtype: GGMLQuantizationType | None = None,
|
||||||
|
@ -354,7 +330,7 @@ class GGUFManager:
|
||||||
# fp.seek(0)
|
# fp.seek(0)
|
||||||
# self.temp_file = fp
|
# self.temp_file = fp
|
||||||
|
|
||||||
self.add_tensor_info(name, tensor)
|
self.tensors.append((name, tensor))
|
||||||
|
|
||||||
#if self.temp_file is None:
|
#if self.temp_file is None:
|
||||||
# self.tensors.append(tensor)
|
# self.tensors.append(tensor)
|
||||||
|
@ -363,12 +339,8 @@ class GGUFManager:
|
||||||
#tensor.tofile(self.temp_file)
|
#tensor.tofile(self.temp_file)
|
||||||
#self.write_padding(self.temp_file, tensor.nbytes)
|
#self.write_padding(self.temp_file, tensor.nbytes)
|
||||||
|
|
||||||
def write_tensors_to_file(self) -> None:
|
|
||||||
# TODO WRITE
|
|
||||||
pass
|
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
for _, _, writer in self.split_strategy:
|
for _, _, writer in self.split_strategy.data:
|
||||||
writer.close()
|
writer.close()
|
||||||
|
|
||||||
def add_architecture(self) -> None:
|
def add_architecture(self) -> None:
|
||||||
|
|
|
@ -301,6 +301,7 @@ class GGUFWriter:
|
||||||
tensor.tofile(self.fout)
|
tensor.tofile(self.fout)
|
||||||
bar.update(tensor.nbytes)
|
bar.update(tensor.nbytes)
|
||||||
self.write_padding(self.fout, tensor.nbytes)
|
self.write_padding(self.fout, tensor.nbytes)
|
||||||
|
del tensor
|
||||||
return
|
return
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue