fix linting

This commit is contained in:
Christian Zhou-Zheng 2024-06-09 11:23:55 -04:00
parent 0779f2f74f
commit a234bf821b
2 changed files with 8 additions and 11 deletions

View file

@ -2891,15 +2891,13 @@ def main() -> None:
model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION) model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
if args.vocab_only: if args.vocab_only:
logger.info(f"Exporting model vocab...") logger.info("Exporting model vocab...")
model_instance.write_vocab() model_instance.write_vocab()
logger.info(f"Model vocab successfully exported.") logger.info("Model vocab successfully exported.")
else: else:
logger.info(f"Exporting model...") logger.info("Exporting model...")
model_instance.write() model_instance.write()
logger.info(f"Model successfully exported.") logger.info("Model successfully exported.")
if __name__ == '__main__': if __name__ == '__main__':
main() main()

View file

@ -30,7 +30,6 @@ from .constants import (
) )
from .quants import quant_shape_from_byte_shape from .quants import quant_shape_from_byte_shape
from .constants import Keys
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -191,7 +190,7 @@ class GGUFWriter:
def add_shard_kv_data(self) -> None: def add_shard_kv_data(self) -> None:
if self.split_arguments.split_style == SplitStyle.NONE: if self.split_arguments.split_style == SplitStyle.NONE:
return return
total_tensors = sum(len(t) for t in self.tensors) total_tensors = sum(len(t) for t in self.tensors)
for i in range(len(self.fout)): for i in range(len(self.fout)):
# just see whether it exists # just see whether it exists
@ -746,11 +745,11 @@ class GGUFWriter:
return tensor.data_type.elements_to_bytes(np.prod(tensor.shape)) return tensor.data_type.elements_to_bytes(np.prod(tensor.shape))
except AttributeError: # numpy ndarray[Any, Any] except AttributeError: # numpy ndarray[Any, Any]
return tensor.nbytes return tensor.nbytes
@staticmethod @staticmethod
def get_tensors_total_size(tensors) -> int: def get_tensors_total_size(tensors) -> int:
return sum(GGUFWriter.get_tensor_size(ti) for ti in tensors) return sum(GGUFWriter.get_tensor_size(ti) for ti in tensors)
@staticmethod @staticmethod
def split_str_to_n_bytes(split_str: str) -> int: def split_str_to_n_bytes(split_str: str) -> int:
if split_str.endswith("K"): if split_str.endswith("K"):
@ -778,4 +777,4 @@ class GGUFWriter:
if abs(fnum) < 1000.0: if abs(fnum) < 1000.0:
return f"{fnum:3.1f}{unit}" return f"{fnum:3.1f}{unit}"
fnum /= 1000.0 fnum /= 1000.0
return f"{fnum:.1f}T - over 1TB, --split recommended" return f"{fnum:.1f}T - over 1TB, --split recommended"