convert-*.py: refactor per model weight count estimation

This commit is contained in:
brian khuu 2024-07-10 20:20:54 +10:00
parent 2a976e1211
commit 59a01df784

View file

@ -108,8 +108,12 @@ class Model:
# Generate parameter weight class (useful for leader boards) if not yet determined # Generate parameter weight class (useful for leader boards) if not yet determined
if self.metadata.parameter_class_attribute is None: if self.metadata.parameter_class_attribute is None:
expert_count = self.hparams["num_local_experts"] if "num_local_experts" in self.hparams else None expert_count = self.hparams["num_local_experts"] if "num_local_experts" in self.hparams else None
weight_estimate = self.per_model_weight_count_estimation(self.get_tensors(), expert_count) sum_weight_estimate = self.calculate_total_weight_count()
self.metadata.parameter_class_attribute = gguf.parameter_class_attribute(expert_count, weight_estimate)
# Calculate weight estimate per model
per_model_weight_estimate = (sum_weight_estimate / expert_count) if expert_count is not None and (expert_count > 0) else sum_weight_estimate
self.metadata.parameter_class_attribute = gguf.parameter_class_attribute(expert_count, per_model_weight_estimate)
# Extracts and converts the encoding scheme from the given file type name. e.g. 'gguf.LlamaFileType.ALL_F32' --> 'F32' # Extracts and converts the encoding scheme from the given file type name. e.g. 'gguf.LlamaFileType.ALL_F32' --> 'F32'
output_type = self.ftype.name.partition("_")[2] output_type = self.ftype.name.partition("_")[2]
@ -187,6 +191,23 @@ class Model:
if len(sym_diff := tensor_names_from_parts.symmetric_difference(self.tensor_names)) > 0: if len(sym_diff := tensor_names_from_parts.symmetric_difference(self.tensor_names)) > 0:
raise ValueError(f"Mismatch between weight map and model parts for tensor names: {sym_diff}") raise ValueError(f"Mismatch between weight map and model parts for tensor names: {sym_diff}")
def calculate_total_weight_count(self) -> int:
sum_weight_estimate = 0
for name, data_torch in self.get_tensors():
# we don't need these
if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")):
continue
# Calculate Tensor Volume
sum_weights_in_tensor = 1
for dim in data_torch.shape:
sum_weights_in_tensor *= dim
# Add Tensor Volume To Running Count
sum_weight_estimate += sum_weights_in_tensor
return sum_weight_estimate
def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str: def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str:
if key not in gguf.MODEL_TENSORS[self.model_arch]: if key not in gguf.MODEL_TENSORS[self.model_arch]:
raise ValueError(f"Missing {key!r} for MODEL_TENSORS of {self.model_arch!r}") raise ValueError(f"Missing {key!r} for MODEL_TENSORS of {self.model_arch!r}")
@ -397,31 +418,6 @@ class Model:
self.gguf_writer.write_kv_data_to_file() self.gguf_writer.write_kv_data_to_file()
self.gguf_writer.close() self.gguf_writer.close()
def per_model_weight_count_estimation(self, tensors: Iterator[tuple[str, Tensor]], expert_count: int) -> int:
# TODO: Ensure parameter count is accurate throughout various model type
# May currently overestimate parameter count in Mamba model because
# output weights is tied with token embeddings.
sum_weight_estimate = 0
for name, data_torch in tensors:
# Got A Tensor
# We don't need these
if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")):
continue
# Calculate Tensor Volume
sum_weights_in_tensor = 1
for dim in data_torch.shape:
sum_weights_in_tensor *= dim
# Add Tensor Volume To Running Count
sum_weight_estimate += sum_weights_in_tensor
# Calculate weight estimate per model
per_model_weight_estimate = (sum_weight_estimate / expert_count) if expert_count is not None and (expert_count > 0) else sum_weight_estimate
return per_model_weight_estimate
@staticmethod @staticmethod
def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str]: def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str]:
part_names: list[str] = [] part_names: list[str] = []