simplify even further and standardize with GGUFWriter

This commit is contained in:
Christian Zhou-Zheng 2024-06-05 12:49:08 -04:00
parent f6fd3ea4e9
commit bb5ee02096
2 changed files with 12 additions and 31 deletions

View file

@ -331,7 +331,7 @@ class Model:
self.write_tensors()
self.gguf_writer.write_header_to_file()
self.gguf_writer.write_kv_data_to_file()
self.gguf_writer.write_ti_data_to_file()
self.gguf_writer.write_tensors_to_file()
self.gguf_writer.close()
def write_vocab(self):

View file

@ -73,33 +73,24 @@ class SplitStrategy(deque):
self.append((shard, model[start:stop], GGUFWriter(shard, arch, use_temp_file=use_temp_file, endianess=endianess)))
elif split_arguments.split_style == SplitStyle.SIZE:
shards = []
shards = [[model[0]]]
# we have to determine the shards first to determine how many shards there will be in total - two passes
for i, shard in enumerate(model):
if i == 0:
shards.append([shard])
continue
for i, shard in enumerate(model[1:]):
if SplitStrategy.get_tensor_size(shard[1]) + sum(SplitStrategy.get_tensor_size(t[1]) for t in shards[-1]) > split_arguments.split_max_size:
shards.append([shard])
else:
shards[-1].append(shard)
total_shards = len(shards) + split_arguments.small_first_shard
shard_offset = 1
if split_arguments.small_first_shard:
outname = fname_out.with_name(SHARD_NAME_FORMAT.format(fname_out.stem, shard_offset, total_shards))
self.append((outname, None, GGUFWriter(outname, arch, use_temp_file=use_temp_file, endianess=endianess)))
shard_offset += 1
shards.insert(0, None)
for i, shard in enumerate(shards):
outname = fname_out.with_name(SHARD_NAME_FORMAT.format(fname_out.stem, i + shard_offset, total_shards))
outname = fname_out.with_name(SHARD_NAME_FORMAT.format(fname_out.stem, i + 1, len(shards)))
self.append((outname, shard, GGUFWriter(outname, arch, use_temp_file=use_temp_file, endianess=endianess)))
@staticmethod
def get_tensor_size(tensor) -> int:
# we don't have the LazyTensor class here from convert.py but we can try
try:
return tensor.data_type.elements_to_bytes(np.prod(tensor.shape))
except AttributeError: # numpy ndarray[Any, Any]
@ -213,7 +204,7 @@ class GGUFManager(GGUFWriter):
self.state = WriterState.KV_DATA
def write_ti_data_to_file(self) -> None:
def write_tensors_to_file(self) -> None:
if self.split_arguments.dry_run:
return
@ -222,25 +213,17 @@ class GGUFManager(GGUFWriter):
running_total = self.total_tensors
for ct in range(self.total_shards):
try:
(_, tensors, writer) = self.split_strategy.popleft()
tensors = deque(tensors) if tensors else None
except IndexError:
break
(_, tensors, writer) = self.split_strategy.popleft()
tensors = deque(tensors) if tensors else None
shard_num_tensors = len(tensors) if tensors else 0
if tensors:
while True:
try:
(name, tensor, dtype) = tensors.popleft()
except IndexError:
break
writer.add_tensor(name, tensor, raw_dtype=dtype)
print(f"Writing to shard {ct}/{self.total_shards} with {shard_num_tensors}/{running_total} remaining tensors (of {self.total_tensors} total)")
running_total -= shard_num_tensors
for _ in range(shard_num_tensors):
(name, tensor, dtype) = tensors.popleft()
writer.add_tensor(name, tensor, raw_dtype=dtype)
# need to write everything down here
writer.write_header_to_file()
writer.write_kv_data_to_file()
@ -268,8 +251,6 @@ class GGUFManager(GGUFWriter):
self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
raw_dtype: GGMLQuantizationType | None = None,
) -> None:
if self.endianess == GGUFEndian.BIG:
tensor.byteswap(inplace=True)
self.tensors.append((name, tensor, raw_dtype))
def close(self) -> None: