convert-*.py: pyright type fixes
This commit is contained in:
parent
59a01df784
commit
dd14b8fdb1
5 changed files with 90 additions and 86 deletions
|
@ -59,14 +59,14 @@ class Model:
|
|||
tensor_map: gguf.TensorNameMap
|
||||
tensor_names: set[str] | None
|
||||
fname_out: Path
|
||||
fname_default: Path
|
||||
fname_default: str
|
||||
gguf_writer: gguf.GGUFWriter
|
||||
metadata: gguf.Metadata
|
||||
|
||||
# subclasses should define this!
|
||||
model_arch: gguf.MODEL_ARCH
|
||||
|
||||
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool, metadata: gguf.Metadata,
|
||||
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path | None, is_big_endian: bool, use_temp_file: bool, eager: bool, metadata: gguf.Metadata,
|
||||
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False):
|
||||
if type(self) is Model:
|
||||
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
|
||||
|
@ -107,11 +107,11 @@ class Model:
|
|||
|
||||
# Generate parameter weight class (useful for leader boards) if not yet determined
|
||||
if self.metadata.parameter_class_attribute is None:
|
||||
expert_count = self.hparams["num_local_experts"] if "num_local_experts" in self.hparams else None
|
||||
expert_count = self.hparams.get("num_local_experts", 0)
|
||||
sum_weight_estimate = self.calculate_total_weight_count()
|
||||
|
||||
# Calculate weight estimate per model
|
||||
per_model_weight_estimate = (sum_weight_estimate / expert_count) if expert_count is not None and (expert_count > 0) else sum_weight_estimate
|
||||
per_model_weight_estimate: int = (sum_weight_estimate / expert_count) if (expert_count > 0) else sum_weight_estimate
|
||||
|
||||
self.metadata.parameter_class_attribute = gguf.parameter_class_attribute(expert_count, per_model_weight_estimate)
|
||||
|
||||
|
@ -400,7 +400,7 @@ class Model:
|
|||
def write(self):
|
||||
self.prepare_tensors_for_writing()
|
||||
self.prepare_key_value_store()
|
||||
self.gguf_writer.write_header_to_file(self.fname_out)
|
||||
self.gguf_writer.write_header_to_file()
|
||||
self.gguf_writer.write_kv_data_to_file()
|
||||
self.gguf_writer.write_tensors_to_file(progress=True)
|
||||
self.gguf_writer.close()
|
||||
|
@ -414,7 +414,7 @@ class Model:
|
|||
self.prepare_tensors_for_writing()
|
||||
|
||||
self.prepare_key_value_store()
|
||||
self.gguf_writer.write_header_to_file(self.fname_out)
|
||||
self.gguf_writer.write_header_to_file()
|
||||
self.gguf_writer.write_kv_data_to_file()
|
||||
self.gguf_writer.close()
|
||||
|
||||
|
@ -2525,7 +2525,6 @@ class Gemma2Model(Model):
|
|||
hparams = self.hparams
|
||||
block_count = hparams["num_hidden_layers"]
|
||||
|
||||
self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
|
||||
self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
|
||||
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
|
@ -2778,7 +2777,6 @@ class OpenELMModel(Model):
|
|||
assert self.block_count == len(self._num_query_heads)
|
||||
assert self.block_count == len(self._ffn_dims)
|
||||
|
||||
self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_context_length(self.hparams["max_context_length"])
|
||||
self.gguf_writer.add_embedding_length(n_embd)
|
||||
|
@ -3618,11 +3616,10 @@ def main() -> None:
|
|||
logger.error("Error: Cannot use temp file when splitting")
|
||||
sys.exit(1)
|
||||
|
||||
fname_out = None
|
||||
|
||||
if args.outfile is not None:
|
||||
fname_out = args.outfile
|
||||
else:
|
||||
# output in the same directory as the model by default
|
||||
fname_out = dir_model / 'ggml-model-{ftype}.gguf'
|
||||
|
||||
logger.info(f"Loading model: {dir_model.name}")
|
||||
|
||||
|
@ -3638,8 +3635,9 @@ def main() -> None:
|
|||
logger.error(f"Model {hparams['architectures'][0]} is not supported")
|
||||
sys.exit(1)
|
||||
|
||||
model_instance = model_class(dir_model, output_type, fname_out, args.bigendian, args.use_temp_file,
|
||||
args.no_lazy, metadata, split_max_tensors=args.split_max_tensors,
|
||||
model_instance = model_class(dir_model=dir_model, ftype=output_type, fname_out=fname_out,
|
||||
is_big_endian=args.bigendian, use_temp_file=args.use_temp_file,
|
||||
eager=args.no_lazy, metadata=metadata, split_max_tensors=args.split_max_tensors,
|
||||
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
|
||||
small_first_shard=args.no_tensor_first_split)
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@ from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
|||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, IO, Iterable, Literal, TypeVar
|
||||
from _collections_abc import dict_items
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -773,7 +774,7 @@ class OutputFile:
|
|||
def add_meta_model(self, params: Params, metadata: gguf.Metadata | None) -> None:
|
||||
# Metadata About The Model And Its Provenence
|
||||
name = "LLaMA"
|
||||
if metadata.name is not None:
|
||||
if metadata is not None and metadata.name is not None:
|
||||
name = metadata.name
|
||||
elif params.path_model is not None:
|
||||
name = params.path_model.name
|
||||
|
@ -783,52 +784,52 @@ class OutputFile:
|
|||
|
||||
self.gguf.add_name(name)
|
||||
|
||||
if metadata.author is not None:
|
||||
if metadata is not None and metadata.author is not None:
|
||||
self.gguf.add_author(metadata.author)
|
||||
if metadata.version is not None:
|
||||
if metadata is not None and metadata.version is not None:
|
||||
self.gguf.add_version(metadata.version)
|
||||
if metadata.organization is not None:
|
||||
if metadata is not None and metadata.organization is not None:
|
||||
self.gguf.add_organization(metadata.organization)
|
||||
|
||||
if metadata.finetune is not None:
|
||||
if metadata is not None and metadata.finetune is not None:
|
||||
self.gguf.add_finetune(metadata.finetune)
|
||||
if metadata.basename is not None:
|
||||
if metadata is not None and metadata.basename is not None:
|
||||
self.gguf.add_basename(metadata.basename)
|
||||
|
||||
if metadata.description is not None:
|
||||
if metadata is not None and metadata.description is not None:
|
||||
self.gguf.add_description(metadata.description)
|
||||
if metadata.quantized_by is not None:
|
||||
if metadata is not None and metadata.quantized_by is not None:
|
||||
self.gguf.add_quantized_by(metadata.quantized_by)
|
||||
|
||||
if metadata.parameter_class_attribute is not None:
|
||||
if metadata is not None and metadata.parameter_class_attribute is not None:
|
||||
self.gguf.add_parameter_class_attribute(metadata.parameter_class_attribute)
|
||||
|
||||
if metadata.license is not None:
|
||||
if metadata is not None and metadata.license is not None:
|
||||
self.gguf.add_license(metadata.license)
|
||||
if metadata.license_name is not None:
|
||||
if metadata is not None and metadata.license_name is not None:
|
||||
self.gguf.add_license_name(metadata.license_name)
|
||||
if metadata.license_link is not None:
|
||||
if metadata is not None and metadata.license_link is not None:
|
||||
self.gguf.add_license_link(metadata.license_link)
|
||||
|
||||
if metadata.url is not None:
|
||||
if metadata is not None and metadata.url is not None:
|
||||
self.gguf.add_url(metadata.url)
|
||||
if metadata.doi is not None:
|
||||
if metadata is not None and metadata.doi is not None:
|
||||
self.gguf.add_doi(metadata.doi)
|
||||
if metadata.uuid is not None:
|
||||
if metadata is not None and metadata.uuid is not None:
|
||||
self.gguf.add_uuid(metadata.uuid)
|
||||
if metadata.repo_url is not None:
|
||||
if metadata is not None and metadata.repo_url is not None:
|
||||
self.gguf.add_repo_url(metadata.repo_url)
|
||||
|
||||
if metadata.source_url is not None:
|
||||
if metadata is not None and metadata.source_url is not None:
|
||||
self.gguf.add_source_url(metadata.source_url)
|
||||
if metadata.source_doi is not None:
|
||||
if metadata is not None and metadata.source_doi is not None:
|
||||
self.gguf.add_source_doi(metadata.source_doi)
|
||||
if metadata.source_uuid is not None:
|
||||
if metadata is not None and metadata.source_uuid is not None:
|
||||
self.gguf.add_source_uuid(metadata.source_uuid)
|
||||
if metadata.source_repo_url is not None:
|
||||
if metadata is not None and metadata.source_repo_url is not None:
|
||||
self.gguf.add_source_repo_url(metadata.source_repo_url)
|
||||
|
||||
if metadata.base_models is not None:
|
||||
if metadata is not None and metadata.base_models is not None:
|
||||
self.gguf.add_base_model_count(len(metadata.base_models))
|
||||
for key, base_model_entry in enumerate(metadata.base_models):
|
||||
if "name" in base_model_entry:
|
||||
|
@ -848,11 +849,11 @@ class OutputFile:
|
|||
if "repo_url" in base_model_entry:
|
||||
self.gguf.add_base_model_repo_url(key, base_model_entry["repo_url"])
|
||||
|
||||
if metadata.tags is not None:
|
||||
if metadata is not None and metadata.tags is not None:
|
||||
self.gguf.add_tags(metadata.tags)
|
||||
if metadata.languages is not None:
|
||||
if metadata is not None and metadata.languages is not None:
|
||||
self.gguf.add_languages(metadata.languages)
|
||||
if metadata.datasets is not None:
|
||||
if metadata is not None and metadata.datasets is not None:
|
||||
self.gguf.add_datasets(metadata.datasets)
|
||||
|
||||
def add_meta_arch(self, params: Params) -> None:
|
||||
|
@ -1041,16 +1042,16 @@ def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileT
|
|||
raise ValueError(f"Unexpected combination of types: {name_to_type}")
|
||||
|
||||
|
||||
def per_model_weight_count_estimation(tensors: dict[str, LazyTensor], expert_count:int) -> int:
|
||||
def per_model_weight_count_estimation(tensors: dict_items[str, LazyTensor], expert_count:int | None) -> int:
|
||||
# TODO: Ensure parameter count is accurate throughout various model type
|
||||
sum_weight_estimate = 0
|
||||
sum_weight_estimate: int = 0
|
||||
for name, lazy_tensor in tensors:
|
||||
# We don't need these
|
||||
if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")):
|
||||
continue
|
||||
|
||||
# Got A Tensor
|
||||
sum_weights_in_tensor = 1
|
||||
sum_weights_in_tensor: int = 1
|
||||
|
||||
# Tensor Volume
|
||||
for dim in lazy_tensor.shape:
|
||||
|
@ -1059,10 +1060,14 @@ def per_model_weight_count_estimation(tensors: dict[str, LazyTensor], expert_cou
|
|||
# Add Tensor Volume To Running Count
|
||||
sum_weight_estimate += sum_weights_in_tensor
|
||||
|
||||
# Calculate weight estimate per model
|
||||
per_model_weight_estimate = (sum_weight_estimate / expert_count) if expert_count is not None and (expert_count > 0) else sum_weight_estimate
|
||||
if expert_count is None:
|
||||
return sum_weight_estimate
|
||||
|
||||
return per_model_weight_estimate
|
||||
if expert_count is not None and expert_count == 0:
|
||||
return sum_weight_estimate
|
||||
|
||||
# Calculate weight estimate per model
|
||||
return int(sum_weight_estimate / expert_count)
|
||||
|
||||
|
||||
def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyModel:
|
||||
|
@ -1244,7 +1249,7 @@ class VocabFactory:
|
|||
return vocab, special_vocab
|
||||
|
||||
|
||||
def default_convention_outfile(file_type: GGMLFileType, expert_count:int, model_params_count: int, metadata: gguf.Metadata) -> str:
|
||||
def default_convention_outfile(file_type: GGMLFileType, expert_count:int | None, model_params_count: int, metadata: gguf.Metadata) -> str:
|
||||
name = metadata.name if metadata.name is not None else None
|
||||
basename = metadata.basename if metadata.basename is not None else None
|
||||
finetune = metadata.finetune if metadata.finetune is not None else None
|
||||
|
@ -1260,7 +1265,7 @@ def default_convention_outfile(file_type: GGMLFileType, expert_count:int, model_
|
|||
return gguf.naming_convention(name, basename, finetune, version, parameter_class_attribute, output_type)
|
||||
|
||||
|
||||
def default_outfile(model_paths: list[Path], file_type: GGMLFileType, expert_count:int, model_params_count: int, metadata: gguf.Metadata) -> Path:
|
||||
def default_outfile(model_paths: list[Path], file_type: GGMLFileType, expert_count:int | None, model_params_count: int, metadata: gguf.Metadata) -> Path:
|
||||
default_filename = default_convention_outfile(file_type, expert_count, model_params_count, metadata)
|
||||
ret = model_paths[0].parent / f"{default_filename}.gguf"
|
||||
if ret in model_paths:
|
||||
|
@ -1326,7 +1331,7 @@ def main(args_in: list[str] | None = None) -> None:
|
|||
model_params_count = per_model_weight_count_estimation(model_plus.model.items(), params.n_experts)
|
||||
ftype = pick_output_type(model, args.outtype)
|
||||
|
||||
if metadata.name is None:
|
||||
if (metadata is None or metadata.name is None) and params.path_model is not None:
|
||||
metadata.name = params.path_model.name
|
||||
|
||||
print(f"{default_convention_outfile(ftype, params.n_experts, model_params_count, metadata)}") # noqa: NP100
|
||||
|
@ -1407,7 +1412,7 @@ def main(args_in: list[str] | None = None) -> None:
|
|||
|
||||
assert params is not None
|
||||
|
||||
if metadata.name is None:
|
||||
if metadata.name is None and params.path_model is not None:
|
||||
metadata.name = params.path_model.name
|
||||
|
||||
model_params_count = per_model_weight_count_estimation(model_plus.model.items(), params.n_experts)
|
||||
|
|
|
@ -369,7 +369,7 @@ class GGUFWriter:
|
|||
|
||||
self.state = WriterState.WEIGHTS
|
||||
|
||||
def generate_tensors_uuid(self) -> None:
|
||||
def generate_tensors_uuid(self) -> str:
|
||||
uuidv5_sha1 = hashlib.sha1()
|
||||
uuidv5_sha1.update(uuid.UUID('ef001206-dadc-5f6d-a15f-3359e577d4e5').bytes)
|
||||
|
||||
|
@ -520,28 +520,28 @@ class GGUFWriter:
|
|||
self.add_uint32(Keys.General.BASE_MODEL_COUNT, source_count)
|
||||
|
||||
def add_base_model_name(self, source_id: int, name: str) -> None:
|
||||
self.add_string(Keys.General.BASE_MODEL_NAME.format(id=self.source_id), name)
|
||||
self.add_string(Keys.General.BASE_MODEL_NAME.format(id=source_id), name)
|
||||
|
||||
def add_base_model_author(self, source_id: int, author: str) -> None:
|
||||
self.add_string(Keys.General.BASE_MODEL_AUTHOR.format(id=self.source_id), author)
|
||||
self.add_string(Keys.General.BASE_MODEL_AUTHOR.format(id=source_id), author)
|
||||
|
||||
def add_base_model_version(self, source_id: int, version: str) -> None:
|
||||
self.add_string(Keys.General.BASE_MODEL_VERSION.format(id=self.source_id), version)
|
||||
self.add_string(Keys.General.BASE_MODEL_VERSION.format(id=source_id), version)
|
||||
|
||||
def add_base_model_organization(self, source_id: int, organization: str) -> None:
|
||||
self.add_string(Keys.General.BASE_MODEL_ORGANIZATION.format(id=self.source_id), organization)
|
||||
self.add_string(Keys.General.BASE_MODEL_ORGANIZATION.format(id=source_id), organization)
|
||||
|
||||
def add_base_model_url(self, source_id: int, url: str) -> None:
|
||||
self.add_string(Keys.General.BASE_MODEL_URL.format(id=self.source_id), url)
|
||||
self.add_string(Keys.General.BASE_MODEL_URL.format(id=source_id), url)
|
||||
|
||||
def add_base_model_doi(self, source_id: int, doi: str) -> None:
|
||||
self.add_string(Keys.General.BASE_MODEL_DOI.format(id=self.source_id), doi)
|
||||
self.add_string(Keys.General.BASE_MODEL_DOI.format(id=source_id), doi)
|
||||
|
||||
def add_base_model_uuid(self, source_id: int, uuid: str) -> None:
|
||||
self.add_string(Keys.General.BASE_MODEL_UUID.format(id=self.source_id), uuid)
|
||||
self.add_string(Keys.General.BASE_MODEL_UUID.format(id=source_id), uuid)
|
||||
|
||||
def add_base_model_repo_url(self, source_id: int, repo_url: str) -> None:
|
||||
self.add_string(Keys.General.BASE_MODEL_REPO_URL.format(id=self.source_id), repo_url)
|
||||
self.add_string(Keys.General.BASE_MODEL_REPO_URL.format(id=source_id), repo_url)
|
||||
|
||||
def add_tags(self, tags: Sequence[str]) -> None:
|
||||
self.add_array(Keys.General.TAGS, tags)
|
||||
|
|
|
@ -4,7 +4,7 @@ import re
|
|||
import json
|
||||
import frontmatter
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Optional, cast
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .constants import Keys
|
||||
|
@ -59,38 +59,38 @@ class Metadata:
|
|||
# This is based on LLM_KV_NAMES mapping in llama.cpp
|
||||
metadata_override = Metadata.load_metadata_override(metadata_override_path)
|
||||
|
||||
metadata.name = metadata_override.get(Keys.General.NAME , metadata.name ) # noqa: E202
|
||||
metadata.author = metadata_override.get(Keys.General.AUTHOR , metadata.author ) # noqa: E202
|
||||
metadata.version = metadata_override.get(Keys.General.VERSION , metadata.version ) # noqa: E202
|
||||
metadata.organization = metadata_override.get(Keys.General.ORGANIZATION , metadata.organization ) # noqa: E202
|
||||
metadata.author = cast(Optional[str], metadata_override.get(Keys.General.AUTHOR , metadata.author )) # noqa: E202
|
||||
metadata.version = cast(Optional[str], metadata_override.get(Keys.General.VERSION , metadata.version )) # noqa: E202
|
||||
metadata.organization = cast(Optional[str], metadata_override.get(Keys.General.ORGANIZATION , metadata.organization )) # noqa: E202
|
||||
|
||||
metadata.finetune = metadata_override.get(Keys.General.FINETUNE , metadata.finetune ) # noqa: E202
|
||||
metadata.basename = metadata_override.get(Keys.General.BASENAME , metadata.basename ) # noqa: E202
|
||||
metadata.finetune = cast(Optional[str], metadata_override.get(Keys.General.FINETUNE , metadata.finetune )) # noqa: E202
|
||||
metadata.basename = cast(Optional[str], metadata_override.get(Keys.General.BASENAME , metadata.basename )) # noqa: E202
|
||||
|
||||
metadata.description = metadata_override.get(Keys.General.DESCRIPTION , metadata.description ) # noqa: E202
|
||||
metadata.quantized_by = metadata_override.get(Keys.General.QUANTIZED_BY , metadata.quantized_by ) # noqa: E202
|
||||
metadata.description = cast(Optional[str], metadata_override.get(Keys.General.DESCRIPTION , metadata.description )) # noqa: E202
|
||||
metadata.quantized_by = cast(Optional[str], metadata_override.get(Keys.General.QUANTIZED_BY , metadata.quantized_by )) # noqa: E202
|
||||
|
||||
metadata.parameter_class_attribute = metadata_override.get(Keys.General.PARAMETER_CLASS_ATTRIBUTE, metadata.parameter_class_attribute) # noqa: E202
|
||||
metadata.parameter_class_attribute = cast(Optional[str], metadata_override.get(Keys.General.PARAMETER_CLASS_ATTRIBUTE, metadata.parameter_class_attribute)) # noqa: E202
|
||||
|
||||
metadata.license = metadata_override.get(Keys.General.LICENSE , metadata.license ) # noqa: E202
|
||||
metadata.license_name = metadata_override.get(Keys.General.LICENSE_NAME , metadata.license_name ) # noqa: E202
|
||||
metadata.license_link = metadata_override.get(Keys.General.LICENSE_LINK , metadata.license_link ) # noqa: E202
|
||||
metadata.license = cast(Optional[str], metadata_override.get(Keys.General.LICENSE , metadata.license )) # noqa: E202
|
||||
metadata.license_name = cast(Optional[str], metadata_override.get(Keys.General.LICENSE_NAME , metadata.license_name )) # noqa: E202
|
||||
metadata.license_link = cast(Optional[str], metadata_override.get(Keys.General.LICENSE_LINK , metadata.license_link )) # noqa: E202
|
||||
|
||||
metadata.url = metadata_override.get(Keys.General.URL , metadata.url ) # noqa: E202
|
||||
metadata.doi = metadata_override.get(Keys.General.DOI , metadata.doi ) # noqa: E202
|
||||
metadata.uuid = metadata_override.get(Keys.General.UUID , metadata.uuid ) # noqa: E202
|
||||
metadata.repo_url = metadata_override.get(Keys.General.REPO_URL , metadata.repo_url ) # noqa: E202
|
||||
metadata.url = cast(Optional[str], metadata_override.get(Keys.General.URL , metadata.url )) # noqa: E202
|
||||
metadata.doi = cast(Optional[str], metadata_override.get(Keys.General.DOI , metadata.doi )) # noqa: E202
|
||||
metadata.uuid = cast(Optional[str], metadata_override.get(Keys.General.UUID , metadata.uuid )) # noqa: E202
|
||||
metadata.repo_url = cast(Optional[str], metadata_override.get(Keys.General.REPO_URL , metadata.repo_url )) # noqa: E202
|
||||
|
||||
metadata.source_url = metadata_override.get(Keys.General.SOURCE_URL , metadata.source_url ) # noqa: E202
|
||||
metadata.source_doi = metadata_override.get(Keys.General.SOURCE_DOI , metadata.source_doi ) # noqa: E202
|
||||
metadata.source_uuid = metadata_override.get(Keys.General.SOURCE_UUID , metadata.source_uuid ) # noqa: E202
|
||||
metadata.source_repo_url = metadata_override.get(Keys.General.SOURCE_REPO_URL , metadata.source_repo_url ) # noqa: E202
|
||||
metadata.source_url = cast(Optional[str], metadata_override.get(Keys.General.SOURCE_URL , metadata.source_url )) # noqa: E202
|
||||
metadata.source_doi = cast(Optional[str], metadata_override.get(Keys.General.SOURCE_DOI , metadata.source_doi )) # noqa: E202
|
||||
metadata.source_uuid = cast(Optional[str], metadata_override.get(Keys.General.SOURCE_UUID , metadata.source_uuid )) # noqa: E202
|
||||
metadata.source_repo_url = cast(Optional[str], metadata_override.get(Keys.General.SOURCE_REPO_URL , metadata.source_repo_url )) # noqa: E202
|
||||
|
||||
metadata.base_models = metadata_override.get("general.base_models" , metadata.base_models ) # noqa: E202
|
||||
# Base Models is received here as an array of models
|
||||
metadata.base_models = cast("Optional[list[dict]]", metadata_override.get("general.base_models" , metadata.base_models )) # noqa: E202
|
||||
|
||||
metadata.tags = metadata_override.get(Keys.General.TAGS , metadata.tags ) # noqa: E202
|
||||
metadata.languages = metadata_override.get(Keys.General.LANGUAGES , metadata.languages ) # noqa: E202
|
||||
metadata.datasets = metadata_override.get(Keys.General.DATASETS , metadata.datasets ) # noqa: E202
|
||||
metadata.tags = cast("Optional[list[str]]", metadata_override.get(Keys.General.TAGS , metadata.tags )) # noqa: E202
|
||||
metadata.languages = cast("Optional[list[str]]", metadata_override.get(Keys.General.LANGUAGES , metadata.languages )) # noqa: E202
|
||||
metadata.datasets = cast("Optional[list[str]]", metadata_override.get(Keys.General.DATASETS , metadata.datasets )) # noqa: E202
|
||||
|
||||
# Direct Metadata Override (via direct cli argument)
|
||||
if model_name is not None:
|
||||
|
@ -117,7 +117,7 @@ class Metadata:
|
|||
return {}
|
||||
|
||||
with open(model_card_path, "r", encoding="utf-8") as f:
|
||||
return frontmatter.load(f)
|
||||
return cast("dict[str, object]", frontmatter.load(f))
|
||||
|
||||
@staticmethod
|
||||
def load_hf_parameters(model_path: Optional[Path] = None) -> dict[str, object]:
|
||||
|
@ -138,7 +138,7 @@ class Metadata:
|
|||
return ' '.join([w.title() if w.islower() and not re.match(r'^(v\d+(?:\.\d+)*|\d.*)$', w) else w for w in string.strip().replace('-', ' ').split()])
|
||||
|
||||
@staticmethod
|
||||
def get_model_id_components(model_id: Optional[str] = None) -> dict[str, object]:
|
||||
def get_model_id_components(model_id: Optional[str] = None) -> tuple[str | None, str | None, str | None, str | None, str | None, str | None]:
|
||||
# Huggingface often store model id as '<org>/<model name>'
|
||||
# so let's parse it and apply some heuristics if possible for model name components
|
||||
|
||||
|
@ -334,6 +334,7 @@ class Metadata:
|
|||
return metadata
|
||||
|
||||
def set_gguf_meta_model(self, gguf_writer: gguf.GGUFWriter):
|
||||
assert self.name is not None
|
||||
gguf_writer.add_name(self.name)
|
||||
|
||||
if self.author is not None:
|
||||
|
|
|
@ -34,7 +34,7 @@ def model_weight_count_rounded_notation(model_params_count: int) -> str:
|
|||
return f"{round(scaled_model_params)}{scale_suffix}"
|
||||
|
||||
|
||||
def parameter_class_attribute(expert_count_int:int, model_params_count: int) -> str:
|
||||
def parameter_class_attribute(expert_count_int:int | None, model_params_count: int) -> str:
|
||||
per_model_rounded_weight_estimate = model_weight_count_rounded_notation(model_params_count)
|
||||
|
||||
if expert_count_int is not None and expert_count_int > 0:
|
||||
|
@ -45,7 +45,7 @@ def parameter_class_attribute(expert_count_int:int, model_params_count: int) ->
|
|||
return size_class
|
||||
|
||||
|
||||
def naming_convention(model_name: str, base_name: str, finetune_string:str, version_string:str, parameter_class_attribute: str, output_type: str) -> str:
|
||||
def naming_convention(model_name: str | None, base_name: str | None, finetune_string:str | None, version_string:str | None, parameter_class_attribute: str | None, output_type: str | None) -> str:
|
||||
# Reference: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#gguf-naming-convention
|
||||
|
||||
if base_name is not None:
|
||||
|
@ -61,6 +61,6 @@ def naming_convention(model_name: str, base_name: str, finetune_string:str, vers
|
|||
|
||||
version = f"-{version_string.strip().replace(' ', '-')}" if version_string is not None else ""
|
||||
|
||||
precision = f"-{output_type.strip().replace(' ', '-').upper()}"
|
||||
precision = f"-{output_type.strip().replace(' ', '-').upper()}" if output_type is not None else ""
|
||||
|
||||
return f"{name}{parameters}{finetune}{version}{precision}"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue