convert_hf : use GGUFWriter to count model parameters

This commit is contained in:
Francis Couture-Harpin 2024-07-14 20:38:26 -04:00 committed by brian khuu
parent 78a42fbee5
commit 417d7a7c62
4 changed files with 101 additions and 97 deletions

View file

@ -48,6 +48,7 @@ class Model:
dir_model: Path dir_model: Path
ftype: gguf.LlamaFileType ftype: gguf.LlamaFileType
fname_out: Path | None
is_big_endian: bool is_big_endian: bool
endianess: gguf.GGUFEndian endianess: gguf.GGUFEndian
use_temp_file: bool use_temp_file: bool
@ -58,8 +59,6 @@ class Model:
block_count: int block_count: int
tensor_map: gguf.TensorNameMap tensor_map: gguf.TensorNameMap
tensor_names: set[str] | None tensor_names: set[str] | None
fname_out: Path
fname_default: str
gguf_writer: gguf.GGUFWriter gguf_writer: gguf.GGUFWriter
metadata: gguf.Metadata metadata: gguf.Metadata
@ -76,6 +75,7 @@ class Model:
self.dir_model = dir_model self.dir_model = dir_model
self.ftype = ftype self.ftype = ftype
self.fname_out = fname_out
self.is_big_endian = is_big_endian self.is_big_endian = is_big_endian
self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
self.use_temp_file = use_temp_file self.use_temp_file = use_temp_file
@ -101,37 +101,8 @@ class Model:
logger.info(f"choosing --outtype bf16 from first tensor type ({first_tensor.dtype})") logger.info(f"choosing --outtype bf16 from first tensor type ({first_tensor.dtype})")
self.ftype = gguf.LlamaFileType.MOSTLY_BF16 self.ftype = gguf.LlamaFileType.MOSTLY_BF16
# Fallback to model directory name if metadata name is still missing
if self.metadata.name is None:
self.metadata.name = dir_model.name
# Generate parameter weight class (useful for leader boards) if not yet determined
if self.metadata.size_label is None:
expert_count = self.hparams.get("num_local_experts", 0)
sum_weight_estimate = self.calculate_total_weight_count()
# Calculate weight estimate per model
per_model_weight_estimate: int = (sum_weight_estimate / expert_count) if (expert_count > 0) else sum_weight_estimate
self.metadata.size_label = gguf.size_label(expert_count, per_model_weight_estimate)
# Extracts and converts the encoding scheme from the given file type name. e.g. 'gguf.LlamaFileType.ALL_F32' --> 'F32'
output_type = self.ftype.name.partition("_")[2]
# Generate default filename based on model specification and available metadata
self.fname_default = gguf.naming_convention(self.metadata.name, self.metadata.basename, self.metadata.finetune, self.metadata.version, self.metadata.size_label, output_type)
# Filename Output
if fname_out is not None:
# custom defined filename and path was provided
# allow templating the file name with the output ftype, useful with the "auto" ftype
self.fname_out = fname_out.parent / gguf.fill_templated_filename(fname_out.name, output_type)
else:
# output in the same directory as the model by default
self.fname_out = dir_model / f"{self.fname_default}.gguf"
# Configure GGUF Writer # Configure GGUF Writer
self.gguf_writer = gguf.GGUFWriter(path=self.fname_out, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file, self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file,
split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard) split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard)
@classmethod @classmethod
@ -191,23 +162,6 @@ class Model:
if len(sym_diff := tensor_names_from_parts.symmetric_difference(self.tensor_names)) > 0: if len(sym_diff := tensor_names_from_parts.symmetric_difference(self.tensor_names)) > 0:
raise ValueError(f"Mismatch between weight map and model parts for tensor names: {sym_diff}") raise ValueError(f"Mismatch between weight map and model parts for tensor names: {sym_diff}")
def calculate_total_weight_count(self) -> int:
sum_weight_estimate = 0
for name, data_torch in self.get_tensors():
# we don't need these
if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")):
continue
# Calculate Tensor Volume
sum_weights_in_tensor = 1
for dim in data_torch.shape:
sum_weights_in_tensor *= dim
# Add Tensor Volume To Running Count
sum_weight_estimate += sum_weights_in_tensor
return sum_weight_estimate
def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str: def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str:
if key not in gguf.MODEL_TENSORS[self.model_arch]: if key not in gguf.MODEL_TENSORS[self.model_arch]:
raise ValueError(f"Missing {key!r} for MODEL_TENSORS of {self.model_arch!r}") raise ValueError(f"Missing {key!r} for MODEL_TENSORS of {self.model_arch!r}")
@ -376,7 +330,37 @@ class Model:
self.gguf_writer.add_tensor(new_name, data, raw_dtype=data_qtype) self.gguf_writer.add_tensor(new_name, data, raw_dtype=data_qtype)
def prepare_key_value_store(self): def prepare_metadata(self):
# Fallback to model directory name if metadata name is still missing
if self.metadata.name is None:
self.metadata.name = self.dir_model.name
# Generate parameter weight class (useful for leader boards) if not yet determined
if self.metadata.size_label is None:
total_params, shared_params, expert_params, expert_count = self.gguf_writer.get_total_parameter_count()
if (total_params > 0):
self.metadata.size_label = gguf.size_label(total_params, shared_params, expert_params, expert_count)
# Extract the encoding scheme from the file type name. e.g. 'gguf.LlamaFileType.MOSTLY_Q8_0' --> 'Q8_0'
output_type = self.ftype.name.partition("_")[2]
# Generate default filename based on model specification and available metadata
fname_default = gguf.naming_convention(self.metadata.name, self.metadata.basename, self.metadata.finetune, self.metadata.version, self.metadata.size_label, output_type)
# Filename Output
if self.fname_out is not None:
if not self.fname_out.is_dir():
# custom defined filename and path was provided
# allow templating the file name with the output ftype, useful with the "auto" ftype
self.fname_out = self.fname_out.parent / gguf.fill_templated_filename(self.fname_out.name, output_type)
else:
# the target file is a directory
self.fname_out = self.fname_out / f"{fname_default}.gguf"
else:
# output in the same directory as the model by default
self.fname_out = self.dir_model / f"{fname_default}.gguf"
logger.info("Set meta model") logger.info("Set meta model")
self.metadata.set_gguf_meta_model(self.gguf_writer) self.metadata.set_gguf_meta_model(self.gguf_writer)
@ -393,8 +377,8 @@ class Model:
def write(self): def write(self):
self.prepare_tensors() self.prepare_tensors()
self.prepare_key_value_store() self.prepare_metadata()
self.gguf_writer.write_header_to_file() self.gguf_writer.write_header_to_file(path=self.fname_out)
self.gguf_writer.write_kv_data_to_file() self.gguf_writer.write_kv_data_to_file()
self.gguf_writer.write_tensors_to_file(progress=True) self.gguf_writer.write_tensors_to_file(progress=True)
self.gguf_writer.close() self.gguf_writer.close()
@ -403,8 +387,8 @@ class Model:
if len(self.gguf_writer.tensors) != 1: if len(self.gguf_writer.tensors) != 1:
raise ValueError('Splitting the vocabulary is not supported') raise ValueError('Splitting the vocabulary is not supported')
self.prepare_key_value_store() self.prepare_metadata()
self.gguf_writer.write_header_to_file() self.gguf_writer.write_header_to_file(path=self.fname_out)
self.gguf_writer.write_kv_data_to_file() self.gguf_writer.write_kv_data_to_file()
self.gguf_writer.close() self.gguf_writer.close()
@ -3545,10 +3529,6 @@ def parse_args() -> argparse.Namespace:
"--metadata", type=Path, "--metadata", type=Path,
help="Specify the path for an authorship metadata override file" help="Specify the path for an authorship metadata override file"
) )
parser.add_argument(
"--get-outfile", action="store_true",
help="print calculated output file name then exit"
)
return parser.parse_args() return parser.parse_args()
@ -3576,9 +3556,6 @@ def main() -> None:
if args.verbose: if args.verbose:
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
elif args.get_outfile:
# Avoid printing anything besides the dump output
logging.basicConfig(level=logging.WARNING)
else: else:
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@ -3629,10 +3606,6 @@ def main() -> None:
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run, split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
small_first_shard=args.no_tensor_first_split) small_first_shard=args.no_tensor_first_split)
if args.get_outfile:
print(f"{model_instance.fname_default}") # noqa: NP100
return
if args.vocab_only: if args.vocab_only:
logger.info("Exporting model vocab...") logger.info("Exporting model vocab...")
model_instance.write_vocab() model_instance.write_vocab()
@ -3640,6 +3613,7 @@ def main() -> None:
else: else:
logger.info("Exporting model...") logger.info("Exporting model...")
model_instance.write() model_instance.write()
assert model_instance.fname_out is not None
out_path = f"{model_instance.fname_out.parent}{os.sep}" if is_split else model_instance.fname_out out_path = f"{model_instance.fname_out.parent}{os.sep}" if is_split else model_instance.fname_out
logger.info(f"Model successfully exported to {out_path}") logger.info(f"Model successfully exported to {out_path}")

View file

@ -1042,9 +1042,11 @@ def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileT
raise ValueError(f"Unexpected combination of types: {name_to_type}") raise ValueError(f"Unexpected combination of types: {name_to_type}")
def per_model_weight_count_estimation(tensors: Iterable[tuple[str, LazyTensor]], expert_count:int | None) -> int: def per_model_weight_count_estimation(tensors: Iterable[tuple[str, LazyTensor]]) -> tuple[int, int, int]:
# TODO: Ensure parameter count is accurate throughout various model type total_params = 0
sum_weight_estimate: int = 0 shared_params = 0
expert_params = 0
for name, lazy_tensor in tensors: for name, lazy_tensor in tensors:
# We don't need these # We don't need these
if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")): if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")):
@ -1057,17 +1059,15 @@ def per_model_weight_count_estimation(tensors: Iterable[tuple[str, LazyTensor]],
for dim in lazy_tensor.shape: for dim in lazy_tensor.shape:
sum_weights_in_tensor *= dim sum_weights_in_tensor *= dim
# Add Tensor Volume To Running Count if ".experts." in name:
sum_weight_estimate += sum_weights_in_tensor if ".experts.0." in name:
expert_params += sum_weights_in_tensor
else:
shared_params += sum_weights_in_tensor
if expert_count is None: total_params += sum_weights_in_tensor
return sum_weight_estimate
if expert_count is not None and expert_count == 0: return total_params, shared_params, expert_params
return sum_weight_estimate
# Calculate weight estimate per model
return int(sum_weight_estimate / expert_count)
def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyModel: def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyModel:
@ -1249,12 +1249,12 @@ class VocabFactory:
return vocab, special_vocab return vocab, special_vocab
def default_convention_outfile(file_type: GGMLFileType, expert_count:int | None, model_params_count: int, metadata: gguf.Metadata) -> str: def default_convention_outfile(file_type: GGMLFileType, expert_count: int | None, model_params_count: tuple[int, int, int], metadata: gguf.Metadata) -> str:
name = metadata.name if metadata.name is not None else None name = metadata.name if metadata.name is not None else None
basename = metadata.basename if metadata.basename is not None else None basename = metadata.basename if metadata.basename is not None else None
finetune = metadata.finetune if metadata.finetune is not None else None finetune = metadata.finetune if metadata.finetune is not None else None
version = metadata.version if metadata.version is not None else None version = metadata.version if metadata.version is not None else None
size_label = metadata.size_label if metadata.size_label is not None else gguf.size_label(expert_count, model_params_count) size_label = metadata.size_label if metadata.size_label is not None else gguf.size_label(*model_params_count, expert_count=expert_count or 0)
output_type = { output_type = {
GGMLFileType.AllF32: "F32", GGMLFileType.AllF32: "F32",
@ -1265,7 +1265,7 @@ def default_convention_outfile(file_type: GGMLFileType, expert_count:int | None,
return gguf.naming_convention(name, basename, finetune, version, size_label, output_type) return gguf.naming_convention(name, basename, finetune, version, size_label, output_type)
def default_outfile(model_paths: list[Path], file_type: GGMLFileType, expert_count:int | None, model_params_count: int, metadata: gguf.Metadata) -> Path: def default_outfile(model_paths: list[Path], file_type: GGMLFileType, expert_count: int | None, model_params_count: tuple[int, int, int], metadata: gguf.Metadata) -> Path:
default_filename = default_convention_outfile(file_type, expert_count, model_params_count, metadata) default_filename = default_convention_outfile(file_type, expert_count, model_params_count, metadata)
ret = model_paths[0].parent / f"{default_filename}.gguf" ret = model_paths[0].parent / f"{default_filename}.gguf"
if ret in model_paths: if ret in model_paths:
@ -1328,7 +1328,7 @@ def main(args_in: list[str] | None = None) -> None:
model_plus = load_some_model(dir_model) model_plus = load_some_model(dir_model)
params = Params.load(model_plus) params = Params.load(model_plus)
model = convert_model_names(model_plus.model, params, args.skip_unknown) model = convert_model_names(model_plus.model, params, args.skip_unknown)
model_params_count = per_model_weight_count_estimation(model_plus.model.items(), params.n_experts) model_params_count = per_model_weight_count_estimation(model_plus.model.items())
ftype = pick_output_type(model, args.outtype) ftype = pick_output_type(model, args.outtype)
if (metadata is None or metadata.name is None) and params.path_model is not None: if (metadata is None or metadata.name is None) and params.path_model is not None:
@ -1415,8 +1415,8 @@ def main(args_in: list[str] | None = None) -> None:
if metadata.name is None and params.path_model is not None: if metadata.name is None and params.path_model is not None:
metadata.name = params.path_model.name metadata.name = params.path_model.name
model_params_count = per_model_weight_count_estimation(model_plus.model.items(), params.n_experts) model_params_count = per_model_weight_count_estimation(model_plus.model.items())
logger.info(f"model parameters count : {model_params_count} ({gguf.model_weight_count_rounded_notation(model_params_count)})") logger.info(f"model parameters count : {model_params_count} ({gguf.model_weight_count_rounded_notation(model_params_count[0])})")
logger.info(f"Vocab info: {vocab}") logger.info(f"Vocab info: {vocab}")
logger.info(f"Special vocab info: {special_vocab}") logger.info(f"Special vocab info: {special_vocab}")
@ -1426,7 +1426,7 @@ def main(args_in: list[str] | None = None) -> None:
model = convert_to_output_type(model, ftype) model = convert_to_output_type(model, ftype)
outfile = args.outfile or default_outfile(model_plus.paths, ftype, params.n_experts, model_params_count, metadata=metadata) outfile = args.outfile or default_outfile(model_plus.paths, ftype, params.n_experts, model_params_count, metadata=metadata)
metadata.size_label = gguf.size_label(params.n_experts, model_params_count) metadata.size_label = gguf.size_label(*model_params_count, expert_count=params.n_experts or 0)
params.ftype = ftype params.ftype = ftype
logger.info(f"Writing {outfile}, format {ftype}") logger.info(f"Writing {outfile}, format {ftype}")

View file

@ -17,6 +17,7 @@ import numpy as np
from .constants import ( from .constants import (
GGUF_DEFAULT_ALIGNMENT, GGUF_DEFAULT_ALIGNMENT,
GGUF_MAGIC, GGUF_MAGIC,
GGML_QUANT_SIZES,
GGUF_VERSION, GGUF_VERSION,
GGMLQuantizationType, GGMLQuantizationType,
GGUFEndian, GGUFEndian,
@ -106,6 +107,36 @@ class GGUFWriter:
self.add_architecture() self.add_architecture()
def get_total_parameter_count(self) -> tuple[int, int, int, int]:
total_params = 0
shared_params = 0
expert_params = 0
expert_sum = 0
n_expert_tensors = 0
for tensors in self.tensors:
for name, info in tensors.items():
block_size, type_size = GGML_QUANT_SIZES[info.dtype]
size = (info.nbytes // type_size) * block_size
if "_exps." in name:
expert_params += (size // info.shape[-3])
expert_sum += info.shape[-3]
n_expert_tensors += 1
else:
shared_params += size
total_params += size
# Hopefully this should work even for variable-expert-count models
expert_count = (expert_sum // n_expert_tensors) if n_expert_tensors > 0 else 0
# NOTE: keep the output in the same order as accepted by 'size_label' in gguf-py/gguf/utility.py
return total_params, shared_params, expert_params, expert_count
def format_shard_names(self, path: Path) -> list[Path]: def format_shard_names(self, path: Path) -> list[Path]:
if len(self.tensors) == 1: if len(self.tensors) == 1:
return [path] return [path]

View file

@ -10,12 +10,8 @@ def fill_templated_filename(filename: str, output_type: str):
OUTTYPE=ftype_uppercase, FTYPE=ftype_uppercase) OUTTYPE=ftype_uppercase, FTYPE=ftype_uppercase)
def model_weight_count_rounded_notation(model_params_count: int) -> str: def model_weight_count_rounded_notation(model_params_count: int, min_digits: int = 2) -> str:
if model_params_count > 1e15 : if model_params_count > 1e12 :
# Quadrillion Of Parameters
scaled_model_params = model_params_count * 1e-15
scale_suffix = "Q"
elif model_params_count > 1e12 :
# Trillions Of Parameters # Trillions Of Parameters
scaled_model_params = model_params_count * 1e-12 scaled_model_params = model_params_count * 1e-12
scale_suffix = "T" scale_suffix = "T"
@ -31,21 +27,24 @@ def model_weight_count_rounded_notation(model_params_count: int) -> str:
# Thousands Of Parameters # Thousands Of Parameters
scaled_model_params = model_params_count * 1e-3 scaled_model_params = model_params_count * 1e-3
scale_suffix = "K" scale_suffix = "K"
return f"{round(scaled_model_params)}{scale_suffix}"
fix = max(min_digits - len(str(round(scaled_model_params)).lstrip('0')), 0)
return f"{scaled_model_params:.{fix}f}{scale_suffix}"
def size_label(expert_count_int:int | None, model_params_count: int) -> str: def size_label(total_params: int, shared_params: int, expert_params: int, expert_count: int) -> str:
per_model_rounded_weight_estimate = model_weight_count_rounded_notation(model_params_count)
if expert_count_int is not None and expert_count_int > 0: if expert_count > 0:
size_class = f"{expert_count_int}x{per_model_rounded_weight_estimate}" pretty_size = model_weight_count_rounded_notation(shared_params + expert_params, min_digits=2)
size_class = f"{expert_count}x{pretty_size}"
else: else:
size_class = f"{per_model_rounded_weight_estimate}" size_class = model_weight_count_rounded_notation(total_params, min_digits=2)
return size_class return size_class
def naming_convention(model_name: str | None, base_name: str | None, finetune_string:str | None, version_string:str | None, size_label: str | None, output_type: str | None) -> str: def naming_convention(model_name: str | None, base_name: str | None, finetune_string: str | None, version_string: str | None, size_label: str | None, output_type: str | None) -> str:
# Reference: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#gguf-naming-convention # Reference: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#gguf-naming-convention
if base_name is not None: if base_name is not None: