diff --git a/gguf-py/gguf/metadata.py b/gguf-py/gguf/metadata.py index 4425970eb..4872247d7 100644 --- a/gguf-py/gguf/metadata.py +++ b/gguf-py/gguf/metadata.py @@ -4,7 +4,7 @@ import re import json import frontmatter from pathlib import Path -from typing import Optional, cast +from typing import Any, Optional from dataclasses import dataclass from .constants import Keys @@ -59,38 +59,36 @@ class Metadata: # This is based on LLM_KV_NAMES mapping in llama.cpp metadata_override = Metadata.load_metadata_override(metadata_override_path) - metadata.author = cast(Optional[str], metadata_override.get(Keys.General.AUTHOR , metadata.author )) # noqa: E202 - metadata.version = cast(Optional[str], metadata_override.get(Keys.General.VERSION , metadata.version )) # noqa: E202 - metadata.organization = cast(Optional[str], metadata_override.get(Keys.General.ORGANIZATION , metadata.organization )) # noqa: E202 + metadata.author = metadata_override.get(Keys.General.AUTHOR, metadata.author) + metadata.version = metadata_override.get(Keys.General.VERSION, metadata.version) + metadata.organization = metadata_override.get(Keys.General.ORGANIZATION, metadata.organization) - metadata.finetune = cast(Optional[str], metadata_override.get(Keys.General.FINETUNE , metadata.finetune )) # noqa: E202 - metadata.basename = cast(Optional[str], metadata_override.get(Keys.General.BASENAME , metadata.basename )) # noqa: E202 + metadata.finetune = metadata_override.get(Keys.General.FINETUNE, metadata.finetune) + metadata.basename = metadata_override.get(Keys.General.BASENAME, metadata.basename) - metadata.description = cast(Optional[str], metadata_override.get(Keys.General.DESCRIPTION , metadata.description )) # noqa: E202 - metadata.quantized_by = cast(Optional[str], metadata_override.get(Keys.General.QUANTIZED_BY , metadata.quantized_by )) # noqa: E202 + metadata.description = metadata_override.get(Keys.General.DESCRIPTION, metadata.description) + metadata.quantized_by = metadata_override.get(Keys.General.QUANTIZED_BY, metadata.quantized_by) - metadata.parameter_class_attribute = cast(Optional[str], metadata_override.get(Keys.General.PARAMETER_CLASS_ATTRIBUTE, metadata.parameter_class_attribute)) # noqa: E202 + metadata.parameter_class_attribute = metadata_override.get(Keys.General.PARAMETER_CLASS_ATTRIBUTE, metadata.parameter_class_attribute) + metadata.license_name = metadata_override.get(Keys.General.LICENSE_NAME, metadata.license_name) + metadata.license_link = metadata_override.get(Keys.General.LICENSE_LINK, metadata.license_link) - metadata.license = cast(Optional[str], metadata_override.get(Keys.General.LICENSE , metadata.license )) # noqa: E202 - metadata.license_name = cast(Optional[str], metadata_override.get(Keys.General.LICENSE_NAME , metadata.license_name )) # noqa: E202 - metadata.license_link = cast(Optional[str], metadata_override.get(Keys.General.LICENSE_LINK , metadata.license_link )) # noqa: E202 + metadata.url = metadata_override.get(Keys.General.URL, metadata.url) + metadata.doi = metadata_override.get(Keys.General.DOI, metadata.doi) + metadata.uuid = metadata_override.get(Keys.General.UUID, metadata.uuid) + metadata.repo_url = metadata_override.get(Keys.General.REPO_URL, metadata.repo_url) - metadata.url = cast(Optional[str], metadata_override.get(Keys.General.URL , metadata.url )) # noqa: E202 - metadata.doi = cast(Optional[str], metadata_override.get(Keys.General.DOI , metadata.doi )) # noqa: E202 - metadata.uuid = cast(Optional[str], metadata_override.get(Keys.General.UUID , metadata.uuid )) # noqa: E202 - metadata.repo_url = cast(Optional[str], metadata_override.get(Keys.General.REPO_URL , metadata.repo_url )) # noqa: E202 - - metadata.source_url = cast(Optional[str], metadata_override.get(Keys.General.SOURCE_URL , metadata.source_url )) # noqa: E202 - metadata.source_doi = cast(Optional[str], metadata_override.get(Keys.General.SOURCE_DOI , metadata.source_doi )) # noqa: E202 - metadata.source_uuid = cast(Optional[str], metadata_override.get(Keys.General.SOURCE_UUID , metadata.source_uuid )) # noqa: E202 - metadata.source_repo_url = cast(Optional[str], metadata_override.get(Keys.General.SOURCE_REPO_URL , metadata.source_repo_url )) # noqa: E202 + metadata.source_url = metadata_override.get(Keys.General.SOURCE_URL, metadata.source_url) + metadata.source_doi = metadata_override.get(Keys.General.SOURCE_DOI, metadata.source_doi) + metadata.source_uuid = metadata_override.get(Keys.General.SOURCE_UUID, metadata.source_uuid) + metadata.source_repo_url = metadata_override.get(Keys.General.SOURCE_REPO_URL, metadata.source_repo_url) # Base Models is received here as an array of models - metadata.base_models = cast("Optional[list[dict]]", metadata_override.get("general.base_models" , metadata.base_models )) # noqa: E202 + metadata.base_models = metadata_override.get("general.base_models", metadata.base_models) - metadata.tags = cast("Optional[list[str]]", metadata_override.get(Keys.General.TAGS , metadata.tags )) # noqa: E202 - metadata.languages = cast("Optional[list[str]]", metadata_override.get(Keys.General.LANGUAGES , metadata.languages )) # noqa: E202 - metadata.datasets = cast("Optional[list[str]]", metadata_override.get(Keys.General.DATASETS , metadata.datasets )) # noqa: E202 + metadata.tags = metadata_override.get(Keys.General.TAGS, metadata.tags) + metadata.languages = metadata_override.get(Keys.General.LANGUAGES, metadata.languages) + metadata.datasets = metadata_override.get(Keys.General.DATASETS, metadata.datasets) # Direct Metadata Override (via direct cli argument) if model_name is not None: @@ -99,7 +97,7 @@ class Metadata: return metadata @staticmethod - def load_metadata_override(metadata_override_path: Optional[Path] = None) -> dict[str, object]: + def load_metadata_override(metadata_override_path: Optional[Path] = None) -> dict[str, Any]: if metadata_override_path is None or not metadata_override_path.exists(): return {}