convert-*.py: pyright type fixes

This commit is contained in:
brian khuu 2024-07-10 23:39:09 +10:00
parent 59a01df784
commit dd14b8fdb1
5 changed files with 90 additions and 86 deletions

View file

@ -59,14 +59,14 @@ class Model:
tensor_map: gguf.TensorNameMap tensor_map: gguf.TensorNameMap
tensor_names: set[str] | None tensor_names: set[str] | None
fname_out: Path fname_out: Path
fname_default: Path fname_default: str
gguf_writer: gguf.GGUFWriter gguf_writer: gguf.GGUFWriter
metadata: gguf.Metadata metadata: gguf.Metadata
# subclasses should define this! # subclasses should define this!
model_arch: gguf.MODEL_ARCH model_arch: gguf.MODEL_ARCH
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool, metadata: gguf.Metadata, def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path | None, is_big_endian: bool, use_temp_file: bool, eager: bool, metadata: gguf.Metadata,
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False): split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False):
if type(self) is Model: if type(self) is Model:
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated") raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
@ -107,11 +107,11 @@ class Model:
# Generate parameter weight class (useful for leader boards) if not yet determined # Generate parameter weight class (useful for leader boards) if not yet determined
if self.metadata.parameter_class_attribute is None: if self.metadata.parameter_class_attribute is None:
expert_count = self.hparams["num_local_experts"] if "num_local_experts" in self.hparams else None expert_count = self.hparams.get("num_local_experts", 0)
sum_weight_estimate = self.calculate_total_weight_count() sum_weight_estimate = self.calculate_total_weight_count()
# Calculate weight estimate per model # Calculate weight estimate per model
per_model_weight_estimate = (sum_weight_estimate / expert_count) if expert_count is not None and (expert_count > 0) else sum_weight_estimate per_model_weight_estimate: int = (sum_weight_estimate / expert_count) if (expert_count > 0) else sum_weight_estimate
self.metadata.parameter_class_attribute = gguf.parameter_class_attribute(expert_count, per_model_weight_estimate) self.metadata.parameter_class_attribute = gguf.parameter_class_attribute(expert_count, per_model_weight_estimate)
@ -400,7 +400,7 @@ class Model:
def write(self): def write(self):
self.prepare_tensors_for_writing() self.prepare_tensors_for_writing()
self.prepare_key_value_store() self.prepare_key_value_store()
self.gguf_writer.write_header_to_file(self.fname_out) self.gguf_writer.write_header_to_file()
self.gguf_writer.write_kv_data_to_file() self.gguf_writer.write_kv_data_to_file()
self.gguf_writer.write_tensors_to_file(progress=True) self.gguf_writer.write_tensors_to_file(progress=True)
self.gguf_writer.close() self.gguf_writer.close()
@ -414,7 +414,7 @@ class Model:
self.prepare_tensors_for_writing() self.prepare_tensors_for_writing()
self.prepare_key_value_store() self.prepare_key_value_store()
self.gguf_writer.write_header_to_file(self.fname_out) self.gguf_writer.write_header_to_file()
self.gguf_writer.write_kv_data_to_file() self.gguf_writer.write_kv_data_to_file()
self.gguf_writer.close() self.gguf_writer.close()
@ -2525,7 +2525,6 @@ class Gemma2Model(Model):
hparams = self.hparams hparams = self.hparams
block_count = hparams["num_hidden_layers"] block_count = hparams["num_hidden_layers"]
self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
self.gguf_writer.add_context_length(hparams["max_position_embeddings"]) self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
self.gguf_writer.add_embedding_length(hparams["hidden_size"]) self.gguf_writer.add_embedding_length(hparams["hidden_size"])
self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_block_count(block_count)
@ -2778,7 +2777,6 @@ class OpenELMModel(Model):
assert self.block_count == len(self._num_query_heads) assert self.block_count == len(self._num_query_heads)
assert self.block_count == len(self._ffn_dims) assert self.block_count == len(self._ffn_dims)
self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_context_length(self.hparams["max_context_length"]) self.gguf_writer.add_context_length(self.hparams["max_context_length"])
self.gguf_writer.add_embedding_length(n_embd) self.gguf_writer.add_embedding_length(n_embd)
@ -3618,11 +3616,10 @@ def main() -> None:
logger.error("Error: Cannot use temp file when splitting") logger.error("Error: Cannot use temp file when splitting")
sys.exit(1) sys.exit(1)
fname_out = None
if args.outfile is not None: if args.outfile is not None:
fname_out = args.outfile fname_out = args.outfile
else:
# output in the same directory as the model by default
fname_out = dir_model / 'ggml-model-{ftype}.gguf'
logger.info(f"Loading model: {dir_model.name}") logger.info(f"Loading model: {dir_model.name}")
@ -3638,8 +3635,9 @@ def main() -> None:
logger.error(f"Model {hparams['architectures'][0]} is not supported") logger.error(f"Model {hparams['architectures'][0]} is not supported")
sys.exit(1) sys.exit(1)
model_instance = model_class(dir_model, output_type, fname_out, args.bigendian, args.use_temp_file, model_instance = model_class(dir_model=dir_model, ftype=output_type, fname_out=fname_out,
args.no_lazy, metadata, split_max_tensors=args.split_max_tensors, is_big_endian=args.bigendian, use_temp_file=args.use_temp_file,
eager=args.no_lazy, metadata=metadata, split_max_tensors=args.split_max_tensors,
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run, split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
small_first_shard=args.no_tensor_first_split) small_first_shard=args.no_tensor_first_split)

View file

@ -25,6 +25,7 @@ from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, IO, Iterable, Literal, TypeVar from typing import TYPE_CHECKING, Any, Callable, IO, Iterable, Literal, TypeVar
from _collections_abc import dict_items
import numpy as np import numpy as np
@ -773,7 +774,7 @@ class OutputFile:
def add_meta_model(self, params: Params, metadata: gguf.Metadata | None) -> None: def add_meta_model(self, params: Params, metadata: gguf.Metadata | None) -> None:
# Metadata About The Model And Its Provenence # Metadata About The Model And Its Provenence
name = "LLaMA" name = "LLaMA"
if metadata.name is not None: if metadata is not None and metadata.name is not None:
name = metadata.name name = metadata.name
elif params.path_model is not None: elif params.path_model is not None:
name = params.path_model.name name = params.path_model.name
@ -783,52 +784,52 @@ class OutputFile:
self.gguf.add_name(name) self.gguf.add_name(name)
if metadata.author is not None: if metadata is not None and metadata.author is not None:
self.gguf.add_author(metadata.author) self.gguf.add_author(metadata.author)
if metadata.version is not None: if metadata is not None and metadata.version is not None:
self.gguf.add_version(metadata.version) self.gguf.add_version(metadata.version)
if metadata.organization is not None: if metadata is not None and metadata.organization is not None:
self.gguf.add_organization(metadata.organization) self.gguf.add_organization(metadata.organization)
if metadata.finetune is not None: if metadata is not None and metadata.finetune is not None:
self.gguf.add_finetune(metadata.finetune) self.gguf.add_finetune(metadata.finetune)
if metadata.basename is not None: if metadata is not None and metadata.basename is not None:
self.gguf.add_basename(metadata.basename) self.gguf.add_basename(metadata.basename)
if metadata.description is not None: if metadata is not None and metadata.description is not None:
self.gguf.add_description(metadata.description) self.gguf.add_description(metadata.description)
if metadata.quantized_by is not None: if metadata is not None and metadata.quantized_by is not None:
self.gguf.add_quantized_by(metadata.quantized_by) self.gguf.add_quantized_by(metadata.quantized_by)
if metadata.parameter_class_attribute is not None: if metadata is not None and metadata.parameter_class_attribute is not None:
self.gguf.add_parameter_class_attribute(metadata.parameter_class_attribute) self.gguf.add_parameter_class_attribute(metadata.parameter_class_attribute)
if metadata.license is not None: if metadata is not None and metadata.license is not None:
self.gguf.add_license(metadata.license) self.gguf.add_license(metadata.license)
if metadata.license_name is not None: if metadata is not None and metadata.license_name is not None:
self.gguf.add_license_name(metadata.license_name) self.gguf.add_license_name(metadata.license_name)
if metadata.license_link is not None: if metadata is not None and metadata.license_link is not None:
self.gguf.add_license_link(metadata.license_link) self.gguf.add_license_link(metadata.license_link)
if metadata.url is not None: if metadata is not None and metadata.url is not None:
self.gguf.add_url(metadata.url) self.gguf.add_url(metadata.url)
if metadata.doi is not None: if metadata is not None and metadata.doi is not None:
self.gguf.add_doi(metadata.doi) self.gguf.add_doi(metadata.doi)
if metadata.uuid is not None: if metadata is not None and metadata.uuid is not None:
self.gguf.add_uuid(metadata.uuid) self.gguf.add_uuid(metadata.uuid)
if metadata.repo_url is not None: if metadata is not None and metadata.repo_url is not None:
self.gguf.add_repo_url(metadata.repo_url) self.gguf.add_repo_url(metadata.repo_url)
if metadata.source_url is not None: if metadata is not None and metadata.source_url is not None:
self.gguf.add_source_url(metadata.source_url) self.gguf.add_source_url(metadata.source_url)
if metadata.source_doi is not None: if metadata is not None and metadata.source_doi is not None:
self.gguf.add_source_doi(metadata.source_doi) self.gguf.add_source_doi(metadata.source_doi)
if metadata.source_uuid is not None: if metadata is not None and metadata.source_uuid is not None:
self.gguf.add_source_uuid(metadata.source_uuid) self.gguf.add_source_uuid(metadata.source_uuid)
if metadata.source_repo_url is not None: if metadata is not None and metadata.source_repo_url is not None:
self.gguf.add_source_repo_url(metadata.source_repo_url) self.gguf.add_source_repo_url(metadata.source_repo_url)
if metadata.base_models is not None: if metadata is not None and metadata.base_models is not None:
self.gguf.add_base_model_count(len(metadata.base_models)) self.gguf.add_base_model_count(len(metadata.base_models))
for key, base_model_entry in enumerate(metadata.base_models): for key, base_model_entry in enumerate(metadata.base_models):
if "name" in base_model_entry: if "name" in base_model_entry:
@ -848,11 +849,11 @@ class OutputFile:
if "repo_url" in base_model_entry: if "repo_url" in base_model_entry:
self.gguf.add_base_model_repo_url(key, base_model_entry["repo_url"]) self.gguf.add_base_model_repo_url(key, base_model_entry["repo_url"])
if metadata.tags is not None: if metadata is not None and metadata.tags is not None:
self.gguf.add_tags(metadata.tags) self.gguf.add_tags(metadata.tags)
if metadata.languages is not None: if metadata is not None and metadata.languages is not None:
self.gguf.add_languages(metadata.languages) self.gguf.add_languages(metadata.languages)
if metadata.datasets is not None: if metadata is not None and metadata.datasets is not None:
self.gguf.add_datasets(metadata.datasets) self.gguf.add_datasets(metadata.datasets)
def add_meta_arch(self, params: Params) -> None: def add_meta_arch(self, params: Params) -> None:
@ -1041,16 +1042,16 @@ def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileT
raise ValueError(f"Unexpected combination of types: {name_to_type}") raise ValueError(f"Unexpected combination of types: {name_to_type}")
def per_model_weight_count_estimation(tensors: dict[str, LazyTensor], expert_count:int) -> int: def per_model_weight_count_estimation(tensors: dict_items[str, LazyTensor], expert_count:int | None) -> int:
# TODO: Ensure parameter count is accurate throughout various model type # TODO: Ensure parameter count is accurate throughout various model type
sum_weight_estimate = 0 sum_weight_estimate: int = 0
for name, lazy_tensor in tensors: for name, lazy_tensor in tensors:
# We don't need these # We don't need these
if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")): if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")):
continue continue
# Got A Tensor # Got A Tensor
sum_weights_in_tensor = 1 sum_weights_in_tensor: int = 1
# Tensor Volume # Tensor Volume
for dim in lazy_tensor.shape: for dim in lazy_tensor.shape:
@ -1059,10 +1060,14 @@ def per_model_weight_count_estimation(tensors: dict[str, LazyTensor], expert_cou
# Add Tensor Volume To Running Count # Add Tensor Volume To Running Count
sum_weight_estimate += sum_weights_in_tensor sum_weight_estimate += sum_weights_in_tensor
# Calculate weight estimate per model if expert_count is None:
per_model_weight_estimate = (sum_weight_estimate / expert_count) if expert_count is not None and (expert_count > 0) else sum_weight_estimate return sum_weight_estimate
return per_model_weight_estimate if expert_count is not None and expert_count == 0:
return sum_weight_estimate
# Calculate weight estimate per model
return int(sum_weight_estimate / expert_count)
def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyModel: def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyModel:
@ -1244,7 +1249,7 @@ class VocabFactory:
return vocab, special_vocab return vocab, special_vocab
def default_convention_outfile(file_type: GGMLFileType, expert_count:int, model_params_count: int, metadata: gguf.Metadata) -> str: def default_convention_outfile(file_type: GGMLFileType, expert_count:int | None, model_params_count: int, metadata: gguf.Metadata) -> str:
name = metadata.name if metadata.name is not None else None name = metadata.name if metadata.name is not None else None
basename = metadata.basename if metadata.basename is not None else None basename = metadata.basename if metadata.basename is not None else None
finetune = metadata.finetune if metadata.finetune is not None else None finetune = metadata.finetune if metadata.finetune is not None else None
@ -1260,7 +1265,7 @@ def default_convention_outfile(file_type: GGMLFileType, expert_count:int, model_
return gguf.naming_convention(name, basename, finetune, version, parameter_class_attribute, output_type) return gguf.naming_convention(name, basename, finetune, version, parameter_class_attribute, output_type)
def default_outfile(model_paths: list[Path], file_type: GGMLFileType, expert_count:int, model_params_count: int, metadata: gguf.Metadata) -> Path: def default_outfile(model_paths: list[Path], file_type: GGMLFileType, expert_count:int | None, model_params_count: int, metadata: gguf.Metadata) -> Path:
default_filename = default_convention_outfile(file_type, expert_count, model_params_count, metadata) default_filename = default_convention_outfile(file_type, expert_count, model_params_count, metadata)
ret = model_paths[0].parent / f"{default_filename}.gguf" ret = model_paths[0].parent / f"{default_filename}.gguf"
if ret in model_paths: if ret in model_paths:
@ -1326,7 +1331,7 @@ def main(args_in: list[str] | None = None) -> None:
model_params_count = per_model_weight_count_estimation(model_plus.model.items(), params.n_experts) model_params_count = per_model_weight_count_estimation(model_plus.model.items(), params.n_experts)
ftype = pick_output_type(model, args.outtype) ftype = pick_output_type(model, args.outtype)
if metadata.name is None: if (metadata is None or metadata.name is None) and params.path_model is not None:
metadata.name = params.path_model.name metadata.name = params.path_model.name
print(f"{default_convention_outfile(ftype, params.n_experts, model_params_count, metadata)}") # noqa: NP100 print(f"{default_convention_outfile(ftype, params.n_experts, model_params_count, metadata)}") # noqa: NP100
@ -1407,7 +1412,7 @@ def main(args_in: list[str] | None = None) -> None:
assert params is not None assert params is not None
if metadata.name is None: if metadata.name is None and params.path_model is not None:
metadata.name = params.path_model.name metadata.name = params.path_model.name
model_params_count = per_model_weight_count_estimation(model_plus.model.items(), params.n_experts) model_params_count = per_model_weight_count_estimation(model_plus.model.items(), params.n_experts)

View file

@ -369,7 +369,7 @@ class GGUFWriter:
self.state = WriterState.WEIGHTS self.state = WriterState.WEIGHTS
def generate_tensors_uuid(self) -> None: def generate_tensors_uuid(self) -> str:
uuidv5_sha1 = hashlib.sha1() uuidv5_sha1 = hashlib.sha1()
uuidv5_sha1.update(uuid.UUID('ef001206-dadc-5f6d-a15f-3359e577d4e5').bytes) uuidv5_sha1.update(uuid.UUID('ef001206-dadc-5f6d-a15f-3359e577d4e5').bytes)
@ -520,28 +520,28 @@ class GGUFWriter:
self.add_uint32(Keys.General.BASE_MODEL_COUNT, source_count) self.add_uint32(Keys.General.BASE_MODEL_COUNT, source_count)
def add_base_model_name(self, source_id: int, name: str) -> None: def add_base_model_name(self, source_id: int, name: str) -> None:
self.add_string(Keys.General.BASE_MODEL_NAME.format(id=self.source_id), name) self.add_string(Keys.General.BASE_MODEL_NAME.format(id=source_id), name)
def add_base_model_author(self, source_id: int, author: str) -> None: def add_base_model_author(self, source_id: int, author: str) -> None:
self.add_string(Keys.General.BASE_MODEL_AUTHOR.format(id=self.source_id), author) self.add_string(Keys.General.BASE_MODEL_AUTHOR.format(id=source_id), author)
def add_base_model_version(self, source_id: int, version: str) -> None: def add_base_model_version(self, source_id: int, version: str) -> None:
self.add_string(Keys.General.BASE_MODEL_VERSION.format(id=self.source_id), version) self.add_string(Keys.General.BASE_MODEL_VERSION.format(id=source_id), version)
def add_base_model_organization(self, source_id: int, organization: str) -> None: def add_base_model_organization(self, source_id: int, organization: str) -> None:
self.add_string(Keys.General.BASE_MODEL_ORGANIZATION.format(id=self.source_id), organization) self.add_string(Keys.General.BASE_MODEL_ORGANIZATION.format(id=source_id), organization)
def add_base_model_url(self, source_id: int, url: str) -> None: def add_base_model_url(self, source_id: int, url: str) -> None:
self.add_string(Keys.General.BASE_MODEL_URL.format(id=self.source_id), url) self.add_string(Keys.General.BASE_MODEL_URL.format(id=source_id), url)
def add_base_model_doi(self, source_id: int, doi: str) -> None: def add_base_model_doi(self, source_id: int, doi: str) -> None:
self.add_string(Keys.General.BASE_MODEL_DOI.format(id=self.source_id), doi) self.add_string(Keys.General.BASE_MODEL_DOI.format(id=source_id), doi)
def add_base_model_uuid(self, source_id: int, uuid: str) -> None: def add_base_model_uuid(self, source_id: int, uuid: str) -> None:
self.add_string(Keys.General.BASE_MODEL_UUID.format(id=self.source_id), uuid) self.add_string(Keys.General.BASE_MODEL_UUID.format(id=source_id), uuid)
def add_base_model_repo_url(self, source_id: int, repo_url: str) -> None: def add_base_model_repo_url(self, source_id: int, repo_url: str) -> None:
self.add_string(Keys.General.BASE_MODEL_REPO_URL.format(id=self.source_id), repo_url) self.add_string(Keys.General.BASE_MODEL_REPO_URL.format(id=source_id), repo_url)
def add_tags(self, tags: Sequence[str]) -> None: def add_tags(self, tags: Sequence[str]) -> None:
self.add_array(Keys.General.TAGS, tags) self.add_array(Keys.General.TAGS, tags)

View file

@ -4,7 +4,7 @@ import re
import json import json
import frontmatter import frontmatter
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional, cast
from dataclasses import dataclass from dataclasses import dataclass
from .constants import Keys from .constants import Keys
@ -59,38 +59,38 @@ class Metadata:
# This is based on LLM_KV_NAMES mapping in llama.cpp # This is based on LLM_KV_NAMES mapping in llama.cpp
metadata_override = Metadata.load_metadata_override(metadata_override_path) metadata_override = Metadata.load_metadata_override(metadata_override_path)
metadata.name = metadata_override.get(Keys.General.NAME , metadata.name ) # noqa: E202 metadata.author = cast(Optional[str], metadata_override.get(Keys.General.AUTHOR , metadata.author )) # noqa: E202
metadata.author = metadata_override.get(Keys.General.AUTHOR , metadata.author ) # noqa: E202 metadata.version = cast(Optional[str], metadata_override.get(Keys.General.VERSION , metadata.version )) # noqa: E202
metadata.version = metadata_override.get(Keys.General.VERSION , metadata.version ) # noqa: E202 metadata.organization = cast(Optional[str], metadata_override.get(Keys.General.ORGANIZATION , metadata.organization )) # noqa: E202
metadata.organization = metadata_override.get(Keys.General.ORGANIZATION , metadata.organization ) # noqa: E202
metadata.finetune = metadata_override.get(Keys.General.FINETUNE , metadata.finetune ) # noqa: E202 metadata.finetune = cast(Optional[str], metadata_override.get(Keys.General.FINETUNE , metadata.finetune )) # noqa: E202
metadata.basename = metadata_override.get(Keys.General.BASENAME , metadata.basename ) # noqa: E202 metadata.basename = cast(Optional[str], metadata_override.get(Keys.General.BASENAME , metadata.basename )) # noqa: E202
metadata.description = metadata_override.get(Keys.General.DESCRIPTION , metadata.description ) # noqa: E202 metadata.description = cast(Optional[str], metadata_override.get(Keys.General.DESCRIPTION , metadata.description )) # noqa: E202
metadata.quantized_by = metadata_override.get(Keys.General.QUANTIZED_BY , metadata.quantized_by ) # noqa: E202 metadata.quantized_by = cast(Optional[str], metadata_override.get(Keys.General.QUANTIZED_BY , metadata.quantized_by )) # noqa: E202
metadata.parameter_class_attribute = metadata_override.get(Keys.General.PARAMETER_CLASS_ATTRIBUTE, metadata.parameter_class_attribute) # noqa: E202 metadata.parameter_class_attribute = cast(Optional[str], metadata_override.get(Keys.General.PARAMETER_CLASS_ATTRIBUTE, metadata.parameter_class_attribute)) # noqa: E202
metadata.license = metadata_override.get(Keys.General.LICENSE , metadata.license ) # noqa: E202 metadata.license = cast(Optional[str], metadata_override.get(Keys.General.LICENSE , metadata.license )) # noqa: E202
metadata.license_name = metadata_override.get(Keys.General.LICENSE_NAME , metadata.license_name ) # noqa: E202 metadata.license_name = cast(Optional[str], metadata_override.get(Keys.General.LICENSE_NAME , metadata.license_name )) # noqa: E202
metadata.license_link = metadata_override.get(Keys.General.LICENSE_LINK , metadata.license_link ) # noqa: E202 metadata.license_link = cast(Optional[str], metadata_override.get(Keys.General.LICENSE_LINK , metadata.license_link )) # noqa: E202
metadata.url = metadata_override.get(Keys.General.URL , metadata.url ) # noqa: E202 metadata.url = cast(Optional[str], metadata_override.get(Keys.General.URL , metadata.url )) # noqa: E202
metadata.doi = metadata_override.get(Keys.General.DOI , metadata.doi ) # noqa: E202 metadata.doi = cast(Optional[str], metadata_override.get(Keys.General.DOI , metadata.doi )) # noqa: E202
metadata.uuid = metadata_override.get(Keys.General.UUID , metadata.uuid ) # noqa: E202 metadata.uuid = cast(Optional[str], metadata_override.get(Keys.General.UUID , metadata.uuid )) # noqa: E202
metadata.repo_url = metadata_override.get(Keys.General.REPO_URL , metadata.repo_url ) # noqa: E202 metadata.repo_url = cast(Optional[str], metadata_override.get(Keys.General.REPO_URL , metadata.repo_url )) # noqa: E202
metadata.source_url = metadata_override.get(Keys.General.SOURCE_URL , metadata.source_url ) # noqa: E202 metadata.source_url = cast(Optional[str], metadata_override.get(Keys.General.SOURCE_URL , metadata.source_url )) # noqa: E202
metadata.source_doi = metadata_override.get(Keys.General.SOURCE_DOI , metadata.source_doi ) # noqa: E202 metadata.source_doi = cast(Optional[str], metadata_override.get(Keys.General.SOURCE_DOI , metadata.source_doi )) # noqa: E202
metadata.source_uuid = metadata_override.get(Keys.General.SOURCE_UUID , metadata.source_uuid ) # noqa: E202 metadata.source_uuid = cast(Optional[str], metadata_override.get(Keys.General.SOURCE_UUID , metadata.source_uuid )) # noqa: E202
metadata.source_repo_url = metadata_override.get(Keys.General.SOURCE_REPO_URL , metadata.source_repo_url ) # noqa: E202 metadata.source_repo_url = cast(Optional[str], metadata_override.get(Keys.General.SOURCE_REPO_URL , metadata.source_repo_url )) # noqa: E202
metadata.base_models = metadata_override.get("general.base_models" , metadata.base_models ) # noqa: E202 # Base Models is received here as an array of models
metadata.base_models = cast("Optional[list[dict]]", metadata_override.get("general.base_models" , metadata.base_models )) # noqa: E202
metadata.tags = metadata_override.get(Keys.General.TAGS , metadata.tags ) # noqa: E202 metadata.tags = cast("Optional[list[str]]", metadata_override.get(Keys.General.TAGS , metadata.tags )) # noqa: E202
metadata.languages = metadata_override.get(Keys.General.LANGUAGES , metadata.languages ) # noqa: E202 metadata.languages = cast("Optional[list[str]]", metadata_override.get(Keys.General.LANGUAGES , metadata.languages )) # noqa: E202
metadata.datasets = metadata_override.get(Keys.General.DATASETS , metadata.datasets ) # noqa: E202 metadata.datasets = cast("Optional[list[str]]", metadata_override.get(Keys.General.DATASETS , metadata.datasets )) # noqa: E202
# Direct Metadata Override (via direct cli argument) # Direct Metadata Override (via direct cli argument)
if model_name is not None: if model_name is not None:
@ -117,7 +117,7 @@ class Metadata:
return {} return {}
with open(model_card_path, "r", encoding="utf-8") as f: with open(model_card_path, "r", encoding="utf-8") as f:
return frontmatter.load(f) return cast("dict[str, object]", frontmatter.load(f))
@staticmethod @staticmethod
def load_hf_parameters(model_path: Optional[Path] = None) -> dict[str, object]: def load_hf_parameters(model_path: Optional[Path] = None) -> dict[str, object]:
@ -138,7 +138,7 @@ class Metadata:
return ' '.join([w.title() if w.islower() and not re.match(r'^(v\d+(?:\.\d+)*|\d.*)$', w) else w for w in string.strip().replace('-', ' ').split()]) return ' '.join([w.title() if w.islower() and not re.match(r'^(v\d+(?:\.\d+)*|\d.*)$', w) else w for w in string.strip().replace('-', ' ').split()])
@staticmethod @staticmethod
def get_model_id_components(model_id: Optional[str] = None) -> dict[str, object]: def get_model_id_components(model_id: Optional[str] = None) -> tuple[str | None, str | None, str | None, str | None, str | None, str | None]:
# Huggingface often store model id as '<org>/<model name>' # Huggingface often store model id as '<org>/<model name>'
# so let's parse it and apply some heuristics if possible for model name components # so let's parse it and apply some heuristics if possible for model name components
@ -334,6 +334,7 @@ class Metadata:
return metadata return metadata
def set_gguf_meta_model(self, gguf_writer: gguf.GGUFWriter): def set_gguf_meta_model(self, gguf_writer: gguf.GGUFWriter):
assert self.name is not None
gguf_writer.add_name(self.name) gguf_writer.add_name(self.name)
if self.author is not None: if self.author is not None:

View file

@ -34,7 +34,7 @@ def model_weight_count_rounded_notation(model_params_count: int) -> str:
return f"{round(scaled_model_params)}{scale_suffix}" return f"{round(scaled_model_params)}{scale_suffix}"
def parameter_class_attribute(expert_count_int:int, model_params_count: int) -> str: def parameter_class_attribute(expert_count_int:int | None, model_params_count: int) -> str:
per_model_rounded_weight_estimate = model_weight_count_rounded_notation(model_params_count) per_model_rounded_weight_estimate = model_weight_count_rounded_notation(model_params_count)
if expert_count_int is not None and expert_count_int > 0: if expert_count_int is not None and expert_count_int > 0:
@ -45,7 +45,7 @@ def parameter_class_attribute(expert_count_int:int, model_params_count: int) ->
return size_class return size_class
def naming_convention(model_name: str, base_name: str, finetune_string:str, version_string:str, parameter_class_attribute: str, output_type: str) -> str: def naming_convention(model_name: str | None, base_name: str | None, finetune_string:str | None, version_string:str | None, parameter_class_attribute: str | None, output_type: str | None) -> str:
# Reference: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#gguf-naming-convention # Reference: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#gguf-naming-convention
if base_name is not None: if base_name is not None:
@ -61,6 +61,6 @@ def naming_convention(model_name: str, base_name: str, finetune_string:str, vers
version = f"-{version_string.strip().replace(' ', '-')}" if version_string is not None else "" version = f"-{version_string.strip().replace(' ', '-')}" if version_string is not None else ""
precision = f"-{output_type.strip().replace(' ', '-').upper()}" precision = f"-{output_type.strip().replace(' ', '-').upper()}" if output_type is not None else ""
return f"{name}{parameters}{finetune}{version}{precision}" return f"{name}{parameters}{finetune}{version}{precision}"