catch oversights
This commit is contained in:
parent
f7e7983946
commit
79bd2bfcb0
1 changed files with 18 additions and 17 deletions
|
@ -122,11 +122,11 @@ class GGUFWriter:
|
|||
self.path = path
|
||||
|
||||
if self.path is not None:
|
||||
self.print_plan()
|
||||
self.fout = [open(filename, "wb") for filename in self.format_shard_names(self.path)]
|
||||
filenames = self.print_plan()
|
||||
self.fout = [open(filename, "wb") for filename in filenames]
|
||||
self.state = WriterState.EMPTY
|
||||
|
||||
def print_plan(self) -> None:
|
||||
def print_plan(self) -> list[Path]:
|
||||
logger.info("Writing the following files:")
|
||||
assert self.path is not None
|
||||
filenames = self.format_shard_names(self.path)
|
||||
|
@ -138,6 +138,8 @@ class GGUFWriter:
|
|||
logger.info("Dry run, not writing files")
|
||||
exit()
|
||||
|
||||
return filenames
|
||||
|
||||
def add_shard_kv_data(self) -> None:
|
||||
if len(self.tensors) == 1:
|
||||
return
|
||||
|
@ -152,7 +154,7 @@ class GGUFWriter:
|
|||
self.kv_data[i][Keys.Split.LLM_KV_SPLIT_TENSORS_COUNT] = GGUFValue(total_tensors, GGUFValueType.INT32)
|
||||
|
||||
def write_header_to_file(self, path: Path | None = None) -> None:
|
||||
if len(self.tensors) == 1:
|
||||
if len(self.tensors) == 1 and (self.split_max_tensors != 0 or self.split_max_size != 0):
|
||||
logger.warning("Model fails split requirements, not splitting")
|
||||
|
||||
self.open_output_file(path)
|
||||
|
@ -298,13 +300,15 @@ class GGUFWriter:
|
|||
tensor_shape = quant_shape_from_byte_shape(tensor_shape, raw_dtype)
|
||||
|
||||
# make sure there is at least one tensor before splitting
|
||||
if (len(self.tensors[-1]) > 0
|
||||
# split when over tensor limit
|
||||
and (self.split_max_tensors != 0 and len(self.tensors[-1]) >= self.split_max_tensors)
|
||||
# or split when over size limit
|
||||
or (self.split_max_size != 0 and sum(ti.nbytes for ti in self.tensors[-1].values()) + tensor_nbytes > self.split_max_size)):
|
||||
|
||||
self.tensors.append(dict())
|
||||
if len(self.tensors[-1]) > 0:
|
||||
if ( # split when over tensor limit
|
||||
self.split_max_tensors != 0
|
||||
and len(self.tensors[-1]) >= self.split_max_tensors
|
||||
) or ( # split when over size limit
|
||||
self.split_max_size != 0
|
||||
and sum(ti.nbytes for ti in self.tensors[-1].values()) + tensor_nbytes > self.split_max_size
|
||||
):
|
||||
self.tensors.append({})
|
||||
|
||||
self.tensors[-1][name] = TensorInfo(shape=tensor_shape, dtype=dtype, nbytes=tensor_nbytes)
|
||||
|
||||
|
@ -367,12 +371,12 @@ class GGUFWriter:
|
|||
total_bytes = sum(ti.nbytes for t in self.tensors for ti in t.values())
|
||||
|
||||
bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
|
||||
shard_bar = tqdm(desc="Shard progress", total=total_bytes, unit="byte", unit_scale=True)
|
||||
if len(self.fout) > 1:
|
||||
shard_bar = tqdm(desc=f"Shard (1/{len(self.fout)})", total=total_bytes, unit="byte", unit_scale=True)
|
||||
|
||||
for i, (fout, tensors) in enumerate(zip(self.fout, self.tensors)):
|
||||
if bar and len(self.fout) > 1:
|
||||
bar.desc = f"Writing ({i + 1}/{len(self.fout)})"
|
||||
if shard_bar and len(self.fout) > 1:
|
||||
shard_bar.set_description(f"Shard ({i + 1}/{len(self.fout)})")
|
||||
total = sum(ti.nbytes for ti in tensors.values())
|
||||
# bar behaves weirdly when total is 0
|
||||
if total > 0:
|
||||
|
@ -686,9 +690,6 @@ class GGUFWriter:
|
|||
|
||||
return kv_data
|
||||
|
||||
def _write_packed(self, fout: BufferedWriter, fmt: str, value: Any, skip_pack_prefix: bool = False) -> None:
|
||||
fout.write(self._pack(fmt, value, skip_pack_prefix))
|
||||
|
||||
@staticmethod
|
||||
def format_n_bytes_to_str(num: int) -> str:
|
||||
if num == 0:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue