make pathlib explicit
This commit is contained in:
parent
2037eabb64
commit
83e4a3f5cc
1 changed files with 6 additions and 7 deletions
|
@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any, Sequence
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
@ -30,7 +31,7 @@ TensorTempData: TypeAlias = tuple[str, np.ndarray[Any, Any], GGMLQuantizationTyp
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Shard:
|
class Shard:
|
||||||
path: str
|
path: Path
|
||||||
tensor_count: int
|
tensor_count: int
|
||||||
size: int
|
size: int
|
||||||
tensors: deque[TensorTempData]
|
tensors: deque[TensorTempData]
|
||||||
|
@ -56,7 +57,6 @@ class SplitArguments:
|
||||||
|
|
||||||
class GGUFManager(GGUFWriter):
|
class GGUFManager(GGUFWriter):
|
||||||
kv_data: KVTempData
|
kv_data: KVTempData
|
||||||
tensors: list[TensorTempData]
|
|
||||||
split_arguments: SplitArguments
|
split_arguments: SplitArguments
|
||||||
shards: list[Shard]
|
shards: list[Shard]
|
||||||
shard_writers: list[GGUFWriter]
|
shard_writers: list[GGUFWriter]
|
||||||
|
@ -66,7 +66,7 @@ class GGUFManager(GGUFWriter):
|
||||||
) -> None:
|
) -> None:
|
||||||
# we intentionally don't call superclass constructor
|
# we intentionally don't call superclass constructor
|
||||||
self.arch = arch
|
self.arch = arch
|
||||||
self.path = path
|
self.path = Path(path)
|
||||||
self.endianess = endianess
|
self.endianess = endianess
|
||||||
self.kv_data = {}
|
self.kv_data = {}
|
||||||
self.shards = []
|
self.shards = []
|
||||||
|
@ -78,7 +78,7 @@ class GGUFManager(GGUFWriter):
|
||||||
self.state = WriterState.EMPTY
|
self.state = WriterState.EMPTY
|
||||||
|
|
||||||
if self.split_arguments.small_first_shard:
|
if self.split_arguments.small_first_shard:
|
||||||
self.shards.append(Shard("", 0, METADATA_ONLY_INDICATOR, deque()))
|
self.shards.append(Shard(Path(), 0, METADATA_ONLY_INDICATOR, deque()))
|
||||||
|
|
||||||
def init_shards(self) -> None:
|
def init_shards(self) -> None:
|
||||||
self.total_tensors = sum(shard.tensor_count for shard in self.shards)
|
self.total_tensors = sum(shard.tensor_count for shard in self.shards)
|
||||||
|
@ -95,14 +95,13 @@ class GGUFManager(GGUFWriter):
|
||||||
|
|
||||||
# no shards are created when writing vocab so make one
|
# no shards are created when writing vocab so make one
|
||||||
if not self.shards:
|
if not self.shards:
|
||||||
self.shards.append(Shard("", 0, METADATA_ONLY_INDICATOR, deque()))
|
self.shards.append(Shard(Path(), 0, METADATA_ONLY_INDICATOR, deque()))
|
||||||
|
|
||||||
# format shard names
|
# format shard names
|
||||||
if len(self.shards) == 1:
|
if len(self.shards) == 1:
|
||||||
self.shards[0].path = self.path
|
self.shards[0].path = self.path
|
||||||
else:
|
else:
|
||||||
for i in range(len(self.shards)):
|
for i in range(len(self.shards)):
|
||||||
# TODO with_name is not explicit - import pathlib
|
|
||||||
self.shards[i].path = self.path.with_name(SHARD_NAME_FORMAT.format(self.path.stem, i + 1, len(self.shards)))
|
self.shards[i].path = self.path.with_name(SHARD_NAME_FORMAT.format(self.path.stem, i + 1, len(self.shards)))
|
||||||
|
|
||||||
# print shard info
|
# print shard info
|
||||||
|
@ -211,7 +210,7 @@ class GGUFManager(GGUFWriter):
|
||||||
and self.shards[-1].size + GGUFManager.get_tensor_size(tensor) > self.split_arguments.split_max_size)):
|
and self.shards[-1].size + GGUFManager.get_tensor_size(tensor) > self.split_arguments.split_max_size)):
|
||||||
|
|
||||||
# we fill in the name later when we know how many shards there are
|
# we fill in the name later when we know how many shards there are
|
||||||
self.shards.append(Shard("", 1, GGUFManager.get_tensor_size(tensor), deque([(name, tensor, raw_dtype)])))
|
self.shards.append(Shard(Path(), 1, GGUFManager.get_tensor_size(tensor), deque([(name, tensor, raw_dtype)])))
|
||||||
else:
|
else:
|
||||||
self.shards[-1].tensor_count += 1
|
self.shards[-1].tensor_count += 1
|
||||||
self.shards[-1].size += GGUFManager.get_tensor_size(tensor)
|
self.shards[-1].size += GGUFManager.get_tensor_size(tensor)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue