convert-*.py: add unittest to metadata class
This commit is contained in:
parent
3625a42061
commit
91e65d9485
6 changed files with 275 additions and 194 deletions
|
@ -129,16 +129,16 @@ class Model:
|
|||
self.metadata.name = gguf.MODEL_ARCH_NAMES[self.model_arch]
|
||||
|
||||
# Generate parameter weight class (useful for leader boards) if not yet determined
|
||||
if self.metadata.parameter_weight_class is None:
|
||||
if self.metadata.parameter_class_attribute is None:
|
||||
expert_count = self.hparams["num_local_experts"] if "num_local_experts" in self.hparams else None
|
||||
weight_estimate = self.per_model_weight_count_estimation(self.get_tensors(), expert_count)
|
||||
self.metadata.parameter_weight_class = gguf.parameter_weight_class(expert_count, weight_estimate)
|
||||
self.metadata.parameter_class_attribute = gguf.parameter_class_attribute(expert_count, weight_estimate)
|
||||
|
||||
# Extracts and converts the encoding scheme from the given file type name. e.g. 'gguf.LlamaFileType.ALL_F32' --> 'F32'
|
||||
output_type = self.ftype.name.partition("_")[2]
|
||||
|
||||
# Generate default filename based on model specification and available metadata
|
||||
self.fname_default = gguf.naming_convention(self.metadata.name, self.metadata.basename, self.metadata.finetune, self.metadata.version, self.metadata.parameter_weight_class, output_type)
|
||||
self.fname_default = gguf.naming_convention(self.metadata.name, self.metadata.basename, self.metadata.finetune, self.metadata.version, self.metadata.parameter_class_attribute, output_type)
|
||||
|
||||
# Filename Output
|
||||
if fname_out is not None:
|
||||
|
@ -263,8 +263,8 @@ class Model:
|
|||
self.gguf_writer.add_source_url(self.metadata.source_url)
|
||||
if self.metadata.source_hf_repo is not None:
|
||||
self.gguf_writer.add_source_hf_repo(self.metadata.source_hf_repo)
|
||||
if self.metadata.parameter_weight_class is not None:
|
||||
self.gguf_writer.add_parameter_weight_class(self.metadata.parameter_weight_class)
|
||||
if self.metadata.parameter_class_attribute is not None:
|
||||
self.gguf_writer.add_parameter_class_attribute(self.metadata.parameter_class_attribute)
|
||||
if self.metadata.tags is not None:
|
||||
self.gguf_writer.add_tags(self.metadata.tags)
|
||||
if self.metadata.languages is not None:
|
||||
|
|
|
@ -805,6 +805,8 @@ class OutputFile:
|
|||
self.gguf.add_source_url(metadata.source_url)
|
||||
if metadata.source_hf_repo is not None:
|
||||
self.gguf.add_source_hf_repo(metadata.source_hf_repo)
|
||||
if metadata.parameter_class_attribute is not None:
|
||||
self.gguf.add_parameter_class_attribute(metadata.parameter_class_attribute)
|
||||
if metadata.tags is not None:
|
||||
self.gguf.add_tags(metadata.tags)
|
||||
if metadata.languages is not None:
|
||||
|
@ -961,8 +963,6 @@ class OutputFile:
|
|||
|
||||
of = OutputFile(fname_out, endianess=endianess)
|
||||
|
||||
print(metadata)
|
||||
|
||||
# meta data
|
||||
of.add_meta_model(params, metadata)
|
||||
of.add_meta_arch(params)
|
||||
|
@ -1208,7 +1208,7 @@ def default_convention_outfile(file_type: GGMLFileType, expert_count:int, model_
|
|||
basename = metadata.basename if metadata.basename is not None else None
|
||||
finetune = metadata.finetune if metadata.finetune is not None else None
|
||||
version = metadata.version if metadata.version is not None else None
|
||||
parameter_weight_class = metadata.parameter_weight_class if metadata.parameter_weight_class is not None else gguf.parameter_weight_class(expert_count, model_params_count)
|
||||
parameter_class_attribute = metadata.parameter_class_attribute if metadata.parameter_class_attribute is not None else gguf.parameter_class_attribute(expert_count, model_params_count)
|
||||
|
||||
output_type = {
|
||||
GGMLFileType.AllF32: "F32",
|
||||
|
@ -1216,7 +1216,7 @@ def default_convention_outfile(file_type: GGMLFileType, expert_count:int, model_
|
|||
GGMLFileType.MostlyQ8_0: "Q8_0",
|
||||
}[file_type]
|
||||
|
||||
return gguf.naming_convention(name, basename, finetune, version, parameter_weight_class, output_type)
|
||||
return gguf.naming_convention(name, basename, finetune, version, parameter_class_attribute, output_type)
|
||||
|
||||
|
||||
def default_outfile(model_paths: list[Path], file_type: GGMLFileType, expert_count:int, model_params_count: int, metadata: gguf.Metadata) -> Path:
|
||||
|
@ -1380,7 +1380,7 @@ def main(args_in: list[str] | None = None) -> None:
|
|||
model = convert_to_output_type(model, ftype)
|
||||
outfile = args.outfile or default_outfile(model_plus.paths, ftype, params.n_experts, model_params_count, metadata=metadata)
|
||||
|
||||
metadata.parameter_weight_class = gguf.parameter_weight_class(params.n_experts, model_params_count)
|
||||
metadata.parameter_class_attribute = gguf.parameter_class_attribute(params.n_experts, model_params_count)
|
||||
|
||||
params.ftype = ftype
|
||||
logger.info(f"Writing {outfile}, format {ftype}")
|
||||
|
|
|
@ -39,7 +39,7 @@ class Keys:
|
|||
SOURCE_URL = "general.source.url"
|
||||
SOURCE_HF_REPO = "general.source.huggingface.repository"
|
||||
FILE_TYPE = "general.file_type"
|
||||
PARAMETER_WEIGHT_CLASS = "general.parameter_weight_class"
|
||||
PARAMETER_CLASS_ATTRIBUTE = "general.parameter_class_attribute"
|
||||
TAGS = "general.tags"
|
||||
LANGUAGES = "general.languages"
|
||||
DATASETS = "general.datasets"
|
||||
|
|
|
@ -478,8 +478,8 @@ class GGUFWriter:
|
|||
def add_file_type(self, ftype: int) -> None:
|
||||
self.add_uint32(Keys.General.FILE_TYPE, ftype)
|
||||
|
||||
def add_parameter_weight_class(self, parameter_weight_class: str) -> None:
|
||||
self.add_string(Keys.General.PARAMETER_WEIGHT_CLASS, parameter_weight_class)
|
||||
def add_parameter_class_attribute(self, parameter_class_attribute: str) -> None:
|
||||
self.add_string(Keys.General.PARAMETER_CLASS_ATTRIBUTE, parameter_class_attribute)
|
||||
|
||||
def add_tags(self, tags: Sequence[str]) -> None:
|
||||
self.add_array(Keys.General.TAGS, tags)
|
||||
|
|
389
gguf-py/gguf/metadata.py
Normal file → Executable file
389
gguf-py/gguf/metadata.py
Normal file → Executable file
|
@ -1,11 +1,18 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import json
|
||||
import unittest
|
||||
import frontmatter
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
if __name__ == '__main__':
|
||||
from constants import Keys
|
||||
else:
|
||||
from .constants import Keys
|
||||
|
||||
|
||||
|
@ -27,7 +34,7 @@ class Metadata:
|
|||
license_link: Optional[str] = None
|
||||
source_url: Optional[str] = None
|
||||
source_hf_repo: Optional[str] = None
|
||||
parameter_weight_class: Optional[str] = None
|
||||
parameter_class_attribute: Optional[str] = None
|
||||
tags: Optional[list[str]] = None
|
||||
languages: Optional[list[str]] = None
|
||||
datasets: Optional[list[str]] = None
|
||||
|
@ -41,120 +48,11 @@ class Metadata:
|
|||
# Create a new Metadata instance
|
||||
metadata = Metadata()
|
||||
|
||||
# load huggingface model card if available
|
||||
# Reference Model Card Metadata: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
|
||||
model_card = Metadata.load_model_card(model_path)
|
||||
hf_params = Metadata.load_hf_parameters(model_path)
|
||||
|
||||
if "model_name" in model_card:
|
||||
# Not part of huggingface model card standard but notice some model creator using it
|
||||
# such as TheBloke who would encode 'Mixtral 8X7B Instruct v0.1' into model_name
|
||||
metadata.name = model_card.get("model_name")
|
||||
|
||||
if "base_model" in model_card:
|
||||
# Example: stabilityai/stable-diffusion-xl-base-1.0. Can also be a list (for merges)
|
||||
model_id = model_card.get("base_model")
|
||||
|
||||
# Check if string. We cannot handle lists as that is too ambagious
|
||||
if isinstance(model_id, str):
|
||||
model_name_normal, organization_name, base_name, fine_tune, version_string, parameter_weight_class = Metadata.get_model_name_components(model_id)
|
||||
if metadata.name is None and model_name_normal is not None:
|
||||
metadata.name = model_name_normal
|
||||
if metadata.organization is None and organization_name is not None:
|
||||
metadata.organization = organization_name
|
||||
if metadata.basename is None and base_name is not None:
|
||||
metadata.basename = base_name
|
||||
if metadata.finetune is None and fine_tune is not None:
|
||||
metadata.finetune = fine_tune
|
||||
if metadata.version is None and version_string is not None:
|
||||
metadata.version = version_string
|
||||
if metadata.parameter_weight_class is None and parameter_weight_class is not None:
|
||||
metadata.parameter_weight_class = parameter_weight_class
|
||||
if metadata.source_url is None:
|
||||
metadata.source_url = f"https://huggingface.co/{model_id}"
|
||||
if metadata.source_hf_repo is None:
|
||||
metadata.source_hf_repo = model_id
|
||||
|
||||
if "model-index" in model_card and len(model_card["model_name"]) == 1 and "name" in model_card["model_name"][0]:
|
||||
# This is a model index which has model id that can be extracted into organization and model name
|
||||
# if so then we can safely extract organization and name
|
||||
# (This is a safe choice in case there is multiple models in one repo in the future)
|
||||
model_id = model_card["model-index"][0].get("name")
|
||||
model_name_normal, organization_name, base_name, fine_tune, version_string, parameter_weight_class = Metadata.get_model_name_components(model_id)
|
||||
|
||||
if metadata.name is None and model_name_normal is not None:
|
||||
metadata.name = model_name_normal
|
||||
if metadata.organization is None and organization_name is not None:
|
||||
metadata.organization = organization_name
|
||||
if metadata.basename is None and base_name is not None:
|
||||
metadata.basename = base_name
|
||||
if metadata.finetune is None and fine_tune is not None:
|
||||
metadata.finetune = fine_tune
|
||||
if metadata.version is None and version_string is not None:
|
||||
metadata.version = version_string
|
||||
if metadata.parameter_weight_class is None and parameter_weight_class is not None:
|
||||
metadata.parameter_weight_class = parameter_weight_class
|
||||
|
||||
if metadata.quantized_by is None:
|
||||
# Not part of hugging face model card standard, but is used by TheBloke to credit them self for quantizing 3rd party models
|
||||
metadata.quantized_by = model_card.get("quantized_by")
|
||||
if metadata.license is None:
|
||||
metadata.license = model_card.get("license")
|
||||
if metadata.license_name is None:
|
||||
metadata.license_name = model_card.get("license_name")
|
||||
if metadata.license_link is None:
|
||||
metadata.license_link = model_card.get("license_link")
|
||||
if metadata.author is None:
|
||||
# non huggingface model card standard but notice some model creator using it
|
||||
metadata.author = model_card.get("model_creator")
|
||||
if metadata.tags is None:
|
||||
metadata.tags = model_card.get("tags", None)
|
||||
if metadata.languages is None:
|
||||
metadata.languages = model_card.get("language", model_card.get("languages", None))
|
||||
if metadata.datasets is None:
|
||||
metadata.datasets = model_card.get("datasets", model_card.get("dataset", None))
|
||||
|
||||
# load huggingface parameters if available
|
||||
hf_params = Metadata.load_huggingface_parameters(model_path)
|
||||
|
||||
hf_name_or_path = hf_params.get("_name_or_path")
|
||||
if hf_name_or_path is not None and Metadata.is_model_id(hf_name_or_path):
|
||||
# Use _name_or_path only if its actually a model name and not some computer path
|
||||
# e.g. 'meta-llama/Llama-2-7b-hf'
|
||||
model_name_normal, organization_name, base_name, fine_tune, version_string, parameter_weight_class = Metadata.get_model_name_components(hf_name_or_path)
|
||||
if metadata.name is None and model_name_normal is not None:
|
||||
metadata.name = model_name_normal
|
||||
if metadata.organization is None and organization_name is not None:
|
||||
metadata.organization = organization_name
|
||||
if metadata.basename is None and base_name is not None:
|
||||
metadata.basename = base_name
|
||||
if metadata.finetune is None and fine_tune is not None:
|
||||
metadata.finetune = fine_tune
|
||||
if metadata.version is None and version_string is not None:
|
||||
metadata.version = version_string
|
||||
if metadata.parameter_weight_class is None and parameter_weight_class is not None:
|
||||
metadata.parameter_weight_class = parameter_weight_class
|
||||
if not Metadata.is_model_name_only(hf_name_or_path):
|
||||
# Can't just have the model name as the source hf repo as a link to the huggingface website needs the org name and the model name
|
||||
if metadata.source_url is None:
|
||||
metadata.source_url = f"https://huggingface.co/{hf_name_or_path}"
|
||||
if metadata.source_hf_repo is None:
|
||||
metadata.source_hf_repo = hf_name_or_path
|
||||
|
||||
# Use Directory Folder Name As Fallback Name
|
||||
if model_path is not None and model_path.exists():
|
||||
model_name_normal, organization_name, base_name, fine_tune, version_string, parameter_weight_class = Metadata.get_model_name_components(model_path.name)
|
||||
if metadata.name is None and model_name_normal is not None:
|
||||
metadata.name = model_name_normal
|
||||
if metadata.organization is None and organization_name is not None:
|
||||
metadata.organization = organization_name
|
||||
if metadata.basename is None and base_name is not None:
|
||||
metadata.basename = base_name
|
||||
if metadata.finetune is None and fine_tune is not None:
|
||||
metadata.finetune = fine_tune
|
||||
if metadata.version is None and version_string is not None:
|
||||
metadata.version = version_string
|
||||
if metadata.parameter_weight_class is None and parameter_weight_class is not None:
|
||||
metadata.parameter_weight_class = parameter_weight_class
|
||||
# heuristics
|
||||
metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_path)
|
||||
|
||||
# Metadata Override File Provided
|
||||
# This is based on LLM_KV_NAMES mapping in llama.cpp
|
||||
|
@ -174,7 +72,7 @@ class Metadata:
|
|||
metadata.license_link = metadata_override.get(Keys.General.LICENSE_LINK , metadata.license_link ) # noqa: E202
|
||||
metadata.source_url = metadata_override.get(Keys.General.SOURCE_URL , metadata.source_url ) # noqa: E202
|
||||
metadata.source_hf_repo = metadata_override.get(Keys.General.SOURCE_HF_REPO , metadata.source_hf_repo ) # noqa: E202
|
||||
metadata.parameter_weight_class = metadata_override.get(Keys.General.PARAMETER_WEIGHT_CLASS, metadata.parameter_weight_class) # noqa: E202
|
||||
metadata.parameter_class_attribute = metadata_override.get(Keys.General.PARAMETER_CLASS_ATTRIBUTE, metadata.parameter_class_attribute) # noqa: E202
|
||||
metadata.tags = metadata_override.get(Keys.General.TAGS , metadata.tags ) # noqa: E202
|
||||
metadata.languages = metadata_override.get(Keys.General.LANGUAGES , metadata.languages ) # noqa: E202
|
||||
metadata.datasets = metadata_override.get(Keys.General.DATASETS , metadata.datasets ) # noqa: E202
|
||||
|
@ -207,7 +105,7 @@ class Metadata:
|
|||
return frontmatter.load(f)
|
||||
|
||||
@staticmethod
|
||||
def load_huggingface_parameters(model_path: Optional[Path] = None) -> dict[str, object]:
|
||||
def load_hf_parameters(model_path: Optional[Path] = None) -> dict[str, object]:
|
||||
if model_path is None or not model_path.exists():
|
||||
return {}
|
||||
|
||||
|
@ -220,64 +118,247 @@ class Metadata:
|
|||
return json.load(f)
|
||||
|
||||
@staticmethod
|
||||
def is_model_id(name_or_path: Optional[str] = None) -> bool:
|
||||
# Return True if the string has 1 or 0 slashes, indicating a model id
|
||||
# Created specifically because of _name_or_path in hugging face parameter
|
||||
if name_or_path is None:
|
||||
return False
|
||||
return name_or_path.count('/') <= 1
|
||||
def id_to_title(string):
|
||||
# Convert capitalization into title form unless acronym or version number
|
||||
string = string.strip().replace('-', ' ')
|
||||
return ' '.join([w.title() if w.islower() and not re.match(r'^v\d+(?:\.\d+)*$', w) else w for w in string.split()])
|
||||
|
||||
@staticmethod
|
||||
def is_model_name_only(name_or_path: Optional[str] = None) -> bool:
|
||||
# Return True if the string has 0 slashes, indicating a model name only model id
|
||||
# Created specifically because of _name_or_path in hugging face parameter
|
||||
if name_or_path is None:
|
||||
return False
|
||||
return name_or_path.count('/') == 0
|
||||
def get_model_id_components(model_id: Optional[str] = None) -> dict[str, object]:
|
||||
# Huggingface often store model id as '<org>/<model name>'
|
||||
# so let's parse it and apply some heuristics if possible for model name components
|
||||
|
||||
@staticmethod
|
||||
def get_model_name_components(model_identifier: Optional[str] = None) -> dict[str, object]:
|
||||
# Huggingface often store model id
|
||||
|
||||
if model_identifier is None:
|
||||
if model_id is None:
|
||||
# model ID missing
|
||||
return None, None, None, None, None, None
|
||||
|
||||
if ' ' in model_identifier:
|
||||
if ' ' in model_id:
|
||||
# model ID is actually a normal human sentence
|
||||
# which means its most likely a normal model name only
|
||||
# not part of the hugging face naming standard, but whatever
|
||||
return model_identifier, None, None, None, None, None
|
||||
return model_id, None, None, None, None, None
|
||||
|
||||
if '/' in model_identifier:
|
||||
if '/' in model_id:
|
||||
# model ID (huggingface style)
|
||||
organization, model = model_identifier.split('/', 1)
|
||||
org_component, model_full_name_component = model_id.split('/', 1)
|
||||
else:
|
||||
# model ID but missing org components
|
||||
model = model_identifier
|
||||
organization = None
|
||||
org_component, model_full_name_component = None, model_id
|
||||
|
||||
# Apply formatting to organization and model_name
|
||||
# 'stable-diffusion-xl-base-1.0' --> 'Stable Diffusion Xl Base 1.0'
|
||||
|
||||
organization_name = organization.strip().replace('-', ' ').title() if organization is not None else None
|
||||
model_name_normal = model.strip().replace('-', ' ').title() if model is not None else None
|
||||
# Check if we erroneously matched against './' or '../' etc...
|
||||
if org_component is not None and org_component[0] == '.':
|
||||
org_component = None
|
||||
|
||||
# Regular expression to extract model name components
|
||||
# Heuristic to match against cases such as 'Mixtral-8x7B-Instruct-v0.1' or 'Codestral-22B-v0.1'
|
||||
|
||||
regex_match = re.compile(r'^(?P<base_name>[A-Za-z0-9\s]*(?:(?:-[A-Za-z\s][A-Za-z0-9\s]*)*))'
|
||||
r'(?:-(?P<parameter_weight_class>(?:\d+x)?\d+[A-Za-z]+))?'
|
||||
r'(?:-(?P<fine_tune>[A-Za-z0-9\s-]+))?'
|
||||
r'(?:-(?P<version_string>v\d+(?:\.\d+)*))?$').match(model)
|
||||
regex_match = re.compile(r'^(?P<basename>[A-Za-z0-9\s]*(?:(?:-(?:(?:[A-Za-z\s][A-Za-z0-9\s]*)|(?:[0-9\s]*)))*))'
|
||||
r'(?:-(?P<parameter_class_attribute>(?:\d+x)?\d+[A-Za-z]+)(?:-(?P<finetune>[A-Za-z0-9\s-]+))?)?'
|
||||
r'(?:-(?P<version>v\d+(?:\.\d+)*))?$').match(model_full_name_component)
|
||||
|
||||
if not regex_match:
|
||||
return model_name_normal, organization_name, None, None, None, None
|
||||
return model_full_name_component, org_component, None, None, None, None
|
||||
|
||||
components = regex_match.groupdict()
|
||||
base_name = components.get("base_name")
|
||||
fine_tune = components.get("fine_tune")
|
||||
version_string = components.get("version_string")
|
||||
parameter_weight_class = components.get("parameter_weight_class")
|
||||
basename = components.get("basename")
|
||||
finetune = components.get("finetune")
|
||||
version = components.get("version")
|
||||
parameter_class_attribute = components.get("parameter_class_attribute")
|
||||
|
||||
return model_name_normal, organization_name, base_name, fine_tune, version_string, parameter_weight_class
|
||||
return model_full_name_component, org_component, basename, finetune, version, parameter_class_attribute
|
||||
|
||||
@staticmethod
|
||||
def apply_metadata_heuristic(metadata: Metadata, model_card: Optional[dict] = None, hf_params: Optional[dict] = None, model_path: Optional[Path] = None) -> Metadata:
|
||||
# Reference Model Card Metadata: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
|
||||
found_model_name = False
|
||||
|
||||
# Model Card Heuristics
|
||||
########################
|
||||
if model_card is not None:
|
||||
|
||||
if "model_name" in model_card:
|
||||
# Not part of huggingface model card standard but notice some model creator using it
|
||||
# such as TheBloke who would encode 'Mixtral 8X7B Instruct v0.1' into model_name
|
||||
metadata.name = model_card.get("model_name")
|
||||
|
||||
if "model-index" in model_card and len(model_card["model-index"]) == 1 and "name" in model_card["model-index"][0]:
|
||||
# This is a model index which has model id that can be extracted into organization and model name
|
||||
# if so then we can safely extract organization and name
|
||||
# (This is a safe choice in case there is multiple models in one repo in the future)
|
||||
model_id = model_card["model-index"][0]["name"]
|
||||
model_full_name_component, org_component, basename, finetune, version, parameter_class_attribute = Metadata.get_model_id_components(model_id)
|
||||
if metadata.name is None and model_full_name_component is not None:
|
||||
metadata.name = Metadata.id_to_title(model_full_name_component)
|
||||
if metadata.organization is None and org_component is not None:
|
||||
metadata.organization = Metadata.id_to_title(org_component)
|
||||
if metadata.basename is None and basename is not None:
|
||||
metadata.basename = basename
|
||||
if metadata.finetune is None and finetune is not None:
|
||||
metadata.finetune = finetune
|
||||
if metadata.version is None and version is not None:
|
||||
metadata.version = version
|
||||
if metadata.parameter_class_attribute is None and parameter_class_attribute is not None:
|
||||
metadata.parameter_class_attribute = parameter_class_attribute
|
||||
if metadata.source_url is None and org_component is not None and model_full_name_component is not None:
|
||||
metadata.source_url = f"https://huggingface.co/{org_component}/{model_full_name_component}"
|
||||
if metadata.source_hf_repo is None and org_component is not None and model_full_name_component is not None:
|
||||
metadata.source_hf_repo = f"{org_component}/{model_full_name_component}"
|
||||
|
||||
found_model_name = True
|
||||
|
||||
if "base_model" in model_card and isinstance(model_card["base_model"], str) and not found_model_name:
|
||||
# Check if string. We cannot handle lists as that is too ambagious
|
||||
# Example: stabilityai/stable-diffusion-xl-base-1.0. Can also be a list (for merges)
|
||||
model_id = model_card.get("base_model")
|
||||
model_full_name_component, org_component, basename, finetune, version, parameter_class_attribute = Metadata.get_model_id_components(model_id)
|
||||
if metadata.name is None and model_full_name_component is not None:
|
||||
metadata.name = Metadata.id_to_title(model_full_name_component)
|
||||
if metadata.organization is None and org_component is not None:
|
||||
metadata.organization = Metadata.id_to_title(org_component)
|
||||
if metadata.basename is None and basename is not None:
|
||||
metadata.basename = basename
|
||||
if metadata.finetune is None and finetune is not None:
|
||||
metadata.finetune = finetune
|
||||
if metadata.version is None and version is not None:
|
||||
metadata.version = version
|
||||
if metadata.parameter_class_attribute is None and parameter_class_attribute is not None:
|
||||
metadata.parameter_class_attribute = parameter_class_attribute
|
||||
if metadata.source_url is None and org_component is not None and model_full_name_component is not None:
|
||||
metadata.source_url = f"https://huggingface.co/{org_component}/{model_full_name_component}"
|
||||
if metadata.source_hf_repo is None and org_component is not None and model_full_name_component is not None:
|
||||
metadata.source_hf_repo = f"{org_component}/{model_full_name_component}"
|
||||
|
||||
found_model_name = True
|
||||
|
||||
if metadata.quantized_by is None:
|
||||
# Not part of hugging face model card standard, but is used by TheBloke to credit them self for quantizing 3rd party models
|
||||
metadata.quantized_by = model_card.get("quantized_by")
|
||||
if metadata.license is None:
|
||||
metadata.license = model_card.get("license")
|
||||
if metadata.license_name is None:
|
||||
metadata.license_name = model_card.get("license_name")
|
||||
if metadata.license_link is None:
|
||||
metadata.license_link = model_card.get("license_link")
|
||||
if metadata.author is None:
|
||||
# non huggingface model card standard but notice some model creator using it
|
||||
metadata.author = model_card.get("model_creator")
|
||||
if metadata.tags is None:
|
||||
metadata.tags = model_card.get("tags", None)
|
||||
if metadata.languages is None:
|
||||
metadata.languages = model_card.get("language", model_card.get("languages", None))
|
||||
if metadata.datasets is None:
|
||||
metadata.datasets = model_card.get("datasets", model_card.get("dataset", None))
|
||||
|
||||
# Hugging Face Parameter Heuristics
|
||||
####################################
|
||||
|
||||
if hf_params is not None:
|
||||
hf_name_or_path = hf_params.get("_name_or_path")
|
||||
|
||||
if hf_name_or_path is not None and hf_name_or_path.count('/') <= 1 and not found_model_name:
|
||||
# Use _name_or_path only if its actually a model name and not some computer path
|
||||
# e.g. 'meta-llama/Llama-2-7b-hf'
|
||||
model_id = hf_name_or_path
|
||||
model_full_name_component, org_component, basename, finetune, version, parameter_class_attribute = Metadata.get_model_id_components(model_id)
|
||||
if metadata.name is None and model_full_name_component is not None:
|
||||
metadata.name = Metadata.id_to_title(model_full_name_component)
|
||||
if metadata.organization is None and org_component is not None:
|
||||
metadata.organization = Metadata.id_to_title(org_component)
|
||||
if metadata.basename is None and basename is not None:
|
||||
metadata.basename = basename
|
||||
if metadata.finetune is None and finetune is not None:
|
||||
metadata.finetune = finetune
|
||||
if metadata.version is None and version is not None:
|
||||
metadata.version = version
|
||||
if metadata.parameter_class_attribute is None and parameter_class_attribute is not None:
|
||||
metadata.parameter_class_attribute = parameter_class_attribute
|
||||
if metadata.source_url is None and org_component is not None and model_full_name_component is not None:
|
||||
metadata.source_url = f"https://huggingface.co/{org_component}/{model_full_name_component}"
|
||||
if metadata.source_hf_repo is None and org_component is not None and model_full_name_component is not None:
|
||||
metadata.source_hf_repo = f"{org_component}/{model_full_name_component}"
|
||||
|
||||
# Directory Folder Name Fallback Heuristics
|
||||
############################################
|
||||
if model_path is not None:
|
||||
model_id = model_path.name
|
||||
model_full_name_component, org_component, basename, finetune, version, parameter_class_attribute = Metadata.get_model_id_components(model_id)
|
||||
if metadata.name is None and model_full_name_component is not None:
|
||||
metadata.name = Metadata.id_to_title(model_full_name_component)
|
||||
if metadata.organization is None and org_component is not None:
|
||||
metadata.organization = Metadata.id_to_title(org_component)
|
||||
if metadata.basename is None and basename is not None:
|
||||
metadata.basename = basename
|
||||
if metadata.finetune is None and finetune is not None:
|
||||
metadata.finetune = finetune
|
||||
if metadata.version is None and version is not None:
|
||||
metadata.version = version
|
||||
if metadata.parameter_class_attribute is None and parameter_class_attribute is not None:
|
||||
metadata.parameter_class_attribute = parameter_class_attribute
|
||||
if metadata.source_url is None and org_component is not None and model_full_name_component is not None:
|
||||
metadata.source_url = f"https://huggingface.co/{org_component}/{model_full_name_component}"
|
||||
if metadata.source_hf_repo is None and org_component is not None and model_full_name_component is not None:
|
||||
metadata.source_hf_repo = f"{org_component}/{model_full_name_component}"
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
class TestStringMethods(unittest.TestCase):
|
||||
|
||||
def test_get_model_id_components(self):
|
||||
self.assertEqual(Metadata.get_model_id_components("Mistral/Mixtral-8x7B-Instruct-v0.1"),
|
||||
('Mixtral-8x7B-Instruct-v0.1', "Mistral", 'Mixtral', 'Instruct', 'v0.1', '8x7B'))
|
||||
self.assertEqual(Metadata.get_model_id_components("Mixtral-8x7B-Instruct-v0.1"),
|
||||
('Mixtral-8x7B-Instruct-v0.1', None, 'Mixtral', 'Instruct', 'v0.1', '8x7B'))
|
||||
self.assertEqual(Metadata.get_model_id_components("Mixtral-8x7B-Instruct"),
|
||||
('Mixtral-8x7B-Instruct', None, 'Mixtral', 'Instruct', None, '8x7B'))
|
||||
self.assertEqual(Metadata.get_model_id_components("Mixtral-8x7B-v0.1"),
|
||||
('Mixtral-8x7B-v0.1', None, 'Mixtral', None, 'v0.1', '8x7B'))
|
||||
self.assertEqual(Metadata.get_model_id_components("Mixtral-8x7B"),
|
||||
('Mixtral-8x7B', None, 'Mixtral', None, None, '8x7B'))
|
||||
self.assertEqual(Metadata.get_model_id_components("Mixtral"),
|
||||
('Mixtral', None, 'Mixtral', None, None, None))
|
||||
self.assertEqual(Metadata.get_model_id_components("Mixtral-v0.1"),
|
||||
('Mixtral-v0.1', None, 'Mixtral', None, 'v0.1', None))
|
||||
self.assertEqual(Metadata.get_model_id_components("hermes-2-pro-llama-3-8b-DPO"),
|
||||
('hermes-2-pro-llama-3-8b-DPO', None, 'hermes-2-pro-llama-3', 'DPO', None, '8b'))
|
||||
self.assertEqual(Metadata.get_model_id_components("NousResearch/Meta-Llama-3-8B"),
|
||||
('Meta-Llama-3-8B', "NousResearch", 'Meta-Llama-3', None, None, "8B"))
|
||||
|
||||
def test_apply_metadata_heuristic_from_model_card(self):
|
||||
# Source: https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B/blob/main/README.md
|
||||
model_card = {
|
||||
'base_model': 'NousResearch/Meta-Llama-3-8B',
|
||||
'tags': ['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl'],
|
||||
'model-index': [{'name': 'Hermes-2-Pro-Llama-3-8B', 'results': []}],
|
||||
'language': ['en'],
|
||||
'datasets': ['teknium/OpenHermes-2.5'],
|
||||
'widget': [{'example_title': 'Hermes 2 Pro', 'messages': [{'role': 'system', 'content': 'You are a sentient, superintelligent artificial general intelligence, here to teach and assist me.'}, {'role': 'user', 'content': 'Write a short story about Goku discovering kirby has teamed up with Majin Buu to destroy the world.'}]}]
|
||||
}
|
||||
expected = Metadata(name='Hermes 2 Pro Llama 3 8B', basename='Hermes-2-Pro-Llama-3', finetune=None, author=None, quantized_by=None, organization=None, version=None, base_version=None, url=None, description=None, license=None, license_name=None, license_link=None, source_url=None, source_hf_repo=None, parameter_class_attribute='8B', tags=['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl'], languages=['en'], datasets=['teknium/OpenHermes-2.5'])
|
||||
|
||||
got = Metadata.apply_metadata_heuristic(Metadata(), model_card, None, None)
|
||||
|
||||
self.assertEqual(got, expected)
|
||||
|
||||
def test_apply_metadata_heuristic_from_hf_parameters(self):
|
||||
# Source: https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B/blob/main/config.json
|
||||
hf_params = {"_name_or_path": "./hermes-2-pro-llama-3-8b-DPO"}
|
||||
|
||||
expected = Metadata(name='Hermes 2 Pro Llama 3 8B DPO', basename='hermes-2-pro-llama-3', finetune='DPO', author=None, quantized_by=None, organization=None, version=None, base_version=None, url=None, description=None, license=None, license_name=None, license_link=None, source_url=None, source_hf_repo=None, parameter_class_attribute='8b', tags=None, languages=None, datasets=None)
|
||||
|
||||
got = Metadata.apply_metadata_heuristic(Metadata(), None, hf_params, None)
|
||||
|
||||
self.assertEqual(got, expected)
|
||||
|
||||
def test_apply_metadata_heuristic_from_model_dir(self):
|
||||
# Source: https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B/blob/main/config.json
|
||||
model_dir_path = Path("./hermes-2-pro-llama-3-8b-DPO")
|
||||
|
||||
expected = Metadata(name='Hermes 2 Pro Llama 3 8B DPO', basename='hermes-2-pro-llama-3', finetune='DPO', author=None, quantized_by=None, organization=None, version=None, base_version=None, url=None, description=None, license=None, license_name=None, license_link=None, source_url=None, source_hf_repo=None, parameter_class_attribute='8b', tags=None, languages=None, datasets=None)
|
||||
|
||||
got = Metadata.apply_metadata_heuristic(Metadata(), None, None, model_dir_path)
|
||||
|
||||
self.assertEqual(got, expected)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -34,7 +34,7 @@ def model_weight_count_rounded_notation(model_params_count: int) -> str:
|
|||
return f"{round(scaled_model_params)}{scale_suffix}"
|
||||
|
||||
|
||||
def parameter_weight_class(expert_count_int:int, model_params_count: int) -> str:
|
||||
def parameter_class_attribute(expert_count_int:int, model_params_count: int) -> str:
|
||||
per_model_rounded_weight_estimate = model_weight_count_rounded_notation(model_params_count)
|
||||
|
||||
if expert_count_int is not None and expert_count_int > 0:
|
||||
|
@ -45,7 +45,7 @@ def parameter_weight_class(expert_count_int:int, model_params_count: int) -> str
|
|||
return size_class
|
||||
|
||||
|
||||
def naming_convention(model_name: str, base_name: str, finetune_string:str, version_string:str, parameter_weight_class: str, output_type: str) -> str:
|
||||
def naming_convention(model_name: str, base_name: str, finetune_string:str, version_string:str, parameter_class_attribute: str, output_type: str) -> str:
|
||||
# Reference: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#gguf-naming-convention
|
||||
|
||||
if base_name is not None:
|
||||
|
@ -55,7 +55,7 @@ def naming_convention(model_name: str, base_name: str, finetune_string:str, vers
|
|||
else:
|
||||
name = "ggml-model"
|
||||
|
||||
parameters = f"-{parameter_weight_class}" if parameter_weight_class is not None else ""
|
||||
parameters = f"-{parameter_class_attribute}" if parameter_class_attribute is not None else ""
|
||||
|
||||
finetune = f"-{finetune_string.strip().title().replace(' ', '-')}" if finetune_string is not None else ""
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue