gguf-py : extract metadata from model name more resiliently
Using more than one regex to annotate the parts of the name, this way, the order doesn't have to be fixed and this should work correctly for more edge cases. Also, the total parameter count of the model is used to figure out if a size label is not actually a size label, but a context size. * convert_lora : fix duplicate model type key
This commit is contained in:
parent
7e9271cabf
commit
2c18a9a4d4
5 changed files with 161 additions and 73 deletions
|
@ -60,19 +60,19 @@ class Model:
|
||||||
tensor_map: gguf.TensorNameMap
|
tensor_map: gguf.TensorNameMap
|
||||||
tensor_names: set[str] | None
|
tensor_names: set[str] | None
|
||||||
gguf_writer: gguf.GGUFWriter
|
gguf_writer: gguf.GGUFWriter
|
||||||
metadata: gguf.Metadata
|
model_name: str | None
|
||||||
|
metadata_override: Path | None
|
||||||
|
|
||||||
# subclasses should define this!
|
# subclasses should define this!
|
||||||
model_arch: gguf.MODEL_ARCH
|
model_arch: gguf.MODEL_ARCH
|
||||||
|
|
||||||
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path | None, is_big_endian: bool = False, use_temp_file: bool = False, eager: bool = False, metadata: gguf.Metadata | None = None,
|
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path | None, is_big_endian: bool = False,
|
||||||
|
use_temp_file: bool = False, eager: bool = False,
|
||||||
|
metadata_override: Path | None = None, model_name: str | None = None,
|
||||||
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False):
|
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False):
|
||||||
if type(self) is Model:
|
if type(self) is Model:
|
||||||
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
|
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
|
||||||
|
|
||||||
if metadata is None:
|
|
||||||
metadata = gguf.Metadata()
|
|
||||||
|
|
||||||
self.dir_model = dir_model
|
self.dir_model = dir_model
|
||||||
self.ftype = ftype
|
self.ftype = ftype
|
||||||
self.fname_out = fname_out
|
self.fname_out = fname_out
|
||||||
|
@ -88,7 +88,8 @@ class Model:
|
||||||
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"])
|
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"])
|
||||||
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
|
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
|
||||||
self.tensor_names = None
|
self.tensor_names = None
|
||||||
self.metadata = metadata
|
self.metadata_override = metadata_override
|
||||||
|
self.model_name = model_name
|
||||||
|
|
||||||
# Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type
|
# Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type
|
||||||
if self.ftype == gguf.LlamaFileType.GUESSED:
|
if self.ftype == gguf.LlamaFileType.GUESSED:
|
||||||
|
@ -337,18 +338,22 @@ class Model:
|
||||||
|
|
||||||
self.gguf_writer.add_tensor(new_name, data, raw_dtype=data_qtype)
|
self.gguf_writer.add_tensor(new_name, data, raw_dtype=data_qtype)
|
||||||
|
|
||||||
|
def set_type(self):
|
||||||
|
self.gguf_writer.add_type(gguf.GGUFType.MODEL)
|
||||||
|
|
||||||
def prepare_metadata(self, vocab_only: bool):
|
def prepare_metadata(self, vocab_only: bool):
|
||||||
|
|
||||||
|
total_params, shared_params, expert_params, expert_count = self.gguf_writer.get_total_parameter_count()
|
||||||
|
|
||||||
|
self.metadata = gguf.Metadata.load(self.metadata_override, self.dir_model, self.model_name, total_params)
|
||||||
|
|
||||||
# Fallback to model directory name if metadata name is still missing
|
# Fallback to model directory name if metadata name is still missing
|
||||||
if self.metadata.name is None:
|
if self.metadata.name is None:
|
||||||
self.metadata.name = self.dir_model.name
|
self.metadata.name = self.dir_model.name
|
||||||
|
|
||||||
# Generate parameter weight class (useful for leader boards) if not yet determined
|
# Generate parameter weight class (useful for leader boards) if not yet determined
|
||||||
if self.metadata.size_label is None:
|
if self.metadata.size_label is None and total_params > 0:
|
||||||
total_params, shared_params, expert_params, expert_count = self.gguf_writer.get_total_parameter_count()
|
self.metadata.size_label = gguf.size_label(total_params, shared_params, expert_params, expert_count)
|
||||||
|
|
||||||
if (total_params > 0):
|
|
||||||
self.metadata.size_label = gguf.size_label(total_params, shared_params, expert_params, expert_count)
|
|
||||||
|
|
||||||
# Filename Output
|
# Filename Output
|
||||||
if self.fname_out is not None and not self.fname_out.is_dir():
|
if self.fname_out is not None and not self.fname_out.is_dir():
|
||||||
|
@ -383,11 +388,12 @@ class Model:
|
||||||
# output in the same directory as the model by default
|
# output in the same directory as the model by default
|
||||||
self.fname_out = self.dir_model / f"{fname_default}.gguf"
|
self.fname_out = self.dir_model / f"{fname_default}.gguf"
|
||||||
|
|
||||||
|
self.set_type()
|
||||||
|
|
||||||
logger.info("Set meta model")
|
logger.info("Set meta model")
|
||||||
self.metadata.set_gguf_meta_model(self.gguf_writer)
|
self.metadata.set_gguf_meta_model(self.gguf_writer)
|
||||||
|
|
||||||
logger.info("Set model parameters")
|
logger.info("Set model parameters")
|
||||||
self.gguf_writer.add_type(gguf.GGUFType.MODEL)
|
|
||||||
self.set_gguf_parameters()
|
self.set_gguf_parameters()
|
||||||
|
|
||||||
logger.info("Set model tokenizer")
|
logger.info("Set model tokenizer")
|
||||||
|
@ -3607,11 +3613,8 @@ def main() -> None:
|
||||||
else:
|
else:
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
model_name = args.model_name
|
|
||||||
dir_model = args.model
|
dir_model = args.model
|
||||||
|
|
||||||
metadata = gguf.Metadata.load(args.metadata, dir_model, model_name)
|
|
||||||
|
|
||||||
if not dir_model.is_dir():
|
if not dir_model.is_dir():
|
||||||
logger.error(f'Error: {args.model} is not a directory')
|
logger.error(f'Error: {args.model} is not a directory')
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
@ -3650,7 +3653,9 @@ def main() -> None:
|
||||||
|
|
||||||
model_instance = model_class(dir_model=dir_model, ftype=output_type, fname_out=fname_out,
|
model_instance = model_class(dir_model=dir_model, ftype=output_type, fname_out=fname_out,
|
||||||
is_big_endian=args.bigendian, use_temp_file=args.use_temp_file,
|
is_big_endian=args.bigendian, use_temp_file=args.use_temp_file,
|
||||||
eager=args.no_lazy, metadata=metadata, split_max_tensors=args.split_max_tensors,
|
eager=args.no_lazy,
|
||||||
|
metadata_override=args.metadata, model_name=args.model_name,
|
||||||
|
split_max_tensors=args.split_max_tensors,
|
||||||
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
|
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
|
||||||
small_first_shard=args.no_tensor_first_split)
|
small_first_shard=args.no_tensor_first_split)
|
||||||
|
|
||||||
|
|
|
@ -251,6 +251,10 @@ def parse_args() -> argparse.Namespace:
|
||||||
"--verbose", action="store_true",
|
"--verbose", action="store_true",
|
||||||
help="increase output verbosity",
|
help="increase output verbosity",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dry-run", action="store_true",
|
||||||
|
help="only print out what will be done, without writing any new files",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--base", type=Path, required=True,
|
"--base", type=Path, required=True,
|
||||||
help="directory containing base model file",
|
help="directory containing base model file",
|
||||||
|
@ -300,6 +304,12 @@ if __name__ == '__main__':
|
||||||
# load base model
|
# load base model
|
||||||
logger.info(f"Loading base model: {dir_base_model.name}")
|
logger.info(f"Loading base model: {dir_base_model.name}")
|
||||||
hparams = Model.load_hparams(dir_base_model)
|
hparams = Model.load_hparams(dir_base_model)
|
||||||
|
|
||||||
|
with open(lora_config, "r") as f:
|
||||||
|
lparams: dict[str, Any] = json.load(f)
|
||||||
|
|
||||||
|
alpha: float = lparams["lora_alpha"]
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
try:
|
try:
|
||||||
model_class = Model.from_model_architecture(hparams["architectures"][0])
|
model_class = Model.from_model_architecture(hparams["architectures"][0])
|
||||||
|
@ -310,6 +320,14 @@ if __name__ == '__main__':
|
||||||
class LoraModel(model_class):
|
class LoraModel(model_class):
|
||||||
model_arch = model_class.model_arch
|
model_arch = model_class.model_arch
|
||||||
|
|
||||||
|
def set_type(self):
|
||||||
|
self.gguf_writer.add_type(gguf.GGUFType.ADAPTER)
|
||||||
|
self.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora")
|
||||||
|
|
||||||
|
def set_gguf_parameters(self):
|
||||||
|
self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, float(alpha))
|
||||||
|
super().set_gguf_parameters()
|
||||||
|
|
||||||
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
|
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
|
||||||
tensor_map: dict[str, PartialLoraTensor] = {}
|
tensor_map: dict[str, PartialLoraTensor] = {}
|
||||||
|
|
||||||
|
@ -356,18 +374,10 @@ if __name__ == '__main__':
|
||||||
fname_out,
|
fname_out,
|
||||||
is_big_endian=args.bigendian,
|
is_big_endian=args.bigendian,
|
||||||
use_temp_file=False,
|
use_temp_file=False,
|
||||||
eager=args.no_lazy
|
eager=args.no_lazy,
|
||||||
|
dry_run=args.dry_run,
|
||||||
)
|
)
|
||||||
|
|
||||||
with open(lora_config, "r") as f:
|
|
||||||
lparams: dict[str, Any] = json.load(f)
|
|
||||||
|
|
||||||
alpha = lparams["lora_alpha"]
|
|
||||||
|
|
||||||
model_instance.gguf_writer.add_string(gguf.Keys.General.TYPE, gguf.GGUFType.ADAPTER)
|
|
||||||
model_instance.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora")
|
|
||||||
model_instance.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, float(alpha))
|
|
||||||
model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
|
|
||||||
logger.info("Exporting model...")
|
logger.info("Exporting model...")
|
||||||
model_instance.write()
|
model_instance.write()
|
||||||
logger.info(f"Model successfully exported to {model_instance.fname_out}")
|
logger.info(f"Model successfully exported to {model_instance.fname_out}")
|
||||||
|
|
|
@ -7,6 +7,7 @@ import struct
|
||||||
import tempfile
|
import tempfile
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
|
from math import prod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from io import BufferedWriter
|
from io import BufferedWriter
|
||||||
from typing import IO, Any, Sequence, Mapping
|
from typing import IO, Any, Sequence, Mapping
|
||||||
|
@ -17,7 +18,6 @@ import numpy as np
|
||||||
from .constants import (
|
from .constants import (
|
||||||
GGUF_DEFAULT_ALIGNMENT,
|
GGUF_DEFAULT_ALIGNMENT,
|
||||||
GGUF_MAGIC,
|
GGUF_MAGIC,
|
||||||
GGML_QUANT_SIZES,
|
|
||||||
GGUF_VERSION,
|
GGUF_VERSION,
|
||||||
GGMLQuantizationType,
|
GGMLQuantizationType,
|
||||||
GGUFEndian,
|
GGUFEndian,
|
||||||
|
@ -115,16 +115,29 @@ class GGUFWriter:
|
||||||
expert_sum = 0
|
expert_sum = 0
|
||||||
n_expert_tensors = 0
|
n_expert_tensors = 0
|
||||||
|
|
||||||
|
last_lora_a: tuple[str, TensorInfo] | None = None
|
||||||
|
|
||||||
for tensors in self.tensors:
|
for tensors in self.tensors:
|
||||||
for name, info in tensors.items():
|
for name, info in tensors.items():
|
||||||
|
|
||||||
block_size, type_size = GGML_QUANT_SIZES[info.dtype]
|
shape = info.shape
|
||||||
|
|
||||||
size = (info.nbytes // type_size) * block_size
|
if name.endswith(".lora_a"):
|
||||||
|
last_lora_a = (name, info)
|
||||||
|
continue
|
||||||
|
elif name.endswith(".lora_b"):
|
||||||
|
if last_lora_a is None or last_lora_a[0] != name[:-1] + "a":
|
||||||
|
# Bail when the LoRA pair can't be found trivially
|
||||||
|
logger.warning("can't measure LoRA size correctly, tensor order is unusual")
|
||||||
|
return 0, 0, 0, 0
|
||||||
|
else:
|
||||||
|
shape = (*shape[:-1], last_lora_a[1].shape[-1])
|
||||||
|
|
||||||
|
size = prod(shape)
|
||||||
|
|
||||||
if "_exps." in name:
|
if "_exps." in name:
|
||||||
expert_params += (size // info.shape[-3])
|
expert_params += (size // shape[-3])
|
||||||
expert_sum += info.shape[-3]
|
expert_sum += shape[-3]
|
||||||
n_expert_tensors += 1
|
n_expert_tensors += 1
|
||||||
else:
|
else:
|
||||||
shared_params += size
|
shared_params += size
|
||||||
|
|
|
@ -5,7 +5,7 @@ import json
|
||||||
import yaml
|
import yaml
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional
|
from typing import Any, Literal, Optional
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from .constants import Keys
|
from .constants import Keys
|
||||||
|
@ -44,7 +44,7 @@ class Metadata:
|
||||||
datasets: Optional[list[str]] = None
|
datasets: Optional[list[str]] = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Path] = None, model_name: Optional[str] = None) -> Metadata:
|
def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Path] = None, model_name: Optional[str] = None, total_params: int = 0) -> Metadata:
|
||||||
# This grabs as many contextual authorship metadata as possible from the model repository
|
# This grabs as many contextual authorship metadata as possible from the model repository
|
||||||
# making any conversion as required to match the gguf kv store metadata format
|
# making any conversion as required to match the gguf kv store metadata format
|
||||||
# as well as giving users the ability to override any authorship metadata that may be incorrect
|
# as well as giving users the ability to override any authorship metadata that may be incorrect
|
||||||
|
@ -56,7 +56,7 @@ class Metadata:
|
||||||
hf_params = Metadata.load_hf_parameters(model_path)
|
hf_params = Metadata.load_hf_parameters(model_path)
|
||||||
|
|
||||||
# heuristics
|
# heuristics
|
||||||
metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_path)
|
metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_path, total_params)
|
||||||
|
|
||||||
# Metadata Override File Provided
|
# Metadata Override File Provided
|
||||||
# This is based on LLM_KV_NAMES mapping in llama.cpp
|
# This is based on LLM_KV_NAMES mapping in llama.cpp
|
||||||
|
@ -150,7 +150,7 @@ class Metadata:
|
||||||
return ' '.join([w.title() if w.islower() and not re.match(r'^(v\d+(?:\.\d+)*|\d.*)$', w) else w for w in string.strip().replace('-', ' ').split()])
|
return ' '.join([w.title() if w.islower() and not re.match(r'^(v\d+(?:\.\d+)*|\d.*)$', w) else w for w in string.strip().replace('-', ' ').split()])
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_model_id_components(model_id: Optional[str] = None) -> tuple[str | None, str | None, str | None, str | None, str | None, str | None]:
|
def get_model_id_components(model_id: Optional[str] = None, total_params: int = 0) -> tuple[str | None, str | None, str | None, str | None, str | None, str | None]:
|
||||||
# Huggingface often store model id as '<org>/<model name>'
|
# Huggingface often store model id as '<org>/<model name>'
|
||||||
# so let's parse it and apply some heuristics if possible for model name components
|
# so let's parse it and apply some heuristics if possible for model name components
|
||||||
|
|
||||||
|
@ -175,35 +175,82 @@ class Metadata:
|
||||||
if org_component is not None and org_component[0] == '.':
|
if org_component is not None and org_component[0] == '.':
|
||||||
org_component = None
|
org_component = None
|
||||||
|
|
||||||
# Regular expression to extract model name components
|
name_parts: list[str] = model_full_name_component.split('-')
|
||||||
# Heuristic to match against cases such as 'Mixtral-8x7B-Instruct-v0.1' or 'Codestral-22B-v0.1'
|
name_types: list[
|
||||||
regex_match = re.compile(r'^'
|
set[Literal["basename", "size_label", "finetune", "version", "type"]]
|
||||||
r'(?P<basename>[A-Za-z0-9\s]*(?:(?:-(?:(?:[A-Za-z\s][A-Za-z0-9\s]*)|(?:[0-9\s]*)))*))'
|
] = [set() for _ in name_parts]
|
||||||
r'(?:-(?P<size_label>(?:\d+x)?(\d+\.)?\d+[A-Za-z](?:-[A-Za-z]+(\d+\.)?\d+[A-Za-z]+)?)(?:-(?P<finetune>[A-Za-z0-9\s-]+))?)?'
|
|
||||||
r'(?:-(?P<version>v\d+(?:\.\d+)*))?'
|
|
||||||
r'$').match(model_full_name_component)
|
|
||||||
|
|
||||||
if not regex_match:
|
# Annotate the name
|
||||||
return model_full_name_component, org_component, None, None, None, None
|
for i, part in enumerate(name_parts):
|
||||||
|
# Version
|
||||||
|
if re.fullmatch(r'(v|iter)?\d+([.]\d+)*', part, re.IGNORECASE):
|
||||||
|
name_types[i].add("version")
|
||||||
|
# Quant type (should not be there for base models, but still annotated)
|
||||||
|
elif re.fullmatch(r'[iI]?[qQ]\d(_\w)*', part):
|
||||||
|
name_types[i].add("type")
|
||||||
|
name_parts[i] = part.upper()
|
||||||
|
# Model size
|
||||||
|
elif i > 0 and re.fullmatch(r'(([A]|\d+[x])?\d+([._]\d+)?[kMBT]|small|mini|medium|large|xl)', part, re.IGNORECASE):
|
||||||
|
part = part.replace("_", ".")
|
||||||
|
if len(part) > 1 and part[-2].isdecimal():
|
||||||
|
if part[-1] in "mbt":
|
||||||
|
part = part[:-1] + part[-1].upper()
|
||||||
|
elif part[-1] in "k":
|
||||||
|
part = part[:-1] + part[-1].lower()
|
||||||
|
if total_params > 0:
|
||||||
|
try:
|
||||||
|
label_params = float(part[:-1]) * pow(1000, " kMBT".find(part[-1]))
|
||||||
|
# Only use it as a size label if it's close or bigger than the model size
|
||||||
|
# Note that LoRA adapters don't necessarily include all layers,
|
||||||
|
# so this is why bigger label sizes are accepted.
|
||||||
|
# Do not use the size label when it's smaller than 3/4 of the model size
|
||||||
|
if total_params - label_params > total_params // 4:
|
||||||
|
# Likely a context length
|
||||||
|
name_types[i].add("finetune")
|
||||||
|
except ValueError:
|
||||||
|
# Failed to convert the size label to float, use it anyway
|
||||||
|
pass
|
||||||
|
if len(name_types[i]) == 0:
|
||||||
|
name_types[i].add("size_label")
|
||||||
|
name_parts[i] = part
|
||||||
|
# Some easy to recognize finetune names
|
||||||
|
elif i > 0 and re.fullmatch(r'chat|instruct|vision', part, re.IGNORECASE):
|
||||||
|
name_types[i].add("finetune")
|
||||||
|
|
||||||
components = regex_match.groupdict()
|
at_start = True
|
||||||
basename = components.get("basename")
|
# Find the basename through the annotated name
|
||||||
size_label = components.get("size_label")
|
for part, t in zip(name_parts, name_types):
|
||||||
finetune = components.get("finetune")
|
if at_start and ((len(t) == 0 and part[0].isalpha()) or "version" in t):
|
||||||
version = components.get("version")
|
t.add("basename")
|
||||||
|
else:
|
||||||
|
if at_start:
|
||||||
|
at_start = False
|
||||||
|
if len(t) == 0:
|
||||||
|
t.add("finetune")
|
||||||
|
|
||||||
# Base name required at a minimum
|
# Remove the basename annotation from trailing version
|
||||||
if basename is None:
|
for part, t in zip(reversed(name_parts), reversed(name_types)):
|
||||||
return model_full_name_component, None, None, None, None, None
|
if "basename" in t:
|
||||||
|
if len(t) > 1:
|
||||||
|
t.remove("basename")
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
# Need to capture at least one component that is not basename
|
basename = "-".join(n for n, t in zip(name_parts, name_types) if "basename" in t) or None
|
||||||
if size_label is None and version is None and finetune is None:
|
size_label = "-".join(s for s, t in zip(name_parts, name_types) if "size_label" in t) or None
|
||||||
return model_full_name_component, None, None, None, None, None
|
finetune = "-".join(f for f, t in zip(name_parts, name_types) if "finetune" in t) or None
|
||||||
|
# TODO: should the basename version always be excluded?
|
||||||
|
# TODO: should multiple versions be joined together?
|
||||||
|
version = ([v for v, t, in zip(name_parts, name_types) if "version" in t and "basename" not in t] or [None])[-1]
|
||||||
|
|
||||||
|
if size_label is None and finetune is None and version is None:
|
||||||
|
# Too ambiguous, output nothing
|
||||||
|
basename = None
|
||||||
|
|
||||||
return model_full_name_component, org_component, basename, finetune, version, size_label
|
return model_full_name_component, org_component, basename, finetune, version, size_label
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def apply_metadata_heuristic(metadata: Metadata, model_card: Optional[dict] = None, hf_params: Optional[dict] = None, model_path: Optional[Path] = None) -> Metadata:
|
def apply_metadata_heuristic(metadata: Metadata, model_card: Optional[dict] = None, hf_params: Optional[dict] = None, model_path: Optional[Path] = None, total_params: int = 0) -> Metadata:
|
||||||
# Reference Model Card Metadata: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
|
# Reference Model Card Metadata: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
|
||||||
|
|
||||||
# Model Card Heuristics
|
# Model Card Heuristics
|
||||||
|
@ -242,7 +289,8 @@ class Metadata:
|
||||||
metadata.base_models = []
|
metadata.base_models = []
|
||||||
|
|
||||||
for model_id in metadata_base_models:
|
for model_id in metadata_base_models:
|
||||||
model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id)
|
# NOTE: model size of base model is assumed to be similar to the size of the current model
|
||||||
|
model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
|
||||||
base_model = {}
|
base_model = {}
|
||||||
if model_full_name_component is not None:
|
if model_full_name_component is not None:
|
||||||
base_model["name"] = Metadata.id_to_title(model_full_name_component)
|
base_model["name"] = Metadata.id_to_title(model_full_name_component)
|
||||||
|
@ -317,7 +365,7 @@ class Metadata:
|
||||||
# Use _name_or_path only if its actually a model name and not some computer path
|
# Use _name_or_path only if its actually a model name and not some computer path
|
||||||
# e.g. 'meta-llama/Llama-2-7b-hf'
|
# e.g. 'meta-llama/Llama-2-7b-hf'
|
||||||
model_id = hf_name_or_path
|
model_id = hf_name_or_path
|
||||||
model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id)
|
model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
|
||||||
if metadata.name is None and model_full_name_component is not None:
|
if metadata.name is None and model_full_name_component is not None:
|
||||||
metadata.name = Metadata.id_to_title(model_full_name_component)
|
metadata.name = Metadata.id_to_title(model_full_name_component)
|
||||||
if metadata.organization is None and org_component is not None:
|
if metadata.organization is None and org_component is not None:
|
||||||
|
@ -335,7 +383,7 @@ class Metadata:
|
||||||
############################################
|
############################################
|
||||||
if model_path is not None:
|
if model_path is not None:
|
||||||
model_id = model_path.name
|
model_id = model_path.name
|
||||||
model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id)
|
model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
|
||||||
if metadata.name is None and model_full_name_component is not None:
|
if metadata.name is None and model_full_name_component is not None:
|
||||||
metadata.name = Metadata.id_to_title(model_full_name_component)
|
metadata.name = Metadata.id_to_title(model_full_name_component)
|
||||||
if metadata.organization is None and org_component is not None:
|
if metadata.organization is None and org_component is not None:
|
||||||
|
|
|
@ -1,9 +1,15 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
import gguf # noqa: F401
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# Necessary to load the local gguf package
|
||||||
|
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
|
import gguf
|
||||||
|
|
||||||
class TestMetadataMethod(unittest.TestCase):
|
class TestMetadataMethod(unittest.TestCase):
|
||||||
|
|
||||||
|
@ -49,7 +55,7 @@ class TestMetadataMethod(unittest.TestCase):
|
||||||
|
|
||||||
# Can't detect all non standard form in a heuristically safe way... best to err in caution and output nothing...
|
# Can't detect all non standard form in a heuristically safe way... best to err in caution and output nothing...
|
||||||
self.assertEqual(gguf.Metadata.get_model_id_components("Qwen1.5-MoE-A2.7B-Chat"),
|
self.assertEqual(gguf.Metadata.get_model_id_components("Qwen1.5-MoE-A2.7B-Chat"),
|
||||||
('Qwen1.5-MoE-A2.7B-Chat', None, None, None, None, None))
|
('Qwen1.5-MoE-A2.7B-Chat', None, 'Qwen1.5-MoE', 'Chat', None, 'A2.7B'))
|
||||||
|
|
||||||
# Capture 'sub size labels' e.g. A14B in '57B-A14B' usually refers to activated params/weight count
|
# Capture 'sub size labels' e.g. A14B in '57B-A14B' usually refers to activated params/weight count
|
||||||
self.assertEqual(gguf.Metadata.get_model_id_components("Qwen2-57B-A14B-Instruct"),
|
self.assertEqual(gguf.Metadata.get_model_id_components("Qwen2-57B-A14B-Instruct"),
|
||||||
|
@ -57,26 +63,29 @@ class TestMetadataMethod(unittest.TestCase):
|
||||||
|
|
||||||
# Check that it can handle a real model id with no version code
|
# Check that it can handle a real model id with no version code
|
||||||
# Note that 4k in this string is non standard and microsoft were referring to context length rather than weight count
|
# Note that 4k in this string is non standard and microsoft were referring to context length rather than weight count
|
||||||
self.assertEqual(gguf.Metadata.get_model_id_components("microsoft/Phi-3-mini-4k-instruct"),
|
self.assertEqual(gguf.Metadata.get_model_id_components("microsoft/Phi-3-mini-4k-instruct", 4*10**9),
|
||||||
('Phi-3-mini-4k-instruct', 'microsoft', 'Phi-3-mini', 'instruct', None, '4k'))
|
('Phi-3-mini-4k-instruct', 'microsoft', 'Phi-3', '4k-instruct', None, 'mini'))
|
||||||
|
|
||||||
# There is some legitimate models with only thousands of parameters
|
# There is some legitimate models with only thousands of parameters
|
||||||
self.assertEqual(gguf.Metadata.get_model_id_components("delphi-suite/stories-llama2-50k"),
|
self.assertEqual(gguf.Metadata.get_model_id_components("delphi-suite/stories-llama2-50k", 50*10**3),
|
||||||
('stories-llama2-50k', 'delphi-suite', 'stories-llama2', None, None, '50k'))
|
('stories-llama2-50k', 'delphi-suite', 'stories-llama2', None, None, '50k'))
|
||||||
|
|
||||||
# None standard and not easy to disambiguate, best to err in caution and output nothing
|
# None standard and not easy to disambiguate
|
||||||
self.assertEqual(gguf.Metadata.get_model_id_components("DeepSeek-Coder-V2-Lite-Instruct"),
|
self.assertEqual(gguf.Metadata.get_model_id_components("DeepSeek-Coder-V2-Lite-Instruct"),
|
||||||
('DeepSeek-Coder-V2-Lite-Instruct', None, None, None, None, None))
|
('DeepSeek-Coder-V2-Lite-Instruct', None, 'DeepSeek-Coder-V2-Lite', 'Instruct', None, None))
|
||||||
|
|
||||||
# This is a real model_id where they append 2DPO to refer to Direct Preference Optimization
|
# This is a real model_id where they append 2DPO to refer to Direct Preference Optimization
|
||||||
# Not able to easily reject '2dpo' while keeping to simple regexp, so best to reject
|
|
||||||
self.assertEqual(gguf.Metadata.get_model_id_components("crestf411/daybreak-kunoichi-2dpo-7b"),
|
self.assertEqual(gguf.Metadata.get_model_id_components("crestf411/daybreak-kunoichi-2dpo-7b"),
|
||||||
('daybreak-kunoichi-2dpo-7b', 'crestf411', None, None, None, None))
|
('daybreak-kunoichi-2dpo-7b', 'crestf411', 'daybreak-kunoichi', '2dpo', None, '7B'))
|
||||||
|
|
||||||
# This is a real model id where the weight size has a decimal point
|
# This is a real model id where the weight size has a decimal point
|
||||||
self.assertEqual(gguf.Metadata.get_model_id_components("Qwen2-0.5B-Instruct"),
|
self.assertEqual(gguf.Metadata.get_model_id_components("Qwen2-0.5B-Instruct"),
|
||||||
('Qwen2-0.5B-Instruct', None, 'Qwen2', 'Instruct', None, '0.5B'))
|
('Qwen2-0.5B-Instruct', None, 'Qwen2', 'Instruct', None, '0.5B'))
|
||||||
|
|
||||||
|
# Uses an underscore in the size label
|
||||||
|
self.assertEqual(gguf.Metadata.get_model_id_components("smallcloudai/Refact-1_6B-fim"),
|
||||||
|
('Refact-1_6B-fim', 'smallcloudai', 'Refact', 'fim', None, '1.6B'))
|
||||||
|
|
||||||
def test_apply_metadata_heuristic_from_model_card(self):
|
def test_apply_metadata_heuristic_from_model_card(self):
|
||||||
model_card = {
|
model_card = {
|
||||||
'tags': ['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl'],
|
'tags': ['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl'],
|
||||||
|
@ -88,7 +97,7 @@ class TestMetadataMethod(unittest.TestCase):
|
||||||
}
|
}
|
||||||
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
|
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
|
||||||
expect = gguf.Metadata()
|
expect = gguf.Metadata()
|
||||||
expect.base_models=[{'name': 'Mistral 7B Merge 14 v0', 'organization': 'EmbeddedLLM', 'repo_url': 'https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0'}, {'name': 'Trinity v1'}]
|
expect.base_models=[{'name': 'Mistral 7B Merge 14 v0', 'organization': 'EmbeddedLLM', 'version': 'v0', 'repo_url': 'https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0'}, {'name': 'Trinity v1', 'organization': 'Janai Hq', 'version': 'v1', 'repo_url': 'https://huggingface.co/janai-hq/trinity-v1'}]
|
||||||
expect.tags=['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl']
|
expect.tags=['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl']
|
||||||
expect.languages=['en']
|
expect.languages=['en']
|
||||||
expect.datasets=['teknium/OpenHermes-2.5']
|
expect.datasets=['teknium/OpenHermes-2.5']
|
||||||
|
@ -97,12 +106,15 @@ class TestMetadataMethod(unittest.TestCase):
|
||||||
|
|
||||||
def test_apply_metadata_heuristic_from_hf_parameters(self):
|
def test_apply_metadata_heuristic_from_hf_parameters(self):
|
||||||
hf_params = {"_name_or_path": "./hermes-2-pro-llama-3-8b-DPO"}
|
hf_params = {"_name_or_path": "./hermes-2-pro-llama-3-8b-DPO"}
|
||||||
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), None, hf_params, None)
|
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card=None, hf_params=hf_params, model_path=None)
|
||||||
expect = gguf.Metadata(name='Hermes 2 Pro Llama 3 8b DPO', author=None, version=None, organization=None, finetune='DPO', basename='hermes-2-pro-llama-3', description=None, quantized_by=None, size_label='8b', url=None, doi=None, uuid=None, repo_url=None, license=None, license_name=None, license_link=None, base_models=None, tags=None, languages=None, datasets=None)
|
expect = gguf.Metadata(name='Hermes 2 Pro Llama 3 8b DPO', finetune='DPO', basename='hermes-2-pro-llama-3', size_label='8B')
|
||||||
self.assertEqual(got, expect)
|
self.assertEqual(got, expect)
|
||||||
|
|
||||||
def test_apply_metadata_heuristic_from_model_dir(self):
|
def test_apply_metadata_heuristic_from_model_dir(self):
|
||||||
model_dir_path = Path("./hermes-2-pro-llama-3-8b-DPO")
|
model_dir_path = Path("./hermes-2-pro-llama-3-8b-DPO")
|
||||||
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), None, None, model_dir_path)
|
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card=None, hf_params=None, model_path=model_dir_path)
|
||||||
expect = gguf.Metadata(name='Hermes 2 Pro Llama 3 8b DPO', author=None, version=None, organization=None, finetune='DPO', basename='hermes-2-pro-llama-3', description=None, quantized_by=None, size_label='8b', url=None, doi=None, uuid=None, repo_url=None, license=None, license_name=None, license_link=None, base_models=None, tags=None, languages=None, datasets=None)
|
expect = gguf.Metadata(name='Hermes 2 Pro Llama 3 8b DPO', finetune='DPO', basename='hermes-2-pro-llama-3', size_label='8B')
|
||||||
self.assertEqual(got, expect)
|
self.assertEqual(got, expect)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue