convert-*.py: metadata class moved to utility
This commit is contained in:
parent
916872f72f
commit
4d5f18a0e6
5 changed files with 128 additions and 149 deletions
|
@ -31,46 +31,6 @@ import gguf
|
||||||
logger = logging.getLogger("hf-to-gguf")
|
logger = logging.getLogger("hf-to-gguf")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Metadata:
|
|
||||||
name: Optional[str] = None
|
|
||||||
basename: Optional[str] = None
|
|
||||||
finetune: Optional[str] = None
|
|
||||||
author: Optional[str] = None
|
|
||||||
version: Optional[str] = None
|
|
||||||
url: Optional[str] = None
|
|
||||||
description: Optional[str] = None
|
|
||||||
licence: Optional[str] = None
|
|
||||||
source_url: Optional[str] = None
|
|
||||||
source_hf_repo: Optional[str] = None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def load(metadata_path: Path) -> Metadata:
|
|
||||||
if metadata_path is None or not metadata_path.exists():
|
|
||||||
return Metadata()
|
|
||||||
|
|
||||||
with open(metadata_path, 'r') as file:
|
|
||||||
data = json.load(file)
|
|
||||||
|
|
||||||
# Create a new Metadata instance
|
|
||||||
metadata = Metadata()
|
|
||||||
|
|
||||||
# Assigning values to Metadata attributes if they exist in the JSON file
|
|
||||||
# This is based on LLM_KV_NAMES mapping in llama.cpp
|
|
||||||
metadata.name = data.get("general.name")
|
|
||||||
metadata.basename = data.get("general.basename")
|
|
||||||
metadata.finetune = data.get("general.finetune")
|
|
||||||
metadata.author = data.get("general.author")
|
|
||||||
metadata.version = data.get("general.version")
|
|
||||||
metadata.url = data.get("general.url")
|
|
||||||
metadata.description = data.get("general.description")
|
|
||||||
metadata.license = data.get("general.license")
|
|
||||||
metadata.source_url = data.get("general.source.url")
|
|
||||||
metadata.source_hf_repo = data.get("general.source.huggingface.repository")
|
|
||||||
|
|
||||||
return metadata
|
|
||||||
|
|
||||||
|
|
||||||
###### MODEL DEFINITIONS ######
|
###### MODEL DEFINITIONS ######
|
||||||
|
|
||||||
class SentencePieceTokenTypes(IntEnum):
|
class SentencePieceTokenTypes(IntEnum):
|
||||||
|
@ -105,12 +65,12 @@ class Model:
|
||||||
fname_out: Path
|
fname_out: Path
|
||||||
fname_default: Path
|
fname_default: Path
|
||||||
gguf_writer: gguf.GGUFWriter
|
gguf_writer: gguf.GGUFWriter
|
||||||
metadata: Metadata
|
metadata: gguf.Metadata
|
||||||
|
|
||||||
# subclasses should define this!
|
# subclasses should define this!
|
||||||
model_arch: gguf.MODEL_ARCH
|
model_arch: gguf.MODEL_ARCH
|
||||||
|
|
||||||
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool, metadata: Metadata,
|
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool, metadata: gguf.Metadata,
|
||||||
model_name: str | None, split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False):
|
model_name: str | None, split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False):
|
||||||
if type(self) is Model:
|
if type(self) is Model:
|
||||||
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
|
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
|
||||||
|
@ -164,72 +124,23 @@ class Model:
|
||||||
if self.model_card is not None and "license" in self.model_card:
|
if self.model_card is not None and "license" in self.model_card:
|
||||||
self.metadata.source_hf_repo = self.model_card["license"]
|
self.metadata.source_hf_repo = self.model_card["license"]
|
||||||
|
|
||||||
# Set model name based on latest metadata either provided or calculated from environment
|
self.model_name = Model.get_model_name(self.metadata, self.hparams, self.dir_model, self.model_arch)
|
||||||
def get_model_name(metadata, huggingface_parameters, dir_model, model_arch):
|
|
||||||
if metadata is not None and metadata.name is not None:
|
|
||||||
# Explicit Metadata Was Provided By User
|
|
||||||
return metadata.name
|
|
||||||
elif huggingface_parameters is not None and "_name_or_path" in huggingface_parameters:
|
|
||||||
# Hugging Face Parameters Model Name or Model Folder Name is Provided
|
|
||||||
return huggingface_parameters["_name_or_path"]
|
|
||||||
elif huggingface_parameters is not None and "model_type" in huggingface_parameters:
|
|
||||||
# Hugging Face Parameters Model Type is Provided
|
|
||||||
return huggingface_parameters["model_type"]
|
|
||||||
elif dir_model is not None and dir_model.name is not None:
|
|
||||||
# Use directory folder name
|
|
||||||
return dir_model.name
|
|
||||||
else:
|
|
||||||
return gguf.MODEL_ARCH_NAMES[model_arch]
|
|
||||||
self.model_name = get_model_name(self.metadata, self.hparams, self.dir_model, self.model_arch)
|
|
||||||
|
|
||||||
# Extracts and converts the encoding scheme from the given file type name. e.g. 'gguf.LlamaFileType.ALL_F32' --> 'F32'
|
# Extracts and converts the encoding scheme from the given file type name. e.g. 'gguf.LlamaFileType.ALL_F32' --> 'F32'
|
||||||
encodingScheme = self.ftype.name.partition("_")[2]
|
encoding_scheme = self.ftype.name.partition("_")[2]
|
||||||
|
|
||||||
# Get Expert Count From huggingface_parameters
|
# Get Expert Count From huggingface_parameters
|
||||||
expert_count = self.hparams["num_local_experts"] if "num_local_experts" in self.hparams else None
|
expert_count = self.hparams["num_local_experts"] if "num_local_experts" in self.hparams else None
|
||||||
|
|
||||||
def per_model_weight_count_estimation(tensors, expert_count):
|
weight_estimate = gguf.per_model_weight_count_estimation(model_tensors, expert_count)
|
||||||
# TODO: Ensure parameter count is accurate throughout various model type
|
|
||||||
# May currently overestimate parameter count in Mamba model because
|
|
||||||
# output weights is tied with token embeddings.
|
|
||||||
sum_weight_estimate = 0
|
|
||||||
for name, data_torch in tensors:
|
|
||||||
# Got A Tensor
|
|
||||||
|
|
||||||
# We don't need these
|
|
||||||
if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Calculate Tensor Volume
|
|
||||||
sum_weights_in_tensor = 1
|
|
||||||
for dim in data_torch.shape:
|
|
||||||
sum_weights_in_tensor *= dim
|
|
||||||
|
|
||||||
# Add Tensor Volume To Running Count
|
|
||||||
sum_weight_estimate += sum_weights_in_tensor
|
|
||||||
|
|
||||||
# Calculate weight estimate per model
|
|
||||||
per_model_weight_estimate = (sum_weight_estimate / expert_count) if (expert_count > 0) else sum_weight_estimate
|
|
||||||
|
|
||||||
return per_model_weight_estimate
|
|
||||||
|
|
||||||
weight_estimate = per_model_weight_count_estimation(model_tensors, expert_count)
|
|
||||||
|
|
||||||
# Generate default filename based on model specification and available metadata
|
# Generate default filename based on model specification and available metadata
|
||||||
self.fname_default = gguf.naming_convention(self.model_name, self.metadata.basename, self.metadata.finetune, self.metadata.version, expert_count, weight_estimate, encodingScheme)
|
self.fname_default = gguf.naming_convention(self.model_name, self.metadata.basename, self.metadata.finetune, self.metadata.version, expert_count, weight_estimate, encoding_scheme)
|
||||||
|
|
||||||
# Filename Output
|
# Filename Output
|
||||||
if fname_out is not None:
|
if fname_out is not None:
|
||||||
# custom defined filename and path was provided
|
# custom defined filename and path was provided
|
||||||
def fill_templated_filename(filename: str, encodingScheme: str):
|
self.fname_out = fname_out.parent / gguf.fill_templated_filename(fname_out.name, encoding_scheme)
|
||||||
# Given a file name fill in any type templates e.g. 'some-model-name.{ftype}.gguf'
|
|
||||||
ftype_uppercase: str = encodingScheme.upper()
|
|
||||||
ftype_lowercase: str = encodingScheme.lower()
|
|
||||||
return filename.format(ftype_lowercase,
|
|
||||||
outtype=ftype_lowercase, ftype=ftype_lowercase,
|
|
||||||
OUTTYPE=ftype_uppercase, FTYPE=ftype_uppercase)
|
|
||||||
|
|
||||||
self.fname_out = fname_out.parent / fill_templated_filename(fname_out.name, encodingScheme)
|
|
||||||
else:
|
else:
|
||||||
# output in the same directory as the model by default
|
# output in the same directory as the model by default
|
||||||
self.fname_out = dir_model.parent / self.fname_default
|
self.fname_out = dir_model.parent / self.fname_default
|
||||||
|
@ -499,6 +410,24 @@ class Model:
|
||||||
self.gguf_writer.write_kv_data_to_file()
|
self.gguf_writer.write_kv_data_to_file()
|
||||||
self.gguf_writer.close()
|
self.gguf_writer.close()
|
||||||
|
|
||||||
|
# Set model name based on latest metadata either provided or calculated from environment
|
||||||
|
@staticmethod
|
||||||
|
def get_model_name(metadata, huggingface_parameters, dir_model, model_arch):
|
||||||
|
if metadata is not None and metadata.name is not None:
|
||||||
|
# Explicit Metadata Was Provided By User
|
||||||
|
return metadata.name
|
||||||
|
elif huggingface_parameters is not None and "_name_or_path" in huggingface_parameters:
|
||||||
|
# Hugging Face Parameters Model Name or Model Folder Name is Provided
|
||||||
|
return huggingface_parameters["_name_or_path"]
|
||||||
|
elif huggingface_parameters is not None and "model_type" in huggingface_parameters:
|
||||||
|
# Hugging Face Parameters Model Type is Provided
|
||||||
|
return huggingface_parameters["model_type"]
|
||||||
|
elif dir_model is not None and dir_model.name is not None:
|
||||||
|
# Use directory folder name
|
||||||
|
return dir_model.name
|
||||||
|
else:
|
||||||
|
return gguf.MODEL_ARCH_NAMES[model_arch]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str]:
|
def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str]:
|
||||||
part_names: list[str] = []
|
part_names: list[str] = []
|
||||||
|
@ -3682,7 +3611,7 @@ def main() -> None:
|
||||||
else:
|
else:
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
metadata = Metadata.load(args.metadata)
|
metadata = gguf.Metadata.load(args.metadata)
|
||||||
dir_model = args.model
|
dir_model = args.model
|
||||||
|
|
||||||
if not dir_model.is_dir():
|
if not dir_model.is_dir():
|
||||||
|
@ -3713,7 +3642,7 @@ def main() -> None:
|
||||||
hparams = Model.load_hparams(dir_model)
|
hparams = Model.load_hparams(dir_model)
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
encodingScheme = ftype_map[args.outtype]
|
encoding_scheme = ftype_map[args.outtype]
|
||||||
model_architecture = hparams["architectures"][0]
|
model_architecture = hparams["architectures"][0]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -24,7 +24,7 @@ from abc import ABC, abstractmethod
|
||||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Callable, IO, Iterable, Literal, TypeVar, Optional
|
from typing import TYPE_CHECKING, Any, Callable, IO, Iterable, Literal, TypeVar
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
@ -346,46 +346,6 @@ class Params:
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Metadata:
|
|
||||||
name: Optional[str] = None
|
|
||||||
basename: Optional[str] = None
|
|
||||||
finetune: Optional[str] = None
|
|
||||||
author: Optional[str] = None
|
|
||||||
version: Optional[str] = None
|
|
||||||
url: Optional[str] = None
|
|
||||||
description: Optional[str] = None
|
|
||||||
license: Optional[str] = None
|
|
||||||
source_url: Optional[str] = None
|
|
||||||
source_hf_repo: Optional[str] = None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def load(metadata_path: Path) -> Metadata:
|
|
||||||
if metadata_path is None or not metadata_path.exists():
|
|
||||||
return Metadata()
|
|
||||||
|
|
||||||
with open(metadata_path, 'r') as file:
|
|
||||||
data = json.load(file)
|
|
||||||
|
|
||||||
# Create a new Metadata instance
|
|
||||||
metadata = Metadata()
|
|
||||||
|
|
||||||
# Assigning values to Metadata attributes if they exist in the JSON file
|
|
||||||
# This is based on LLM_KV_NAMES mapping in llama.cpp
|
|
||||||
metadata.name = data.get("general.name")
|
|
||||||
metadata.basename = data.get("general.basename")
|
|
||||||
metadata.finetune = data.get("general.finetune")
|
|
||||||
metadata.author = data.get("general.author")
|
|
||||||
metadata.version = data.get("general.version")
|
|
||||||
metadata.url = data.get("general.url")
|
|
||||||
metadata.description = data.get("general.description")
|
|
||||||
metadata.license = data.get("general.license")
|
|
||||||
metadata.source_url = data.get("general.source.url")
|
|
||||||
metadata.source_hf_repo = data.get("general.source.huggingface.repository")
|
|
||||||
|
|
||||||
return metadata
|
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# data loading
|
# data loading
|
||||||
# TODO: reuse (probably move to gguf.py?)
|
# TODO: reuse (probably move to gguf.py?)
|
||||||
|
@ -810,7 +770,7 @@ class OutputFile:
|
||||||
def __init__(self, fname_out: Path, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE):
|
def __init__(self, fname_out: Path, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE):
|
||||||
self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess)
|
self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess)
|
||||||
|
|
||||||
def add_meta_model(self, params: Params, metadata: Metadata | None) -> None:
|
def add_meta_model(self, params: Params, metadata: gguf.Metadata | None) -> None:
|
||||||
# Metadata About The Model And Its Provenence
|
# Metadata About The Model And Its Provenence
|
||||||
name = "LLaMA"
|
name = "LLaMA"
|
||||||
if metadata is not None and metadata.name is not None:
|
if metadata is not None and metadata.name is not None:
|
||||||
|
@ -952,7 +912,7 @@ class OutputFile:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def write_vocab_only(
|
def write_vocab_only(
|
||||||
fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab,
|
fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab,
|
||||||
endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, metadata: Metadata | None = None,
|
endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, metadata: gguf.Metadata | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
check_vocab_size(params, vocab, pad_vocab=pad_vocab)
|
check_vocab_size(params, vocab, pad_vocab=pad_vocab)
|
||||||
|
|
||||||
|
@ -986,7 +946,7 @@ class OutputFile:
|
||||||
fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: BaseVocab, svocab: gguf.SpecialVocab,
|
fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: BaseVocab, svocab: gguf.SpecialVocab,
|
||||||
concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE,
|
concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE,
|
||||||
pad_vocab: bool = False,
|
pad_vocab: bool = False,
|
||||||
metadata: Metadata | None = None,
|
metadata: gguf.Metadata | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
check_vocab_size(params, vocab, pad_vocab=pad_vocab)
|
check_vocab_size(params, vocab, pad_vocab=pad_vocab)
|
||||||
|
|
||||||
|
@ -1029,10 +989,10 @@ def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileT
|
||||||
raise ValueError(f"Unexpected combination of types: {name_to_type}")
|
raise ValueError(f"Unexpected combination of types: {name_to_type}")
|
||||||
|
|
||||||
|
|
||||||
def per_model_weight_count_estimation(model: LazyModel, expert_count:int) -> int:
|
def per_model_weight_count_estimation(tensors: dict[str, LazyTensor], expert_count:int) -> int:
|
||||||
# TODO: Ensure parameter count is accurate throughout various model type
|
# TODO: Ensure parameter count is accurate throughout various model type
|
||||||
sum_weight_estimate = 0
|
sum_weight_estimate = 0
|
||||||
for name, lazy_tensor in model.items():
|
for name, lazy_tensor in tensors:
|
||||||
# We don't need these
|
# We don't need these
|
||||||
if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")):
|
if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")):
|
||||||
continue
|
continue
|
||||||
|
@ -1232,7 +1192,7 @@ class VocabFactory:
|
||||||
return vocab, special_vocab
|
return vocab, special_vocab
|
||||||
|
|
||||||
|
|
||||||
def default_convention_outfile(file_type: GGMLFileType, model_name:str, expert_count:int, model_params_count: int, metadata: Metadata) -> str:
|
def default_convention_outfile(file_type: GGMLFileType, model_name:str, expert_count:int, model_params_count: int, metadata: gguf.Metadata) -> str:
|
||||||
name = metadata.name if metadata is not None and metadata.name is not None else model_name
|
name = metadata.name if metadata is not None and metadata.name is not None else model_name
|
||||||
basename = metadata.basename if metadata is not None and metadata.basename is not None else None
|
basename = metadata.basename if metadata is not None and metadata.basename is not None else None
|
||||||
finetune = metadata.finetune if metadata is not None and metadata.finetune is not None else None
|
finetune = metadata.finetune if metadata is not None and metadata.finetune is not None else None
|
||||||
|
@ -1247,7 +1207,7 @@ def default_convention_outfile(file_type: GGMLFileType, model_name:str, expert_c
|
||||||
return gguf.naming_convention(name, basename, finetune, version, expert_count, model_params_count, encodingScheme)
|
return gguf.naming_convention(name, basename, finetune, version, expert_count, model_params_count, encodingScheme)
|
||||||
|
|
||||||
|
|
||||||
def default_outfile(model_paths: list[Path], file_type: GGMLFileType, model_name:str, expert_count:int, model_params_count: int, metadata: Metadata) -> Path:
|
def default_outfile(model_paths: list[Path], file_type: GGMLFileType, model_name:str, expert_count:int, model_params_count: int, metadata: gguf.Metadata) -> Path:
|
||||||
default_filename = default_convention_outfile(file_type, model_name, expert_count, model_params_count, metadata)
|
default_filename = default_convention_outfile(file_type, model_name, expert_count, model_params_count, metadata)
|
||||||
ret = model_paths[0].parent / f"{default_filename}.gguf"
|
ret = model_paths[0].parent / f"{default_filename}.gguf"
|
||||||
if ret in model_paths:
|
if ret in model_paths:
|
||||||
|
@ -1300,13 +1260,13 @@ def main(args_in: list[str] | None = None) -> None:
|
||||||
else:
|
else:
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
metadata = Metadata.load(args.metadata)
|
metadata = gguf.Metadata.load(args.metadata)
|
||||||
|
|
||||||
if args.get_outfile:
|
if args.get_outfile:
|
||||||
model_plus = load_some_model(args.model)
|
model_plus = load_some_model(args.model)
|
||||||
params = Params.load(model_plus)
|
params = Params.load(model_plus)
|
||||||
model = convert_model_names(model_plus.model, params, args.skip_unknown)
|
model = convert_model_names(model_plus.model, params, args.skip_unknown)
|
||||||
model_params_count = per_model_weight_count_estimation(model_plus.model, params.n_experts)
|
model_params_count = per_model_weight_count_estimation(model_plus.model.items(), params.n_experts)
|
||||||
ftype = pick_output_type(model, args.outtype)
|
ftype = pick_output_type(model, args.outtype)
|
||||||
print(f"{default_convention_outfile(ftype, params.path_model.name, params.n_experts, model_params_count, metadata)}") # noqa: NP100
|
print(f"{default_convention_outfile(ftype, params.path_model.name, params.n_experts, model_params_count, metadata)}") # noqa: NP100
|
||||||
return
|
return
|
||||||
|
@ -1324,7 +1284,7 @@ def main(args_in: list[str] | None = None) -> None:
|
||||||
else:
|
else:
|
||||||
model_plus = ModelPlus(model = {}, paths = [args.model / 'dummy'], format = 'none', vocab = None)
|
model_plus = ModelPlus(model = {}, paths = [args.model / 'dummy'], format = 'none', vocab = None)
|
||||||
|
|
||||||
model_params_count = per_model_weight_count_estimation(model_plus.model, params.n_experts)
|
model_params_count = per_model_weight_count_estimation(model_plus.model.items(), params.n_experts)
|
||||||
logger.info(f"model parameters count : {model_params_count} ({gguf.model_weight_count_rounded_notation(model_params_count)})")
|
logger.info(f"model parameters count : {model_params_count} ({gguf.model_weight_count_rounded_notation(model_params_count)})")
|
||||||
|
|
||||||
if args.dump:
|
if args.dump:
|
||||||
|
|
|
@ -6,3 +6,4 @@ from .quants import *
|
||||||
from .tensor_mapping import *
|
from .tensor_mapping import *
|
||||||
from .vocab import *
|
from .vocab import *
|
||||||
from .utility import *
|
from .utility import *
|
||||||
|
from .metadata import *
|
||||||
|
|
49
gguf-py/gguf/metadata.py
Normal file
49
gguf-py/gguf/metadata.py
Normal file
|
@ -0,0 +1,49 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from .constants import Keys
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Metadata:
|
||||||
|
name: Optional[str] = None
|
||||||
|
basename: Optional[str] = None
|
||||||
|
finetune: Optional[str] = None
|
||||||
|
author: Optional[str] = None
|
||||||
|
version: Optional[str] = None
|
||||||
|
url: Optional[str] = None
|
||||||
|
description: Optional[str] = None
|
||||||
|
licence: Optional[str] = None
|
||||||
|
source_url: Optional[str] = None
|
||||||
|
source_hf_repo: Optional[str] = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(metadata_path: Path) -> Metadata:
|
||||||
|
if metadata_path is None or not metadata_path.exists():
|
||||||
|
return Metadata()
|
||||||
|
|
||||||
|
with open(metadata_path, 'r') as file:
|
||||||
|
data = json.load(file)
|
||||||
|
|
||||||
|
# Create a new Metadata instance
|
||||||
|
metadata = Metadata()
|
||||||
|
|
||||||
|
# Assigning values to Metadata attributes if they exist in the JSON file
|
||||||
|
# This is based on LLM_KV_NAMES mapping in llama.cpp
|
||||||
|
metadata.name = data.get(Keys.General.NAME)
|
||||||
|
metadata.basename = data.get(Keys.General.BASENAME)
|
||||||
|
metadata.finetune = data.get(Keys.General.FINETUNE)
|
||||||
|
metadata.author = data.get(Keys.General.AUTHOR)
|
||||||
|
metadata.version = data.get(Keys.General.VERSION)
|
||||||
|
metadata.url = data.get(Keys.General.URL)
|
||||||
|
metadata.description = data.get(Keys.General.DESCRIPTION)
|
||||||
|
metadata.license = data.get(Keys.General.LICENSE)
|
||||||
|
metadata.source_url = data.get(Keys.General.SOURCE_URL)
|
||||||
|
metadata.source_hf_repo = data.get(Keys.General.SOURCE_HF_REPO)
|
||||||
|
|
||||||
|
return metadata
|
|
@ -1,5 +1,45 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Iterator
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
def fill_templated_filename(filename: str, encoding_scheme: str):
|
||||||
|
# Given a file name fill in any type templates e.g. 'some-model-name.{ftype}.gguf'
|
||||||
|
ftype_uppercase: str = encoding_scheme.upper()
|
||||||
|
ftype_lowercase: str = encoding_scheme.lower()
|
||||||
|
return filename.format(ftype_lowercase,
|
||||||
|
outtype=ftype_lowercase, ftype=ftype_lowercase,
|
||||||
|
OUTTYPE=ftype_uppercase, FTYPE=ftype_uppercase)
|
||||||
|
|
||||||
|
|
||||||
|
def per_model_weight_count_estimation(tensors: Iterator[tuple[str, Tensor]], expert_count: int) -> int:
|
||||||
|
# TODO: Ensure parameter count is accurate throughout various model type
|
||||||
|
# May currently overestimate parameter count in Mamba model because
|
||||||
|
# output weights is tied with token embeddings.
|
||||||
|
sum_weight_estimate = 0
|
||||||
|
for name, data_torch in tensors:
|
||||||
|
# Got A Tensor
|
||||||
|
|
||||||
|
# We don't need these
|
||||||
|
if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Calculate Tensor Volume
|
||||||
|
sum_weights_in_tensor = 1
|
||||||
|
for dim in data_torch.shape:
|
||||||
|
sum_weights_in_tensor *= dim
|
||||||
|
|
||||||
|
# Add Tensor Volume To Running Count
|
||||||
|
sum_weight_estimate += sum_weights_in_tensor
|
||||||
|
|
||||||
|
# Calculate weight estimate per model
|
||||||
|
per_model_weight_estimate = (sum_weight_estimate / expert_count) if (expert_count > 0) else sum_weight_estimate
|
||||||
|
|
||||||
|
return per_model_weight_estimate
|
||||||
|
|
||||||
|
|
||||||
def model_weight_count_rounded_notation(model_params_count: int) -> str:
|
def model_weight_count_rounded_notation(model_params_count: int) -> str:
|
||||||
if model_params_count > 1e15 :
|
if model_params_count > 1e15 :
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue