convert-*.py: separated unit test, hf_repo to repo_url

This commit is contained in:
brian khuu 2024-06-08 21:54:20 +10:00
parent d060fcdbe2
commit eaa47f5546
8 changed files with 409 additions and 273 deletions

View file

@ -239,63 +239,71 @@ class Model:
def set_gguf_meta_model(self): def set_gguf_meta_model(self):
self.gguf_writer.add_name(self.metadata.name) self.gguf_writer.add_name(self.metadata.name)
if self.metadata.basename is not None:
self.gguf_writer.add_basename(self.metadata.basename)
if self.metadata.finetune is not None:
self.gguf_writer.add_finetune(self.metadata.finetune)
if self.metadata.author is not None: if self.metadata.author is not None:
self.gguf_writer.add_author(self.metadata.author) self.gguf_writer.add_author(self.metadata.author)
if self.metadata.quantized_by is not None:
self.gguf_writer.add_quantized_by(self.metadata.quantized_by)
if self.metadata.organization is not None:
self.gguf_writer.add_organization(self.metadata.organization)
if self.metadata.version is not None: if self.metadata.version is not None:
self.gguf_writer.add_version(self.metadata.version) self.gguf_writer.add_version(self.metadata.version)
if self.metadata.url is not None: if self.metadata.organization is not None:
self.gguf_writer.add_url(self.metadata.url) self.gguf_writer.add_organization(self.metadata.organization)
if self.metadata.doi is not None:
self.gguf_writer.add_doi(self.metadata.doi) if self.metadata.finetune is not None:
if self.metadata.uuid is not None: self.gguf_writer.add_finetune(self.metadata.finetune)
self.gguf_writer.add_uuid(self.metadata.uuid) if self.metadata.basename is not None:
if self.metadata.hf_repo is not None: self.gguf_writer.add_basename(self.metadata.basename)
self.gguf_writer.add_hf_repo(self.metadata.hf_repo)
if self.metadata.description is not None: if self.metadata.description is not None:
self.gguf_writer.add_description(self.metadata.description) self.gguf_writer.add_description(self.metadata.description)
if self.metadata.quantized_by is not None:
self.gguf_writer.add_quantized_by(self.metadata.quantized_by)
if self.metadata.parameter_class_attribute is not None:
self.gguf_writer.add_parameter_class_attribute(self.metadata.parameter_class_attribute)
if self.metadata.license is not None: if self.metadata.license is not None:
self.gguf_writer.add_license(self.metadata.license) self.gguf_writer.add_license(self.metadata.license)
if self.metadata.license_name is not None: if self.metadata.license_name is not None:
self.gguf_writer.add_license_name(self.metadata.license_name) self.gguf_writer.add_license_name(self.metadata.license_name)
if self.metadata.license_link is not None: if self.metadata.license_link is not None:
self.gguf_writer.add_license_link(self.metadata.license_link) self.gguf_writer.add_license_link(self.metadata.license_link)
if self.metadata.url is not None:
self.gguf_writer.add_url(self.metadata.url)
if self.metadata.doi is not None:
self.gguf_writer.add_doi(self.metadata.doi)
if self.metadata.uuid is not None:
self.gguf_writer.add_uuid(self.metadata.uuid)
if self.metadata.repo_url is not None:
self.gguf_writer.add_repo_url(self.metadata.repo_url)
if self.metadata.source_url is not None: if self.metadata.source_url is not None:
self.gguf_writer.add_source_url(self.metadata.source_url) self.gguf_writer.add_source_url(self.metadata.source_url)
if self.metadata.source_doi is not None: if self.metadata.source_doi is not None:
self.gguf_writer.add_source_doi(self.metadata.source_doi) self.gguf_writer.add_source_doi(self.metadata.source_doi)
if self.metadata.source_uuid is not None: if self.metadata.source_uuid is not None:
self.gguf_writer.add_source_uuid(self.metadata.source_uuid) self.gguf_writer.add_source_uuid(self.metadata.source_uuid)
if self.metadata.source_hf_repo is not None: if self.metadata.source_repo_url is not None:
self.gguf_writer.add_source_hf_repo(self.metadata.source_hf_repo) self.gguf_writer.add_source_repo_url(self.metadata.source_repo_url)
if self.metadata.parameter_class_attribute is not None:
self.gguf_writer.add_parameter_class_attribute(self.metadata.parameter_class_attribute) if self.metadata.base_models is not None:
if self.metadata.parents is not None: self.gguf_writer.add_base_model_count(len(self.metadata.base_models))
metadata.parent_count = len(self.metadata.parents) for key, base_model_entry in enumerate(self.metadata.base_models):
for key, parent_entry in self.metadata.parents: if "name" in base_model_entry:
if "name" in parent_entry: self.gguf_writer.add_base_model_name(key, base_model_entry["name"])
self.gguf_writer.add_parent_name(key, parent_entry.get("name")) if "author" in base_model_entry:
if "author" in parent_entry: self.gguf_writer.add_base_model_author(key, base_model_entry["author"])
self.gguf_writer.add_parent_author(key, parent_entry.get("author")) if "version" in base_model_entry:
if "version" in parent_entry: self.gguf_writer.add_base_model_version(key, base_model_entry["version"])
self.gguf_writer.add_parent_version(key, parent_entry.get("version")) if "organization" in base_model_entry:
if "organization" in parent_entry: self.gguf_writer.add_base_model_organization(key, base_model_entry["organization"])
self.gguf_writer.add_parent_organization(key, parent_entry.get("organization")) if "url" in base_model_entry:
if "url" in parent_entry: self.gguf_writer.add_base_model_url(key, base_model_entry["url"])
self.gguf_writer.add_parent_url(key, parent_entry.get("url")) if "doi" in base_model_entry:
if "doi" in parent_entry: self.gguf_writer.add_base_model_doi(key, base_model_entry["doi"])
self.gguf_writer.add_parent_doi(key, parent_entry.get("doi")) if "uuid" in base_model_entry:
if "uuid" in parent_entry: self.gguf_writer.add_base_model_uuid(key, base_model_entry["uuid"])
self.gguf_writer.add_parent_uuid(key, parent_entry.get("uuid")) if "repo_url" in base_model_entry:
if "hf_repo" in parent_entry: self.gguf_writer.add_base_model_repo_url(key, base_model_entry["repo_url"])
self.gguf_writer.add_parent_hf_repo(key, parent_entry.get("hf_repo"))
if self.metadata.tags is not None: if self.metadata.tags is not None:
self.gguf_writer.add_tags(self.metadata.tags) self.gguf_writer.add_tags(self.metadata.tags)
if self.metadata.languages is not None: if self.metadata.languages is not None:

View file

@ -783,57 +783,71 @@ class OutputFile:
self.gguf.add_name(name) self.gguf.add_name(name)
if metadata.basename is not None:
self.gguf.add_basename(metadata.basename)
if metadata.finetune is not None:
self.gguf.add_finetune(metadata.finetune)
if metadata.author is not None: if metadata.author is not None:
self.gguf.add_author(metadata.author) self.gguf.add_author(metadata.author)
if metadata.organization is not None:
self.add_organization(metadata.organization)
if metadata.version is not None: if metadata.version is not None:
self.gguf.add_version(metadata.version) self.gguf.add_version(metadata.version)
if metadata.url is not None: if metadata.organization is not None:
self.gguf.add_url(metadata.url) self.gguf.add_organization(metadata.organization)
if metadata.doi is not None:
self.gguf.add_doi(metadata.doi) if metadata.finetune is not None:
if metadata.uuid is not None: self.gguf.add_finetune(metadata.finetune)
self.gguf.add_uuid(metadata.uuid) if metadata.basename is not None:
if metadata.hf_repo is not None: self.gguf.add_basename(metadata.basename)
self.gguf.add_hf_repo(metadata.hf_repo)
if metadata.description is not None: if metadata.description is not None:
self.gguf.add_description(metadata.description) self.gguf.add_description(metadata.description)
if metadata.quantized_by is not None:
self.gguf.add_quantized_by(metadata.quantized_by)
if metadata.parameter_class_attribute is not None:
self.gguf.add_parameter_class_attribute(metadata.parameter_class_attribute)
if metadata.license is not None: if metadata.license is not None:
self.gguf.add_license(metadata.license) self.gguf.add_license(metadata.license)
if metadata.license_name is not None: if metadata.license_name is not None:
self.gguf.add_license_name(metadata.license_name) self.gguf.add_license_name(metadata.license_name)
if metadata.license_link is not None: if metadata.license_link is not None:
self.gguf.add_license_link(metadata.license_link) self.gguf.add_license_link(metadata.license_link)
if metadata.url is not None:
self.gguf.add_url(metadata.url)
if metadata.doi is not None:
self.gguf.add_doi(metadata.doi)
if metadata.uuid is not None:
self.gguf.add_uuid(metadata.uuid)
if metadata.repo_url is not None:
self.gguf.add_repo_url(metadata.repo_url)
if metadata.source_url is not None: if metadata.source_url is not None:
self.gguf.add_source_url(metadata.source_url) self.gguf.add_source_url(metadata.source_url)
if metadata.source_hf_repo is not None: if metadata.source_doi is not None:
self.gguf.add_source_hf_repo(metadata.source_hf_repo) self.gguf.add_source_doi(metadata.source_doi)
if metadata.parameter_class_attribute is not None: if metadata.source_uuid is not None:
self.gguf.add_parameter_class_attribute(metadata.parameter_class_attribute) self.gguf.add_source_uuid(metadata.source_uuid)
if metadata.parents is not None: if metadata.source_repo_url is not None:
metadata.parent_count = len(metadata.parents) self.gguf.add_source_repo_url(metadata.source_repo_url)
for key, parent_entry in metadata.parents:
if "name" in parent_entry: if metadata.base_models is not None:
self.gguf.add_parent_name(key, parent_entry.get("name")) self.gguf.add_base_model_count(len(metadata.base_models))
if "author" in parent_entry: for key, base_model_entry in enumerate(metadata.base_models):
self.gguf.add_parent_author(key, parent_entry.get("author")) if "name" in base_model_entry:
if "version" in parent_entry: self.gguf.add_base_model_name(key, base_model_entry["name"])
self.gguf.add_parent_version(key, parent_entry.get("version")) if "author" in base_model_entry:
if "organization" in parent_entry: self.gguf.add_base_model_author(key, base_model_entry["author"])
self.gguf.add_parent_organization(key, parent_entry.get("organization")) if "version" in base_model_entry:
if "url" in parent_entry: self.gguf.add_base_model_version(key, base_model_entry["version"])
self.gguf.add_parent_url(key, parent_entry.get("url")) if "organization" in base_model_entry:
if "doi" in parent_entry: self.gguf.add_base_model_organization(key, base_model_entry["organization"])
self.gguf.add_parent_doi(key, parent_entry.get("doi")) if "url" in base_model_entry:
if "uuid" in parent_entry: self.gguf.add_base_model_url(key, base_model_entry["url"])
self.gguf.add_parent_uuid(key, parent_entry.get("uuid")) if "doi" in base_model_entry:
if "hf_repo" in parent_entry: self.gguf.add_base_model_doi(key, base_model_entry["doi"])
self.gguf.add_parent_hf_repo(key, parent_entry.get("hf_repo")) if "uuid" in base_model_entry:
self.gguf.add_base_model_uuid(key, base_model_entry["uuid"])
if "repo_url" in base_model_entry:
self.gguf.add_base_model_repo_url(key, base_model_entry["repo_url"])
if metadata.tags is not None: if metadata.tags is not None:
self.gguf.add_tags(metadata.tags) self.gguf.add_tags(metadata.tags)
if metadata.languages is not None: if metadata.languages is not None:

View file

@ -78,5 +78,13 @@ python -m build
python -m twine upload dist/* python -m twine upload dist/*
``` ```
## Run Unit Tests
From root of this repository you can run this command to run all the unit tests
```bash
python -m unittest discover ./gguf-py -v
```
## TODO ## TODO
- [ ] Include conversion scripts as command line entry points in this package. - [ ] Include conversion scripts as command line entry points in this package.

View file

@ -31,10 +31,12 @@ class Keys:
VERSION = "general.version" VERSION = "general.version"
ORGANIZATION = "general.organization" ORGANIZATION = "general.organization"
BASENAME = "general.basename"
FINETUNE = "general.finetune" FINETUNE = "general.finetune"
BASENAME = "general.basename"
DESCRIPTION = "general.description" DESCRIPTION = "general.description"
QUANTIZED_BY = "general.quantized_by" QUANTIZED_BY = "general.quantized_by"
PARAMETER_CLASS_ATTRIBUTE = "general.parameter_class_attribute" PARAMETER_CLASS_ATTRIBUTE = "general.parameter_class_attribute"
# Licensing details # Licensing details
@ -43,31 +45,29 @@ class Keys:
LICENSE_LINK = "general.license.link" LICENSE_LINK = "general.license.link"
# Typically represents the converted GGUF repo (Unless native) # Typically represents the converted GGUF repo (Unless native)
URL = "general.url" URL = "general.url" # Model Website/Paper
DOI = "general.doi" DOI = "general.doi"
UUID = "general.uuid" UUID = "general.uuid"
HF_URL = "general.huggingface.repository" REPO_URL = "general.repo_url" # Model Source Repository (git/svn/etc...)
# Typically represents the original source repository (e.g. safetensors) # Model Source during conversion
# that this was SOURCE_URL = "general.source.url" # Model Website/Paper
SOURCE_URL = "general.source.url"
SOURCE_DOI = "general.source.doi" SOURCE_DOI = "general.source.doi"
SOURCE_UUID = "general.source.uuid" SOURCE_UUID = "general.source.uuid"
SOURCE_HF_REPO = "general.source.huggingface.repository" SOURCE_REPO_URL = "general.source.repo_url" # Model Source Repository (git/svn/etc...)
# This represents the parent model that the converted/source model was # Base Model Source. There can be more than one source if it's a merged
# derived from on allowing users to trace the linage of a model. # model like with 'Mistral-7B-Merge-14-v0.1'. This will assist in
# E.g. A finetune model would have the base model as the parent # tracing linage of models as it is finetuned or merged over time.
# (A model can have multiple parent, especially if it's a merged model) BASE_MODEL_COUNT = "general.base_model.count"
PARENTS_COUNT = "general.parents.count" BASE_MODEL_NAME = "general.base_model.{id}.name"
PARENTS_NAME = "general.parents.{id}.name" BASE_MODEL_AUTHOR = "general.base_model.{id}.author"
PARENTS_AUTHOR = "general.parents.{id}.author" BASE_MODEL_VERSION = "general.base_model.{id}.version"
PARENTS_VERSION = "general.parents.{id}.version" BASE_MODEL_ORGANIZATION = "general.base_model.{id}.organization"
PARENTS_ORGANIZATION = "general.parents.{id}.organization" BASE_MODEL_URL = "general.base_model.{id}.url" # Model Website/Paper
PARENTS_URL = "general.parents.{id}.url" BASE_MODEL_DOI = "general.base_model.{id}.doi"
PARENTS_DOI = "general.parents.{id}.doi" BASE_MODEL_UUID = "general.base_model.{id}.uuid"
PARENTS_UUID = "general.parents.{id}.uuid" BASE_MODEL_REPO_URL = "general.base_model.{id}.repo_url" # Model Source Repository (git/svn/etc...)
PARENTS_HF_REPO = "general.parents.{id}.huggingface.repository"
# Array based KV stores # Array based KV stores
TAGS = "general.tags" TAGS = "general.tags"
@ -1273,8 +1273,6 @@ KEY_GENERAL_AUTHOR = Keys.General.AUTHOR
KEY_GENERAL_URL = Keys.General.URL KEY_GENERAL_URL = Keys.General.URL
KEY_GENERAL_DESCRIPTION = Keys.General.DESCRIPTION KEY_GENERAL_DESCRIPTION = Keys.General.DESCRIPTION
KEY_GENERAL_LICENSE = Keys.General.LICENSE KEY_GENERAL_LICENSE = Keys.General.LICENSE
KEY_GENERAL_SOURCE_URL = Keys.General.SOURCE_URL
KEY_GENERAL_SOURCE_HF_REPO = Keys.General.SOURCE_HF_REPO
KEY_GENERAL_FILE_TYPE = Keys.General.FILE_TYPE KEY_GENERAL_FILE_TYPE = Keys.General.FILE_TYPE
# LLM # LLM

View file

@ -430,42 +430,43 @@ class GGUFWriter:
def add_architecture(self) -> None: def add_architecture(self) -> None:
self.add_string(Keys.General.ARCHITECTURE, self.arch) self.add_string(Keys.General.ARCHITECTURE, self.arch)
def add_basename(self, basename: str) -> None: def add_quantization_version(self, quantization_version: int) -> None:
self.add_string(Keys.General.BASENAME, basename) self.add_uint32(Keys.General.QUANTIZATION_VERSION, quantization_version)
def add_finetune(self, finetune: str) -> None: def add_custom_alignment(self, alignment: int) -> None:
self.add_string(Keys.General.FINETUNE, finetune) self.data_alignment = alignment
self.add_uint32(Keys.General.ALIGNMENT, alignment)
def add_file_type(self, ftype: int) -> None:
self.add_uint32(Keys.General.FILE_TYPE, ftype)
def add_name(self, name: str) -> None:
self.add_string(Keys.General.NAME, name)
def add_author(self, author: str) -> None: def add_author(self, author: str) -> None:
self.add_string(Keys.General.AUTHOR, author) self.add_string(Keys.General.AUTHOR, author)
def add_quantized_by(self, quantized: str) -> None: def add_version(self, version: str) -> None:
self.add_string(Keys.General.QUANTIZED_BY, quantized) self.add_string(Keys.General.VERSION, version)
def add_organization(self, organization: str) -> None: def add_organization(self, organization: str) -> None:
self.add_string(Keys.General.ORGANIZATION, organization) self.add_string(Keys.General.ORGANIZATION, organization)
def add_version(self, version: str) -> None: def add_finetune(self, finetune: str) -> None:
self.add_string(Keys.General.VERSION, version) self.add_string(Keys.General.FINETUNE, finetune)
def add_tensor_data_layout(self, layout: str) -> None: def add_basename(self, basename: str) -> None:
self.add_string(Keys.LLM.TENSOR_DATA_LAYOUT.format(arch=self.arch), layout) self.add_string(Keys.General.BASENAME, basename)
def add_url(self, url: str) -> None:
self.add_string(Keys.General.URL, url)
def add_doi(self, doi: str) -> None:
self.add_string(Keys.General.DOI, doi)
def add_uuid(self, uuid: str) -> None:
self.add_string(Keys.General.UUID, uuid)
def add_hf_repo(self, hf_repo: str) -> None:
self.add_string(Keys.General.HF_REPO, hf_repo)
def add_description(self, description: str) -> None: def add_description(self, description: str) -> None:
self.add_string(Keys.General.DESCRIPTION, description) self.add_string(Keys.General.DESCRIPTION, description)
def add_quantized_by(self, quantized: str) -> None:
self.add_string(Keys.General.QUANTIZED_BY, quantized)
def add_parameter_class_attribute(self, parameter_class_attribute: str) -> None:
self.add_string(Keys.General.PARAMETER_CLASS_ATTRIBUTE, parameter_class_attribute)
def add_license(self, license: str) -> None: def add_license(self, license: str) -> None:
self.add_string(Keys.General.LICENSE, license) self.add_string(Keys.General.LICENSE, license)
@ -475,44 +476,56 @@ class GGUFWriter:
def add_license_link(self, license: str) -> None: def add_license_link(self, license: str) -> None:
self.add_string(Keys.General.LICENSE_LINK, license) self.add_string(Keys.General.LICENSE_LINK, license)
def add_url(self, url: str) -> None:
self.add_string(Keys.General.URL, url)
def add_doi(self, doi: str) -> None:
self.add_string(Keys.General.DOI, doi)
def add_uuid(self, uuid: str) -> None:
self.add_string(Keys.General.UUID, uuid)
def add_repo_url(self, repo_url: str) -> None:
self.add_string(Keys.General.REPO_URL, repo_url)
def add_source_url(self, url: str) -> None: def add_source_url(self, url: str) -> None:
self.add_string(Keys.General.SOURCE_URL, url) self.add_string(Keys.General.SOURCE_URL, url)
def add_source_hf_repo(self, repo: str) -> None: def add_source_doi(self, doi: str) -> None:
self.add_string(Keys.General.SOURCE_HF_REPO, repo) self.add_string(Keys.General.SOURCE_DOI, doi)
def add_file_type(self, ftype: int) -> None: def add_source_uuid(self, uuid: str) -> None:
self.add_uint32(Keys.General.FILE_TYPE, ftype) self.add_string(Keys.General.SOURCE_UUID, uuid)
def add_parameter_class_attribute(self, parameter_class_attribute: str) -> None: def add_source_repo_url(self, repo_url: str) -> None:
self.add_string(Keys.General.PARAMETER_CLASS_ATTRIBUTE, parameter_class_attribute) self.add_string(Keys.General.SOURCE_REPO_URL, repo_url)
def add_parent_count(self, parent_count: int) -> None: def add_base_model_count(self, source_count: int) -> None:
self.add_uint32(Keys.General.PARENTS_COUNT, parent_count) self.add_uint32(Keys.General.BASE_MODEL_COUNT, source_count)
def add_parent_name(self, parent_id: int, name: str) -> None: def add_base_model_name(self, source_id: int, name: str) -> None:
self.add_string(Keys.General.PARENTS_NAME.format(id=self.parent_id), name) self.add_string(Keys.General.BASE_MODEL_NAME.format(id=self.source_id), name)
def add_parent_author(self, parent_id: int, author: str) -> None: def add_base_model_author(self, source_id: int, author: str) -> None:
self.add_string(Keys.General.PARENTS_AUTHOR.format(id=self.parent_id), author) self.add_string(Keys.General.BASE_MODEL_AUTHOR.format(id=self.source_id), author)
def add_parent_version(self, parent_id: int, version: str) -> None: def add_base_model_version(self, source_id: int, version: str) -> None:
self.add_string(Keys.General.PARENTS_VERSION.format(id=self.parent_id), version) self.add_string(Keys.General.BASE_MODEL_VERSION.format(id=self.source_id), version)
def add_parent_organization(self, parent_id: int, organization: str) -> None: def add_base_model_organization(self, source_id: int, organization: str) -> None:
self.add_string(Keys.General.PARENTS_ORGANIZATION.format(id=self.parent_id), organization) self.add_string(Keys.General.BASE_MODEL_ORGANIZATION.format(id=self.source_id), organization)
def add_parent_url(self, parent_id: int, url: str) -> None: def add_base_model_url(self, source_id: int, url: str) -> None:
self.add_string(Keys.General.PARENTS_URL.format(id=self.parent_id), url) self.add_string(Keys.General.BASE_MODEL_URL.format(id=self.source_id), url)
def add_parent_doi(self, parent_id: int, doi: str) -> None: def add_base_model_doi(self, source_id: int, doi: str) -> None:
self.add_string(Keys.General.PARENTS_DOI.format(id=self.parent_id), doi) self.add_string(Keys.General.BASE_MODEL_DOI.format(id=self.source_id), doi)
def add_parent_uuid(self, parent_id: int, uuid: str) -> None: def add_base_model_uuid(self, source_id: int, uuid: str) -> None:
self.add_string(Keys.General.PARENTS_UUID.format(id=self.parent_id), uuid) self.add_string(Keys.General.BASE_MODEL_UUID.format(id=self.source_id), uuid)
def add_parent_hf_repo(self, parent_id: int, hf_repo: str) -> None: def add_base_model_repo_url(self, source_id: int, repo_url: str) -> None:
self.add_string(Keys.General.PARENTS_HF_REPO.format(id=self.parent_id), hf_repo) self.add_string(Keys.General.BASE_MODEL_REPO_URL.format(id=self.source_id), repo_url)
def add_tags(self, tags: Sequence[str]) -> None: def add_tags(self, tags: Sequence[str]) -> None:
self.add_array(Keys.General.TAGS, tags) self.add_array(Keys.General.TAGS, tags)
@ -523,16 +536,8 @@ class GGUFWriter:
def add_datasets(self, datasets: Sequence[str]) -> None: def add_datasets(self, datasets: Sequence[str]) -> None:
self.add_array(Keys.General.DATASETS, datasets) self.add_array(Keys.General.DATASETS, datasets)
def add_name(self, name: str) -> None: def add_tensor_data_layout(self, layout: str) -> None:
self.add_string(Keys.General.NAME, name) self.add_string(Keys.LLM.TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
def add_quantization_version(self, quantization_version: int) -> None:
self.add_uint32(
Keys.General.QUANTIZATION_VERSION, quantization_version)
def add_custom_alignment(self, alignment: int) -> None:
self.data_alignment = alignment
self.add_uint32(Keys.General.ALIGNMENT, alignment)
def add_vocab_size(self, size: int) -> None: def add_vocab_size(self, size: int) -> None:
self.add_uint32(Keys.LLM.VOCAB_SIZE.format(arch=self.arch), size) self.add_uint32(Keys.LLM.VOCAB_SIZE.format(arch=self.arch), size)

270
gguf-py/gguf/metadata.py Executable file → Normal file
View file

@ -1,45 +1,40 @@
#!/usr/bin/env python3
from __future__ import annotations from __future__ import annotations
import re import re
import json import json
import unittest import uuid
import frontmatter import frontmatter
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
from dataclasses import dataclass from dataclasses import dataclass
if __name__ == '__main__': from .constants import Keys
from constants import Keys
else:
from .constants import Keys
@dataclass @dataclass
class Metadata: class Metadata:
# Authorship Metadata to be written to GGUF KV Store # Authorship Metadata to be written to GGUF KV Store
name: Optional[str] = None name: Optional[str] = None
basename: Optional[str] = None
finetune: Optional[str] = None
author: Optional[str] = None author: Optional[str] = None
quantized_by: Optional[str] = None
organization: Optional[str] = None
version: Optional[str] = None version: Optional[str] = None
organization: Optional[str] = None
finetune: Optional[str] = None
basename: Optional[str] = None
description: Optional[str] = None
quantized_by: Optional[str] = None
parameter_class_attribute: Optional[str] = None
url: Optional[str] = None url: Optional[str] = None
doi: Optional[str] = None doi: Optional[str] = None
uuid: Optional[str] = None uuid: Optional[str] = None
hf_repo: Optional[str] = None repo_url: Optional[str] = None
description: Optional[str] = None
license: Optional[str] = None
license_name: Optional[str] = None
license_link: Optional[str] = None
source_url: Optional[str] = None source_url: Optional[str] = None
source_doi: Optional[str] = None source_doi: Optional[str] = None
source_uuid: Optional[str] = None source_uuid: Optional[str] = None
source_hf_repo: Optional[str] = None source_repo_url: Optional[str] = None
parameter_class_attribute: Optional[str] = None license: Optional[str] = None
parents: Optional[list[dict]] = None license_name: Optional[str] = None
license_link: Optional[str] = None
base_models: Optional[list[dict]] = None
tags: Optional[list[str]] = None tags: Optional[list[str]] = None
languages: Optional[list[str]] = None languages: Optional[list[str]] = None
datasets: Optional[list[str]] = None datasets: Optional[list[str]] = None
@ -68,10 +63,12 @@ class Metadata:
metadata.version = metadata_override.get(Keys.General.VERSION , metadata.version ) # noqa: E202 metadata.version = metadata_override.get(Keys.General.VERSION , metadata.version ) # noqa: E202
metadata.organization = metadata_override.get(Keys.General.ORGANIZATION , metadata.organization ) # noqa: E202 metadata.organization = metadata_override.get(Keys.General.ORGANIZATION , metadata.organization ) # noqa: E202
metadata.basename = metadata_override.get(Keys.General.BASENAME , metadata.basename ) # noqa: E202
metadata.finetune = metadata_override.get(Keys.General.FINETUNE , metadata.finetune ) # noqa: E202 metadata.finetune = metadata_override.get(Keys.General.FINETUNE , metadata.finetune ) # noqa: E202
metadata.basename = metadata_override.get(Keys.General.BASENAME , metadata.basename ) # noqa: E202
metadata.description = metadata_override.get(Keys.General.DESCRIPTION , metadata.description ) # noqa: E202 metadata.description = metadata_override.get(Keys.General.DESCRIPTION , metadata.description ) # noqa: E202
metadata.quantized_by = metadata_override.get(Keys.General.QUANTIZED_BY , metadata.quantized_by ) # noqa: E202 metadata.quantized_by = metadata_override.get(Keys.General.QUANTIZED_BY , metadata.quantized_by ) # noqa: E202
metadata.parameter_class_attribute = metadata_override.get(Keys.General.PARAMETER_CLASS_ATTRIBUTE, metadata.parameter_class_attribute) # noqa: E202 metadata.parameter_class_attribute = metadata_override.get(Keys.General.PARAMETER_CLASS_ATTRIBUTE, metadata.parameter_class_attribute) # noqa: E202
metadata.license = metadata_override.get(Keys.General.LICENSE , metadata.license ) # noqa: E202 metadata.license = metadata_override.get(Keys.General.LICENSE , metadata.license ) # noqa: E202
@ -81,14 +78,14 @@ class Metadata:
metadata.url = metadata_override.get(Keys.General.URL , metadata.url ) # noqa: E202 metadata.url = metadata_override.get(Keys.General.URL , metadata.url ) # noqa: E202
metadata.doi = metadata_override.get(Keys.General.DOI , metadata.doi ) # noqa: E202 metadata.doi = metadata_override.get(Keys.General.DOI , metadata.doi ) # noqa: E202
metadata.uuid = metadata_override.get(Keys.General.UUID , metadata.uuid ) # noqa: E202 metadata.uuid = metadata_override.get(Keys.General.UUID , metadata.uuid ) # noqa: E202
metadata.hf_repo = metadata_override.get(Keys.General.HF_REPO , metadata.hf_repo ) # noqa: E202 metadata.repo_url = metadata_override.get(Keys.General.REPO_URL , metadata.repo_url ) # noqa: E202
metadata.source_url = metadata_override.get(Keys.General.SOURCE_URL , metadata.source_url ) # noqa: E202 metadata.source_url = metadata_override.get(Keys.General.SOURCE_URL , metadata.source_url ) # noqa: E202
metadata.source_doi = metadata_override.get(Keys.General.SOURCE_DOI , metadata.source_doi ) # noqa: E202 metadata.source_doi = metadata_override.get(Keys.General.SOURCE_DOI , metadata.source_doi ) # noqa: E202
metadata.source_uuid = metadata_override.get(Keys.General.SOURCE_UUID , metadata.source_uuid ) # noqa: E202 metadata.source_uuid = metadata_override.get(Keys.General.SOURCE_UUID , metadata.source_uuid ) # noqa: E202
metadata.source_hf_repo = metadata_override.get(Keys.General.SOURCE_HF_REPO , metadata.source_hf_repo ) # noqa: E202 metadata.source_repo_url = metadata_override.get(Keys.General.SOURCE_REPO_URL , metadata.source_repo_url ) # noqa: E202
metadata.parent_count = metadata_override.get("general.parents" , metadata.parent_count ) # noqa: E202 metadata.base_models = metadata_override.get("general.base_models" , metadata.base_models ) # noqa: E202
metadata.tags = metadata_override.get(Keys.General.TAGS , metadata.tags ) # noqa: E202 metadata.tags = metadata_override.get(Keys.General.TAGS , metadata.tags ) # noqa: E202
metadata.languages = metadata_override.get(Keys.General.LANGUAGES , metadata.languages ) # noqa: E202 metadata.languages = metadata_override.get(Keys.General.LANGUAGES , metadata.languages ) # noqa: E202
@ -98,6 +95,9 @@ class Metadata:
if model_name is not None: if model_name is not None:
metadata.name = model_name metadata.name = model_name
# If any UUID is still missing at this point, then we should fill it in
metadata = Metadata.generate_any_missing_uuid(metadata)
return metadata return metadata
@staticmethod @staticmethod
@ -137,8 +137,7 @@ class Metadata:
@staticmethod @staticmethod
def id_to_title(string): def id_to_title(string):
# Convert capitalization into title form unless acronym or version number # Convert capitalization into title form unless acronym or version number
string = string.strip().replace('-', ' ') return ' '.join([w.title() if w.islower() and not re.match(r'^(v\d+(?:\.\d+)*|\d.*)$', w) else w for w in string.strip().replace('-', ' ').split()])
return ' '.join([w.title() if w.islower() and not re.match(r'^v\d+(?:\.\d+)*$', w) else w for w in string.split()])
@staticmethod @staticmethod
def get_model_id_components(model_id: Optional[str] = None) -> dict[str, object]: def get_model_id_components(model_id: Optional[str] = None) -> dict[str, object]:
@ -186,67 +185,119 @@ class Metadata:
@staticmethod @staticmethod
def apply_metadata_heuristic(metadata: Metadata, model_card: Optional[dict] = None, hf_params: Optional[dict] = None, model_path: Optional[Path] = None) -> Metadata: def apply_metadata_heuristic(metadata: Metadata, model_card: Optional[dict] = None, hf_params: Optional[dict] = None, model_path: Optional[Path] = None) -> Metadata:
# Reference Model Card Metadata: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1 # Reference Model Card Metadata: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
found_base_model = False
# Model Card Heuristics # Model Card Heuristics
######################## ########################
if model_card is not None: if model_card is not None:
if "model_name" in model_card: if "model_name" in model_card and metadata.name is None:
# Not part of huggingface model card standard but notice some model creator using it # Not part of huggingface model card standard but notice some model creator using it
# such as TheBloke who would encode 'Mixtral 8X7B Instruct v0.1' into model_name # such as TheBloke in 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF'
metadata.name = model_card.get("model_name") metadata.name = model_card.get("model_name")
if "base_model" in model_card and isinstance(model_card["base_model"], str) and not found_base_model: if "model_creator" in model_card and metadata.author is None:
# Check if string. We cannot handle lists as that is too ambagious # Not part of huggingface model card standard but notice some model creator using it
# such as TheBloke in 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF'
metadata.author = model_card.get("model_creator")
if "model_type" in model_card and metadata.basename is None:
# Not part of huggingface model card standard but notice some model creator using it
# such as TheBloke in 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF'
metadata.basename = model_card.get("model_type")
if "base_model" in model_card:
# This represents the parent models that this is based on
# Example: stabilityai/stable-diffusion-xl-base-1.0. Can also be a list (for merges) # Example: stabilityai/stable-diffusion-xl-base-1.0. Can also be a list (for merges)
model_id = model_card.get("base_model") # Example of merges: https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0.1/blob/main/README.md
metadata_base_models = []
base_model_value = model_card.get("base_model", None)
if base_model_value is not None:
if isinstance(base_model_value, str):
metadata_base_models.append(base_model_value)
elif isinstance(base_model_value, list):
metadata_base_models.extend(base_model_value)
if metadata.base_models is None:
metadata.base_models = []
for model_id in metadata_base_models:
model_full_name_component, org_component, basename, finetune, version, parameter_class_attribute = Metadata.get_model_id_components(model_id) model_full_name_component, org_component, basename, finetune, version, parameter_class_attribute = Metadata.get_model_id_components(model_id)
if metadata.name is None and model_full_name_component is not None: base_model = {}
metadata.name = Metadata.id_to_title(model_full_name_component) if model_full_name_component is not None:
if metadata.organization is None and org_component is not None: base_model["name"] = Metadata.id_to_title(model_full_name_component)
metadata.organization = Metadata.id_to_title(org_component) if org_component is not None:
if metadata.basename is None and basename is not None: base_model["organization"] = Metadata.id_to_title(org_component)
metadata.basename = basename if version is not None:
if metadata.finetune is None and finetune is not None: base_model["version"] = version
metadata.finetune = finetune if org_component is not None and model_full_name_component is not None:
if metadata.version is None and version is not None: base_model["repo_url"] = f"https://huggingface.co/{org_component}/{model_full_name_component}"
metadata.version = version metadata.base_models.append(base_model)
if metadata.parameter_class_attribute is None and parameter_class_attribute is not None:
metadata.parameter_class_attribute = parameter_class_attribute
if metadata.source_url is None and org_component is not None and model_full_name_component is not None:
metadata.source_url = f"https://huggingface.co/{org_component}/{model_full_name_component}"
if metadata.source_hf_repo is None and org_component is not None and model_full_name_component is not None:
metadata.source_hf_repo = f"{org_component}/{model_full_name_component}"
found_base_model = True if "quantized_by" in model_card and metadata.quantized_by is None:
if metadata.quantized_by is None:
# Not part of hugging face model card standard, but is used by TheBloke to credit them self for quantizing 3rd party models # Not part of hugging face model card standard, but is used by TheBloke to credit them self for quantizing 3rd party models
metadata.quantized_by = model_card.get("quantized_by") metadata.quantized_by = model_card.get("quantized_by")
if metadata.license is None:
if "license" in model_card and metadata.license is None:
metadata.license = model_card.get("license") metadata.license = model_card.get("license")
if metadata.license_name is None:
if "license_name" in model_card and metadata.license_name is None:
metadata.license_name = model_card.get("license_name") metadata.license_name = model_card.get("license_name")
if metadata.license_link is None:
if "license_link" in model_card and metadata.license_link is None:
metadata.license_link = model_card.get("license_link") metadata.license_link = model_card.get("license_link")
if metadata.author is None:
# non huggingface model card standard but notice some model creator using it tags_value = model_card.get("tags", None)
metadata.author = model_card.get("model_creator") if tags_value is not None:
if metadata.tags is None: if metadata.tags is None:
metadata.tags = model_card.get("tags", None) metadata.tags = []
if isinstance(tags_value, str):
metadata.tags.append(tags_value)
elif isinstance(tags_value, list):
metadata.tags.extend(tags_value)
pipeline_tags_value = model_card.get("pipeline_tag", None)
if pipeline_tags_value is not None:
if metadata.tags is None:
metadata.tags = []
if isinstance(pipeline_tags_value, str):
metadata.tags.append(pipeline_tags_value)
elif isinstance(pipeline_tags_value, list):
metadata.tags.extend(pipeline_tags_value)
language_value = model_card.get("languages", model_card.get("language", None))
if language_value is not None:
if metadata.languages is None: if metadata.languages is None:
metadata.languages = model_card.get("language", model_card.get("languages", None)) metadata.languages = []
if isinstance(language_value, str):
metadata.languages.append(language_value)
elif isinstance(language_value, list):
metadata.languages.extend(language_value)
dataset_value = model_card.get("datasets", model_card.get("dataset", None))
if dataset_value is not None:
if metadata.datasets is None: if metadata.datasets is None:
metadata.datasets = model_card.get("datasets", model_card.get("dataset", None)) metadata.datasets = []
if isinstance(dataset_value, str):
metadata.datasets.append(dataset_value)
elif isinstance(dataset_value, list):
metadata.datasets.extend(dataset_value)
# Hugging Face Parameter Heuristics # Hugging Face Parameter Heuristics
#################################### ####################################
if hf_params is not None: if hf_params is not None:
hf_name_or_path = hf_params.get("_name_or_path")
if hf_name_or_path is not None and hf_name_or_path.count('/') <= 1 and not found_base_model: hf_name_or_path = hf_params.get("_name_or_path")
if hf_name_or_path is not None and hf_name_or_path.count('/') <= 1:
# Use _name_or_path only if its actually a model name and not some computer path # Use _name_or_path only if its actually a model name and not some computer path
# e.g. 'meta-llama/Llama-2-7b-hf' # e.g. 'meta-llama/Llama-2-7b-hf'
model_id = hf_name_or_path model_id = hf_name_or_path
@ -263,10 +314,6 @@ class Metadata:
metadata.version = version metadata.version = version
if metadata.parameter_class_attribute is None and parameter_class_attribute is not None: if metadata.parameter_class_attribute is None and parameter_class_attribute is not None:
metadata.parameter_class_attribute = parameter_class_attribute metadata.parameter_class_attribute = parameter_class_attribute
if metadata.source_url is None and org_component is not None and model_full_name_component is not None:
metadata.source_url = f"https://huggingface.co/{org_component}/{model_full_name_component}"
if metadata.source_hf_repo is None and org_component is not None and model_full_name_component is not None:
metadata.source_hf_repo = f"{org_component}/{model_full_name_component}"
# Directory Folder Name Fallback Heuristics # Directory Folder Name Fallback Heuristics
############################################ ############################################
@ -285,66 +332,57 @@ class Metadata:
metadata.version = version metadata.version = version
if metadata.parameter_class_attribute is None and parameter_class_attribute is not None: if metadata.parameter_class_attribute is None and parameter_class_attribute is not None:
metadata.parameter_class_attribute = parameter_class_attribute metadata.parameter_class_attribute = parameter_class_attribute
if metadata.source_url is None and org_component is not None and model_full_name_component is not None:
metadata.source_url = f"https://huggingface.co/{org_component}/{model_full_name_component}"
if metadata.source_hf_repo is None and org_component is not None and model_full_name_component is not None:
metadata.source_hf_repo = f"{org_component}/{model_full_name_component}"
return metadata return metadata
@staticmethod
def generate_any_missing_uuid(metadata: Metadata) -> Metadata:
class TestStringMethods(unittest.TestCase): # UUID Generation if not already provided
if metadata.uuid is None:
# Generate UUID based on provided links/id. UUIDv4 used as fallback
new_uuid = None
def test_get_model_id_components(self): if metadata.doi is not None:
self.assertEqual(Metadata.get_model_id_components("Mistral/Mixtral-8x7B-Instruct-v0.1"), new_uuid = uuid.uuid5(uuid.NAMESPACE_URL, f"https://doi.org/{metadata.doi}")
('Mixtral-8x7B-Instruct-v0.1', "Mistral", 'Mixtral', 'Instruct', 'v0.1', '8x7B')) elif metadata.repo_url is not None:
self.assertEqual(Metadata.get_model_id_components("Mixtral-8x7B-Instruct-v0.1"), new_uuid = uuid.uuid5(uuid.NAMESPACE_URL, metadata.repo_url)
('Mixtral-8x7B-Instruct-v0.1', None, 'Mixtral', 'Instruct', 'v0.1', '8x7B')) elif metadata.url is not None:
self.assertEqual(Metadata.get_model_id_components("Mixtral-8x7B-Instruct"), new_uuid = uuid.uuid5(uuid.NAMESPACE_URL, metadata.url)
('Mixtral-8x7B-Instruct', None, 'Mixtral', 'Instruct', None, '8x7B')) else:
self.assertEqual(Metadata.get_model_id_components("Mixtral-8x7B-v0.1"), new_uuid = uuid.uuid4() # every model must have at least a random UUIDv4
('Mixtral-8x7B-v0.1', None, 'Mixtral', None, 'v0.1', '8x7B'))
self.assertEqual(Metadata.get_model_id_components("Mixtral-8x7B"),
('Mixtral-8x7B', None, 'Mixtral', None, None, '8x7B'))
self.assertEqual(Metadata.get_model_id_components("Mixtral"),
('Mixtral', None, 'Mixtral', None, None, None))
self.assertEqual(Metadata.get_model_id_components("Mixtral-v0.1"),
('Mixtral-v0.1', None, 'Mixtral', None, 'v0.1', None))
self.assertEqual(Metadata.get_model_id_components("hermes-2-pro-llama-3-8b-DPO"),
('hermes-2-pro-llama-3-8b-DPO', None, 'hermes-2-pro-llama-3', 'DPO', None, '8b'))
self.assertEqual(Metadata.get_model_id_components("NousResearch/Meta-Llama-3-8B"),
('Meta-Llama-3-8B', "NousResearch", 'Meta-Llama-3', None, None, "8B"))
def test_apply_metadata_heuristic_from_model_card(self): if new_uuid is not None:
# Source: https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B/blob/main/README.md metadata.uuid = str(new_uuid)
model_card = {
'base_model': 'NousResearch/Meta-Llama-3-8B',
'tags': ['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl'],
'model-index': [{'name': 'Hermes-2-Pro-Llama-3-8B', 'results': []}],
'language': ['en'],
'datasets': ['teknium/OpenHermes-2.5'],
'widget': [{'example_title': 'Hermes 2 Pro', 'messages': [{'role': 'system', 'content': 'You are a sentient, superintelligent artificial general intelligence, here to teach and assist me.'}, {'role': 'user', 'content': 'Write a short story about Goku discovering kirby has teamed up with Majin Buu to destroy the world.'}]}]
}
expected = Metadata(name='Meta Llama 3 8B', basename='Meta-Llama-3', finetune=None, author=None, quantized_by=None, organization='NousResearch', version=None, url=None, doi=None, uuid=None, hf_repo=None, description=None, license=None, license_name=None, license_link=None, source_url='https://huggingface.co/NousResearch/Meta-Llama-3-8B', source_doi=None, source_uuid=None, source_hf_repo='NousResearch/Meta-Llama-3-8B', parameter_class_attribute='8B', parents=None, tags=['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl'], languages=['en'], datasets=['teknium/OpenHermes-2.5'])
got = Metadata.apply_metadata_heuristic(Metadata(), model_card, None, None) if metadata.source_uuid is None:
# Generate a UUID based on provided links/id only if source provided
new_uuid = None
self.assertEqual(got, expected) if metadata.source_doi is not None:
new_uuid = uuid.uuid5(uuid.NAMESPACE_URL, f"https://doi.org/{metadata.source_doi}")
elif metadata.source_repo_url is not None:
new_uuid = uuid.uuid5(uuid.NAMESPACE_URL, metadata.source_repo_url)
elif metadata.source_url is not None:
new_uuid = uuid.uuid5(uuid.NAMESPACE_URL, metadata.source_url)
def test_apply_metadata_heuristic_from_hf_parameters(self): if new_uuid is not None:
# Source: https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B/blob/main/config.json metadata.source_uuid = str(new_uuid)
hf_params = {"_name_or_path": "./hermes-2-pro-llama-3-8b-DPO"}
expected = Metadata(name='Hermes 2 Pro Llama 3 8B DPO', basename='hermes-2-pro-llama-3', finetune='DPO', author=None, quantized_by=None, organization=None, version=None, url=None, doi=None, uuid=None, hf_repo=None, description=None, license=None, license_name=None, license_link=None, source_url=None, source_doi=None, source_uuid=None, source_hf_repo=None, parameter_class_attribute='8b', parents=None, tags=None, languages=None, datasets=None)
got = Metadata.apply_metadata_heuristic(Metadata(), None, hf_params, None)
self.assertEqual(got, expected)
def test_apply_metadata_heuristic_from_model_dir(self): if metadata.base_models is not None:
# Source: https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B/blob/main/config.json for model_entry in metadata.base_models:
model_dir_path = Path("./hermes-2-pro-llama-3-8b-DPO") if "uuid" not in model_entry:
expected = Metadata(name='Hermes 2 Pro Llama 3 8B DPO', basename='hermes-2-pro-llama-3', finetune='DPO', author=None, quantized_by=None, organization=None, version=None, url=None, doi=None, uuid=None, hf_repo=None, description=None, license=None, license_name=None, license_link=None, source_url=None, source_doi=None, source_uuid=None, source_hf_repo=None, parameter_class_attribute='8b', parents=None, tags=None, languages=None, datasets=None) # Generate a UUID based on provided links/id only if source provided
got = Metadata.apply_metadata_heuristic(Metadata(), None, None, model_dir_path) new_uuid = None
self.assertEqual(got, expected)
if "repo_url" in model_entry:
new_uuid = uuid.uuid5(uuid.NAMESPACE_URL, model_entry["repo_url"])
elif "url" in model_entry:
new_uuid = uuid.uuid5(uuid.NAMESPACE_URL, model_entry["url"])
elif "doi" in model_entry:
new_uuid = uuid.uuid5(uuid.NAMESPACE_URL, model_entry["doi"])
if __name__ == '__main__': if new_uuid is not None:
unittest.main() model_entry["uuid"] = str(new_uuid)
return metadata

View file

@ -0,0 +1 @@
from .test_metadata import *

64
gguf-py/tests/test_metadata.py Executable file
View file

@ -0,0 +1,64 @@
#!/usr/bin/env python3
import unittest
import gguf # noqa: F401
from pathlib import Path
class TestMetadataMethod(unittest.TestCase):
def test_id_to_title(self):
self.assertEqual(gguf.Metadata.id_to_title("Mixtral-8x7B-Instruct-v0.1"), "Mixtral 8x7B Instruct v0.1")
self.assertEqual(gguf.Metadata.id_to_title("Meta-Llama-3-8B"), "Meta Llama 3 8B")
self.assertEqual(gguf.Metadata.id_to_title("hermes-2-pro-llama-3-8b-DPO"), "Hermes 2 Pro Llama 3 8b DPO")
def test_get_model_id_components(self):
self.assertEqual(gguf.Metadata.get_model_id_components("Mistral/Mixtral-8x7B-Instruct-v0.1"),
('Mixtral-8x7B-Instruct-v0.1', "Mistral", 'Mixtral', 'Instruct', 'v0.1', '8x7B'))
self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B-Instruct-v0.1"),
('Mixtral-8x7B-Instruct-v0.1', None, 'Mixtral', 'Instruct', 'v0.1', '8x7B'))
self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B-Instruct"),
('Mixtral-8x7B-Instruct', None, 'Mixtral', 'Instruct', None, '8x7B'))
self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B-v0.1"),
('Mixtral-8x7B-v0.1', None, 'Mixtral', None, 'v0.1', '8x7B'))
self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B"),
('Mixtral-8x7B', None, 'Mixtral', None, None, '8x7B'))
self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral"),
('Mixtral', None, 'Mixtral', None, None, None))
self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-v0.1"),
('Mixtral-v0.1', None, 'Mixtral', None, 'v0.1', None))
self.assertEqual(gguf.Metadata.get_model_id_components("hermes-2-pro-llama-3-8b-DPO"),
('hermes-2-pro-llama-3-8b-DPO', None, 'hermes-2-pro-llama-3', 'DPO', None, '8b'))
self.assertEqual(gguf.Metadata.get_model_id_components("NousResearch/Meta-Llama-3-8B"),
('Meta-Llama-3-8B', "NousResearch", 'Meta-Llama-3', None, None, "8B"))
def test_apply_metadata_heuristic_from_model_card(self):
model_card = {
'tags': ['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl'],
'model-index': [{'name': 'Hermes-2-Pro-Llama-3-8B', 'results': []}],
'language': ['en'],
'datasets': ['teknium/OpenHermes-2.5'],
'widget': [{'example_title': 'Hermes 2 Pro', 'messages': [{'role': 'system', 'content': 'You are a sentient, superintelligent artificial general intelligence, here to teach and assist me.'}, {'role': 'user', 'content': 'Write a short story about Goku discovering kirby has teamed up with Majin Buu to destroy the world.'}]}],
'base_model': ["EmbeddedLLM/Mistral-7B-Merge-14-v0", "janai-hq/trinity-v1"]
}
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
expect = gguf.Metadata(name=None, author=None, version=None, organization=None, finetune=None, basename=None, description=None, quantized_by=None, parameter_class_attribute=None, url=None, doi=None, uuid=None, repo_url=None, license=None, license_name=None, license_link=None, base_models=[{'name': 'Mistral 7B Merge 14 v0', 'organization': 'EmbeddedLLM', 'repo_url': 'https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0'}, {'name': 'Trinity v1', 'organization': 'Janai Hq', 'repo_url': 'https://huggingface.co/janai-hq/trinity-v1'}], tags=['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl'], languages=['en'], datasets=['teknium/OpenHermes-2.5'])
self.assertEqual(got, expect)
def test_apply_metadata_heuristic_from_hf_parameters(self):
hf_params = {"_name_or_path": "./hermes-2-pro-llama-3-8b-DPO"}
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), None, hf_params, None)
expect = gguf.Metadata(name='Hermes 2 Pro Llama 3 8b DPO', author=None, version=None, organization=None, finetune='DPO', basename='hermes-2-pro-llama-3', description=None, quantized_by=None, parameter_class_attribute='8b', url=None, doi=None, uuid=None, repo_url=None, license=None, license_name=None, license_link=None, base_models=None, tags=None, languages=None, datasets=None)
self.assertEqual(got, expect)
def test_apply_metadata_heuristic_from_model_dir(self):
model_dir_path = Path("./hermes-2-pro-llama-3-8b-DPO")
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), None, None, model_dir_path)
expect = gguf.Metadata(name='Hermes 2 Pro Llama 3 8b DPO', author=None, version=None, organization=None, finetune='DPO', basename='hermes-2-pro-llama-3', description=None, quantized_by=None, parameter_class_attribute='8b', url=None, doi=None, uuid=None, repo_url=None, license=None, license_name=None, license_link=None, base_models=None, tags=None, languages=None, datasets=None)
self.assertEqual(got, expect)
def test_generate_any_missing_uuid(self):
metadata = gguf.Metadata(repo_url="example.com", source_url="example.com", base_models=[{"doi":"10.57967/hf/2410"},{"doi":"10.47366/sabia.v5n1a3"}])
got = gguf.Metadata.generate_any_missing_uuid(metadata)
expect = gguf.Metadata(name=None, author=None, version=None, organization=None, finetune=None, basename=None, description=None, quantized_by=None, parameter_class_attribute=None, url=None, doi=None, uuid='a5cf6e8e-4cfa-5f31-a804-6de6d1245e26', repo_url='example.com', source_url='example.com', source_doi=None, source_uuid='a5cf6e8e-4cfa-5f31-a804-6de6d1245e26', source_repo_url=None, license=None, license_name=None, license_link=None, base_models=[{'doi': '10.57967/hf/2410', 'uuid': '26ce8128-2d34-5ea2-bc50-b5b90e21ed71'}, {'doi': '10.47366/sabia.v5n1a3', 'uuid': 'a15b24d6-5657-5d52-aaed-20dad7f4c500'}], tags=None, languages=None, datasets=None)
self.assertEqual(got, expect)