rename GGUFManager to GGUFWriterSplit

This commit is contained in:
Christian Zhou-Zheng 2024-06-07 09:12:44 -04:00
parent 13ffe22ca7
commit 6d3a256d1d
3 changed files with 11 additions and 11 deletions

View file

@ -96,7 +96,7 @@ class Model:
ftype_lw: str = ftype_up.lower() ftype_lw: str = ftype_up.lower()
# allow templating the file name with the output ftype, useful with the "auto" ftype # allow templating the file name with the output ftype, useful with the "auto" ftype
self.fname_out = fname_out.parent / fname_out.name.format(ftype_lw, outtype=ftype_lw, ftype=ftype_lw, OUTTYPE=ftype_up, FTYPE=ftype_up) self.fname_out = fname_out.parent / fname_out.name.format(ftype_lw, outtype=ftype_lw, ftype=ftype_lw, OUTTYPE=ftype_up, FTYPE=ftype_up)
self.gguf_writer = gguf.GGUFManager(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], split_arguments, self.gguf_writer = gguf.GGUFWriterSplit(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], split_arguments,
endianess=self.endianess, use_temp_file=self.use_temp_file) endianess=self.endianess, use_temp_file=self.use_temp_file)
@classmethod @classmethod

View file

@ -2,7 +2,7 @@ from .constants import *
from .lazy import * from .lazy import *
from .gguf_reader import * from .gguf_reader import *
from .gguf_writer import * from .gguf_writer import *
from .gguf_manager import * from .gguf_writer_split import *
from .quants import * from .quants import *
from .tensor_mapping import * from .tensor_mapping import *
from .vocab import * from .vocab import *

View file

@ -47,7 +47,7 @@ class SplitArguments:
def __init__(self, args: Namespace) -> None: def __init__(self, args: Namespace) -> None:
self.split = args.split self.split = args.split
self.split_max_tensors = args.split_max_tensors if args.split else 0 self.split_max_tensors = args.split_max_tensors if args.split else 0
self.split_max_size = GGUFManager.split_str_to_n_bytes(args.split_max_size) if args.split and args.split_max_size else 0 self.split_max_size = GGUFWriterSplit.split_str_to_n_bytes(args.split_max_size) if args.split and args.split_max_size else 0
self.split_style = SplitStyle.NONE if not self.split \ self.split_style = SplitStyle.NONE if not self.split \
else SplitStyle.TENSORS if self.split_max_tensors \ else SplitStyle.TENSORS if self.split_max_tensors \
else SplitStyle.SIZE else SplitStyle.SIZE
@ -55,7 +55,7 @@ class SplitArguments:
self.small_first_shard = args.small_first_shard self.small_first_shard = args.small_first_shard
class GGUFManager(GGUFWriter): class GGUFWriterSplit(GGUFWriter):
kv_data: KVTempData kv_data: KVTempData
split_arguments: SplitArguments split_arguments: SplitArguments
shards: list[Shard] shards: list[Shard]
@ -107,7 +107,7 @@ class GGUFManager(GGUFWriter):
# print shard info # print shard info
print("\nWriting the following files:") print("\nWriting the following files:")
for shard in self.shards: for shard in self.shards:
print(f" {shard.path}: n_tensors = {shard.tensor_count}, total_size = {GGUFManager.format_n_bytes_to_str(shard.size)}") print(f" {shard.path}: n_tensors = {shard.tensor_count}, total_size = {GGUFWriterSplit.format_n_bytes_to_str(shard.size)}")
print() print()
if self.split_arguments.dry_run: if self.split_arguments.dry_run:
@ -144,7 +144,7 @@ class GGUFManager(GGUFWriter):
def write_header_to_file(self) -> None: def write_header_to_file(self) -> None:
if self.state is not WriterState.EMPTY: if self.state is not WriterState.EMPTY:
raise ValueError(f'Expected GGUFManager state to be EMPTY, got {self.state}') raise ValueError(f'Expected GGUFWriterSplit state to be EMPTY, got {self.state}')
for writer in self.shard_writers: for writer in self.shard_writers:
writer.write_header_to_file() writer.write_header_to_file()
@ -153,7 +153,7 @@ class GGUFManager(GGUFWriter):
def write_kv_data_to_file(self) -> None: def write_kv_data_to_file(self) -> None:
if self.state is not WriterState.HEADER: if self.state is not WriterState.HEADER:
raise ValueError(f'Expected GGUFManager state to be HEADER, got {self.state}') raise ValueError(f'Expected GGUFWriterSplit state to be HEADER, got {self.state}')
for writer in self.shard_writers: for writer in self.shard_writers:
writer.write_kv_data_to_file() writer.write_kv_data_to_file()
@ -162,7 +162,7 @@ class GGUFManager(GGUFWriter):
def write_tensors_to_file(self, *, progress: bool = False) -> None: def write_tensors_to_file(self, *, progress: bool = False) -> None:
if self.state is not WriterState.KV_DATA: if self.state is not WriterState.KV_DATA:
raise ValueError(f'Expected GGUFManager state to be KV_DATA, got {self.state}') raise ValueError(f'Expected GGUFWriterSplit state to be KV_DATA, got {self.state}')
running_total = self.total_tensors running_total = self.total_tensors
for i in range(len(self.shard_writers)): for i in range(len(self.shard_writers)):
@ -207,13 +207,13 @@ class GGUFManager(GGUFWriter):
and self.shards[-1].tensor_count >= self.split_arguments.split_max_tensors) \ and self.shards[-1].tensor_count >= self.split_arguments.split_max_tensors) \
# or split when over size limit # or split when over size limit
or (self.split_arguments.split_style == SplitStyle.SIZE \ or (self.split_arguments.split_style == SplitStyle.SIZE \
and self.shards[-1].size + GGUFManager.get_tensor_size(tensor) > self.split_arguments.split_max_size)): and self.shards[-1].size + GGUFWriterSplit.get_tensor_size(tensor) > self.split_arguments.split_max_size)):
# we fill in the name later when we know how many shards there are # we fill in the name later when we know how many shards there are
self.shards.append(Shard(Path(), 1, GGUFManager.get_tensor_size(tensor), deque([(name, tensor, raw_dtype)]))) self.shards.append(Shard(Path(), 1, GGUFWriterSplit.get_tensor_size(tensor), deque([(name, tensor, raw_dtype)])))
else: else:
self.shards[-1].tensor_count += 1 self.shards[-1].tensor_count += 1
self.shards[-1].size += GGUFManager.get_tensor_size(tensor) self.shards[-1].size += GGUFWriterSplit.get_tensor_size(tensor)
self.shards[-1].tensors.append((name, tensor, raw_dtype)) self.shards[-1].tensors.append((name, tensor, raw_dtype))
def close(self) -> None: def close(self) -> None: