diff --git a/gguf-py/gguf/metadata.py b/gguf-py/gguf/metadata.py index d18ab400e..30d0719d1 100644 --- a/gguf-py/gguf/metadata.py +++ b/gguf-py/gguf/metadata.py @@ -59,36 +59,36 @@ class Metadata: # This is based on LLM_KV_NAMES mapping in llama.cpp metadata_override = Metadata.load_metadata_override(metadata_override_path) - metadata.author = metadata_override.get(Keys.General.AUTHOR, metadata.author) - metadata.version = metadata_override.get(Keys.General.VERSION, metadata.version) - metadata.organization = metadata_override.get(Keys.General.ORGANIZATION, metadata.organization) + metadata.author = metadata_override.get(Keys.General.AUTHOR, metadata.author) + metadata.version = metadata_override.get(Keys.General.VERSION, metadata.version) + metadata.organization = metadata_override.get(Keys.General.ORGANIZATION, metadata.organization) - metadata.finetune = metadata_override.get(Keys.General.FINETUNE, metadata.finetune) - metadata.basename = metadata_override.get(Keys.General.BASENAME, metadata.basename) + metadata.finetune = metadata_override.get(Keys.General.FINETUNE, metadata.finetune) + metadata.basename = metadata_override.get(Keys.General.BASENAME, metadata.basename) - metadata.description = metadata_override.get(Keys.General.DESCRIPTION, metadata.description) - metadata.quantized_by = metadata_override.get(Keys.General.QUANTIZED_BY, metadata.quantized_by) + metadata.description = metadata_override.get(Keys.General.DESCRIPTION, metadata.description) + metadata.quantized_by = metadata_override.get(Keys.General.QUANTIZED_BY, metadata.quantized_by) - metadata.size_label = metadata_override.get(Keys.General.SIZE_LABEL, metadata.size_label) - metadata.license_name = metadata_override.get(Keys.General.LICENSE_NAME, metadata.license_name) - metadata.license_link = metadata_override.get(Keys.General.LICENSE_LINK, metadata.license_link) + metadata.size_label = metadata_override.get(Keys.General.SIZE_LABEL, metadata.size_label) + metadata.license_name = metadata_override.get(Keys.General.LICENSE_NAME, metadata.license_name) + metadata.license_link = metadata_override.get(Keys.General.LICENSE_LINK, metadata.license_link) - metadata.url = metadata_override.get(Keys.General.URL, metadata.url) - metadata.doi = metadata_override.get(Keys.General.DOI, metadata.doi) - metadata.uuid = metadata_override.get(Keys.General.UUID, metadata.uuid) - metadata.repo_url = metadata_override.get(Keys.General.REPO_URL, metadata.repo_url) + metadata.url = metadata_override.get(Keys.General.URL, metadata.url) + metadata.doi = metadata_override.get(Keys.General.DOI, metadata.doi) + metadata.uuid = metadata_override.get(Keys.General.UUID, metadata.uuid) + metadata.repo_url = metadata_override.get(Keys.General.REPO_URL, metadata.repo_url) - metadata.source_url = metadata_override.get(Keys.General.SOURCE_URL, metadata.source_url) - metadata.source_doi = metadata_override.get(Keys.General.SOURCE_DOI, metadata.source_doi) - metadata.source_uuid = metadata_override.get(Keys.General.SOURCE_UUID, metadata.source_uuid) - metadata.source_repo_url = metadata_override.get(Keys.General.SOURCE_REPO_URL, metadata.source_repo_url) + metadata.source_url = metadata_override.get(Keys.General.SOURCE_URL, metadata.source_url) + metadata.source_doi = metadata_override.get(Keys.General.SOURCE_DOI, metadata.source_doi) + metadata.source_uuid = metadata_override.get(Keys.General.SOURCE_UUID, metadata.source_uuid) + metadata.source_repo_url = metadata_override.get(Keys.General.SOURCE_REPO_URL, metadata.source_repo_url) # Base Models is received here as an array of models - metadata.base_models = metadata_override.get("general.base_models", metadata.base_models) + metadata.base_models = metadata_override.get("general.base_models", metadata.base_models) - metadata.tags = metadata_override.get(Keys.General.TAGS, metadata.tags) - metadata.languages = metadata_override.get(Keys.General.LANGUAGES, metadata.languages) - metadata.datasets = metadata_override.get(Keys.General.DATASETS, metadata.datasets) + metadata.tags = metadata_override.get(Keys.General.TAGS, metadata.tags) + metadata.languages = metadata_override.get(Keys.General.LANGUAGES, metadata.languages) + metadata.datasets = metadata_override.get(Keys.General.DATASETS, metadata.datasets) # Direct Metadata Override (via direct cli argument) if model_name is not None: @@ -115,7 +115,7 @@ class Metadata: return {} with open(model_card_path, "r", encoding="utf-8") as f: - return cast("dict[str, Any]", frontmatter.load(f)) + return frontmatter.load(f).to_dict() @staticmethod def load_hf_parameters(model_path: Optional[Path] = None) -> dict[str, Any]: