convert-*.py: adjusted authorship KV store

This commit is contained in:
brian khuu 2024-06-07 03:33:21 +10:00
parent 91e65d9485
commit d060fcdbe2
5 changed files with 179 additions and 72 deletions

View file

@ -251,20 +251,51 @@ class Model:
self.gguf_writer.add_organization(self.metadata.organization) self.gguf_writer.add_organization(self.metadata.organization)
if self.metadata.version is not None: if self.metadata.version is not None:
self.gguf_writer.add_version(self.metadata.version) self.gguf_writer.add_version(self.metadata.version)
if self.metadata.base_version is not None:
self.gguf_writer.add_base_version(self.metadata.base_version)
if self.metadata.url is not None: if self.metadata.url is not None:
self.gguf_writer.add_url(self.metadata.url) self.gguf_writer.add_url(self.metadata.url)
if self.metadata.doi is not None:
self.gguf_writer.add_doi(self.metadata.doi)
if self.metadata.uuid is not None:
self.gguf_writer.add_uuid(self.metadata.uuid)
if self.metadata.hf_repo is not None:
self.gguf_writer.add_hf_repo(self.metadata.hf_repo)
if self.metadata.description is not None: if self.metadata.description is not None:
self.gguf_writer.add_description(self.metadata.description) self.gguf_writer.add_description(self.metadata.description)
if self.metadata.license is not None: if self.metadata.license is not None:
self.gguf_writer.add_license(self.metadata.license) self.gguf_writer.add_license(self.metadata.license)
if self.metadata.license_name is not None:
self.gguf_writer.add_license_name(self.metadata.license_name)
if self.metadata.license_link is not None:
self.gguf_writer.add_license_link(self.metadata.license_link)
if self.metadata.source_url is not None: if self.metadata.source_url is not None:
self.gguf_writer.add_source_url(self.metadata.source_url) self.gguf_writer.add_source_url(self.metadata.source_url)
if self.metadata.source_doi is not None:
self.gguf_writer.add_source_doi(self.metadata.source_doi)
if self.metadata.source_uuid is not None:
self.gguf_writer.add_source_uuid(self.metadata.source_uuid)
if self.metadata.source_hf_repo is not None: if self.metadata.source_hf_repo is not None:
self.gguf_writer.add_source_hf_repo(self.metadata.source_hf_repo) self.gguf_writer.add_source_hf_repo(self.metadata.source_hf_repo)
if self.metadata.parameter_class_attribute is not None: if self.metadata.parameter_class_attribute is not None:
self.gguf_writer.add_parameter_class_attribute(self.metadata.parameter_class_attribute) self.gguf_writer.add_parameter_class_attribute(self.metadata.parameter_class_attribute)
if self.metadata.parents is not None:
metadata.parent_count = len(self.metadata.parents)
for key, parent_entry in self.metadata.parents:
if "name" in parent_entry:
self.gguf_writer.add_parent_name(key, parent_entry.get("name"))
if "author" in parent_entry:
self.gguf_writer.add_parent_author(key, parent_entry.get("author"))
if "version" in parent_entry:
self.gguf_writer.add_parent_version(key, parent_entry.get("version"))
if "organization" in parent_entry:
self.gguf_writer.add_parent_organization(key, parent_entry.get("organization"))
if "url" in parent_entry:
self.gguf_writer.add_parent_url(key, parent_entry.get("url"))
if "doi" in parent_entry:
self.gguf_writer.add_parent_doi(key, parent_entry.get("doi"))
if "uuid" in parent_entry:
self.gguf_writer.add_parent_uuid(key, parent_entry.get("uuid"))
if "hf_repo" in parent_entry:
self.gguf_writer.add_parent_hf_repo(key, parent_entry.get("hf_repo"))
if self.metadata.tags is not None: if self.metadata.tags is not None:
self.gguf_writer.add_tags(self.metadata.tags) self.gguf_writer.add_tags(self.metadata.tags)
if self.metadata.languages is not None: if self.metadata.languages is not None:

View file

@ -793,20 +793,47 @@ class OutputFile:
self.add_organization(metadata.organization) self.add_organization(metadata.organization)
if metadata.version is not None: if metadata.version is not None:
self.gguf.add_version(metadata.version) self.gguf.add_version(metadata.version)
if metadata.base_version is not None:
self.gguf.add_base_version(metadata.base_version)
if metadata.url is not None: if metadata.url is not None:
self.gguf.add_url(metadata.url) self.gguf.add_url(metadata.url)
if metadata.doi is not None:
self.gguf.add_doi(metadata.doi)
if metadata.uuid is not None:
self.gguf.add_uuid(metadata.uuid)
if metadata.hf_repo is not None:
self.gguf.add_hf_repo(metadata.hf_repo)
if metadata.description is not None: if metadata.description is not None:
self.gguf.add_description(metadata.description) self.gguf.add_description(metadata.description)
if metadata.license is not None: if metadata.license is not None:
self.gguf.add_license(metadata.license) self.gguf.add_license(metadata.license)
if metadata.license_name is not None:
self.gguf.add_license_name(metadata.license_name)
if metadata.license_link is not None:
self.gguf.add_license_link(metadata.license_link)
if metadata.source_url is not None: if metadata.source_url is not None:
self.gguf.add_source_url(metadata.source_url) self.gguf.add_source_url(metadata.source_url)
if metadata.source_hf_repo is not None: if metadata.source_hf_repo is not None:
self.gguf.add_source_hf_repo(metadata.source_hf_repo) self.gguf.add_source_hf_repo(metadata.source_hf_repo)
if metadata.parameter_class_attribute is not None: if metadata.parameter_class_attribute is not None:
self.gguf.add_parameter_class_attribute(metadata.parameter_class_attribute) self.gguf.add_parameter_class_attribute(metadata.parameter_class_attribute)
if metadata.parents is not None:
metadata.parent_count = len(metadata.parents)
for key, parent_entry in metadata.parents:
if "name" in parent_entry:
self.gguf.add_parent_name(key, parent_entry.get("name"))
if "author" in parent_entry:
self.gguf.add_parent_author(key, parent_entry.get("author"))
if "version" in parent_entry:
self.gguf.add_parent_version(key, parent_entry.get("version"))
if "organization" in parent_entry:
self.gguf.add_parent_organization(key, parent_entry.get("organization"))
if "url" in parent_entry:
self.gguf.add_parent_url(key, parent_entry.get("url"))
if "doi" in parent_entry:
self.gguf.add_parent_doi(key, parent_entry.get("doi"))
if "uuid" in parent_entry:
self.gguf.add_parent_uuid(key, parent_entry.get("uuid"))
if "hf_repo" in parent_entry:
self.gguf.add_parent_hf_repo(key, parent_entry.get("hf_repo"))
if metadata.tags is not None: if metadata.tags is not None:
self.gguf.add_tags(metadata.tags) self.gguf.add_tags(metadata.tags)
if metadata.languages is not None: if metadata.languages is not None:

View file

@ -23,23 +23,53 @@ class Keys:
ARCHITECTURE = "general.architecture" ARCHITECTURE = "general.architecture"
QUANTIZATION_VERSION = "general.quantization_version" QUANTIZATION_VERSION = "general.quantization_version"
ALIGNMENT = "general.alignment" ALIGNMENT = "general.alignment"
FILE_TYPE = "general.file_type"
# Authorship Metadata
NAME = "general.name" NAME = "general.name"
AUTHOR = "general.author"
VERSION = "general.version"
ORGANIZATION = "general.organization"
BASENAME = "general.basename" BASENAME = "general.basename"
FINETUNE = "general.finetune" FINETUNE = "general.finetune"
AUTHOR = "general.author"
QUANTIZED_BY = "general.quantized_by"
ORGANIZATION = "general.organization"
VERSION = "general.version"
BASE_VERSION = "general.base_version"
URL = "general.url"
DESCRIPTION = "general.description" DESCRIPTION = "general.description"
QUANTIZED_BY = "general.quantized_by"
PARAMETER_CLASS_ATTRIBUTE = "general.parameter_class_attribute"
# Licensing details
LICENSE = "general.license" LICENSE = "general.license"
LICENSE_NAME = "general.license.name" LICENSE_NAME = "general.license.name"
LICENSE_LINK = "general.license.link" LICENSE_LINK = "general.license.link"
# Typically represents the converted GGUF repo (Unless native)
URL = "general.url"
DOI = "general.doi"
UUID = "general.uuid"
HF_URL = "general.huggingface.repository"
# Typically represents the original source repository (e.g. safetensors)
# that this was
SOURCE_URL = "general.source.url" SOURCE_URL = "general.source.url"
SOURCE_DOI = "general.source.doi"
SOURCE_UUID = "general.source.uuid"
SOURCE_HF_REPO = "general.source.huggingface.repository" SOURCE_HF_REPO = "general.source.huggingface.repository"
FILE_TYPE = "general.file_type"
PARAMETER_CLASS_ATTRIBUTE = "general.parameter_class_attribute" # This represents the parent model that the converted/source model was
# derived from on allowing users to trace the linage of a model.
# E.g. A finetune model would have the base model as the parent
# (A model can have multiple parent, especially if it's a merged model)
PARENTS_COUNT = "general.parents.count"
PARENTS_NAME = "general.parents.{id}.name"
PARENTS_AUTHOR = "general.parents.{id}.author"
PARENTS_VERSION = "general.parents.{id}.version"
PARENTS_ORGANIZATION = "general.parents.{id}.organization"
PARENTS_URL = "general.parents.{id}.url"
PARENTS_DOI = "general.parents.{id}.doi"
PARENTS_UUID = "general.parents.{id}.uuid"
PARENTS_HF_REPO = "general.parents.{id}.huggingface.repository"
# Array based KV stores
TAGS = "general.tags" TAGS = "general.tags"
LANGUAGES = "general.languages" LANGUAGES = "general.languages"
DATASETS = "general.datasets" DATASETS = "general.datasets"

View file

@ -448,15 +448,21 @@ class GGUFWriter:
def add_version(self, version: str) -> None: def add_version(self, version: str) -> None:
self.add_string(Keys.General.VERSION, version) self.add_string(Keys.General.VERSION, version)
def add_base_version(self, version: str) -> None:
self.add_string(Keys.General.BASE_VERSION, version)
def add_tensor_data_layout(self, layout: str) -> None: def add_tensor_data_layout(self, layout: str) -> None:
self.add_string(Keys.LLM.TENSOR_DATA_LAYOUT.format(arch=self.arch), layout) self.add_string(Keys.LLM.TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
def add_url(self, url: str) -> None: def add_url(self, url: str) -> None:
self.add_string(Keys.General.URL, url) self.add_string(Keys.General.URL, url)
def add_doi(self, doi: str) -> None:
self.add_string(Keys.General.DOI, doi)
def add_uuid(self, uuid: str) -> None:
self.add_string(Keys.General.UUID, uuid)
def add_hf_repo(self, hf_repo: str) -> None:
self.add_string(Keys.General.HF_REPO, hf_repo)
def add_description(self, description: str) -> None: def add_description(self, description: str) -> None:
self.add_string(Keys.General.DESCRIPTION, description) self.add_string(Keys.General.DESCRIPTION, description)
@ -481,6 +487,33 @@ class GGUFWriter:
def add_parameter_class_attribute(self, parameter_class_attribute: str) -> None: def add_parameter_class_attribute(self, parameter_class_attribute: str) -> None:
self.add_string(Keys.General.PARAMETER_CLASS_ATTRIBUTE, parameter_class_attribute) self.add_string(Keys.General.PARAMETER_CLASS_ATTRIBUTE, parameter_class_attribute)
def add_parent_count(self, parent_count: int) -> None:
self.add_uint32(Keys.General.PARENTS_COUNT, parent_count)
def add_parent_name(self, parent_id: int, name: str) -> None:
self.add_string(Keys.General.PARENTS_NAME.format(id=self.parent_id), name)
def add_parent_author(self, parent_id: int, author: str) -> None:
self.add_string(Keys.General.PARENTS_AUTHOR.format(id=self.parent_id), author)
def add_parent_version(self, parent_id: int, version: str) -> None:
self.add_string(Keys.General.PARENTS_VERSION.format(id=self.parent_id), version)
def add_parent_organization(self, parent_id: int, organization: str) -> None:
self.add_string(Keys.General.PARENTS_ORGANIZATION.format(id=self.parent_id), organization)
def add_parent_url(self, parent_id: int, url: str) -> None:
self.add_string(Keys.General.PARENTS_URL.format(id=self.parent_id), url)
def add_parent_doi(self, parent_id: int, doi: str) -> None:
self.add_string(Keys.General.PARENTS_DOI.format(id=self.parent_id), doi)
def add_parent_uuid(self, parent_id: int, uuid: str) -> None:
self.add_string(Keys.General.PARENTS_UUID.format(id=self.parent_id), uuid)
def add_parent_hf_repo(self, parent_id: int, hf_repo: str) -> None:
self.add_string(Keys.General.PARENTS_HF_REPO.format(id=self.parent_id), hf_repo)
def add_tags(self, tags: Sequence[str]) -> None: def add_tags(self, tags: Sequence[str]) -> None:
self.add_array(Keys.General.TAGS, tags) self.add_array(Keys.General.TAGS, tags)

View file

@ -26,15 +26,20 @@ class Metadata:
quantized_by: Optional[str] = None quantized_by: Optional[str] = None
organization: Optional[str] = None organization: Optional[str] = None
version: Optional[str] = None version: Optional[str] = None
base_version: Optional[str] = None
url: Optional[str] = None url: Optional[str] = None
doi: Optional[str] = None
uuid: Optional[str] = None
hf_repo: Optional[str] = None
description: Optional[str] = None description: Optional[str] = None
license: Optional[str] = None license: Optional[str] = None
license_name: Optional[str] = None license_name: Optional[str] = None
license_link: Optional[str] = None license_link: Optional[str] = None
source_url: Optional[str] = None source_url: Optional[str] = None
source_doi: Optional[str] = None
source_uuid: Optional[str] = None
source_hf_repo: Optional[str] = None source_hf_repo: Optional[str] = None
parameter_class_attribute: Optional[str] = None parameter_class_attribute: Optional[str] = None
parents: Optional[list[dict]] = None
tags: Optional[list[str]] = None tags: Optional[list[str]] = None
languages: Optional[list[str]] = None languages: Optional[list[str]] = None
datasets: Optional[list[str]] = None datasets: Optional[list[str]] = None
@ -57,22 +62,34 @@ class Metadata:
# Metadata Override File Provided # Metadata Override File Provided
# This is based on LLM_KV_NAMES mapping in llama.cpp # This is based on LLM_KV_NAMES mapping in llama.cpp
metadata_override = Metadata.load_metadata_override(metadata_override_path) metadata_override = Metadata.load_metadata_override(metadata_override_path)
metadata.name = metadata_override.get(Keys.General.NAME , metadata.name ) # noqa: E202 metadata.name = metadata_override.get(Keys.General.NAME , metadata.name ) # noqa: E202
metadata.author = metadata_override.get(Keys.General.AUTHOR , metadata.author ) # noqa: E202
metadata.version = metadata_override.get(Keys.General.VERSION , metadata.version ) # noqa: E202
metadata.organization = metadata_override.get(Keys.General.ORGANIZATION , metadata.organization ) # noqa: E202
metadata.basename = metadata_override.get(Keys.General.BASENAME , metadata.basename ) # noqa: E202 metadata.basename = metadata_override.get(Keys.General.BASENAME , metadata.basename ) # noqa: E202
metadata.finetune = metadata_override.get(Keys.General.FINETUNE , metadata.finetune ) # noqa: E202 metadata.finetune = metadata_override.get(Keys.General.FINETUNE , metadata.finetune ) # noqa: E202
metadata.author = metadata_override.get(Keys.General.AUTHOR , metadata.author ) # noqa: E202
metadata.quantized_by = metadata_override.get(Keys.General.QUANTIZED_BY , metadata.quantized_by ) # noqa: E202
metadata.organization = metadata_override.get(Keys.General.ORGANIZATION , metadata.organization ) # noqa: E202
metadata.version = metadata_override.get(Keys.General.VERSION , metadata.version ) # noqa: E202
metadata.base_version = metadata_override.get(Keys.General.BASE_VERSION , metadata.base_version ) # noqa: E202
metadata.url = metadata_override.get(Keys.General.URL , metadata.url ) # noqa: E202
metadata.description = metadata_override.get(Keys.General.DESCRIPTION , metadata.description ) # noqa: E202 metadata.description = metadata_override.get(Keys.General.DESCRIPTION , metadata.description ) # noqa: E202
metadata.quantized_by = metadata_override.get(Keys.General.QUANTIZED_BY , metadata.quantized_by ) # noqa: E202
metadata.parameter_class_attribute = metadata_override.get(Keys.General.PARAMETER_CLASS_ATTRIBUTE, metadata.parameter_class_attribute) # noqa: E202
metadata.license = metadata_override.get(Keys.General.LICENSE , metadata.license ) # noqa: E202 metadata.license = metadata_override.get(Keys.General.LICENSE , metadata.license ) # noqa: E202
metadata.license_name = metadata_override.get(Keys.General.LICENSE_NAME , metadata.license_name ) # noqa: E202 metadata.license_name = metadata_override.get(Keys.General.LICENSE_NAME , metadata.license_name ) # noqa: E202
metadata.license_link = metadata_override.get(Keys.General.LICENSE_LINK , metadata.license_link ) # noqa: E202 metadata.license_link = metadata_override.get(Keys.General.LICENSE_LINK , metadata.license_link ) # noqa: E202
metadata.url = metadata_override.get(Keys.General.URL , metadata.url ) # noqa: E202
metadata.doi = metadata_override.get(Keys.General.DOI , metadata.doi ) # noqa: E202
metadata.uuid = metadata_override.get(Keys.General.UUID , metadata.uuid ) # noqa: E202
metadata.hf_repo = metadata_override.get(Keys.General.HF_REPO , metadata.hf_repo ) # noqa: E202
metadata.source_url = metadata_override.get(Keys.General.SOURCE_URL , metadata.source_url ) # noqa: E202 metadata.source_url = metadata_override.get(Keys.General.SOURCE_URL , metadata.source_url ) # noqa: E202
metadata.source_doi = metadata_override.get(Keys.General.SOURCE_DOI , metadata.source_doi ) # noqa: E202
metadata.source_uuid = metadata_override.get(Keys.General.SOURCE_UUID , metadata.source_uuid ) # noqa: E202
metadata.source_hf_repo = metadata_override.get(Keys.General.SOURCE_HF_REPO , metadata.source_hf_repo ) # noqa: E202 metadata.source_hf_repo = metadata_override.get(Keys.General.SOURCE_HF_REPO , metadata.source_hf_repo ) # noqa: E202
metadata.parameter_class_attribute = metadata_override.get(Keys.General.PARAMETER_CLASS_ATTRIBUTE, metadata.parameter_class_attribute) # noqa: E202
metadata.parent_count = metadata_override.get("general.parents" , metadata.parent_count ) # noqa: E202
metadata.tags = metadata_override.get(Keys.General.TAGS , metadata.tags ) # noqa: E202 metadata.tags = metadata_override.get(Keys.General.TAGS , metadata.tags ) # noqa: E202
metadata.languages = metadata_override.get(Keys.General.LANGUAGES , metadata.languages ) # noqa: E202 metadata.languages = metadata_override.get(Keys.General.LANGUAGES , metadata.languages ) # noqa: E202
metadata.datasets = metadata_override.get(Keys.General.DATASETS , metadata.datasets ) # noqa: E202 metadata.datasets = metadata_override.get(Keys.General.DATASETS , metadata.datasets ) # noqa: E202
@ -169,7 +186,7 @@ class Metadata:
@staticmethod @staticmethod
def apply_metadata_heuristic(metadata: Metadata, model_card: Optional[dict] = None, hf_params: Optional[dict] = None, model_path: Optional[Path] = None) -> Metadata: def apply_metadata_heuristic(metadata: Metadata, model_card: Optional[dict] = None, hf_params: Optional[dict] = None, model_path: Optional[Path] = None) -> Metadata:
# Reference Model Card Metadata: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1 # Reference Model Card Metadata: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
found_model_name = False found_base_model = False
# Model Card Heuristics # Model Card Heuristics
######################## ########################
@ -180,32 +197,7 @@ class Metadata:
# such as TheBloke who would encode 'Mixtral 8X7B Instruct v0.1' into model_name # such as TheBloke who would encode 'Mixtral 8X7B Instruct v0.1' into model_name
metadata.name = model_card.get("model_name") metadata.name = model_card.get("model_name")
if "model-index" in model_card and len(model_card["model-index"]) == 1 and "name" in model_card["model-index"][0]: if "base_model" in model_card and isinstance(model_card["base_model"], str) and not found_base_model:
# This is a model index which has model id that can be extracted into organization and model name
# if so then we can safely extract organization and name
# (This is a safe choice in case there is multiple models in one repo in the future)
model_id = model_card["model-index"][0]["name"]
model_full_name_component, org_component, basename, finetune, version, parameter_class_attribute = Metadata.get_model_id_components(model_id)
if metadata.name is None and model_full_name_component is not None:
metadata.name = Metadata.id_to_title(model_full_name_component)
if metadata.organization is None and org_component is not None:
metadata.organization = Metadata.id_to_title(org_component)
if metadata.basename is None and basename is not None:
metadata.basename = basename
if metadata.finetune is None and finetune is not None:
metadata.finetune = finetune
if metadata.version is None and version is not None:
metadata.version = version
if metadata.parameter_class_attribute is None and parameter_class_attribute is not None:
metadata.parameter_class_attribute = parameter_class_attribute
if metadata.source_url is None and org_component is not None and model_full_name_component is not None:
metadata.source_url = f"https://huggingface.co/{org_component}/{model_full_name_component}"
if metadata.source_hf_repo is None and org_component is not None and model_full_name_component is not None:
metadata.source_hf_repo = f"{org_component}/{model_full_name_component}"
found_model_name = True
if "base_model" in model_card and isinstance(model_card["base_model"], str) and not found_model_name:
# Check if string. We cannot handle lists as that is too ambagious # Check if string. We cannot handle lists as that is too ambagious
# Example: stabilityai/stable-diffusion-xl-base-1.0. Can also be a list (for merges) # Example: stabilityai/stable-diffusion-xl-base-1.0. Can also be a list (for merges)
model_id = model_card.get("base_model") model_id = model_card.get("base_model")
@ -227,7 +219,7 @@ class Metadata:
if metadata.source_hf_repo is None and org_component is not None and model_full_name_component is not None: if metadata.source_hf_repo is None and org_component is not None and model_full_name_component is not None:
metadata.source_hf_repo = f"{org_component}/{model_full_name_component}" metadata.source_hf_repo = f"{org_component}/{model_full_name_component}"
found_model_name = True found_base_model = True
if metadata.quantized_by is None: if metadata.quantized_by is None:
# Not part of hugging face model card standard, but is used by TheBloke to credit them self for quantizing 3rd party models # Not part of hugging face model card standard, but is used by TheBloke to credit them self for quantizing 3rd party models
@ -254,7 +246,7 @@ class Metadata:
if hf_params is not None: if hf_params is not None:
hf_name_or_path = hf_params.get("_name_or_path") hf_name_or_path = hf_params.get("_name_or_path")
if hf_name_or_path is not None and hf_name_or_path.count('/') <= 1 and not found_model_name: if hf_name_or_path is not None and hf_name_or_path.count('/') <= 1 and not found_base_model:
# Use _name_or_path only if its actually a model name and not some computer path # Use _name_or_path only if its actually a model name and not some computer path
# e.g. 'meta-llama/Llama-2-7b-hf' # e.g. 'meta-llama/Llama-2-7b-hf'
model_id = hf_name_or_path model_id = hf_name_or_path
@ -333,7 +325,7 @@ class TestStringMethods(unittest.TestCase):
'datasets': ['teknium/OpenHermes-2.5'], 'datasets': ['teknium/OpenHermes-2.5'],
'widget': [{'example_title': 'Hermes 2 Pro', 'messages': [{'role': 'system', 'content': 'You are a sentient, superintelligent artificial general intelligence, here to teach and assist me.'}, {'role': 'user', 'content': 'Write a short story about Goku discovering kirby has teamed up with Majin Buu to destroy the world.'}]}] 'widget': [{'example_title': 'Hermes 2 Pro', 'messages': [{'role': 'system', 'content': 'You are a sentient, superintelligent artificial general intelligence, here to teach and assist me.'}, {'role': 'user', 'content': 'Write a short story about Goku discovering kirby has teamed up with Majin Buu to destroy the world.'}]}]
} }
expected = Metadata(name='Hermes 2 Pro Llama 3 8B', basename='Hermes-2-Pro-Llama-3', finetune=None, author=None, quantized_by=None, organization=None, version=None, base_version=None, url=None, description=None, license=None, license_name=None, license_link=None, source_url=None, source_hf_repo=None, parameter_class_attribute='8B', tags=['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl'], languages=['en'], datasets=['teknium/OpenHermes-2.5']) expected = Metadata(name='Meta Llama 3 8B', basename='Meta-Llama-3', finetune=None, author=None, quantized_by=None, organization='NousResearch', version=None, url=None, doi=None, uuid=None, hf_repo=None, description=None, license=None, license_name=None, license_link=None, source_url='https://huggingface.co/NousResearch/Meta-Llama-3-8B', source_doi=None, source_uuid=None, source_hf_repo='NousResearch/Meta-Llama-3-8B', parameter_class_attribute='8B', parents=None, tags=['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl'], languages=['en'], datasets=['teknium/OpenHermes-2.5'])
got = Metadata.apply_metadata_heuristic(Metadata(), model_card, None, None) got = Metadata.apply_metadata_heuristic(Metadata(), model_card, None, None)
@ -342,21 +334,15 @@ class TestStringMethods(unittest.TestCase):
def test_apply_metadata_heuristic_from_hf_parameters(self): def test_apply_metadata_heuristic_from_hf_parameters(self):
# Source: https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B/blob/main/config.json # Source: https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B/blob/main/config.json
hf_params = {"_name_or_path": "./hermes-2-pro-llama-3-8b-DPO"} hf_params = {"_name_or_path": "./hermes-2-pro-llama-3-8b-DPO"}
expected = Metadata(name='Hermes 2 Pro Llama 3 8B DPO', basename='hermes-2-pro-llama-3', finetune='DPO', author=None, quantized_by=None, organization=None, version=None, url=None, doi=None, uuid=None, hf_repo=None, description=None, license=None, license_name=None, license_link=None, source_url=None, source_doi=None, source_uuid=None, source_hf_repo=None, parameter_class_attribute='8b', parents=None, tags=None, languages=None, datasets=None)
expected = Metadata(name='Hermes 2 Pro Llama 3 8B DPO', basename='hermes-2-pro-llama-3', finetune='DPO', author=None, quantized_by=None, organization=None, version=None, base_version=None, url=None, description=None, license=None, license_name=None, license_link=None, source_url=None, source_hf_repo=None, parameter_class_attribute='8b', tags=None, languages=None, datasets=None)
got = Metadata.apply_metadata_heuristic(Metadata(), None, hf_params, None) got = Metadata.apply_metadata_heuristic(Metadata(), None, hf_params, None)
self.assertEqual(got, expected) self.assertEqual(got, expected)
def test_apply_metadata_heuristic_from_model_dir(self): def test_apply_metadata_heuristic_from_model_dir(self):
# Source: https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B/blob/main/config.json # Source: https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B/blob/main/config.json
model_dir_path = Path("./hermes-2-pro-llama-3-8b-DPO") model_dir_path = Path("./hermes-2-pro-llama-3-8b-DPO")
expected = Metadata(name='Hermes 2 Pro Llama 3 8B DPO', basename='hermes-2-pro-llama-3', finetune='DPO', author=None, quantized_by=None, organization=None, version=None, url=None, doi=None, uuid=None, hf_repo=None, description=None, license=None, license_name=None, license_link=None, source_url=None, source_doi=None, source_uuid=None, source_hf_repo=None, parameter_class_attribute='8b', parents=None, tags=None, languages=None, datasets=None)
expected = Metadata(name='Hermes 2 Pro Llama 3 8B DPO', basename='hermes-2-pro-llama-3', finetune='DPO', author=None, quantized_by=None, organization=None, version=None, base_version=None, url=None, description=None, license=None, license_name=None, license_link=None, source_url=None, source_hf_repo=None, parameter_class_attribute='8b', tags=None, languages=None, datasets=None)
got = Metadata.apply_metadata_heuristic(Metadata(), None, None, model_dir_path) got = Metadata.apply_metadata_heuristic(Metadata(), None, None, model_dir_path)
self.assertEqual(got, expected) self.assertEqual(got, expected)