convert-*.py: add datasets and language to KV store

This commit is contained in:
brian khuu 2024-06-02 17:17:56 +10:00
parent 0f1d50fab7
commit 684c604eca
4 changed files with 20 additions and 0 deletions

View file

@ -256,6 +256,10 @@ class Model:
self.gguf_writer.add_parameter_size_class(self.metadata.parameter_size_class)
if self.metadata.tags is not None:
self.gguf_writer.add_tags(self.metadata.tags)
if self.metadata.languages is not None:
self.gguf_writer.add_languages(self.metadata.languages)
if self.metadata.datasets is not None:
self.gguf_writer.add_datasets(self.metadata.datasets)
def set_gguf_parameters(self):
self.gguf_writer.add_block_count(self.block_count)

View file

@ -39,6 +39,8 @@ class Keys:
FILE_TYPE = "general.file_type"
PARAMETER_SIZE_CLASS = "general.parameter_size_class"
TAGS = "general.tags"
LANGUAGE = "general.language"
DATASETS = "general.datasets"
class LLM:
VOCAB_SIZE = "{arch}.vocab_size"

View file

@ -478,6 +478,12 @@ class GGUFWriter:
def add_tags(self, tags: Sequence[str]) -> None:
self.add_array(Keys.Tokenizer.TAGS, tags)
def add_languages(self, languages: Sequence[str]) -> None:
self.add_array(Keys.Tokenizer.LANGUAGE, languages)
def add_datasets(self, datasets: Sequence[str]) -> None:
self.add_array(Keys.Tokenizer.DATASETS, datasets)
def add_name(self, name: str) -> None:
self.add_string(Keys.General.NAME, name)

View file

@ -28,6 +28,8 @@ class Metadata:
source_hf_repo: Optional[str] = None
parameter_size_class: Optional[str] = None
tags: Optional[list[str]] = None
language: Optional[list[str]] = None
datasets: Optional[list[str]] = None
@staticmethod
def load(metadata_override_path: Path, model_path: Path) -> Metadata:
@ -60,6 +62,10 @@ class Metadata:
metadata.author = model_card.get("model_creator")
if metadata.tags is None:
metadata.tags = model_card.get("tags", [])
if metadata.languages is None:
metadata.languages = model_card.get("languages", [])
if metadata.datasets is None:
metadata.datasets = model_card.get("datasets", [])
# load huggingface parameters if available
hf_params = Metadata.load_huggingface_parameters(model_path)
@ -92,6 +98,8 @@ class Metadata:
metadata.source_hf_repo = metadata_override.get(Keys.General.SOURCE_HF_REPO , metadata.source_hf_repo ) # noqa: E202
metadata.parameter_size_class = metadata_override.get(Keys.General.PARAMETER_SIZE_CLASS, metadata.parameter_size_class) # noqa: E202
metadata.tags = metadata_override.get(Keys.General.TAGS , metadata.tags ) # noqa: E202
metadata.languages = metadata_override.get(Keys.General.LANGUAGES , metadata.languages ) # noqa: E202
metadata.datasets = metadata_override.get(Keys.General.datasets , metadata.datasets ) # noqa: E202
return metadata