convert-*.py: cast not required if Metadata.load_metadata_override returned a dict[str, Any] instead of a dict[str, object]

Co-authored-by: compilade <git@compilade.net>
This commit is contained in:
brian khuu 2024-07-11 20:39:10 +10:00
parent 74383ba6d2
commit 4c91d077d2

View file

@ -4,7 +4,7 @@ import re
import json
import frontmatter
from pathlib import Path
from typing import Optional, cast
from typing import Any, Optional
from dataclasses import dataclass
from .constants import Keys
@ -59,38 +59,36 @@ class Metadata:
# This is based on LLM_KV_NAMES mapping in llama.cpp
metadata_override = Metadata.load_metadata_override(metadata_override_path)
metadata.author = cast(Optional[str], metadata_override.get(Keys.General.AUTHOR , metadata.author )) # noqa: E202
metadata.version = cast(Optional[str], metadata_override.get(Keys.General.VERSION , metadata.version )) # noqa: E202
metadata.organization = cast(Optional[str], metadata_override.get(Keys.General.ORGANIZATION , metadata.organization )) # noqa: E202
metadata.author = metadata_override.get(Keys.General.AUTHOR, metadata.author)
metadata.version = metadata_override.get(Keys.General.VERSION, metadata.version)
metadata.organization = metadata_override.get(Keys.General.ORGANIZATION, metadata.organization)
metadata.finetune = cast(Optional[str], metadata_override.get(Keys.General.FINETUNE , metadata.finetune )) # noqa: E202
metadata.basename = cast(Optional[str], metadata_override.get(Keys.General.BASENAME , metadata.basename )) # noqa: E202
metadata.finetune = metadata_override.get(Keys.General.FINETUNE, metadata.finetune)
metadata.basename = metadata_override.get(Keys.General.BASENAME, metadata.basename)
metadata.description = cast(Optional[str], metadata_override.get(Keys.General.DESCRIPTION , metadata.description )) # noqa: E202
metadata.quantized_by = cast(Optional[str], metadata_override.get(Keys.General.QUANTIZED_BY , metadata.quantized_by )) # noqa: E202
metadata.description = metadata_override.get(Keys.General.DESCRIPTION, metadata.description)
metadata.quantized_by = metadata_override.get(Keys.General.QUANTIZED_BY, metadata.quantized_by)
metadata.parameter_class_attribute = cast(Optional[str], metadata_override.get(Keys.General.PARAMETER_CLASS_ATTRIBUTE, metadata.parameter_class_attribute)) # noqa: E202
metadata.parameter_class_attribute = metadata_override.get(Keys.General.PARAMETER_CLASS_ATTRIBUTE, metadata.parameter_class_attribute)
metadata.license_name = metadata_override.get(Keys.General.LICENSE_NAME, metadata.license_name)
metadata.license_link = metadata_override.get(Keys.General.LICENSE_LINK, metadata.license_link)
metadata.license = cast(Optional[str], metadata_override.get(Keys.General.LICENSE , metadata.license )) # noqa: E202
metadata.license_name = cast(Optional[str], metadata_override.get(Keys.General.LICENSE_NAME , metadata.license_name )) # noqa: E202
metadata.license_link = cast(Optional[str], metadata_override.get(Keys.General.LICENSE_LINK , metadata.license_link )) # noqa: E202
metadata.url = metadata_override.get(Keys.General.URL, metadata.url)
metadata.doi = metadata_override.get(Keys.General.DOI, metadata.doi)
metadata.uuid = metadata_override.get(Keys.General.UUID, metadata.uuid)
metadata.repo_url = metadata_override.get(Keys.General.REPO_URL, metadata.repo_url)
metadata.url = cast(Optional[str], metadata_override.get(Keys.General.URL , metadata.url )) # noqa: E202
metadata.doi = cast(Optional[str], metadata_override.get(Keys.General.DOI , metadata.doi )) # noqa: E202
metadata.uuid = cast(Optional[str], metadata_override.get(Keys.General.UUID , metadata.uuid )) # noqa: E202
metadata.repo_url = cast(Optional[str], metadata_override.get(Keys.General.REPO_URL , metadata.repo_url )) # noqa: E202
metadata.source_url = cast(Optional[str], metadata_override.get(Keys.General.SOURCE_URL , metadata.source_url )) # noqa: E202
metadata.source_doi = cast(Optional[str], metadata_override.get(Keys.General.SOURCE_DOI , metadata.source_doi )) # noqa: E202
metadata.source_uuid = cast(Optional[str], metadata_override.get(Keys.General.SOURCE_UUID , metadata.source_uuid )) # noqa: E202
metadata.source_repo_url = cast(Optional[str], metadata_override.get(Keys.General.SOURCE_REPO_URL , metadata.source_repo_url )) # noqa: E202
metadata.source_url = metadata_override.get(Keys.General.SOURCE_URL, metadata.source_url)
metadata.source_doi = metadata_override.get(Keys.General.SOURCE_DOI, metadata.source_doi)
metadata.source_uuid = metadata_override.get(Keys.General.SOURCE_UUID, metadata.source_uuid)
metadata.source_repo_url = metadata_override.get(Keys.General.SOURCE_REPO_URL, metadata.source_repo_url)
# Base Models is received here as an array of models
metadata.base_models = cast("Optional[list[dict]]", metadata_override.get("general.base_models" , metadata.base_models )) # noqa: E202
metadata.base_models = metadata_override.get("general.base_models", metadata.base_models)
metadata.tags = cast("Optional[list[str]]", metadata_override.get(Keys.General.TAGS , metadata.tags )) # noqa: E202
metadata.languages = cast("Optional[list[str]]", metadata_override.get(Keys.General.LANGUAGES , metadata.languages )) # noqa: E202
metadata.datasets = cast("Optional[list[str]]", metadata_override.get(Keys.General.DATASETS , metadata.datasets )) # noqa: E202
metadata.tags = metadata_override.get(Keys.General.TAGS, metadata.tags)
metadata.languages = metadata_override.get(Keys.General.LANGUAGES, metadata.languages)
metadata.datasets = metadata_override.get(Keys.General.DATASETS, metadata.datasets)
# Direct Metadata Override (via direct cli argument)
if model_name is not None:
@ -99,7 +97,7 @@ class Metadata:
return metadata
@staticmethod
def load_metadata_override(metadata_override_path: Optional[Path] = None) -> dict[str, object]:
def load_metadata_override(metadata_override_path: Optional[Path] = None) -> dict[str, Any]:
if metadata_override_path is None or not metadata_override_path.exists():
return {}