convert-*.py: more rigorous regexp for get_model_id_components()

This commit is contained in:
brian khuu 2024-07-14 16:28:52 +10:00
parent 4e3761109d
commit f98f1098f9
2 changed files with 68 additions and 11 deletions

View file

@ -170,18 +170,28 @@ class Metadata:
# Regular expression to extract model name components # Regular expression to extract model name components
# Heuristic to match against cases such as 'Mixtral-8x7B-Instruct-v0.1' or 'Codestral-22B-v0.1' # Heuristic to match against cases such as 'Mixtral-8x7B-Instruct-v0.1' or 'Codestral-22B-v0.1'
regex_match = re.compile(r'^(?P<basename>[A-Za-z0-9\s]*(?:(?:-(?:(?:[A-Za-z\s][A-Za-z0-9\s]*)|(?:[0-9\s]*)))*))' regex_match = re.compile(r'^'
r'(?:-(?P<size_label>(?:\d+x)?\d+[A-Za-z]+)(?:-(?P<finetune>[A-Za-z0-9\s-]+))?)?' r'(?P<basename>[A-Za-z0-9\s]*(?:(?:-(?:(?:[A-Za-z\s][A-Za-z0-9\s]*)|(?:[0-9\s]*)))*))'
r'(?:-(?P<version>v\d+(?:\.\d+)*))?$').match(model_full_name_component) r'(?:-(?P<size_label>(?:\d+x)?\d+[A-Za-z](?:-[A-Za-z]+(?:\d+x)?\d+[A-Za-z]+)?)(?:-(?P<finetune>[A-Za-z0-9\s-]+))?)?'
r'(?:-(?P<version>v\d+(?:\.\d+)*))?'
r'$').match(model_full_name_component)
if not regex_match: if not regex_match:
return model_full_name_component, org_component, None, None, None, None return model_full_name_component, org_component, None, None, None, None
components = regex_match.groupdict() components = regex_match.groupdict()
basename = components.get("basename") basename = components.get("basename")
size_label = components.get("size_label")
finetune = components.get("finetune") finetune = components.get("finetune")
version = components.get("version") version = components.get("version")
size_label = components.get("size_label")
# Base name required at a minimum
if basename is None:
return model_full_name_component, None, None, None, None, None
# Need to capture at least one component that is not basename
if size_label is None and version is None and finetune is None:
return model_full_name_component, None, None, None, None, None
return model_full_name_component, org_component, basename, finetune, version, size_label return model_full_name_component, org_component, basename, finetune, version, size_label

View file

@ -13,36 +13,83 @@ class TestMetadataMethod(unittest.TestCase):
self.assertEqual(gguf.Metadata.id_to_title("hermes-2-pro-llama-3-8b-DPO"), "Hermes 2 Pro Llama 3 8b DPO") self.assertEqual(gguf.Metadata.id_to_title("hermes-2-pro-llama-3-8b-DPO"), "Hermes 2 Pro Llama 3 8b DPO")
def test_get_model_id_components(self): def test_get_model_id_components(self):
# This is the basic standard form with organization marker
self.assertEqual(gguf.Metadata.get_model_id_components("Mistral/Mixtral-8x7B-Instruct-v0.1"), self.assertEqual(gguf.Metadata.get_model_id_components("Mistral/Mixtral-8x7B-Instruct-v0.1"),
('Mixtral-8x7B-Instruct-v0.1', "Mistral", 'Mixtral', 'Instruct', 'v0.1', '8x7B')) ('Mixtral-8x7B-Instruct-v0.1', "Mistral", 'Mixtral', 'Instruct', 'v0.1', '8x7B'))
# Similar to basic standard form but without organization marker
self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B-Instruct-v0.1"), self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B-Instruct-v0.1"),
('Mixtral-8x7B-Instruct-v0.1', None, 'Mixtral', 'Instruct', 'v0.1', '8x7B')) ('Mixtral-8x7B-Instruct-v0.1', None, 'Mixtral', 'Instruct', 'v0.1', '8x7B'))
# Missing version
self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B-Instruct"), self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B-Instruct"),
('Mixtral-8x7B-Instruct', None, 'Mixtral', 'Instruct', None, '8x7B')) ('Mixtral-8x7B-Instruct', None, 'Mixtral', 'Instruct', None, '8x7B'))
# Missing finetune
self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B-v0.1"), self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B-v0.1"),
('Mixtral-8x7B-v0.1', None, 'Mixtral', None, 'v0.1', '8x7B')) ('Mixtral-8x7B-v0.1', None, 'Mixtral', None, 'v0.1', '8x7B'))
# Base name and size label only
self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B"), self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B"),
('Mixtral-8x7B', None, 'Mixtral', None, None, '8x7B')) ('Mixtral-8x7B', None, 'Mixtral', None, None, '8x7B'))
self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral"),
('Mixtral', None, 'Mixtral', None, None, None)) # Base name and version only
self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-v0.1"), self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-v0.1"),
('Mixtral-v0.1', None, 'Mixtral', None, 'v0.1', None)) ('Mixtral-v0.1', None, 'Mixtral', None, 'v0.1', None))
self.assertEqual(gguf.Metadata.get_model_id_components("hermes-2-pro-llama-3-8b-DPO"),
('hermes-2-pro-llama-3-8b-DPO', None, 'hermes-2-pro-llama-3', 'DPO', None, '8b')) ## Edge Cases ##
# This is too ambiguous... best to err on caution and output nothing
self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral"),
('Mixtral', None, None, None, None, None))
# Basename has numbers mixed in and also size label provided. Must avoid capturing number in basename
self.assertEqual(gguf.Metadata.get_model_id_components("NousResearch/Meta-Llama-3-8B"), self.assertEqual(gguf.Metadata.get_model_id_components("NousResearch/Meta-Llama-3-8B"),
('Meta-Llama-3-8B', "NousResearch", 'Meta-Llama-3', None, None, "8B")) ('Meta-Llama-3-8B', "NousResearch", 'Meta-Llama-3', None, None, '8B'))
# Can't detect all non standard form in a heuristically safe way... best to err in caution and output nothing...
self.assertEqual(gguf.Metadata.get_model_id_components("Qwen1.5-MoE-A2.7B-Chat"),
('Qwen1.5-MoE-A2.7B-Chat', None, None, None, None, None))
# Capture 'sub size labels' e.g. A14B in '57B-A14B' usually refers to activated params/weight count
self.assertEqual(gguf.Metadata.get_model_id_components("Qwen2-57B-A14B-Instruct"),
('Qwen2-57B-A14B-Instruct', None, 'Qwen2', 'Instruct', None, '57B-A14B'))
# Check that it can handle a real model id with no version code
# Note that 4k in this string is non standard and microsoft were referring to context length rather than weight count
self.assertEqual(gguf.Metadata.get_model_id_components("microsoft/Phi-3-mini-4k-instruct"),
('Phi-3-mini-4k-instruct', 'microsoft', 'Phi-3-mini', 'instruct', None, '4k'))
# There is some legitimate models with only thousands of parameters
self.assertEqual(gguf.Metadata.get_model_id_components("delphi-suite/stories-llama2-50k"),
('stories-llama2-50k', 'delphi-suite', 'stories-llama2', None, None, '50k'))
# None standard and not easy to disambiguate, best to err in caution and output nothing
self.assertEqual(gguf.Metadata.get_model_id_components("DeepSeek-Coder-V2-Lite-Instruct"),
('DeepSeek-Coder-V2-Lite-Instruct', None, None, None, None, None))
# This is a real model_id where they append 2DPO to refer to Direct Preference Optimization
# Not able to easily reject '2dpo' while keeping to simple regexp, so best to reject
self.assertEqual(gguf.Metadata.get_model_id_components("crestf411/daybreak-kunoichi-2dpo-7b"),
('daybreak-kunoichi-2dpo-7b', 'crestf411', None, None, None, None))
def test_apply_metadata_heuristic_from_model_card(self): def test_apply_metadata_heuristic_from_model_card(self):
model_card = { model_card = {
'tags': ['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl'], 'tags': ['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl'],
'model-index': [{'name': 'Hermes-2-Pro-Llama-3-8B', 'results': []}], 'model-index': [{'name': 'Mixtral-8x7B-Instruct-v0.1', 'results': []}],
'language': ['en'], 'language': ['en'],
'datasets': ['teknium/OpenHermes-2.5'], 'datasets': ['teknium/OpenHermes-2.5'],
'widget': [{'example_title': 'Hermes 2 Pro', 'messages': [{'role': 'system', 'content': 'You are a sentient, superintelligent artificial general intelligence, here to teach and assist me.'}, {'role': 'user', 'content': 'Write a short story about Goku discovering kirby has teamed up with Majin Buu to destroy the world.'}]}], 'widget': [{'example_title': 'Hermes 2 Pro', 'messages': [{'role': 'system', 'content': 'You are a sentient, superintelligent artificial general intelligence, here to teach and assist me.'}, {'role': 'user', 'content': 'Write a short story about Goku discovering kirby has teamed up with Majin Buu to destroy the world.'}]}],
'base_model': ["EmbeddedLLM/Mistral-7B-Merge-14-v0", "janai-hq/trinity-v1"] 'base_model': ["EmbeddedLLM/Mistral-7B-Merge-14-v0", "janai-hq/trinity-v1"]
} }
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None) got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
expect = gguf.Metadata(name=None, author=None, version=None, organization=None, finetune=None, basename=None, description=None, quantized_by=None, size_label=None, url=None, doi=None, uuid=None, repo_url=None, license=None, license_name=None, license_link=None, base_models=[{'name': 'Mistral 7B Merge 14 v0', 'organization': 'EmbeddedLLM', 'repo_url': 'https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0'}, {'name': 'Trinity v1', 'organization': 'Janai Hq', 'repo_url': 'https://huggingface.co/janai-hq/trinity-v1'}], tags=['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl'], languages=['en'], datasets=['teknium/OpenHermes-2.5']) expect = gguf.Metadata()
expect.base_models=[{'name': 'Mistral 7B Merge 14 v0', 'organization': 'EmbeddedLLM', 'repo_url': 'https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0'}, {'name': 'Trinity v1'}]
expect.tags=['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl']
expect.languages=['en']
expect.datasets=['teknium/OpenHermes-2.5']
self.assertEqual(got, expect) self.assertEqual(got, expect)
def test_apply_metadata_heuristic_from_hf_parameters(self): def test_apply_metadata_heuristic_from_hf_parameters(self):