convert-*.py: GGUF Naming Convention Refactor and Metadata Override Refactor (#7499)
Main thing is that the default output filename will take this form {name}{parameters}{finetune}{version}{encoding}{kind} In addition this add and remove some entries in the KV store and adds a metadata class with automatic heuristics capability to derive some values based on model card content * No Change: - Internal GGUF Spec - `general.architecture` - `general.quantization_version` - `general.alignment` - `general.file_type` - General Model Details - `general.name` - `general.author` - `general.version` - `general.description` - Licensing details - `general.license` - Typically represents the converted GGUF repo (Unless made from scratch) - `general.url` - Model Source during conversion - `general.source.url` * Removed: - Model Source during conversion - `general.source.huggingface.repository` * Added: - General Model Details - `general.organization` - `general.finetune` - `general.basename` - `general.quantized_by` - `general.size_label` - Licensing details - `general.license.name` - `general.license.link` - Typically represents the converted GGUF repo (Unless made from scratch) - `general.doi` - `general.uuid` - `general.repo_url` - Model Source during conversion - `general.source.doi` - `general.source.uuid` - `general.source.repo_url` - Base Model Source - `general.base_model.count` - `general.base_model.{id}.name` - `general.base_model.{id}.author` - `general.base_model.{id}.version` - `general.base_model.{id}.organization` - `general.base_model.{id}.url` (Model Website/Paper) - `general.base_model.{id}.doi` - `general.base_model.{id}.uuid` - `general.base_model.{id}.repo_url` (Model Source Repository (git/svn/etc...)) - Array based KV stores - `general.tags` - `general.languages` - `general.datasets` --------- Co-authored-by: compilade <git@compilade.net> Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com>
This commit is contained in:
parent
3807c3de04
commit
672a6f1018
13 changed files with 1185 additions and 239 deletions
|
@ -24,7 +24,7 @@ from abc import ABC, abstractmethod
|
|||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, IO, Iterable, Literal, TypeVar, Optional
|
||||
from typing import TYPE_CHECKING, Any, Callable, IO, Iterable, Literal, TypeVar
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -346,42 +346,6 @@ class Params:
|
|||
return params
|
||||
|
||||
|
||||
@dataclass
|
||||
class Metadata:
|
||||
name: Optional[str] = None
|
||||
author: Optional[str] = None
|
||||
version: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
license: Optional[str] = None
|
||||
source_url: Optional[str] = None
|
||||
source_hf_repo: Optional[str] = None
|
||||
|
||||
@staticmethod
|
||||
def load(metadata_path: Path) -> Metadata:
|
||||
if metadata_path is None or not metadata_path.exists():
|
||||
return Metadata()
|
||||
|
||||
with open(metadata_path, 'r') as file:
|
||||
data = json.load(file)
|
||||
|
||||
# Create a new Metadata instance
|
||||
metadata = Metadata()
|
||||
|
||||
# Assigning values to Metadata attributes if they exist in the JSON file
|
||||
# This is based on LLM_KV_NAMES mapping in llama.cpp
|
||||
metadata.name = data.get("general.name")
|
||||
metadata.author = data.get("general.author")
|
||||
metadata.version = data.get("general.version")
|
||||
metadata.url = data.get("general.url")
|
||||
metadata.description = data.get("general.description")
|
||||
metadata.license = data.get("general.license")
|
||||
metadata.source_url = data.get("general.source.url")
|
||||
metadata.source_hf_repo = data.get("general.source.huggingface.repository")
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
#
|
||||
# data loading
|
||||
# TODO: reuse (probably move to gguf.py?)
|
||||
|
@ -806,7 +770,7 @@ class OutputFile:
|
|||
def __init__(self, fname_out: Path, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE):
|
||||
self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess)
|
||||
|
||||
def add_meta_model(self, params: Params, metadata: Metadata | None) -> None:
|
||||
def add_meta_model(self, params: Params, metadata: gguf.Metadata | None) -> None:
|
||||
# Metadata About The Model And Its Provenence
|
||||
name = "LLaMA"
|
||||
if metadata is not None and metadata.name is not None:
|
||||
|
@ -824,16 +788,73 @@ class OutputFile:
|
|||
self.gguf.add_author(metadata.author)
|
||||
if metadata.version is not None:
|
||||
self.gguf.add_version(metadata.version)
|
||||
if metadata.url is not None:
|
||||
self.gguf.add_url(metadata.url)
|
||||
if metadata.organization is not None:
|
||||
self.gguf.add_organization(metadata.organization)
|
||||
|
||||
if metadata.finetune is not None:
|
||||
self.gguf.add_finetune(metadata.finetune)
|
||||
if metadata.basename is not None:
|
||||
self.gguf.add_basename(metadata.basename)
|
||||
|
||||
if metadata.description is not None:
|
||||
self.gguf.add_description(metadata.description)
|
||||
if metadata.quantized_by is not None:
|
||||
self.gguf.add_quantized_by(metadata.quantized_by)
|
||||
|
||||
if metadata.size_label is not None:
|
||||
self.gguf.add_size_label(metadata.size_label)
|
||||
|
||||
if metadata.license is not None:
|
||||
self.gguf.add_licence(metadata.license)
|
||||
self.gguf.add_license(metadata.license)
|
||||
if metadata.license_name is not None:
|
||||
self.gguf.add_license_name(metadata.license_name)
|
||||
if metadata.license_link is not None:
|
||||
self.gguf.add_license_link(metadata.license_link)
|
||||
|
||||
if metadata.url is not None:
|
||||
self.gguf.add_url(metadata.url)
|
||||
if metadata.doi is not None:
|
||||
self.gguf.add_doi(metadata.doi)
|
||||
if metadata.uuid is not None:
|
||||
self.gguf.add_uuid(metadata.uuid)
|
||||
if metadata.repo_url is not None:
|
||||
self.gguf.add_repo_url(metadata.repo_url)
|
||||
|
||||
if metadata.source_url is not None:
|
||||
self.gguf.add_source_url(metadata.source_url)
|
||||
if metadata.source_hf_repo is not None:
|
||||
self.gguf.add_source_hf_repo(metadata.source_hf_repo)
|
||||
if metadata.source_doi is not None:
|
||||
self.gguf.add_source_doi(metadata.source_doi)
|
||||
if metadata.source_uuid is not None:
|
||||
self.gguf.add_source_uuid(metadata.source_uuid)
|
||||
if metadata.source_repo_url is not None:
|
||||
self.gguf.add_source_repo_url(metadata.source_repo_url)
|
||||
|
||||
if metadata.base_models is not None:
|
||||
self.gguf.add_base_model_count(len(metadata.base_models))
|
||||
for key, base_model_entry in enumerate(metadata.base_models):
|
||||
if "name" in base_model_entry:
|
||||
self.gguf.add_base_model_name(key, base_model_entry["name"])
|
||||
if "author" in base_model_entry:
|
||||
self.gguf.add_base_model_author(key, base_model_entry["author"])
|
||||
if "version" in base_model_entry:
|
||||
self.gguf.add_base_model_version(key, base_model_entry["version"])
|
||||
if "organization" in base_model_entry:
|
||||
self.gguf.add_base_model_organization(key, base_model_entry["organization"])
|
||||
if "url" in base_model_entry:
|
||||
self.gguf.add_base_model_url(key, base_model_entry["url"])
|
||||
if "doi" in base_model_entry:
|
||||
self.gguf.add_base_model_doi(key, base_model_entry["doi"])
|
||||
if "uuid" in base_model_entry:
|
||||
self.gguf.add_base_model_uuid(key, base_model_entry["uuid"])
|
||||
if "repo_url" in base_model_entry:
|
||||
self.gguf.add_base_model_repo_url(key, base_model_entry["repo_url"])
|
||||
|
||||
if metadata.tags is not None:
|
||||
self.gguf.add_tags(metadata.tags)
|
||||
if metadata.languages is not None:
|
||||
self.gguf.add_languages(metadata.languages)
|
||||
if metadata.datasets is not None:
|
||||
self.gguf.add_datasets(metadata.datasets)
|
||||
|
||||
def add_meta_arch(self, params: Params) -> None:
|
||||
# Metadata About The Neural Architecture Itself
|
||||
|
@ -944,7 +965,7 @@ class OutputFile:
|
|||
@staticmethod
|
||||
def write_vocab_only(
|
||||
fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab,
|
||||
endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, metadata: Metadata | None = None,
|
||||
endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, metadata: gguf.Metadata | None = None,
|
||||
) -> None:
|
||||
check_vocab_size(params, vocab, pad_vocab=pad_vocab)
|
||||
|
||||
|
@ -978,7 +999,7 @@ class OutputFile:
|
|||
fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: BaseVocab, svocab: gguf.SpecialVocab,
|
||||
concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE,
|
||||
pad_vocab: bool = False,
|
||||
metadata: Metadata | None = None,
|
||||
metadata: gguf.Metadata | None = None,
|
||||
) -> None:
|
||||
check_vocab_size(params, vocab, pad_vocab=pad_vocab)
|
||||
|
||||
|
@ -1021,35 +1042,32 @@ def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileT
|
|||
raise ValueError(f"Unexpected combination of types: {name_to_type}")
|
||||
|
||||
|
||||
def model_parameter_count(model: LazyModel) -> int:
|
||||
total_model_parameters = 0
|
||||
for i, (name, lazy_tensor) in enumerate(model.items()):
|
||||
sum_weights_in_tensor = 1
|
||||
def per_model_weight_count_estimation(tensors: Iterable[tuple[str, LazyTensor]]) -> tuple[int, int, int]:
|
||||
total_params = 0
|
||||
shared_params = 0
|
||||
expert_params = 0
|
||||
|
||||
for name, lazy_tensor in tensors:
|
||||
# We don't need these
|
||||
if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")):
|
||||
continue
|
||||
|
||||
# Got A Tensor
|
||||
sum_weights_in_tensor: int = 1
|
||||
|
||||
# Tensor Volume
|
||||
for dim in lazy_tensor.shape:
|
||||
sum_weights_in_tensor *= dim
|
||||
total_model_parameters += sum_weights_in_tensor
|
||||
return total_model_parameters
|
||||
|
||||
if ".experts." in name:
|
||||
if ".experts.0." in name:
|
||||
expert_params += sum_weights_in_tensor
|
||||
else:
|
||||
shared_params += sum_weights_in_tensor
|
||||
|
||||
def model_parameter_count_rounded_notation(model_params_count: int) -> str:
|
||||
if model_params_count > 1e12 :
|
||||
# Trillions Of Parameters
|
||||
scaled_model_params = model_params_count * 1e-12
|
||||
scale_suffix = "T"
|
||||
elif model_params_count > 1e9 :
|
||||
# Billions Of Parameters
|
||||
scaled_model_params = model_params_count * 1e-9
|
||||
scale_suffix = "B"
|
||||
elif model_params_count > 1e6 :
|
||||
# Millions Of Parameters
|
||||
scaled_model_params = model_params_count * 1e-6
|
||||
scale_suffix = "M"
|
||||
else:
|
||||
# Thousands Of Parameters
|
||||
scaled_model_params = model_params_count * 1e-3
|
||||
scale_suffix = "K"
|
||||
total_params += sum_weights_in_tensor
|
||||
|
||||
return f"{round(scaled_model_params)}{scale_suffix}"
|
||||
return total_params, shared_params, expert_params
|
||||
|
||||
|
||||
def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyModel:
|
||||
|
@ -1231,34 +1249,24 @@ class VocabFactory:
|
|||
return vocab, special_vocab
|
||||
|
||||
|
||||
def default_convention_outfile(file_type: GGMLFileType, params: Params, model_params_count: int, metadata: Metadata) -> str:
|
||||
quantization = {
|
||||
def default_convention_outfile(file_type: GGMLFileType, expert_count: int | None, model_params_count: tuple[int, int, int], metadata: gguf.Metadata) -> str:
|
||||
name = metadata.name if metadata.name is not None else None
|
||||
basename = metadata.basename if metadata.basename is not None else None
|
||||
finetune = metadata.finetune if metadata.finetune is not None else None
|
||||
version = metadata.version if metadata.version is not None else None
|
||||
size_label = metadata.size_label if metadata.size_label is not None else gguf.size_label(*model_params_count, expert_count=expert_count or 0)
|
||||
|
||||
output_type = {
|
||||
GGMLFileType.AllF32: "F32",
|
||||
GGMLFileType.MostlyF16: "F16",
|
||||
GGMLFileType.MostlyQ8_0: "Q8_0",
|
||||
}[file_type]
|
||||
|
||||
parameters = model_parameter_count_rounded_notation(model_params_count)
|
||||
|
||||
expert_count = ""
|
||||
if params.n_experts is not None:
|
||||
expert_count = f"{params.n_experts}x"
|
||||
|
||||
version = ""
|
||||
if metadata is not None and metadata.version is not None:
|
||||
version = f"-{metadata.version}"
|
||||
|
||||
name = "ggml-model"
|
||||
if metadata is not None and metadata.name is not None:
|
||||
name = metadata.name
|
||||
elif params.path_model is not None:
|
||||
name = params.path_model.name
|
||||
|
||||
return f"{name}{version}-{expert_count}{parameters}-{quantization}"
|
||||
return gguf.naming_convention(name, basename, finetune, version, size_label, output_type)
|
||||
|
||||
|
||||
def default_outfile(model_paths: list[Path], file_type: GGMLFileType, params: Params, model_params_count: int, metadata: Metadata) -> Path:
|
||||
default_filename = default_convention_outfile(file_type, params, model_params_count, metadata)
|
||||
def default_outfile(model_paths: list[Path], file_type: GGMLFileType, expert_count: int | None, model_params_count: tuple[int, int, int], metadata: gguf.Metadata) -> Path:
|
||||
default_filename = default_convention_outfile(file_type, expert_count, model_params_count, metadata)
|
||||
ret = model_paths[0].parent / f"{default_filename}.gguf"
|
||||
if ret in model_paths:
|
||||
logger.error(
|
||||
|
@ -1297,8 +1305,9 @@ def main(args_in: list[str] | None = None) -> None:
|
|||
parser.add_argument("--pad-vocab", action="store_true", help="add pad tokens when model vocab expects more than tokenizer metadata provides")
|
||||
parser.add_argument("--skip-unknown", action="store_true", help="skip unknown tensor names instead of failing")
|
||||
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
|
||||
parser.add_argument("--metadata", type=Path, help="Specify the path for a metadata file")
|
||||
parser.add_argument("--metadata", type=Path, help="Specify the path for an authorship metadata override file")
|
||||
parser.add_argument("--get-outfile", action="store_true", help="get calculated default outfile name")
|
||||
parser.add_argument("--model-name", type=str, default=None, help="name of the model")
|
||||
|
||||
args = parser.parse_args(args_in)
|
||||
|
||||
|
@ -1310,32 +1319,36 @@ def main(args_in: list[str] | None = None) -> None:
|
|||
else:
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
metadata = Metadata.load(args.metadata)
|
||||
model_name = args.model_name
|
||||
dir_model = args.model
|
||||
|
||||
metadata = gguf.Metadata.load(args.metadata, dir_model, model_name)
|
||||
|
||||
if args.get_outfile:
|
||||
model_plus = load_some_model(args.model)
|
||||
model_plus = load_some_model(dir_model)
|
||||
params = Params.load(model_plus)
|
||||
model = convert_model_names(model_plus.model, params, args.skip_unknown)
|
||||
model_params_count = model_parameter_count(model_plus.model)
|
||||
ftype = pick_output_type(model, args.outtype)
|
||||
print(f"{default_convention_outfile(ftype, params, model_params_count, metadata)}") # noqa: NP100
|
||||
model = convert_model_names(model_plus.model, params, args.skip_unknown)
|
||||
model_params_count = per_model_weight_count_estimation(model_plus.model.items())
|
||||
ftype = pick_output_type(model, args.outtype)
|
||||
|
||||
if (metadata is None or metadata.name is None) and params.path_model is not None:
|
||||
metadata.name = params.path_model.name
|
||||
|
||||
print(f"{default_convention_outfile(ftype, params.n_experts, model_params_count, metadata)}") # noqa: NP100
|
||||
return
|
||||
|
||||
if args.no_vocab and args.vocab_only:
|
||||
raise ValueError("--vocab-only does not make sense with --no-vocab")
|
||||
|
||||
if args.dump_single:
|
||||
model_plus = lazy_load_file(args.model)
|
||||
model_plus = lazy_load_file(dir_model)
|
||||
do_dump_model(model_plus)
|
||||
return
|
||||
|
||||
if not args.vocab_only:
|
||||
model_plus = load_some_model(args.model)
|
||||
model_plus = load_some_model(dir_model)
|
||||
else:
|
||||
model_plus = ModelPlus(model = {}, paths = [args.model / 'dummy'], format = 'none', vocab = None)
|
||||
|
||||
model_params_count = model_parameter_count(model_plus.model)
|
||||
logger.info(f"model parameters count : {model_params_count} ({model_parameter_count_rounded_notation(model_params_count)})")
|
||||
model_plus = ModelPlus(model = {}, paths = [dir_model / 'dummy'], format = 'none', vocab = None)
|
||||
|
||||
if args.dump:
|
||||
do_dump_model(model_plus)
|
||||
|
@ -1368,7 +1381,7 @@ def main(args_in: list[str] | None = None) -> None:
|
|||
logger.info(f"params = {params}")
|
||||
|
||||
model_parent_path = model_plus.paths[0].parent
|
||||
vocab_path = Path(args.vocab_dir or args.model or model_parent_path)
|
||||
vocab_path = Path(args.vocab_dir or dir_model or model_parent_path)
|
||||
vocab_factory = VocabFactory(vocab_path)
|
||||
vocab_types = None if args.no_vocab else args.vocab_type.split(",")
|
||||
vocab, special_vocab = vocab_factory.load_vocab(vocab_types, model_parent_path)
|
||||
|
@ -1399,13 +1412,21 @@ def main(args_in: list[str] | None = None) -> None:
|
|||
|
||||
assert params is not None
|
||||
|
||||
if metadata.name is None and params.path_model is not None:
|
||||
metadata.name = params.path_model.name
|
||||
|
||||
model_params_count = per_model_weight_count_estimation(model_plus.model.items())
|
||||
logger.info(f"model parameters count : {model_params_count} ({gguf.model_weight_count_rounded_notation(model_params_count[0])})")
|
||||
|
||||
logger.info(f"Vocab info: {vocab}")
|
||||
logger.info(f"Special vocab info: {special_vocab}")
|
||||
model = model_plus.model
|
||||
model = convert_model_names(model, params, args.skip_unknown)
|
||||
ftype = pick_output_type(model, args.outtype)
|
||||
model = convert_to_output_type(model, ftype)
|
||||
outfile = args.outfile or default_outfile(model_plus.paths, ftype, params, model_params_count, metadata)
|
||||
outfile = args.outfile or default_outfile(model_plus.paths, ftype, params.n_experts, model_params_count, metadata=metadata)
|
||||
|
||||
metadata.size_label = gguf.size_label(*model_params_count, expert_count=params.n_experts or 0)
|
||||
|
||||
params.ftype = ftype
|
||||
logger.info(f"Writing {outfile}, format {ftype}")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue