gguf-py : fix some metadata name extraction edge cases

* convert_lora : use the lora dir for the model card path
This commit is contained in:
Francis Couture-Harpin 2024-07-19 12:30:37 -04:00
parent 87e397d00b
commit 2164c9deb3
4 changed files with 57 additions and 14 deletions

View file

@ -44,7 +44,7 @@ class Metadata:
datasets: Optional[list[str]] = None
@staticmethod
def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Path] = None, model_name: Optional[str] = None, total_params: int = 0) -> Metadata:
def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Path] = None, model_name: Optional[str] = None, model_card_path: Optional[Path] = None, total_params: int = 0) -> Metadata:
# This grabs as many contextual authorship metadata as possible from the model repository
# making any conversion as required to match the gguf kv store metadata format
# as well as giving users the ability to override any authorship metadata that may be incorrect
@ -52,11 +52,14 @@ class Metadata:
# Create a new Metadata instance
metadata = Metadata()
model_card = Metadata.load_model_card(model_path)
if model_card_path is None:
model_card_path = model_path
model_card = Metadata.load_model_card(model_card_path)
hf_params = Metadata.load_hf_parameters(model_path)
# heuristics
metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_path, total_params)
metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_card_path, total_params)
# Metadata Override File Provided
# This is based on LLM_KV_NAMES mapping in llama.cpp
@ -177,6 +180,12 @@ class Metadata:
org_component = None
name_parts: list[str] = model_full_name_component.split('-')
# Remove empty parts
for i in reversed(range(len(name_parts))):
if len(name_parts[i]) == 0:
del name_parts[i]
name_types: list[
set[Literal["basename", "size_label", "finetune", "version", "type"]]
] = [set() for _ in name_parts]
@ -227,6 +236,13 @@ class Metadata:
if part.lower() == "lora":
name_parts[i] = "LoRA"
# Ignore word-based size labels when there is at least a number-based one present
if any(c.isdecimal() for n, t in zip(name_parts, name_types) if "size_label" in t for c in n):
for n, t in zip(name_parts, name_types):
if "size_label" in t:
if all(c.isalpha() for c in n):
t.remove("size_label")
at_start = True
# Find the basename through the annotated name
for part, t in zip(name_parts, name_types):
@ -247,7 +263,8 @@ class Metadata:
break
basename = "-".join(n for n, t in zip(name_parts, name_types) if "basename" in t) or None
size_label = "-".join(s for s, t in zip(name_parts, name_types) if "size_label" in t) or None
# Deduplicate size labels using order-preserving 'dict' ('set' seems to sort the keys)
size_label = "-".join(dict.fromkeys(s for s, t in zip(name_parts, name_types) if "size_label" in t).keys()) or None
finetune = "-".join(f for f, t in zip(name_parts, name_types) if "finetune" in t) or None
# TODO: should the basename version always be excluded?
# TODO: should multiple versions be joined together?