rename GGUFManager to GGUFWriterSplit
This commit is contained in:
parent
13ffe22ca7
commit
6d3a256d1d
3 changed files with 11 additions and 11 deletions
|
@ -96,7 +96,7 @@ class Model:
|
|||
ftype_lw: str = ftype_up.lower()
|
||||
# allow templating the file name with the output ftype, useful with the "auto" ftype
|
||||
self.fname_out = fname_out.parent / fname_out.name.format(ftype_lw, outtype=ftype_lw, ftype=ftype_lw, OUTTYPE=ftype_up, FTYPE=ftype_up)
|
||||
self.gguf_writer = gguf.GGUFManager(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], split_arguments,
|
||||
self.gguf_writer = gguf.GGUFWriterSplit(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], split_arguments,
|
||||
endianess=self.endianess, use_temp_file=self.use_temp_file)
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -2,7 +2,7 @@ from .constants import *
|
|||
from .lazy import *
|
||||
from .gguf_reader import *
|
||||
from .gguf_writer import *
|
||||
from .gguf_manager import *
|
||||
from .gguf_writer_split import *
|
||||
from .quants import *
|
||||
from .tensor_mapping import *
|
||||
from .vocab import *
|
||||
|
|
|
@ -47,7 +47,7 @@ class SplitArguments:
|
|||
def __init__(self, args: Namespace) -> None:
|
||||
self.split = args.split
|
||||
self.split_max_tensors = args.split_max_tensors if args.split else 0
|
||||
self.split_max_size = GGUFManager.split_str_to_n_bytes(args.split_max_size) if args.split and args.split_max_size else 0
|
||||
self.split_max_size = GGUFWriterSplit.split_str_to_n_bytes(args.split_max_size) if args.split and args.split_max_size else 0
|
||||
self.split_style = SplitStyle.NONE if not self.split \
|
||||
else SplitStyle.TENSORS if self.split_max_tensors \
|
||||
else SplitStyle.SIZE
|
||||
|
@ -55,7 +55,7 @@ class SplitArguments:
|
|||
self.small_first_shard = args.small_first_shard
|
||||
|
||||
|
||||
class GGUFManager(GGUFWriter):
|
||||
class GGUFWriterSplit(GGUFWriter):
|
||||
kv_data: KVTempData
|
||||
split_arguments: SplitArguments
|
||||
shards: list[Shard]
|
||||
|
@ -107,7 +107,7 @@ class GGUFManager(GGUFWriter):
|
|||
# print shard info
|
||||
print("\nWriting the following files:")
|
||||
for shard in self.shards:
|
||||
print(f" {shard.path}: n_tensors = {shard.tensor_count}, total_size = {GGUFManager.format_n_bytes_to_str(shard.size)}")
|
||||
print(f" {shard.path}: n_tensors = {shard.tensor_count}, total_size = {GGUFWriterSplit.format_n_bytes_to_str(shard.size)}")
|
||||
print()
|
||||
|
||||
if self.split_arguments.dry_run:
|
||||
|
@ -144,7 +144,7 @@ class GGUFManager(GGUFWriter):
|
|||
|
||||
def write_header_to_file(self) -> None:
|
||||
if self.state is not WriterState.EMPTY:
|
||||
raise ValueError(f'Expected GGUFManager state to be EMPTY, got {self.state}')
|
||||
raise ValueError(f'Expected GGUFWriterSplit state to be EMPTY, got {self.state}')
|
||||
|
||||
for writer in self.shard_writers:
|
||||
writer.write_header_to_file()
|
||||
|
@ -153,7 +153,7 @@ class GGUFManager(GGUFWriter):
|
|||
|
||||
def write_kv_data_to_file(self) -> None:
|
||||
if self.state is not WriterState.HEADER:
|
||||
raise ValueError(f'Expected GGUFManager state to be HEADER, got {self.state}')
|
||||
raise ValueError(f'Expected GGUFWriterSplit state to be HEADER, got {self.state}')
|
||||
|
||||
for writer in self.shard_writers:
|
||||
writer.write_kv_data_to_file()
|
||||
|
@ -162,7 +162,7 @@ class GGUFManager(GGUFWriter):
|
|||
|
||||
def write_tensors_to_file(self, *, progress: bool = False) -> None:
|
||||
if self.state is not WriterState.KV_DATA:
|
||||
raise ValueError(f'Expected GGUFManager state to be KV_DATA, got {self.state}')
|
||||
raise ValueError(f'Expected GGUFWriterSplit state to be KV_DATA, got {self.state}')
|
||||
|
||||
running_total = self.total_tensors
|
||||
for i in range(len(self.shard_writers)):
|
||||
|
@ -207,13 +207,13 @@ class GGUFManager(GGUFWriter):
|
|||
and self.shards[-1].tensor_count >= self.split_arguments.split_max_tensors) \
|
||||
# or split when over size limit
|
||||
or (self.split_arguments.split_style == SplitStyle.SIZE \
|
||||
and self.shards[-1].size + GGUFManager.get_tensor_size(tensor) > self.split_arguments.split_max_size)):
|
||||
and self.shards[-1].size + GGUFWriterSplit.get_tensor_size(tensor) > self.split_arguments.split_max_size)):
|
||||
|
||||
# we fill in the name later when we know how many shards there are
|
||||
self.shards.append(Shard(Path(), 1, GGUFManager.get_tensor_size(tensor), deque([(name, tensor, raw_dtype)])))
|
||||
self.shards.append(Shard(Path(), 1, GGUFWriterSplit.get_tensor_size(tensor), deque([(name, tensor, raw_dtype)])))
|
||||
else:
|
||||
self.shards[-1].tensor_count += 1
|
||||
self.shards[-1].size += GGUFManager.get_tensor_size(tensor)
|
||||
self.shards[-1].size += GGUFWriterSplit.get_tensor_size(tensor)
|
||||
self.shards[-1].tensors.append((name, tensor, raw_dtype))
|
||||
|
||||
def close(self) -> None:
|
Loading…
Add table
Add a link
Reference in a new issue