convert-*.py: metadata class moved to utility

This commit is contained in:
brian khuu 2024-06-02 01:49:58 +10:00
parent 916872f72f
commit 4d5f18a0e6
5 changed files with 128 additions and 149 deletions

View file

@ -31,46 +31,6 @@ import gguf
logger = logging.getLogger("hf-to-gguf")
@dataclass
class Metadata:
name: Optional[str] = None
basename: Optional[str] = None
finetune: Optional[str] = None
author: Optional[str] = None
version: Optional[str] = None
url: Optional[str] = None
description: Optional[str] = None
licence: Optional[str] = None
source_url: Optional[str] = None
source_hf_repo: Optional[str] = None
@staticmethod
def load(metadata_path: Path) -> Metadata:
if metadata_path is None or not metadata_path.exists():
return Metadata()
with open(metadata_path, 'r') as file:
data = json.load(file)
# Create a new Metadata instance
metadata = Metadata()
# Assigning values to Metadata attributes if they exist in the JSON file
# This is based on LLM_KV_NAMES mapping in llama.cpp
metadata.name = data.get("general.name")
metadata.basename = data.get("general.basename")
metadata.finetune = data.get("general.finetune")
metadata.author = data.get("general.author")
metadata.version = data.get("general.version")
metadata.url = data.get("general.url")
metadata.description = data.get("general.description")
metadata.license = data.get("general.license")
metadata.source_url = data.get("general.source.url")
metadata.source_hf_repo = data.get("general.source.huggingface.repository")
return metadata
###### MODEL DEFINITIONS ######
class SentencePieceTokenTypes(IntEnum):
@ -105,12 +65,12 @@ class Model:
fname_out: Path
fname_default: Path
gguf_writer: gguf.GGUFWriter
metadata: Metadata
metadata: gguf.Metadata
# subclasses should define this!
model_arch: gguf.MODEL_ARCH
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool, metadata: Metadata,
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool, metadata: gguf.Metadata,
model_name: str | None, split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False):
if type(self) is Model:
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
@ -164,72 +124,23 @@ class Model:
if self.model_card is not None and "license" in self.model_card:
self.metadata.source_hf_repo = self.model_card["license"]
# Set model name based on latest metadata either provided or calculated from environment
def get_model_name(metadata, huggingface_parameters, dir_model, model_arch):
if metadata is not None and metadata.name is not None:
# Explicit Metadata Was Provided By User
return metadata.name
elif huggingface_parameters is not None and "_name_or_path" in huggingface_parameters:
# Hugging Face Parameters Model Name or Model Folder Name is Provided
return huggingface_parameters["_name_or_path"]
elif huggingface_parameters is not None and "model_type" in huggingface_parameters:
# Hugging Face Parameters Model Type is Provided
return huggingface_parameters["model_type"]
elif dir_model is not None and dir_model.name is not None:
# Use directory folder name
return dir_model.name
else:
return gguf.MODEL_ARCH_NAMES[model_arch]
self.model_name = get_model_name(self.metadata, self.hparams, self.dir_model, self.model_arch)
self.model_name = Model.get_model_name(self.metadata, self.hparams, self.dir_model, self.model_arch)
# Extracts and converts the encoding scheme from the given file type name. e.g. 'gguf.LlamaFileType.ALL_F32' --> 'F32'
encodingScheme = self.ftype.name.partition("_")[2]
encoding_scheme = self.ftype.name.partition("_")[2]
# Get Expert Count From huggingface_parameters
expert_count = self.hparams["num_local_experts"] if "num_local_experts" in self.hparams else None
def per_model_weight_count_estimation(tensors, expert_count):
# TODO: Ensure parameter count is accurate throughout various model type
# May currently overestimate parameter count in Mamba model because
# output weights is tied with token embeddings.
sum_weight_estimate = 0
for name, data_torch in tensors:
# Got A Tensor
# We don't need these
if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")):
continue
# Calculate Tensor Volume
sum_weights_in_tensor = 1
for dim in data_torch.shape:
sum_weights_in_tensor *= dim
# Add Tensor Volume To Running Count
sum_weight_estimate += sum_weights_in_tensor
# Calculate weight estimate per model
per_model_weight_estimate = (sum_weight_estimate / expert_count) if (expert_count > 0) else sum_weight_estimate
return per_model_weight_estimate
weight_estimate = per_model_weight_count_estimation(model_tensors, expert_count)
weight_estimate = gguf.per_model_weight_count_estimation(model_tensors, expert_count)
# Generate default filename based on model specification and available metadata
self.fname_default = gguf.naming_convention(self.model_name, self.metadata.basename, self.metadata.finetune, self.metadata.version, expert_count, weight_estimate, encodingScheme)
self.fname_default = gguf.naming_convention(self.model_name, self.metadata.basename, self.metadata.finetune, self.metadata.version, expert_count, weight_estimate, encoding_scheme)
# Filename Output
if fname_out is not None:
# custom defined filename and path was provided
def fill_templated_filename(filename: str, encodingScheme: str):
# Given a file name fill in any type templates e.g. 'some-model-name.{ftype}.gguf'
ftype_uppercase: str = encodingScheme.upper()
ftype_lowercase: str = encodingScheme.lower()
return filename.format(ftype_lowercase,
outtype=ftype_lowercase, ftype=ftype_lowercase,
OUTTYPE=ftype_uppercase, FTYPE=ftype_uppercase)
self.fname_out = fname_out.parent / fill_templated_filename(fname_out.name, encodingScheme)
self.fname_out = fname_out.parent / gguf.fill_templated_filename(fname_out.name, encoding_scheme)
else:
# output in the same directory as the model by default
self.fname_out = dir_model.parent / self.fname_default
@ -499,6 +410,24 @@ class Model:
self.gguf_writer.write_kv_data_to_file()
self.gguf_writer.close()
# Set model name based on latest metadata either provided or calculated from environment
@staticmethod
def get_model_name(metadata, huggingface_parameters, dir_model, model_arch):
if metadata is not None and metadata.name is not None:
# Explicit Metadata Was Provided By User
return metadata.name
elif huggingface_parameters is not None and "_name_or_path" in huggingface_parameters:
# Hugging Face Parameters Model Name or Model Folder Name is Provided
return huggingface_parameters["_name_or_path"]
elif huggingface_parameters is not None and "model_type" in huggingface_parameters:
# Hugging Face Parameters Model Type is Provided
return huggingface_parameters["model_type"]
elif dir_model is not None and dir_model.name is not None:
# Use directory folder name
return dir_model.name
else:
return gguf.MODEL_ARCH_NAMES[model_arch]
@staticmethod
def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str]:
part_names: list[str] = []
@ -3682,7 +3611,7 @@ def main() -> None:
else:
logging.basicConfig(level=logging.INFO)
metadata = Metadata.load(args.metadata)
metadata = gguf.Metadata.load(args.metadata)
dir_model = args.model
if not dir_model.is_dir():
@ -3713,7 +3642,7 @@ def main() -> None:
hparams = Model.load_hparams(dir_model)
with torch.inference_mode():
encodingScheme = ftype_map[args.outtype]
encoding_scheme = ftype_map[args.outtype]
model_architecture = hparams["architectures"][0]
try:

View file

@ -24,7 +24,7 @@ from abc import ABC, abstractmethod
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, IO, Iterable, Literal, TypeVar, Optional
from typing import TYPE_CHECKING, Any, Callable, IO, Iterable, Literal, TypeVar
import numpy as np
@ -346,46 +346,6 @@ class Params:
return params
@dataclass
class Metadata:
name: Optional[str] = None
basename: Optional[str] = None
finetune: Optional[str] = None
author: Optional[str] = None
version: Optional[str] = None
url: Optional[str] = None
description: Optional[str] = None
license: Optional[str] = None
source_url: Optional[str] = None
source_hf_repo: Optional[str] = None
@staticmethod
def load(metadata_path: Path) -> Metadata:
if metadata_path is None or not metadata_path.exists():
return Metadata()
with open(metadata_path, 'r') as file:
data = json.load(file)
# Create a new Metadata instance
metadata = Metadata()
# Assigning values to Metadata attributes if they exist in the JSON file
# This is based on LLM_KV_NAMES mapping in llama.cpp
metadata.name = data.get("general.name")
metadata.basename = data.get("general.basename")
metadata.finetune = data.get("general.finetune")
metadata.author = data.get("general.author")
metadata.version = data.get("general.version")
metadata.url = data.get("general.url")
metadata.description = data.get("general.description")
metadata.license = data.get("general.license")
metadata.source_url = data.get("general.source.url")
metadata.source_hf_repo = data.get("general.source.huggingface.repository")
return metadata
#
# data loading
# TODO: reuse (probably move to gguf.py?)
@ -810,7 +770,7 @@ class OutputFile:
def __init__(self, fname_out: Path, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE):
self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess)
def add_meta_model(self, params: Params, metadata: Metadata | None) -> None:
def add_meta_model(self, params: Params, metadata: gguf.Metadata | None) -> None:
# Metadata About The Model And Its Provenence
name = "LLaMA"
if metadata is not None and metadata.name is not None:
@ -952,7 +912,7 @@ class OutputFile:
@staticmethod
def write_vocab_only(
fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab,
endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, metadata: Metadata | None = None,
endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, metadata: gguf.Metadata | None = None,
) -> None:
check_vocab_size(params, vocab, pad_vocab=pad_vocab)
@ -986,7 +946,7 @@ class OutputFile:
fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: BaseVocab, svocab: gguf.SpecialVocab,
concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE,
pad_vocab: bool = False,
metadata: Metadata | None = None,
metadata: gguf.Metadata | None = None,
) -> None:
check_vocab_size(params, vocab, pad_vocab=pad_vocab)
@ -1029,10 +989,10 @@ def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileT
raise ValueError(f"Unexpected combination of types: {name_to_type}")
def per_model_weight_count_estimation(model: LazyModel, expert_count:int) -> int:
def per_model_weight_count_estimation(tensors: dict[str, LazyTensor], expert_count:int) -> int:
# TODO: Ensure parameter count is accurate throughout various model type
sum_weight_estimate = 0
for name, lazy_tensor in model.items():
for name, lazy_tensor in tensors:
# We don't need these
if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")):
continue
@ -1232,7 +1192,7 @@ class VocabFactory:
return vocab, special_vocab
def default_convention_outfile(file_type: GGMLFileType, model_name:str, expert_count:int, model_params_count: int, metadata: Metadata) -> str:
def default_convention_outfile(file_type: GGMLFileType, model_name:str, expert_count:int, model_params_count: int, metadata: gguf.Metadata) -> str:
name = metadata.name if metadata is not None and metadata.name is not None else model_name
basename = metadata.basename if metadata is not None and metadata.basename is not None else None
finetune = metadata.finetune if metadata is not None and metadata.finetune is not None else None
@ -1247,7 +1207,7 @@ def default_convention_outfile(file_type: GGMLFileType, model_name:str, expert_c
return gguf.naming_convention(name, basename, finetune, version, expert_count, model_params_count, encodingScheme)
def default_outfile(model_paths: list[Path], file_type: GGMLFileType, model_name:str, expert_count:int, model_params_count: int, metadata: Metadata) -> Path:
def default_outfile(model_paths: list[Path], file_type: GGMLFileType, model_name:str, expert_count:int, model_params_count: int, metadata: gguf.Metadata) -> Path:
default_filename = default_convention_outfile(file_type, model_name, expert_count, model_params_count, metadata)
ret = model_paths[0].parent / f"{default_filename}.gguf"
if ret in model_paths:
@ -1300,13 +1260,13 @@ def main(args_in: list[str] | None = None) -> None:
else:
logging.basicConfig(level=logging.INFO)
metadata = Metadata.load(args.metadata)
metadata = gguf.Metadata.load(args.metadata)
if args.get_outfile:
model_plus = load_some_model(args.model)
params = Params.load(model_plus)
model = convert_model_names(model_plus.model, params, args.skip_unknown)
model_params_count = per_model_weight_count_estimation(model_plus.model, params.n_experts)
model_params_count = per_model_weight_count_estimation(model_plus.model.items(), params.n_experts)
ftype = pick_output_type(model, args.outtype)
print(f"{default_convention_outfile(ftype, params.path_model.name, params.n_experts, model_params_count, metadata)}") # noqa: NP100
return
@ -1324,7 +1284,7 @@ def main(args_in: list[str] | None = None) -> None:
else:
model_plus = ModelPlus(model = {}, paths = [args.model / 'dummy'], format = 'none', vocab = None)
model_params_count = per_model_weight_count_estimation(model_plus.model, params.n_experts)
model_params_count = per_model_weight_count_estimation(model_plus.model.items(), params.n_experts)
logger.info(f"model parameters count : {model_params_count} ({gguf.model_weight_count_rounded_notation(model_params_count)})")
if args.dump:

View file

@ -6,3 +6,4 @@ from .quants import *
from .tensor_mapping import *
from .vocab import *
from .utility import *
from .metadata import *

49
gguf-py/gguf/metadata.py Normal file
View file

@ -0,0 +1,49 @@
from __future__ import annotations
import json
from pathlib import Path
from typing import Optional
from dataclasses import dataclass
from .constants import Keys
@dataclass
class Metadata:
name: Optional[str] = None
basename: Optional[str] = None
finetune: Optional[str] = None
author: Optional[str] = None
version: Optional[str] = None
url: Optional[str] = None
description: Optional[str] = None
licence: Optional[str] = None
source_url: Optional[str] = None
source_hf_repo: Optional[str] = None
@staticmethod
def load(metadata_path: Path) -> Metadata:
if metadata_path is None or not metadata_path.exists():
return Metadata()
with open(metadata_path, 'r') as file:
data = json.load(file)
# Create a new Metadata instance
metadata = Metadata()
# Assigning values to Metadata attributes if they exist in the JSON file
# This is based on LLM_KV_NAMES mapping in llama.cpp
metadata.name = data.get(Keys.General.NAME)
metadata.basename = data.get(Keys.General.BASENAME)
metadata.finetune = data.get(Keys.General.FINETUNE)
metadata.author = data.get(Keys.General.AUTHOR)
metadata.version = data.get(Keys.General.VERSION)
metadata.url = data.get(Keys.General.URL)
metadata.description = data.get(Keys.General.DESCRIPTION)
metadata.license = data.get(Keys.General.LICENSE)
metadata.source_url = data.get(Keys.General.SOURCE_URL)
metadata.source_hf_repo = data.get(Keys.General.SOURCE_HF_REPO)
return metadata

View file

@ -1,5 +1,45 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Iterator
if TYPE_CHECKING:
from torch import Tensor
def fill_templated_filename(filename: str, encoding_scheme: str):
# Given a file name fill in any type templates e.g. 'some-model-name.{ftype}.gguf'
ftype_uppercase: str = encoding_scheme.upper()
ftype_lowercase: str = encoding_scheme.lower()
return filename.format(ftype_lowercase,
outtype=ftype_lowercase, ftype=ftype_lowercase,
OUTTYPE=ftype_uppercase, FTYPE=ftype_uppercase)
def per_model_weight_count_estimation(tensors: Iterator[tuple[str, Tensor]], expert_count: int) -> int:
# TODO: Ensure parameter count is accurate throughout various model type
# May currently overestimate parameter count in Mamba model because
# output weights is tied with token embeddings.
sum_weight_estimate = 0
for name, data_torch in tensors:
# Got A Tensor
# We don't need these
if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")):
continue
# Calculate Tensor Volume
sum_weights_in_tensor = 1
for dim in data_torch.shape:
sum_weights_in_tensor *= dim
# Add Tensor Volume To Running Count
sum_weight_estimate += sum_weights_in_tensor
# Calculate weight estimate per model
per_model_weight_estimate = (sum_weight_estimate / expert_count) if (expert_count > 0) else sum_weight_estimate
return per_model_weight_estimate
def model_weight_count_rounded_notation(model_params_count: int) -> str:
if model_params_count > 1e15 :