add progress bar

This commit is contained in:
Sigbjørn Skjæret 2024-05-04 20:29:40 +02:00 committed by GitHub
parent d39f20359e
commit 158215c828
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -7,6 +7,7 @@ import json
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
from tqdm import tqdm
from typing import Any, Sequence, NamedTuple from typing import Any, Sequence, NamedTuple
# Necessary to load the local gguf package # Necessary to load the local gguf package
@ -113,17 +114,23 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new
writer.add_key(key) writer.add_key(key)
writer.add_val(val.value, val.type) writer.add_val(val.value, val.type)
total_bytes = 0
for tensor in reader.tensors: for tensor in reader.tensors:
total_bytes += tensor.n_bytes
# Dimensions are written in reverse order, so flip them first # Dimensions are written in reverse order, so flip them first
shape = np.flipud(tensor.shape).tolist() shape = np.flipud(tensor.shape).tolist()
writer.add_tensor_info(tensor.name, shape, tensor.data.dtype, tensor.data.nbytes, tensor.tensor_type) writer.add_tensor_info(tensor.name, shape, tensor.data.dtype, tensor.data.nbytes, tensor.tensor_type)
bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
writer.write_header_to_file() writer.write_header_to_file()
writer.write_kv_data_to_file() writer.write_kv_data_to_file()
writer.write_ti_data_to_file() writer.write_ti_data_to_file()
for tensor in reader.tensors: for tensor in reader.tensors:
writer.write_tensor_data(tensor.data) writer.write_tensor_data(tensor.data)
bar.update(tensor.n_bytes)
writer.close() writer.close()