cleanup round 1

This commit is contained in:
Christian Zhou-Zheng 2024-06-09 12:40:02 -04:00
parent 49b9fbe942
commit 0471f67f4f

View file

@ -56,14 +56,6 @@ class GGUFValue:
type: GGUFValueType type: GGUFValueType
@dataclass
class Shard:
path: Path
tensor_count: int
size: int
tensors: deque[TensorTempData]
class SplitArguments: class SplitArguments:
def __init__(self, args: Namespace) -> None: def __init__(self, args: Namespace) -> None:
self.split_max_tensors = args.split_max_tensors if args.split_max_tensors else 0 self.split_max_tensors = args.split_max_tensors if args.split_max_tensors else 0
@ -91,10 +83,10 @@ class SplitStyle(Enum):
class GGUFWriter: class GGUFWriter:
fout: list[BufferedWriter | None] | None fout: list[BufferedWriter] | None
path: os.PathLike[str] | str | None path: os.PathLike[str] | str | None
temp_file: tempfile.SpooledTemporaryFile[bytes] | None temp_file: tempfile.SpooledTemporaryFile[bytes] | None
tensors: list[dict[str, TensorInfo | np.ndarray[Any, Any]]] tensors: list[dict[str, TensorInfo]]
kv_data: list[dict[str, GGUFValue]] kv_data: list[dict[str, GGUFValue]]
state: WriterState state: WriterState
_simple_value_packing = { _simple_value_packing = {
@ -137,7 +129,7 @@ class GGUFWriter:
def verify_arguments(self) -> None: def verify_arguments(self) -> None:
total_tensors = sum(len(ti) for ti in self.tensors) total_tensors = sum(len(ti) for ti in self.tensors)
total_size = sum(sum(GGUFWriter.get_tensor_size(ti) for ti in t.values()) for t in self.tensors) total_size = sum(ti.nbytes for t in self.tensors for ti in t.values())
if self.split_arguments.split_max_tensors and total_tensors < self.split_arguments.split_max_tensors: if self.split_arguments.split_max_tensors and total_tensors < self.split_arguments.split_max_tensors:
logger.warning("Model has fewer tensors than the split threshold, not splitting") logger.warning("Model has fewer tensors than the split threshold, not splitting")
@ -149,10 +141,10 @@ class GGUFWriter:
# no shards are created when writing vocab so make one # no shards are created when writing vocab so make one
if not self.tensors or len(self.tensors) == 0: if not self.tensors or len(self.tensors) == 0:
self.tensors.append(dict()) self.tensors = [dict()]
def format_shard_names(self) -> list[os.PathLike[str]]: def format_shard_names(self, path: os.PathLike[str] | str | None = None) -> list[os.PathLike[str]]:
pathobj = Path(self.path) pathobj = Path(path)
if self.split_arguments.split_style == SplitStyle.NONE: if self.split_arguments.split_style == SplitStyle.NONE:
return [pathobj] return [pathobj]
@ -174,14 +166,15 @@ class GGUFWriter:
if self.path is not None: if self.path is not None:
self.fout = [] self.fout = []
for fout in self.format_shard_names(): for fout in self.format_shard_names(self.path):
self.fout.append(open(fout, "wb")) self.fout.append(open(fout, "wb"))
self.state = WriterState.EMPTY self.state = WriterState.EMPTY
def print_plan(self) -> None: def print_plan(self, path: os.PathLike[str] | str | None = None) -> None:
logger.info("Writing the following files:") logger.info("Writing the following files:")
for i in range(len(self.fout)): filenames = self.format_shard_names(path)
logger.info(f"{self.fout[i].name}: n_tensors = {len(self.tensors[i])}, total_size = {GGUFWriter.format_n_bytes_to_str(GGUFWriter.get_tensors_total_size(self.tensors[i].values()))}") for i in range(len(filenames)):
logger.info(f"{filenames[i]}: n_tensors = {len(self.tensors[i])}, total_size = {GGUFWriter.format_n_bytes_to_str(sum(ti.nbytes for ti in self.tensors[i].values()))}")
if self.split_arguments.dry_run: if self.split_arguments.dry_run:
logger.info("Dry run, not writing files") logger.info("Dry run, not writing files")
@ -204,8 +197,8 @@ class GGUFWriter:
def write_header_to_file(self, path: os.PathLike[str] | str | None = None) -> None: def write_header_to_file(self, path: os.PathLike[str] | str | None = None) -> None:
self.verify_arguments() self.verify_arguments()
self.print_plan(path)
self.open_output_file(path) self.open_output_file(path)
self.print_plan()
if self.state is not WriterState.EMPTY: if self.state is not WriterState.EMPTY:
raise ValueError(f'Expected output file to be empty, got {self.state}') raise ValueError(f'Expected output file to be empty, got {self.state}')
@ -215,13 +208,12 @@ class GGUFWriter:
self.add_shard_kv_data() self.add_shard_kv_data()
for i in range(len(self.fout)): for fout, tensors, kv_data in zip(self.fout, self.tensors, self.kv_data):
fout = self.fout[i]
self._write_packed(fout, "<I", GGUF_MAGIC, skip_pack_prefix = True) self._write_packed(fout, "<I", GGUF_MAGIC, skip_pack_prefix = True)
self._write_packed(fout, "I", GGUF_VERSION) self._write_packed(fout, "I", GGUF_VERSION)
self._write_packed(fout, "Q", len(self.tensors[i])) self._write_packed(fout, "Q", len(tensors))
self._write_packed(fout, "Q", len(self.kv_data[i])) self._write_packed(fout, "Q", len(kv_data))
self.fout[i].flush() fout.flush()
self.state = WriterState.HEADER self.state = WriterState.HEADER
def write_kv_data_to_file(self) -> None: def write_kv_data_to_file(self) -> None:
@ -246,12 +238,12 @@ class GGUFWriter:
raise ValueError(f'Expected output file to contain KV data, got {self.state}') raise ValueError(f'Expected output file to contain KV data, got {self.state}')
assert self.fout is not None assert self.fout is not None
for i in range(len(self.fout)): for fout, tensors in zip(self.fout, self.tensors):
assert self.fout[i] is not None assert fout is not None
ti_data = bytearray() ti_data = bytearray()
offset_tensor = 0 offset_tensor = 0
for name, ti in self.tensors[i].items(): for name, ti in tensors.items():
ti_data += self._pack_val(name, GGUFValueType.STRING, add_vtype=False) ti_data += self._pack_val(name, GGUFValueType.STRING, add_vtype=False)
n_dims = len(ti.shape) n_dims = len(ti.shape)
ti_data += self._pack("I", n_dims) ti_data += self._pack("I", n_dims)
@ -261,8 +253,8 @@ class GGUFWriter:
ti_data += self._pack("Q", offset_tensor) ti_data += self._pack("Q", offset_tensor)
offset_tensor += GGUFWriter.ggml_pad(ti.nbytes, self.data_alignment) offset_tensor += GGUFWriter.ggml_pad(ti.nbytes, self.data_alignment)
self.fout[i].write(ti_data) fout.write(ti_data)
self.fout[i].flush() fout.flush()
self.state = WriterState.TI_DATA self.state = WriterState.TI_DATA
def add_key_value(self, key: str, val: Any, vtype: GGUFValueType) -> None: def add_key_value(self, key: str, val: Any, vtype: GGUFValueType) -> None:
@ -359,7 +351,7 @@ class GGUFWriter:
and len(self.tensors[-1]) >= self.split_arguments.split_max_tensors) \ and len(self.tensors[-1]) >= self.split_arguments.split_max_tensors) \
# or split when over size limit # or split when over size limit
or (self.split_arguments.split_style == SplitStyle.SIZE \ or (self.split_arguments.split_style == SplitStyle.SIZE \
and GGUFWriter.get_tensors_total_size(self.tensors[-1].values()) + tensor_nbytes > self.split_arguments.split_max_size)): and sum(ti.nbytes for ti in self.tensors[-1].values()) + tensor_nbytes > self.split_arguments.split_max_size)):
self.tensors.append(dict()) self.tensors.append(dict())
@ -424,7 +416,7 @@ class GGUFWriter:
if progress: if progress:
from tqdm import tqdm from tqdm import tqdm
total_bytes = GGUFWriter.get_tensors_total_size(self.tensors[i].values()) total_bytes = sum(ti.nbytes for ti in self.tensors[i].values())
bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True) bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
@ -739,17 +731,6 @@ class GGUFWriter:
assert fout is not None assert fout is not None
fout.write(self._pack(fmt, value, skip_pack_prefix)) fout.write(self._pack(fmt, value, skip_pack_prefix))
@staticmethod
def get_tensor_size(tensor) -> int:
try:
return tensor.data_type.elements_to_bytes(np.prod(tensor.shape))
except AttributeError: # numpy ndarray[Any, Any]
return tensor.nbytes
@staticmethod
def get_tensors_total_size(tensors) -> int:
return sum(GGUFWriter.get_tensor_size(ti) for ti in tensors)
@staticmethod @staticmethod
def split_str_to_n_bytes(split_str: str) -> int: def split_str_to_n_bytes(split_str: str) -> int:
if split_str.endswith("K"): if split_str.endswith("K"):