cleanup round 1
This commit is contained in:
parent
49b9fbe942
commit
0471f67f4f
1 changed files with 23 additions and 42 deletions
|
@ -56,14 +56,6 @@ class GGUFValue:
|
|||
type: GGUFValueType
|
||||
|
||||
|
||||
@dataclass
|
||||
class Shard:
|
||||
path: Path
|
||||
tensor_count: int
|
||||
size: int
|
||||
tensors: deque[TensorTempData]
|
||||
|
||||
|
||||
class SplitArguments:
|
||||
def __init__(self, args: Namespace) -> None:
|
||||
self.split_max_tensors = args.split_max_tensors if args.split_max_tensors else 0
|
||||
|
@ -91,10 +83,10 @@ class SplitStyle(Enum):
|
|||
|
||||
|
||||
class GGUFWriter:
|
||||
fout: list[BufferedWriter | None] | None
|
||||
fout: list[BufferedWriter] | None
|
||||
path: os.PathLike[str] | str | None
|
||||
temp_file: tempfile.SpooledTemporaryFile[bytes] | None
|
||||
tensors: list[dict[str, TensorInfo | np.ndarray[Any, Any]]]
|
||||
tensors: list[dict[str, TensorInfo]]
|
||||
kv_data: list[dict[str, GGUFValue]]
|
||||
state: WriterState
|
||||
_simple_value_packing = {
|
||||
|
@ -137,7 +129,7 @@ class GGUFWriter:
|
|||
|
||||
def verify_arguments(self) -> None:
|
||||
total_tensors = sum(len(ti) for ti in self.tensors)
|
||||
total_size = sum(sum(GGUFWriter.get_tensor_size(ti) for ti in t.values()) for t in self.tensors)
|
||||
total_size = sum(ti.nbytes for t in self.tensors for ti in t.values())
|
||||
|
||||
if self.split_arguments.split_max_tensors and total_tensors < self.split_arguments.split_max_tensors:
|
||||
logger.warning("Model has fewer tensors than the split threshold, not splitting")
|
||||
|
@ -149,10 +141,10 @@ class GGUFWriter:
|
|||
|
||||
# no shards are created when writing vocab so make one
|
||||
if not self.tensors or len(self.tensors) == 0:
|
||||
self.tensors.append(dict())
|
||||
self.tensors = [dict()]
|
||||
|
||||
def format_shard_names(self) -> list[os.PathLike[str]]:
|
||||
pathobj = Path(self.path)
|
||||
def format_shard_names(self, path: os.PathLike[str] | str | None = None) -> list[os.PathLike[str]]:
|
||||
pathobj = Path(path)
|
||||
if self.split_arguments.split_style == SplitStyle.NONE:
|
||||
return [pathobj]
|
||||
|
||||
|
@ -174,14 +166,15 @@ class GGUFWriter:
|
|||
|
||||
if self.path is not None:
|
||||
self.fout = []
|
||||
for fout in self.format_shard_names():
|
||||
for fout in self.format_shard_names(self.path):
|
||||
self.fout.append(open(fout, "wb"))
|
||||
self.state = WriterState.EMPTY
|
||||
|
||||
def print_plan(self) -> None:
|
||||
def print_plan(self, path: os.PathLike[str] | str | None = None) -> None:
|
||||
logger.info("Writing the following files:")
|
||||
for i in range(len(self.fout)):
|
||||
logger.info(f"{self.fout[i].name}: n_tensors = {len(self.tensors[i])}, total_size = {GGUFWriter.format_n_bytes_to_str(GGUFWriter.get_tensors_total_size(self.tensors[i].values()))}")
|
||||
filenames = self.format_shard_names(path)
|
||||
for i in range(len(filenames)):
|
||||
logger.info(f"{filenames[i]}: n_tensors = {len(self.tensors[i])}, total_size = {GGUFWriter.format_n_bytes_to_str(sum(ti.nbytes for ti in self.tensors[i].values()))}")
|
||||
|
||||
if self.split_arguments.dry_run:
|
||||
logger.info("Dry run, not writing files")
|
||||
|
@ -204,8 +197,8 @@ class GGUFWriter:
|
|||
|
||||
def write_header_to_file(self, path: os.PathLike[str] | str | None = None) -> None:
|
||||
self.verify_arguments()
|
||||
self.print_plan(path)
|
||||
self.open_output_file(path)
|
||||
self.print_plan()
|
||||
|
||||
if self.state is not WriterState.EMPTY:
|
||||
raise ValueError(f'Expected output file to be empty, got {self.state}')
|
||||
|
@ -215,13 +208,12 @@ class GGUFWriter:
|
|||
|
||||
self.add_shard_kv_data()
|
||||
|
||||
for i in range(len(self.fout)):
|
||||
fout = self.fout[i]
|
||||
for fout, tensors, kv_data in zip(self.fout, self.tensors, self.kv_data):
|
||||
self._write_packed(fout, "<I", GGUF_MAGIC, skip_pack_prefix = True)
|
||||
self._write_packed(fout, "I", GGUF_VERSION)
|
||||
self._write_packed(fout, "Q", len(self.tensors[i]))
|
||||
self._write_packed(fout, "Q", len(self.kv_data[i]))
|
||||
self.fout[i].flush()
|
||||
self._write_packed(fout, "Q", len(tensors))
|
||||
self._write_packed(fout, "Q", len(kv_data))
|
||||
fout.flush()
|
||||
self.state = WriterState.HEADER
|
||||
|
||||
def write_kv_data_to_file(self) -> None:
|
||||
|
@ -246,12 +238,12 @@ class GGUFWriter:
|
|||
raise ValueError(f'Expected output file to contain KV data, got {self.state}')
|
||||
assert self.fout is not None
|
||||
|
||||
for i in range(len(self.fout)):
|
||||
assert self.fout[i] is not None
|
||||
for fout, tensors in zip(self.fout, self.tensors):
|
||||
assert fout is not None
|
||||
ti_data = bytearray()
|
||||
offset_tensor = 0
|
||||
|
||||
for name, ti in self.tensors[i].items():
|
||||
for name, ti in tensors.items():
|
||||
ti_data += self._pack_val(name, GGUFValueType.STRING, add_vtype=False)
|
||||
n_dims = len(ti.shape)
|
||||
ti_data += self._pack("I", n_dims)
|
||||
|
@ -261,8 +253,8 @@ class GGUFWriter:
|
|||
ti_data += self._pack("Q", offset_tensor)
|
||||
offset_tensor += GGUFWriter.ggml_pad(ti.nbytes, self.data_alignment)
|
||||
|
||||
self.fout[i].write(ti_data)
|
||||
self.fout[i].flush()
|
||||
fout.write(ti_data)
|
||||
fout.flush()
|
||||
self.state = WriterState.TI_DATA
|
||||
|
||||
def add_key_value(self, key: str, val: Any, vtype: GGUFValueType) -> None:
|
||||
|
@ -359,7 +351,7 @@ class GGUFWriter:
|
|||
and len(self.tensors[-1]) >= self.split_arguments.split_max_tensors) \
|
||||
# or split when over size limit
|
||||
or (self.split_arguments.split_style == SplitStyle.SIZE \
|
||||
and GGUFWriter.get_tensors_total_size(self.tensors[-1].values()) + tensor_nbytes > self.split_arguments.split_max_size)):
|
||||
and sum(ti.nbytes for ti in self.tensors[-1].values()) + tensor_nbytes > self.split_arguments.split_max_size)):
|
||||
|
||||
self.tensors.append(dict())
|
||||
|
||||
|
@ -424,7 +416,7 @@ class GGUFWriter:
|
|||
if progress:
|
||||
from tqdm import tqdm
|
||||
|
||||
total_bytes = GGUFWriter.get_tensors_total_size(self.tensors[i].values())
|
||||
total_bytes = sum(ti.nbytes for ti in self.tensors[i].values())
|
||||
|
||||
bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
|
||||
|
||||
|
@ -739,17 +731,6 @@ class GGUFWriter:
|
|||
assert fout is not None
|
||||
fout.write(self._pack(fmt, value, skip_pack_prefix))
|
||||
|
||||
@staticmethod
|
||||
def get_tensor_size(tensor) -> int:
|
||||
try:
|
||||
return tensor.data_type.elements_to_bytes(np.prod(tensor.shape))
|
||||
except AttributeError: # numpy ndarray[Any, Any]
|
||||
return tensor.nbytes
|
||||
|
||||
@staticmethod
|
||||
def get_tensors_total_size(tensors) -> int:
|
||||
return sum(GGUFWriter.get_tensor_size(ti) for ti in tensors)
|
||||
|
||||
@staticmethod
|
||||
def split_str_to_n_bytes(split_str: str) -> int:
|
||||
if split_str.endswith("K"):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue