convert_hf : use GGUFWriter to count model parameters
This commit is contained in:
parent
78a42fbee5
commit
417d7a7c62
4 changed files with 101 additions and 97 deletions
|
@ -48,6 +48,7 @@ class Model:
|
|||
|
||||
dir_model: Path
|
||||
ftype: gguf.LlamaFileType
|
||||
fname_out: Path | None
|
||||
is_big_endian: bool
|
||||
endianess: gguf.GGUFEndian
|
||||
use_temp_file: bool
|
||||
|
@ -58,8 +59,6 @@ class Model:
|
|||
block_count: int
|
||||
tensor_map: gguf.TensorNameMap
|
||||
tensor_names: set[str] | None
|
||||
fname_out: Path
|
||||
fname_default: str
|
||||
gguf_writer: gguf.GGUFWriter
|
||||
metadata: gguf.Metadata
|
||||
|
||||
|
@ -76,6 +75,7 @@ class Model:
|
|||
|
||||
self.dir_model = dir_model
|
||||
self.ftype = ftype
|
||||
self.fname_out = fname_out
|
||||
self.is_big_endian = is_big_endian
|
||||
self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
|
||||
self.use_temp_file = use_temp_file
|
||||
|
@ -101,37 +101,8 @@ class Model:
|
|||
logger.info(f"choosing --outtype bf16 from first tensor type ({first_tensor.dtype})")
|
||||
self.ftype = gguf.LlamaFileType.MOSTLY_BF16
|
||||
|
||||
# Fallback to model directory name if metadata name is still missing
|
||||
if self.metadata.name is None:
|
||||
self.metadata.name = dir_model.name
|
||||
|
||||
# Generate parameter weight class (useful for leader boards) if not yet determined
|
||||
if self.metadata.size_label is None:
|
||||
expert_count = self.hparams.get("num_local_experts", 0)
|
||||
sum_weight_estimate = self.calculate_total_weight_count()
|
||||
|
||||
# Calculate weight estimate per model
|
||||
per_model_weight_estimate: int = (sum_weight_estimate / expert_count) if (expert_count > 0) else sum_weight_estimate
|
||||
|
||||
self.metadata.size_label = gguf.size_label(expert_count, per_model_weight_estimate)
|
||||
|
||||
# Extracts and converts the encoding scheme from the given file type name. e.g. 'gguf.LlamaFileType.ALL_F32' --> 'F32'
|
||||
output_type = self.ftype.name.partition("_")[2]
|
||||
|
||||
# Generate default filename based on model specification and available metadata
|
||||
self.fname_default = gguf.naming_convention(self.metadata.name, self.metadata.basename, self.metadata.finetune, self.metadata.version, self.metadata.size_label, output_type)
|
||||
|
||||
# Filename Output
|
||||
if fname_out is not None:
|
||||
# custom defined filename and path was provided
|
||||
# allow templating the file name with the output ftype, useful with the "auto" ftype
|
||||
self.fname_out = fname_out.parent / gguf.fill_templated_filename(fname_out.name, output_type)
|
||||
else:
|
||||
# output in the same directory as the model by default
|
||||
self.fname_out = dir_model / f"{self.fname_default}.gguf"
|
||||
|
||||
# Configure GGUF Writer
|
||||
self.gguf_writer = gguf.GGUFWriter(path=self.fname_out, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file,
|
||||
self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file,
|
||||
split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard)
|
||||
|
||||
@classmethod
|
||||
|
@ -191,23 +162,6 @@ class Model:
|
|||
if len(sym_diff := tensor_names_from_parts.symmetric_difference(self.tensor_names)) > 0:
|
||||
raise ValueError(f"Mismatch between weight map and model parts for tensor names: {sym_diff}")
|
||||
|
||||
def calculate_total_weight_count(self) -> int:
|
||||
sum_weight_estimate = 0
|
||||
for name, data_torch in self.get_tensors():
|
||||
# we don't need these
|
||||
if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")):
|
||||
continue
|
||||
|
||||
# Calculate Tensor Volume
|
||||
sum_weights_in_tensor = 1
|
||||
for dim in data_torch.shape:
|
||||
sum_weights_in_tensor *= dim
|
||||
|
||||
# Add Tensor Volume To Running Count
|
||||
sum_weight_estimate += sum_weights_in_tensor
|
||||
|
||||
return sum_weight_estimate
|
||||
|
||||
def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str:
|
||||
if key not in gguf.MODEL_TENSORS[self.model_arch]:
|
||||
raise ValueError(f"Missing {key!r} for MODEL_TENSORS of {self.model_arch!r}")
|
||||
|
@ -376,7 +330,37 @@ class Model:
|
|||
|
||||
self.gguf_writer.add_tensor(new_name, data, raw_dtype=data_qtype)
|
||||
|
||||
def prepare_key_value_store(self):
|
||||
def prepare_metadata(self):
|
||||
|
||||
# Fallback to model directory name if metadata name is still missing
|
||||
if self.metadata.name is None:
|
||||
self.metadata.name = self.dir_model.name
|
||||
|
||||
# Generate parameter weight class (useful for leader boards) if not yet determined
|
||||
if self.metadata.size_label is None:
|
||||
total_params, shared_params, expert_params, expert_count = self.gguf_writer.get_total_parameter_count()
|
||||
|
||||
if (total_params > 0):
|
||||
self.metadata.size_label = gguf.size_label(total_params, shared_params, expert_params, expert_count)
|
||||
|
||||
# Extract the encoding scheme from the file type name. e.g. 'gguf.LlamaFileType.MOSTLY_Q8_0' --> 'Q8_0'
|
||||
output_type = self.ftype.name.partition("_")[2]
|
||||
|
||||
# Generate default filename based on model specification and available metadata
|
||||
fname_default = gguf.naming_convention(self.metadata.name, self.metadata.basename, self.metadata.finetune, self.metadata.version, self.metadata.size_label, output_type)
|
||||
|
||||
# Filename Output
|
||||
if self.fname_out is not None:
|
||||
if not self.fname_out.is_dir():
|
||||
# custom defined filename and path was provided
|
||||
# allow templating the file name with the output ftype, useful with the "auto" ftype
|
||||
self.fname_out = self.fname_out.parent / gguf.fill_templated_filename(self.fname_out.name, output_type)
|
||||
else:
|
||||
# the target file is a directory
|
||||
self.fname_out = self.fname_out / f"{fname_default}.gguf"
|
||||
else:
|
||||
# output in the same directory as the model by default
|
||||
self.fname_out = self.dir_model / f"{fname_default}.gguf"
|
||||
|
||||
logger.info("Set meta model")
|
||||
self.metadata.set_gguf_meta_model(self.gguf_writer)
|
||||
|
@ -393,8 +377,8 @@ class Model:
|
|||
|
||||
def write(self):
|
||||
self.prepare_tensors()
|
||||
self.prepare_key_value_store()
|
||||
self.gguf_writer.write_header_to_file()
|
||||
self.prepare_metadata()
|
||||
self.gguf_writer.write_header_to_file(path=self.fname_out)
|
||||
self.gguf_writer.write_kv_data_to_file()
|
||||
self.gguf_writer.write_tensors_to_file(progress=True)
|
||||
self.gguf_writer.close()
|
||||
|
@ -403,8 +387,8 @@ class Model:
|
|||
if len(self.gguf_writer.tensors) != 1:
|
||||
raise ValueError('Splitting the vocabulary is not supported')
|
||||
|
||||
self.prepare_key_value_store()
|
||||
self.gguf_writer.write_header_to_file()
|
||||
self.prepare_metadata()
|
||||
self.gguf_writer.write_header_to_file(path=self.fname_out)
|
||||
self.gguf_writer.write_kv_data_to_file()
|
||||
self.gguf_writer.close()
|
||||
|
||||
|
@ -3545,10 +3529,6 @@ def parse_args() -> argparse.Namespace:
|
|||
"--metadata", type=Path,
|
||||
help="Specify the path for an authorship metadata override file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--get-outfile", action="store_true",
|
||||
help="print calculated output file name then exit"
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
@ -3576,9 +3556,6 @@ def main() -> None:
|
|||
|
||||
if args.verbose:
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
elif args.get_outfile:
|
||||
# Avoid printing anything besides the dump output
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
else:
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
@ -3629,10 +3606,6 @@ def main() -> None:
|
|||
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
|
||||
small_first_shard=args.no_tensor_first_split)
|
||||
|
||||
if args.get_outfile:
|
||||
print(f"{model_instance.fname_default}") # noqa: NP100
|
||||
return
|
||||
|
||||
if args.vocab_only:
|
||||
logger.info("Exporting model vocab...")
|
||||
model_instance.write_vocab()
|
||||
|
@ -3640,6 +3613,7 @@ def main() -> None:
|
|||
else:
|
||||
logger.info("Exporting model...")
|
||||
model_instance.write()
|
||||
assert model_instance.fname_out is not None
|
||||
out_path = f"{model_instance.fname_out.parent}{os.sep}" if is_split else model_instance.fname_out
|
||||
logger.info(f"Model successfully exported to {out_path}")
|
||||
|
||||
|
|
|
@ -1042,9 +1042,11 @@ def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileT
|
|||
raise ValueError(f"Unexpected combination of types: {name_to_type}")
|
||||
|
||||
|
||||
def per_model_weight_count_estimation(tensors: Iterable[tuple[str, LazyTensor]], expert_count:int | None) -> int:
|
||||
# TODO: Ensure parameter count is accurate throughout various model type
|
||||
sum_weight_estimate: int = 0
|
||||
def per_model_weight_count_estimation(tensors: Iterable[tuple[str, LazyTensor]]) -> tuple[int, int, int]:
|
||||
total_params = 0
|
||||
shared_params = 0
|
||||
expert_params = 0
|
||||
|
||||
for name, lazy_tensor in tensors:
|
||||
# We don't need these
|
||||
if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")):
|
||||
|
@ -1057,17 +1059,15 @@ def per_model_weight_count_estimation(tensors: Iterable[tuple[str, LazyTensor]],
|
|||
for dim in lazy_tensor.shape:
|
||||
sum_weights_in_tensor *= dim
|
||||
|
||||
# Add Tensor Volume To Running Count
|
||||
sum_weight_estimate += sum_weights_in_tensor
|
||||
if ".experts." in name:
|
||||
if ".experts.0." in name:
|
||||
expert_params += sum_weights_in_tensor
|
||||
else:
|
||||
shared_params += sum_weights_in_tensor
|
||||
|
||||
if expert_count is None:
|
||||
return sum_weight_estimate
|
||||
total_params += sum_weights_in_tensor
|
||||
|
||||
if expert_count is not None and expert_count == 0:
|
||||
return sum_weight_estimate
|
||||
|
||||
# Calculate weight estimate per model
|
||||
return int(sum_weight_estimate / expert_count)
|
||||
return total_params, shared_params, expert_params
|
||||
|
||||
|
||||
def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyModel:
|
||||
|
@ -1249,12 +1249,12 @@ class VocabFactory:
|
|||
return vocab, special_vocab
|
||||
|
||||
|
||||
def default_convention_outfile(file_type: GGMLFileType, expert_count:int | None, model_params_count: int, metadata: gguf.Metadata) -> str:
|
||||
def default_convention_outfile(file_type: GGMLFileType, expert_count: int | None, model_params_count: tuple[int, int, int], metadata: gguf.Metadata) -> str:
|
||||
name = metadata.name if metadata.name is not None else None
|
||||
basename = metadata.basename if metadata.basename is not None else None
|
||||
finetune = metadata.finetune if metadata.finetune is not None else None
|
||||
version = metadata.version if metadata.version is not None else None
|
||||
size_label = metadata.size_label if metadata.size_label is not None else gguf.size_label(expert_count, model_params_count)
|
||||
size_label = metadata.size_label if metadata.size_label is not None else gguf.size_label(*model_params_count, expert_count=expert_count or 0)
|
||||
|
||||
output_type = {
|
||||
GGMLFileType.AllF32: "F32",
|
||||
|
@ -1265,7 +1265,7 @@ def default_convention_outfile(file_type: GGMLFileType, expert_count:int | None,
|
|||
return gguf.naming_convention(name, basename, finetune, version, size_label, output_type)
|
||||
|
||||
|
||||
def default_outfile(model_paths: list[Path], file_type: GGMLFileType, expert_count:int | None, model_params_count: int, metadata: gguf.Metadata) -> Path:
|
||||
def default_outfile(model_paths: list[Path], file_type: GGMLFileType, expert_count: int | None, model_params_count: tuple[int, int, int], metadata: gguf.Metadata) -> Path:
|
||||
default_filename = default_convention_outfile(file_type, expert_count, model_params_count, metadata)
|
||||
ret = model_paths[0].parent / f"{default_filename}.gguf"
|
||||
if ret in model_paths:
|
||||
|
@ -1328,7 +1328,7 @@ def main(args_in: list[str] | None = None) -> None:
|
|||
model_plus = load_some_model(dir_model)
|
||||
params = Params.load(model_plus)
|
||||
model = convert_model_names(model_plus.model, params, args.skip_unknown)
|
||||
model_params_count = per_model_weight_count_estimation(model_plus.model.items(), params.n_experts)
|
||||
model_params_count = per_model_weight_count_estimation(model_plus.model.items())
|
||||
ftype = pick_output_type(model, args.outtype)
|
||||
|
||||
if (metadata is None or metadata.name is None) and params.path_model is not None:
|
||||
|
@ -1415,8 +1415,8 @@ def main(args_in: list[str] | None = None) -> None:
|
|||
if metadata.name is None and params.path_model is not None:
|
||||
metadata.name = params.path_model.name
|
||||
|
||||
model_params_count = per_model_weight_count_estimation(model_plus.model.items(), params.n_experts)
|
||||
logger.info(f"model parameters count : {model_params_count} ({gguf.model_weight_count_rounded_notation(model_params_count)})")
|
||||
model_params_count = per_model_weight_count_estimation(model_plus.model.items())
|
||||
logger.info(f"model parameters count : {model_params_count} ({gguf.model_weight_count_rounded_notation(model_params_count[0])})")
|
||||
|
||||
logger.info(f"Vocab info: {vocab}")
|
||||
logger.info(f"Special vocab info: {special_vocab}")
|
||||
|
@ -1426,7 +1426,7 @@ def main(args_in: list[str] | None = None) -> None:
|
|||
model = convert_to_output_type(model, ftype)
|
||||
outfile = args.outfile or default_outfile(model_plus.paths, ftype, params.n_experts, model_params_count, metadata=metadata)
|
||||
|
||||
metadata.size_label = gguf.size_label(params.n_experts, model_params_count)
|
||||
metadata.size_label = gguf.size_label(*model_params_count, expert_count=params.n_experts or 0)
|
||||
|
||||
params.ftype = ftype
|
||||
logger.info(f"Writing {outfile}, format {ftype}")
|
||||
|
|
|
@ -17,6 +17,7 @@ import numpy as np
|
|||
from .constants import (
|
||||
GGUF_DEFAULT_ALIGNMENT,
|
||||
GGUF_MAGIC,
|
||||
GGML_QUANT_SIZES,
|
||||
GGUF_VERSION,
|
||||
GGMLQuantizationType,
|
||||
GGUFEndian,
|
||||
|
@ -106,6 +107,36 @@ class GGUFWriter:
|
|||
|
||||
self.add_architecture()
|
||||
|
||||
def get_total_parameter_count(self) -> tuple[int, int, int, int]:
|
||||
total_params = 0
|
||||
shared_params = 0
|
||||
expert_params = 0
|
||||
|
||||
expert_sum = 0
|
||||
n_expert_tensors = 0
|
||||
|
||||
for tensors in self.tensors:
|
||||
for name, info in tensors.items():
|
||||
|
||||
block_size, type_size = GGML_QUANT_SIZES[info.dtype]
|
||||
|
||||
size = (info.nbytes // type_size) * block_size
|
||||
|
||||
if "_exps." in name:
|
||||
expert_params += (size // info.shape[-3])
|
||||
expert_sum += info.shape[-3]
|
||||
n_expert_tensors += 1
|
||||
else:
|
||||
shared_params += size
|
||||
|
||||
total_params += size
|
||||
|
||||
# Hopefully this should work even for variable-expert-count models
|
||||
expert_count = (expert_sum // n_expert_tensors) if n_expert_tensors > 0 else 0
|
||||
|
||||
# NOTE: keep the output in the same order as accepted by 'size_label' in gguf-py/gguf/utility.py
|
||||
return total_params, shared_params, expert_params, expert_count
|
||||
|
||||
def format_shard_names(self, path: Path) -> list[Path]:
|
||||
if len(self.tensors) == 1:
|
||||
return [path]
|
||||
|
|
|
@ -10,12 +10,8 @@ def fill_templated_filename(filename: str, output_type: str):
|
|||
OUTTYPE=ftype_uppercase, FTYPE=ftype_uppercase)
|
||||
|
||||
|
||||
def model_weight_count_rounded_notation(model_params_count: int) -> str:
|
||||
if model_params_count > 1e15 :
|
||||
# Quadrillion Of Parameters
|
||||
scaled_model_params = model_params_count * 1e-15
|
||||
scale_suffix = "Q"
|
||||
elif model_params_count > 1e12 :
|
||||
def model_weight_count_rounded_notation(model_params_count: int, min_digits: int = 2) -> str:
|
||||
if model_params_count > 1e12 :
|
||||
# Trillions Of Parameters
|
||||
scaled_model_params = model_params_count * 1e-12
|
||||
scale_suffix = "T"
|
||||
|
@ -31,16 +27,19 @@ def model_weight_count_rounded_notation(model_params_count: int) -> str:
|
|||
# Thousands Of Parameters
|
||||
scaled_model_params = model_params_count * 1e-3
|
||||
scale_suffix = "K"
|
||||
return f"{round(scaled_model_params)}{scale_suffix}"
|
||||
|
||||
fix = max(min_digits - len(str(round(scaled_model_params)).lstrip('0')), 0)
|
||||
|
||||
return f"{scaled_model_params:.{fix}f}{scale_suffix}"
|
||||
|
||||
|
||||
def size_label(expert_count_int:int | None, model_params_count: int) -> str:
|
||||
per_model_rounded_weight_estimate = model_weight_count_rounded_notation(model_params_count)
|
||||
def size_label(total_params: int, shared_params: int, expert_params: int, expert_count: int) -> str:
|
||||
|
||||
if expert_count_int is not None and expert_count_int > 0:
|
||||
size_class = f"{expert_count_int}x{per_model_rounded_weight_estimate}"
|
||||
if expert_count > 0:
|
||||
pretty_size = model_weight_count_rounded_notation(shared_params + expert_params, min_digits=2)
|
||||
size_class = f"{expert_count}x{pretty_size}"
|
||||
else:
|
||||
size_class = f"{per_model_rounded_weight_estimate}"
|
||||
size_class = model_weight_count_rounded_notation(total_params, min_digits=2)
|
||||
|
||||
return size_class
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue