convert-*.py: GGUF Naming Convention Refactor and Metadata Override Refactor (#7499)
Main thing is that the default output filename will take this form {name}{parameters}{finetune}{version}{encoding}{kind} In addition this add and remove some entries in the KV store and adds a metadata class with automatic heuristics capability to derive some values based on model card content * No Change: - Internal GGUF Spec - `general.architecture` - `general.quantization_version` - `general.alignment` - `general.file_type` - General Model Details - `general.name` - `general.author` - `general.version` - `general.description` - Licensing details - `general.license` - Typically represents the converted GGUF repo (Unless made from scratch) - `general.url` - Model Source during conversion - `general.source.url` * Removed: - Model Source during conversion - `general.source.huggingface.repository` * Added: - General Model Details - `general.organization` - `general.finetune` - `general.basename` - `general.quantized_by` - `general.size_label` - Licensing details - `general.license.name` - `general.license.link` - Typically represents the converted GGUF repo (Unless made from scratch) - `general.doi` - `general.uuid` - `general.repo_url` - Model Source during conversion - `general.source.doi` - `general.source.uuid` - `general.source.repo_url` - Base Model Source - `general.base_model.count` - `general.base_model.{id}.name` - `general.base_model.{id}.author` - `general.base_model.{id}.version` - `general.base_model.{id}.organization` - `general.base_model.{id}.url` (Model Website/Paper) - `general.base_model.{id}.doi` - `general.base_model.{id}.uuid` - `general.base_model.{id}.repo_url` (Model Source Repository (git/svn/etc...)) - Array based KV stores - `general.tags` - `general.languages` - `general.datasets` --------- Co-authored-by: compilade <git@compilade.net> Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com>
This commit is contained in:
parent
3807c3de04
commit
672a6f1018
13 changed files with 1185 additions and 239 deletions
|
@ -78,5 +78,13 @@ python -m build
|
|||
python -m twine upload dist/*
|
||||
```
|
||||
|
||||
## Run Unit Tests
|
||||
|
||||
From root of this repository you can run this command to run all the unit tests
|
||||
|
||||
```bash
|
||||
python -m unittest discover ./gguf-py -v
|
||||
```
|
||||
|
||||
## TODO
|
||||
- [ ] Include conversion scripts as command line entry points in this package.
|
||||
|
|
|
@ -5,3 +5,5 @@ from .gguf_writer import *
|
|||
from .quants import *
|
||||
from .tensor_mapping import *
|
||||
from .vocab import *
|
||||
from .utility import *
|
||||
from .metadata import *
|
||||
|
|
|
@ -19,19 +19,60 @@ GGML_QUANT_VERSION = 2 # GGML_QNT_VERSION from ggml.h
|
|||
|
||||
class Keys:
|
||||
class General:
|
||||
TYPE = "general.type"
|
||||
ARCHITECTURE = "general.architecture"
|
||||
QUANTIZATION_VERSION = "general.quantization_version"
|
||||
ALIGNMENT = "general.alignment"
|
||||
NAME = "general.name"
|
||||
AUTHOR = "general.author"
|
||||
VERSION = "general.version"
|
||||
URL = "general.url"
|
||||
DESCRIPTION = "general.description"
|
||||
LICENSE = "general.license"
|
||||
SOURCE_URL = "general.source.url"
|
||||
SOURCE_HF_REPO = "general.source.huggingface.repository"
|
||||
FILE_TYPE = "general.file_type"
|
||||
TYPE = "general.type"
|
||||
ARCHITECTURE = "general.architecture"
|
||||
QUANTIZATION_VERSION = "general.quantization_version"
|
||||
ALIGNMENT = "general.alignment"
|
||||
FILE_TYPE = "general.file_type"
|
||||
|
||||
# Authorship Metadata
|
||||
NAME = "general.name"
|
||||
AUTHOR = "general.author"
|
||||
VERSION = "general.version"
|
||||
ORGANIZATION = "general.organization"
|
||||
|
||||
FINETUNE = "general.finetune"
|
||||
BASENAME = "general.basename"
|
||||
|
||||
DESCRIPTION = "general.description"
|
||||
QUANTIZED_BY = "general.quantized_by"
|
||||
|
||||
SIZE_LABEL = "general.size_label"
|
||||
|
||||
# Licensing details
|
||||
LICENSE = "general.license"
|
||||
LICENSE_NAME = "general.license.name"
|
||||
LICENSE_LINK = "general.license.link"
|
||||
|
||||
# Typically represents the converted GGUF repo (Unless native)
|
||||
URL = "general.url" # Model Website/Paper
|
||||
DOI = "general.doi"
|
||||
UUID = "general.uuid"
|
||||
REPO_URL = "general.repo_url" # Model Source Repository (git/svn/etc...)
|
||||
|
||||
# Model Source during conversion
|
||||
SOURCE_URL = "general.source.url" # Model Website/Paper
|
||||
SOURCE_DOI = "general.source.doi"
|
||||
SOURCE_UUID = "general.source.uuid"
|
||||
SOURCE_REPO_URL = "general.source.repo_url" # Model Source Repository (git/svn/etc...)
|
||||
|
||||
# Base Model Source. There can be more than one source if it's a merged
|
||||
# model like with 'Mistral-7B-Merge-14-v0.1'. This will assist in
|
||||
# tracing linage of models as it is finetuned or merged over time.
|
||||
BASE_MODEL_COUNT = "general.base_model.count"
|
||||
BASE_MODEL_NAME = "general.base_model.{id}.name"
|
||||
BASE_MODEL_AUTHOR = "general.base_model.{id}.author"
|
||||
BASE_MODEL_VERSION = "general.base_model.{id}.version"
|
||||
BASE_MODEL_ORGANIZATION = "general.base_model.{id}.organization"
|
||||
BASE_MODEL_URL = "general.base_model.{id}.url" # Model Website/Paper
|
||||
BASE_MODEL_DOI = "general.base_model.{id}.doi"
|
||||
BASE_MODEL_UUID = "general.base_model.{id}.uuid"
|
||||
BASE_MODEL_REPO_URL = "general.base_model.{id}.repo_url" # Model Source Repository (git/svn/etc...)
|
||||
|
||||
# Array based KV stores
|
||||
TAGS = "general.tags"
|
||||
LANGUAGES = "general.languages"
|
||||
DATASETS = "general.datasets"
|
||||
|
||||
class LLM:
|
||||
VOCAB_SIZE = "{arch}.vocab_size"
|
||||
|
@ -1233,7 +1274,6 @@ KEY_GENERAL_URL = Keys.General.URL
|
|||
KEY_GENERAL_DESCRIPTION = Keys.General.DESCRIPTION
|
||||
KEY_GENERAL_LICENSE = Keys.General.LICENSE
|
||||
KEY_GENERAL_SOURCE_URL = Keys.General.SOURCE_URL
|
||||
KEY_GENERAL_SOURCE_HF_REPO = Keys.General.SOURCE_HF_REPO
|
||||
KEY_GENERAL_FILE_TYPE = Keys.General.FILE_TYPE
|
||||
|
||||
# LLM
|
||||
|
|
|
@ -7,6 +7,7 @@ import struct
|
|||
import tempfile
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from math import prod
|
||||
from pathlib import Path
|
||||
from io import BufferedWriter
|
||||
from typing import IO, Any, Sequence, Mapping
|
||||
|
@ -106,6 +107,53 @@ class GGUFWriter:
|
|||
|
||||
self.add_architecture()
|
||||
|
||||
def get_total_parameter_count(self) -> tuple[int, int, int, int]:
|
||||
total_params = 0
|
||||
shared_params = 0
|
||||
expert_params = 0
|
||||
|
||||
expert_sum = 0
|
||||
n_expert_tensors = 0
|
||||
|
||||
last_lora_a: tuple[str, TensorInfo] | None = None
|
||||
|
||||
for tensors in self.tensors:
|
||||
for name, info in tensors.items():
|
||||
|
||||
shape = info.shape
|
||||
|
||||
if name.endswith(".lora_a"):
|
||||
last_lora_a = (name, info)
|
||||
continue
|
||||
elif name.endswith(".lora_b"):
|
||||
if last_lora_a is None or last_lora_a[0] != name[:-1] + "a":
|
||||
# Bail when the LoRA pair can't be found trivially
|
||||
logger.warning("can't measure LoRA size correctly, tensor order is unusual")
|
||||
return 0, 0, 0, 0
|
||||
else:
|
||||
shape = (*shape[:-1], last_lora_a[1].shape[-1])
|
||||
|
||||
size = prod(shape)
|
||||
|
||||
if "_exps." in name:
|
||||
expert_params += (size // shape[-3])
|
||||
expert_sum += shape[-3]
|
||||
n_expert_tensors += 1
|
||||
else:
|
||||
shared_params += size
|
||||
|
||||
total_params += size
|
||||
|
||||
# Hopefully this should work even for variable-expert-count models
|
||||
expert_count = (expert_sum // n_expert_tensors) if n_expert_tensors > 0 else 0
|
||||
|
||||
# Negate the total to signal it's likely not exact
|
||||
if last_lora_a is not None:
|
||||
total_params = -total_params
|
||||
|
||||
# NOTE: keep the output in the same order as accepted by 'size_label' in gguf-py/gguf/utility.py
|
||||
return total_params, shared_params, expert_params, expert_count
|
||||
|
||||
def format_shard_names(self, path: Path) -> list[Path]:
|
||||
if len(self.tensors) == 1:
|
||||
return [path]
|
||||
|
@ -115,6 +163,7 @@ class GGUFWriter:
|
|||
if self.state is WriterState.EMPTY and self.fout is not None and (path is None or path == self.path):
|
||||
# allow calling this multiple times as long as the path is the same
|
||||
return
|
||||
|
||||
if self.state is not WriterState.NO_FILE:
|
||||
raise ValueError(f'Expected output file to be not yet opened, got {self.state}')
|
||||
|
||||
|
@ -136,6 +185,8 @@ class GGUFWriter:
|
|||
|
||||
if self.dry_run:
|
||||
logger.info("Dry run, not writing files")
|
||||
for name in filenames:
|
||||
print(name) # noqa: NP100
|
||||
exit()
|
||||
|
||||
return filenames
|
||||
|
@ -430,29 +481,12 @@ class GGUFWriter:
|
|||
def add_architecture(self) -> None:
|
||||
self.add_string(Keys.General.ARCHITECTURE, self.arch)
|
||||
|
||||
def add_author(self, author: str) -> None:
|
||||
self.add_string(Keys.General.AUTHOR, author)
|
||||
def add_quantization_version(self, quantization_version: int) -> None:
|
||||
self.add_uint32(Keys.General.QUANTIZATION_VERSION, quantization_version)
|
||||
|
||||
def add_version(self, version: str) -> None:
|
||||
self.add_string(Keys.General.VERSION, version)
|
||||
|
||||
def add_tensor_data_layout(self, layout: str) -> None:
|
||||
self.add_string(Keys.LLM.TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
|
||||
|
||||
def add_url(self, url: str) -> None:
|
||||
self.add_string(Keys.General.URL, url)
|
||||
|
||||
def add_description(self, description: str) -> None:
|
||||
self.add_string(Keys.General.DESCRIPTION, description)
|
||||
|
||||
def add_licence(self, licence: str) -> None:
|
||||
self.add_string(Keys.General.LICENSE, licence)
|
||||
|
||||
def add_source_url(self, url: str) -> None:
|
||||
self.add_string(Keys.General.SOURCE_URL, url)
|
||||
|
||||
def add_source_hf_repo(self, repo: str) -> None:
|
||||
self.add_string(Keys.General.SOURCE_HF_REPO, repo)
|
||||
def add_custom_alignment(self, alignment: int) -> None:
|
||||
self.data_alignment = alignment
|
||||
self.add_uint32(Keys.General.ALIGNMENT, alignment)
|
||||
|
||||
def add_file_type(self, ftype: int) -> None:
|
||||
self.add_uint32(Keys.General.FILE_TYPE, ftype)
|
||||
|
@ -460,13 +494,101 @@ class GGUFWriter:
|
|||
def add_name(self, name: str) -> None:
|
||||
self.add_string(Keys.General.NAME, name)
|
||||
|
||||
def add_quantization_version(self, quantization_version: int) -> None:
|
||||
self.add_uint32(
|
||||
Keys.General.QUANTIZATION_VERSION, quantization_version)
|
||||
def add_author(self, author: str) -> None:
|
||||
self.add_string(Keys.General.AUTHOR, author)
|
||||
|
||||
def add_custom_alignment(self, alignment: int) -> None:
|
||||
self.data_alignment = alignment
|
||||
self.add_uint32(Keys.General.ALIGNMENT, alignment)
|
||||
def add_version(self, version: str) -> None:
|
||||
self.add_string(Keys.General.VERSION, version)
|
||||
|
||||
def add_organization(self, organization: str) -> None:
|
||||
self.add_string(Keys.General.ORGANIZATION, organization)
|
||||
|
||||
def add_finetune(self, finetune: str) -> None:
|
||||
self.add_string(Keys.General.FINETUNE, finetune)
|
||||
|
||||
def add_basename(self, basename: str) -> None:
|
||||
self.add_string(Keys.General.BASENAME, basename)
|
||||
|
||||
def add_description(self, description: str) -> None:
|
||||
self.add_string(Keys.General.DESCRIPTION, description)
|
||||
|
||||
def add_quantized_by(self, quantized: str) -> None:
|
||||
self.add_string(Keys.General.QUANTIZED_BY, quantized)
|
||||
|
||||
def add_size_label(self, size_label: str) -> None:
|
||||
self.add_string(Keys.General.SIZE_LABEL, size_label)
|
||||
|
||||
def add_license(self, license: str) -> None:
|
||||
self.add_string(Keys.General.LICENSE, license)
|
||||
|
||||
def add_license_name(self, license: str) -> None:
|
||||
self.add_string(Keys.General.LICENSE_NAME, license)
|
||||
|
||||
def add_license_link(self, license: str) -> None:
|
||||
self.add_string(Keys.General.LICENSE_LINK, license)
|
||||
|
||||
def add_url(self, url: str) -> None:
|
||||
self.add_string(Keys.General.URL, url)
|
||||
|
||||
def add_doi(self, doi: str) -> None:
|
||||
self.add_string(Keys.General.DOI, doi)
|
||||
|
||||
def add_uuid(self, uuid: str) -> None:
|
||||
self.add_string(Keys.General.UUID, uuid)
|
||||
|
||||
def add_repo_url(self, repo_url: str) -> None:
|
||||
self.add_string(Keys.General.REPO_URL, repo_url)
|
||||
|
||||
def add_source_url(self, url: str) -> None:
|
||||
self.add_string(Keys.General.SOURCE_URL, url)
|
||||
|
||||
def add_source_doi(self, doi: str) -> None:
|
||||
self.add_string(Keys.General.SOURCE_DOI, doi)
|
||||
|
||||
def add_source_uuid(self, uuid: str) -> None:
|
||||
self.add_string(Keys.General.SOURCE_UUID, uuid)
|
||||
|
||||
def add_source_repo_url(self, repo_url: str) -> None:
|
||||
self.add_string(Keys.General.SOURCE_REPO_URL, repo_url)
|
||||
|
||||
def add_base_model_count(self, source_count: int) -> None:
|
||||
self.add_uint32(Keys.General.BASE_MODEL_COUNT, source_count)
|
||||
|
||||
def add_base_model_name(self, source_id: int, name: str) -> None:
|
||||
self.add_string(Keys.General.BASE_MODEL_NAME.format(id=source_id), name)
|
||||
|
||||
def add_base_model_author(self, source_id: int, author: str) -> None:
|
||||
self.add_string(Keys.General.BASE_MODEL_AUTHOR.format(id=source_id), author)
|
||||
|
||||
def add_base_model_version(self, source_id: int, version: str) -> None:
|
||||
self.add_string(Keys.General.BASE_MODEL_VERSION.format(id=source_id), version)
|
||||
|
||||
def add_base_model_organization(self, source_id: int, organization: str) -> None:
|
||||
self.add_string(Keys.General.BASE_MODEL_ORGANIZATION.format(id=source_id), organization)
|
||||
|
||||
def add_base_model_url(self, source_id: int, url: str) -> None:
|
||||
self.add_string(Keys.General.BASE_MODEL_URL.format(id=source_id), url)
|
||||
|
||||
def add_base_model_doi(self, source_id: int, doi: str) -> None:
|
||||
self.add_string(Keys.General.BASE_MODEL_DOI.format(id=source_id), doi)
|
||||
|
||||
def add_base_model_uuid(self, source_id: int, uuid: str) -> None:
|
||||
self.add_string(Keys.General.BASE_MODEL_UUID.format(id=source_id), uuid)
|
||||
|
||||
def add_base_model_repo_url(self, source_id: int, repo_url: str) -> None:
|
||||
self.add_string(Keys.General.BASE_MODEL_REPO_URL.format(id=source_id), repo_url)
|
||||
|
||||
def add_tags(self, tags: Sequence[str]) -> None:
|
||||
self.add_array(Keys.General.TAGS, tags)
|
||||
|
||||
def add_languages(self, languages: Sequence[str]) -> None:
|
||||
self.add_array(Keys.General.LANGUAGES, languages)
|
||||
|
||||
def add_datasets(self, datasets: Sequence[str]) -> None:
|
||||
self.add_array(Keys.General.DATASETS, datasets)
|
||||
|
||||
def add_tensor_data_layout(self, layout: str) -> None:
|
||||
self.add_string(Keys.LLM.TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
|
||||
|
||||
def add_vocab_size(self, size: int) -> None:
|
||||
self.add_uint32(Keys.LLM.VOCAB_SIZE.format(arch=self.arch), size)
|
||||
|
|
485
gguf-py/gguf/metadata.py
Normal file
485
gguf-py/gguf/metadata.py
Normal file
|
@ -0,0 +1,485 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import json
|
||||
import yaml
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .constants import Keys
|
||||
|
||||
import gguf
|
||||
|
||||
logger = logging.getLogger("metadata")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Metadata:
|
||||
# Authorship Metadata to be written to GGUF KV Store
|
||||
name: Optional[str] = None
|
||||
author: Optional[str] = None
|
||||
version: Optional[str] = None
|
||||
organization: Optional[str] = None
|
||||
finetune: Optional[str] = None
|
||||
basename: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
quantized_by: Optional[str] = None
|
||||
size_label: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
doi: Optional[str] = None
|
||||
uuid: Optional[str] = None
|
||||
repo_url: Optional[str] = None
|
||||
source_url: Optional[str] = None
|
||||
source_doi: Optional[str] = None
|
||||
source_uuid: Optional[str] = None
|
||||
source_repo_url: Optional[str] = None
|
||||
license: Optional[str] = None
|
||||
license_name: Optional[str] = None
|
||||
license_link: Optional[str] = None
|
||||
base_models: Optional[list[dict]] = None
|
||||
tags: Optional[list[str]] = None
|
||||
languages: Optional[list[str]] = None
|
||||
datasets: Optional[list[str]] = None
|
||||
|
||||
@staticmethod
|
||||
def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Path] = None, model_name: Optional[str] = None, total_params: int = 0) -> Metadata:
|
||||
# This grabs as many contextual authorship metadata as possible from the model repository
|
||||
# making any conversion as required to match the gguf kv store metadata format
|
||||
# as well as giving users the ability to override any authorship metadata that may be incorrect
|
||||
|
||||
# Create a new Metadata instance
|
||||
metadata = Metadata()
|
||||
|
||||
model_card = Metadata.load_model_card(model_path)
|
||||
hf_params = Metadata.load_hf_parameters(model_path)
|
||||
|
||||
# heuristics
|
||||
metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_path, total_params)
|
||||
|
||||
# Metadata Override File Provided
|
||||
# This is based on LLM_KV_NAMES mapping in llama.cpp
|
||||
metadata_override = Metadata.load_metadata_override(metadata_override_path)
|
||||
|
||||
metadata.author = metadata_override.get(Keys.General.AUTHOR, metadata.author)
|
||||
metadata.version = metadata_override.get(Keys.General.VERSION, metadata.version)
|
||||
metadata.organization = metadata_override.get(Keys.General.ORGANIZATION, metadata.organization)
|
||||
|
||||
metadata.finetune = metadata_override.get(Keys.General.FINETUNE, metadata.finetune)
|
||||
metadata.basename = metadata_override.get(Keys.General.BASENAME, metadata.basename)
|
||||
|
||||
metadata.description = metadata_override.get(Keys.General.DESCRIPTION, metadata.description)
|
||||
metadata.quantized_by = metadata_override.get(Keys.General.QUANTIZED_BY, metadata.quantized_by)
|
||||
|
||||
metadata.size_label = metadata_override.get(Keys.General.SIZE_LABEL, metadata.size_label)
|
||||
metadata.license_name = metadata_override.get(Keys.General.LICENSE_NAME, metadata.license_name)
|
||||
metadata.license_link = metadata_override.get(Keys.General.LICENSE_LINK, metadata.license_link)
|
||||
|
||||
metadata.url = metadata_override.get(Keys.General.URL, metadata.url)
|
||||
metadata.doi = metadata_override.get(Keys.General.DOI, metadata.doi)
|
||||
metadata.uuid = metadata_override.get(Keys.General.UUID, metadata.uuid)
|
||||
metadata.repo_url = metadata_override.get(Keys.General.REPO_URL, metadata.repo_url)
|
||||
|
||||
metadata.source_url = metadata_override.get(Keys.General.SOURCE_URL, metadata.source_url)
|
||||
metadata.source_doi = metadata_override.get(Keys.General.SOURCE_DOI, metadata.source_doi)
|
||||
metadata.source_uuid = metadata_override.get(Keys.General.SOURCE_UUID, metadata.source_uuid)
|
||||
metadata.source_repo_url = metadata_override.get(Keys.General.SOURCE_REPO_URL, metadata.source_repo_url)
|
||||
|
||||
# Base Models is received here as an array of models
|
||||
metadata.base_models = metadata_override.get("general.base_models", metadata.base_models)
|
||||
|
||||
metadata.tags = metadata_override.get(Keys.General.TAGS, metadata.tags)
|
||||
metadata.languages = metadata_override.get(Keys.General.LANGUAGES, metadata.languages)
|
||||
metadata.datasets = metadata_override.get(Keys.General.DATASETS, metadata.datasets)
|
||||
|
||||
# Direct Metadata Override (via direct cli argument)
|
||||
if model_name is not None:
|
||||
metadata.name = model_name
|
||||
|
||||
return metadata
|
||||
|
||||
@staticmethod
|
||||
def load_metadata_override(metadata_override_path: Optional[Path] = None) -> dict[str, Any]:
|
||||
if metadata_override_path is None or not metadata_override_path.is_file():
|
||||
return {}
|
||||
|
||||
with open(metadata_override_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
@staticmethod
|
||||
def load_model_card(model_path: Optional[Path] = None) -> dict[str, Any]:
|
||||
if model_path is None or not model_path.is_dir():
|
||||
return {}
|
||||
|
||||
model_card_path = model_path / "README.md"
|
||||
|
||||
if not model_card_path.is_file():
|
||||
return {}
|
||||
|
||||
# The model card metadata is assumed to always be in YAML
|
||||
# ref: https://github.com/huggingface/transformers/blob/a5c642fe7a1f25d3bdcd76991443ba6ff7ee34b2/src/transformers/modelcard.py#L468-L473
|
||||
with open(model_card_path, "r", encoding="utf-8") as f:
|
||||
if f.readline() == "---\n":
|
||||
raw = f.read().partition("---\n")[0]
|
||||
data = yaml.safe_load(raw)
|
||||
if isinstance(data, dict):
|
||||
return data
|
||||
else:
|
||||
logger.error(f"while reading YAML model card frontmatter, data is {type(data)} instead of dict")
|
||||
return {}
|
||||
else:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def load_hf_parameters(model_path: Optional[Path] = None) -> dict[str, Any]:
|
||||
if model_path is None or not model_path.is_dir():
|
||||
return {}
|
||||
|
||||
config_path = model_path / "config.json"
|
||||
|
||||
if not config_path.is_file():
|
||||
return {}
|
||||
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
@staticmethod
|
||||
def id_to_title(string):
|
||||
# Convert capitalization into title form unless acronym or version number
|
||||
return ' '.join([w.title() if w.islower() and not re.match(r'^(v\d+(?:\.\d+)*|\d.*)$', w) else w for w in string.strip().replace('-', ' ').split()])
|
||||
|
||||
@staticmethod
|
||||
def get_model_id_components(model_id: Optional[str] = None, total_params: int = 0) -> tuple[str | None, str | None, str | None, str | None, str | None, str | None]:
|
||||
# Huggingface often store model id as '<org>/<model name>'
|
||||
# so let's parse it and apply some heuristics if possible for model name components
|
||||
|
||||
if model_id is None:
|
||||
# model ID missing
|
||||
return None, None, None, None, None, None
|
||||
|
||||
if ' ' in model_id:
|
||||
# model ID is actually a normal human sentence
|
||||
# which means its most likely a normal model name only
|
||||
# not part of the hugging face naming standard, but whatever
|
||||
return model_id, None, None, None, None, None
|
||||
|
||||
if '/' in model_id:
|
||||
# model ID (huggingface style)
|
||||
org_component, model_full_name_component = model_id.split('/', 1)
|
||||
else:
|
||||
# model ID but missing org components
|
||||
org_component, model_full_name_component = None, model_id
|
||||
|
||||
# Check if we erroneously matched against './' or '../' etc...
|
||||
if org_component is not None and org_component[0] == '.':
|
||||
org_component = None
|
||||
|
||||
name_parts: list[str] = model_full_name_component.split('-')
|
||||
name_types: list[
|
||||
set[Literal["basename", "size_label", "finetune", "version", "type"]]
|
||||
] = [set() for _ in name_parts]
|
||||
|
||||
# Annotate the name
|
||||
for i, part in enumerate(name_parts):
|
||||
# Version
|
||||
if re.fullmatch(r'(v|iter)?\d+([.]\d+)*', part, re.IGNORECASE):
|
||||
name_types[i].add("version")
|
||||
# Quant type (should not be there for base models, but still annotated)
|
||||
elif re.fullmatch(r'i?q\d(_\w)*|b?fp?(16|32)', part, re.IGNORECASE):
|
||||
name_types[i].add("type")
|
||||
name_parts[i] = part.upper()
|
||||
# Model size
|
||||
elif i > 0 and re.fullmatch(r'(([A]|\d+[x])?\d+([._]\d+)?[KMBT][\d]?|small|mini|medium|large|x?xl)', part, re.IGNORECASE):
|
||||
part = part.replace("_", ".")
|
||||
# Handle weird bloom-7b1 notation
|
||||
if part[-1].isdecimal():
|
||||
part = part[:-2] + "." + part[-1] + part[-2]
|
||||
# Normalize the size suffixes
|
||||
if len(part) > 1 and part[-2].isdecimal():
|
||||
if part[-1] in "kmbt":
|
||||
part = part[:-1] + part[-1].upper()
|
||||
if total_params != 0:
|
||||
try:
|
||||
label_params = float(part[:-1]) * pow(1000, " KMBT".find(part[-1]))
|
||||
# Only use it as a size label if it's close or bigger than the model size
|
||||
# Note that LoRA adapters don't necessarily include all layers,
|
||||
# so this is why bigger label sizes are accepted.
|
||||
# Do not use the size label when it's smaller than 1/8 of the model size
|
||||
if (total_params < 0 and label_params < abs(total_params) // 8) or (
|
||||
# Check both directions when the current model isn't a LoRA adapter
|
||||
total_params > 0 and abs(label_params - total_params) > 7 * total_params // 8
|
||||
):
|
||||
# Likely a context length
|
||||
name_types[i].add("finetune")
|
||||
# Lowercase the size when it's a context length
|
||||
part = part[:-1] + part[-1].lower()
|
||||
except ValueError:
|
||||
# Failed to convert the size label to float, use it anyway
|
||||
pass
|
||||
if len(name_types[i]) == 0:
|
||||
name_types[i].add("size_label")
|
||||
name_parts[i] = part
|
||||
# Some easy to recognize finetune names
|
||||
elif i > 0 and re.fullmatch(r'chat|instruct|vision|lora', part, re.IGNORECASE):
|
||||
name_types[i].add("finetune")
|
||||
if part.lower() == "lora":
|
||||
name_parts[i] = "LoRA"
|
||||
|
||||
at_start = True
|
||||
# Find the basename through the annotated name
|
||||
for part, t in zip(name_parts, name_types):
|
||||
if at_start and ((len(t) == 0 and part[0].isalpha()) or "version" in t):
|
||||
t.add("basename")
|
||||
else:
|
||||
if at_start:
|
||||
at_start = False
|
||||
if len(t) == 0:
|
||||
t.add("finetune")
|
||||
|
||||
# Remove the basename annotation from trailing version
|
||||
for part, t in zip(reversed(name_parts), reversed(name_types)):
|
||||
if "basename" in t:
|
||||
if len(t) > 1:
|
||||
t.remove("basename")
|
||||
else:
|
||||
break
|
||||
|
||||
basename = "-".join(n for n, t in zip(name_parts, name_types) if "basename" in t) or None
|
||||
size_label = "-".join(s for s, t in zip(name_parts, name_types) if "size_label" in t) or None
|
||||
finetune = "-".join(f for f, t in zip(name_parts, name_types) if "finetune" in t) or None
|
||||
# TODO: should the basename version always be excluded?
|
||||
# TODO: should multiple versions be joined together?
|
||||
version = ([v for v, t, in zip(name_parts, name_types) if "version" in t and "basename" not in t] or [None])[-1]
|
||||
|
||||
if size_label is None and finetune is None and version is None:
|
||||
# Too ambiguous, output nothing
|
||||
basename = None
|
||||
|
||||
return model_full_name_component, org_component, basename, finetune, version, size_label
|
||||
|
||||
@staticmethod
|
||||
def apply_metadata_heuristic(metadata: Metadata, model_card: Optional[dict] = None, hf_params: Optional[dict] = None, model_path: Optional[Path] = None, total_params: int = 0) -> Metadata:
|
||||
# Reference Model Card Metadata: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
|
||||
|
||||
# Model Card Heuristics
|
||||
########################
|
||||
if model_card is not None:
|
||||
|
||||
if "model_name" in model_card and metadata.name is None:
|
||||
# Not part of huggingface model card standard but notice some model creator using it
|
||||
# such as TheBloke in 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF'
|
||||
metadata.name = model_card.get("model_name")
|
||||
|
||||
if "model_creator" in model_card and metadata.author is None:
|
||||
# Not part of huggingface model card standard but notice some model creator using it
|
||||
# such as TheBloke in 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF'
|
||||
metadata.author = model_card.get("model_creator")
|
||||
|
||||
if "model_type" in model_card and metadata.basename is None:
|
||||
# Not part of huggingface model card standard but notice some model creator using it
|
||||
# such as TheBloke in 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF'
|
||||
metadata.basename = model_card.get("model_type")
|
||||
|
||||
if "base_model" in model_card:
|
||||
# This represents the parent models that this is based on
|
||||
# Example: stabilityai/stable-diffusion-xl-base-1.0. Can also be a list (for merges)
|
||||
# Example of merges: https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0.1/blob/main/README.md
|
||||
metadata_base_models = []
|
||||
base_model_value = model_card.get("base_model", None)
|
||||
|
||||
if base_model_value is not None:
|
||||
if isinstance(base_model_value, str):
|
||||
metadata_base_models.append(base_model_value)
|
||||
elif isinstance(base_model_value, list):
|
||||
metadata_base_models.extend(base_model_value)
|
||||
|
||||
if metadata.base_models is None:
|
||||
metadata.base_models = []
|
||||
|
||||
for model_id in metadata_base_models:
|
||||
# NOTE: model size of base model is assumed to be similar to the size of the current model
|
||||
model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
|
||||
base_model = {}
|
||||
if model_full_name_component is not None:
|
||||
base_model["name"] = Metadata.id_to_title(model_full_name_component)
|
||||
if org_component is not None:
|
||||
base_model["organization"] = Metadata.id_to_title(org_component)
|
||||
if version is not None:
|
||||
base_model["version"] = version
|
||||
if org_component is not None and model_full_name_component is not None:
|
||||
base_model["repo_url"] = f"https://huggingface.co/{org_component}/{model_full_name_component}"
|
||||
metadata.base_models.append(base_model)
|
||||
|
||||
if "license" in model_card and metadata.license is None:
|
||||
metadata.license = model_card.get("license")
|
||||
|
||||
if "license_name" in model_card and metadata.license_name is None:
|
||||
metadata.license_name = model_card.get("license_name")
|
||||
|
||||
if "license_link" in model_card and metadata.license_link is None:
|
||||
metadata.license_link = model_card.get("license_link")
|
||||
|
||||
tags_value = model_card.get("tags", None)
|
||||
if tags_value is not None:
|
||||
|
||||
if metadata.tags is None:
|
||||
metadata.tags = []
|
||||
|
||||
if isinstance(tags_value, str):
|
||||
metadata.tags.append(tags_value)
|
||||
elif isinstance(tags_value, list):
|
||||
metadata.tags.extend(tags_value)
|
||||
|
||||
pipeline_tags_value = model_card.get("pipeline_tag", None)
|
||||
if pipeline_tags_value is not None:
|
||||
|
||||
if metadata.tags is None:
|
||||
metadata.tags = []
|
||||
|
||||
if isinstance(pipeline_tags_value, str):
|
||||
metadata.tags.append(pipeline_tags_value)
|
||||
elif isinstance(pipeline_tags_value, list):
|
||||
metadata.tags.extend(pipeline_tags_value)
|
||||
|
||||
language_value = model_card.get("languages", model_card.get("language", None))
|
||||
if language_value is not None:
|
||||
|
||||
if metadata.languages is None:
|
||||
metadata.languages = []
|
||||
|
||||
if isinstance(language_value, str):
|
||||
metadata.languages.append(language_value)
|
||||
elif isinstance(language_value, list):
|
||||
metadata.languages.extend(language_value)
|
||||
|
||||
dataset_value = model_card.get("datasets", model_card.get("dataset", None))
|
||||
if dataset_value is not None:
|
||||
|
||||
if metadata.datasets is None:
|
||||
metadata.datasets = []
|
||||
|
||||
if isinstance(dataset_value, str):
|
||||
metadata.datasets.append(dataset_value)
|
||||
elif isinstance(dataset_value, list):
|
||||
metadata.datasets.extend(dataset_value)
|
||||
|
||||
# Hugging Face Parameter Heuristics
|
||||
####################################
|
||||
|
||||
if hf_params is not None:
|
||||
|
||||
hf_name_or_path = hf_params.get("_name_or_path")
|
||||
if hf_name_or_path is not None and hf_name_or_path.count('/') <= 1:
|
||||
# Use _name_or_path only if its actually a model name and not some computer path
|
||||
# e.g. 'meta-llama/Llama-2-7b-hf'
|
||||
model_id = hf_name_or_path
|
||||
model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
|
||||
if metadata.name is None and model_full_name_component is not None:
|
||||
metadata.name = Metadata.id_to_title(model_full_name_component)
|
||||
if metadata.organization is None and org_component is not None:
|
||||
metadata.organization = Metadata.id_to_title(org_component)
|
||||
if metadata.basename is None and basename is not None:
|
||||
metadata.basename = basename
|
||||
if metadata.finetune is None and finetune is not None:
|
||||
metadata.finetune = finetune
|
||||
if metadata.version is None and version is not None:
|
||||
metadata.version = version
|
||||
if metadata.size_label is None and size_label is not None:
|
||||
metadata.size_label = size_label
|
||||
|
||||
# Directory Folder Name Fallback Heuristics
|
||||
############################################
|
||||
if model_path is not None:
|
||||
model_id = model_path.name
|
||||
model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
|
||||
if metadata.name is None and model_full_name_component is not None:
|
||||
metadata.name = Metadata.id_to_title(model_full_name_component)
|
||||
if metadata.organization is None and org_component is not None:
|
||||
metadata.organization = Metadata.id_to_title(org_component)
|
||||
if metadata.basename is None and basename is not None:
|
||||
metadata.basename = basename
|
||||
if metadata.finetune is None and finetune is not None:
|
||||
metadata.finetune = finetune
|
||||
if metadata.version is None and version is not None:
|
||||
metadata.version = version
|
||||
if metadata.size_label is None and size_label is not None:
|
||||
metadata.size_label = size_label
|
||||
|
||||
return metadata
|
||||
|
||||
def set_gguf_meta_model(self, gguf_writer: gguf.GGUFWriter):
|
||||
assert self.name is not None
|
||||
gguf_writer.add_name(self.name)
|
||||
|
||||
if self.author is not None:
|
||||
gguf_writer.add_author(self.author)
|
||||
if self.version is not None:
|
||||
gguf_writer.add_version(self.version)
|
||||
if self.organization is not None:
|
||||
gguf_writer.add_organization(self.organization)
|
||||
|
||||
if self.finetune is not None:
|
||||
gguf_writer.add_finetune(self.finetune)
|
||||
if self.basename is not None:
|
||||
gguf_writer.add_basename(self.basename)
|
||||
|
||||
if self.description is not None:
|
||||
gguf_writer.add_description(self.description)
|
||||
if self.quantized_by is not None:
|
||||
gguf_writer.add_quantized_by(self.quantized_by)
|
||||
|
||||
if self.size_label is not None:
|
||||
gguf_writer.add_size_label(self.size_label)
|
||||
|
||||
if self.license is not None:
|
||||
gguf_writer.add_license(self.license)
|
||||
if self.license_name is not None:
|
||||
gguf_writer.add_license_name(self.license_name)
|
||||
if self.license_link is not None:
|
||||
gguf_writer.add_license_link(self.license_link)
|
||||
|
||||
if self.url is not None:
|
||||
gguf_writer.add_url(self.url)
|
||||
if self.doi is not None:
|
||||
gguf_writer.add_doi(self.doi)
|
||||
if self.uuid is not None:
|
||||
gguf_writer.add_uuid(self.uuid)
|
||||
if self.repo_url is not None:
|
||||
gguf_writer.add_repo_url(self.repo_url)
|
||||
|
||||
if self.source_url is not None:
|
||||
gguf_writer.add_source_url(self.source_url)
|
||||
if self.source_doi is not None:
|
||||
gguf_writer.add_source_doi(self.source_doi)
|
||||
if self.source_uuid is not None:
|
||||
gguf_writer.add_source_uuid(self.source_uuid)
|
||||
if self.source_repo_url is not None:
|
||||
gguf_writer.add_source_repo_url(self.source_repo_url)
|
||||
|
||||
if self.base_models is not None:
|
||||
gguf_writer.add_base_model_count(len(self.base_models))
|
||||
for key, base_model_entry in enumerate(self.base_models):
|
||||
if "name" in base_model_entry:
|
||||
gguf_writer.add_base_model_name(key, base_model_entry["name"])
|
||||
if "author" in base_model_entry:
|
||||
gguf_writer.add_base_model_author(key, base_model_entry["author"])
|
||||
if "version" in base_model_entry:
|
||||
gguf_writer.add_base_model_version(key, base_model_entry["version"])
|
||||
if "organization" in base_model_entry:
|
||||
gguf_writer.add_base_model_organization(key, base_model_entry["organization"])
|
||||
if "url" in base_model_entry:
|
||||
gguf_writer.add_base_model_url(key, base_model_entry["url"])
|
||||
if "doi" in base_model_entry:
|
||||
gguf_writer.add_base_model_doi(key, base_model_entry["doi"])
|
||||
if "uuid" in base_model_entry:
|
||||
gguf_writer.add_base_model_uuid(key, base_model_entry["uuid"])
|
||||
if "repo_url" in base_model_entry:
|
||||
gguf_writer.add_base_model_repo_url(key, base_model_entry["repo_url"])
|
||||
|
||||
if self.tags is not None:
|
||||
gguf_writer.add_tags(self.tags)
|
||||
if self.languages is not None:
|
||||
gguf_writer.add_languages(self.languages)
|
||||
if self.datasets is not None:
|
||||
gguf_writer.add_datasets(self.datasets)
|
69
gguf-py/gguf/utility.py
Normal file
69
gguf-py/gguf/utility.py
Normal file
|
@ -0,0 +1,69 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
|
||||
def fill_templated_filename(filename: str, output_type: str | None) -> str:
|
||||
# Given a file name fill in any type templates e.g. 'some-model-name.{ftype}.gguf'
|
||||
ftype_lowercase: str = output_type.lower() if output_type is not None else ""
|
||||
ftype_uppercase: str = output_type.upper() if output_type is not None else ""
|
||||
return filename.format(ftype_lowercase,
|
||||
outtype=ftype_lowercase, ftype=ftype_lowercase,
|
||||
OUTTYPE=ftype_uppercase, FTYPE=ftype_uppercase)
|
||||
|
||||
|
||||
def model_weight_count_rounded_notation(model_params_count: int, min_digits: int = 2) -> str:
|
||||
if model_params_count > 1e12 :
|
||||
# Trillions Of Parameters
|
||||
scaled_model_params = model_params_count * 1e-12
|
||||
scale_suffix = "T"
|
||||
elif model_params_count > 1e9 :
|
||||
# Billions Of Parameters
|
||||
scaled_model_params = model_params_count * 1e-9
|
||||
scale_suffix = "B"
|
||||
elif model_params_count > 1e6 :
|
||||
# Millions Of Parameters
|
||||
scaled_model_params = model_params_count * 1e-6
|
||||
scale_suffix = "M"
|
||||
else:
|
||||
# Thousands Of Parameters
|
||||
scaled_model_params = model_params_count * 1e-3
|
||||
scale_suffix = "K"
|
||||
|
||||
fix = max(min_digits - len(str(round(scaled_model_params)).lstrip('0')), 0)
|
||||
|
||||
return f"{scaled_model_params:.{fix}f}{scale_suffix}"
|
||||
|
||||
|
||||
def size_label(total_params: int, shared_params: int, expert_params: int, expert_count: int) -> str:
|
||||
|
||||
if expert_count > 0:
|
||||
pretty_size = model_weight_count_rounded_notation(abs(shared_params) + abs(expert_params), min_digits=2)
|
||||
size_class = f"{expert_count}x{pretty_size}"
|
||||
else:
|
||||
size_class = model_weight_count_rounded_notation(abs(total_params), min_digits=2)
|
||||
|
||||
return size_class
|
||||
|
||||
|
||||
def naming_convention(model_name: str | None, base_name: str | None, finetune_string: str | None, version_string: str | None, size_label: str | None, output_type: str | None, model_type: Literal['vocab', 'LoRA'] | None = None) -> str:
|
||||
# Reference: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#gguf-naming-convention
|
||||
|
||||
if base_name is not None:
|
||||
name = base_name.strip().title().replace(' ', '-').replace('/', '-')
|
||||
elif model_name is not None:
|
||||
name = model_name.strip().title().replace(' ', '-').replace('/', '-')
|
||||
else:
|
||||
name = "ggml-model"
|
||||
|
||||
parameters = f"-{size_label}" if size_label is not None else ""
|
||||
|
||||
finetune = f"-{finetune_string.strip().title().replace(' ', '-')}" if finetune_string is not None else ""
|
||||
|
||||
version = f"-{version_string.strip().replace(' ', '-')}" if version_string is not None else ""
|
||||
|
||||
encoding = f"-{output_type.strip().replace(' ', '-').upper()}" if output_type is not None else ""
|
||||
|
||||
kind = f"-{model_type.strip().replace(' ', '-')}" if model_type is not None else ""
|
||||
|
||||
return f"{name}{parameters}{finetune}{version}{encoding}{kind}"
|
|
@ -22,6 +22,7 @@ classifiers = [
|
|||
python = ">=3.8"
|
||||
numpy = ">=1.17"
|
||||
tqdm = ">=4.27"
|
||||
pyyaml = ">=5.1"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
pytest = "^5.2"
|
||||
|
|
1
gguf-py/tests/__init__.py
Normal file
1
gguf-py/tests/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
from .test_metadata import *
|
|
@ -1,7 +0,0 @@
|
|||
import gguf # noqa: F401 # pyright: ignore[reportUnusedImport]
|
||||
|
||||
# TODO: add tests
|
||||
|
||||
|
||||
def test_write_gguf() -> None:
|
||||
pass
|
158
gguf-py/tests/test_metadata.py
Executable file
158
gguf-py/tests/test_metadata.py
Executable file
|
@ -0,0 +1,158 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Necessary to load the local gguf package
|
||||
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
import gguf
|
||||
|
||||
|
||||
class TestMetadataMethod(unittest.TestCase):
|
||||
|
||||
def test_id_to_title(self):
|
||||
self.assertEqual(gguf.Metadata.id_to_title("Mixtral-8x7B-Instruct-v0.1"), "Mixtral 8x7B Instruct v0.1")
|
||||
self.assertEqual(gguf.Metadata.id_to_title("Meta-Llama-3-8B"), "Meta Llama 3 8B")
|
||||
self.assertEqual(gguf.Metadata.id_to_title("hermes-2-pro-llama-3-8b-DPO"), "Hermes 2 Pro Llama 3 8b DPO")
|
||||
|
||||
def test_get_model_id_components(self):
|
||||
# This is the basic standard form with organization marker
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("Mistral/Mixtral-8x7B-Instruct-v0.1"),
|
||||
('Mixtral-8x7B-Instruct-v0.1', "Mistral", 'Mixtral', 'Instruct', 'v0.1', '8x7B'))
|
||||
|
||||
# Similar to basic standard form but without organization marker
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B-Instruct-v0.1"),
|
||||
('Mixtral-8x7B-Instruct-v0.1', None, 'Mixtral', 'Instruct', 'v0.1', '8x7B'))
|
||||
|
||||
# Missing version
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B-Instruct"),
|
||||
('Mixtral-8x7B-Instruct', None, 'Mixtral', 'Instruct', None, '8x7B'))
|
||||
|
||||
# Missing finetune
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B-v0.1"),
|
||||
('Mixtral-8x7B-v0.1', None, 'Mixtral', None, 'v0.1', '8x7B'))
|
||||
|
||||
# Base name and size label only
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B"),
|
||||
('Mixtral-8x7B', None, 'Mixtral', None, None, '8x7B'))
|
||||
|
||||
# Base name and version only
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-v0.1"),
|
||||
('Mixtral-v0.1', None, 'Mixtral', None, 'v0.1', None))
|
||||
|
||||
## Edge Cases ##
|
||||
|
||||
# This is too ambiguous... best to err on caution and output nothing
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral"),
|
||||
('Mixtral', None, None, None, None, None))
|
||||
|
||||
# Basename has numbers mixed in and also size label provided. Must avoid capturing number in basename
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("NousResearch/Meta-Llama-3-8B"),
|
||||
('Meta-Llama-3-8B', "NousResearch", 'Meta-Llama-3', None, None, '8B'))
|
||||
|
||||
# Can't detect all non standard form in a heuristically safe way... best to err in caution and output nothing...
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("Qwen1.5-MoE-A2.7B-Chat"),
|
||||
('Qwen1.5-MoE-A2.7B-Chat', None, 'Qwen1.5-MoE', 'Chat', None, 'A2.7B'))
|
||||
|
||||
# Capture 'sub size labels' e.g. A14B in '57B-A14B' usually refers to activated params/weight count
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("Qwen2-57B-A14B-Instruct"),
|
||||
('Qwen2-57B-A14B-Instruct', None, 'Qwen2', 'Instruct', None, '57B-A14B'))
|
||||
|
||||
# Check that it can handle a real model id with no version code
|
||||
# Note that 4k in this string is non standard and microsoft were referring to context length rather than weight count
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("microsoft/Phi-3-mini-4k-instruct", 4 * 10**9),
|
||||
('Phi-3-mini-4k-instruct', 'microsoft', 'Phi-3', '4k-instruct', None, 'mini'))
|
||||
|
||||
# There is some legitimate models with only thousands of parameters
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("delphi-suite/stories-llama2-50k", 50 * 10**3),
|
||||
('stories-llama2-50k', 'delphi-suite', 'stories-llama2', None, None, '50K'))
|
||||
|
||||
# None standard and not easy to disambiguate
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("DeepSeek-Coder-V2-Lite-Instruct"),
|
||||
('DeepSeek-Coder-V2-Lite-Instruct', None, 'DeepSeek-Coder-V2-Lite', 'Instruct', None, None))
|
||||
|
||||
# This is a real model_id where they append 2DPO to refer to Direct Preference Optimization
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("crestf411/daybreak-kunoichi-2dpo-7b"),
|
||||
('daybreak-kunoichi-2dpo-7b', 'crestf411', 'daybreak-kunoichi', '2dpo', None, '7B'))
|
||||
|
||||
# This is a real model id where the weight size has a decimal point
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("Qwen2-0.5B-Instruct"),
|
||||
('Qwen2-0.5B-Instruct', None, 'Qwen2', 'Instruct', None, '0.5B'))
|
||||
|
||||
# Uses an underscore in the size label
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("smallcloudai/Refact-1_6B-fim"),
|
||||
('Refact-1_6B-fim', 'smallcloudai', 'Refact', 'fim', None, '1.6B'))
|
||||
|
||||
# Uses Iter3 for the version
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("UCLA-AGI/Gemma-2-9B-It-SPPO-Iter3"),
|
||||
('Gemma-2-9B-It-SPPO-Iter3', 'UCLA-AGI', 'Gemma-2', 'It-SPPO', 'Iter3', '9B'))
|
||||
|
||||
# Has two potential versions in the basename
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("NousResearch/Hermes-2-Theta-Llama-3-8B"),
|
||||
('Hermes-2-Theta-Llama-3-8B', 'NousResearch', 'Hermes-2-Theta-Llama-3', None, None, '8B'))
|
||||
|
||||
# Potential version in the basename
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("SeaLLMs/SeaLLMs-v3-7B-Chat"),
|
||||
('SeaLLMs-v3-7B-Chat', 'SeaLLMs', 'SeaLLMs-v3', 'Chat', None, '7B'))
|
||||
|
||||
# Underscore in the basename, and 1m for the context size
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("internlm/internlm2_5-7b-chat-1m", 7 * 10**9),
|
||||
('internlm2_5-7b-chat-1m', 'internlm', 'internlm2_5', 'chat-1m', None, '7B'))
|
||||
|
||||
# Version before the finetune name
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("pszemraj/jamba-900M-v0.13-KIx2"),
|
||||
('jamba-900M-v0.13-KIx2', 'pszemraj', 'jamba', 'KIx2', 'v0.13', '900M'))
|
||||
|
||||
# TODO: hf suffix which could be ignored but isn't
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("state-spaces/mamba-2.8b-hf"),
|
||||
('mamba-2.8b-hf', 'state-spaces', 'mamba', 'hf', None, '2.8B'))
|
||||
|
||||
# Two sizes, don't merge them, the other is the number of tokens on which it was trained
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("abacaj/llama-161M-100B", 161 * 10**6),
|
||||
('llama-161M-100B', 'abacaj', 'llama', '100b', None, '161M'))
|
||||
|
||||
# It's a trap, there is no size label
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("SparseLLM/relu-100B", 1340 * 10**6),
|
||||
('relu-100B', 'SparseLLM', 'relu', '100b', None, None))
|
||||
|
||||
# Weird size notation
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("bigscience/bloom-7b1-petals"),
|
||||
('bloom-7b1-petals', 'bigscience', 'bloom', 'petals', None, '7.1B'))
|
||||
|
||||
def test_apply_metadata_heuristic_from_model_card(self):
|
||||
model_card = {
|
||||
'tags': ['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl'],
|
||||
'model-index': [{'name': 'Mixtral-8x7B-Instruct-v0.1', 'results': []}],
|
||||
'language': ['en'],
|
||||
'datasets': ['teknium/OpenHermes-2.5'],
|
||||
'widget': [{'example_title': 'Hermes 2 Pro', 'messages': [{'role': 'system', 'content': 'You are a sentient, superintelligent artificial general intelligence, here to teach and assist me.'}, {'role': 'user', 'content': 'Write a short story about Goku discovering kirby has teamed up with Majin Buu to destroy the world.'}]}],
|
||||
'base_model': ["EmbeddedLLM/Mistral-7B-Merge-14-v0", "janai-hq/trinity-v1"]
|
||||
}
|
||||
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
|
||||
expect = gguf.Metadata()
|
||||
expect.base_models=[{'name': 'Mistral 7B Merge 14 v0', 'organization': 'EmbeddedLLM', 'version': 'v0', 'repo_url': 'https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0'}, {'name': 'Trinity v1', 'organization': 'Janai Hq', 'version': 'v1', 'repo_url': 'https://huggingface.co/janai-hq/trinity-v1'}]
|
||||
expect.tags=['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl']
|
||||
expect.languages=['en']
|
||||
expect.datasets=['teknium/OpenHermes-2.5']
|
||||
|
||||
self.assertEqual(got, expect)
|
||||
|
||||
def test_apply_metadata_heuristic_from_hf_parameters(self):
|
||||
hf_params = {"_name_or_path": "./hermes-2-pro-llama-3-8b-DPO"}
|
||||
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card=None, hf_params=hf_params, model_path=None)
|
||||
expect = gguf.Metadata(name='Hermes 2 Pro Llama 3 8b DPO', finetune='DPO', basename='hermes-2-pro-llama-3', size_label='8B')
|
||||
self.assertEqual(got, expect)
|
||||
|
||||
def test_apply_metadata_heuristic_from_model_dir(self):
|
||||
model_dir_path = Path("./hermes-2-pro-llama-3-8b-DPO")
|
||||
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card=None, hf_params=None, model_path=model_dir_path)
|
||||
expect = gguf.Metadata(name='Hermes 2 Pro Llama 3 8b DPO', finetune='DPO', basename='hermes-2-pro-llama-3', size_label='8B')
|
||||
self.assertEqual(got, expect)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Add table
Add a link
Reference in a new issue