convert-*.py: cast not required if Metadata.load_metadata_override returned a dict[str, Any] instead of a dict[str, object]
Co-authored-by: compilade <git@compilade.net>
This commit is contained in:
parent
74383ba6d2
commit
4c91d077d2
1 changed files with 24 additions and 26 deletions
|
@ -4,7 +4,7 @@ import re
|
|||
import json
|
||||
import frontmatter
|
||||
from pathlib import Path
|
||||
from typing import Optional, cast
|
||||
from typing import Any, Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .constants import Keys
|
||||
|
@ -59,38 +59,36 @@ class Metadata:
|
|||
# This is based on LLM_KV_NAMES mapping in llama.cpp
|
||||
metadata_override = Metadata.load_metadata_override(metadata_override_path)
|
||||
|
||||
metadata.author = cast(Optional[str], metadata_override.get(Keys.General.AUTHOR , metadata.author )) # noqa: E202
|
||||
metadata.version = cast(Optional[str], metadata_override.get(Keys.General.VERSION , metadata.version )) # noqa: E202
|
||||
metadata.organization = cast(Optional[str], metadata_override.get(Keys.General.ORGANIZATION , metadata.organization )) # noqa: E202
|
||||
metadata.author = metadata_override.get(Keys.General.AUTHOR, metadata.author)
|
||||
metadata.version = metadata_override.get(Keys.General.VERSION, metadata.version)
|
||||
metadata.organization = metadata_override.get(Keys.General.ORGANIZATION, metadata.organization)
|
||||
|
||||
metadata.finetune = cast(Optional[str], metadata_override.get(Keys.General.FINETUNE , metadata.finetune )) # noqa: E202
|
||||
metadata.basename = cast(Optional[str], metadata_override.get(Keys.General.BASENAME , metadata.basename )) # noqa: E202
|
||||
metadata.finetune = metadata_override.get(Keys.General.FINETUNE, metadata.finetune)
|
||||
metadata.basename = metadata_override.get(Keys.General.BASENAME, metadata.basename)
|
||||
|
||||
metadata.description = cast(Optional[str], metadata_override.get(Keys.General.DESCRIPTION , metadata.description )) # noqa: E202
|
||||
metadata.quantized_by = cast(Optional[str], metadata_override.get(Keys.General.QUANTIZED_BY , metadata.quantized_by )) # noqa: E202
|
||||
metadata.description = metadata_override.get(Keys.General.DESCRIPTION, metadata.description)
|
||||
metadata.quantized_by = metadata_override.get(Keys.General.QUANTIZED_BY, metadata.quantized_by)
|
||||
|
||||
metadata.parameter_class_attribute = cast(Optional[str], metadata_override.get(Keys.General.PARAMETER_CLASS_ATTRIBUTE, metadata.parameter_class_attribute)) # noqa: E202
|
||||
metadata.parameter_class_attribute = metadata_override.get(Keys.General.PARAMETER_CLASS_ATTRIBUTE, metadata.parameter_class_attribute)
|
||||
metadata.license_name = metadata_override.get(Keys.General.LICENSE_NAME, metadata.license_name)
|
||||
metadata.license_link = metadata_override.get(Keys.General.LICENSE_LINK, metadata.license_link)
|
||||
|
||||
metadata.license = cast(Optional[str], metadata_override.get(Keys.General.LICENSE , metadata.license )) # noqa: E202
|
||||
metadata.license_name = cast(Optional[str], metadata_override.get(Keys.General.LICENSE_NAME , metadata.license_name )) # noqa: E202
|
||||
metadata.license_link = cast(Optional[str], metadata_override.get(Keys.General.LICENSE_LINK , metadata.license_link )) # noqa: E202
|
||||
metadata.url = metadata_override.get(Keys.General.URL, metadata.url)
|
||||
metadata.doi = metadata_override.get(Keys.General.DOI, metadata.doi)
|
||||
metadata.uuid = metadata_override.get(Keys.General.UUID, metadata.uuid)
|
||||
metadata.repo_url = metadata_override.get(Keys.General.REPO_URL, metadata.repo_url)
|
||||
|
||||
metadata.url = cast(Optional[str], metadata_override.get(Keys.General.URL , metadata.url )) # noqa: E202
|
||||
metadata.doi = cast(Optional[str], metadata_override.get(Keys.General.DOI , metadata.doi )) # noqa: E202
|
||||
metadata.uuid = cast(Optional[str], metadata_override.get(Keys.General.UUID , metadata.uuid )) # noqa: E202
|
||||
metadata.repo_url = cast(Optional[str], metadata_override.get(Keys.General.REPO_URL , metadata.repo_url )) # noqa: E202
|
||||
|
||||
metadata.source_url = cast(Optional[str], metadata_override.get(Keys.General.SOURCE_URL , metadata.source_url )) # noqa: E202
|
||||
metadata.source_doi = cast(Optional[str], metadata_override.get(Keys.General.SOURCE_DOI , metadata.source_doi )) # noqa: E202
|
||||
metadata.source_uuid = cast(Optional[str], metadata_override.get(Keys.General.SOURCE_UUID , metadata.source_uuid )) # noqa: E202
|
||||
metadata.source_repo_url = cast(Optional[str], metadata_override.get(Keys.General.SOURCE_REPO_URL , metadata.source_repo_url )) # noqa: E202
|
||||
metadata.source_url = metadata_override.get(Keys.General.SOURCE_URL, metadata.source_url)
|
||||
metadata.source_doi = metadata_override.get(Keys.General.SOURCE_DOI, metadata.source_doi)
|
||||
metadata.source_uuid = metadata_override.get(Keys.General.SOURCE_UUID, metadata.source_uuid)
|
||||
metadata.source_repo_url = metadata_override.get(Keys.General.SOURCE_REPO_URL, metadata.source_repo_url)
|
||||
|
||||
# Base Models is received here as an array of models
|
||||
metadata.base_models = cast("Optional[list[dict]]", metadata_override.get("general.base_models" , metadata.base_models )) # noqa: E202
|
||||
metadata.base_models = metadata_override.get("general.base_models", metadata.base_models)
|
||||
|
||||
metadata.tags = cast("Optional[list[str]]", metadata_override.get(Keys.General.TAGS , metadata.tags )) # noqa: E202
|
||||
metadata.languages = cast("Optional[list[str]]", metadata_override.get(Keys.General.LANGUAGES , metadata.languages )) # noqa: E202
|
||||
metadata.datasets = cast("Optional[list[str]]", metadata_override.get(Keys.General.DATASETS , metadata.datasets )) # noqa: E202
|
||||
metadata.tags = metadata_override.get(Keys.General.TAGS, metadata.tags)
|
||||
metadata.languages = metadata_override.get(Keys.General.LANGUAGES, metadata.languages)
|
||||
metadata.datasets = metadata_override.get(Keys.General.DATASETS, metadata.datasets)
|
||||
|
||||
# Direct Metadata Override (via direct cli argument)
|
||||
if model_name is not None:
|
||||
|
@ -99,7 +97,7 @@ class Metadata:
|
|||
return metadata
|
||||
|
||||
@staticmethod
|
||||
def load_metadata_override(metadata_override_path: Optional[Path] = None) -> dict[str, object]:
|
||||
def load_metadata_override(metadata_override_path: Optional[Path] = None) -> dict[str, Any]:
|
||||
if metadata_override_path is None or not metadata_override_path.exists():
|
||||
return {}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue