Compare commits
5 commits
master
...
compilade/
Author | SHA1 | Date | |
---|---|---|---|
|
1932a1b871 | ||
|
bf8e71b0c0 | ||
|
a3d154b260 | ||
|
912e6fa5c6 | ||
|
2164c9deb3 |
5 changed files with 112 additions and 44 deletions
|
@ -48,7 +48,7 @@ class Model:
|
||||||
|
|
||||||
dir_model: Path
|
dir_model: Path
|
||||||
ftype: gguf.LlamaFileType
|
ftype: gguf.LlamaFileType
|
||||||
fname_out: Path | None
|
fname_out: Path
|
||||||
is_big_endian: bool
|
is_big_endian: bool
|
||||||
endianess: gguf.GGUFEndian
|
endianess: gguf.GGUFEndian
|
||||||
use_temp_file: bool
|
use_temp_file: bool
|
||||||
|
@ -62,11 +62,12 @@ class Model:
|
||||||
gguf_writer: gguf.GGUFWriter
|
gguf_writer: gguf.GGUFWriter
|
||||||
model_name: str | None
|
model_name: str | None
|
||||||
metadata_override: Path | None
|
metadata_override: Path | None
|
||||||
|
dir_model_card: Path
|
||||||
|
|
||||||
# subclasses should define this!
|
# subclasses should define this!
|
||||||
model_arch: gguf.MODEL_ARCH
|
model_arch: gguf.MODEL_ARCH
|
||||||
|
|
||||||
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path | None, is_big_endian: bool = False,
|
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool = False,
|
||||||
use_temp_file: bool = False, eager: bool = False,
|
use_temp_file: bool = False, eager: bool = False,
|
||||||
metadata_override: Path | None = None, model_name: str | None = None,
|
metadata_override: Path | None = None, model_name: str | None = None,
|
||||||
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False):
|
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False):
|
||||||
|
@ -90,6 +91,7 @@ class Model:
|
||||||
self.tensor_names = None
|
self.tensor_names = None
|
||||||
self.metadata_override = metadata_override
|
self.metadata_override = metadata_override
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
|
||||||
|
|
||||||
# Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type
|
# Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type
|
||||||
if self.ftype == gguf.LlamaFileType.GUESSED:
|
if self.ftype == gguf.LlamaFileType.GUESSED:
|
||||||
|
@ -345,7 +347,7 @@ class Model:
|
||||||
|
|
||||||
total_params, shared_params, expert_params, expert_count = self.gguf_writer.get_total_parameter_count()
|
total_params, shared_params, expert_params, expert_count = self.gguf_writer.get_total_parameter_count()
|
||||||
|
|
||||||
self.metadata = gguf.Metadata.load(self.metadata_override, self.dir_model, self.model_name, total_params)
|
self.metadata = gguf.Metadata.load(self.metadata_override, self.dir_model_card, self.model_name, total_params)
|
||||||
|
|
||||||
# Fallback to model directory name if metadata name is still missing
|
# Fallback to model directory name if metadata name is still missing
|
||||||
if self.metadata.name is None:
|
if self.metadata.name is None:
|
||||||
|
@ -359,27 +361,22 @@ class Model:
|
||||||
output_type: str = self.ftype.name.partition("_")[2]
|
output_type: str = self.ftype.name.partition("_")[2]
|
||||||
|
|
||||||
# Filename Output
|
# Filename Output
|
||||||
# Note: `not is_dir()` is used because `.is_file()` will not detect
|
if self.fname_out.is_dir():
|
||||||
# file template strings as it doesn't actually exist as a file
|
|
||||||
if self.fname_out is not None and not self.fname_out.is_dir():
|
|
||||||
# Output path is a custom defined templated filename
|
|
||||||
|
|
||||||
# Process templated file name with the output ftype, useful with the "auto" ftype
|
|
||||||
self.fname_out = self.fname_out.parent / gguf.fill_templated_filename(self.fname_out.name, output_type)
|
|
||||||
else:
|
|
||||||
# Generate default filename based on model specification and available metadata
|
# Generate default filename based on model specification and available metadata
|
||||||
if not vocab_only:
|
if not vocab_only:
|
||||||
fname_default: str = gguf.naming_convention(self.metadata.name, self.metadata.basename, self.metadata.finetune, self.metadata.version, self.metadata.size_label, output_type, model_type="LoRA" if total_params < 0 else None)
|
fname_default: str = gguf.naming_convention(self.metadata.name, self.metadata.basename, self.metadata.finetune, self.metadata.version, self.metadata.size_label, output_type, model_type="LoRA" if total_params < 0 else None)
|
||||||
else:
|
else:
|
||||||
fname_default: str = gguf.naming_convention(self.metadata.name, self.metadata.basename, self.metadata.finetune, self.metadata.version, size_label=None, output_type=None, model_type="vocab")
|
fname_default: str = gguf.naming_convention(self.metadata.name, self.metadata.basename, self.metadata.finetune, self.metadata.version, size_label=None, output_type=None, model_type="vocab")
|
||||||
|
|
||||||
# Check if preferred output directory path was provided
|
# Use the default filename
|
||||||
if self.fname_out is not None and self.fname_out.is_dir():
|
self.fname_out = self.fname_out / f"{fname_default}.gguf"
|
||||||
# output path is a directory
|
else:
|
||||||
self.fname_out = self.fname_out / f"{fname_default}.gguf"
|
# Output path is a custom defined templated filename
|
||||||
else:
|
# Note: `not is_dir()` is used because `.is_file()` will not detect
|
||||||
# output in the same directory as the model by default
|
# file template strings as it doesn't actually exist as a file
|
||||||
self.fname_out = self.dir_model / f"{fname_default}.gguf"
|
|
||||||
|
# Process templated file name with the output ftype, useful with the "auto" ftype
|
||||||
|
self.fname_out = self.fname_out.parent / gguf.fill_templated_filename(self.fname_out.name, output_type)
|
||||||
|
|
||||||
self.set_type()
|
self.set_type()
|
||||||
|
|
||||||
|
@ -3624,10 +3621,10 @@ def main() -> None:
|
||||||
logger.error("Error: Cannot use temp file when splitting")
|
logger.error("Error: Cannot use temp file when splitting")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
fname_out = None
|
|
||||||
|
|
||||||
if args.outfile is not None:
|
if args.outfile is not None:
|
||||||
fname_out = args.outfile
|
fname_out = args.outfile
|
||||||
|
else:
|
||||||
|
fname_out = dir_model
|
||||||
|
|
||||||
logger.info(f"Loading model: {dir_model.name}")
|
logger.info(f"Loading model: {dir_model.name}")
|
||||||
|
|
||||||
|
@ -3658,7 +3655,6 @@ def main() -> None:
|
||||||
else:
|
else:
|
||||||
logger.info("Exporting model...")
|
logger.info("Exporting model...")
|
||||||
model_instance.write()
|
model_instance.write()
|
||||||
assert model_instance.fname_out is not None
|
|
||||||
out_path = f"{model_instance.fname_out.parent}{os.sep}" if is_split else model_instance.fname_out
|
out_path = f"{model_instance.fname_out.parent}{os.sep}" if is_split else model_instance.fname_out
|
||||||
logger.info(f"Model successfully exported to {out_path}")
|
logger.info(f"Model successfully exported to {out_path}")
|
||||||
|
|
||||||
|
|
|
@ -290,7 +290,7 @@ if __name__ == '__main__':
|
||||||
fname_out = args.outfile
|
fname_out = args.outfile
|
||||||
else:
|
else:
|
||||||
# output in the same directory as the model by default
|
# output in the same directory as the model by default
|
||||||
fname_out = dir_lora / 'ggml-lora-{ftype}.gguf'
|
fname_out = dir_lora
|
||||||
|
|
||||||
if os.path.exists(input_model):
|
if os.path.exists(input_model):
|
||||||
# lazy import load_file only if lora is in safetensors format.
|
# lazy import load_file only if lora is in safetensors format.
|
||||||
|
@ -304,12 +304,6 @@ if __name__ == '__main__':
|
||||||
# load base model
|
# load base model
|
||||||
logger.info(f"Loading base model: {dir_base_model.name}")
|
logger.info(f"Loading base model: {dir_base_model.name}")
|
||||||
hparams = Model.load_hparams(dir_base_model)
|
hparams = Model.load_hparams(dir_base_model)
|
||||||
|
|
||||||
with open(lora_config, "r") as f:
|
|
||||||
lparams: dict[str, Any] = json.load(f)
|
|
||||||
|
|
||||||
alpha: float = lparams["lora_alpha"]
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
try:
|
try:
|
||||||
model_class = Model.from_model_architecture(hparams["architectures"][0])
|
model_class = Model.from_model_architecture(hparams["architectures"][0])
|
||||||
|
@ -320,12 +314,21 @@ if __name__ == '__main__':
|
||||||
class LoraModel(model_class):
|
class LoraModel(model_class):
|
||||||
model_arch = model_class.model_arch
|
model_arch = model_class.model_arch
|
||||||
|
|
||||||
|
lora_alpha: float
|
||||||
|
|
||||||
|
def __init__(self, *args, dir_lora_model: Path, lora_alpha: float, **kwargs):
|
||||||
|
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
self.dir_model_card = dir_lora_model
|
||||||
|
self.lora_alpha = float(lora_alpha)
|
||||||
|
|
||||||
def set_type(self):
|
def set_type(self):
|
||||||
self.gguf_writer.add_type(gguf.GGUFType.ADAPTER)
|
self.gguf_writer.add_type(gguf.GGUFType.ADAPTER)
|
||||||
self.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora")
|
self.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora")
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, float(alpha))
|
self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha)
|
||||||
super().set_gguf_parameters()
|
super().set_gguf_parameters()
|
||||||
|
|
||||||
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
|
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
|
||||||
|
@ -368,6 +371,11 @@ if __name__ == '__main__':
|
||||||
yield (dest_name + ".lora_a", lora_a)
|
yield (dest_name + ".lora_a", lora_a)
|
||||||
yield (dest_name + ".lora_b", lora_b)
|
yield (dest_name + ".lora_b", lora_b)
|
||||||
|
|
||||||
|
with open(lora_config, "r") as f:
|
||||||
|
lparams: dict[str, Any] = json.load(f)
|
||||||
|
|
||||||
|
alpha: float = lparams["lora_alpha"]
|
||||||
|
|
||||||
model_instance = LoraModel(
|
model_instance = LoraModel(
|
||||||
dir_base_model,
|
dir_base_model,
|
||||||
ftype,
|
ftype,
|
||||||
|
@ -376,6 +384,8 @@ if __name__ == '__main__':
|
||||||
use_temp_file=False,
|
use_temp_file=False,
|
||||||
eager=args.no_lazy,
|
eager=args.no_lazy,
|
||||||
dry_run=args.dry_run,
|
dry_run=args.dry_run,
|
||||||
|
dir_lora_model=dir_lora,
|
||||||
|
lora_alpha=alpha,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("Exporting model...")
|
logger.info("Exporting model...")
|
||||||
|
|
|
@ -54,6 +54,7 @@ class Metadata:
|
||||||
|
|
||||||
model_card = Metadata.load_model_card(model_path)
|
model_card = Metadata.load_model_card(model_path)
|
||||||
hf_params = Metadata.load_hf_parameters(model_path)
|
hf_params = Metadata.load_hf_parameters(model_path)
|
||||||
|
# TODO: load adapter_config.json when possible, it usually contains the base model of the LoRA adapter
|
||||||
|
|
||||||
# heuristics
|
# heuristics
|
||||||
metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_path, total_params)
|
metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_path, total_params)
|
||||||
|
@ -177,6 +178,12 @@ class Metadata:
|
||||||
org_component = None
|
org_component = None
|
||||||
|
|
||||||
name_parts: list[str] = model_full_name_component.split('-')
|
name_parts: list[str] = model_full_name_component.split('-')
|
||||||
|
|
||||||
|
# Remove empty parts
|
||||||
|
for i in reversed(range(len(name_parts))):
|
||||||
|
if len(name_parts[i]) == 0:
|
||||||
|
del name_parts[i]
|
||||||
|
|
||||||
name_types: list[
|
name_types: list[
|
||||||
set[Literal["basename", "size_label", "finetune", "version", "type"]]
|
set[Literal["basename", "size_label", "finetune", "version", "type"]]
|
||||||
] = [set() for _ in name_parts]
|
] = [set() for _ in name_parts]
|
||||||
|
@ -223,9 +230,19 @@ class Metadata:
|
||||||
name_parts[i] = part
|
name_parts[i] = part
|
||||||
# Some easy to recognize finetune names
|
# Some easy to recognize finetune names
|
||||||
elif i > 0 and re.fullmatch(r'chat|instruct|vision|lora', part, re.IGNORECASE):
|
elif i > 0 and re.fullmatch(r'chat|instruct|vision|lora', part, re.IGNORECASE):
|
||||||
name_types[i].add("finetune")
|
if total_params < 0 and part.lower() == "lora":
|
||||||
if part.lower() == "lora":
|
# ignore redundant "lora" in the finetune part when the output is a lora adapter
|
||||||
name_parts[i] = "LoRA"
|
name_types[i].add("type")
|
||||||
|
else:
|
||||||
|
name_types[i].add("finetune")
|
||||||
|
|
||||||
|
# Ignore word-based size labels when there is at least a number-based one present
|
||||||
|
# TODO: should word-based size labels always be removed instead?
|
||||||
|
if any(c.isdecimal() for n, t in zip(name_parts, name_types) if "size_label" in t for c in n):
|
||||||
|
for n, t in zip(name_parts, name_types):
|
||||||
|
if "size_label" in t:
|
||||||
|
if all(c.isalpha() for c in n):
|
||||||
|
t.remove("size_label")
|
||||||
|
|
||||||
at_start = True
|
at_start = True
|
||||||
# Find the basename through the annotated name
|
# Find the basename through the annotated name
|
||||||
|
@ -240,18 +257,18 @@ class Metadata:
|
||||||
|
|
||||||
# Remove the basename annotation from trailing version
|
# Remove the basename annotation from trailing version
|
||||||
for part, t in zip(reversed(name_parts), reversed(name_types)):
|
for part, t in zip(reversed(name_parts), reversed(name_types)):
|
||||||
if "basename" in t:
|
if "basename" in t and len(t) > 1:
|
||||||
if len(t) > 1:
|
t.remove("basename")
|
||||||
t.remove("basename")
|
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
|
||||||
basename = "-".join(n for n, t in zip(name_parts, name_types) if "basename" in t) or None
|
basename = "-".join(n for n, t in zip(name_parts, name_types) if "basename" in t) or None
|
||||||
size_label = "-".join(s for s, t in zip(name_parts, name_types) if "size_label" in t) or None
|
# Deduplicate size labels using order-preserving 'dict' ('set' seems to sort the keys)
|
||||||
|
size_label = "-".join(dict.fromkeys(s for s, t in zip(name_parts, name_types) if "size_label" in t).keys()) or None
|
||||||
finetune = "-".join(f for f, t in zip(name_parts, name_types) if "finetune" in t) or None
|
finetune = "-".join(f for f, t in zip(name_parts, name_types) if "finetune" in t) or None
|
||||||
# TODO: should the basename version always be excluded?
|
# TODO: should the basename version always be excluded?
|
||||||
# TODO: should multiple versions be joined together?
|
# NOTE: multiple finetune versions are joined together
|
||||||
version = ([v for v, t, in zip(name_parts, name_types) if "version" in t and "basename" not in t] or [None])[-1]
|
version = "-".join(v for v, t, in zip(name_parts, name_types) if "version" in t and "basename" not in t) or None
|
||||||
|
|
||||||
if size_label is None and finetune is None and version is None:
|
if size_label is None and finetune is None and version is None:
|
||||||
# Too ambiguous, output nothing
|
# Too ambiguous, output nothing
|
||||||
|
|
|
@ -50,15 +50,15 @@ def naming_convention(model_name: str | None, base_name: str | None, finetune_st
|
||||||
# Reference: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#gguf-naming-convention
|
# Reference: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#gguf-naming-convention
|
||||||
|
|
||||||
if base_name is not None:
|
if base_name is not None:
|
||||||
name = base_name.strip().title().replace(' ', '-').replace('/', '-')
|
name = base_name.strip().replace(' ', '-').replace('/', '-')
|
||||||
elif model_name is not None:
|
elif model_name is not None:
|
||||||
name = model_name.strip().title().replace(' ', '-').replace('/', '-')
|
name = model_name.strip().replace(' ', '-').replace('/', '-')
|
||||||
else:
|
else:
|
||||||
name = "ggml-model"
|
name = "ggml-model"
|
||||||
|
|
||||||
parameters = f"-{size_label}" if size_label is not None else ""
|
parameters = f"-{size_label}" if size_label is not None else ""
|
||||||
|
|
||||||
finetune = f"-{finetune_string.strip().title().replace(' ', '-')}" if finetune_string is not None else ""
|
finetune = f"-{finetune_string.strip().replace(' ', '-')}" if finetune_string is not None else ""
|
||||||
|
|
||||||
version = f"-{version_string.strip().replace(' ', '-')}" if version_string is not None else ""
|
version = f"-{version_string.strip().replace(' ', '-')}" if version_string is not None else ""
|
||||||
|
|
||||||
|
|
|
@ -54,7 +54,7 @@ class TestMetadataMethod(unittest.TestCase):
|
||||||
self.assertEqual(gguf.Metadata.get_model_id_components("NousResearch/Meta-Llama-3-8B"),
|
self.assertEqual(gguf.Metadata.get_model_id_components("NousResearch/Meta-Llama-3-8B"),
|
||||||
('Meta-Llama-3-8B', "NousResearch", 'Meta-Llama-3', None, None, '8B'))
|
('Meta-Llama-3-8B', "NousResearch", 'Meta-Llama-3', None, None, '8B'))
|
||||||
|
|
||||||
# Can't detect all non standard form in a heuristically safe way... best to err in caution and output nothing...
|
# Non standard naming
|
||||||
self.assertEqual(gguf.Metadata.get_model_id_components("Qwen1.5-MoE-A2.7B-Chat"),
|
self.assertEqual(gguf.Metadata.get_model_id_components("Qwen1.5-MoE-A2.7B-Chat"),
|
||||||
('Qwen1.5-MoE-A2.7B-Chat', None, 'Qwen1.5-MoE', 'Chat', None, 'A2.7B'))
|
('Qwen1.5-MoE-A2.7B-Chat', None, 'Qwen1.5-MoE', 'Chat', None, 'A2.7B'))
|
||||||
|
|
||||||
|
@ -71,7 +71,7 @@ class TestMetadataMethod(unittest.TestCase):
|
||||||
self.assertEqual(gguf.Metadata.get_model_id_components("delphi-suite/stories-llama2-50k", 50 * 10**3),
|
self.assertEqual(gguf.Metadata.get_model_id_components("delphi-suite/stories-llama2-50k", 50 * 10**3),
|
||||||
('stories-llama2-50k', 'delphi-suite', 'stories-llama2', None, None, '50K'))
|
('stories-llama2-50k', 'delphi-suite', 'stories-llama2', None, None, '50K'))
|
||||||
|
|
||||||
# None standard and not easy to disambiguate
|
# Non standard and not easy to disambiguate
|
||||||
self.assertEqual(gguf.Metadata.get_model_id_components("DeepSeek-Coder-V2-Lite-Instruct"),
|
self.assertEqual(gguf.Metadata.get_model_id_components("DeepSeek-Coder-V2-Lite-Instruct"),
|
||||||
('DeepSeek-Coder-V2-Lite-Instruct', None, 'DeepSeek-Coder-V2-Lite', 'Instruct', None, None))
|
('DeepSeek-Coder-V2-Lite-Instruct', None, 'DeepSeek-Coder-V2-Lite', 'Instruct', None, None))
|
||||||
|
|
||||||
|
@ -123,6 +123,51 @@ class TestMetadataMethod(unittest.TestCase):
|
||||||
self.assertEqual(gguf.Metadata.get_model_id_components("bigscience/bloom-7b1-petals"),
|
self.assertEqual(gguf.Metadata.get_model_id_components("bigscience/bloom-7b1-petals"),
|
||||||
('bloom-7b1-petals', 'bigscience', 'bloom', 'petals', None, '7.1B'))
|
('bloom-7b1-petals', 'bigscience', 'bloom', 'petals', None, '7.1B'))
|
||||||
|
|
||||||
|
# Ignore full-text size labels when there are number-based ones, and deduplicate size labels
|
||||||
|
self.assertEqual(gguf.Metadata.get_model_id_components("MaziyarPanahi/GreenNode-mini-7B-multilingual-v1olet-Mistral-7B-Instruct-v0.1"),
|
||||||
|
('GreenNode-mini-7B-multilingual-v1olet-Mistral-7B-Instruct-v0.1', 'MaziyarPanahi', 'GreenNode-mini', 'multilingual-v1olet-Mistral-Instruct', 'v0.1', '7B'))
|
||||||
|
|
||||||
|
# Instruct in a name without a size label
|
||||||
|
self.assertEqual(gguf.Metadata.get_model_id_components("mistralai/Mistral-Nemo-Instruct-2407"),
|
||||||
|
('Mistral-Nemo-Instruct-2407', 'mistralai', 'Mistral-Nemo', 'Instruct', '2407', None))
|
||||||
|
|
||||||
|
# Non-obvious splitting relying on 'chat' keyword
|
||||||
|
self.assertEqual(gguf.Metadata.get_model_id_components("deepseek-ai/DeepSeek-V2-Chat-0628"),
|
||||||
|
('DeepSeek-V2-Chat-0628', 'deepseek-ai', 'DeepSeek-V2', 'Chat', '0628', None))
|
||||||
|
|
||||||
|
# Multiple versions
|
||||||
|
self.assertEqual(gguf.Metadata.get_model_id_components("OpenGVLab/Mini-InternVL-Chat-2B-V1-5"),
|
||||||
|
('Mini-InternVL-Chat-2B-V1-5', 'OpenGVLab', 'Mini-InternVL', 'Chat', 'V1-5', '2B'))
|
||||||
|
|
||||||
|
# TODO: DPO in the name
|
||||||
|
self.assertEqual(gguf.Metadata.get_model_id_components("jondurbin/bagel-dpo-2.8b-v0.2"),
|
||||||
|
('bagel-dpo-2.8b-v0.2', 'jondurbin', 'bagel-dpo', None, 'v0.2', '2.8B'))
|
||||||
|
|
||||||
|
# DPO in name, but can't be used for the finetune to keep 'LLaMA-3' in the basename
|
||||||
|
self.assertEqual(gguf.Metadata.get_model_id_components("voxmenthe/SFR-Iterative-DPO-LLaMA-3-8B-R-unquantized"),
|
||||||
|
('SFR-Iterative-DPO-LLaMA-3-8B-R-unquantized', 'voxmenthe', 'SFR-Iterative-DPO-LLaMA-3', 'R-unquantized', None, '8B'))
|
||||||
|
|
||||||
|
# Too ambiguous
|
||||||
|
# TODO: should "base" be a 'finetune' or 'size_label'?
|
||||||
|
# (in this case it should be a size label, but other models use it to signal that they are not finetuned)
|
||||||
|
self.assertEqual(gguf.Metadata.get_model_id_components("microsoft/Florence-2-base"),
|
||||||
|
('Florence-2-base', 'microsoft', None, None, None, None))
|
||||||
|
|
||||||
|
## Invalid cases ##
|
||||||
|
|
||||||
|
# Start with a dash and has dashes in rows
|
||||||
|
self.assertEqual(gguf.Metadata.get_model_id_components("mistralai/-Mistral--Nemo-Base-2407-"),
|
||||||
|
('-Mistral--Nemo-Base-2407-', 'mistralai', 'Mistral-Nemo-Base', None, '2407', None))
|
||||||
|
|
||||||
|
## LoRA ##
|
||||||
|
|
||||||
|
self.assertEqual(gguf.Metadata.get_model_id_components("Llama-3-Instruct-abliteration-LoRA-8B"),
|
||||||
|
('Llama-3-Instruct-abliteration-LoRA-8B', None, 'Llama-3', 'Instruct-abliteration-LoRA', None, '8B'))
|
||||||
|
|
||||||
|
# Negative size --> output is a LoRA adaper --> prune "LoRA" out of the name to avoid redundancy with the suffix
|
||||||
|
self.assertEqual(gguf.Metadata.get_model_id_components("Llama-3-Instruct-abliteration-LoRA-8B", -1234),
|
||||||
|
('Llama-3-Instruct-abliteration-LoRA-8B', None, 'Llama-3', 'Instruct-abliteration', None, '8B'))
|
||||||
|
|
||||||
def test_apply_metadata_heuristic_from_model_card(self):
|
def test_apply_metadata_heuristic_from_model_card(self):
|
||||||
model_card = {
|
model_card = {
|
||||||
'tags': ['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl'],
|
'tags': ['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl'],
|
||||||
|
@ -134,7 +179,7 @@ class TestMetadataMethod(unittest.TestCase):
|
||||||
}
|
}
|
||||||
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
|
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
|
||||||
expect = gguf.Metadata()
|
expect = gguf.Metadata()
|
||||||
expect.base_models=[{'name': 'Mistral 7B Merge 14 v0', 'organization': 'EmbeddedLLM', 'version': 'v0', 'repo_url': 'https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0'}, {'name': 'Trinity v1', 'organization': 'Janai Hq', 'version': 'v1', 'repo_url': 'https://huggingface.co/janai-hq/trinity-v1'}]
|
expect.base_models=[{'name': 'Mistral 7B Merge 14 v0', 'organization': 'EmbeddedLLM', 'version': '14-v0', 'repo_url': 'https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0'}, {'name': 'Trinity v1', 'organization': 'Janai Hq', 'version': 'v1', 'repo_url': 'https://huggingface.co/janai-hq/trinity-v1'}]
|
||||||
expect.tags=['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl']
|
expect.tags=['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl']
|
||||||
expect.languages=['en']
|
expect.languages=['en']
|
||||||
expect.datasets=['teknium/OpenHermes-2.5']
|
expect.datasets=['teknium/OpenHermes-2.5']
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue