py: let users add full base model and dataset to model_card

This commit is contained in:
brian khuu 2024-08-06 00:42:27 +10:00
parent d32c74d1f2
commit 640039106f
2 changed files with 16 additions and 0 deletions

View file

@ -400,6 +400,8 @@ class Metadata:
if org_component is not None and model_full_name_component is not None: if org_component is not None and model_full_name_component is not None:
base_model["repo_url"] = f"https://huggingface.co/{org_component}/{model_full_name_component}" base_model["repo_url"] = f"https://huggingface.co/{org_component}/{model_full_name_component}"
elif isinstance(model_id, dict):
base_model = model_id
else: else:
logger.error(f"base model entry '{str(model_id)}' not in a known format") logger.error(f"base model entry '{str(model_id)}' not in a known format")
metadata.base_models.append(base_model) metadata.base_models.append(base_model)
@ -454,6 +456,8 @@ class Metadata:
if org_component is not None and dataset_name_component is not None: if org_component is not None and dataset_name_component is not None:
dataset["repo_url"] = f"https://huggingface.co/{org_component}/{dataset_name_component}" dataset["repo_url"] = f"https://huggingface.co/{org_component}/{dataset_name_component}"
elif isinstance(dataset_id, dict):
dataset = dataset_id
else: else:
logger.error(f"dataset entry '{str(dataset_id)}' not in a known format") logger.error(f"dataset entry '{str(dataset_id)}' not in a known format")
metadata.datasets.append(dataset) metadata.datasets.append(dataset)

View file

@ -197,6 +197,12 @@ class TestMetadataMethod(unittest.TestCase):
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None) got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
self.assertEqual(got, expect) self.assertEqual(got, expect)
# Base Model spec is given directly
model_card = {'base_models': [{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]}
expect = gguf.Metadata(base_models=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}])
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
self.assertEqual(got, expect)
# Dataset spec is inferred from model id # Dataset spec is inferred from model id
model_card = {'datasets': 'teknium/OpenHermes-2.5'} model_card = {'datasets': 'teknium/OpenHermes-2.5'}
expect = gguf.Metadata(datasets=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]) expect = gguf.Metadata(datasets=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}])
@ -209,6 +215,12 @@ class TestMetadataMethod(unittest.TestCase):
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None) got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
self.assertEqual(got, expect) self.assertEqual(got, expect)
# Dataset spec is given directly
model_card = {'datasets': [{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]}
expect = gguf.Metadata(datasets=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}])
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
self.assertEqual(got, expect)
def test_apply_metadata_heuristic_from_hf_parameters(self): def test_apply_metadata_heuristic_from_hf_parameters(self):
hf_params = {"_name_or_path": "./hermes-2-pro-llama-3-8b-DPO"} hf_params = {"_name_or_path": "./hermes-2-pro-llama-3-8b-DPO"}
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card=None, hf_params=hf_params, model_path=None) got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card=None, hf_params=hf_params, model_path=None)