From 640039106ff45dab5914fa903e5c5d4a451fc135 Mon Sep 17 00:00:00 2001 From: brian khuu Date: Tue, 6 Aug 2024 00:42:27 +1000 Subject: [PATCH] py: let users add full base model and dataset to model_card --- gguf-py/gguf/metadata.py | 4 ++++ gguf-py/tests/test_metadata.py | 12 ++++++++++++ 2 files changed, 16 insertions(+) diff --git a/gguf-py/gguf/metadata.py b/gguf-py/gguf/metadata.py index 9f3a1ecb7..4d901f04d 100644 --- a/gguf-py/gguf/metadata.py +++ b/gguf-py/gguf/metadata.py @@ -400,6 +400,8 @@ class Metadata: if org_component is not None and model_full_name_component is not None: base_model["repo_url"] = f"https://huggingface.co/{org_component}/{model_full_name_component}" + elif isinstance(model_id, dict): + base_model = model_id else: logger.error(f"base model entry '{str(model_id)}' not in a known format") metadata.base_models.append(base_model) @@ -454,6 +456,8 @@ class Metadata: if org_component is not None and dataset_name_component is not None: dataset["repo_url"] = f"https://huggingface.co/{org_component}/{dataset_name_component}" + elif isinstance(dataset_id, dict): + dataset = dataset_id else: logger.error(f"dataset entry '{str(dataset_id)}' not in a known format") metadata.datasets.append(dataset) diff --git a/gguf-py/tests/test_metadata.py b/gguf-py/tests/test_metadata.py index a9ab1da88..40d484f4e 100755 --- a/gguf-py/tests/test_metadata.py +++ b/gguf-py/tests/test_metadata.py @@ -197,6 +197,12 @@ class TestMetadataMethod(unittest.TestCase): got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None) self.assertEqual(got, expect) + # Base Model spec is given directly + model_card = {'base_models': [{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]} + expect = gguf.Metadata(base_models=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]) + got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None) + self.assertEqual(got, expect) + # Dataset spec is inferred from model id model_card = {'datasets': 'teknium/OpenHermes-2.5'} expect = gguf.Metadata(datasets=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]) @@ -209,6 +215,12 @@ class TestMetadataMethod(unittest.TestCase): got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None) self.assertEqual(got, expect) + # Dataset spec is given directly + model_card = {'datasets': [{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]} + expect = gguf.Metadata(datasets=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]) + got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None) + self.assertEqual(got, expect) + def test_apply_metadata_heuristic_from_hf_parameters(self): hf_params = {"_name_or_path": "./hermes-2-pro-llama-3-8b-DPO"} got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card=None, hf_params=hf_params, model_path=None)