diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 37b4462d3..5b96cbbb3 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -48,6 +48,7 @@ class Model: dir_model: Path ftype: gguf.LlamaFileType + fname_out: Path | None is_big_endian: bool endianess: gguf.GGUFEndian use_temp_file: bool @@ -58,8 +59,6 @@ class Model: block_count: int tensor_map: gguf.TensorNameMap tensor_names: set[str] | None - fname_out: Path - fname_default: str gguf_writer: gguf.GGUFWriter metadata: gguf.Metadata @@ -76,6 +75,7 @@ class Model: self.dir_model = dir_model self.ftype = ftype + self.fname_out = fname_out self.is_big_endian = is_big_endian self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE self.use_temp_file = use_temp_file @@ -101,37 +101,8 @@ class Model: logger.info(f"choosing --outtype bf16 from first tensor type ({first_tensor.dtype})") self.ftype = gguf.LlamaFileType.MOSTLY_BF16 - # Fallback to model directory name if metadata name is still missing - if self.metadata.name is None: - self.metadata.name = dir_model.name - - # Generate parameter weight class (useful for leader boards) if not yet determined - if self.metadata.size_label is None: - expert_count = self.hparams.get("num_local_experts", 0) - sum_weight_estimate = self.calculate_total_weight_count() - - # Calculate weight estimate per model - per_model_weight_estimate: int = (sum_weight_estimate / expert_count) if (expert_count > 0) else sum_weight_estimate - - self.metadata.size_label = gguf.size_label(expert_count, per_model_weight_estimate) - - # Extracts and converts the encoding scheme from the given file type name. e.g. 'gguf.LlamaFileType.ALL_F32' --> 'F32' - output_type = self.ftype.name.partition("_")[2] - - # Generate default filename based on model specification and available metadata - self.fname_default = gguf.naming_convention(self.metadata.name, self.metadata.basename, self.metadata.finetune, self.metadata.version, self.metadata.size_label, output_type) - - # Filename Output - if fname_out is not None: - # custom defined filename and path was provided - # allow templating the file name with the output ftype, useful with the "auto" ftype - self.fname_out = fname_out.parent / gguf.fill_templated_filename(fname_out.name, output_type) - else: - # output in the same directory as the model by default - self.fname_out = dir_model / f"{self.fname_default}.gguf" - # Configure GGUF Writer - self.gguf_writer = gguf.GGUFWriter(path=self.fname_out, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file, + self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file, split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard) @classmethod @@ -191,23 +162,6 @@ class Model: if len(sym_diff := tensor_names_from_parts.symmetric_difference(self.tensor_names)) > 0: raise ValueError(f"Mismatch between weight map and model parts for tensor names: {sym_diff}") - def calculate_total_weight_count(self) -> int: - sum_weight_estimate = 0 - for name, data_torch in self.get_tensors(): - # we don't need these - if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")): - continue - - # Calculate Tensor Volume - sum_weights_in_tensor = 1 - for dim in data_torch.shape: - sum_weights_in_tensor *= dim - - # Add Tensor Volume To Running Count - sum_weight_estimate += sum_weights_in_tensor - - return sum_weight_estimate - def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str: if key not in gguf.MODEL_TENSORS[self.model_arch]: raise ValueError(f"Missing {key!r} for MODEL_TENSORS of {self.model_arch!r}") @@ -376,7 +330,37 @@ class Model: self.gguf_writer.add_tensor(new_name, data, raw_dtype=data_qtype) - def prepare_key_value_store(self): + def prepare_metadata(self): + + # Fallback to model directory name if metadata name is still missing + if self.metadata.name is None: + self.metadata.name = self.dir_model.name + + # Generate parameter weight class (useful for leader boards) if not yet determined + if self.metadata.size_label is None: + total_params, shared_params, expert_params, expert_count = self.gguf_writer.get_total_parameter_count() + + if (total_params > 0): + self.metadata.size_label = gguf.size_label(total_params, shared_params, expert_params, expert_count) + + # Extract the encoding scheme from the file type name. e.g. 'gguf.LlamaFileType.MOSTLY_Q8_0' --> 'Q8_0' + output_type = self.ftype.name.partition("_")[2] + + # Generate default filename based on model specification and available metadata + fname_default = gguf.naming_convention(self.metadata.name, self.metadata.basename, self.metadata.finetune, self.metadata.version, self.metadata.size_label, output_type) + + # Filename Output + if self.fname_out is not None: + if not self.fname_out.is_dir(): + # custom defined filename and path was provided + # allow templating the file name with the output ftype, useful with the "auto" ftype + self.fname_out = self.fname_out.parent / gguf.fill_templated_filename(self.fname_out.name, output_type) + else: + # the target file is a directory + self.fname_out = self.fname_out / f"{fname_default}.gguf" + else: + # output in the same directory as the model by default + self.fname_out = self.dir_model / f"{fname_default}.gguf" logger.info("Set meta model") self.metadata.set_gguf_meta_model(self.gguf_writer) @@ -393,8 +377,8 @@ class Model: def write(self): self.prepare_tensors() - self.prepare_key_value_store() - self.gguf_writer.write_header_to_file() + self.prepare_metadata() + self.gguf_writer.write_header_to_file(path=self.fname_out) self.gguf_writer.write_kv_data_to_file() self.gguf_writer.write_tensors_to_file(progress=True) self.gguf_writer.close() @@ -403,8 +387,8 @@ class Model: if len(self.gguf_writer.tensors) != 1: raise ValueError('Splitting the vocabulary is not supported') - self.prepare_key_value_store() - self.gguf_writer.write_header_to_file() + self.prepare_metadata() + self.gguf_writer.write_header_to_file(path=self.fname_out) self.gguf_writer.write_kv_data_to_file() self.gguf_writer.close() @@ -3545,10 +3529,6 @@ def parse_args() -> argparse.Namespace: "--metadata", type=Path, help="Specify the path for an authorship metadata override file" ) - parser.add_argument( - "--get-outfile", action="store_true", - help="print calculated output file name then exit" - ) return parser.parse_args() @@ -3576,9 +3556,6 @@ def main() -> None: if args.verbose: logging.basicConfig(level=logging.DEBUG) - elif args.get_outfile: - # Avoid printing anything besides the dump output - logging.basicConfig(level=logging.WARNING) else: logging.basicConfig(level=logging.INFO) @@ -3629,10 +3606,6 @@ def main() -> None: split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run, small_first_shard=args.no_tensor_first_split) - if args.get_outfile: - print(f"{model_instance.fname_default}") # noqa: NP100 - return - if args.vocab_only: logger.info("Exporting model vocab...") model_instance.write_vocab() @@ -3640,6 +3613,7 @@ def main() -> None: else: logger.info("Exporting model...") model_instance.write() + assert model_instance.fname_out is not None out_path = f"{model_instance.fname_out.parent}{os.sep}" if is_split else model_instance.fname_out logger.info(f"Model successfully exported to {out_path}") diff --git a/examples/convert_legacy_llama.py b/examples/convert_legacy_llama.py index fc8a08cdf..9ab9ab06e 100755 --- a/examples/convert_legacy_llama.py +++ b/examples/convert_legacy_llama.py @@ -1042,9 +1042,11 @@ def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileT raise ValueError(f"Unexpected combination of types: {name_to_type}") -def per_model_weight_count_estimation(tensors: Iterable[tuple[str, LazyTensor]], expert_count:int | None) -> int: - # TODO: Ensure parameter count is accurate throughout various model type - sum_weight_estimate: int = 0 +def per_model_weight_count_estimation(tensors: Iterable[tuple[str, LazyTensor]]) -> tuple[int, int, int]: + total_params = 0 + shared_params = 0 + expert_params = 0 + for name, lazy_tensor in tensors: # We don't need these if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")): @@ -1057,17 +1059,15 @@ def per_model_weight_count_estimation(tensors: Iterable[tuple[str, LazyTensor]], for dim in lazy_tensor.shape: sum_weights_in_tensor *= dim - # Add Tensor Volume To Running Count - sum_weight_estimate += sum_weights_in_tensor + if ".experts." in name: + if ".experts.0." in name: + expert_params += sum_weights_in_tensor + else: + shared_params += sum_weights_in_tensor - if expert_count is None: - return sum_weight_estimate + total_params += sum_weights_in_tensor - if expert_count is not None and expert_count == 0: - return sum_weight_estimate - - # Calculate weight estimate per model - return int(sum_weight_estimate / expert_count) + return total_params, shared_params, expert_params def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyModel: @@ -1249,12 +1249,12 @@ class VocabFactory: return vocab, special_vocab -def default_convention_outfile(file_type: GGMLFileType, expert_count:int | None, model_params_count: int, metadata: gguf.Metadata) -> str: +def default_convention_outfile(file_type: GGMLFileType, expert_count: int | None, model_params_count: tuple[int, int, int], metadata: gguf.Metadata) -> str: name = metadata.name if metadata.name is not None else None basename = metadata.basename if metadata.basename is not None else None finetune = metadata.finetune if metadata.finetune is not None else None version = metadata.version if metadata.version is not None else None - size_label = metadata.size_label if metadata.size_label is not None else gguf.size_label(expert_count, model_params_count) + size_label = metadata.size_label if metadata.size_label is not None else gguf.size_label(*model_params_count, expert_count=expert_count or 0) output_type = { GGMLFileType.AllF32: "F32", @@ -1265,7 +1265,7 @@ def default_convention_outfile(file_type: GGMLFileType, expert_count:int | None, return gguf.naming_convention(name, basename, finetune, version, size_label, output_type) -def default_outfile(model_paths: list[Path], file_type: GGMLFileType, expert_count:int | None, model_params_count: int, metadata: gguf.Metadata) -> Path: +def default_outfile(model_paths: list[Path], file_type: GGMLFileType, expert_count: int | None, model_params_count: tuple[int, int, int], metadata: gguf.Metadata) -> Path: default_filename = default_convention_outfile(file_type, expert_count, model_params_count, metadata) ret = model_paths[0].parent / f"{default_filename}.gguf" if ret in model_paths: @@ -1328,7 +1328,7 @@ def main(args_in: list[str] | None = None) -> None: model_plus = load_some_model(dir_model) params = Params.load(model_plus) model = convert_model_names(model_plus.model, params, args.skip_unknown) - model_params_count = per_model_weight_count_estimation(model_plus.model.items(), params.n_experts) + model_params_count = per_model_weight_count_estimation(model_plus.model.items()) ftype = pick_output_type(model, args.outtype) if (metadata is None or metadata.name is None) and params.path_model is not None: @@ -1415,8 +1415,8 @@ def main(args_in: list[str] | None = None) -> None: if metadata.name is None and params.path_model is not None: metadata.name = params.path_model.name - model_params_count = per_model_weight_count_estimation(model_plus.model.items(), params.n_experts) - logger.info(f"model parameters count : {model_params_count} ({gguf.model_weight_count_rounded_notation(model_params_count)})") + model_params_count = per_model_weight_count_estimation(model_plus.model.items()) + logger.info(f"model parameters count : {model_params_count} ({gguf.model_weight_count_rounded_notation(model_params_count[0])})") logger.info(f"Vocab info: {vocab}") logger.info(f"Special vocab info: {special_vocab}") @@ -1426,7 +1426,7 @@ def main(args_in: list[str] | None = None) -> None: model = convert_to_output_type(model, ftype) outfile = args.outfile or default_outfile(model_plus.paths, ftype, params.n_experts, model_params_count, metadata=metadata) - metadata.size_label = gguf.size_label(params.n_experts, model_params_count) + metadata.size_label = gguf.size_label(*model_params_count, expert_count=params.n_experts or 0) params.ftype = ftype logger.info(f"Writing {outfile}, format {ftype}") diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 9cb9415d0..f0f029a18 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -17,6 +17,7 @@ import numpy as np from .constants import ( GGUF_DEFAULT_ALIGNMENT, GGUF_MAGIC, + GGML_QUANT_SIZES, GGUF_VERSION, GGMLQuantizationType, GGUFEndian, @@ -106,6 +107,36 @@ class GGUFWriter: self.add_architecture() + def get_total_parameter_count(self) -> tuple[int, int, int, int]: + total_params = 0 + shared_params = 0 + expert_params = 0 + + expert_sum = 0 + n_expert_tensors = 0 + + for tensors in self.tensors: + for name, info in tensors.items(): + + block_size, type_size = GGML_QUANT_SIZES[info.dtype] + + size = (info.nbytes // type_size) * block_size + + if "_exps." in name: + expert_params += (size // info.shape[-3]) + expert_sum += info.shape[-3] + n_expert_tensors += 1 + else: + shared_params += size + + total_params += size + + # Hopefully this should work even for variable-expert-count models + expert_count = (expert_sum // n_expert_tensors) if n_expert_tensors > 0 else 0 + + # NOTE: keep the output in the same order as accepted by 'size_label' in gguf-py/gguf/utility.py + return total_params, shared_params, expert_params, expert_count + def format_shard_names(self, path: Path) -> list[Path]: if len(self.tensors) == 1: return [path] diff --git a/gguf-py/gguf/utility.py b/gguf-py/gguf/utility.py index 88502e180..873ca406c 100644 --- a/gguf-py/gguf/utility.py +++ b/gguf-py/gguf/utility.py @@ -10,12 +10,8 @@ def fill_templated_filename(filename: str, output_type: str): OUTTYPE=ftype_uppercase, FTYPE=ftype_uppercase) -def model_weight_count_rounded_notation(model_params_count: int) -> str: - if model_params_count > 1e15 : - # Quadrillion Of Parameters - scaled_model_params = model_params_count * 1e-15 - scale_suffix = "Q" - elif model_params_count > 1e12 : +def model_weight_count_rounded_notation(model_params_count: int, min_digits: int = 2) -> str: + if model_params_count > 1e12 : # Trillions Of Parameters scaled_model_params = model_params_count * 1e-12 scale_suffix = "T" @@ -31,21 +27,24 @@ def model_weight_count_rounded_notation(model_params_count: int) -> str: # Thousands Of Parameters scaled_model_params = model_params_count * 1e-3 scale_suffix = "K" - return f"{round(scaled_model_params)}{scale_suffix}" + + fix = max(min_digits - len(str(round(scaled_model_params)).lstrip('0')), 0) + + return f"{scaled_model_params:.{fix}f}{scale_suffix}" -def size_label(expert_count_int:int | None, model_params_count: int) -> str: - per_model_rounded_weight_estimate = model_weight_count_rounded_notation(model_params_count) +def size_label(total_params: int, shared_params: int, expert_params: int, expert_count: int) -> str: - if expert_count_int is not None and expert_count_int > 0: - size_class = f"{expert_count_int}x{per_model_rounded_weight_estimate}" + if expert_count > 0: + pretty_size = model_weight_count_rounded_notation(shared_params + expert_params, min_digits=2) + size_class = f"{expert_count}x{pretty_size}" else: - size_class = f"{per_model_rounded_weight_estimate}" + size_class = model_weight_count_rounded_notation(total_params, min_digits=2) return size_class -def naming_convention(model_name: str | None, base_name: str | None, finetune_string:str | None, version_string:str | None, size_label: str | None, output_type: str | None) -> str: +def naming_convention(model_name: str | None, base_name: str | None, finetune_string: str | None, version_string: str | None, size_label: str | None, output_type: str | None) -> str: # Reference: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#gguf-naming-convention if base_name is not None: