diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 9fe81d1a2..d3c1e4c0c 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -31,46 +31,6 @@ import gguf logger = logging.getLogger("hf-to-gguf") -@dataclass -class Metadata: - name: Optional[str] = None - basename: Optional[str] = None - finetune: Optional[str] = None - author: Optional[str] = None - version: Optional[str] = None - url: Optional[str] = None - description: Optional[str] = None - licence: Optional[str] = None - source_url: Optional[str] = None - source_hf_repo: Optional[str] = None - - @staticmethod - def load(metadata_path: Path) -> Metadata: - if metadata_path is None or not metadata_path.exists(): - return Metadata() - - with open(metadata_path, 'r') as file: - data = json.load(file) - - # Create a new Metadata instance - metadata = Metadata() - - # Assigning values to Metadata attributes if they exist in the JSON file - # This is based on LLM_KV_NAMES mapping in llama.cpp - metadata.name = data.get("general.name") - metadata.basename = data.get("general.basename") - metadata.finetune = data.get("general.finetune") - metadata.author = data.get("general.author") - metadata.version = data.get("general.version") - metadata.url = data.get("general.url") - metadata.description = data.get("general.description") - metadata.license = data.get("general.license") - metadata.source_url = data.get("general.source.url") - metadata.source_hf_repo = data.get("general.source.huggingface.repository") - - return metadata - - ###### MODEL DEFINITIONS ###### class SentencePieceTokenTypes(IntEnum): @@ -105,12 +65,12 @@ class Model: fname_out: Path fname_default: Path gguf_writer: gguf.GGUFWriter - metadata: Metadata + metadata: gguf.Metadata # subclasses should define this! model_arch: gguf.MODEL_ARCH - def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool, metadata: Metadata, + def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool, metadata: gguf.Metadata, model_name: str | None, split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False): if type(self) is Model: raise TypeError(f"{type(self).__name__!r} should not be directly instantiated") @@ -164,72 +124,23 @@ class Model: if self.model_card is not None and "license" in self.model_card: self.metadata.source_hf_repo = self.model_card["license"] - # Set model name based on latest metadata either provided or calculated from environment - def get_model_name(metadata, huggingface_parameters, dir_model, model_arch): - if metadata is not None and metadata.name is not None: - # Explicit Metadata Was Provided By User - return metadata.name - elif huggingface_parameters is not None and "_name_or_path" in huggingface_parameters: - # Hugging Face Parameters Model Name or Model Folder Name is Provided - return huggingface_parameters["_name_or_path"] - elif huggingface_parameters is not None and "model_type" in huggingface_parameters: - # Hugging Face Parameters Model Type is Provided - return huggingface_parameters["model_type"] - elif dir_model is not None and dir_model.name is not None: - # Use directory folder name - return dir_model.name - else: - return gguf.MODEL_ARCH_NAMES[model_arch] - self.model_name = get_model_name(self.metadata, self.hparams, self.dir_model, self.model_arch) + self.model_name = Model.get_model_name(self.metadata, self.hparams, self.dir_model, self.model_arch) # Extracts and converts the encoding scheme from the given file type name. e.g. 'gguf.LlamaFileType.ALL_F32' --> 'F32' - encodingScheme = self.ftype.name.partition("_")[2] + encoding_scheme = self.ftype.name.partition("_")[2] # Get Expert Count From huggingface_parameters expert_count = self.hparams["num_local_experts"] if "num_local_experts" in self.hparams else None - def per_model_weight_count_estimation(tensors, expert_count): - # TODO: Ensure parameter count is accurate throughout various model type - # May currently overestimate parameter count in Mamba model because - # output weights is tied with token embeddings. - sum_weight_estimate = 0 - for name, data_torch in tensors: - # Got A Tensor - - # We don't need these - if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")): - continue - - # Calculate Tensor Volume - sum_weights_in_tensor = 1 - for dim in data_torch.shape: - sum_weights_in_tensor *= dim - - # Add Tensor Volume To Running Count - sum_weight_estimate += sum_weights_in_tensor - - # Calculate weight estimate per model - per_model_weight_estimate = (sum_weight_estimate / expert_count) if (expert_count > 0) else sum_weight_estimate - - return per_model_weight_estimate - - weight_estimate = per_model_weight_count_estimation(model_tensors, expert_count) + weight_estimate = gguf.per_model_weight_count_estimation(model_tensors, expert_count) # Generate default filename based on model specification and available metadata - self.fname_default = gguf.naming_convention(self.model_name, self.metadata.basename, self.metadata.finetune, self.metadata.version, expert_count, weight_estimate, encodingScheme) + self.fname_default = gguf.naming_convention(self.model_name, self.metadata.basename, self.metadata.finetune, self.metadata.version, expert_count, weight_estimate, encoding_scheme) # Filename Output if fname_out is not None: # custom defined filename and path was provided - def fill_templated_filename(filename: str, encodingScheme: str): - # Given a file name fill in any type templates e.g. 'some-model-name.{ftype}.gguf' - ftype_uppercase: str = encodingScheme.upper() - ftype_lowercase: str = encodingScheme.lower() - return filename.format(ftype_lowercase, - outtype=ftype_lowercase, ftype=ftype_lowercase, - OUTTYPE=ftype_uppercase, FTYPE=ftype_uppercase) - - self.fname_out = fname_out.parent / fill_templated_filename(fname_out.name, encodingScheme) + self.fname_out = fname_out.parent / gguf.fill_templated_filename(fname_out.name, encoding_scheme) else: # output in the same directory as the model by default self.fname_out = dir_model.parent / self.fname_default @@ -499,6 +410,24 @@ class Model: self.gguf_writer.write_kv_data_to_file() self.gguf_writer.close() + # Set model name based on latest metadata either provided or calculated from environment + @staticmethod + def get_model_name(metadata, huggingface_parameters, dir_model, model_arch): + if metadata is not None and metadata.name is not None: + # Explicit Metadata Was Provided By User + return metadata.name + elif huggingface_parameters is not None and "_name_or_path" in huggingface_parameters: + # Hugging Face Parameters Model Name or Model Folder Name is Provided + return huggingface_parameters["_name_or_path"] + elif huggingface_parameters is not None and "model_type" in huggingface_parameters: + # Hugging Face Parameters Model Type is Provided + return huggingface_parameters["model_type"] + elif dir_model is not None and dir_model.name is not None: + # Use directory folder name + return dir_model.name + else: + return gguf.MODEL_ARCH_NAMES[model_arch] + @staticmethod def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str]: part_names: list[str] = [] @@ -3682,7 +3611,7 @@ def main() -> None: else: logging.basicConfig(level=logging.INFO) - metadata = Metadata.load(args.metadata) + metadata = gguf.Metadata.load(args.metadata) dir_model = args.model if not dir_model.is_dir(): @@ -3713,7 +3642,7 @@ def main() -> None: hparams = Model.load_hparams(dir_model) with torch.inference_mode(): - encodingScheme = ftype_map[args.outtype] + encoding_scheme = ftype_map[args.outtype] model_architecture = hparams["architectures"][0] try: diff --git a/examples/convert_legacy_llama.py b/examples/convert_legacy_llama.py index 874f8f8e6..38d1745f5 100755 --- a/examples/convert_legacy_llama.py +++ b/examples/convert_legacy_llama.py @@ -24,7 +24,7 @@ from abc import ABC, abstractmethod from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, IO, Iterable, Literal, TypeVar, Optional +from typing import TYPE_CHECKING, Any, Callable, IO, Iterable, Literal, TypeVar import numpy as np @@ -346,46 +346,6 @@ class Params: return params -@dataclass -class Metadata: - name: Optional[str] = None - basename: Optional[str] = None - finetune: Optional[str] = None - author: Optional[str] = None - version: Optional[str] = None - url: Optional[str] = None - description: Optional[str] = None - license: Optional[str] = None - source_url: Optional[str] = None - source_hf_repo: Optional[str] = None - - @staticmethod - def load(metadata_path: Path) -> Metadata: - if metadata_path is None or not metadata_path.exists(): - return Metadata() - - with open(metadata_path, 'r') as file: - data = json.load(file) - - # Create a new Metadata instance - metadata = Metadata() - - # Assigning values to Metadata attributes if they exist in the JSON file - # This is based on LLM_KV_NAMES mapping in llama.cpp - metadata.name = data.get("general.name") - metadata.basename = data.get("general.basename") - metadata.finetune = data.get("general.finetune") - metadata.author = data.get("general.author") - metadata.version = data.get("general.version") - metadata.url = data.get("general.url") - metadata.description = data.get("general.description") - metadata.license = data.get("general.license") - metadata.source_url = data.get("general.source.url") - metadata.source_hf_repo = data.get("general.source.huggingface.repository") - - return metadata - - # # data loading # TODO: reuse (probably move to gguf.py?) @@ -810,7 +770,7 @@ class OutputFile: def __init__(self, fname_out: Path, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE): self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess) - def add_meta_model(self, params: Params, metadata: Metadata | None) -> None: + def add_meta_model(self, params: Params, metadata: gguf.Metadata | None) -> None: # Metadata About The Model And Its Provenence name = "LLaMA" if metadata is not None and metadata.name is not None: @@ -952,7 +912,7 @@ class OutputFile: @staticmethod def write_vocab_only( fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab, - endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, metadata: Metadata | None = None, + endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, metadata: gguf.Metadata | None = None, ) -> None: check_vocab_size(params, vocab, pad_vocab=pad_vocab) @@ -986,7 +946,7 @@ class OutputFile: fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: BaseVocab, svocab: gguf.SpecialVocab, concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, - metadata: Metadata | None = None, + metadata: gguf.Metadata | None = None, ) -> None: check_vocab_size(params, vocab, pad_vocab=pad_vocab) @@ -1029,10 +989,10 @@ def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileT raise ValueError(f"Unexpected combination of types: {name_to_type}") -def per_model_weight_count_estimation(model: LazyModel, expert_count:int) -> int: +def per_model_weight_count_estimation(tensors: dict[str, LazyTensor], expert_count:int) -> int: # TODO: Ensure parameter count is accurate throughout various model type sum_weight_estimate = 0 - for name, lazy_tensor in model.items(): + for name, lazy_tensor in tensors: # We don't need these if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")): continue @@ -1232,7 +1192,7 @@ class VocabFactory: return vocab, special_vocab -def default_convention_outfile(file_type: GGMLFileType, model_name:str, expert_count:int, model_params_count: int, metadata: Metadata) -> str: +def default_convention_outfile(file_type: GGMLFileType, model_name:str, expert_count:int, model_params_count: int, metadata: gguf.Metadata) -> str: name = metadata.name if metadata is not None and metadata.name is not None else model_name basename = metadata.basename if metadata is not None and metadata.basename is not None else None finetune = metadata.finetune if metadata is not None and metadata.finetune is not None else None @@ -1247,7 +1207,7 @@ def default_convention_outfile(file_type: GGMLFileType, model_name:str, expert_c return gguf.naming_convention(name, basename, finetune, version, expert_count, model_params_count, encodingScheme) -def default_outfile(model_paths: list[Path], file_type: GGMLFileType, model_name:str, expert_count:int, model_params_count: int, metadata: Metadata) -> Path: +def default_outfile(model_paths: list[Path], file_type: GGMLFileType, model_name:str, expert_count:int, model_params_count: int, metadata: gguf.Metadata) -> Path: default_filename = default_convention_outfile(file_type, model_name, expert_count, model_params_count, metadata) ret = model_paths[0].parent / f"{default_filename}.gguf" if ret in model_paths: @@ -1300,13 +1260,13 @@ def main(args_in: list[str] | None = None) -> None: else: logging.basicConfig(level=logging.INFO) - metadata = Metadata.load(args.metadata) + metadata = gguf.Metadata.load(args.metadata) if args.get_outfile: model_plus = load_some_model(args.model) params = Params.load(model_plus) model = convert_model_names(model_plus.model, params, args.skip_unknown) - model_params_count = per_model_weight_count_estimation(model_plus.model, params.n_experts) + model_params_count = per_model_weight_count_estimation(model_plus.model.items(), params.n_experts) ftype = pick_output_type(model, args.outtype) print(f"{default_convention_outfile(ftype, params.path_model.name, params.n_experts, model_params_count, metadata)}") # noqa: NP100 return @@ -1324,7 +1284,7 @@ def main(args_in: list[str] | None = None) -> None: else: model_plus = ModelPlus(model = {}, paths = [args.model / 'dummy'], format = 'none', vocab = None) - model_params_count = per_model_weight_count_estimation(model_plus.model, params.n_experts) + model_params_count = per_model_weight_count_estimation(model_plus.model.items(), params.n_experts) logger.info(f"model parameters count : {model_params_count} ({gguf.model_weight_count_rounded_notation(model_params_count)})") if args.dump: diff --git a/gguf-py/gguf/__init__.py b/gguf-py/gguf/__init__.py index a07b8ff0d..243defc4c 100644 --- a/gguf-py/gguf/__init__.py +++ b/gguf-py/gguf/__init__.py @@ -6,3 +6,4 @@ from .quants import * from .tensor_mapping import * from .vocab import * from .utility import * +from .metadata import * diff --git a/gguf-py/gguf/metadata.py b/gguf-py/gguf/metadata.py new file mode 100644 index 000000000..0d175605a --- /dev/null +++ b/gguf-py/gguf/metadata.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +import json +from pathlib import Path + +from typing import Optional +from dataclasses import dataclass + +from .constants import Keys + + +@dataclass +class Metadata: + name: Optional[str] = None + basename: Optional[str] = None + finetune: Optional[str] = None + author: Optional[str] = None + version: Optional[str] = None + url: Optional[str] = None + description: Optional[str] = None + licence: Optional[str] = None + source_url: Optional[str] = None + source_hf_repo: Optional[str] = None + + @staticmethod + def load(metadata_path: Path) -> Metadata: + if metadata_path is None or not metadata_path.exists(): + return Metadata() + + with open(metadata_path, 'r') as file: + data = json.load(file) + + # Create a new Metadata instance + metadata = Metadata() + + # Assigning values to Metadata attributes if they exist in the JSON file + # This is based on LLM_KV_NAMES mapping in llama.cpp + metadata.name = data.get(Keys.General.NAME) + metadata.basename = data.get(Keys.General.BASENAME) + metadata.finetune = data.get(Keys.General.FINETUNE) + metadata.author = data.get(Keys.General.AUTHOR) + metadata.version = data.get(Keys.General.VERSION) + metadata.url = data.get(Keys.General.URL) + metadata.description = data.get(Keys.General.DESCRIPTION) + metadata.license = data.get(Keys.General.LICENSE) + metadata.source_url = data.get(Keys.General.SOURCE_URL) + metadata.source_hf_repo = data.get(Keys.General.SOURCE_HF_REPO) + + return metadata diff --git a/gguf-py/gguf/utility.py b/gguf-py/gguf/utility.py index 0919a744e..3a6046277 100644 --- a/gguf-py/gguf/utility.py +++ b/gguf-py/gguf/utility.py @@ -1,5 +1,45 @@ from __future__ import annotations +from typing import TYPE_CHECKING, Iterator + +if TYPE_CHECKING: + from torch import Tensor + + +def fill_templated_filename(filename: str, encoding_scheme: str): + # Given a file name fill in any type templates e.g. 'some-model-name.{ftype}.gguf' + ftype_uppercase: str = encoding_scheme.upper() + ftype_lowercase: str = encoding_scheme.lower() + return filename.format(ftype_lowercase, + outtype=ftype_lowercase, ftype=ftype_lowercase, + OUTTYPE=ftype_uppercase, FTYPE=ftype_uppercase) + + +def per_model_weight_count_estimation(tensors: Iterator[tuple[str, Tensor]], expert_count: int) -> int: + # TODO: Ensure parameter count is accurate throughout various model type + # May currently overestimate parameter count in Mamba model because + # output weights is tied with token embeddings. + sum_weight_estimate = 0 + for name, data_torch in tensors: + # Got A Tensor + + # We don't need these + if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")): + continue + + # Calculate Tensor Volume + sum_weights_in_tensor = 1 + for dim in data_torch.shape: + sum_weights_in_tensor *= dim + + # Add Tensor Volume To Running Count + sum_weight_estimate += sum_weights_in_tensor + + # Calculate weight estimate per model + per_model_weight_estimate = (sum_weight_estimate / expert_count) if (expert_count > 0) else sum_weight_estimate + + return per_model_weight_estimate + def model_weight_count_rounded_notation(model_params_count: int) -> str: if model_params_count > 1e15 :