gguf-py : fix some metadata name extraction edge cases
* convert_lora : use the lora dir for the model card path
This commit is contained in:
parent
87e397d00b
commit
2164c9deb3
4 changed files with 57 additions and 14 deletions
|
@ -44,7 +44,7 @@ class Metadata:
|
|||
datasets: Optional[list[str]] = None
|
||||
|
||||
@staticmethod
|
||||
def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Path] = None, model_name: Optional[str] = None, total_params: int = 0) -> Metadata:
|
||||
def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Path] = None, model_name: Optional[str] = None, model_card_path: Optional[Path] = None, total_params: int = 0) -> Metadata:
|
||||
# This grabs as many contextual authorship metadata as possible from the model repository
|
||||
# making any conversion as required to match the gguf kv store metadata format
|
||||
# as well as giving users the ability to override any authorship metadata that may be incorrect
|
||||
|
@ -52,11 +52,14 @@ class Metadata:
|
|||
# Create a new Metadata instance
|
||||
metadata = Metadata()
|
||||
|
||||
model_card = Metadata.load_model_card(model_path)
|
||||
if model_card_path is None:
|
||||
model_card_path = model_path
|
||||
|
||||
model_card = Metadata.load_model_card(model_card_path)
|
||||
hf_params = Metadata.load_hf_parameters(model_path)
|
||||
|
||||
# heuristics
|
||||
metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_path, total_params)
|
||||
metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_card_path, total_params)
|
||||
|
||||
# Metadata Override File Provided
|
||||
# This is based on LLM_KV_NAMES mapping in llama.cpp
|
||||
|
@ -177,6 +180,12 @@ class Metadata:
|
|||
org_component = None
|
||||
|
||||
name_parts: list[str] = model_full_name_component.split('-')
|
||||
|
||||
# Remove empty parts
|
||||
for i in reversed(range(len(name_parts))):
|
||||
if len(name_parts[i]) == 0:
|
||||
del name_parts[i]
|
||||
|
||||
name_types: list[
|
||||
set[Literal["basename", "size_label", "finetune", "version", "type"]]
|
||||
] = [set() for _ in name_parts]
|
||||
|
@ -227,6 +236,13 @@ class Metadata:
|
|||
if part.lower() == "lora":
|
||||
name_parts[i] = "LoRA"
|
||||
|
||||
# Ignore word-based size labels when there is at least a number-based one present
|
||||
if any(c.isdecimal() for n, t in zip(name_parts, name_types) if "size_label" in t for c in n):
|
||||
for n, t in zip(name_parts, name_types):
|
||||
if "size_label" in t:
|
||||
if all(c.isalpha() for c in n):
|
||||
t.remove("size_label")
|
||||
|
||||
at_start = True
|
||||
# Find the basename through the annotated name
|
||||
for part, t in zip(name_parts, name_types):
|
||||
|
@ -247,7 +263,8 @@ class Metadata:
|
|||
break
|
||||
|
||||
basename = "-".join(n for n, t in zip(name_parts, name_types) if "basename" in t) or None
|
||||
size_label = "-".join(s for s, t in zip(name_parts, name_types) if "size_label" in t) or None
|
||||
# Deduplicate size labels using order-preserving 'dict' ('set' seems to sort the keys)
|
||||
size_label = "-".join(dict.fromkeys(s for s, t in zip(name_parts, name_types) if "size_label" in t).keys()) or None
|
||||
finetune = "-".join(f for f, t in zip(name_parts, name_types) if "finetune" in t) or None
|
||||
# TODO: should the basename version always be excluded?
|
||||
# TODO: should multiple versions be joined together?
|
||||
|
|
|
@ -54,7 +54,7 @@ class TestMetadataMethod(unittest.TestCase):
|
|||
self.assertEqual(gguf.Metadata.get_model_id_components("NousResearch/Meta-Llama-3-8B"),
|
||||
('Meta-Llama-3-8B', "NousResearch", 'Meta-Llama-3', None, None, '8B'))
|
||||
|
||||
# Can't detect all non standard form in a heuristically safe way... best to err in caution and output nothing...
|
||||
# Non standard naming
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("Qwen1.5-MoE-A2.7B-Chat"),
|
||||
('Qwen1.5-MoE-A2.7B-Chat', None, 'Qwen1.5-MoE', 'Chat', None, 'A2.7B'))
|
||||
|
||||
|
@ -71,7 +71,7 @@ class TestMetadataMethod(unittest.TestCase):
|
|||
self.assertEqual(gguf.Metadata.get_model_id_components("delphi-suite/stories-llama2-50k", 50 * 10**3),
|
||||
('stories-llama2-50k', 'delphi-suite', 'stories-llama2', None, None, '50K'))
|
||||
|
||||
# None standard and not easy to disambiguate
|
||||
# Non standard and not easy to disambiguate
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("DeepSeek-Coder-V2-Lite-Instruct"),
|
||||
('DeepSeek-Coder-V2-Lite-Instruct', None, 'DeepSeek-Coder-V2-Lite', 'Instruct', None, None))
|
||||
|
||||
|
@ -123,6 +123,20 @@ class TestMetadataMethod(unittest.TestCase):
|
|||
self.assertEqual(gguf.Metadata.get_model_id_components("bigscience/bloom-7b1-petals"),
|
||||
('bloom-7b1-petals', 'bigscience', 'bloom', 'petals', None, '7.1B'))
|
||||
|
||||
# Ignore full-text size labels when there are number-based ones, and deduplicate size labels
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("MaziyarPanahi/GreenNode-mini-7B-multilingual-v1olet-Mistral-7B-Instruct-v0.1"),
|
||||
('GreenNode-mini-7B-multilingual-v1olet-Mistral-7B-Instruct-v0.1', 'MaziyarPanahi', 'GreenNode-mini', 'multilingual-v1olet-Mistral-Instruct', 'v0.1', '7B'))
|
||||
|
||||
# Version at the end with a long basename
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("mistralai/Mistral-Nemo-Base-2407"),
|
||||
('Mistral-Nemo-Base-2407', 'mistralai', 'Mistral-Nemo-Base', None, '2407', None))
|
||||
|
||||
## Invalid cases ##
|
||||
|
||||
# Start with a dash and has dashes in rows
|
||||
self.assertEqual(gguf.Metadata.get_model_id_components("mistralai/-Mistral--Nemo-Base-2407-"),
|
||||
('-Mistral--Nemo-Base-2407-', 'mistralai', 'Mistral-Nemo-Base', None, '2407', None))
|
||||
|
||||
def test_apply_metadata_heuristic_from_model_card(self):
|
||||
model_card = {
|
||||
'tags': ['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl'],
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue