convert-*.py: cast not required if Metadata.load_metadata_override returned a dict[str, Any] instead of a dict[str, object]

Co-authored-by: compilade <git@compilade.net>
This commit is contained in:
brian khuu 2024-07-11 20:39:10 +10:00
parent 74383ba6d2
commit 4c91d077d2

View file

@ -4,7 +4,7 @@ import re
import json import json
import frontmatter import frontmatter
from pathlib import Path from pathlib import Path
from typing import Optional, cast from typing import Any, Optional
from dataclasses import dataclass from dataclasses import dataclass
from .constants import Keys from .constants import Keys
@ -59,38 +59,36 @@ class Metadata:
# This is based on LLM_KV_NAMES mapping in llama.cpp # This is based on LLM_KV_NAMES mapping in llama.cpp
metadata_override = Metadata.load_metadata_override(metadata_override_path) metadata_override = Metadata.load_metadata_override(metadata_override_path)
metadata.author = cast(Optional[str], metadata_override.get(Keys.General.AUTHOR , metadata.author )) # noqa: E202 metadata.author = metadata_override.get(Keys.General.AUTHOR, metadata.author)
metadata.version = cast(Optional[str], metadata_override.get(Keys.General.VERSION , metadata.version )) # noqa: E202 metadata.version = metadata_override.get(Keys.General.VERSION, metadata.version)
metadata.organization = cast(Optional[str], metadata_override.get(Keys.General.ORGANIZATION , metadata.organization )) # noqa: E202 metadata.organization = metadata_override.get(Keys.General.ORGANIZATION, metadata.organization)
metadata.finetune = cast(Optional[str], metadata_override.get(Keys.General.FINETUNE , metadata.finetune )) # noqa: E202 metadata.finetune = metadata_override.get(Keys.General.FINETUNE, metadata.finetune)
metadata.basename = cast(Optional[str], metadata_override.get(Keys.General.BASENAME , metadata.basename )) # noqa: E202 metadata.basename = metadata_override.get(Keys.General.BASENAME, metadata.basename)
metadata.description = cast(Optional[str], metadata_override.get(Keys.General.DESCRIPTION , metadata.description )) # noqa: E202 metadata.description = metadata_override.get(Keys.General.DESCRIPTION, metadata.description)
metadata.quantized_by = cast(Optional[str], metadata_override.get(Keys.General.QUANTIZED_BY , metadata.quantized_by )) # noqa: E202 metadata.quantized_by = metadata_override.get(Keys.General.QUANTIZED_BY, metadata.quantized_by)
metadata.parameter_class_attribute = cast(Optional[str], metadata_override.get(Keys.General.PARAMETER_CLASS_ATTRIBUTE, metadata.parameter_class_attribute)) # noqa: E202 metadata.parameter_class_attribute = metadata_override.get(Keys.General.PARAMETER_CLASS_ATTRIBUTE, metadata.parameter_class_attribute)
metadata.license_name = metadata_override.get(Keys.General.LICENSE_NAME, metadata.license_name)
metadata.license_link = metadata_override.get(Keys.General.LICENSE_LINK, metadata.license_link)
metadata.license = cast(Optional[str], metadata_override.get(Keys.General.LICENSE , metadata.license )) # noqa: E202 metadata.url = metadata_override.get(Keys.General.URL, metadata.url)
metadata.license_name = cast(Optional[str], metadata_override.get(Keys.General.LICENSE_NAME , metadata.license_name )) # noqa: E202 metadata.doi = metadata_override.get(Keys.General.DOI, metadata.doi)
metadata.license_link = cast(Optional[str], metadata_override.get(Keys.General.LICENSE_LINK , metadata.license_link )) # noqa: E202 metadata.uuid = metadata_override.get(Keys.General.UUID, metadata.uuid)
metadata.repo_url = metadata_override.get(Keys.General.REPO_URL, metadata.repo_url)
metadata.url = cast(Optional[str], metadata_override.get(Keys.General.URL , metadata.url )) # noqa: E202 metadata.source_url = metadata_override.get(Keys.General.SOURCE_URL, metadata.source_url)
metadata.doi = cast(Optional[str], metadata_override.get(Keys.General.DOI , metadata.doi )) # noqa: E202 metadata.source_doi = metadata_override.get(Keys.General.SOURCE_DOI, metadata.source_doi)
metadata.uuid = cast(Optional[str], metadata_override.get(Keys.General.UUID , metadata.uuid )) # noqa: E202 metadata.source_uuid = metadata_override.get(Keys.General.SOURCE_UUID, metadata.source_uuid)
metadata.repo_url = cast(Optional[str], metadata_override.get(Keys.General.REPO_URL , metadata.repo_url )) # noqa: E202 metadata.source_repo_url = metadata_override.get(Keys.General.SOURCE_REPO_URL, metadata.source_repo_url)
metadata.source_url = cast(Optional[str], metadata_override.get(Keys.General.SOURCE_URL , metadata.source_url )) # noqa: E202
metadata.source_doi = cast(Optional[str], metadata_override.get(Keys.General.SOURCE_DOI , metadata.source_doi )) # noqa: E202
metadata.source_uuid = cast(Optional[str], metadata_override.get(Keys.General.SOURCE_UUID , metadata.source_uuid )) # noqa: E202
metadata.source_repo_url = cast(Optional[str], metadata_override.get(Keys.General.SOURCE_REPO_URL , metadata.source_repo_url )) # noqa: E202
# Base Models is received here as an array of models # Base Models is received here as an array of models
metadata.base_models = cast("Optional[list[dict]]", metadata_override.get("general.base_models" , metadata.base_models )) # noqa: E202 metadata.base_models = metadata_override.get("general.base_models", metadata.base_models)
metadata.tags = cast("Optional[list[str]]", metadata_override.get(Keys.General.TAGS , metadata.tags )) # noqa: E202 metadata.tags = metadata_override.get(Keys.General.TAGS, metadata.tags)
metadata.languages = cast("Optional[list[str]]", metadata_override.get(Keys.General.LANGUAGES , metadata.languages )) # noqa: E202 metadata.languages = metadata_override.get(Keys.General.LANGUAGES, metadata.languages)
metadata.datasets = cast("Optional[list[str]]", metadata_override.get(Keys.General.DATASETS , metadata.datasets )) # noqa: E202 metadata.datasets = metadata_override.get(Keys.General.DATASETS, metadata.datasets)
# Direct Metadata Override (via direct cli argument) # Direct Metadata Override (via direct cli argument)
if model_name is not None: if model_name is not None:
@ -99,7 +97,7 @@ class Metadata:
return metadata return metadata
@staticmethod @staticmethod
def load_metadata_override(metadata_override_path: Optional[Path] = None) -> dict[str, object]: def load_metadata_override(metadata_override_path: Optional[Path] = None) -> dict[str, Any]:
if metadata_override_path is None or not metadata_override_path.exists(): if metadata_override_path is None or not metadata_override_path.exists():
return {} return {}