Fix eager tensor memory leak and remove convert.py changes

Removed a memory leak caused by unexpected reference retention to eager tensors.

Also removed GGUFManager functionality in convert.py in favor of specializing for convert-hf-to-gguf.py.
This commit is contained in:
Christian Zhou-Zheng 2024-05-23 18:50:21 -04:00
parent 2dd784108b
commit 3ff27efa89
4 changed files with 87 additions and 120 deletions

View file

@ -2570,7 +2570,7 @@ def main() -> None:
if args.split_max_tensors and args.split_max_size: if args.split_max_tensors and args.split_max_size:
raise ValueError("Can't specify both --split-max-tensors and --split-max-size") raise ValueError("Can't specify both --split-max-tensors and --split-max-size")
split_arguments = gguf.SplitArguments(args) if args.split else gguf.SplitArguments() split_arguments = gguf.SplitArguments(args=args) if args.split else gguf.SplitArguments()
ftype_map = { ftype_map = {
"f32": gguf.LlamaFileType.ALL_F32, "f32": gguf.LlamaFileType.ALL_F32,

View file

@ -24,17 +24,14 @@ from abc import ABC, abstractmethod
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, ClassVar, IO, Iterable, Literal, Protocol, TypeVar, runtime_checkable from typing import TYPE_CHECKING, Any, Callable, ClassVar, IO, Iterable, Literal, Protocol, TypeVar, runtime_checkable, Optional
# TEMPORARY IMPORT - TODO REMOVE
import importlib
gguf = importlib.import_module("gguf-py.gguf")
import numpy as np import numpy as np
from sentencepiece import SentencePieceProcessor from sentencepiece import SentencePieceProcessor
if 'NO_LOCAL_GGUF' not in os.environ: if 'NO_LOCAL_GGUF' not in os.environ:
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
# import gguf import gguf
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import Self, TypeAlias from typing_extensions import Self, TypeAlias
@ -1103,8 +1100,8 @@ def check_vocab_size(params: Params, vocab: BaseVocab, pad_vocab: bool = False)
class OutputFile: class OutputFile:
def __init__(self, fname_out: Path, split_arguments: gguf.SplitArguments, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE): def __init__(self, fname_out: Path, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE):
self.gguf = gguf.GGUFManager(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], split_arguments, endianess=endianess) self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess)
def add_meta_model(self, params: Params, metadata: Metadata) -> None: def add_meta_model(self, params: Params, metadata: Metadata) -> None:
# Metadata About The Model And Its Provenence # Metadata About The Model And Its Provenence
@ -1204,15 +1201,21 @@ class OutputFile:
def add_meta_special_vocab(self, svocab: gguf.SpecialVocab) -> None: def add_meta_special_vocab(self, svocab: gguf.SpecialVocab) -> None:
svocab.add_to_gguf(self.gguf) svocab.add_to_gguf(self.gguf)
def add_tensor_info(self, name: str, tensor: LazyTensor) -> None:
n_elements = int(np.prod(tensor.shape))
raw_dtype = getattr(tensor.data_type, 'ggml_type', None)
data_type = getattr(tensor.data_type, 'quantized_type', None) or tensor.data_type.dtype
data_nbytes = tensor.data_type.elements_to_bytes(n_elements)
self.gguf.add_tensor_info(name, tensor.shape, data_type, data_nbytes, raw_dtype=raw_dtype)
def write_meta(self) -> None: def write_meta(self) -> None:
self.gguf.write_to_file(meta_only=True) self.gguf.write_header_to_file()
self.gguf.write_kv_data_to_file()
def write_tensors(self, ftype: GGMLFileType, concurrency: int) -> None: def write_tensor_info(self) -> None:
self.gguf.write_to_file(ftype=ftype, concurrency=concurrency, write_tensor_data=OutputFile.write_tensor_data) self.gguf.write_ti_data_to_file()
# really awkward with how this is managed with gguf_manager.py: maybe refactor at some point? def write_tensor_data(self, ftype: GGMLFileType, model: LazyModel, concurrency: int) -> None:
@staticmethod
def write_tensor_data(ftype: GGMLFileType, model: LazyModel, concurrency: int, writer: gguf.GGUFWriter) -> None:
ndarrays_inner = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency=concurrency) ndarrays_inner = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency=concurrency)
if ftype == GGMLFileType.MostlyQ8_0: if ftype == GGMLFileType.MostlyQ8_0:
ndarrays = bounded_parallel_map( ndarrays = bounded_parallel_map(
@ -1230,7 +1233,7 @@ class OutputFile:
logger.info( logger.info(
f"[{i + 1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4} | T+{int(elapsed):4}" f"[{i + 1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4} | T+{int(elapsed):4}"
) )
writer.write_tensor_data(ndarray) self.gguf.write_tensor_data(ndarray)
def close(self) -> None: def close(self) -> None:
self.gguf.close() self.gguf.close()
@ -1242,7 +1245,7 @@ class OutputFile:
) -> None: ) -> None:
check_vocab_size(params, vocab, pad_vocab=pad_vocab) check_vocab_size(params, vocab, pad_vocab=pad_vocab)
of = OutputFile(fname_out, gguf.SplitArguments(), endianess=endianess) of = OutputFile(fname_out, endianess=endianess)
# meta data # meta data
of.add_meta_model(params, metadata) of.add_meta_model(params, metadata)
@ -1270,11 +1273,13 @@ class OutputFile:
@staticmethod @staticmethod
def write_all( def write_all(
fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: BaseVocab, svocab: gguf.SpecialVocab, fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: BaseVocab, svocab: gguf.SpecialVocab,
split_arguments: gguf.SplitArguments, concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE,
pad_vocab: bool = False, metadata: Metadata = None, pad_vocab: bool = False,
metadata: Metadata = None,
) -> None: ) -> None:
check_vocab_size(params, vocab, pad_vocab=pad_vocab) check_vocab_size(params, vocab, pad_vocab=pad_vocab)
of = OutputFile(fname_out, split_arguments, endianess=endianess)
of = OutputFile(fname_out, endianess=endianess)
# meta data # meta data
of.add_meta_model(params, metadata) of.add_meta_model(params, metadata)
@ -1287,9 +1292,13 @@ class OutputFile:
# tensor info # tensor info
for name, lazy_tensor in model.items(): for name, lazy_tensor in model.items():
of.gguf.add_tensor_info(name, lazy_tensor) of.add_tensor_info(name, lazy_tensor)
of.write_tensors(ftype, concurrency) of.write_meta()
of.write_tensor_info()
# tensor data
of.write_tensor_data(ftype, model, concurrency)
of.close() of.close()
@ -1364,7 +1373,7 @@ def convert_model_names(model: LazyModel, params: Params, skip_unknown: bool) ->
experts.append(model[f"model.layers.{i_l}.block_sparse_moe.experts.{e}.w{w}.weight"]) experts.append(model[f"model.layers.{i_l}.block_sparse_moe.experts.{e}.w{w}.weight"])
del tmp[f"model.layers.{i_l}.block_sparse_moe.experts.{e}.w{w}.weight"] del tmp[f"model.layers.{i_l}.block_sparse_moe.experts.{e}.w{w}.weight"]
else: else:
raise ValueError(f"Expert tensor not found: layers.{i_l}.feed_forward.experts.{e}.w{w}.model_classweight") raise ValueError(f"Expert tensor not found: layers.{i_l}.feed_forward.experts.{e}.w{w}.weight")
tmp[f"layers.{i_l}.feed_forward.experts.w{w}.weight"] = pack_experts_lazy(experts) tmp[f"layers.{i_l}.feed_forward.experts.w{w}.weight"] = pack_experts_lazy(experts)
# HF models permut or pack some of the tensors, so we need to undo that # HF models permut or pack some of the tensors, so we need to undo that
@ -1584,11 +1593,6 @@ def main(args_in: list[str] | None = None) -> None:
parser.add_argument("--big-endian", action="store_true", help="model is executed on big endian machine") parser.add_argument("--big-endian", action="store_true", help="model is executed on big endian machine")
parser.add_argument("--pad-vocab", action="store_true", help="add pad tokens when model vocab expects more than tokenizer metadata provides") parser.add_argument("--pad-vocab", action="store_true", help="add pad tokens when model vocab expects more than tokenizer metadata provides")
parser.add_argument("--skip-unknown", action="store_true", help="skip unknown tensor names instead of failing") parser.add_argument("--skip-unknown", action="store_true", help="skip unknown tensor names instead of failing")
parser.add_argument("--split", action="store_true", help="split the converted model into multiple files")
parser.add_argument("--split-max-tensors", type=int, help="max tensors in each split")
parser.add_argument("--split-max-size", type=str, help="max size per split N(M|G)")
parser.add_argument("--dry-run", action="store_true", help="only print out a split plan and exit, without writing any new files")
parser.add_argument("--large-first-shard", action="store_true", help="include tensors in the first shard when splitting (default: metadata only)")
parser.add_argument("--verbose", action="store_true", help="increase output verbosity") parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
parser.add_argument("--metadata", type=Path, help="Specify the path for a metadata file") parser.add_argument("--metadata", type=Path, help="Specify the path for a metadata file")
parser.add_argument("--get-outfile", action="store_true", help="get calculated default outfile name") parser.add_argument("--get-outfile", action="store_true", help="get calculated default outfile name")
@ -1622,14 +1626,6 @@ def main(args_in: list[str] | None = None) -> None:
do_dump_model(model_plus) do_dump_model(model_plus)
return return
if args.split and not (args.split_max_tensors or args.split_max_size):
raise ValueError("Need to specify one of --split-max-tensors or --split-max-size when splitting")
if args.split_max_tensors and args.split_max_size:
raise ValueError("Can't specify both --split-max-tensors and --split-max-size")
split_arguments = gguf.SplitArguments(args) if args.split else gguf.SplitArguments()
if not args.vocab_only: if not args.vocab_only:
model_plus = load_some_model(args.model) model_plus = load_some_model(args.model)
else: else:
@ -1707,13 +1703,11 @@ def main(args_in: list[str] | None = None) -> None:
outfile = args.outfile or default_outfile(model_plus.paths, ftype, params, model_params_count, metadata) outfile = args.outfile or default_outfile(model_plus.paths, ftype, params, model_params_count, metadata)
params.ftype = ftype params.ftype = ftype
logger.info(f"Writing {outfile}, format {ftype}") logger.info(f"Writing {outfile}, format {ftype}")
OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab, split_arguments, OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab,
concurrency=args.concurrency, endianess=endianess, pad_vocab=args.pad_vocab, metadata=metadata) concurrency=args.concurrency, endianess=endianess, pad_vocab=args.pad_vocab, metadata=metadata)
if not args.dry_run: logger.info(f"Wrote {outfile}")
logger.info(f"Wrote {outfile}")
if __name__ == '__main__': if __name__ == '__main__':

View file

@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Any, Sequence, Mapping
from string import ascii_letters, digits from string import ascii_letters, digits
from argparse import Namespace from argparse import Namespace
from math import ceil from math import ceil
from collections import deque
import numpy as np import numpy as np
@ -34,7 +35,7 @@ LLM_KV_SPLIT_NO = "split.no"
LLM_KV_SPLIT_COUNT = "split.count" LLM_KV_SPLIT_COUNT = "split.count"
LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count" LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count"
SplitTensorsPerFile: TypeAlias = list[tuple[os.PathLike[str], list[tuple[str, Any]], GGUFWriter]] # [(outfile name, [(tensor name, tensor data)] for each tensor in file, filewriter)] SplitTensorsPerFile: TypeAlias = deque[tuple[os.PathLike[str], deque[tuple[str, Any]], GGUFWriter]] # [(outfile name, [(tensor name, tensor data)] for each tensor in file, filewriter)]
KVTempData: TypeAlias = dict[str, tuple[Any, GGUFValueType]] # {key: (value, type)} KVTempData: TypeAlias = dict[str, tuple[Any, GGUFValueType]] # {key: (value, type)}
TensorTempData: TypeAlias = tuple[str, np.ndarray[Any, Any]] # (tensor name, tensor data), aka LazyModel TensorTempData: TypeAlias = tuple[str, np.ndarray[Any, Any]] # (tensor name, tensor data), aka LazyModel
@ -53,23 +54,23 @@ class SplitArguments:
split_max_size: int split_max_size: int
split_style: SplitStyle split_style: SplitStyle
def __init__(self) -> None: def __init__(self, args: Namespace = None) -> None:
self.split = False if args is None:
self.dry_run = False self.split = False
self.small_first_shard = False self.dry_run = False
self.split_max_tensors = 0 self.small_first_shard = False
self.split_max_size = 0 self.split_max_tensors = 0
self.split_style = SplitStyle.NONE self.split_max_size = 0
self.split_style = SplitStyle.NONE
def __init__(self, args: Namespace) -> None: else:
self.split = args.split self.split = args.split
self.split_max_tensors = args.split_max_tensors self.split_max_tensors = args.split_max_tensors
self.split_max_size = SplitStrategy.split_str_to_n_bytes(args.split_max_size) if args.split_max_size else None self.split_max_size = SplitStrategy.split_str_to_n_bytes(args.split_max_size) if args.split_max_size else None
self.dry_run = args.dry_run self.dry_run = args.dry_run
self.small_first_shard = not args.large_first_shard self.small_first_shard = not args.large_first_shard
self.split_style = SplitStyle.NONE if not self.split \ self.split_style = SplitStyle.NONE if not self.split \
else SplitStyle.TENSORS if self.split_max_tensors \ else SplitStyle.TENSORS if self.split_max_tensors \
else SplitStyle.SIZE else SplitStyle.SIZE
class SplitStrategy: class SplitStrategy:
@ -78,7 +79,7 @@ class SplitStrategy:
def __init__(self, fname_out: os.PathLike[str], model: list[TensorTempData], arch: str, def __init__(self, fname_out: os.PathLike[str], model: list[TensorTempData], arch: str,
split_arguments: SplitArguments, use_temp_file: bool = True, endianess: GGUFEndian = GGUFEndian.LITTLE, split_arguments: SplitArguments, use_temp_file: bool = True, endianess: GGUFEndian = GGUFEndian.LITTLE,
): ):
self.data = [] self.data = deque()
if split_arguments.split_style == SplitStyle.NONE: if split_arguments.split_style == SplitStyle.NONE:
self.append((fname_out, model, GGUFWriter(fname_out, arch, use_temp_file=use_temp_file, endianess=endianess))) self.append((fname_out, model, GGUFWriter(fname_out, arch, use_temp_file=use_temp_file, endianess=endianess)))
@ -96,7 +97,7 @@ class SplitStrategy:
self.append((shard, model[start:stop], GGUFWriter(shard, arch, use_temp_file=use_temp_file, endianess=endianess))) self.append((shard, model[start:stop], GGUFWriter(shard, arch, use_temp_file=use_temp_file, endianess=endianess)))
elif split_arguments.split_style == SplitStyle.SIZE: elif split_arguments.split_style == SplitStyle.SIZE:
shards = [] shards = deque()
# we have to determine the shards first to determine how many shards there will be in total - two passes # we have to determine the shards first to determine how many shards there will be in total - two passes
for i, shard in enumerate(model): for i, shard in enumerate(model):
@ -118,13 +119,7 @@ class SplitStrategy:
for i, shard in enumerate(shards): for i, shard in enumerate(shards):
outname = fname_out.with_name(SHARD_NAME_FORMAT.format(fname_out.stem, i + shard_offset, total_shards)) outname = fname_out.with_name(SHARD_NAME_FORMAT.format(fname_out.stem, i + shard_offset, total_shards))
self.append((outname, shard, GGUFWriter(outname, arch, use_temp_file=use_temp_file, endianess=endianess))) self.append((outname, deque(shard), GGUFWriter(outname, arch, use_temp_file=use_temp_file, endianess=endianess)))
def __getitem__(self, index):
return self.data[index]
def __setitem__(self, index, value):
self.data[index] = value
def __len__(self): def __len__(self):
return len(self.data) return len(self.data)
@ -176,7 +171,7 @@ class SplitStrategy:
# ideally this has most of the same signatures as GGUFWriter so it's nearly a drop-in replacement # ideally this has most of the same signatures as GGUFWriter so it's nearly a drop-in replacement
class GGUFManager: class GGUFManager:
kv_data: KVTempData kv_data: KVTempData
tensors: list[TensorTempData] tensors: deque[TensorTempData]
split_arguments: SplitArguments split_arguments: SplitArguments
split_strategy: SplitStrategy split_strategy: SplitStrategy
@ -188,7 +183,7 @@ class GGUFManager:
self.endianess = endianess self.endianess = endianess
self.offset_tensor = 0 self.offset_tensor = 0
self.kv_data = {} self.kv_data = {}
self.tensors = [] self.tensors = deque()
self.split_strategy = None self.split_strategy = None
self.total_shards = None self.total_shards = None
self.total_tensors = None self.total_tensors = None
@ -200,9 +195,7 @@ class GGUFManager:
# have to consolidate because we need to know kv data count and tensor count before we can write the header # have to consolidate because we need to know kv data count and tensor count before we can write the header
# and we need to write tensor info before we can write metadata # and we need to write tensor info before we can write metadata
# these all kinda show up around the same places anyway so it's not a huge deal? # these all kinda show up around the same places anyway so it's not a huge deal?
def write_to_file(self, meta_only: bool = False, ftype: int = 0, concurrency: int = 8, def write_to_file(self, meta_only: bool = False) -> None:
write_tensor_data: function = None
) -> None:
# here is the first place you can assume you have all tensors written and you can establish the size of the file - so logic goes here # here is the first place you can assume you have all tensors written and you can establish the size of the file - so logic goes here
self.total_tensors = len(self.tensors) self.total_tensors = len(self.tensors)
@ -218,22 +211,23 @@ class GGUFManager:
self.split_strategy = SplitStrategy(self.path, self.tensors, self.arch, self.split_arguments, self.split_strategy = SplitStrategy(self.path, self.tensors, self.arch, self.split_arguments,
use_temp_file=self.use_temp_file, endianess=self.endianess) use_temp_file=self.use_temp_file, endianess=self.endianess)
del self.tensors
self.total_shards = len(self.split_strategy) self.total_shards = len(self.split_strategy)
# only the first shard needs all the KV data # only the first shard needs all the KV data
for key, (value, etype) in self.kv_data.items(): for key, (value, etype) in self.kv_data.items():
self.split_strategy[0][2].add_key(key) self.split_strategy.data[0][2].add_key(key)
self.split_strategy[0][2].add_val(value, etype) self.split_strategy.data[0][2].add_val(value, etype)
if self.split_arguments.split_style != SplitStyle.NONE: if self.split_arguments.split_style != SplitStyle.NONE:
for i, (_, _, writer) in enumerate(self.split_strategy): for i, (_, _, writer) in enumerate(self.split_strategy.data):
writer.add_uint16(LLM_KV_SPLIT_NO, i) writer.add_uint16(LLM_KV_SPLIT_NO, i)
writer.add_uint16(LLM_KV_SPLIT_COUNT, self.total_shards) writer.add_uint16(LLM_KV_SPLIT_COUNT, self.total_shards)
writer.add_int32(LLM_KV_SPLIT_TENSORS_COUNT, self.total_tensors) writer.add_int32(LLM_KV_SPLIT_TENSORS_COUNT, self.total_tensors)
# metadata/vocab only can write and return here # metadata/vocab only can write and return here
if meta_only: if meta_only:
for i, (_, _, writer) in enumerate(self.split_strategy): for i, (_, _, writer) in enumerate(self.split_strategy.data):
writer.write_header_to_file() writer.write_header_to_file()
writer.write_kv_data_to_file() writer.write_kv_data_to_file()
return return
@ -241,57 +235,44 @@ class GGUFManager:
# tensor writing code starts here # tensor writing code starts here
print("\nWriting the following files:") print("\nWriting the following files:")
for (shard_path, shard_tensors, _) in self.split_strategy: for (shard_path, shard_tensors, _) in self.split_strategy.data:
size = SplitStrategy.format_n_bytes_to_str(sum(SplitStrategy.get_tensor_size(t[1]) for t in shard_tensors)) if shard_tensors else "negligible - metadata only" size = SplitStrategy.format_n_bytes_to_str(sum(SplitStrategy.get_tensor_size(t[1]) for t in shard_tensors)) if shard_tensors else "negligible - metadata only"
print(f" {shard_path}: n_tensors = {len(shard_tensors) if shard_tensors else 0}, total_size = {size}") print(f" {shard_path}: n_tensors = {len(shard_tensors) if shard_tensors else 0}, total_size = {size}")
if self.split_arguments.dry_run: if self.split_arguments.dry_run:
print("\nDry run, not writing files") print("\nDry run, not writing files")
# instantiating GGUFWriters creates files # instantiating GGUFWriters creates files
for name, _, _ in self.split_strategy: for name, _, _ in self.split_strategy.data:
os.remove(name) os.remove(name)
return return
# run add_tensor_info, write data, then write_tensor_data - taken from convert.py # run add_tensor_info, write data, then write_tensor_data - taken from convert.py
running_total = self.total_tensors running_total = self.total_tensors
start = time.time() ct = 0
for i, (_, tensors, writer) in enumerate(self.split_strategy): while True:
try:
(_, tensors, writer) = self.split_strategy.data.popleft()
except IndexError:
break
shard_num_tensors = len(tensors) if tensors else 0
if tensors: if tensors:
print(f"\nWriting to shard {i + 1}/{self.total_shards} with {len(tensors)}/{running_total} remaining tensors (of {self.total_tensors} total)") while True:
for j, (name, tensor) in enumerate(tensors): try:
n_elements = int(np.prod(tensor.shape)) (name, tensor) = tensors.popleft()
# logic from convert.py except IndexError:
if getattr(tensor, 'data_type', None): break
raw_dtype = getattr(tensor.data_type, 'ggml_type', None) writer.add_tensor(name, tensor)
data_type = getattr(tensor.data_type, 'quantized_type', None) or tensor.data_type.dtype
data_nbytes = tensor.data_type.elements_to_bytes(n_elements)
writer.add_tensor_info(name, tensor.shape, data_type, data_nbytes, raw_dtype=raw_dtype)
# logic from convert-hf-to-gguf.py
else:
# stolen from write_tensor_data because that doesn't get called with this logic
elapsed = time.time() - start
size = ' x '.join(f"{dim:6d}" for dim in tensor.shape)
padi = len(str(self.total_tensors))
dtype = str(tensor.dtype)
print(
f"[{j + 1:{padi}d}/{len(tensors)}] Writing tensor {name:38s} | size {size:16} | type {dtype:8} | T+{int(elapsed):4}"
)
writer.add_tensor(name, tensor)
print(f"Writing to shard {i + 1}/{self.total_shards} with {len(tensors)}/{running_total} remaining tensors (of {self.total_tensors} total)")
print(f"Writing to shard {ct + 1}/{self.total_shards} with {shard_num_tensors}/{running_total} remaining tensors (of {self.total_tensors} total)")
running_total -= shard_num_tensors
writer.write_header_to_file() writer.write_header_to_file()
writer.write_kv_data_to_file() writer.write_kv_data_to_file()
writer.write_tensors_to_file() writer.write_tensors_to_file(progress=True)
ct = ct + 1
if tensors: del tensors
# TODO this shows up AFTER writing which we don't really want - move it
running_total -= len(tensors)
if write_tensor_data:
# convert.py's write_tensor_data is dependent on so many objects in convert.py itself that it's easier to pass the function as a parameter and call it here
write_tensor_data(ftype, dict(tensors), concurrency, writer)
def add_uint8(self, key: str, val: int) -> None: def add_uint8(self, key: str, val: int) -> None:
self.kv_data[key] = (val, GGUFValueType.UINT8) self.kv_data[key] = (val, GGUFValueType.UINT8)
@ -336,11 +317,6 @@ class GGUFManager:
raise ValueError(f'Expected a sequence for {key}, got {type(val)}') raise ValueError(f'Expected a sequence for {key}, got {type(val)}')
self.kv_data[key] = (val, GGUFValueType.ARRAY) self.kv_data[key] = (val, GGUFValueType.ARRAY)
# this method is exclusive to convert.py - we don't have LazyTensor so Any type is used
def add_tensor_info(self, name: str, tensor: Any) -> None:
self.tensors.append((name, tensor))
# these methods are everywhere but convert.py (and convert-lora-to-ggml.py since that doesn't use the class)
def add_tensor( def add_tensor(
self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None, self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
raw_dtype: GGMLQuantizationType | None = None, raw_dtype: GGMLQuantizationType | None = None,
@ -354,7 +330,7 @@ class GGUFManager:
# fp.seek(0) # fp.seek(0)
# self.temp_file = fp # self.temp_file = fp
self.add_tensor_info(name, tensor) self.tensors.append((name, tensor))
#if self.temp_file is None: #if self.temp_file is None:
# self.tensors.append(tensor) # self.tensors.append(tensor)
@ -363,12 +339,8 @@ class GGUFManager:
#tensor.tofile(self.temp_file) #tensor.tofile(self.temp_file)
#self.write_padding(self.temp_file, tensor.nbytes) #self.write_padding(self.temp_file, tensor.nbytes)
def write_tensors_to_file(self) -> None:
# TODO WRITE
pass
def close(self) -> None: def close(self) -> None:
for _, _, writer in self.split_strategy: for _, _, writer in self.split_strategy.data:
writer.close() writer.close()
def add_architecture(self) -> None: def add_architecture(self) -> None:

View file

@ -301,6 +301,7 @@ class GGUFWriter:
tensor.tofile(self.fout) tensor.tofile(self.fout)
bar.update(tensor.nbytes) bar.update(tensor.nbytes)
self.write_padding(self.fout, tensor.nbytes) self.write_padding(self.fout, tensor.nbytes)
del tensor
return return
while True: while True:
try: try: