make pathlib explicit

This commit is contained in:
Christian Zhou-Zheng 2024-06-06 09:00:59 -04:00
parent 2037eabb64
commit 83e4a3f5cc

View file

@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any, Sequence
from argparse import Namespace from argparse import Namespace
from collections import deque from collections import deque
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path
import numpy as np import numpy as np
@ -30,7 +31,7 @@ TensorTempData: TypeAlias = tuple[str, np.ndarray[Any, Any], GGMLQuantizationTyp
@dataclass @dataclass
class Shard: class Shard:
path: str path: Path
tensor_count: int tensor_count: int
size: int size: int
tensors: deque[TensorTempData] tensors: deque[TensorTempData]
@ -56,7 +57,6 @@ class SplitArguments:
class GGUFManager(GGUFWriter): class GGUFManager(GGUFWriter):
kv_data: KVTempData kv_data: KVTempData
tensors: list[TensorTempData]
split_arguments: SplitArguments split_arguments: SplitArguments
shards: list[Shard] shards: list[Shard]
shard_writers: list[GGUFWriter] shard_writers: list[GGUFWriter]
@ -66,7 +66,7 @@ class GGUFManager(GGUFWriter):
) -> None: ) -> None:
# we intentionally don't call superclass constructor # we intentionally don't call superclass constructor
self.arch = arch self.arch = arch
self.path = path self.path = Path(path)
self.endianess = endianess self.endianess = endianess
self.kv_data = {} self.kv_data = {}
self.shards = [] self.shards = []
@ -78,7 +78,7 @@ class GGUFManager(GGUFWriter):
self.state = WriterState.EMPTY self.state = WriterState.EMPTY
if self.split_arguments.small_first_shard: if self.split_arguments.small_first_shard:
self.shards.append(Shard("", 0, METADATA_ONLY_INDICATOR, deque())) self.shards.append(Shard(Path(), 0, METADATA_ONLY_INDICATOR, deque()))
def init_shards(self) -> None: def init_shards(self) -> None:
self.total_tensors = sum(shard.tensor_count for shard in self.shards) self.total_tensors = sum(shard.tensor_count for shard in self.shards)
@ -95,14 +95,13 @@ class GGUFManager(GGUFWriter):
# no shards are created when writing vocab so make one # no shards are created when writing vocab so make one
if not self.shards: if not self.shards:
self.shards.append(Shard("", 0, METADATA_ONLY_INDICATOR, deque())) self.shards.append(Shard(Path(), 0, METADATA_ONLY_INDICATOR, deque()))
# format shard names # format shard names
if len(self.shards) == 1: if len(self.shards) == 1:
self.shards[0].path = self.path self.shards[0].path = self.path
else: else:
for i in range(len(self.shards)): for i in range(len(self.shards)):
# TODO with_name is not explicit - import pathlib
self.shards[i].path = self.path.with_name(SHARD_NAME_FORMAT.format(self.path.stem, i + 1, len(self.shards))) self.shards[i].path = self.path.with_name(SHARD_NAME_FORMAT.format(self.path.stem, i + 1, len(self.shards)))
# print shard info # print shard info
@ -211,7 +210,7 @@ class GGUFManager(GGUFWriter):
and self.shards[-1].size + GGUFManager.get_tensor_size(tensor) > self.split_arguments.split_max_size)): and self.shards[-1].size + GGUFManager.get_tensor_size(tensor) > self.split_arguments.split_max_size)):
# we fill in the name later when we know how many shards there are # we fill in the name later when we know how many shards there are
self.shards.append(Shard("", 1, GGUFManager.get_tensor_size(tensor), deque([(name, tensor, raw_dtype)]))) self.shards.append(Shard(Path(), 1, GGUFManager.get_tensor_size(tensor), deque([(name, tensor, raw_dtype)])))
else: else:
self.shards[-1].tensor_count += 1 self.shards[-1].tensor_count += 1
self.shards[-1].size += GGUFManager.get_tensor_size(tensor) self.shards[-1].size += GGUFManager.get_tensor_size(tensor)