Apply suggestions from code review

Co-authored-by: compilade <git@compilade.net>
This commit is contained in:
Brian 2024-07-14 10:29:03 +10:00 committed by brian khuu
parent ccff6c7fb2
commit 455c0e53ac

View file

@ -59,36 +59,36 @@ class Metadata:
# This is based on LLM_KV_NAMES mapping in llama.cpp # This is based on LLM_KV_NAMES mapping in llama.cpp
metadata_override = Metadata.load_metadata_override(metadata_override_path) metadata_override = Metadata.load_metadata_override(metadata_override_path)
metadata.author = metadata_override.get(Keys.General.AUTHOR, metadata.author) metadata.author = metadata_override.get(Keys.General.AUTHOR, metadata.author)
metadata.version = metadata_override.get(Keys.General.VERSION, metadata.version) metadata.version = metadata_override.get(Keys.General.VERSION, metadata.version)
metadata.organization = metadata_override.get(Keys.General.ORGANIZATION, metadata.organization) metadata.organization = metadata_override.get(Keys.General.ORGANIZATION, metadata.organization)
metadata.finetune = metadata_override.get(Keys.General.FINETUNE, metadata.finetune) metadata.finetune = metadata_override.get(Keys.General.FINETUNE, metadata.finetune)
metadata.basename = metadata_override.get(Keys.General.BASENAME, metadata.basename) metadata.basename = metadata_override.get(Keys.General.BASENAME, metadata.basename)
metadata.description = metadata_override.get(Keys.General.DESCRIPTION, metadata.description) metadata.description = metadata_override.get(Keys.General.DESCRIPTION, metadata.description)
metadata.quantized_by = metadata_override.get(Keys.General.QUANTIZED_BY, metadata.quantized_by) metadata.quantized_by = metadata_override.get(Keys.General.QUANTIZED_BY, metadata.quantized_by)
metadata.size_label = metadata_override.get(Keys.General.SIZE_LABEL, metadata.size_label) metadata.size_label = metadata_override.get(Keys.General.SIZE_LABEL, metadata.size_label)
metadata.license_name = metadata_override.get(Keys.General.LICENSE_NAME, metadata.license_name) metadata.license_name = metadata_override.get(Keys.General.LICENSE_NAME, metadata.license_name)
metadata.license_link = metadata_override.get(Keys.General.LICENSE_LINK, metadata.license_link) metadata.license_link = metadata_override.get(Keys.General.LICENSE_LINK, metadata.license_link)
metadata.url = metadata_override.get(Keys.General.URL, metadata.url) metadata.url = metadata_override.get(Keys.General.URL, metadata.url)
metadata.doi = metadata_override.get(Keys.General.DOI, metadata.doi) metadata.doi = metadata_override.get(Keys.General.DOI, metadata.doi)
metadata.uuid = metadata_override.get(Keys.General.UUID, metadata.uuid) metadata.uuid = metadata_override.get(Keys.General.UUID, metadata.uuid)
metadata.repo_url = metadata_override.get(Keys.General.REPO_URL, metadata.repo_url) metadata.repo_url = metadata_override.get(Keys.General.REPO_URL, metadata.repo_url)
metadata.source_url = metadata_override.get(Keys.General.SOURCE_URL, metadata.source_url) metadata.source_url = metadata_override.get(Keys.General.SOURCE_URL, metadata.source_url)
metadata.source_doi = metadata_override.get(Keys.General.SOURCE_DOI, metadata.source_doi) metadata.source_doi = metadata_override.get(Keys.General.SOURCE_DOI, metadata.source_doi)
metadata.source_uuid = metadata_override.get(Keys.General.SOURCE_UUID, metadata.source_uuid) metadata.source_uuid = metadata_override.get(Keys.General.SOURCE_UUID, metadata.source_uuid)
metadata.source_repo_url = metadata_override.get(Keys.General.SOURCE_REPO_URL, metadata.source_repo_url) metadata.source_repo_url = metadata_override.get(Keys.General.SOURCE_REPO_URL, metadata.source_repo_url)
# Base Models is received here as an array of models # Base Models is received here as an array of models
metadata.base_models = metadata_override.get("general.base_models", metadata.base_models) metadata.base_models = metadata_override.get("general.base_models", metadata.base_models)
metadata.tags = metadata_override.get(Keys.General.TAGS, metadata.tags) metadata.tags = metadata_override.get(Keys.General.TAGS, metadata.tags)
metadata.languages = metadata_override.get(Keys.General.LANGUAGES, metadata.languages) metadata.languages = metadata_override.get(Keys.General.LANGUAGES, metadata.languages)
metadata.datasets = metadata_override.get(Keys.General.DATASETS, metadata.datasets) metadata.datasets = metadata_override.get(Keys.General.DATASETS, metadata.datasets)
# Direct Metadata Override (via direct cli argument) # Direct Metadata Override (via direct cli argument)
if model_name is not None: if model_name is not None:
@ -115,7 +115,7 @@ class Metadata:
return {} return {}
with open(model_card_path, "r", encoding="utf-8") as f: with open(model_card_path, "r", encoding="utf-8") as f:
return cast("dict[str, Any]", frontmatter.load(f)) return frontmatter.load(f).to_dict()
@staticmethod @staticmethod
def load_hf_parameters(model_path: Optional[Path] = None) -> dict[str, Any]: def load_hf_parameters(model_path: Optional[Path] = None) -> dict[str, Any]: