Shard dataclass and un-negative dont_add_architecture
This commit is contained in:
parent
6a05183b97
commit
3328b0a991
2 changed files with 31 additions and 22 deletions
|
@ -5,6 +5,7 @@ from enum import IntEnum
|
||||||
from typing import TYPE_CHECKING, Any, Sequence
|
from typing import TYPE_CHECKING, Any, Sequence
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
@ -28,7 +29,14 @@ LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count"
|
||||||
|
|
||||||
KVTempData: TypeAlias = dict[str, tuple[Any, GGUFValueType]] # {key: (value, type)}
|
KVTempData: TypeAlias = dict[str, tuple[Any, GGUFValueType]] # {key: (value, type)}
|
||||||
TensorTempData: TypeAlias = tuple[str, np.ndarray[Any, Any], GGMLQuantizationType] # (tensor name, tensor data, tensor dtype)
|
TensorTempData: TypeAlias = tuple[str, np.ndarray[Any, Any], GGMLQuantizationType] # (tensor name, tensor data, tensor dtype)
|
||||||
Shard: TypeAlias = list[os.PathLike[str], int, int, deque[TensorTempData]] # [shard filename, shard tensor count, shard size, [tensor data]]
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Shard:
|
||||||
|
path: str
|
||||||
|
tensor_count: int
|
||||||
|
size: int
|
||||||
|
tensors: deque[TensorTempData]
|
||||||
|
|
||||||
|
|
||||||
class SplitStyle(IntEnum):
|
class SplitStyle(IntEnum):
|
||||||
|
@ -73,11 +81,11 @@ class GGUFManager(GGUFWriter):
|
||||||
self.state = WriterState.EMPTY
|
self.state = WriterState.EMPTY
|
||||||
|
|
||||||
if self.split_arguments.small_first_shard:
|
if self.split_arguments.small_first_shard:
|
||||||
self.shards.append(["", 0, METADATA_ONLY_INDICATOR, None])
|
self.shards.append(Shard("", 0, METADATA_ONLY_INDICATOR, deque()))
|
||||||
|
|
||||||
def init_shards(self) -> None:
|
def init_shards(self) -> None:
|
||||||
self.total_tensors = sum(shard[1] for shard in self.shards)
|
self.total_tensors = sum(shard.tensor_count for shard in self.shards)
|
||||||
total_size = sum(shard[2] for shard in self.shards)
|
total_size = sum(shard.size for shard in self.shards)
|
||||||
|
|
||||||
# check if we need to split
|
# check if we need to split
|
||||||
if self.split_arguments.split_max_tensors and self.total_tensors < self.split_arguments.split_max_tensors:
|
if self.split_arguments.split_max_tensors and self.total_tensors < self.split_arguments.split_max_tensors:
|
||||||
|
@ -90,19 +98,20 @@ class GGUFManager(GGUFWriter):
|
||||||
|
|
||||||
# no shards are created when writing vocab so make one
|
# no shards are created when writing vocab so make one
|
||||||
if not self.shards:
|
if not self.shards:
|
||||||
self.shards.append(["", 0, METADATA_ONLY_INDICATOR, None])
|
self.shards.append(Shard("", 0, METADATA_ONLY_INDICATOR, deque()))
|
||||||
|
|
||||||
# format shard names
|
# format shard names
|
||||||
if len(self.shards) == 1:
|
if len(self.shards) == 1:
|
||||||
self.shards[0][0] = self.path
|
self.shards[0].path = self.path
|
||||||
else:
|
else:
|
||||||
for i in range(len(self.shards)):
|
for i in range(len(self.shards)):
|
||||||
self.shards[i][0] = self.path.with_name(SHARD_NAME_FORMAT.format(self.path.stem, i + 1, len(self.shards)))
|
# TODO with_name is not explicit - import pathlib
|
||||||
|
self.shards[i].path = self.path.with_name(SHARD_NAME_FORMAT.format(self.path.stem, i + 1, len(self.shards)))
|
||||||
|
|
||||||
# print shard info
|
# print shard info
|
||||||
print("\nWriting the following files:")
|
print("\nWriting the following files:")
|
||||||
for (path, tensor_ct, size, _) in self.shards:
|
for shard in self.shards:
|
||||||
print(f" {path}: n_tensors = {tensor_ct}, total_size = {GGUFManager.format_n_bytes_to_str(size)}")
|
print(f" {shard.path}: n_tensors = {shard.tensor_count}, total_size = {GGUFManager.format_n_bytes_to_str(shard.size)}")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
if self.split_arguments.dry_run:
|
if self.split_arguments.dry_run:
|
||||||
|
@ -110,10 +119,10 @@ class GGUFManager(GGUFWriter):
|
||||||
exit()
|
exit()
|
||||||
|
|
||||||
# we don't want to initialize GGUFWriters until now because they create files
|
# we don't want to initialize GGUFWriters until now because they create files
|
||||||
for i, (path, _, _, tensors) in enumerate(self.shards):
|
for i, shard in enumerate(self.shards):
|
||||||
# dont_add_architecture is used for consistency - examples/gguf_split doesn't add arch to all shards
|
# add_architecture is used for consistency - examples/gguf_split doesn't add arch to all shards
|
||||||
writer = GGUFWriter(path, self.arch, use_temp_file=self.use_temp_file,
|
writer = GGUFWriter(shard.path, self.arch, use_temp_file=self.use_temp_file,
|
||||||
endianess=self.endianess, dont_add_architecture=not (i == 0))
|
endianess=self.endianess, add_architecture=(i == 0))
|
||||||
|
|
||||||
# only the first shard needs all the KV data
|
# only the first shard needs all the KV data
|
||||||
if i == 0:
|
if i == 0:
|
||||||
|
@ -130,7 +139,7 @@ class GGUFManager(GGUFWriter):
|
||||||
# add tensors, deque popleft() ensures references to eager tensors are not kept
|
# add tensors, deque popleft() ensures references to eager tensors are not kept
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
(name, tensor, dtype) = tensors.popleft()
|
(name, tensor, dtype) = shard.tensors.popleft()
|
||||||
writer.add_tensor(name, tensor, raw_dtype=dtype)
|
writer.add_tensor(name, tensor, raw_dtype=dtype)
|
||||||
except:
|
except:
|
||||||
break
|
break
|
||||||
|
@ -199,17 +208,17 @@ class GGUFManager(GGUFWriter):
|
||||||
if (len(self.shards) == self.split_arguments.small_first_shard \
|
if (len(self.shards) == self.split_arguments.small_first_shard \
|
||||||
# or split when over tensor limit
|
# or split when over tensor limit
|
||||||
or (self.split_arguments.split_style == SplitStyle.TENSORS \
|
or (self.split_arguments.split_style == SplitStyle.TENSORS \
|
||||||
and self.shards[-1][1] >= self.split_arguments.split_max_tensors) \
|
and self.shards[-1].tensor_count >= self.split_arguments.split_max_tensors) \
|
||||||
# or split when over size limit
|
# or split when over size limit
|
||||||
or (self.split_arguments.split_style == SplitStyle.SIZE \
|
or (self.split_arguments.split_style == SplitStyle.SIZE \
|
||||||
and self.shards[-1][2] + GGUFManager.get_tensor_size(tensor) > self.split_arguments.split_max_size)):
|
and self.shards[-1].size + GGUFManager.get_tensor_size(tensor) > self.split_arguments.split_max_size)):
|
||||||
|
|
||||||
# we fill in the name later when we know how many shards there are
|
# we fill in the name later when we know how many shards there are
|
||||||
self.shards.append(["", 1, GGUFManager.get_tensor_size(tensor), deque([(name, tensor, raw_dtype)])])
|
self.shards.append(Shard("", 1, GGUFManager.get_tensor_size(tensor), deque([(name, tensor, raw_dtype)])))
|
||||||
else:
|
else:
|
||||||
self.shards[-1][1] += 1
|
self.shards[-1].tensor_count += 1
|
||||||
self.shards[-1][2] += GGUFManager.get_tensor_size(tensor)
|
self.shards[-1].size += GGUFManager.get_tensor_size(tensor)
|
||||||
self.shards[-1][3].append((name, tensor, raw_dtype))
|
self.shards[-1].tensors.append((name, tensor, raw_dtype))
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
for writer in self.shard_writers:
|
for writer in self.shard_writers:
|
||||||
|
|
|
@ -57,7 +57,7 @@ class GGUFWriter:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, path: os.PathLike[str] | str, arch: str, use_temp_file: bool = True,
|
self, path: os.PathLike[str] | str, arch: str, use_temp_file: bool = True,
|
||||||
endianess: GGUFEndian = GGUFEndian.LITTLE, dont_add_architecture: bool = False
|
endianess: GGUFEndian = GGUFEndian.LITTLE, add_architecture: bool = True
|
||||||
):
|
):
|
||||||
self.fout = open(path, "wb")
|
self.fout = open(path, "wb")
|
||||||
self.arch = arch
|
self.arch = arch
|
||||||
|
@ -77,7 +77,7 @@ class GGUFWriter:
|
||||||
))
|
))
|
||||||
self.state = WriterState.EMPTY
|
self.state = WriterState.EMPTY
|
||||||
|
|
||||||
if not dont_add_architecture:
|
if add_architecture:
|
||||||
self.add_architecture()
|
self.add_architecture()
|
||||||
|
|
||||||
def write_header_to_file(self) -> None:
|
def write_header_to_file(self) -> None:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue