convert-*.py: GGUF Naming Convention Refactor and Metadata Override Refactor (#7499)

Main thing is that the default output filename will take this form

{name}{parameters}{finetune}{version}{encoding}{kind}

In addition this add and remove some entries in the KV store and adds a metadata class with automatic heuristics capability to derive some values based on model card content

* No Change:
  - Internal GGUF Spec
    - `general.architecture`
    - `general.quantization_version`
    - `general.alignment`
    - `general.file_type`
  - General Model Details
    - `general.name`
    - `general.author`
    - `general.version`
    - `general.description`
  - Licensing details
    - `general.license`
  - Typically represents the converted GGUF repo (Unless made from scratch)
    - `general.url`
  - Model Source during conversion
    - `general.source.url`

* Removed:
  - Model Source during conversion
    - `general.source.huggingface.repository`

* Added:
  - General Model Details
    - `general.organization`
    - `general.finetune`
    - `general.basename`
    - `general.quantized_by`
    - `general.size_label`
  - Licensing details
    - `general.license.name`
    - `general.license.link`
  - Typically represents the converted GGUF repo (Unless made from scratch)
    - `general.doi`
    - `general.uuid`
    - `general.repo_url`
  - Model Source during conversion
    - `general.source.doi`
    - `general.source.uuid`
    - `general.source.repo_url`
  - Base Model Source
    - `general.base_model.count`
    - `general.base_model.{id}.name`
    - `general.base_model.{id}.author`
    - `general.base_model.{id}.version`
    - `general.base_model.{id}.organization`
    - `general.base_model.{id}.url` (Model Website/Paper)
    - `general.base_model.{id}.doi`
    - `general.base_model.{id}.uuid`
    - `general.base_model.{id}.repo_url` (Model Source Repository (git/svn/etc...))
  - Array based KV stores
    - `general.tags`
    - `general.languages`
    - `general.datasets`

---------

Co-authored-by: compilade <git@compilade.net>
Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com>
This commit is contained in:
Brian 2024-07-18 20:40:15 +10:00 committed by GitHub
parent 3807c3de04
commit 672a6f1018
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 1185 additions and 239 deletions

View file

@ -78,5 +78,13 @@ python -m build
python -m twine upload dist/*
```
## Run Unit Tests
From root of this repository you can run this command to run all the unit tests
```bash
python -m unittest discover ./gguf-py -v
```
## TODO
- [ ] Include conversion scripts as command line entry points in this package.

View file

@ -5,3 +5,5 @@ from .gguf_writer import *
from .quants import *
from .tensor_mapping import *
from .vocab import *
from .utility import *
from .metadata import *

View file

@ -19,19 +19,60 @@ GGML_QUANT_VERSION = 2 # GGML_QNT_VERSION from ggml.h
class Keys:
class General:
TYPE = "general.type"
ARCHITECTURE = "general.architecture"
QUANTIZATION_VERSION = "general.quantization_version"
ALIGNMENT = "general.alignment"
NAME = "general.name"
AUTHOR = "general.author"
VERSION = "general.version"
URL = "general.url"
DESCRIPTION = "general.description"
LICENSE = "general.license"
SOURCE_URL = "general.source.url"
SOURCE_HF_REPO = "general.source.huggingface.repository"
FILE_TYPE = "general.file_type"
TYPE = "general.type"
ARCHITECTURE = "general.architecture"
QUANTIZATION_VERSION = "general.quantization_version"
ALIGNMENT = "general.alignment"
FILE_TYPE = "general.file_type"
# Authorship Metadata
NAME = "general.name"
AUTHOR = "general.author"
VERSION = "general.version"
ORGANIZATION = "general.organization"
FINETUNE = "general.finetune"
BASENAME = "general.basename"
DESCRIPTION = "general.description"
QUANTIZED_BY = "general.quantized_by"
SIZE_LABEL = "general.size_label"
# Licensing details
LICENSE = "general.license"
LICENSE_NAME = "general.license.name"
LICENSE_LINK = "general.license.link"
# Typically represents the converted GGUF repo (Unless native)
URL = "general.url" # Model Website/Paper
DOI = "general.doi"
UUID = "general.uuid"
REPO_URL = "general.repo_url" # Model Source Repository (git/svn/etc...)
# Model Source during conversion
SOURCE_URL = "general.source.url" # Model Website/Paper
SOURCE_DOI = "general.source.doi"
SOURCE_UUID = "general.source.uuid"
SOURCE_REPO_URL = "general.source.repo_url" # Model Source Repository (git/svn/etc...)
# Base Model Source. There can be more than one source if it's a merged
# model like with 'Mistral-7B-Merge-14-v0.1'. This will assist in
# tracing linage of models as it is finetuned or merged over time.
BASE_MODEL_COUNT = "general.base_model.count"
BASE_MODEL_NAME = "general.base_model.{id}.name"
BASE_MODEL_AUTHOR = "general.base_model.{id}.author"
BASE_MODEL_VERSION = "general.base_model.{id}.version"
BASE_MODEL_ORGANIZATION = "general.base_model.{id}.organization"
BASE_MODEL_URL = "general.base_model.{id}.url" # Model Website/Paper
BASE_MODEL_DOI = "general.base_model.{id}.doi"
BASE_MODEL_UUID = "general.base_model.{id}.uuid"
BASE_MODEL_REPO_URL = "general.base_model.{id}.repo_url" # Model Source Repository (git/svn/etc...)
# Array based KV stores
TAGS = "general.tags"
LANGUAGES = "general.languages"
DATASETS = "general.datasets"
class LLM:
VOCAB_SIZE = "{arch}.vocab_size"
@ -1233,7 +1274,6 @@ KEY_GENERAL_URL = Keys.General.URL
KEY_GENERAL_DESCRIPTION = Keys.General.DESCRIPTION
KEY_GENERAL_LICENSE = Keys.General.LICENSE
KEY_GENERAL_SOURCE_URL = Keys.General.SOURCE_URL
KEY_GENERAL_SOURCE_HF_REPO = Keys.General.SOURCE_HF_REPO
KEY_GENERAL_FILE_TYPE = Keys.General.FILE_TYPE
# LLM

View file

@ -7,6 +7,7 @@ import struct
import tempfile
from dataclasses import dataclass
from enum import Enum, auto
from math import prod
from pathlib import Path
from io import BufferedWriter
from typing import IO, Any, Sequence, Mapping
@ -106,6 +107,53 @@ class GGUFWriter:
self.add_architecture()
def get_total_parameter_count(self) -> tuple[int, int, int, int]:
total_params = 0
shared_params = 0
expert_params = 0
expert_sum = 0
n_expert_tensors = 0
last_lora_a: tuple[str, TensorInfo] | None = None
for tensors in self.tensors:
for name, info in tensors.items():
shape = info.shape
if name.endswith(".lora_a"):
last_lora_a = (name, info)
continue
elif name.endswith(".lora_b"):
if last_lora_a is None or last_lora_a[0] != name[:-1] + "a":
# Bail when the LoRA pair can't be found trivially
logger.warning("can't measure LoRA size correctly, tensor order is unusual")
return 0, 0, 0, 0
else:
shape = (*shape[:-1], last_lora_a[1].shape[-1])
size = prod(shape)
if "_exps." in name:
expert_params += (size // shape[-3])
expert_sum += shape[-3]
n_expert_tensors += 1
else:
shared_params += size
total_params += size
# Hopefully this should work even for variable-expert-count models
expert_count = (expert_sum // n_expert_tensors) if n_expert_tensors > 0 else 0
# Negate the total to signal it's likely not exact
if last_lora_a is not None:
total_params = -total_params
# NOTE: keep the output in the same order as accepted by 'size_label' in gguf-py/gguf/utility.py
return total_params, shared_params, expert_params, expert_count
def format_shard_names(self, path: Path) -> list[Path]:
if len(self.tensors) == 1:
return [path]
@ -115,6 +163,7 @@ class GGUFWriter:
if self.state is WriterState.EMPTY and self.fout is not None and (path is None or path == self.path):
# allow calling this multiple times as long as the path is the same
return
if self.state is not WriterState.NO_FILE:
raise ValueError(f'Expected output file to be not yet opened, got {self.state}')
@ -136,6 +185,8 @@ class GGUFWriter:
if self.dry_run:
logger.info("Dry run, not writing files")
for name in filenames:
print(name) # noqa: NP100
exit()
return filenames
@ -430,29 +481,12 @@ class GGUFWriter:
def add_architecture(self) -> None:
self.add_string(Keys.General.ARCHITECTURE, self.arch)
def add_author(self, author: str) -> None:
self.add_string(Keys.General.AUTHOR, author)
def add_quantization_version(self, quantization_version: int) -> None:
self.add_uint32(Keys.General.QUANTIZATION_VERSION, quantization_version)
def add_version(self, version: str) -> None:
self.add_string(Keys.General.VERSION, version)
def add_tensor_data_layout(self, layout: str) -> None:
self.add_string(Keys.LLM.TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
def add_url(self, url: str) -> None:
self.add_string(Keys.General.URL, url)
def add_description(self, description: str) -> None:
self.add_string(Keys.General.DESCRIPTION, description)
def add_licence(self, licence: str) -> None:
self.add_string(Keys.General.LICENSE, licence)
def add_source_url(self, url: str) -> None:
self.add_string(Keys.General.SOURCE_URL, url)
def add_source_hf_repo(self, repo: str) -> None:
self.add_string(Keys.General.SOURCE_HF_REPO, repo)
def add_custom_alignment(self, alignment: int) -> None:
self.data_alignment = alignment
self.add_uint32(Keys.General.ALIGNMENT, alignment)
def add_file_type(self, ftype: int) -> None:
self.add_uint32(Keys.General.FILE_TYPE, ftype)
@ -460,13 +494,101 @@ class GGUFWriter:
def add_name(self, name: str) -> None:
self.add_string(Keys.General.NAME, name)
def add_quantization_version(self, quantization_version: int) -> None:
self.add_uint32(
Keys.General.QUANTIZATION_VERSION, quantization_version)
def add_author(self, author: str) -> None:
self.add_string(Keys.General.AUTHOR, author)
def add_custom_alignment(self, alignment: int) -> None:
self.data_alignment = alignment
self.add_uint32(Keys.General.ALIGNMENT, alignment)
def add_version(self, version: str) -> None:
self.add_string(Keys.General.VERSION, version)
def add_organization(self, organization: str) -> None:
self.add_string(Keys.General.ORGANIZATION, organization)
def add_finetune(self, finetune: str) -> None:
self.add_string(Keys.General.FINETUNE, finetune)
def add_basename(self, basename: str) -> None:
self.add_string(Keys.General.BASENAME, basename)
def add_description(self, description: str) -> None:
self.add_string(Keys.General.DESCRIPTION, description)
def add_quantized_by(self, quantized: str) -> None:
self.add_string(Keys.General.QUANTIZED_BY, quantized)
def add_size_label(self, size_label: str) -> None:
self.add_string(Keys.General.SIZE_LABEL, size_label)
def add_license(self, license: str) -> None:
self.add_string(Keys.General.LICENSE, license)
def add_license_name(self, license: str) -> None:
self.add_string(Keys.General.LICENSE_NAME, license)
def add_license_link(self, license: str) -> None:
self.add_string(Keys.General.LICENSE_LINK, license)
def add_url(self, url: str) -> None:
self.add_string(Keys.General.URL, url)
def add_doi(self, doi: str) -> None:
self.add_string(Keys.General.DOI, doi)
def add_uuid(self, uuid: str) -> None:
self.add_string(Keys.General.UUID, uuid)
def add_repo_url(self, repo_url: str) -> None:
self.add_string(Keys.General.REPO_URL, repo_url)
def add_source_url(self, url: str) -> None:
self.add_string(Keys.General.SOURCE_URL, url)
def add_source_doi(self, doi: str) -> None:
self.add_string(Keys.General.SOURCE_DOI, doi)
def add_source_uuid(self, uuid: str) -> None:
self.add_string(Keys.General.SOURCE_UUID, uuid)
def add_source_repo_url(self, repo_url: str) -> None:
self.add_string(Keys.General.SOURCE_REPO_URL, repo_url)
def add_base_model_count(self, source_count: int) -> None:
self.add_uint32(Keys.General.BASE_MODEL_COUNT, source_count)
def add_base_model_name(self, source_id: int, name: str) -> None:
self.add_string(Keys.General.BASE_MODEL_NAME.format(id=source_id), name)
def add_base_model_author(self, source_id: int, author: str) -> None:
self.add_string(Keys.General.BASE_MODEL_AUTHOR.format(id=source_id), author)
def add_base_model_version(self, source_id: int, version: str) -> None:
self.add_string(Keys.General.BASE_MODEL_VERSION.format(id=source_id), version)
def add_base_model_organization(self, source_id: int, organization: str) -> None:
self.add_string(Keys.General.BASE_MODEL_ORGANIZATION.format(id=source_id), organization)
def add_base_model_url(self, source_id: int, url: str) -> None:
self.add_string(Keys.General.BASE_MODEL_URL.format(id=source_id), url)
def add_base_model_doi(self, source_id: int, doi: str) -> None:
self.add_string(Keys.General.BASE_MODEL_DOI.format(id=source_id), doi)
def add_base_model_uuid(self, source_id: int, uuid: str) -> None:
self.add_string(Keys.General.BASE_MODEL_UUID.format(id=source_id), uuid)
def add_base_model_repo_url(self, source_id: int, repo_url: str) -> None:
self.add_string(Keys.General.BASE_MODEL_REPO_URL.format(id=source_id), repo_url)
def add_tags(self, tags: Sequence[str]) -> None:
self.add_array(Keys.General.TAGS, tags)
def add_languages(self, languages: Sequence[str]) -> None:
self.add_array(Keys.General.LANGUAGES, languages)
def add_datasets(self, datasets: Sequence[str]) -> None:
self.add_array(Keys.General.DATASETS, datasets)
def add_tensor_data_layout(self, layout: str) -> None:
self.add_string(Keys.LLM.TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
def add_vocab_size(self, size: int) -> None:
self.add_uint32(Keys.LLM.VOCAB_SIZE.format(arch=self.arch), size)

485
gguf-py/gguf/metadata.py Normal file
View file

@ -0,0 +1,485 @@
from __future__ import annotations
import re
import json
import yaml
import logging
from pathlib import Path
from typing import Any, Literal, Optional
from dataclasses import dataclass
from .constants import Keys
import gguf
logger = logging.getLogger("metadata")
@dataclass
class Metadata:
# Authorship Metadata to be written to GGUF KV Store
name: Optional[str] = None
author: Optional[str] = None
version: Optional[str] = None
organization: Optional[str] = None
finetune: Optional[str] = None
basename: Optional[str] = None
description: Optional[str] = None
quantized_by: Optional[str] = None
size_label: Optional[str] = None
url: Optional[str] = None
doi: Optional[str] = None
uuid: Optional[str] = None
repo_url: Optional[str] = None
source_url: Optional[str] = None
source_doi: Optional[str] = None
source_uuid: Optional[str] = None
source_repo_url: Optional[str] = None
license: Optional[str] = None
license_name: Optional[str] = None
license_link: Optional[str] = None
base_models: Optional[list[dict]] = None
tags: Optional[list[str]] = None
languages: Optional[list[str]] = None
datasets: Optional[list[str]] = None
@staticmethod
def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Path] = None, model_name: Optional[str] = None, total_params: int = 0) -> Metadata:
# This grabs as many contextual authorship metadata as possible from the model repository
# making any conversion as required to match the gguf kv store metadata format
# as well as giving users the ability to override any authorship metadata that may be incorrect
# Create a new Metadata instance
metadata = Metadata()
model_card = Metadata.load_model_card(model_path)
hf_params = Metadata.load_hf_parameters(model_path)
# heuristics
metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_path, total_params)
# Metadata Override File Provided
# This is based on LLM_KV_NAMES mapping in llama.cpp
metadata_override = Metadata.load_metadata_override(metadata_override_path)
metadata.author = metadata_override.get(Keys.General.AUTHOR, metadata.author)
metadata.version = metadata_override.get(Keys.General.VERSION, metadata.version)
metadata.organization = metadata_override.get(Keys.General.ORGANIZATION, metadata.organization)
metadata.finetune = metadata_override.get(Keys.General.FINETUNE, metadata.finetune)
metadata.basename = metadata_override.get(Keys.General.BASENAME, metadata.basename)
metadata.description = metadata_override.get(Keys.General.DESCRIPTION, metadata.description)
metadata.quantized_by = metadata_override.get(Keys.General.QUANTIZED_BY, metadata.quantized_by)
metadata.size_label = metadata_override.get(Keys.General.SIZE_LABEL, metadata.size_label)
metadata.license_name = metadata_override.get(Keys.General.LICENSE_NAME, metadata.license_name)
metadata.license_link = metadata_override.get(Keys.General.LICENSE_LINK, metadata.license_link)
metadata.url = metadata_override.get(Keys.General.URL, metadata.url)
metadata.doi = metadata_override.get(Keys.General.DOI, metadata.doi)
metadata.uuid = metadata_override.get(Keys.General.UUID, metadata.uuid)
metadata.repo_url = metadata_override.get(Keys.General.REPO_URL, metadata.repo_url)
metadata.source_url = metadata_override.get(Keys.General.SOURCE_URL, metadata.source_url)
metadata.source_doi = metadata_override.get(Keys.General.SOURCE_DOI, metadata.source_doi)
metadata.source_uuid = metadata_override.get(Keys.General.SOURCE_UUID, metadata.source_uuid)
metadata.source_repo_url = metadata_override.get(Keys.General.SOURCE_REPO_URL, metadata.source_repo_url)
# Base Models is received here as an array of models
metadata.base_models = metadata_override.get("general.base_models", metadata.base_models)
metadata.tags = metadata_override.get(Keys.General.TAGS, metadata.tags)
metadata.languages = metadata_override.get(Keys.General.LANGUAGES, metadata.languages)
metadata.datasets = metadata_override.get(Keys.General.DATASETS, metadata.datasets)
# Direct Metadata Override (via direct cli argument)
if model_name is not None:
metadata.name = model_name
return metadata
@staticmethod
def load_metadata_override(metadata_override_path: Optional[Path] = None) -> dict[str, Any]:
if metadata_override_path is None or not metadata_override_path.is_file():
return {}
with open(metadata_override_path, "r", encoding="utf-8") as f:
return json.load(f)
@staticmethod
def load_model_card(model_path: Optional[Path] = None) -> dict[str, Any]:
if model_path is None or not model_path.is_dir():
return {}
model_card_path = model_path / "README.md"
if not model_card_path.is_file():
return {}
# The model card metadata is assumed to always be in YAML
# ref: https://github.com/huggingface/transformers/blob/a5c642fe7a1f25d3bdcd76991443ba6ff7ee34b2/src/transformers/modelcard.py#L468-L473
with open(model_card_path, "r", encoding="utf-8") as f:
if f.readline() == "---\n":
raw = f.read().partition("---\n")[0]
data = yaml.safe_load(raw)
if isinstance(data, dict):
return data
else:
logger.error(f"while reading YAML model card frontmatter, data is {type(data)} instead of dict")
return {}
else:
return {}
@staticmethod
def load_hf_parameters(model_path: Optional[Path] = None) -> dict[str, Any]:
if model_path is None or not model_path.is_dir():
return {}
config_path = model_path / "config.json"
if not config_path.is_file():
return {}
with open(config_path, "r", encoding="utf-8") as f:
return json.load(f)
@staticmethod
def id_to_title(string):
# Convert capitalization into title form unless acronym or version number
return ' '.join([w.title() if w.islower() and not re.match(r'^(v\d+(?:\.\d+)*|\d.*)$', w) else w for w in string.strip().replace('-', ' ').split()])
@staticmethod
def get_model_id_components(model_id: Optional[str] = None, total_params: int = 0) -> tuple[str | None, str | None, str | None, str | None, str | None, str | None]:
# Huggingface often store model id as '<org>/<model name>'
# so let's parse it and apply some heuristics if possible for model name components
if model_id is None:
# model ID missing
return None, None, None, None, None, None
if ' ' in model_id:
# model ID is actually a normal human sentence
# which means its most likely a normal model name only
# not part of the hugging face naming standard, but whatever
return model_id, None, None, None, None, None
if '/' in model_id:
# model ID (huggingface style)
org_component, model_full_name_component = model_id.split('/', 1)
else:
# model ID but missing org components
org_component, model_full_name_component = None, model_id
# Check if we erroneously matched against './' or '../' etc...
if org_component is not None and org_component[0] == '.':
org_component = None
name_parts: list[str] = model_full_name_component.split('-')
name_types: list[
set[Literal["basename", "size_label", "finetune", "version", "type"]]
] = [set() for _ in name_parts]
# Annotate the name
for i, part in enumerate(name_parts):
# Version
if re.fullmatch(r'(v|iter)?\d+([.]\d+)*', part, re.IGNORECASE):
name_types[i].add("version")
# Quant type (should not be there for base models, but still annotated)
elif re.fullmatch(r'i?q\d(_\w)*|b?fp?(16|32)', part, re.IGNORECASE):
name_types[i].add("type")
name_parts[i] = part.upper()
# Model size
elif i > 0 and re.fullmatch(r'(([A]|\d+[x])?\d+([._]\d+)?[KMBT][\d]?|small|mini|medium|large|x?xl)', part, re.IGNORECASE):
part = part.replace("_", ".")
# Handle weird bloom-7b1 notation
if part[-1].isdecimal():
part = part[:-2] + "." + part[-1] + part[-2]
# Normalize the size suffixes
if len(part) > 1 and part[-2].isdecimal():
if part[-1] in "kmbt":
part = part[:-1] + part[-1].upper()
if total_params != 0:
try:
label_params = float(part[:-1]) * pow(1000, " KMBT".find(part[-1]))
# Only use it as a size label if it's close or bigger than the model size
# Note that LoRA adapters don't necessarily include all layers,
# so this is why bigger label sizes are accepted.
# Do not use the size label when it's smaller than 1/8 of the model size
if (total_params < 0 and label_params < abs(total_params) // 8) or (
# Check both directions when the current model isn't a LoRA adapter
total_params > 0 and abs(label_params - total_params) > 7 * total_params // 8
):
# Likely a context length
name_types[i].add("finetune")
# Lowercase the size when it's a context length
part = part[:-1] + part[-1].lower()
except ValueError:
# Failed to convert the size label to float, use it anyway
pass
if len(name_types[i]) == 0:
name_types[i].add("size_label")
name_parts[i] = part
# Some easy to recognize finetune names
elif i > 0 and re.fullmatch(r'chat|instruct|vision|lora', part, re.IGNORECASE):
name_types[i].add("finetune")
if part.lower() == "lora":
name_parts[i] = "LoRA"
at_start = True
# Find the basename through the annotated name
for part, t in zip(name_parts, name_types):
if at_start and ((len(t) == 0 and part[0].isalpha()) or "version" in t):
t.add("basename")
else:
if at_start:
at_start = False
if len(t) == 0:
t.add("finetune")
# Remove the basename annotation from trailing version
for part, t in zip(reversed(name_parts), reversed(name_types)):
if "basename" in t:
if len(t) > 1:
t.remove("basename")
else:
break
basename = "-".join(n for n, t in zip(name_parts, name_types) if "basename" in t) or None
size_label = "-".join(s for s, t in zip(name_parts, name_types) if "size_label" in t) or None
finetune = "-".join(f for f, t in zip(name_parts, name_types) if "finetune" in t) or None
# TODO: should the basename version always be excluded?
# TODO: should multiple versions be joined together?
version = ([v for v, t, in zip(name_parts, name_types) if "version" in t and "basename" not in t] or [None])[-1]
if size_label is None and finetune is None and version is None:
# Too ambiguous, output nothing
basename = None
return model_full_name_component, org_component, basename, finetune, version, size_label
@staticmethod
def apply_metadata_heuristic(metadata: Metadata, model_card: Optional[dict] = None, hf_params: Optional[dict] = None, model_path: Optional[Path] = None, total_params: int = 0) -> Metadata:
# Reference Model Card Metadata: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
# Model Card Heuristics
########################
if model_card is not None:
if "model_name" in model_card and metadata.name is None:
# Not part of huggingface model card standard but notice some model creator using it
# such as TheBloke in 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF'
metadata.name = model_card.get("model_name")
if "model_creator" in model_card and metadata.author is None:
# Not part of huggingface model card standard but notice some model creator using it
# such as TheBloke in 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF'
metadata.author = model_card.get("model_creator")
if "model_type" in model_card and metadata.basename is None:
# Not part of huggingface model card standard but notice some model creator using it
# such as TheBloke in 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF'
metadata.basename = model_card.get("model_type")
if "base_model" in model_card:
# This represents the parent models that this is based on
# Example: stabilityai/stable-diffusion-xl-base-1.0. Can also be a list (for merges)
# Example of merges: https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0.1/blob/main/README.md
metadata_base_models = []
base_model_value = model_card.get("base_model", None)
if base_model_value is not None:
if isinstance(base_model_value, str):
metadata_base_models.append(base_model_value)
elif isinstance(base_model_value, list):
metadata_base_models.extend(base_model_value)
if metadata.base_models is None:
metadata.base_models = []
for model_id in metadata_base_models:
# NOTE: model size of base model is assumed to be similar to the size of the current model
model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
base_model = {}
if model_full_name_component is not None:
base_model["name"] = Metadata.id_to_title(model_full_name_component)
if org_component is not None:
base_model["organization"] = Metadata.id_to_title(org_component)
if version is not None:
base_model["version"] = version
if org_component is not None and model_full_name_component is not None:
base_model["repo_url"] = f"https://huggingface.co/{org_component}/{model_full_name_component}"
metadata.base_models.append(base_model)
if "license" in model_card and metadata.license is None:
metadata.license = model_card.get("license")
if "license_name" in model_card and metadata.license_name is None:
metadata.license_name = model_card.get("license_name")
if "license_link" in model_card and metadata.license_link is None:
metadata.license_link = model_card.get("license_link")
tags_value = model_card.get("tags", None)
if tags_value is not None:
if metadata.tags is None:
metadata.tags = []
if isinstance(tags_value, str):
metadata.tags.append(tags_value)
elif isinstance(tags_value, list):
metadata.tags.extend(tags_value)
pipeline_tags_value = model_card.get("pipeline_tag", None)
if pipeline_tags_value is not None:
if metadata.tags is None:
metadata.tags = []
if isinstance(pipeline_tags_value, str):
metadata.tags.append(pipeline_tags_value)
elif isinstance(pipeline_tags_value, list):
metadata.tags.extend(pipeline_tags_value)
language_value = model_card.get("languages", model_card.get("language", None))
if language_value is not None:
if metadata.languages is None:
metadata.languages = []
if isinstance(language_value, str):
metadata.languages.append(language_value)
elif isinstance(language_value, list):
metadata.languages.extend(language_value)
dataset_value = model_card.get("datasets", model_card.get("dataset", None))
if dataset_value is not None:
if metadata.datasets is None:
metadata.datasets = []
if isinstance(dataset_value, str):
metadata.datasets.append(dataset_value)
elif isinstance(dataset_value, list):
metadata.datasets.extend(dataset_value)
# Hugging Face Parameter Heuristics
####################################
if hf_params is not None:
hf_name_or_path = hf_params.get("_name_or_path")
if hf_name_or_path is not None and hf_name_or_path.count('/') <= 1:
# Use _name_or_path only if its actually a model name and not some computer path
# e.g. 'meta-llama/Llama-2-7b-hf'
model_id = hf_name_or_path
model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
if metadata.name is None and model_full_name_component is not None:
metadata.name = Metadata.id_to_title(model_full_name_component)
if metadata.organization is None and org_component is not None:
metadata.organization = Metadata.id_to_title(org_component)
if metadata.basename is None and basename is not None:
metadata.basename = basename
if metadata.finetune is None and finetune is not None:
metadata.finetune = finetune
if metadata.version is None and version is not None:
metadata.version = version
if metadata.size_label is None and size_label is not None:
metadata.size_label = size_label
# Directory Folder Name Fallback Heuristics
############################################
if model_path is not None:
model_id = model_path.name
model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
if metadata.name is None and model_full_name_component is not None:
metadata.name = Metadata.id_to_title(model_full_name_component)
if metadata.organization is None and org_component is not None:
metadata.organization = Metadata.id_to_title(org_component)
if metadata.basename is None and basename is not None:
metadata.basename = basename
if metadata.finetune is None and finetune is not None:
metadata.finetune = finetune
if metadata.version is None and version is not None:
metadata.version = version
if metadata.size_label is None and size_label is not None:
metadata.size_label = size_label
return metadata
def set_gguf_meta_model(self, gguf_writer: gguf.GGUFWriter):
assert self.name is not None
gguf_writer.add_name(self.name)
if self.author is not None:
gguf_writer.add_author(self.author)
if self.version is not None:
gguf_writer.add_version(self.version)
if self.organization is not None:
gguf_writer.add_organization(self.organization)
if self.finetune is not None:
gguf_writer.add_finetune(self.finetune)
if self.basename is not None:
gguf_writer.add_basename(self.basename)
if self.description is not None:
gguf_writer.add_description(self.description)
if self.quantized_by is not None:
gguf_writer.add_quantized_by(self.quantized_by)
if self.size_label is not None:
gguf_writer.add_size_label(self.size_label)
if self.license is not None:
gguf_writer.add_license(self.license)
if self.license_name is not None:
gguf_writer.add_license_name(self.license_name)
if self.license_link is not None:
gguf_writer.add_license_link(self.license_link)
if self.url is not None:
gguf_writer.add_url(self.url)
if self.doi is not None:
gguf_writer.add_doi(self.doi)
if self.uuid is not None:
gguf_writer.add_uuid(self.uuid)
if self.repo_url is not None:
gguf_writer.add_repo_url(self.repo_url)
if self.source_url is not None:
gguf_writer.add_source_url(self.source_url)
if self.source_doi is not None:
gguf_writer.add_source_doi(self.source_doi)
if self.source_uuid is not None:
gguf_writer.add_source_uuid(self.source_uuid)
if self.source_repo_url is not None:
gguf_writer.add_source_repo_url(self.source_repo_url)
if self.base_models is not None:
gguf_writer.add_base_model_count(len(self.base_models))
for key, base_model_entry in enumerate(self.base_models):
if "name" in base_model_entry:
gguf_writer.add_base_model_name(key, base_model_entry["name"])
if "author" in base_model_entry:
gguf_writer.add_base_model_author(key, base_model_entry["author"])
if "version" in base_model_entry:
gguf_writer.add_base_model_version(key, base_model_entry["version"])
if "organization" in base_model_entry:
gguf_writer.add_base_model_organization(key, base_model_entry["organization"])
if "url" in base_model_entry:
gguf_writer.add_base_model_url(key, base_model_entry["url"])
if "doi" in base_model_entry:
gguf_writer.add_base_model_doi(key, base_model_entry["doi"])
if "uuid" in base_model_entry:
gguf_writer.add_base_model_uuid(key, base_model_entry["uuid"])
if "repo_url" in base_model_entry:
gguf_writer.add_base_model_repo_url(key, base_model_entry["repo_url"])
if self.tags is not None:
gguf_writer.add_tags(self.tags)
if self.languages is not None:
gguf_writer.add_languages(self.languages)
if self.datasets is not None:
gguf_writer.add_datasets(self.datasets)

69
gguf-py/gguf/utility.py Normal file
View file

@ -0,0 +1,69 @@
from __future__ import annotations
from typing import Literal
def fill_templated_filename(filename: str, output_type: str | None) -> str:
# Given a file name fill in any type templates e.g. 'some-model-name.{ftype}.gguf'
ftype_lowercase: str = output_type.lower() if output_type is not None else ""
ftype_uppercase: str = output_type.upper() if output_type is not None else ""
return filename.format(ftype_lowercase,
outtype=ftype_lowercase, ftype=ftype_lowercase,
OUTTYPE=ftype_uppercase, FTYPE=ftype_uppercase)
def model_weight_count_rounded_notation(model_params_count: int, min_digits: int = 2) -> str:
if model_params_count > 1e12 :
# Trillions Of Parameters
scaled_model_params = model_params_count * 1e-12
scale_suffix = "T"
elif model_params_count > 1e9 :
# Billions Of Parameters
scaled_model_params = model_params_count * 1e-9
scale_suffix = "B"
elif model_params_count > 1e6 :
# Millions Of Parameters
scaled_model_params = model_params_count * 1e-6
scale_suffix = "M"
else:
# Thousands Of Parameters
scaled_model_params = model_params_count * 1e-3
scale_suffix = "K"
fix = max(min_digits - len(str(round(scaled_model_params)).lstrip('0')), 0)
return f"{scaled_model_params:.{fix}f}{scale_suffix}"
def size_label(total_params: int, shared_params: int, expert_params: int, expert_count: int) -> str:
if expert_count > 0:
pretty_size = model_weight_count_rounded_notation(abs(shared_params) + abs(expert_params), min_digits=2)
size_class = f"{expert_count}x{pretty_size}"
else:
size_class = model_weight_count_rounded_notation(abs(total_params), min_digits=2)
return size_class
def naming_convention(model_name: str | None, base_name: str | None, finetune_string: str | None, version_string: str | None, size_label: str | None, output_type: str | None, model_type: Literal['vocab', 'LoRA'] | None = None) -> str:
# Reference: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#gguf-naming-convention
if base_name is not None:
name = base_name.strip().title().replace(' ', '-').replace('/', '-')
elif model_name is not None:
name = model_name.strip().title().replace(' ', '-').replace('/', '-')
else:
name = "ggml-model"
parameters = f"-{size_label}" if size_label is not None else ""
finetune = f"-{finetune_string.strip().title().replace(' ', '-')}" if finetune_string is not None else ""
version = f"-{version_string.strip().replace(' ', '-')}" if version_string is not None else ""
encoding = f"-{output_type.strip().replace(' ', '-').upper()}" if output_type is not None else ""
kind = f"-{model_type.strip().replace(' ', '-')}" if model_type is not None else ""
return f"{name}{parameters}{finetune}{version}{encoding}{kind}"

View file

@ -22,6 +22,7 @@ classifiers = [
python = ">=3.8"
numpy = ">=1.17"
tqdm = ">=4.27"
pyyaml = ">=5.1"
[tool.poetry.dev-dependencies]
pytest = "^5.2"

View file

@ -0,0 +1 @@
from .test_metadata import *

View file

@ -1,7 +0,0 @@
import gguf # noqa: F401 # pyright: ignore[reportUnusedImport]
# TODO: add tests
def test_write_gguf() -> None:
pass

158
gguf-py/tests/test_metadata.py Executable file
View file

@ -0,0 +1,158 @@
#!/usr/bin/env python3
import unittest
from pathlib import Path
import os
import sys
# Necessary to load the local gguf package
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
sys.path.insert(0, str(Path(__file__).parent.parent))
import gguf
class TestMetadataMethod(unittest.TestCase):
def test_id_to_title(self):
self.assertEqual(gguf.Metadata.id_to_title("Mixtral-8x7B-Instruct-v0.1"), "Mixtral 8x7B Instruct v0.1")
self.assertEqual(gguf.Metadata.id_to_title("Meta-Llama-3-8B"), "Meta Llama 3 8B")
self.assertEqual(gguf.Metadata.id_to_title("hermes-2-pro-llama-3-8b-DPO"), "Hermes 2 Pro Llama 3 8b DPO")
def test_get_model_id_components(self):
# This is the basic standard form with organization marker
self.assertEqual(gguf.Metadata.get_model_id_components("Mistral/Mixtral-8x7B-Instruct-v0.1"),
('Mixtral-8x7B-Instruct-v0.1', "Mistral", 'Mixtral', 'Instruct', 'v0.1', '8x7B'))
# Similar to basic standard form but without organization marker
self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B-Instruct-v0.1"),
('Mixtral-8x7B-Instruct-v0.1', None, 'Mixtral', 'Instruct', 'v0.1', '8x7B'))
# Missing version
self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B-Instruct"),
('Mixtral-8x7B-Instruct', None, 'Mixtral', 'Instruct', None, '8x7B'))
# Missing finetune
self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B-v0.1"),
('Mixtral-8x7B-v0.1', None, 'Mixtral', None, 'v0.1', '8x7B'))
# Base name and size label only
self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B"),
('Mixtral-8x7B', None, 'Mixtral', None, None, '8x7B'))
# Base name and version only
self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-v0.1"),
('Mixtral-v0.1', None, 'Mixtral', None, 'v0.1', None))
## Edge Cases ##
# This is too ambiguous... best to err on caution and output nothing
self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral"),
('Mixtral', None, None, None, None, None))
# Basename has numbers mixed in and also size label provided. Must avoid capturing number in basename
self.assertEqual(gguf.Metadata.get_model_id_components("NousResearch/Meta-Llama-3-8B"),
('Meta-Llama-3-8B', "NousResearch", 'Meta-Llama-3', None, None, '8B'))
# Can't detect all non standard form in a heuristically safe way... best to err in caution and output nothing...
self.assertEqual(gguf.Metadata.get_model_id_components("Qwen1.5-MoE-A2.7B-Chat"),
('Qwen1.5-MoE-A2.7B-Chat', None, 'Qwen1.5-MoE', 'Chat', None, 'A2.7B'))
# Capture 'sub size labels' e.g. A14B in '57B-A14B' usually refers to activated params/weight count
self.assertEqual(gguf.Metadata.get_model_id_components("Qwen2-57B-A14B-Instruct"),
('Qwen2-57B-A14B-Instruct', None, 'Qwen2', 'Instruct', None, '57B-A14B'))
# Check that it can handle a real model id with no version code
# Note that 4k in this string is non standard and microsoft were referring to context length rather than weight count
self.assertEqual(gguf.Metadata.get_model_id_components("microsoft/Phi-3-mini-4k-instruct", 4 * 10**9),
('Phi-3-mini-4k-instruct', 'microsoft', 'Phi-3', '4k-instruct', None, 'mini'))
# There is some legitimate models with only thousands of parameters
self.assertEqual(gguf.Metadata.get_model_id_components("delphi-suite/stories-llama2-50k", 50 * 10**3),
('stories-llama2-50k', 'delphi-suite', 'stories-llama2', None, None, '50K'))
# None standard and not easy to disambiguate
self.assertEqual(gguf.Metadata.get_model_id_components("DeepSeek-Coder-V2-Lite-Instruct"),
('DeepSeek-Coder-V2-Lite-Instruct', None, 'DeepSeek-Coder-V2-Lite', 'Instruct', None, None))
# This is a real model_id where they append 2DPO to refer to Direct Preference Optimization
self.assertEqual(gguf.Metadata.get_model_id_components("crestf411/daybreak-kunoichi-2dpo-7b"),
('daybreak-kunoichi-2dpo-7b', 'crestf411', 'daybreak-kunoichi', '2dpo', None, '7B'))
# This is a real model id where the weight size has a decimal point
self.assertEqual(gguf.Metadata.get_model_id_components("Qwen2-0.5B-Instruct"),
('Qwen2-0.5B-Instruct', None, 'Qwen2', 'Instruct', None, '0.5B'))
# Uses an underscore in the size label
self.assertEqual(gguf.Metadata.get_model_id_components("smallcloudai/Refact-1_6B-fim"),
('Refact-1_6B-fim', 'smallcloudai', 'Refact', 'fim', None, '1.6B'))
# Uses Iter3 for the version
self.assertEqual(gguf.Metadata.get_model_id_components("UCLA-AGI/Gemma-2-9B-It-SPPO-Iter3"),
('Gemma-2-9B-It-SPPO-Iter3', 'UCLA-AGI', 'Gemma-2', 'It-SPPO', 'Iter3', '9B'))
# Has two potential versions in the basename
self.assertEqual(gguf.Metadata.get_model_id_components("NousResearch/Hermes-2-Theta-Llama-3-8B"),
('Hermes-2-Theta-Llama-3-8B', 'NousResearch', 'Hermes-2-Theta-Llama-3', None, None, '8B'))
# Potential version in the basename
self.assertEqual(gguf.Metadata.get_model_id_components("SeaLLMs/SeaLLMs-v3-7B-Chat"),
('SeaLLMs-v3-7B-Chat', 'SeaLLMs', 'SeaLLMs-v3', 'Chat', None, '7B'))
# Underscore in the basename, and 1m for the context size
self.assertEqual(gguf.Metadata.get_model_id_components("internlm/internlm2_5-7b-chat-1m", 7 * 10**9),
('internlm2_5-7b-chat-1m', 'internlm', 'internlm2_5', 'chat-1m', None, '7B'))
# Version before the finetune name
self.assertEqual(gguf.Metadata.get_model_id_components("pszemraj/jamba-900M-v0.13-KIx2"),
('jamba-900M-v0.13-KIx2', 'pszemraj', 'jamba', 'KIx2', 'v0.13', '900M'))
# TODO: hf suffix which could be ignored but isn't
self.assertEqual(gguf.Metadata.get_model_id_components("state-spaces/mamba-2.8b-hf"),
('mamba-2.8b-hf', 'state-spaces', 'mamba', 'hf', None, '2.8B'))
# Two sizes, don't merge them, the other is the number of tokens on which it was trained
self.assertEqual(gguf.Metadata.get_model_id_components("abacaj/llama-161M-100B", 161 * 10**6),
('llama-161M-100B', 'abacaj', 'llama', '100b', None, '161M'))
# It's a trap, there is no size label
self.assertEqual(gguf.Metadata.get_model_id_components("SparseLLM/relu-100B", 1340 * 10**6),
('relu-100B', 'SparseLLM', 'relu', '100b', None, None))
# Weird size notation
self.assertEqual(gguf.Metadata.get_model_id_components("bigscience/bloom-7b1-petals"),
('bloom-7b1-petals', 'bigscience', 'bloom', 'petals', None, '7.1B'))
def test_apply_metadata_heuristic_from_model_card(self):
model_card = {
'tags': ['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl'],
'model-index': [{'name': 'Mixtral-8x7B-Instruct-v0.1', 'results': []}],
'language': ['en'],
'datasets': ['teknium/OpenHermes-2.5'],
'widget': [{'example_title': 'Hermes 2 Pro', 'messages': [{'role': 'system', 'content': 'You are a sentient, superintelligent artificial general intelligence, here to teach and assist me.'}, {'role': 'user', 'content': 'Write a short story about Goku discovering kirby has teamed up with Majin Buu to destroy the world.'}]}],
'base_model': ["EmbeddedLLM/Mistral-7B-Merge-14-v0", "janai-hq/trinity-v1"]
}
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
expect = gguf.Metadata()
expect.base_models=[{'name': 'Mistral 7B Merge 14 v0', 'organization': 'EmbeddedLLM', 'version': 'v0', 'repo_url': 'https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0'}, {'name': 'Trinity v1', 'organization': 'Janai Hq', 'version': 'v1', 'repo_url': 'https://huggingface.co/janai-hq/trinity-v1'}]
expect.tags=['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl']
expect.languages=['en']
expect.datasets=['teknium/OpenHermes-2.5']
self.assertEqual(got, expect)
def test_apply_metadata_heuristic_from_hf_parameters(self):
hf_params = {"_name_or_path": "./hermes-2-pro-llama-3-8b-DPO"}
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card=None, hf_params=hf_params, model_path=None)
expect = gguf.Metadata(name='Hermes 2 Pro Llama 3 8b DPO', finetune='DPO', basename='hermes-2-pro-llama-3', size_label='8B')
self.assertEqual(got, expect)
def test_apply_metadata_heuristic_from_model_dir(self):
model_dir_path = Path("./hermes-2-pro-llama-3-8b-DPO")
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card=None, hf_params=None, model_path=model_dir_path)
expect = gguf.Metadata(name='Hermes 2 Pro Llama 3 8b DPO', finetune='DPO', basename='hermes-2-pro-llama-3', size_label='8B')
self.assertEqual(got, expect)
if __name__ == "__main__":
unittest.main()