convert-*.py: separated unit test, hf_repo to repo_url
This commit is contained in:
parent
d060fcdbe2
commit
eaa47f5546
8 changed files with 409 additions and 273 deletions
|
@ -239,63 +239,71 @@ class Model:
|
||||||
def set_gguf_meta_model(self):
|
def set_gguf_meta_model(self):
|
||||||
self.gguf_writer.add_name(self.metadata.name)
|
self.gguf_writer.add_name(self.metadata.name)
|
||||||
|
|
||||||
if self.metadata.basename is not None:
|
|
||||||
self.gguf_writer.add_basename(self.metadata.basename)
|
|
||||||
if self.metadata.finetune is not None:
|
|
||||||
self.gguf_writer.add_finetune(self.metadata.finetune)
|
|
||||||
if self.metadata.author is not None:
|
if self.metadata.author is not None:
|
||||||
self.gguf_writer.add_author(self.metadata.author)
|
self.gguf_writer.add_author(self.metadata.author)
|
||||||
if self.metadata.quantized_by is not None:
|
|
||||||
self.gguf_writer.add_quantized_by(self.metadata.quantized_by)
|
|
||||||
if self.metadata.organization is not None:
|
|
||||||
self.gguf_writer.add_organization(self.metadata.organization)
|
|
||||||
if self.metadata.version is not None:
|
if self.metadata.version is not None:
|
||||||
self.gguf_writer.add_version(self.metadata.version)
|
self.gguf_writer.add_version(self.metadata.version)
|
||||||
if self.metadata.url is not None:
|
if self.metadata.organization is not None:
|
||||||
self.gguf_writer.add_url(self.metadata.url)
|
self.gguf_writer.add_organization(self.metadata.organization)
|
||||||
if self.metadata.doi is not None:
|
|
||||||
self.gguf_writer.add_doi(self.metadata.doi)
|
if self.metadata.finetune is not None:
|
||||||
if self.metadata.uuid is not None:
|
self.gguf_writer.add_finetune(self.metadata.finetune)
|
||||||
self.gguf_writer.add_uuid(self.metadata.uuid)
|
if self.metadata.basename is not None:
|
||||||
if self.metadata.hf_repo is not None:
|
self.gguf_writer.add_basename(self.metadata.basename)
|
||||||
self.gguf_writer.add_hf_repo(self.metadata.hf_repo)
|
|
||||||
if self.metadata.description is not None:
|
if self.metadata.description is not None:
|
||||||
self.gguf_writer.add_description(self.metadata.description)
|
self.gguf_writer.add_description(self.metadata.description)
|
||||||
|
if self.metadata.quantized_by is not None:
|
||||||
|
self.gguf_writer.add_quantized_by(self.metadata.quantized_by)
|
||||||
|
|
||||||
|
if self.metadata.parameter_class_attribute is not None:
|
||||||
|
self.gguf_writer.add_parameter_class_attribute(self.metadata.parameter_class_attribute)
|
||||||
|
|
||||||
if self.metadata.license is not None:
|
if self.metadata.license is not None:
|
||||||
self.gguf_writer.add_license(self.metadata.license)
|
self.gguf_writer.add_license(self.metadata.license)
|
||||||
if self.metadata.license_name is not None:
|
if self.metadata.license_name is not None:
|
||||||
self.gguf_writer.add_license_name(self.metadata.license_name)
|
self.gguf_writer.add_license_name(self.metadata.license_name)
|
||||||
if self.metadata.license_link is not None:
|
if self.metadata.license_link is not None:
|
||||||
self.gguf_writer.add_license_link(self.metadata.license_link)
|
self.gguf_writer.add_license_link(self.metadata.license_link)
|
||||||
|
|
||||||
|
if self.metadata.url is not None:
|
||||||
|
self.gguf_writer.add_url(self.metadata.url)
|
||||||
|
if self.metadata.doi is not None:
|
||||||
|
self.gguf_writer.add_doi(self.metadata.doi)
|
||||||
|
if self.metadata.uuid is not None:
|
||||||
|
self.gguf_writer.add_uuid(self.metadata.uuid)
|
||||||
|
if self.metadata.repo_url is not None:
|
||||||
|
self.gguf_writer.add_repo_url(self.metadata.repo_url)
|
||||||
|
|
||||||
if self.metadata.source_url is not None:
|
if self.metadata.source_url is not None:
|
||||||
self.gguf_writer.add_source_url(self.metadata.source_url)
|
self.gguf_writer.add_source_url(self.metadata.source_url)
|
||||||
if self.metadata.source_doi is not None:
|
if self.metadata.source_doi is not None:
|
||||||
self.gguf_writer.add_source_doi(self.metadata.source_doi)
|
self.gguf_writer.add_source_doi(self.metadata.source_doi)
|
||||||
if self.metadata.source_uuid is not None:
|
if self.metadata.source_uuid is not None:
|
||||||
self.gguf_writer.add_source_uuid(self.metadata.source_uuid)
|
self.gguf_writer.add_source_uuid(self.metadata.source_uuid)
|
||||||
if self.metadata.source_hf_repo is not None:
|
if self.metadata.source_repo_url is not None:
|
||||||
self.gguf_writer.add_source_hf_repo(self.metadata.source_hf_repo)
|
self.gguf_writer.add_source_repo_url(self.metadata.source_repo_url)
|
||||||
if self.metadata.parameter_class_attribute is not None:
|
|
||||||
self.gguf_writer.add_parameter_class_attribute(self.metadata.parameter_class_attribute)
|
if self.metadata.base_models is not None:
|
||||||
if self.metadata.parents is not None:
|
self.gguf_writer.add_base_model_count(len(self.metadata.base_models))
|
||||||
metadata.parent_count = len(self.metadata.parents)
|
for key, base_model_entry in enumerate(self.metadata.base_models):
|
||||||
for key, parent_entry in self.metadata.parents:
|
if "name" in base_model_entry:
|
||||||
if "name" in parent_entry:
|
self.gguf_writer.add_base_model_name(key, base_model_entry["name"])
|
||||||
self.gguf_writer.add_parent_name(key, parent_entry.get("name"))
|
if "author" in base_model_entry:
|
||||||
if "author" in parent_entry:
|
self.gguf_writer.add_base_model_author(key, base_model_entry["author"])
|
||||||
self.gguf_writer.add_parent_author(key, parent_entry.get("author"))
|
if "version" in base_model_entry:
|
||||||
if "version" in parent_entry:
|
self.gguf_writer.add_base_model_version(key, base_model_entry["version"])
|
||||||
self.gguf_writer.add_parent_version(key, parent_entry.get("version"))
|
if "organization" in base_model_entry:
|
||||||
if "organization" in parent_entry:
|
self.gguf_writer.add_base_model_organization(key, base_model_entry["organization"])
|
||||||
self.gguf_writer.add_parent_organization(key, parent_entry.get("organization"))
|
if "url" in base_model_entry:
|
||||||
if "url" in parent_entry:
|
self.gguf_writer.add_base_model_url(key, base_model_entry["url"])
|
||||||
self.gguf_writer.add_parent_url(key, parent_entry.get("url"))
|
if "doi" in base_model_entry:
|
||||||
if "doi" in parent_entry:
|
self.gguf_writer.add_base_model_doi(key, base_model_entry["doi"])
|
||||||
self.gguf_writer.add_parent_doi(key, parent_entry.get("doi"))
|
if "uuid" in base_model_entry:
|
||||||
if "uuid" in parent_entry:
|
self.gguf_writer.add_base_model_uuid(key, base_model_entry["uuid"])
|
||||||
self.gguf_writer.add_parent_uuid(key, parent_entry.get("uuid"))
|
if "repo_url" in base_model_entry:
|
||||||
if "hf_repo" in parent_entry:
|
self.gguf_writer.add_base_model_repo_url(key, base_model_entry["repo_url"])
|
||||||
self.gguf_writer.add_parent_hf_repo(key, parent_entry.get("hf_repo"))
|
|
||||||
if self.metadata.tags is not None:
|
if self.metadata.tags is not None:
|
||||||
self.gguf_writer.add_tags(self.metadata.tags)
|
self.gguf_writer.add_tags(self.metadata.tags)
|
||||||
if self.metadata.languages is not None:
|
if self.metadata.languages is not None:
|
||||||
|
|
|
@ -783,57 +783,71 @@ class OutputFile:
|
||||||
|
|
||||||
self.gguf.add_name(name)
|
self.gguf.add_name(name)
|
||||||
|
|
||||||
if metadata.basename is not None:
|
|
||||||
self.gguf.add_basename(metadata.basename)
|
|
||||||
if metadata.finetune is not None:
|
|
||||||
self.gguf.add_finetune(metadata.finetune)
|
|
||||||
if metadata.author is not None:
|
if metadata.author is not None:
|
||||||
self.gguf.add_author(metadata.author)
|
self.gguf.add_author(metadata.author)
|
||||||
if metadata.organization is not None:
|
|
||||||
self.add_organization(metadata.organization)
|
|
||||||
if metadata.version is not None:
|
if metadata.version is not None:
|
||||||
self.gguf.add_version(metadata.version)
|
self.gguf.add_version(metadata.version)
|
||||||
if metadata.url is not None:
|
if metadata.organization is not None:
|
||||||
self.gguf.add_url(metadata.url)
|
self.gguf.add_organization(metadata.organization)
|
||||||
if metadata.doi is not None:
|
|
||||||
self.gguf.add_doi(metadata.doi)
|
if metadata.finetune is not None:
|
||||||
if metadata.uuid is not None:
|
self.gguf.add_finetune(metadata.finetune)
|
||||||
self.gguf.add_uuid(metadata.uuid)
|
if metadata.basename is not None:
|
||||||
if metadata.hf_repo is not None:
|
self.gguf.add_basename(metadata.basename)
|
||||||
self.gguf.add_hf_repo(metadata.hf_repo)
|
|
||||||
if metadata.description is not None:
|
if metadata.description is not None:
|
||||||
self.gguf.add_description(metadata.description)
|
self.gguf.add_description(metadata.description)
|
||||||
|
if metadata.quantized_by is not None:
|
||||||
|
self.gguf.add_quantized_by(metadata.quantized_by)
|
||||||
|
|
||||||
|
if metadata.parameter_class_attribute is not None:
|
||||||
|
self.gguf.add_parameter_class_attribute(metadata.parameter_class_attribute)
|
||||||
|
|
||||||
if metadata.license is not None:
|
if metadata.license is not None:
|
||||||
self.gguf.add_license(metadata.license)
|
self.gguf.add_license(metadata.license)
|
||||||
if metadata.license_name is not None:
|
if metadata.license_name is not None:
|
||||||
self.gguf.add_license_name(metadata.license_name)
|
self.gguf.add_license_name(metadata.license_name)
|
||||||
if metadata.license_link is not None:
|
if metadata.license_link is not None:
|
||||||
self.gguf.add_license_link(metadata.license_link)
|
self.gguf.add_license_link(metadata.license_link)
|
||||||
|
|
||||||
|
if metadata.url is not None:
|
||||||
|
self.gguf.add_url(metadata.url)
|
||||||
|
if metadata.doi is not None:
|
||||||
|
self.gguf.add_doi(metadata.doi)
|
||||||
|
if metadata.uuid is not None:
|
||||||
|
self.gguf.add_uuid(metadata.uuid)
|
||||||
|
if metadata.repo_url is not None:
|
||||||
|
self.gguf.add_repo_url(metadata.repo_url)
|
||||||
|
|
||||||
if metadata.source_url is not None:
|
if metadata.source_url is not None:
|
||||||
self.gguf.add_source_url(metadata.source_url)
|
self.gguf.add_source_url(metadata.source_url)
|
||||||
if metadata.source_hf_repo is not None:
|
if metadata.source_doi is not None:
|
||||||
self.gguf.add_source_hf_repo(metadata.source_hf_repo)
|
self.gguf.add_source_doi(metadata.source_doi)
|
||||||
if metadata.parameter_class_attribute is not None:
|
if metadata.source_uuid is not None:
|
||||||
self.gguf.add_parameter_class_attribute(metadata.parameter_class_attribute)
|
self.gguf.add_source_uuid(metadata.source_uuid)
|
||||||
if metadata.parents is not None:
|
if metadata.source_repo_url is not None:
|
||||||
metadata.parent_count = len(metadata.parents)
|
self.gguf.add_source_repo_url(metadata.source_repo_url)
|
||||||
for key, parent_entry in metadata.parents:
|
|
||||||
if "name" in parent_entry:
|
if metadata.base_models is not None:
|
||||||
self.gguf.add_parent_name(key, parent_entry.get("name"))
|
self.gguf.add_base_model_count(len(metadata.base_models))
|
||||||
if "author" in parent_entry:
|
for key, base_model_entry in enumerate(metadata.base_models):
|
||||||
self.gguf.add_parent_author(key, parent_entry.get("author"))
|
if "name" in base_model_entry:
|
||||||
if "version" in parent_entry:
|
self.gguf.add_base_model_name(key, base_model_entry["name"])
|
||||||
self.gguf.add_parent_version(key, parent_entry.get("version"))
|
if "author" in base_model_entry:
|
||||||
if "organization" in parent_entry:
|
self.gguf.add_base_model_author(key, base_model_entry["author"])
|
||||||
self.gguf.add_parent_organization(key, parent_entry.get("organization"))
|
if "version" in base_model_entry:
|
||||||
if "url" in parent_entry:
|
self.gguf.add_base_model_version(key, base_model_entry["version"])
|
||||||
self.gguf.add_parent_url(key, parent_entry.get("url"))
|
if "organization" in base_model_entry:
|
||||||
if "doi" in parent_entry:
|
self.gguf.add_base_model_organization(key, base_model_entry["organization"])
|
||||||
self.gguf.add_parent_doi(key, parent_entry.get("doi"))
|
if "url" in base_model_entry:
|
||||||
if "uuid" in parent_entry:
|
self.gguf.add_base_model_url(key, base_model_entry["url"])
|
||||||
self.gguf.add_parent_uuid(key, parent_entry.get("uuid"))
|
if "doi" in base_model_entry:
|
||||||
if "hf_repo" in parent_entry:
|
self.gguf.add_base_model_doi(key, base_model_entry["doi"])
|
||||||
self.gguf.add_parent_hf_repo(key, parent_entry.get("hf_repo"))
|
if "uuid" in base_model_entry:
|
||||||
|
self.gguf.add_base_model_uuid(key, base_model_entry["uuid"])
|
||||||
|
if "repo_url" in base_model_entry:
|
||||||
|
self.gguf.add_base_model_repo_url(key, base_model_entry["repo_url"])
|
||||||
|
|
||||||
if metadata.tags is not None:
|
if metadata.tags is not None:
|
||||||
self.gguf.add_tags(metadata.tags)
|
self.gguf.add_tags(metadata.tags)
|
||||||
if metadata.languages is not None:
|
if metadata.languages is not None:
|
||||||
|
|
|
@ -78,5 +78,13 @@ python -m build
|
||||||
python -m twine upload dist/*
|
python -m twine upload dist/*
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Run Unit Tests
|
||||||
|
|
||||||
|
From root of this repository you can run this command to run all the unit tests
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m unittest discover ./gguf-py -v
|
||||||
|
```
|
||||||
|
|
||||||
## TODO
|
## TODO
|
||||||
- [ ] Include conversion scripts as command line entry points in this package.
|
- [ ] Include conversion scripts as command line entry points in this package.
|
||||||
|
|
|
@ -31,10 +31,12 @@ class Keys:
|
||||||
VERSION = "general.version"
|
VERSION = "general.version"
|
||||||
ORGANIZATION = "general.organization"
|
ORGANIZATION = "general.organization"
|
||||||
|
|
||||||
BASENAME = "general.basename"
|
|
||||||
FINETUNE = "general.finetune"
|
FINETUNE = "general.finetune"
|
||||||
|
BASENAME = "general.basename"
|
||||||
|
|
||||||
DESCRIPTION = "general.description"
|
DESCRIPTION = "general.description"
|
||||||
QUANTIZED_BY = "general.quantized_by"
|
QUANTIZED_BY = "general.quantized_by"
|
||||||
|
|
||||||
PARAMETER_CLASS_ATTRIBUTE = "general.parameter_class_attribute"
|
PARAMETER_CLASS_ATTRIBUTE = "general.parameter_class_attribute"
|
||||||
|
|
||||||
# Licensing details
|
# Licensing details
|
||||||
|
@ -43,31 +45,29 @@ class Keys:
|
||||||
LICENSE_LINK = "general.license.link"
|
LICENSE_LINK = "general.license.link"
|
||||||
|
|
||||||
# Typically represents the converted GGUF repo (Unless native)
|
# Typically represents the converted GGUF repo (Unless native)
|
||||||
URL = "general.url"
|
URL = "general.url" # Model Website/Paper
|
||||||
DOI = "general.doi"
|
DOI = "general.doi"
|
||||||
UUID = "general.uuid"
|
UUID = "general.uuid"
|
||||||
HF_URL = "general.huggingface.repository"
|
REPO_URL = "general.repo_url" # Model Source Repository (git/svn/etc...)
|
||||||
|
|
||||||
# Typically represents the original source repository (e.g. safetensors)
|
# Model Source during conversion
|
||||||
# that this was
|
SOURCE_URL = "general.source.url" # Model Website/Paper
|
||||||
SOURCE_URL = "general.source.url"
|
|
||||||
SOURCE_DOI = "general.source.doi"
|
SOURCE_DOI = "general.source.doi"
|
||||||
SOURCE_UUID = "general.source.uuid"
|
SOURCE_UUID = "general.source.uuid"
|
||||||
SOURCE_HF_REPO = "general.source.huggingface.repository"
|
SOURCE_REPO_URL = "general.source.repo_url" # Model Source Repository (git/svn/etc...)
|
||||||
|
|
||||||
# This represents the parent model that the converted/source model was
|
# Base Model Source. There can be more than one source if it's a merged
|
||||||
# derived from on allowing users to trace the linage of a model.
|
# model like with 'Mistral-7B-Merge-14-v0.1'. This will assist in
|
||||||
# E.g. A finetune model would have the base model as the parent
|
# tracing linage of models as it is finetuned or merged over time.
|
||||||
# (A model can have multiple parent, especially if it's a merged model)
|
BASE_MODEL_COUNT = "general.base_model.count"
|
||||||
PARENTS_COUNT = "general.parents.count"
|
BASE_MODEL_NAME = "general.base_model.{id}.name"
|
||||||
PARENTS_NAME = "general.parents.{id}.name"
|
BASE_MODEL_AUTHOR = "general.base_model.{id}.author"
|
||||||
PARENTS_AUTHOR = "general.parents.{id}.author"
|
BASE_MODEL_VERSION = "general.base_model.{id}.version"
|
||||||
PARENTS_VERSION = "general.parents.{id}.version"
|
BASE_MODEL_ORGANIZATION = "general.base_model.{id}.organization"
|
||||||
PARENTS_ORGANIZATION = "general.parents.{id}.organization"
|
BASE_MODEL_URL = "general.base_model.{id}.url" # Model Website/Paper
|
||||||
PARENTS_URL = "general.parents.{id}.url"
|
BASE_MODEL_DOI = "general.base_model.{id}.doi"
|
||||||
PARENTS_DOI = "general.parents.{id}.doi"
|
BASE_MODEL_UUID = "general.base_model.{id}.uuid"
|
||||||
PARENTS_UUID = "general.parents.{id}.uuid"
|
BASE_MODEL_REPO_URL = "general.base_model.{id}.repo_url" # Model Source Repository (git/svn/etc...)
|
||||||
PARENTS_HF_REPO = "general.parents.{id}.huggingface.repository"
|
|
||||||
|
|
||||||
# Array based KV stores
|
# Array based KV stores
|
||||||
TAGS = "general.tags"
|
TAGS = "general.tags"
|
||||||
|
@ -1273,8 +1273,6 @@ KEY_GENERAL_AUTHOR = Keys.General.AUTHOR
|
||||||
KEY_GENERAL_URL = Keys.General.URL
|
KEY_GENERAL_URL = Keys.General.URL
|
||||||
KEY_GENERAL_DESCRIPTION = Keys.General.DESCRIPTION
|
KEY_GENERAL_DESCRIPTION = Keys.General.DESCRIPTION
|
||||||
KEY_GENERAL_LICENSE = Keys.General.LICENSE
|
KEY_GENERAL_LICENSE = Keys.General.LICENSE
|
||||||
KEY_GENERAL_SOURCE_URL = Keys.General.SOURCE_URL
|
|
||||||
KEY_GENERAL_SOURCE_HF_REPO = Keys.General.SOURCE_HF_REPO
|
|
||||||
KEY_GENERAL_FILE_TYPE = Keys.General.FILE_TYPE
|
KEY_GENERAL_FILE_TYPE = Keys.General.FILE_TYPE
|
||||||
|
|
||||||
# LLM
|
# LLM
|
||||||
|
|
|
@ -430,42 +430,43 @@ class GGUFWriter:
|
||||||
def add_architecture(self) -> None:
|
def add_architecture(self) -> None:
|
||||||
self.add_string(Keys.General.ARCHITECTURE, self.arch)
|
self.add_string(Keys.General.ARCHITECTURE, self.arch)
|
||||||
|
|
||||||
def add_basename(self, basename: str) -> None:
|
def add_quantization_version(self, quantization_version: int) -> None:
|
||||||
self.add_string(Keys.General.BASENAME, basename)
|
self.add_uint32(Keys.General.QUANTIZATION_VERSION, quantization_version)
|
||||||
|
|
||||||
def add_finetune(self, finetune: str) -> None:
|
def add_custom_alignment(self, alignment: int) -> None:
|
||||||
self.add_string(Keys.General.FINETUNE, finetune)
|
self.data_alignment = alignment
|
||||||
|
self.add_uint32(Keys.General.ALIGNMENT, alignment)
|
||||||
|
|
||||||
|
def add_file_type(self, ftype: int) -> None:
|
||||||
|
self.add_uint32(Keys.General.FILE_TYPE, ftype)
|
||||||
|
|
||||||
|
def add_name(self, name: str) -> None:
|
||||||
|
self.add_string(Keys.General.NAME, name)
|
||||||
|
|
||||||
def add_author(self, author: str) -> None:
|
def add_author(self, author: str) -> None:
|
||||||
self.add_string(Keys.General.AUTHOR, author)
|
self.add_string(Keys.General.AUTHOR, author)
|
||||||
|
|
||||||
def add_quantized_by(self, quantized: str) -> None:
|
def add_version(self, version: str) -> None:
|
||||||
self.add_string(Keys.General.QUANTIZED_BY, quantized)
|
self.add_string(Keys.General.VERSION, version)
|
||||||
|
|
||||||
def add_organization(self, organization: str) -> None:
|
def add_organization(self, organization: str) -> None:
|
||||||
self.add_string(Keys.General.ORGANIZATION, organization)
|
self.add_string(Keys.General.ORGANIZATION, organization)
|
||||||
|
|
||||||
def add_version(self, version: str) -> None:
|
def add_finetune(self, finetune: str) -> None:
|
||||||
self.add_string(Keys.General.VERSION, version)
|
self.add_string(Keys.General.FINETUNE, finetune)
|
||||||
|
|
||||||
def add_tensor_data_layout(self, layout: str) -> None:
|
def add_basename(self, basename: str) -> None:
|
||||||
self.add_string(Keys.LLM.TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
|
self.add_string(Keys.General.BASENAME, basename)
|
||||||
|
|
||||||
def add_url(self, url: str) -> None:
|
|
||||||
self.add_string(Keys.General.URL, url)
|
|
||||||
|
|
||||||
def add_doi(self, doi: str) -> None:
|
|
||||||
self.add_string(Keys.General.DOI, doi)
|
|
||||||
|
|
||||||
def add_uuid(self, uuid: str) -> None:
|
|
||||||
self.add_string(Keys.General.UUID, uuid)
|
|
||||||
|
|
||||||
def add_hf_repo(self, hf_repo: str) -> None:
|
|
||||||
self.add_string(Keys.General.HF_REPO, hf_repo)
|
|
||||||
|
|
||||||
def add_description(self, description: str) -> None:
|
def add_description(self, description: str) -> None:
|
||||||
self.add_string(Keys.General.DESCRIPTION, description)
|
self.add_string(Keys.General.DESCRIPTION, description)
|
||||||
|
|
||||||
|
def add_quantized_by(self, quantized: str) -> None:
|
||||||
|
self.add_string(Keys.General.QUANTIZED_BY, quantized)
|
||||||
|
|
||||||
|
def add_parameter_class_attribute(self, parameter_class_attribute: str) -> None:
|
||||||
|
self.add_string(Keys.General.PARAMETER_CLASS_ATTRIBUTE, parameter_class_attribute)
|
||||||
|
|
||||||
def add_license(self, license: str) -> None:
|
def add_license(self, license: str) -> None:
|
||||||
self.add_string(Keys.General.LICENSE, license)
|
self.add_string(Keys.General.LICENSE, license)
|
||||||
|
|
||||||
|
@ -475,44 +476,56 @@ class GGUFWriter:
|
||||||
def add_license_link(self, license: str) -> None:
|
def add_license_link(self, license: str) -> None:
|
||||||
self.add_string(Keys.General.LICENSE_LINK, license)
|
self.add_string(Keys.General.LICENSE_LINK, license)
|
||||||
|
|
||||||
|
def add_url(self, url: str) -> None:
|
||||||
|
self.add_string(Keys.General.URL, url)
|
||||||
|
|
||||||
|
def add_doi(self, doi: str) -> None:
|
||||||
|
self.add_string(Keys.General.DOI, doi)
|
||||||
|
|
||||||
|
def add_uuid(self, uuid: str) -> None:
|
||||||
|
self.add_string(Keys.General.UUID, uuid)
|
||||||
|
|
||||||
|
def add_repo_url(self, repo_url: str) -> None:
|
||||||
|
self.add_string(Keys.General.REPO_URL, repo_url)
|
||||||
|
|
||||||
def add_source_url(self, url: str) -> None:
|
def add_source_url(self, url: str) -> None:
|
||||||
self.add_string(Keys.General.SOURCE_URL, url)
|
self.add_string(Keys.General.SOURCE_URL, url)
|
||||||
|
|
||||||
def add_source_hf_repo(self, repo: str) -> None:
|
def add_source_doi(self, doi: str) -> None:
|
||||||
self.add_string(Keys.General.SOURCE_HF_REPO, repo)
|
self.add_string(Keys.General.SOURCE_DOI, doi)
|
||||||
|
|
||||||
def add_file_type(self, ftype: int) -> None:
|
def add_source_uuid(self, uuid: str) -> None:
|
||||||
self.add_uint32(Keys.General.FILE_TYPE, ftype)
|
self.add_string(Keys.General.SOURCE_UUID, uuid)
|
||||||
|
|
||||||
def add_parameter_class_attribute(self, parameter_class_attribute: str) -> None:
|
def add_source_repo_url(self, repo_url: str) -> None:
|
||||||
self.add_string(Keys.General.PARAMETER_CLASS_ATTRIBUTE, parameter_class_attribute)
|
self.add_string(Keys.General.SOURCE_REPO_URL, repo_url)
|
||||||
|
|
||||||
def add_parent_count(self, parent_count: int) -> None:
|
def add_base_model_count(self, source_count: int) -> None:
|
||||||
self.add_uint32(Keys.General.PARENTS_COUNT, parent_count)
|
self.add_uint32(Keys.General.BASE_MODEL_COUNT, source_count)
|
||||||
|
|
||||||
def add_parent_name(self, parent_id: int, name: str) -> None:
|
def add_base_model_name(self, source_id: int, name: str) -> None:
|
||||||
self.add_string(Keys.General.PARENTS_NAME.format(id=self.parent_id), name)
|
self.add_string(Keys.General.BASE_MODEL_NAME.format(id=self.source_id), name)
|
||||||
|
|
||||||
def add_parent_author(self, parent_id: int, author: str) -> None:
|
def add_base_model_author(self, source_id: int, author: str) -> None:
|
||||||
self.add_string(Keys.General.PARENTS_AUTHOR.format(id=self.parent_id), author)
|
self.add_string(Keys.General.BASE_MODEL_AUTHOR.format(id=self.source_id), author)
|
||||||
|
|
||||||
def add_parent_version(self, parent_id: int, version: str) -> None:
|
def add_base_model_version(self, source_id: int, version: str) -> None:
|
||||||
self.add_string(Keys.General.PARENTS_VERSION.format(id=self.parent_id), version)
|
self.add_string(Keys.General.BASE_MODEL_VERSION.format(id=self.source_id), version)
|
||||||
|
|
||||||
def add_parent_organization(self, parent_id: int, organization: str) -> None:
|
def add_base_model_organization(self, source_id: int, organization: str) -> None:
|
||||||
self.add_string(Keys.General.PARENTS_ORGANIZATION.format(id=self.parent_id), organization)
|
self.add_string(Keys.General.BASE_MODEL_ORGANIZATION.format(id=self.source_id), organization)
|
||||||
|
|
||||||
def add_parent_url(self, parent_id: int, url: str) -> None:
|
def add_base_model_url(self, source_id: int, url: str) -> None:
|
||||||
self.add_string(Keys.General.PARENTS_URL.format(id=self.parent_id), url)
|
self.add_string(Keys.General.BASE_MODEL_URL.format(id=self.source_id), url)
|
||||||
|
|
||||||
def add_parent_doi(self, parent_id: int, doi: str) -> None:
|
def add_base_model_doi(self, source_id: int, doi: str) -> None:
|
||||||
self.add_string(Keys.General.PARENTS_DOI.format(id=self.parent_id), doi)
|
self.add_string(Keys.General.BASE_MODEL_DOI.format(id=self.source_id), doi)
|
||||||
|
|
||||||
def add_parent_uuid(self, parent_id: int, uuid: str) -> None:
|
def add_base_model_uuid(self, source_id: int, uuid: str) -> None:
|
||||||
self.add_string(Keys.General.PARENTS_UUID.format(id=self.parent_id), uuid)
|
self.add_string(Keys.General.BASE_MODEL_UUID.format(id=self.source_id), uuid)
|
||||||
|
|
||||||
def add_parent_hf_repo(self, parent_id: int, hf_repo: str) -> None:
|
def add_base_model_repo_url(self, source_id: int, repo_url: str) -> None:
|
||||||
self.add_string(Keys.General.PARENTS_HF_REPO.format(id=self.parent_id), hf_repo)
|
self.add_string(Keys.General.BASE_MODEL_REPO_URL.format(id=self.source_id), repo_url)
|
||||||
|
|
||||||
def add_tags(self, tags: Sequence[str]) -> None:
|
def add_tags(self, tags: Sequence[str]) -> None:
|
||||||
self.add_array(Keys.General.TAGS, tags)
|
self.add_array(Keys.General.TAGS, tags)
|
||||||
|
@ -523,16 +536,8 @@ class GGUFWriter:
|
||||||
def add_datasets(self, datasets: Sequence[str]) -> None:
|
def add_datasets(self, datasets: Sequence[str]) -> None:
|
||||||
self.add_array(Keys.General.DATASETS, datasets)
|
self.add_array(Keys.General.DATASETS, datasets)
|
||||||
|
|
||||||
def add_name(self, name: str) -> None:
|
def add_tensor_data_layout(self, layout: str) -> None:
|
||||||
self.add_string(Keys.General.NAME, name)
|
self.add_string(Keys.LLM.TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
|
||||||
|
|
||||||
def add_quantization_version(self, quantization_version: int) -> None:
|
|
||||||
self.add_uint32(
|
|
||||||
Keys.General.QUANTIZATION_VERSION, quantization_version)
|
|
||||||
|
|
||||||
def add_custom_alignment(self, alignment: int) -> None:
|
|
||||||
self.data_alignment = alignment
|
|
||||||
self.add_uint32(Keys.General.ALIGNMENT, alignment)
|
|
||||||
|
|
||||||
def add_vocab_size(self, size: int) -> None:
|
def add_vocab_size(self, size: int) -> None:
|
||||||
self.add_uint32(Keys.LLM.VOCAB_SIZE.format(arch=self.arch), size)
|
self.add_uint32(Keys.LLM.VOCAB_SIZE.format(arch=self.arch), size)
|
||||||
|
|
268
gguf-py/gguf/metadata.py
Executable file → Normal file
268
gguf-py/gguf/metadata.py
Executable file → Normal file
|
@ -1,18 +1,13 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import re
|
||||||
import json
|
import json
|
||||||
import unittest
|
import uuid
|
||||||
import frontmatter
|
import frontmatter
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
from constants import Keys
|
|
||||||
else:
|
|
||||||
from .constants import Keys
|
from .constants import Keys
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,26 +15,26 @@ else:
|
||||||
class Metadata:
|
class Metadata:
|
||||||
# Authorship Metadata to be written to GGUF KV Store
|
# Authorship Metadata to be written to GGUF KV Store
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
basename: Optional[str] = None
|
|
||||||
finetune: Optional[str] = None
|
|
||||||
author: Optional[str] = None
|
author: Optional[str] = None
|
||||||
quantized_by: Optional[str] = None
|
|
||||||
organization: Optional[str] = None
|
|
||||||
version: Optional[str] = None
|
version: Optional[str] = None
|
||||||
|
organization: Optional[str] = None
|
||||||
|
finetune: Optional[str] = None
|
||||||
|
basename: Optional[str] = None
|
||||||
|
description: Optional[str] = None
|
||||||
|
quantized_by: Optional[str] = None
|
||||||
|
parameter_class_attribute: Optional[str] = None
|
||||||
url: Optional[str] = None
|
url: Optional[str] = None
|
||||||
doi: Optional[str] = None
|
doi: Optional[str] = None
|
||||||
uuid: Optional[str] = None
|
uuid: Optional[str] = None
|
||||||
hf_repo: Optional[str] = None
|
repo_url: Optional[str] = None
|
||||||
description: Optional[str] = None
|
|
||||||
license: Optional[str] = None
|
|
||||||
license_name: Optional[str] = None
|
|
||||||
license_link: Optional[str] = None
|
|
||||||
source_url: Optional[str] = None
|
source_url: Optional[str] = None
|
||||||
source_doi: Optional[str] = None
|
source_doi: Optional[str] = None
|
||||||
source_uuid: Optional[str] = None
|
source_uuid: Optional[str] = None
|
||||||
source_hf_repo: Optional[str] = None
|
source_repo_url: Optional[str] = None
|
||||||
parameter_class_attribute: Optional[str] = None
|
license: Optional[str] = None
|
||||||
parents: Optional[list[dict]] = None
|
license_name: Optional[str] = None
|
||||||
|
license_link: Optional[str] = None
|
||||||
|
base_models: Optional[list[dict]] = None
|
||||||
tags: Optional[list[str]] = None
|
tags: Optional[list[str]] = None
|
||||||
languages: Optional[list[str]] = None
|
languages: Optional[list[str]] = None
|
||||||
datasets: Optional[list[str]] = None
|
datasets: Optional[list[str]] = None
|
||||||
|
@ -68,10 +63,12 @@ class Metadata:
|
||||||
metadata.version = metadata_override.get(Keys.General.VERSION , metadata.version ) # noqa: E202
|
metadata.version = metadata_override.get(Keys.General.VERSION , metadata.version ) # noqa: E202
|
||||||
metadata.organization = metadata_override.get(Keys.General.ORGANIZATION , metadata.organization ) # noqa: E202
|
metadata.organization = metadata_override.get(Keys.General.ORGANIZATION , metadata.organization ) # noqa: E202
|
||||||
|
|
||||||
metadata.basename = metadata_override.get(Keys.General.BASENAME , metadata.basename ) # noqa: E202
|
|
||||||
metadata.finetune = metadata_override.get(Keys.General.FINETUNE , metadata.finetune ) # noqa: E202
|
metadata.finetune = metadata_override.get(Keys.General.FINETUNE , metadata.finetune ) # noqa: E202
|
||||||
|
metadata.basename = metadata_override.get(Keys.General.BASENAME , metadata.basename ) # noqa: E202
|
||||||
|
|
||||||
metadata.description = metadata_override.get(Keys.General.DESCRIPTION , metadata.description ) # noqa: E202
|
metadata.description = metadata_override.get(Keys.General.DESCRIPTION , metadata.description ) # noqa: E202
|
||||||
metadata.quantized_by = metadata_override.get(Keys.General.QUANTIZED_BY , metadata.quantized_by ) # noqa: E202
|
metadata.quantized_by = metadata_override.get(Keys.General.QUANTIZED_BY , metadata.quantized_by ) # noqa: E202
|
||||||
|
|
||||||
metadata.parameter_class_attribute = metadata_override.get(Keys.General.PARAMETER_CLASS_ATTRIBUTE, metadata.parameter_class_attribute) # noqa: E202
|
metadata.parameter_class_attribute = metadata_override.get(Keys.General.PARAMETER_CLASS_ATTRIBUTE, metadata.parameter_class_attribute) # noqa: E202
|
||||||
|
|
||||||
metadata.license = metadata_override.get(Keys.General.LICENSE , metadata.license ) # noqa: E202
|
metadata.license = metadata_override.get(Keys.General.LICENSE , metadata.license ) # noqa: E202
|
||||||
|
@ -81,14 +78,14 @@ class Metadata:
|
||||||
metadata.url = metadata_override.get(Keys.General.URL , metadata.url ) # noqa: E202
|
metadata.url = metadata_override.get(Keys.General.URL , metadata.url ) # noqa: E202
|
||||||
metadata.doi = metadata_override.get(Keys.General.DOI , metadata.doi ) # noqa: E202
|
metadata.doi = metadata_override.get(Keys.General.DOI , metadata.doi ) # noqa: E202
|
||||||
metadata.uuid = metadata_override.get(Keys.General.UUID , metadata.uuid ) # noqa: E202
|
metadata.uuid = metadata_override.get(Keys.General.UUID , metadata.uuid ) # noqa: E202
|
||||||
metadata.hf_repo = metadata_override.get(Keys.General.HF_REPO , metadata.hf_repo ) # noqa: E202
|
metadata.repo_url = metadata_override.get(Keys.General.REPO_URL , metadata.repo_url ) # noqa: E202
|
||||||
|
|
||||||
metadata.source_url = metadata_override.get(Keys.General.SOURCE_URL , metadata.source_url ) # noqa: E202
|
metadata.source_url = metadata_override.get(Keys.General.SOURCE_URL , metadata.source_url ) # noqa: E202
|
||||||
metadata.source_doi = metadata_override.get(Keys.General.SOURCE_DOI , metadata.source_doi ) # noqa: E202
|
metadata.source_doi = metadata_override.get(Keys.General.SOURCE_DOI , metadata.source_doi ) # noqa: E202
|
||||||
metadata.source_uuid = metadata_override.get(Keys.General.SOURCE_UUID , metadata.source_uuid ) # noqa: E202
|
metadata.source_uuid = metadata_override.get(Keys.General.SOURCE_UUID , metadata.source_uuid ) # noqa: E202
|
||||||
metadata.source_hf_repo = metadata_override.get(Keys.General.SOURCE_HF_REPO , metadata.source_hf_repo ) # noqa: E202
|
metadata.source_repo_url = metadata_override.get(Keys.General.SOURCE_REPO_URL , metadata.source_repo_url ) # noqa: E202
|
||||||
|
|
||||||
metadata.parent_count = metadata_override.get("general.parents" , metadata.parent_count ) # noqa: E202
|
metadata.base_models = metadata_override.get("general.base_models" , metadata.base_models ) # noqa: E202
|
||||||
|
|
||||||
metadata.tags = metadata_override.get(Keys.General.TAGS , metadata.tags ) # noqa: E202
|
metadata.tags = metadata_override.get(Keys.General.TAGS , metadata.tags ) # noqa: E202
|
||||||
metadata.languages = metadata_override.get(Keys.General.LANGUAGES , metadata.languages ) # noqa: E202
|
metadata.languages = metadata_override.get(Keys.General.LANGUAGES , metadata.languages ) # noqa: E202
|
||||||
|
@ -98,6 +95,9 @@ class Metadata:
|
||||||
if model_name is not None:
|
if model_name is not None:
|
||||||
metadata.name = model_name
|
metadata.name = model_name
|
||||||
|
|
||||||
|
# If any UUID is still missing at this point, then we should fill it in
|
||||||
|
metadata = Metadata.generate_any_missing_uuid(metadata)
|
||||||
|
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -137,8 +137,7 @@ class Metadata:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def id_to_title(string):
|
def id_to_title(string):
|
||||||
# Convert capitalization into title form unless acronym or version number
|
# Convert capitalization into title form unless acronym or version number
|
||||||
string = string.strip().replace('-', ' ')
|
return ' '.join([w.title() if w.islower() and not re.match(r'^(v\d+(?:\.\d+)*|\d.*)$', w) else w for w in string.strip().replace('-', ' ').split()])
|
||||||
return ' '.join([w.title() if w.islower() and not re.match(r'^v\d+(?:\.\d+)*$', w) else w for w in string.split()])
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_model_id_components(model_id: Optional[str] = None) -> dict[str, object]:
|
def get_model_id_components(model_id: Optional[str] = None) -> dict[str, object]:
|
||||||
|
@ -186,67 +185,119 @@ class Metadata:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def apply_metadata_heuristic(metadata: Metadata, model_card: Optional[dict] = None, hf_params: Optional[dict] = None, model_path: Optional[Path] = None) -> Metadata:
|
def apply_metadata_heuristic(metadata: Metadata, model_card: Optional[dict] = None, hf_params: Optional[dict] = None, model_path: Optional[Path] = None) -> Metadata:
|
||||||
# Reference Model Card Metadata: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
|
# Reference Model Card Metadata: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
|
||||||
found_base_model = False
|
|
||||||
|
|
||||||
# Model Card Heuristics
|
# Model Card Heuristics
|
||||||
########################
|
########################
|
||||||
if model_card is not None:
|
if model_card is not None:
|
||||||
|
|
||||||
if "model_name" in model_card:
|
if "model_name" in model_card and metadata.name is None:
|
||||||
# Not part of huggingface model card standard but notice some model creator using it
|
# Not part of huggingface model card standard but notice some model creator using it
|
||||||
# such as TheBloke who would encode 'Mixtral 8X7B Instruct v0.1' into model_name
|
# such as TheBloke in 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF'
|
||||||
metadata.name = model_card.get("model_name")
|
metadata.name = model_card.get("model_name")
|
||||||
|
|
||||||
if "base_model" in model_card and isinstance(model_card["base_model"], str) and not found_base_model:
|
if "model_creator" in model_card and metadata.author is None:
|
||||||
# Check if string. We cannot handle lists as that is too ambagious
|
# Not part of huggingface model card standard but notice some model creator using it
|
||||||
|
# such as TheBloke in 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF'
|
||||||
|
metadata.author = model_card.get("model_creator")
|
||||||
|
|
||||||
|
if "model_type" in model_card and metadata.basename is None:
|
||||||
|
# Not part of huggingface model card standard but notice some model creator using it
|
||||||
|
# such as TheBloke in 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF'
|
||||||
|
metadata.basename = model_card.get("model_type")
|
||||||
|
|
||||||
|
if "base_model" in model_card:
|
||||||
|
# This represents the parent models that this is based on
|
||||||
# Example: stabilityai/stable-diffusion-xl-base-1.0. Can also be a list (for merges)
|
# Example: stabilityai/stable-diffusion-xl-base-1.0. Can also be a list (for merges)
|
||||||
model_id = model_card.get("base_model")
|
# Example of merges: https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0.1/blob/main/README.md
|
||||||
|
metadata_base_models = []
|
||||||
|
base_model_value = model_card.get("base_model", None)
|
||||||
|
|
||||||
|
if base_model_value is not None:
|
||||||
|
if isinstance(base_model_value, str):
|
||||||
|
metadata_base_models.append(base_model_value)
|
||||||
|
elif isinstance(base_model_value, list):
|
||||||
|
metadata_base_models.extend(base_model_value)
|
||||||
|
|
||||||
|
if metadata.base_models is None:
|
||||||
|
metadata.base_models = []
|
||||||
|
|
||||||
|
for model_id in metadata_base_models:
|
||||||
model_full_name_component, org_component, basename, finetune, version, parameter_class_attribute = Metadata.get_model_id_components(model_id)
|
model_full_name_component, org_component, basename, finetune, version, parameter_class_attribute = Metadata.get_model_id_components(model_id)
|
||||||
if metadata.name is None and model_full_name_component is not None:
|
base_model = {}
|
||||||
metadata.name = Metadata.id_to_title(model_full_name_component)
|
if model_full_name_component is not None:
|
||||||
if metadata.organization is None and org_component is not None:
|
base_model["name"] = Metadata.id_to_title(model_full_name_component)
|
||||||
metadata.organization = Metadata.id_to_title(org_component)
|
if org_component is not None:
|
||||||
if metadata.basename is None and basename is not None:
|
base_model["organization"] = Metadata.id_to_title(org_component)
|
||||||
metadata.basename = basename
|
if version is not None:
|
||||||
if metadata.finetune is None and finetune is not None:
|
base_model["version"] = version
|
||||||
metadata.finetune = finetune
|
if org_component is not None and model_full_name_component is not None:
|
||||||
if metadata.version is None and version is not None:
|
base_model["repo_url"] = f"https://huggingface.co/{org_component}/{model_full_name_component}"
|
||||||
metadata.version = version
|
metadata.base_models.append(base_model)
|
||||||
if metadata.parameter_class_attribute is None and parameter_class_attribute is not None:
|
|
||||||
metadata.parameter_class_attribute = parameter_class_attribute
|
|
||||||
if metadata.source_url is None and org_component is not None and model_full_name_component is not None:
|
|
||||||
metadata.source_url = f"https://huggingface.co/{org_component}/{model_full_name_component}"
|
|
||||||
if metadata.source_hf_repo is None and org_component is not None and model_full_name_component is not None:
|
|
||||||
metadata.source_hf_repo = f"{org_component}/{model_full_name_component}"
|
|
||||||
|
|
||||||
found_base_model = True
|
if "quantized_by" in model_card and metadata.quantized_by is None:
|
||||||
|
|
||||||
if metadata.quantized_by is None:
|
|
||||||
# Not part of hugging face model card standard, but is used by TheBloke to credit them self for quantizing 3rd party models
|
# Not part of hugging face model card standard, but is used by TheBloke to credit them self for quantizing 3rd party models
|
||||||
metadata.quantized_by = model_card.get("quantized_by")
|
metadata.quantized_by = model_card.get("quantized_by")
|
||||||
if metadata.license is None:
|
|
||||||
|
if "license" in model_card and metadata.license is None:
|
||||||
metadata.license = model_card.get("license")
|
metadata.license = model_card.get("license")
|
||||||
if metadata.license_name is None:
|
|
||||||
|
if "license_name" in model_card and metadata.license_name is None:
|
||||||
metadata.license_name = model_card.get("license_name")
|
metadata.license_name = model_card.get("license_name")
|
||||||
if metadata.license_link is None:
|
|
||||||
|
if "license_link" in model_card and metadata.license_link is None:
|
||||||
metadata.license_link = model_card.get("license_link")
|
metadata.license_link = model_card.get("license_link")
|
||||||
if metadata.author is None:
|
|
||||||
# non huggingface model card standard but notice some model creator using it
|
tags_value = model_card.get("tags", None)
|
||||||
metadata.author = model_card.get("model_creator")
|
if tags_value is not None:
|
||||||
|
|
||||||
if metadata.tags is None:
|
if metadata.tags is None:
|
||||||
metadata.tags = model_card.get("tags", None)
|
metadata.tags = []
|
||||||
|
|
||||||
|
if isinstance(tags_value, str):
|
||||||
|
metadata.tags.append(tags_value)
|
||||||
|
elif isinstance(tags_value, list):
|
||||||
|
metadata.tags.extend(tags_value)
|
||||||
|
|
||||||
|
pipeline_tags_value = model_card.get("pipeline_tag", None)
|
||||||
|
if pipeline_tags_value is not None:
|
||||||
|
|
||||||
|
if metadata.tags is None:
|
||||||
|
metadata.tags = []
|
||||||
|
|
||||||
|
if isinstance(pipeline_tags_value, str):
|
||||||
|
metadata.tags.append(pipeline_tags_value)
|
||||||
|
elif isinstance(pipeline_tags_value, list):
|
||||||
|
metadata.tags.extend(pipeline_tags_value)
|
||||||
|
|
||||||
|
language_value = model_card.get("languages", model_card.get("language", None))
|
||||||
|
if language_value is not None:
|
||||||
|
|
||||||
if metadata.languages is None:
|
if metadata.languages is None:
|
||||||
metadata.languages = model_card.get("language", model_card.get("languages", None))
|
metadata.languages = []
|
||||||
|
|
||||||
|
if isinstance(language_value, str):
|
||||||
|
metadata.languages.append(language_value)
|
||||||
|
elif isinstance(language_value, list):
|
||||||
|
metadata.languages.extend(language_value)
|
||||||
|
|
||||||
|
dataset_value = model_card.get("datasets", model_card.get("dataset", None))
|
||||||
|
if dataset_value is not None:
|
||||||
|
|
||||||
if metadata.datasets is None:
|
if metadata.datasets is None:
|
||||||
metadata.datasets = model_card.get("datasets", model_card.get("dataset", None))
|
metadata.datasets = []
|
||||||
|
|
||||||
|
if isinstance(dataset_value, str):
|
||||||
|
metadata.datasets.append(dataset_value)
|
||||||
|
elif isinstance(dataset_value, list):
|
||||||
|
metadata.datasets.extend(dataset_value)
|
||||||
|
|
||||||
# Hugging Face Parameter Heuristics
|
# Hugging Face Parameter Heuristics
|
||||||
####################################
|
####################################
|
||||||
|
|
||||||
if hf_params is not None:
|
if hf_params is not None:
|
||||||
hf_name_or_path = hf_params.get("_name_or_path")
|
|
||||||
|
|
||||||
if hf_name_or_path is not None and hf_name_or_path.count('/') <= 1 and not found_base_model:
|
hf_name_or_path = hf_params.get("_name_or_path")
|
||||||
|
if hf_name_or_path is not None and hf_name_or_path.count('/') <= 1:
|
||||||
# Use _name_or_path only if its actually a model name and not some computer path
|
# Use _name_or_path only if its actually a model name and not some computer path
|
||||||
# e.g. 'meta-llama/Llama-2-7b-hf'
|
# e.g. 'meta-llama/Llama-2-7b-hf'
|
||||||
model_id = hf_name_or_path
|
model_id = hf_name_or_path
|
||||||
|
@ -263,10 +314,6 @@ class Metadata:
|
||||||
metadata.version = version
|
metadata.version = version
|
||||||
if metadata.parameter_class_attribute is None and parameter_class_attribute is not None:
|
if metadata.parameter_class_attribute is None and parameter_class_attribute is not None:
|
||||||
metadata.parameter_class_attribute = parameter_class_attribute
|
metadata.parameter_class_attribute = parameter_class_attribute
|
||||||
if metadata.source_url is None and org_component is not None and model_full_name_component is not None:
|
|
||||||
metadata.source_url = f"https://huggingface.co/{org_component}/{model_full_name_component}"
|
|
||||||
if metadata.source_hf_repo is None and org_component is not None and model_full_name_component is not None:
|
|
||||||
metadata.source_hf_repo = f"{org_component}/{model_full_name_component}"
|
|
||||||
|
|
||||||
# Directory Folder Name Fallback Heuristics
|
# Directory Folder Name Fallback Heuristics
|
||||||
############################################
|
############################################
|
||||||
|
@ -285,66 +332,57 @@ class Metadata:
|
||||||
metadata.version = version
|
metadata.version = version
|
||||||
if metadata.parameter_class_attribute is None and parameter_class_attribute is not None:
|
if metadata.parameter_class_attribute is None and parameter_class_attribute is not None:
|
||||||
metadata.parameter_class_attribute = parameter_class_attribute
|
metadata.parameter_class_attribute = parameter_class_attribute
|
||||||
if metadata.source_url is None and org_component is not None and model_full_name_component is not None:
|
|
||||||
metadata.source_url = f"https://huggingface.co/{org_component}/{model_full_name_component}"
|
|
||||||
if metadata.source_hf_repo is None and org_component is not None and model_full_name_component is not None:
|
|
||||||
metadata.source_hf_repo = f"{org_component}/{model_full_name_component}"
|
|
||||||
|
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_any_missing_uuid(metadata: Metadata) -> Metadata:
|
||||||
|
|
||||||
class TestStringMethods(unittest.TestCase):
|
# UUID Generation if not already provided
|
||||||
|
if metadata.uuid is None:
|
||||||
|
# Generate UUID based on provided links/id. UUIDv4 used as fallback
|
||||||
|
new_uuid = None
|
||||||
|
|
||||||
def test_get_model_id_components(self):
|
if metadata.doi is not None:
|
||||||
self.assertEqual(Metadata.get_model_id_components("Mistral/Mixtral-8x7B-Instruct-v0.1"),
|
new_uuid = uuid.uuid5(uuid.NAMESPACE_URL, f"https://doi.org/{metadata.doi}")
|
||||||
('Mixtral-8x7B-Instruct-v0.1', "Mistral", 'Mixtral', 'Instruct', 'v0.1', '8x7B'))
|
elif metadata.repo_url is not None:
|
||||||
self.assertEqual(Metadata.get_model_id_components("Mixtral-8x7B-Instruct-v0.1"),
|
new_uuid = uuid.uuid5(uuid.NAMESPACE_URL, metadata.repo_url)
|
||||||
('Mixtral-8x7B-Instruct-v0.1', None, 'Mixtral', 'Instruct', 'v0.1', '8x7B'))
|
elif metadata.url is not None:
|
||||||
self.assertEqual(Metadata.get_model_id_components("Mixtral-8x7B-Instruct"),
|
new_uuid = uuid.uuid5(uuid.NAMESPACE_URL, metadata.url)
|
||||||
('Mixtral-8x7B-Instruct', None, 'Mixtral', 'Instruct', None, '8x7B'))
|
else:
|
||||||
self.assertEqual(Metadata.get_model_id_components("Mixtral-8x7B-v0.1"),
|
new_uuid = uuid.uuid4() # every model must have at least a random UUIDv4
|
||||||
('Mixtral-8x7B-v0.1', None, 'Mixtral', None, 'v0.1', '8x7B'))
|
|
||||||
self.assertEqual(Metadata.get_model_id_components("Mixtral-8x7B"),
|
|
||||||
('Mixtral-8x7B', None, 'Mixtral', None, None, '8x7B'))
|
|
||||||
self.assertEqual(Metadata.get_model_id_components("Mixtral"),
|
|
||||||
('Mixtral', None, 'Mixtral', None, None, None))
|
|
||||||
self.assertEqual(Metadata.get_model_id_components("Mixtral-v0.1"),
|
|
||||||
('Mixtral-v0.1', None, 'Mixtral', None, 'v0.1', None))
|
|
||||||
self.assertEqual(Metadata.get_model_id_components("hermes-2-pro-llama-3-8b-DPO"),
|
|
||||||
('hermes-2-pro-llama-3-8b-DPO', None, 'hermes-2-pro-llama-3', 'DPO', None, '8b'))
|
|
||||||
self.assertEqual(Metadata.get_model_id_components("NousResearch/Meta-Llama-3-8B"),
|
|
||||||
('Meta-Llama-3-8B', "NousResearch", 'Meta-Llama-3', None, None, "8B"))
|
|
||||||
|
|
||||||
def test_apply_metadata_heuristic_from_model_card(self):
|
if new_uuid is not None:
|
||||||
# Source: https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B/blob/main/README.md
|
metadata.uuid = str(new_uuid)
|
||||||
model_card = {
|
|
||||||
'base_model': 'NousResearch/Meta-Llama-3-8B',
|
|
||||||
'tags': ['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl'],
|
|
||||||
'model-index': [{'name': 'Hermes-2-Pro-Llama-3-8B', 'results': []}],
|
|
||||||
'language': ['en'],
|
|
||||||
'datasets': ['teknium/OpenHermes-2.5'],
|
|
||||||
'widget': [{'example_title': 'Hermes 2 Pro', 'messages': [{'role': 'system', 'content': 'You are a sentient, superintelligent artificial general intelligence, here to teach and assist me.'}, {'role': 'user', 'content': 'Write a short story about Goku discovering kirby has teamed up with Majin Buu to destroy the world.'}]}]
|
|
||||||
}
|
|
||||||
expected = Metadata(name='Meta Llama 3 8B', basename='Meta-Llama-3', finetune=None, author=None, quantized_by=None, organization='NousResearch', version=None, url=None, doi=None, uuid=None, hf_repo=None, description=None, license=None, license_name=None, license_link=None, source_url='https://huggingface.co/NousResearch/Meta-Llama-3-8B', source_doi=None, source_uuid=None, source_hf_repo='NousResearch/Meta-Llama-3-8B', parameter_class_attribute='8B', parents=None, tags=['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl'], languages=['en'], datasets=['teknium/OpenHermes-2.5'])
|
|
||||||
|
|
||||||
got = Metadata.apply_metadata_heuristic(Metadata(), model_card, None, None)
|
if metadata.source_uuid is None:
|
||||||
|
# Generate a UUID based on provided links/id only if source provided
|
||||||
|
new_uuid = None
|
||||||
|
|
||||||
self.assertEqual(got, expected)
|
if metadata.source_doi is not None:
|
||||||
|
new_uuid = uuid.uuid5(uuid.NAMESPACE_URL, f"https://doi.org/{metadata.source_doi}")
|
||||||
|
elif metadata.source_repo_url is not None:
|
||||||
|
new_uuid = uuid.uuid5(uuid.NAMESPACE_URL, metadata.source_repo_url)
|
||||||
|
elif metadata.source_url is not None:
|
||||||
|
new_uuid = uuid.uuid5(uuid.NAMESPACE_URL, metadata.source_url)
|
||||||
|
|
||||||
def test_apply_metadata_heuristic_from_hf_parameters(self):
|
if new_uuid is not None:
|
||||||
# Source: https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B/blob/main/config.json
|
metadata.source_uuid = str(new_uuid)
|
||||||
hf_params = {"_name_or_path": "./hermes-2-pro-llama-3-8b-DPO"}
|
|
||||||
expected = Metadata(name='Hermes 2 Pro Llama 3 8B DPO', basename='hermes-2-pro-llama-3', finetune='DPO', author=None, quantized_by=None, organization=None, version=None, url=None, doi=None, uuid=None, hf_repo=None, description=None, license=None, license_name=None, license_link=None, source_url=None, source_doi=None, source_uuid=None, source_hf_repo=None, parameter_class_attribute='8b', parents=None, tags=None, languages=None, datasets=None)
|
|
||||||
got = Metadata.apply_metadata_heuristic(Metadata(), None, hf_params, None)
|
|
||||||
self.assertEqual(got, expected)
|
|
||||||
|
|
||||||
def test_apply_metadata_heuristic_from_model_dir(self):
|
if metadata.base_models is not None:
|
||||||
# Source: https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B/blob/main/config.json
|
for model_entry in metadata.base_models:
|
||||||
model_dir_path = Path("./hermes-2-pro-llama-3-8b-DPO")
|
if "uuid" not in model_entry:
|
||||||
expected = Metadata(name='Hermes 2 Pro Llama 3 8B DPO', basename='hermes-2-pro-llama-3', finetune='DPO', author=None, quantized_by=None, organization=None, version=None, url=None, doi=None, uuid=None, hf_repo=None, description=None, license=None, license_name=None, license_link=None, source_url=None, source_doi=None, source_uuid=None, source_hf_repo=None, parameter_class_attribute='8b', parents=None, tags=None, languages=None, datasets=None)
|
# Generate a UUID based on provided links/id only if source provided
|
||||||
got = Metadata.apply_metadata_heuristic(Metadata(), None, None, model_dir_path)
|
new_uuid = None
|
||||||
self.assertEqual(got, expected)
|
|
||||||
|
|
||||||
|
if "repo_url" in model_entry:
|
||||||
|
new_uuid = uuid.uuid5(uuid.NAMESPACE_URL, model_entry["repo_url"])
|
||||||
|
elif "url" in model_entry:
|
||||||
|
new_uuid = uuid.uuid5(uuid.NAMESPACE_URL, model_entry["url"])
|
||||||
|
elif "doi" in model_entry:
|
||||||
|
new_uuid = uuid.uuid5(uuid.NAMESPACE_URL, model_entry["doi"])
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if new_uuid is not None:
|
||||||
unittest.main()
|
model_entry["uuid"] = str(new_uuid)
|
||||||
|
|
||||||
|
return metadata
|
||||||
|
|
1
gguf-py/tests/__init__.py
Normal file
1
gguf-py/tests/__init__.py
Normal file
|
@ -0,0 +1 @@
|
||||||
|
from .test_metadata import *
|
64
gguf-py/tests/test_metadata.py
Executable file
64
gguf-py/tests/test_metadata.py
Executable file
|
@ -0,0 +1,64 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import gguf # noqa: F401
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetadataMethod(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_id_to_title(self):
|
||||||
|
self.assertEqual(gguf.Metadata.id_to_title("Mixtral-8x7B-Instruct-v0.1"), "Mixtral 8x7B Instruct v0.1")
|
||||||
|
self.assertEqual(gguf.Metadata.id_to_title("Meta-Llama-3-8B"), "Meta Llama 3 8B")
|
||||||
|
self.assertEqual(gguf.Metadata.id_to_title("hermes-2-pro-llama-3-8b-DPO"), "Hermes 2 Pro Llama 3 8b DPO")
|
||||||
|
|
||||||
|
def test_get_model_id_components(self):
|
||||||
|
self.assertEqual(gguf.Metadata.get_model_id_components("Mistral/Mixtral-8x7B-Instruct-v0.1"),
|
||||||
|
('Mixtral-8x7B-Instruct-v0.1', "Mistral", 'Mixtral', 'Instruct', 'v0.1', '8x7B'))
|
||||||
|
self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B-Instruct-v0.1"),
|
||||||
|
('Mixtral-8x7B-Instruct-v0.1', None, 'Mixtral', 'Instruct', 'v0.1', '8x7B'))
|
||||||
|
self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B-Instruct"),
|
||||||
|
('Mixtral-8x7B-Instruct', None, 'Mixtral', 'Instruct', None, '8x7B'))
|
||||||
|
self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B-v0.1"),
|
||||||
|
('Mixtral-8x7B-v0.1', None, 'Mixtral', None, 'v0.1', '8x7B'))
|
||||||
|
self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B"),
|
||||||
|
('Mixtral-8x7B', None, 'Mixtral', None, None, '8x7B'))
|
||||||
|
self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral"),
|
||||||
|
('Mixtral', None, 'Mixtral', None, None, None))
|
||||||
|
self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-v0.1"),
|
||||||
|
('Mixtral-v0.1', None, 'Mixtral', None, 'v0.1', None))
|
||||||
|
self.assertEqual(gguf.Metadata.get_model_id_components("hermes-2-pro-llama-3-8b-DPO"),
|
||||||
|
('hermes-2-pro-llama-3-8b-DPO', None, 'hermes-2-pro-llama-3', 'DPO', None, '8b'))
|
||||||
|
self.assertEqual(gguf.Metadata.get_model_id_components("NousResearch/Meta-Llama-3-8B"),
|
||||||
|
('Meta-Llama-3-8B', "NousResearch", 'Meta-Llama-3', None, None, "8B"))
|
||||||
|
|
||||||
|
def test_apply_metadata_heuristic_from_model_card(self):
|
||||||
|
model_card = {
|
||||||
|
'tags': ['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl'],
|
||||||
|
'model-index': [{'name': 'Hermes-2-Pro-Llama-3-8B', 'results': []}],
|
||||||
|
'language': ['en'],
|
||||||
|
'datasets': ['teknium/OpenHermes-2.5'],
|
||||||
|
'widget': [{'example_title': 'Hermes 2 Pro', 'messages': [{'role': 'system', 'content': 'You are a sentient, superintelligent artificial general intelligence, here to teach and assist me.'}, {'role': 'user', 'content': 'Write a short story about Goku discovering kirby has teamed up with Majin Buu to destroy the world.'}]}],
|
||||||
|
'base_model': ["EmbeddedLLM/Mistral-7B-Merge-14-v0", "janai-hq/trinity-v1"]
|
||||||
|
}
|
||||||
|
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
|
||||||
|
expect = gguf.Metadata(name=None, author=None, version=None, organization=None, finetune=None, basename=None, description=None, quantized_by=None, parameter_class_attribute=None, url=None, doi=None, uuid=None, repo_url=None, license=None, license_name=None, license_link=None, base_models=[{'name': 'Mistral 7B Merge 14 v0', 'organization': 'EmbeddedLLM', 'repo_url': 'https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0'}, {'name': 'Trinity v1', 'organization': 'Janai Hq', 'repo_url': 'https://huggingface.co/janai-hq/trinity-v1'}], tags=['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl'], languages=['en'], datasets=['teknium/OpenHermes-2.5'])
|
||||||
|
self.assertEqual(got, expect)
|
||||||
|
|
||||||
|
def test_apply_metadata_heuristic_from_hf_parameters(self):
|
||||||
|
hf_params = {"_name_or_path": "./hermes-2-pro-llama-3-8b-DPO"}
|
||||||
|
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), None, hf_params, None)
|
||||||
|
expect = gguf.Metadata(name='Hermes 2 Pro Llama 3 8b DPO', author=None, version=None, organization=None, finetune='DPO', basename='hermes-2-pro-llama-3', description=None, quantized_by=None, parameter_class_attribute='8b', url=None, doi=None, uuid=None, repo_url=None, license=None, license_name=None, license_link=None, base_models=None, tags=None, languages=None, datasets=None)
|
||||||
|
self.assertEqual(got, expect)
|
||||||
|
|
||||||
|
def test_apply_metadata_heuristic_from_model_dir(self):
|
||||||
|
model_dir_path = Path("./hermes-2-pro-llama-3-8b-DPO")
|
||||||
|
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), None, None, model_dir_path)
|
||||||
|
expect = gguf.Metadata(name='Hermes 2 Pro Llama 3 8b DPO', author=None, version=None, organization=None, finetune='DPO', basename='hermes-2-pro-llama-3', description=None, quantized_by=None, parameter_class_attribute='8b', url=None, doi=None, uuid=None, repo_url=None, license=None, license_name=None, license_link=None, base_models=None, tags=None, languages=None, datasets=None)
|
||||||
|
self.assertEqual(got, expect)
|
||||||
|
|
||||||
|
def test_generate_any_missing_uuid(self):
|
||||||
|
metadata = gguf.Metadata(repo_url="example.com", source_url="example.com", base_models=[{"doi":"10.57967/hf/2410"},{"doi":"10.47366/sabia.v5n1a3"}])
|
||||||
|
got = gguf.Metadata.generate_any_missing_uuid(metadata)
|
||||||
|
expect = gguf.Metadata(name=None, author=None, version=None, organization=None, finetune=None, basename=None, description=None, quantized_by=None, parameter_class_attribute=None, url=None, doi=None, uuid='a5cf6e8e-4cfa-5f31-a804-6de6d1245e26', repo_url='example.com', source_url='example.com', source_doi=None, source_uuid='a5cf6e8e-4cfa-5f31-a804-6de6d1245e26', source_repo_url=None, license=None, license_name=None, license_link=None, base_models=[{'doi': '10.57967/hf/2410', 'uuid': '26ce8128-2d34-5ea2-bc50-b5b90e21ed71'}, {'doi': '10.47366/sabia.v5n1a3', 'uuid': 'a15b24d6-5657-5d52-aaed-20dad7f4c500'}], tags=None, languages=None, datasets=None)
|
||||||
|
self.assertEqual(got, expect)
|
Loading…
Add table
Add a link
Reference in a new issue