metadata.py: account for decimal point in size label within model id components

This commit is contained in:
brian khuu 2024-07-15 19:16:38 +10:00
parent 417d7a7c62
commit 9a925b56a0
2 changed files with 5 additions and 1 deletions

View file

@ -179,7 +179,7 @@ class Metadata:
# Heuristic to match against cases such as 'Mixtral-8x7B-Instruct-v0.1' or 'Codestral-22B-v0.1' # Heuristic to match against cases such as 'Mixtral-8x7B-Instruct-v0.1' or 'Codestral-22B-v0.1'
regex_match = re.compile(r'^' regex_match = re.compile(r'^'
r'(?P<basename>[A-Za-z0-9\s]*(?:(?:-(?:(?:[A-Za-z\s][A-Za-z0-9\s]*)|(?:[0-9\s]*)))*))' r'(?P<basename>[A-Za-z0-9\s]*(?:(?:-(?:(?:[A-Za-z\s][A-Za-z0-9\s]*)|(?:[0-9\s]*)))*))'
r'(?:-(?P<size_label>(?:\d+x)?\d+[A-Za-z](?:-[A-Za-z]+(?:\d+x)?\d+[A-Za-z]+)?)(?:-(?P<finetune>[A-Za-z0-9\s-]+))?)?' r'(?:-(?P<size_label>(?:\d+x)?(\d+\.)?\d+[A-Za-z](?:-[A-Za-z]+(\d+\.)?\d+[A-Za-z]+)?)(?:-(?P<finetune>[A-Za-z0-9\s-]+))?)?'
r'(?:-(?P<version>v\d+(?:\.\d+)*))?' r'(?:-(?P<version>v\d+(?:\.\d+)*))?'
r'$').match(model_full_name_component) r'$').match(model_full_name_component)

View file

@ -73,6 +73,10 @@ class TestMetadataMethod(unittest.TestCase):
self.assertEqual(gguf.Metadata.get_model_id_components("crestf411/daybreak-kunoichi-2dpo-7b"), self.assertEqual(gguf.Metadata.get_model_id_components("crestf411/daybreak-kunoichi-2dpo-7b"),
('daybreak-kunoichi-2dpo-7b', 'crestf411', None, None, None, None)) ('daybreak-kunoichi-2dpo-7b', 'crestf411', None, None, None, None))
# This is a real model id where the weight size has a decimal point
self.assertEqual(gguf.Metadata.get_model_id_components("Qwen2-0.5B-Instruct"),
('Qwen2-0.5B-Instruct', None, 'Qwen2', 'Instruct', None, '0.5B'))
def test_apply_metadata_heuristic_from_model_card(self): def test_apply_metadata_heuristic_from_model_card(self):
model_card = { model_card = {
'tags': ['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl'], 'tags': ['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl'],