progress bar, fix split logic
This commit is contained in:
parent
70a6bc91cc
commit
1e2d9cb589
1 changed files with 24 additions and 14 deletions
|
@ -145,12 +145,8 @@ class GGUFWriter:
|
||||||
total_tensors = sum(len(t) for t in self.tensors)
|
total_tensors = sum(len(t) for t in self.tensors)
|
||||||
assert self.fout is not None
|
assert self.fout is not None
|
||||||
total_splits = len(self.fout)
|
total_splits = len(self.fout)
|
||||||
|
self.kv_data.extend({} for _ in range(len(self.kv_data), total_splits))
|
||||||
for i in range(total_splits):
|
for i in range(total_splits):
|
||||||
# just see whether it exists
|
|
||||||
try:
|
|
||||||
self.kv_data[i]
|
|
||||||
except IndexError:
|
|
||||||
self.kv_data.append(dict())
|
|
||||||
self.kv_data[i][Keys.Split.LLM_KV_SPLIT_NO] = GGUFValue(i, GGUFValueType.UINT16)
|
self.kv_data[i][Keys.Split.LLM_KV_SPLIT_NO] = GGUFValue(i, GGUFValueType.UINT16)
|
||||||
self.kv_data[i][Keys.Split.LLM_KV_SPLIT_COUNT] = GGUFValue(total_splits, GGUFValueType.UINT16)
|
self.kv_data[i][Keys.Split.LLM_KV_SPLIT_COUNT] = GGUFValue(total_splits, GGUFValueType.UINT16)
|
||||||
self.kv_data[i][Keys.Split.LLM_KV_SPLIT_TENSORS_COUNT] = GGUFValue(total_tensors, GGUFValueType.INT32)
|
self.kv_data[i][Keys.Split.LLM_KV_SPLIT_TENSORS_COUNT] = GGUFValue(total_tensors, GGUFValueType.INT32)
|
||||||
|
@ -301,10 +297,12 @@ class GGUFWriter:
|
||||||
if tensor_dtype == np.uint8:
|
if tensor_dtype == np.uint8:
|
||||||
tensor_shape = quant_shape_from_byte_shape(tensor_shape, raw_dtype)
|
tensor_shape = quant_shape_from_byte_shape(tensor_shape, raw_dtype)
|
||||||
|
|
||||||
|
# make sure there is at least one tensor before splitting
|
||||||
|
if (len(self.tensors[-1]) > 0
|
||||||
# split when over tensor limit
|
# split when over tensor limit
|
||||||
if (self.split_max_tensors != 0 and len(self.tensors[-1]) >= self.split_max_tensors \
|
and (self.split_max_tensors != 0 and len(self.tensors[-1]) >= self.split_max_tensors)
|
||||||
# or split when over size limit
|
# or split when over size limit
|
||||||
or self.split_max_size != 0 and sum(ti.nbytes for ti in self.tensors[-1].values()) + tensor_nbytes > self.split_max_size):
|
or (self.split_max_size != 0 and sum(ti.nbytes for ti in self.tensors[-1].values()) + tensor_nbytes > self.split_max_size)):
|
||||||
|
|
||||||
self.tensors.append(dict())
|
self.tensors.append(dict())
|
||||||
|
|
||||||
|
@ -360,15 +358,25 @@ class GGUFWriter:
|
||||||
self.write_padding(fout, fout.tell())
|
self.write_padding(fout, fout.tell())
|
||||||
|
|
||||||
if self.temp_file is None:
|
if self.temp_file is None:
|
||||||
for fout, tensors in zip(self.fout, self.tensors):
|
|
||||||
bar = None
|
bar = None
|
||||||
|
shard_bar = None
|
||||||
|
|
||||||
if progress:
|
if progress:
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
total_bytes = sum(ti.nbytes for ti in tensors.values())
|
total_bytes = sum(ti.nbytes for t in self.tensors for ti in t.values())
|
||||||
|
|
||||||
bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
|
bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
|
||||||
|
shard_bar = tqdm(desc="Shard progress", total=total_bytes, unit="byte", unit_scale=True)
|
||||||
|
|
||||||
|
for i, (fout, tensors) in enumerate(zip(self.fout, self.tensors)):
|
||||||
|
if bar and len(self.fout) > 1:
|
||||||
|
bar.desc = f"Writing ({i + 1}/{len(self.fout)})"
|
||||||
|
if shard_bar and len(self.fout) > 1:
|
||||||
|
total = sum(ti.nbytes for ti in tensors.values())
|
||||||
|
# bar behaves weirdly when total is 0
|
||||||
|
if total > 0:
|
||||||
|
shard_bar.reset(total=total)
|
||||||
|
|
||||||
# relying on the fact that Python dicts preserve insertion order (since 3.7)
|
# relying on the fact that Python dicts preserve insertion order (since 3.7)
|
||||||
for ti in tensors.values():
|
for ti in tensors.values():
|
||||||
|
@ -377,6 +385,8 @@ class GGUFWriter:
|
||||||
ti.tensor.tofile(fout)
|
ti.tensor.tofile(fout)
|
||||||
if bar is not None:
|
if bar is not None:
|
||||||
bar.update(ti.nbytes)
|
bar.update(ti.nbytes)
|
||||||
|
if shard_bar is not None:
|
||||||
|
shard_bar.update(ti.nbytes)
|
||||||
self.write_padding(fout, ti.nbytes)
|
self.write_padding(fout, ti.nbytes)
|
||||||
ti.tensor = None
|
ti.tensor = None
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue