convert-*.py: add base_version and add tags

This commit is contained in:
brian khuu 2024-06-02 15:11:52 +10:00
parent b36e391b87
commit 8f734083dd
5 changed files with 27 additions and 0 deletions

View file

@ -243,6 +243,8 @@ class Model:
self.gguf_writer.add_author(self.metadata.author) self.gguf_writer.add_author(self.metadata.author)
if self.metadata.version is not None: if self.metadata.version is not None:
self.gguf_writer.add_version(self.metadata.version) self.gguf_writer.add_version(self.metadata.version)
if self.metadata.base_version is not None:
self.gguf_writer.add_base_version(self.metadata.base_version)
if self.metadata.url is not None: if self.metadata.url is not None:
self.gguf_writer.add_url(self.metadata.url) self.gguf_writer.add_url(self.metadata.url)
if self.metadata.description is not None: if self.metadata.description is not None:
@ -253,6 +255,8 @@ class Model:
self.gguf_writer.add_source_url(self.metadata.source_url) self.gguf_writer.add_source_url(self.metadata.source_url)
if self.metadata.source_hf_repo is not None: if self.metadata.source_hf_repo is not None:
self.gguf_writer.add_source_hf_repo(self.metadata.source_hf_repo) self.gguf_writer.add_source_hf_repo(self.metadata.source_hf_repo)
if self.metadata.tags is not None:
self.gguf_writer.add_tags(self.metadata.tags)
def set_gguf_parameters(self): def set_gguf_parameters(self):
self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_block_count(self.block_count)

View file

@ -792,6 +792,8 @@ class OutputFile:
self.gguf.add_author(metadata.author) self.gguf.add_author(metadata.author)
if metadata.version is not None: if metadata.version is not None:
self.gguf.add_version(metadata.version) self.gguf.add_version(metadata.version)
if metadata.base_version is not None:
self.gguf.add_base_version(metadata.base_version)
if metadata.url is not None: if metadata.url is not None:
self.gguf.add_url(metadata.url) self.gguf.add_url(metadata.url)
if metadata.description is not None: if metadata.description is not None:
@ -802,6 +804,8 @@ class OutputFile:
self.gguf.add_source_url(metadata.source_url) self.gguf.add_source_url(metadata.source_url)
if metadata.source_hf_repo is not None: if metadata.source_hf_repo is not None:
self.gguf.add_source_hf_repo(metadata.source_hf_repo) self.gguf.add_source_hf_repo(metadata.source_hf_repo)
if metadata.tags is not None:
self.gguf_writer.add_tags(metadata.tags)
def add_meta_arch(self, params: Params) -> None: def add_meta_arch(self, params: Params) -> None:
# Metadata About The Neural Architecture Itself # Metadata About The Neural Architecture Itself

View file

@ -28,6 +28,7 @@ class Keys:
FINETUNE = "general.finetune" FINETUNE = "general.finetune"
AUTHOR = "general.author" AUTHOR = "general.author"
VERSION = "general.version" VERSION = "general.version"
BASE_VERSION = "general.base_version"
URL = "general.url" URL = "general.url"
DESCRIPTION = "general.description" DESCRIPTION = "general.description"
LICENSE = "general.license" LICENSE = "general.license"
@ -36,6 +37,7 @@ class Keys:
SOURCE_URL = "general.source.url" SOURCE_URL = "general.source.url"
SOURCE_HF_REPO = "general.source.huggingface.repository" SOURCE_HF_REPO = "general.source.huggingface.repository"
FILE_TYPE = "general.file_type" FILE_TYPE = "general.file_type"
TAGS = "general.tags"
class LLM: class LLM:
VOCAB_SIZE = "{arch}.vocab_size" VOCAB_SIZE = "{arch}.vocab_size"

View file

@ -442,6 +442,9 @@ class GGUFWriter:
def add_version(self, version: str) -> None: def add_version(self, version: str) -> None:
self.add_string(Keys.General.VERSION, version) self.add_string(Keys.General.VERSION, version)
def add_base_version(self, version: str) -> None:
self.add_string(Keys.General.BASE_VERSION, version)
def add_tensor_data_layout(self, layout: str) -> None: def add_tensor_data_layout(self, layout: str) -> None:
self.add_string(Keys.LLM.TENSOR_DATA_LAYOUT.format(arch=self.arch), layout) self.add_string(Keys.LLM.TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
@ -469,6 +472,9 @@ class GGUFWriter:
def add_file_type(self, ftype: int) -> None: def add_file_type(self, ftype: int) -> None:
self.add_uint32(Keys.General.FILE_TYPE, ftype) self.add_uint32(Keys.General.FILE_TYPE, ftype)
def add_tags(self, tags: Sequence[str]) -> None:
self.add_array(Keys.Tokenizer.TAGS, tags)
def add_name(self, name: str) -> None: def add_name(self, name: str) -> None:
self.add_string(Keys.General.NAME, name) self.add_string(Keys.General.NAME, name)

View file

@ -18,6 +18,7 @@ class Metadata:
finetune: Optional[str] = None finetune: Optional[str] = None
author: Optional[str] = None author: Optional[str] = None
version: Optional[str] = None version: Optional[str] = None
base_version: Optional[str] = None
url: Optional[str] = None url: Optional[str] = None
description: Optional[str] = None description: Optional[str] = None
license: Optional[str] = None license: Optional[str] = None
@ -25,6 +26,7 @@ class Metadata:
license_link: Optional[str] = None license_link: Optional[str] = None
source_url: Optional[str] = None source_url: Optional[str] = None
source_hf_repo: Optional[str] = None source_hf_repo: Optional[str] = None
tags: Optional[List[str]] = None
@staticmethod @staticmethod
def load(metadata_override_path: Path, model_path: Path) -> Metadata: def load(metadata_override_path: Path, model_path: Path) -> Metadata:
@ -40,6 +42,8 @@ class Metadata:
model_card = Metadata.load_model_card(model_path) model_card = Metadata.load_model_card(model_path)
if metadata.name is None: if metadata.name is None:
if "model-index" in model_card and len(model_card["model_name"]) == 1 and "name" in model_card["model_name"][0]: if "model-index" in model_card and len(model_card["model_name"]) == 1 and "name" in model_card["model_name"][0]:
# We check if there is only one model information in the model-index
# (This is a safe choice in case there is multiple models in one repo in the future)
metadata.name = model_card["model_name"][0].get("name") metadata.name = model_card["model_name"][0].get("name")
elif "model_name" in model_card: elif "model_name" in model_card:
# non huggingface model card standard but notice some model creator using it # non huggingface model card standard but notice some model creator using it
@ -50,6 +54,11 @@ class Metadata:
metadata.license_name = model_card.get("license_name") metadata.license_name = model_card.get("license_name")
if metadata.license_link is None: if metadata.license_link is None:
metadata.license_link = model_card.get("license_link") metadata.license_link = model_card.get("license_link")
if metadata.author is None:
# non huggingface model card standard but notice some model creator using it
metadata.author = model_card.get("model_creator")
if metadata.tags is None:
metadata.tags = model_card.get("tags", [])
# load huggingface parameters if available # load huggingface parameters if available
hf_params = Metadata.load_huggingface_parameters(model_path) hf_params = Metadata.load_huggingface_parameters(model_path)
@ -72,6 +81,7 @@ class Metadata:
metadata.finetune = metadata_override.get(Keys.General.FINETUNE , metadata.finetune ) # noqa: E202 metadata.finetune = metadata_override.get(Keys.General.FINETUNE , metadata.finetune ) # noqa: E202
metadata.author = metadata_override.get(Keys.General.AUTHOR , metadata.author ) # noqa: E202 metadata.author = metadata_override.get(Keys.General.AUTHOR , metadata.author ) # noqa: E202
metadata.version = metadata_override.get(Keys.General.VERSION , metadata.version ) # noqa: E202 metadata.version = metadata_override.get(Keys.General.VERSION , metadata.version ) # noqa: E202
metadata.base_version = metadata_override.get(Keys.General.BASE_VERSION , metadata.base_version ) # noqa: E202
metadata.url = metadata_override.get(Keys.General.URL , metadata.url ) # noqa: E202 metadata.url = metadata_override.get(Keys.General.URL , metadata.url ) # noqa: E202
metadata.description = metadata_override.get(Keys.General.DESCRIPTION , metadata.description ) # noqa: E202 metadata.description = metadata_override.get(Keys.General.DESCRIPTION , metadata.description ) # noqa: E202
metadata.license = metadata_override.get(Keys.General.LICENSE , metadata.license ) # noqa: E202 metadata.license = metadata_override.get(Keys.General.LICENSE , metadata.license ) # noqa: E202
@ -79,6 +89,7 @@ class Metadata:
metadata.license_link = metadata_override.get(Keys.General.LICENSE_LINK , metadata.license_link ) # noqa: E202 metadata.license_link = metadata_override.get(Keys.General.LICENSE_LINK , metadata.license_link ) # noqa: E202
metadata.source_url = metadata_override.get(Keys.General.SOURCE_URL , metadata.source_url ) # noqa: E202 metadata.source_url = metadata_override.get(Keys.General.SOURCE_URL , metadata.source_url ) # noqa: E202
metadata.source_hf_repo = metadata_override.get(Keys.General.SOURCE_HF_REPO, metadata.source_hf_repo) # noqa: E202 metadata.source_hf_repo = metadata_override.get(Keys.General.SOURCE_HF_REPO, metadata.source_hf_repo) # noqa: E202
metadata.tags = metadata_override.get(Keys.General.TAGS , metadata.tags ) # noqa: E202
return metadata return metadata