Merge branch 'master' into compilade/refactor-kv-cache

This commit is contained in:
Francis Couture-Harpin 2024-08-31 21:06:32 -04:00
commit bc320ef66d
395 changed files with 57725 additions and 169970 deletions

View file

@ -5,3 +5,5 @@ from .gguf_writer import *
from .quants import *
from .tensor_mapping import *
from .vocab import *
from .utility import *
from .metadata import *

View file

@ -19,18 +19,60 @@ GGML_QUANT_VERSION = 2 # GGML_QNT_VERSION from ggml.h
class Keys:
class General:
ARCHITECTURE = "general.architecture"
QUANTIZATION_VERSION = "general.quantization_version"
ALIGNMENT = "general.alignment"
NAME = "general.name"
AUTHOR = "general.author"
VERSION = "general.version"
URL = "general.url"
DESCRIPTION = "general.description"
LICENSE = "general.license"
SOURCE_URL = "general.source.url"
SOURCE_HF_REPO = "general.source.huggingface.repository"
FILE_TYPE = "general.file_type"
TYPE = "general.type"
ARCHITECTURE = "general.architecture"
QUANTIZATION_VERSION = "general.quantization_version"
ALIGNMENT = "general.alignment"
FILE_TYPE = "general.file_type"
# Authorship Metadata
NAME = "general.name"
AUTHOR = "general.author"
VERSION = "general.version"
ORGANIZATION = "general.organization"
FINETUNE = "general.finetune"
BASENAME = "general.basename"
DESCRIPTION = "general.description"
QUANTIZED_BY = "general.quantized_by"
SIZE_LABEL = "general.size_label"
# Licensing details
LICENSE = "general.license"
LICENSE_NAME = "general.license.name"
LICENSE_LINK = "general.license.link"
# Typically represents the converted GGUF repo (Unless native)
URL = "general.url" # Model Website/Paper
DOI = "general.doi"
UUID = "general.uuid"
REPO_URL = "general.repo_url" # Model Source Repository (git/svn/etc...)
# Model Source during conversion
SOURCE_URL = "general.source.url" # Model Website/Paper
SOURCE_DOI = "general.source.doi"
SOURCE_UUID = "general.source.uuid"
SOURCE_REPO_URL = "general.source.repo_url" # Model Source Repository (git/svn/etc...)
# Base Model Source. There can be more than one source if it's a merged
# model like with 'Mistral-7B-Merge-14-v0.1'. This will assist in
# tracing linage of models as it is finetuned or merged over time.
BASE_MODEL_COUNT = "general.base_model.count"
BASE_MODEL_NAME = "general.base_model.{id}.name"
BASE_MODEL_AUTHOR = "general.base_model.{id}.author"
BASE_MODEL_VERSION = "general.base_model.{id}.version"
BASE_MODEL_ORGANIZATION = "general.base_model.{id}.organization"
BASE_MODEL_URL = "general.base_model.{id}.url" # Model Website/Paper
BASE_MODEL_DOI = "general.base_model.{id}.doi"
BASE_MODEL_UUID = "general.base_model.{id}.uuid"
BASE_MODEL_REPO_URL = "general.base_model.{id}.repo_url" # Model Source Repository (git/svn/etc...)
# Array based KV stores
TAGS = "general.tags"
LANGUAGES = "general.languages"
DATASETS = "general.datasets"
class LLM:
VOCAB_SIZE = "{arch}.vocab_size"
@ -88,6 +130,7 @@ class Keys:
INNER_SIZE = "{arch}.ssm.inner_size"
STATE_SIZE = "{arch}.ssm.state_size"
TIME_STEP_RANK = "{arch}.ssm.time_step_rank"
DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms"
class Tokenizer:
MODEL = "tokenizer.ggml.model"
@ -119,13 +162,22 @@ class Keys:
SUFFIX_ID = "tokenizer.ggml.suffix_token_id"
MIDDLE_ID = "tokenizer.ggml.middle_token_id"
EOT_ID = "tokenizer.ggml.eot_token_id"
EOM_ID = "tokenizer.ggml.eom_token_id"
class Adapter:
TYPE = "adapter.type"
LORA_ALPHA = "adapter.lora.alpha"
#
# recommended mapping of model tensor names for storage in gguf
#
class GGUFType:
MODEL = "model"
ADAPTER = "adapter"
class MODEL_ARCH(IntEnum):
LLAMA = auto()
FALCON = auto()
@ -164,9 +216,13 @@ class MODEL_ARCH(IntEnum):
OPENELM = auto()
ARCTIC = auto()
DEEPSEEK2 = auto()
CHATGLM = auto()
BITNET = auto()
T5 = auto()
T5ENCODER = auto()
JAIS = auto()
NEMOTRON = auto()
EXAONE = auto()
class MODEL_TENSOR(IntEnum):
@ -294,9 +350,13 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.OPENELM: "openelm",
MODEL_ARCH.ARCTIC: "arctic",
MODEL_ARCH.DEEPSEEK2: "deepseek2",
MODEL_ARCH.CHATGLM: "chatglm",
MODEL_ARCH.BITNET: "bitnet",
MODEL_ARCH.T5: "t5",
MODEL_ARCH.T5ENCODER: "t5encoder",
MODEL_ARCH.JAIS: "jais",
MODEL_ARCH.NEMOTRON: "nemotron",
MODEL_ARCH.EXAONE: "exaone",
}
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@ -960,6 +1020,18 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
],
MODEL_ARCH.CHATGLM : [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.BITNET: [
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
@ -1007,6 +1079,21 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.ENC_FFN_UP,
MODEL_TENSOR.ENC_OUTPUT_NORM,
],
MODEL_ARCH.T5ENCODER: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ENC_ATTN_NORM,
MODEL_TENSOR.ENC_ATTN_Q,
MODEL_TENSOR.ENC_ATTN_K,
MODEL_TENSOR.ENC_ATTN_V,
MODEL_TENSOR.ENC_ATTN_OUT,
MODEL_TENSOR.ENC_ATTN_REL_B,
MODEL_TENSOR.ENC_FFN_NORM,
MODEL_TENSOR.ENC_FFN_GATE,
MODEL_TENSOR.ENC_FFN_DOWN,
MODEL_TENSOR.ENC_FFN_UP,
MODEL_TENSOR.ENC_OUTPUT_NORM,
],
MODEL_ARCH.JAIS: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
@ -1019,6 +1106,37 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.NEMOTRON: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_ROT_EMBD,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.EXAONE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_ROT_EMBD,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
# TODO
}
@ -1056,6 +1174,13 @@ MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_ROT_EMBD,
],
MODEL_ARCH.CHATGLM: [
MODEL_TENSOR.ROPE_FREQS,
],
MODEL_ARCH.NEMOTRON: [
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_ROT_EMBD,
],
}
#
@ -1114,6 +1239,9 @@ class GGMLQuantizationType(IntEnum):
F64 = 28
IQ1_M = 29
BF16 = 30
Q4_0_4_4 = 31
Q4_0_4_8 = 32
Q4_0_8_8 = 33
# TODO: add GGMLFileType from ggml_ftype in ggml.h
@ -1126,7 +1254,7 @@ class LlamaFileType(IntEnum):
MOSTLY_F16 = 1 # except 1d tensors
MOSTLY_Q4_0 = 2 # except 1d tensors
MOSTLY_Q4_1 = 3 # except 1d tensors
MOSTLY_Q4_1_SOME_F16 = 4 # tok_embeddings.weight and output.weight are F16
# MOSTLY_Q4_1_SOME_F16 = 4 # tok_embeddings.weight and output.weight are F16
# MOSTLY_Q4_2 = 5 # support has been removed
# MOSTLY_Q4_3 = 6 # support has been removed
MOSTLY_Q8_0 = 7 # except 1d tensors
@ -1155,6 +1283,9 @@ class LlamaFileType(IntEnum):
MOSTLY_IQ4_XS = 30 # except 1d tensors
MOSTLY_IQ1_M = 31 # except 1d tensors
MOSTLY_BF16 = 32 # except 1d tensors
MOSTLY_Q4_0_4_4 = 33 # except 1d tensors
MOSTLY_Q4_0_4_8 = 34 # except 1d tensors
MOSTLY_Q4_0_8_8 = 35 # except 1d tensors
GUESSED = 1024 # not specified in the model file
@ -1228,6 +1359,9 @@ GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = {
GGMLQuantizationType.F64: (1, 8),
GGMLQuantizationType.IQ1_M: (256, QK_K // 8 + QK_K // 16 + QK_K // 32),
GGMLQuantizationType.BF16: (1, 2),
GGMLQuantizationType.Q4_0_4_4:(32, 2 + 16),
GGMLQuantizationType.Q4_0_4_8:(32, 2 + 16),
GGMLQuantizationType.Q4_0_8_8:(32, 2 + 16),
}
@ -1243,7 +1377,6 @@ KEY_GENERAL_URL = Keys.General.URL
KEY_GENERAL_DESCRIPTION = Keys.General.DESCRIPTION
KEY_GENERAL_LICENSE = Keys.General.LICENSE
KEY_GENERAL_SOURCE_URL = Keys.General.SOURCE_URL
KEY_GENERAL_SOURCE_HF_REPO = Keys.General.SOURCE_HF_REPO
KEY_GENERAL_FILE_TYPE = Keys.General.FILE_TYPE
# LLM
@ -1276,6 +1409,7 @@ KEY_SSM_CONV_KERNEL = Keys.SSM.CONV_KERNEL
KEY_SSM_INNER_SIZE = Keys.SSM.INNER_SIZE
KEY_SSM_STATE_SIZE = Keys.SSM.STATE_SIZE
KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK
KEY_SSM_DT_B_C_RMS = Keys.SSM.DT_B_C_RMS
# tokenization
KEY_TOKENIZER_MODEL = Keys.Tokenizer.MODEL
@ -1297,3 +1431,4 @@ KEY_TOKENIZER_PRIFIX_ID = Keys.Tokenizer.PREFIX_ID
KEY_TOKENIZER_SUFFIX_ID = Keys.Tokenizer.SUFFIX_ID
KEY_TOKENIZER_MIDDLE_ID = Keys.Tokenizer.MIDDLE_ID
KEY_TOKENIZER_EOT_ID = Keys.Tokenizer.EOT_ID
KEY_TOKENIZER_EOM_ID = Keys.Tokenizer.EOM_ID

View file

@ -67,7 +67,7 @@ class ReaderTensor(NamedTuple):
class GGUFReader:
# I - same as host, S - swapped
byte_order: Literal['I'] | Literal['S'] = 'I'
byte_order: Literal['I', 'S'] = 'I'
alignment: int = GGUF_DEFAULT_ALIGNMENT
data_offset: int
@ -86,7 +86,7 @@ class GGUFReader:
GGUFValueType.BOOL: np.bool_,
}
def __init__(self, path: os.PathLike[str] | str, mode: Literal['r'] | Literal['r+'] | Literal['c'] = 'r'):
def __init__(self, path: os.PathLike[str] | str, mode: Literal['r', 'r+', 'c'] = 'r'):
self.data = np.memmap(path, mode = mode)
offs = 0
@ -140,7 +140,7 @@ class GGUFReader:
return self.tensors[idx]
def _get(
self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I'] | Literal['S'] | Literal['<'] = None,
self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I', 'S', '<'] = None,
) -> npt.NDArray[Any]:
count = int(count)
itemsize = int(np.empty([], dtype = dtype).itemsize)

View file

@ -7,6 +7,7 @@ import struct
import tempfile
from dataclasses import dataclass
from enum import Enum, auto
from math import prod
from pathlib import Path
from io import BufferedWriter
from typing import IO, Any, Sequence, Mapping
@ -106,6 +107,53 @@ class GGUFWriter:
self.add_architecture()
def get_total_parameter_count(self) -> tuple[int, int, int, int]:
total_params = 0
shared_params = 0
expert_params = 0
expert_sum = 0
n_expert_tensors = 0
last_lora_a: tuple[str, TensorInfo] | None = None
for tensors in self.tensors:
for name, info in tensors.items():
shape = info.shape
if name.endswith(".lora_a"):
last_lora_a = (name, info)
continue
elif name.endswith(".lora_b"):
if last_lora_a is None or last_lora_a[0] != name[:-1] + "a":
# Bail when the LoRA pair can't be found trivially
logger.warning("can't measure LoRA size correctly, tensor order is unusual")
return 0, 0, 0, 0
else:
shape = (*shape[:-1], last_lora_a[1].shape[-1])
size = prod(shape)
if "_exps." in name:
expert_params += (size // shape[-3])
expert_sum += shape[-3]
n_expert_tensors += 1
else:
shared_params += size
total_params += size
# Hopefully this should work even for variable-expert-count models
expert_count = (expert_sum // n_expert_tensors) if n_expert_tensors > 0 else 0
# Negate the total to signal it's likely not exact
if last_lora_a is not None:
total_params = -total_params
# NOTE: keep the output in the same order as accepted by 'size_label' in gguf-py/gguf/utility.py
return total_params, shared_params, expert_params, expert_count
def format_shard_names(self, path: Path) -> list[Path]:
if len(self.tensors) == 1:
return [path]
@ -115,6 +163,7 @@ class GGUFWriter:
if self.state is WriterState.EMPTY and self.fout is not None and (path is None or path == self.path):
# allow calling this multiple times as long as the path is the same
return
if self.state is not WriterState.NO_FILE:
raise ValueError(f'Expected output file to be not yet opened, got {self.state}')
@ -136,6 +185,8 @@ class GGUFWriter:
if self.dry_run:
logger.info("Dry run, not writing files")
for name in filenames:
print(name) # noqa: NP100
exit()
return filenames
@ -261,6 +312,8 @@ class GGUFWriter:
self.add_key_value(key, val, GGUFValueType.STRING)
def add_array(self, key: str, val: Sequence[Any]) -> None:
if len(val) == 0:
return
self.add_key_value(key, val, GGUFValueType.ARRAY)
@staticmethod
@ -424,32 +477,18 @@ class GGUFWriter:
fout.close()
self.fout = None
def add_type(self, type_name: str) -> None:
self.add_string(Keys.General.TYPE, type_name)
def add_architecture(self) -> None:
self.add_string(Keys.General.ARCHITECTURE, self.arch)
def add_author(self, author: str) -> None:
self.add_string(Keys.General.AUTHOR, author)
def add_quantization_version(self, quantization_version: int) -> None:
self.add_uint32(Keys.General.QUANTIZATION_VERSION, quantization_version)
def add_version(self, version: str) -> None:
self.add_string(Keys.General.VERSION, version)
def add_tensor_data_layout(self, layout: str) -> None:
self.add_string(Keys.LLM.TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
def add_url(self, url: str) -> None:
self.add_string(Keys.General.URL, url)
def add_description(self, description: str) -> None:
self.add_string(Keys.General.DESCRIPTION, description)
def add_licence(self, licence: str) -> None:
self.add_string(Keys.General.LICENSE, licence)
def add_source_url(self, url: str) -> None:
self.add_string(Keys.General.SOURCE_URL, url)
def add_source_hf_repo(self, repo: str) -> None:
self.add_string(Keys.General.SOURCE_HF_REPO, repo)
def add_custom_alignment(self, alignment: int) -> None:
self.data_alignment = alignment
self.add_uint32(Keys.General.ALIGNMENT, alignment)
def add_file_type(self, ftype: int) -> None:
self.add_uint32(Keys.General.FILE_TYPE, ftype)
@ -457,13 +496,101 @@ class GGUFWriter:
def add_name(self, name: str) -> None:
self.add_string(Keys.General.NAME, name)
def add_quantization_version(self, quantization_version: int) -> None:
self.add_uint32(
Keys.General.QUANTIZATION_VERSION, quantization_version)
def add_author(self, author: str) -> None:
self.add_string(Keys.General.AUTHOR, author)
def add_custom_alignment(self, alignment: int) -> None:
self.data_alignment = alignment
self.add_uint32(Keys.General.ALIGNMENT, alignment)
def add_version(self, version: str) -> None:
self.add_string(Keys.General.VERSION, version)
def add_organization(self, organization: str) -> None:
self.add_string(Keys.General.ORGANIZATION, organization)
def add_finetune(self, finetune: str) -> None:
self.add_string(Keys.General.FINETUNE, finetune)
def add_basename(self, basename: str) -> None:
self.add_string(Keys.General.BASENAME, basename)
def add_description(self, description: str) -> None:
self.add_string(Keys.General.DESCRIPTION, description)
def add_quantized_by(self, quantized: str) -> None:
self.add_string(Keys.General.QUANTIZED_BY, quantized)
def add_size_label(self, size_label: str) -> None:
self.add_string(Keys.General.SIZE_LABEL, size_label)
def add_license(self, license: str) -> None:
self.add_string(Keys.General.LICENSE, license)
def add_license_name(self, license: str) -> None:
self.add_string(Keys.General.LICENSE_NAME, license)
def add_license_link(self, license: str) -> None:
self.add_string(Keys.General.LICENSE_LINK, license)
def add_url(self, url: str) -> None:
self.add_string(Keys.General.URL, url)
def add_doi(self, doi: str) -> None:
self.add_string(Keys.General.DOI, doi)
def add_uuid(self, uuid: str) -> None:
self.add_string(Keys.General.UUID, uuid)
def add_repo_url(self, repo_url: str) -> None:
self.add_string(Keys.General.REPO_URL, repo_url)
def add_source_url(self, url: str) -> None:
self.add_string(Keys.General.SOURCE_URL, url)
def add_source_doi(self, doi: str) -> None:
self.add_string(Keys.General.SOURCE_DOI, doi)
def add_source_uuid(self, uuid: str) -> None:
self.add_string(Keys.General.SOURCE_UUID, uuid)
def add_source_repo_url(self, repo_url: str) -> None:
self.add_string(Keys.General.SOURCE_REPO_URL, repo_url)
def add_base_model_count(self, source_count: int) -> None:
self.add_uint32(Keys.General.BASE_MODEL_COUNT, source_count)
def add_base_model_name(self, source_id: int, name: str) -> None:
self.add_string(Keys.General.BASE_MODEL_NAME.format(id=source_id), name)
def add_base_model_author(self, source_id: int, author: str) -> None:
self.add_string(Keys.General.BASE_MODEL_AUTHOR.format(id=source_id), author)
def add_base_model_version(self, source_id: int, version: str) -> None:
self.add_string(Keys.General.BASE_MODEL_VERSION.format(id=source_id), version)
def add_base_model_organization(self, source_id: int, organization: str) -> None:
self.add_string(Keys.General.BASE_MODEL_ORGANIZATION.format(id=source_id), organization)
def add_base_model_url(self, source_id: int, url: str) -> None:
self.add_string(Keys.General.BASE_MODEL_URL.format(id=source_id), url)
def add_base_model_doi(self, source_id: int, doi: str) -> None:
self.add_string(Keys.General.BASE_MODEL_DOI.format(id=source_id), doi)
def add_base_model_uuid(self, source_id: int, uuid: str) -> None:
self.add_string(Keys.General.BASE_MODEL_UUID.format(id=source_id), uuid)
def add_base_model_repo_url(self, source_id: int, repo_url: str) -> None:
self.add_string(Keys.General.BASE_MODEL_REPO_URL.format(id=source_id), repo_url)
def add_tags(self, tags: Sequence[str]) -> None:
self.add_array(Keys.General.TAGS, tags)
def add_languages(self, languages: Sequence[str]) -> None:
self.add_array(Keys.General.LANGUAGES, languages)
def add_datasets(self, datasets: Sequence[str]) -> None:
self.add_array(Keys.General.DATASETS, datasets)
def add_tensor_data_layout(self, layout: str) -> None:
self.add_string(Keys.LLM.TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
def add_vocab_size(self, size: int) -> None:
self.add_uint32(Keys.LLM.VOCAB_SIZE.format(arch=self.arch), size)
@ -603,6 +730,9 @@ class GGUFWriter:
def add_ssm_time_step_rank(self, value: int) -> None:
self.add_uint32(Keys.SSM.TIME_STEP_RANK.format(arch=self.arch), value)
def add_ssm_dt_b_c_rms(self, value: bool) -> None:
self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value)
def add_tokenizer_model(self, model: str) -> None:
self.add_string(Keys.Tokenizer.MODEL, model)
@ -701,6 +831,9 @@ class GGUFWriter:
def add_eot_token_id(self, id: int) -> None:
self.add_uint32(Keys.Tokenizer.EOT_ID, id)
def add_eom_token_id(self, id: int) -> None:
self.add_uint32(Keys.Tokenizer.EOM_ID, id)
def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
pack_prefix = ''
if not skip_pack_prefix:
@ -720,7 +853,14 @@ class GGUFWriter:
encoded_val = val.encode("utf-8") if isinstance(val, str) else val
kv_data += self._pack("Q", len(encoded_val))
kv_data += encoded_val
elif vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and val:
elif vtype == GGUFValueType.ARRAY:
if not isinstance(val, Sequence):
raise ValueError("Invalid GGUF metadata array, expecting sequence")
if len(val) == 0:
raise ValueError("Invalid GGUF metadata array. Empty array")
if isinstance(val, bytes):
ltype = GGUFValueType.UINT8
else:

View file

@ -3,10 +3,8 @@ from abc import ABC, ABCMeta, abstractmethod
import logging
from typing import Any, Callable
from collections import deque
import numpy as np
from numpy._typing import _Shape
from numpy.typing import DTypeLike
@ -16,16 +14,16 @@ logger = logging.getLogger(__name__)
class LazyMeta(ABCMeta):
def __new__(cls, name: str, bases: tuple[type, ...], namespace: dict[str, Any], **kwargs):
def __getattr__(self, __name: str) -> Any:
meta_attr = getattr(self._meta, __name)
def __getattr__(self, name: str) -> Any:
meta_attr = getattr(self._meta, name)
if callable(meta_attr):
return type(self)._wrap_fn(
(lambda s, *args, **kwargs: getattr(s, __name)(*args, **kwargs)),
(lambda s, *args, **kwargs: getattr(s, name)(*args, **kwargs)),
use_self=self,
)
elif isinstance(meta_attr, self._tensor_type):
# e.g. self.T with torch.Tensor should still be wrapped
return type(self)._wrap_fn(lambda s: getattr(s, __name))(self)
return type(self)._wrap_fn(lambda s: getattr(s, name))(self)
else:
# no need to wrap non-tensor properties,
# and they likely don't depend on the actual contents of the tensor
@ -75,20 +73,18 @@ class LazyBase(ABC, metaclass=LazyMeta):
_tensor_type: type
_meta: Any
_data: Any | None
_lazy: deque[LazyBase] # shared within a graph, to avoid deep recursion when making eager
_args: tuple
_func: Callable[[tuple], Any] | None
_kwargs: dict[str, Any]
_func: Callable[[Any], Any] | None
def __init__(self, *, meta: Any, data: Any | None = None, lazy: deque[LazyBase] | None = None, args: tuple = (), func: Callable[[tuple], Any] | None = None):
def __init__(self, *, meta: Any, data: Any | None = None, args: tuple = (), kwargs: dict[str, Any] | None = None, func: Callable[[Any], Any] | None = None):
super().__init__()
self._meta = meta
self._data = data
self._lazy = lazy if lazy is not None else deque()
self._args = args
self._kwargs = kwargs if kwargs is not None else {}
self._func = func
assert self._func is not None or self._data is not None
if self._data is None:
self._lazy.append(self)
def __init_subclass__(cls) -> None:
if "_tensor_type" not in cls.__dict__:
@ -118,6 +114,7 @@ class LazyBase(ABC, metaclass=LazyMeta):
args = ((use_self,) if use_self is not None else ()) + args
meta_args = LazyBase._recurse_apply(args, lambda t: t._meta)
# TODO: maybe handle tensors in kwargs too
if isinstance(meta_noop, bool) and not meta_noop:
try:
@ -141,21 +138,7 @@ class LazyBase(ABC, metaclass=LazyMeta):
res = cls.meta_with_dtype_and_shape(meta_noop, res.shape)
if isinstance(res, cls._tensor_type):
def collect_replace(t: LazyBase):
if collect_replace.shared_lazy is None:
collect_replace.shared_lazy = t._lazy
else:
collect_replace.shared_lazy.extend(t._lazy)
t._lazy = collect_replace.shared_lazy
# emulating a static variable
collect_replace.shared_lazy = None
LazyBase._recurse_apply(args, collect_replace)
shared_lazy = collect_replace.shared_lazy
return cls(meta=cls.eager_to_meta(res), lazy=shared_lazy, args=args, func=lambda a: fn(*a, **kwargs))
return cls(meta=cls.eager_to_meta(res), args=args, kwargs=kwargs, func=fn)
else:
del res # not needed
# non-tensor return likely relies on the contents of the args
@ -167,25 +150,18 @@ class LazyBase(ABC, metaclass=LazyMeta):
@classmethod
def to_eager(cls, t: Any) -> Any:
def simple_to_eager(_t: LazyBase) -> Any:
def already_eager_to_eager(_t: LazyBase) -> Any:
assert _t._data is not None
if _t._data is not None:
return _t._data
while _t._data is None:
lt = _t._lazy.popleft()
if lt._data is not None:
# Lazy tensor did not belong in the lazy queue.
# Weirdly only happens with Bloom models...
# likely because tensors aren't unique in the queue.
# The final output is still the same as in eager mode,
# so it's safe to ignore this.
continue
assert lt._func is not None
lt._args = cls._recurse_apply(lt._args, already_eager_to_eager)
lt._data = lt._func(lt._args)
# sanity check
assert lt._data.dtype == lt._meta.dtype
assert lt._data.shape == lt._meta.shape
# NOTE: there's a recursion limit in Python (usually 1000)
assert _t._func is not None
_t._args = cls._recurse_apply(_t._args, simple_to_eager)
_t._data = _t._func(*_t._args, **_t._kwargs)
# sanity check
assert _t._data is not None
assert _t._data.dtype == _t._meta.dtype
assert _t._data.shape == _t._meta.shape
return _t._data
@ -204,7 +180,7 @@ class LazyBase(ABC, metaclass=LazyMeta):
@classmethod
def from_eager(cls, t: Any) -> Any:
if type(t) is cls:
# already eager
# already lazy
return t
elif isinstance(t, cls._tensor_type):
return cls(meta=cls.eager_to_meta(t), data=t)
@ -215,8 +191,10 @@ class LazyBase(ABC, metaclass=LazyMeta):
class LazyNumpyTensor(LazyBase):
_tensor_type = np.ndarray
shape: tuple[int, ...] # Makes the type checker happy in quants.py
@classmethod
def meta_with_dtype_and_shape(cls, dtype: DTypeLike, shape: _Shape) -> np.ndarray[Any, Any]:
def meta_with_dtype_and_shape(cls, dtype: DTypeLike, shape: tuple[int, ...]) -> np.ndarray[Any, Any]:
# The initial idea was to use np.nan as the fill value,
# but non-float types like np.int16 can't use that.
# So zero it is.
@ -226,8 +204,7 @@ class LazyNumpyTensor(LazyBase):
def astype(self, dtype, *args, **kwargs):
meta = type(self).meta_with_dtype_and_shape(dtype, self._meta.shape)
full_args = (self, dtype,) + args
# very important to pass the shared _lazy deque, or else there's an infinite loop somewhere.
return type(self)(meta=meta, args=full_args, lazy=self._lazy, func=(lambda a: a[0].astype(*a[1:], **kwargs)))
return type(self)(meta=meta, args=full_args, kwargs=kwargs, func=(lambda a, *args, **kwargs: a.astype(*args, **kwargs)))
def tofile(self, *args, **kwargs):
eager = LazyNumpyTensor.to_eager(self)

510
gguf-py/gguf/metadata.py Normal file
View file

@ -0,0 +1,510 @@
from __future__ import annotations
import re
import json
import yaml
import logging
from pathlib import Path
from typing import Any, Literal, Optional
from dataclasses import dataclass
from .constants import Keys
import gguf
logger = logging.getLogger("metadata")
@dataclass
class Metadata:
# Authorship Metadata to be written to GGUF KV Store
name: Optional[str] = None
author: Optional[str] = None
version: Optional[str] = None
organization: Optional[str] = None
finetune: Optional[str] = None
basename: Optional[str] = None
description: Optional[str] = None
quantized_by: Optional[str] = None
size_label: Optional[str] = None
url: Optional[str] = None
doi: Optional[str] = None
uuid: Optional[str] = None
repo_url: Optional[str] = None
source_url: Optional[str] = None
source_doi: Optional[str] = None
source_uuid: Optional[str] = None
source_repo_url: Optional[str] = None
license: Optional[str] = None
license_name: Optional[str] = None
license_link: Optional[str] = None
base_models: Optional[list[dict]] = None
tags: Optional[list[str]] = None
languages: Optional[list[str]] = None
datasets: Optional[list[str]] = None
@staticmethod
def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Path] = None, model_name: Optional[str] = None, total_params: int = 0) -> Metadata:
# This grabs as many contextual authorship metadata as possible from the model repository
# making any conversion as required to match the gguf kv store metadata format
# as well as giving users the ability to override any authorship metadata that may be incorrect
# Create a new Metadata instance
metadata = Metadata()
model_card = Metadata.load_model_card(model_path)
hf_params = Metadata.load_hf_parameters(model_path)
# TODO: load adapter_config.json when possible, it usually contains the base model of the LoRA adapter
# heuristics
metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_path, total_params)
# Metadata Override File Provided
# This is based on LLM_KV_NAMES mapping in llama.cpp
metadata_override = Metadata.load_metadata_override(metadata_override_path)
metadata.name = metadata_override.get(Keys.General.NAME, metadata.name)
metadata.author = metadata_override.get(Keys.General.AUTHOR, metadata.author)
metadata.version = metadata_override.get(Keys.General.VERSION, metadata.version)
metadata.organization = metadata_override.get(Keys.General.ORGANIZATION, metadata.organization)
metadata.finetune = metadata_override.get(Keys.General.FINETUNE, metadata.finetune)
metadata.basename = metadata_override.get(Keys.General.BASENAME, metadata.basename)
metadata.description = metadata_override.get(Keys.General.DESCRIPTION, metadata.description)
metadata.quantized_by = metadata_override.get(Keys.General.QUANTIZED_BY, metadata.quantized_by)
metadata.size_label = metadata_override.get(Keys.General.SIZE_LABEL, metadata.size_label)
metadata.license_name = metadata_override.get(Keys.General.LICENSE_NAME, metadata.license_name)
metadata.license_link = metadata_override.get(Keys.General.LICENSE_LINK, metadata.license_link)
metadata.url = metadata_override.get(Keys.General.URL, metadata.url)
metadata.doi = metadata_override.get(Keys.General.DOI, metadata.doi)
metadata.uuid = metadata_override.get(Keys.General.UUID, metadata.uuid)
metadata.repo_url = metadata_override.get(Keys.General.REPO_URL, metadata.repo_url)
metadata.source_url = metadata_override.get(Keys.General.SOURCE_URL, metadata.source_url)
metadata.source_doi = metadata_override.get(Keys.General.SOURCE_DOI, metadata.source_doi)
metadata.source_uuid = metadata_override.get(Keys.General.SOURCE_UUID, metadata.source_uuid)
metadata.source_repo_url = metadata_override.get(Keys.General.SOURCE_REPO_URL, metadata.source_repo_url)
# Base Models is received here as an array of models
metadata.base_models = metadata_override.get("general.base_models", metadata.base_models)
metadata.tags = metadata_override.get(Keys.General.TAGS, metadata.tags)
metadata.languages = metadata_override.get(Keys.General.LANGUAGES, metadata.languages)
metadata.datasets = metadata_override.get(Keys.General.DATASETS, metadata.datasets)
# Direct Metadata Override (via direct cli argument)
if model_name is not None:
metadata.name = model_name
return metadata
@staticmethod
def load_metadata_override(metadata_override_path: Optional[Path] = None) -> dict[str, Any]:
if metadata_override_path is None or not metadata_override_path.is_file():
return {}
with open(metadata_override_path, "r", encoding="utf-8") as f:
return json.load(f)
@staticmethod
def load_model_card(model_path: Optional[Path] = None) -> dict[str, Any]:
if model_path is None or not model_path.is_dir():
return {}
model_card_path = model_path / "README.md"
if not model_card_path.is_file():
return {}
# The model card metadata is assumed to always be in YAML
# ref: https://github.com/huggingface/transformers/blob/a5c642fe7a1f25d3bdcd76991443ba6ff7ee34b2/src/transformers/modelcard.py#L468-L473
with open(model_card_path, "r", encoding="utf-8") as f:
if f.readline() == "---\n":
raw = f.read().partition("---\n")[0]
data = yaml.safe_load(raw)
if isinstance(data, dict):
return data
else:
logger.error(f"while reading YAML model card frontmatter, data is {type(data)} instead of dict")
return {}
else:
return {}
@staticmethod
def load_hf_parameters(model_path: Optional[Path] = None) -> dict[str, Any]:
if model_path is None or not model_path.is_dir():
return {}
config_path = model_path / "config.json"
if not config_path.is_file():
return {}
with open(config_path, "r", encoding="utf-8") as f:
return json.load(f)
@staticmethod
def id_to_title(string):
# Convert capitalization into title form unless acronym or version number
return ' '.join([w.title() if w.islower() and not re.match(r'^(v\d+(?:\.\d+)*|\d.*)$', w) else w for w in string.strip().replace('-', ' ').split()])
@staticmethod
def get_model_id_components(model_id: Optional[str] = None, total_params: int = 0) -> tuple[str | None, str | None, str | None, str | None, str | None, str | None]:
# Huggingface often store model id as '<org>/<model name>'
# so let's parse it and apply some heuristics if possible for model name components
if model_id is None:
# model ID missing
return None, None, None, None, None, None
if ' ' in model_id:
# model ID is actually a normal human sentence
# which means its most likely a normal model name only
# not part of the hugging face naming standard, but whatever
return model_id, None, None, None, None, None
if '/' in model_id:
# model ID (huggingface style)
org_component, model_full_name_component = model_id.split('/', 1)
else:
# model ID but missing org components
org_component, model_full_name_component = None, model_id
# Check if we erroneously matched against './' or '../' etc...
if org_component is not None and len(org_component) > 0 and org_component[0] == '.':
org_component = None
name_parts: list[str] = model_full_name_component.split('-')
# Remove empty parts
for i in reversed(range(len(name_parts))):
if len(name_parts[i]) == 0:
del name_parts[i]
name_types: list[
set[Literal["basename", "size_label", "finetune", "version", "type"]]
] = [set() for _ in name_parts]
# Annotate the name
for i, part in enumerate(name_parts):
# Version
if re.fullmatch(r'(v|iter)?\d+([.]\d+)*', part, re.IGNORECASE):
name_types[i].add("version")
# Quant type (should not be there for base models, but still annotated)
elif re.fullmatch(r'i?q\d(_\w)*|b?fp?(16|32)', part, re.IGNORECASE):
name_types[i].add("type")
name_parts[i] = part.upper()
# Model size
elif i > 0 and re.fullmatch(r'(([A]|\d+[x])?\d+([._]\d+)?[KMBT][\d]?|small|mini|medium|large|x?xl)', part, re.IGNORECASE):
part = part.replace("_", ".")
# Handle weird bloom-7b1 notation
if part[-1].isdecimal():
part = part[:-2] + "." + part[-1] + part[-2]
# Normalize the size suffixes
if len(part) > 1 and part[-2].isdecimal():
if part[-1] in "kmbt":
part = part[:-1] + part[-1].upper()
if total_params != 0:
try:
label_params = float(part[:-1]) * pow(1000, " KMBT".find(part[-1]))
# Only use it as a size label if it's close or bigger than the model size
# Note that LoRA adapters don't necessarily include all layers,
# so this is why bigger label sizes are accepted.
# Do not use the size label when it's smaller than 1/8 of the model size
if (total_params < 0 and label_params < abs(total_params) // 8) or (
# Check both directions when the current model isn't a LoRA adapter
total_params > 0 and abs(label_params - total_params) > 7 * total_params // 8
):
# Likely a context length
name_types[i].add("finetune")
# Lowercase the size when it's a context length
part = part[:-1] + part[-1].lower()
except ValueError:
# Failed to convert the size label to float, use it anyway
pass
if len(name_types[i]) == 0:
name_types[i].add("size_label")
name_parts[i] = part
# Some easy to recognize finetune names
elif i > 0 and re.fullmatch(r'chat|instruct|vision|lora', part, re.IGNORECASE):
if total_params < 0 and part.lower() == "lora":
# ignore redundant "lora" in the finetune part when the output is a lora adapter
name_types[i].add("type")
else:
name_types[i].add("finetune")
# Ignore word-based size labels when there is at least a number-based one present
# TODO: should word-based size labels always be removed instead?
if any(c.isdecimal() for n, t in zip(name_parts, name_types) if "size_label" in t for c in n):
for n, t in zip(name_parts, name_types):
if "size_label" in t:
if all(c.isalpha() for c in n):
t.remove("size_label")
at_start = True
# Find the basename through the annotated name
for part, t in zip(name_parts, name_types):
if at_start and ((len(t) == 0 and part[0].isalpha()) or "version" in t):
t.add("basename")
else:
if at_start:
at_start = False
if len(t) == 0:
t.add("finetune")
# Remove the basename annotation from trailing version
for part, t in zip(reversed(name_parts), reversed(name_types)):
if "basename" in t and len(t) > 1:
t.remove("basename")
else:
break
basename = "-".join(n for n, t in zip(name_parts, name_types) if "basename" in t) or None
# Deduplicate size labels using order-preserving 'dict' ('set' seems to sort the keys)
size_label = "-".join(dict.fromkeys(s for s, t in zip(name_parts, name_types) if "size_label" in t).keys()) or None
finetune = "-".join(f for f, t in zip(name_parts, name_types) if "finetune" in t) or None
# TODO: should the basename version always be excluded?
# NOTE: multiple finetune versions are joined together
version = "-".join(v for v, t, in zip(name_parts, name_types) if "version" in t and "basename" not in t) or None
if size_label is None and finetune is None and version is None:
# Too ambiguous, output nothing
basename = None
return model_full_name_component, org_component, basename, finetune, version, size_label
@staticmethod
def apply_metadata_heuristic(metadata: Metadata, model_card: Optional[dict] = None, hf_params: Optional[dict] = None, model_path: Optional[Path] = None, total_params: int = 0) -> Metadata:
# Reference Model Card Metadata: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
# Model Card Heuristics
########################
if model_card is not None:
def use_model_card_metadata(metadata_key: str, model_card_key: str):
if model_card_key in model_card and getattr(metadata, metadata_key, None) is None:
setattr(metadata, metadata_key, model_card.get(model_card_key))
def use_array_model_card_metadata(metadata_key: str, model_card_key: str):
# Note: Will append rather than replace if already exist
tags_value = model_card.get(model_card_key, None)
if tags_value is None:
return
current_value = getattr(metadata, metadata_key, None)
if current_value is None:
current_value = []
if isinstance(tags_value, str):
current_value.append(tags_value)
elif isinstance(tags_value, list):
current_value.extend(tags_value)
setattr(metadata, metadata_key, current_value)
# LLAMA.cpp's direct internal convention
# (Definitely not part of hugging face formal/informal standard)
#########################################
use_model_card_metadata("name", "name")
use_model_card_metadata("author", "author")
use_model_card_metadata("version", "version")
use_model_card_metadata("organization", "organization")
use_model_card_metadata("description", "description")
use_model_card_metadata("finetune", "finetune")
use_model_card_metadata("basename", "basename")
use_model_card_metadata("size_label", "size_label")
use_model_card_metadata("source_url", "url")
use_model_card_metadata("source_doi", "doi")
use_model_card_metadata("source_uuid", "uuid")
use_model_card_metadata("source_repo_url", "repo_url")
# LLAMA.cpp's huggingface style convention
# (Definitely not part of hugging face formal/informal standard... but with model_ appended to match their style)
###########################################
use_model_card_metadata("name", "model_name")
use_model_card_metadata("author", "model_author")
use_model_card_metadata("version", "model_version")
use_model_card_metadata("organization", "model_organization")
use_model_card_metadata("description", "model_description")
use_model_card_metadata("finetune", "model_finetune")
use_model_card_metadata("basename", "model_basename")
use_model_card_metadata("size_label", "model_size_label")
use_model_card_metadata("source_url", "model_url")
use_model_card_metadata("source_doi", "model_doi")
use_model_card_metadata("source_uuid", "model_uuid")
use_model_card_metadata("source_repo_url", "model_repo_url")
# Hugging Face Direct Convention
#################################
# Not part of huggingface model card standard but notice some model creator using it
# such as TheBloke in 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF'
use_model_card_metadata("name", "model_name")
use_model_card_metadata("author", "model_creator")
use_model_card_metadata("basename", "model_type")
if "base_model" in model_card:
# This represents the parent models that this is based on
# Example: stabilityai/stable-diffusion-xl-base-1.0. Can also be a list (for merges)
# Example of merges: https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0.1/blob/main/README.md
metadata_base_models = []
base_model_value = model_card.get("base_model", None)
if base_model_value is not None:
if isinstance(base_model_value, str):
metadata_base_models.append(base_model_value)
elif isinstance(base_model_value, list):
metadata_base_models.extend(base_model_value)
if metadata.base_models is None:
metadata.base_models = []
for model_id in metadata_base_models:
# NOTE: model size of base model is assumed to be similar to the size of the current model
model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
base_model = {}
if model_full_name_component is not None:
base_model["name"] = Metadata.id_to_title(model_full_name_component)
if org_component is not None:
base_model["organization"] = Metadata.id_to_title(org_component)
if version is not None:
base_model["version"] = version
if org_component is not None and model_full_name_component is not None:
base_model["repo_url"] = f"https://huggingface.co/{org_component}/{model_full_name_component}"
metadata.base_models.append(base_model)
use_model_card_metadata("license", "license")
use_model_card_metadata("license_name", "license_name")
use_model_card_metadata("license_link", "license_link")
use_array_model_card_metadata("tags", "tags")
use_array_model_card_metadata("tags", "pipeline_tag")
use_array_model_card_metadata("languages", "languages")
use_array_model_card_metadata("languages", "language")
use_array_model_card_metadata("datasets", "datasets")
use_array_model_card_metadata("datasets", "dataset")
# Hugging Face Parameter Heuristics
####################################
if hf_params is not None:
hf_name_or_path = hf_params.get("_name_or_path")
if hf_name_or_path is not None and hf_name_or_path.count('/') <= 1:
# Use _name_or_path only if its actually a model name and not some computer path
# e.g. 'meta-llama/Llama-2-7b-hf'
model_id = hf_name_or_path
model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
if metadata.name is None and model_full_name_component is not None:
metadata.name = Metadata.id_to_title(model_full_name_component)
if metadata.organization is None and org_component is not None:
metadata.organization = Metadata.id_to_title(org_component)
if metadata.basename is None and basename is not None:
metadata.basename = basename
if metadata.finetune is None and finetune is not None:
metadata.finetune = finetune
if metadata.version is None and version is not None:
metadata.version = version
if metadata.size_label is None and size_label is not None:
metadata.size_label = size_label
# Directory Folder Name Fallback Heuristics
############################################
if model_path is not None:
model_id = model_path.name
model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
if metadata.name is None and model_full_name_component is not None:
metadata.name = Metadata.id_to_title(model_full_name_component)
if metadata.organization is None and org_component is not None:
metadata.organization = Metadata.id_to_title(org_component)
if metadata.basename is None and basename is not None:
metadata.basename = basename
if metadata.finetune is None and finetune is not None:
metadata.finetune = finetune
if metadata.version is None and version is not None:
metadata.version = version
if metadata.size_label is None and size_label is not None:
metadata.size_label = size_label
return metadata
def set_gguf_meta_model(self, gguf_writer: gguf.GGUFWriter):
assert self.name is not None
gguf_writer.add_name(self.name)
if self.author is not None:
gguf_writer.add_author(self.author)
if self.version is not None:
gguf_writer.add_version(self.version)
if self.organization is not None:
gguf_writer.add_organization(self.organization)
if self.finetune is not None:
gguf_writer.add_finetune(self.finetune)
if self.basename is not None:
gguf_writer.add_basename(self.basename)
if self.description is not None:
gguf_writer.add_description(self.description)
if self.quantized_by is not None:
gguf_writer.add_quantized_by(self.quantized_by)
if self.size_label is not None:
gguf_writer.add_size_label(self.size_label)
if self.license is not None:
gguf_writer.add_license(self.license)
if self.license_name is not None:
gguf_writer.add_license_name(self.license_name)
if self.license_link is not None:
gguf_writer.add_license_link(self.license_link)
if self.url is not None:
gguf_writer.add_url(self.url)
if self.doi is not None:
gguf_writer.add_doi(self.doi)
if self.uuid is not None:
gguf_writer.add_uuid(self.uuid)
if self.repo_url is not None:
gguf_writer.add_repo_url(self.repo_url)
if self.source_url is not None:
gguf_writer.add_source_url(self.source_url)
if self.source_doi is not None:
gguf_writer.add_source_doi(self.source_doi)
if self.source_uuid is not None:
gguf_writer.add_source_uuid(self.source_uuid)
if self.source_repo_url is not None:
gguf_writer.add_source_repo_url(self.source_repo_url)
if self.base_models is not None:
gguf_writer.add_base_model_count(len(self.base_models))
for key, base_model_entry in enumerate(self.base_models):
if "name" in base_model_entry:
gguf_writer.add_base_model_name(key, base_model_entry["name"])
if "author" in base_model_entry:
gguf_writer.add_base_model_author(key, base_model_entry["author"])
if "version" in base_model_entry:
gguf_writer.add_base_model_version(key, base_model_entry["version"])
if "organization" in base_model_entry:
gguf_writer.add_base_model_organization(key, base_model_entry["organization"])
if "url" in base_model_entry:
gguf_writer.add_base_model_url(key, base_model_entry["url"])
if "doi" in base_model_entry:
gguf_writer.add_base_model_doi(key, base_model_entry["doi"])
if "uuid" in base_model_entry:
gguf_writer.add_base_model_uuid(key, base_model_entry["uuid"])
if "repo_url" in base_model_entry:
gguf_writer.add_base_model_repo_url(key, base_model_entry["repo_url"])
if self.tags is not None:
gguf_writer.add_tags(self.tags)
if self.languages is not None:
gguf_writer.add_languages(self.languages)
if self.datasets is not None:
gguf_writer.add_datasets(self.datasets)

File diff suppressed because it is too large Load diff

View file

@ -10,10 +10,10 @@ class TensorNameMap:
# Token embeddings
MODEL_TENSOR.TOKEN_EMBD: (
"gpt_neox.embed_in", # gptneox
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone
"transformer.word_embeddings", # falcon
"word_embeddings", # bloom
"model.embed_tokens", # llama-hf
"model.embed_tokens", # llama-hf nemotron
"tok_embeddings", # llama-pth
"embeddings.word_embeddings", # bert nomic-bert
"language_model.embedding.word_embeddings", # persimmon
@ -24,6 +24,7 @@ class TensorNameMap:
"backbone.embedding", # mamba
"backbone.embeddings", # mamba-hf
"transformer.in_out_embed", # Grok
"embedding.word_embeddings", # chatglm
"transformer.token_embeddings", # openelm
"shared", # t5
),
@ -51,16 +52,17 @@ class TensorNameMap:
# Output
MODEL_TENSOR.OUTPUT: (
"embed_out", # gptneox
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone
"output", # llama-pth bloom internlm2
"word_embeddings_for_head", # persimmon
"lm_head.linear", # phi2
"output_layer", # chatglm
),
# Output norm
MODEL_TENSOR.OUTPUT_NORM: (
"gpt_neox.final_layer_norm", # gptneox
"transformer.ln_f", # gpt2 gpt-j falcon jais
"transformer.ln_f", # gpt2 gpt-j falcon jais exaone
"model.norm", # llama-hf baichuan internlm2
"norm", # llama-pth
"transformer.norm_f", # mpt dbrx
@ -71,12 +73,15 @@ class TensorNameMap:
"model.norm_f", # mamba-qbert
"backbone.norm_f", # mamba
"transformer.rms_norm", # Grok
"encoder.final_layernorm", # chatglm
"transformer.norm", # openelm
"model.norm", # nemotron
),
# Rope frequencies
MODEL_TENSOR.ROPE_FREQS: (
"rope.freqs", # llama-pth
"rotary_pos_emb.inv_freq", # chatglm
),
}
@ -84,12 +89,12 @@ class TensorNameMap:
# Attention norm
MODEL_TENSOR.ATTN_NORM: (
"gpt_neox.layers.{bid}.input_layernorm", # gptneox
"transformer.h.{bid}.ln_1", # gpt2 gpt-j refact qwen jais
"transformer.h.{bid}.ln_1", # gpt2 gpt-j refact qwen jais exaone
"transformer.blocks.{bid}.norm_1", # mpt
"transformer.h.{bid}.input_layernorm", # falcon7b
"h.{bid}.input_layernorm", # bloom
"transformer.h.{bid}.ln_mlp", # falcon40b
"model.layers.{bid}.input_layernorm", # llama-hf
"model.layers.{bid}.input_layernorm", # llama-hf nemotron
"layers.{bid}.attention_norm", # llama-pth
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
"model.layers.{bid}.ln1", # yi
@ -101,6 +106,7 @@ class TensorNameMap:
"backbone.layers.{bid}.norm", # mamba
"transformer.decoder_layer.{bid}.rms_norm", # Grok
"transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx
"encoder.layers.{bid}.input_layernorm", # chatglm
"transformer.layers.{bid}.attn_norm", # openelm
),
@ -124,42 +130,46 @@ class TensorNameMap:
"transformer.h.{bid}.mixer.Wqkv", # phi2
"encoder.layers.{bid}.attn.Wqkv", # nomic-bert
"model.layers.{bid}.self_attn.qkv_proj", # phi3
"encoder.layers.{bid}.self_attention.query_key_value", # chatglm
"transformer.layers.{bid}.attn.qkv_proj", # openelm
),
# Attention query
MODEL_TENSOR.ATTN_Q: (
"model.layers.{bid}.self_attn.q_proj", # llama-hf
"model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron
"layers.{bid}.attention.wq", # llama-pth
"encoder.layer.{bid}.attention.self.query", # bert
"transformer.h.{bid}.attn.q_proj", # gpt-j
"model.layers.layers.{bid}.self_attn.q_proj", # plamo
"model.layers.{bid}.attention.wq", # internlm2
"transformer.decoder_layer.{bid}.multi_head_attention.query" # Grok
"transformer.decoder_layer.{bid}.multi_head_attention.query",# Grok
"transformer.h.{bid}.attn.attention.q_proj", # exaone
),
# Attention key
MODEL_TENSOR.ATTN_K: (
"model.layers.{bid}.self_attn.k_proj", # llama-hf
"model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron
"layers.{bid}.attention.wk", # llama-pth
"encoder.layer.{bid}.attention.self.key", # bert
"transformer.h.{bid}.attn.k_proj", # gpt-j
"transformer.h.{bid}.attn.k", # refact
"model.layers.layers.{bid}.self_attn.k_proj", # plamo
"model.layers.{bid}.attention.wk", # internlm2
"transformer.decoder_layer.{bid}.multi_head_attention.key" # Grok
"transformer.decoder_layer.{bid}.multi_head_attention.key",# Grok
"transformer.h.{bid}.attn.attention.k_proj", # exaone
),
# Attention value
MODEL_TENSOR.ATTN_V: (
"model.layers.{bid}.self_attn.v_proj", # llama-hf
"model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron
"layers.{bid}.attention.wv", # llama-pth
"encoder.layer.{bid}.attention.self.value", # bert
"transformer.h.{bid}.attn.v_proj", # gpt-j
"transformer.h.{bid}.attn.v", # refact
"model.layers.layers.{bid}.self_attn.v_proj", # plamo
"model.layers.{bid}.attention.wv", # internlm2
"transformer.decoder_layer.{bid}.multi_head_attention.value" # Grok
"transformer.decoder_layer.{bid}.multi_head_attention.value",# Grok
"transformer.h.{bid}.attn.attention.v_proj", # exaone
),
# Attention output
@ -169,7 +179,7 @@ class TensorNameMap:
"transformer.blocks.{bid}.attn.out_proj", # mpt
"transformer.h.{bid}.self_attention.dense", # falcon
"h.{bid}.self_attention.dense", # bloom
"model.layers.{bid}.self_attn.o_proj", # llama-hf
"model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron
"layers.{bid}.attention.wo", # llama-pth
"encoder.layer.{bid}.attention.output.dense", # bert
"transformer.h.{bid}.attn.out_proj", # gpt-j
@ -182,7 +192,9 @@ class TensorNameMap:
"encoder.layers.{bid}.attn.out_proj", # nomic-bert
"transformer.decoder_layer.{bid}.multi_head_attention.linear", # Grok
"transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx
"encoder.layers.{bid}.self_attention.dense", # chatglm
"transformer.layers.{bid}.attn.out_proj", # openelm
"transformer.h.{bid}.attn.attention.out_proj", # exaone
),
# Attention output norm
@ -208,16 +220,17 @@ class TensorNameMap:
# Feed-forward norm
MODEL_TENSOR.FFN_NORM: (
"gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox
"transformer.h.{bid}.ln_2", # gpt2 refact qwen jais
"transformer.h.{bid}.ln_2", # gpt2 refact qwen jais exaone
"h.{bid}.post_attention_layernorm", # bloom
"transformer.blocks.{bid}.norm_2", # mpt
"model.layers.{bid}.post_attention_layernorm", # llama-hf
"model.layers.{bid}.post_attention_layernorm", # llama-hf nemotron
"layers.{bid}.ffn_norm", # llama-pth
"language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon
"model.layers.{bid}.ln2", # yi
"h.{bid}.ln_2", # gpt2
"model.layers.{bid}.ffn_norm", # internlm2
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
"transformer.layers.{bid}.ffn_norm", # openelm
"model.layers.{bid}.pre_ff_layernorm", # jamba
"model.layers.{bid}.pre_moe_layernorm", # mini-jamba
@ -253,7 +266,7 @@ class TensorNameMap:
"transformer.blocks.{bid}.ffn.up_proj", # mpt
"transformer.h.{bid}.mlp.dense_h_to_4h", # falcon
"h.{bid}.mlp.dense_h_to_4h", # bloom
"model.layers.{bid}.mlp.up_proj", # llama-hf refact
"model.layers.{bid}.mlp.up_proj", # llama-hf refact nemotron
"layers.{bid}.feed_forward.w3", # llama-pth
"encoder.layer.{bid}.intermediate.dense", # bert
"transformer.h.{bid}.mlp.fc_in", # gpt-j
@ -272,6 +285,8 @@ class TensorNameMap:
"encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2
"model.layers.{bid}.residual_mlp.w3", # arctic
"model.layers.{bid}.feed_forward.up_proj", # jamba
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
"transformer.h.{bid}.mlp.c_fc_1", # exaone
),
MODEL_TENSOR.FFN_UP_EXP: (
@ -304,6 +319,7 @@ class TensorNameMap:
"transformer.h.{bid}.mlp.linear_1", # refact
"model.layers.{bid}.residual_mlp.w1", # arctic
"model.layers.{bid}.feed_forward.gate_proj", # jamba
"transformer.h.{bid}.mlp.c_fc_0", # exaone
),
MODEL_TENSOR.FFN_GATE_EXP: (
@ -325,7 +341,7 @@ class TensorNameMap:
"transformer.blocks.{bid}.ffn.down_proj", # mpt
"transformer.h.{bid}.mlp.dense_4h_to_h", # falcon
"h.{bid}.mlp.dense_4h_to_h", # bloom
"model.layers.{bid}.mlp.down_proj", # llama-hf
"model.layers.{bid}.mlp.down_proj", # llama-hf nemotron
"layers.{bid}.feed_forward.w2", # llama-pth
"encoder.layer.{bid}.output.dense", # bert
"transformer.h.{bid}.mlp.fc_out", # gpt-j
@ -343,6 +359,8 @@ class TensorNameMap:
"model.layers.{bid}.residual_mlp.w2", # arctic
"encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2
"model.layers.{bid}.feed_forward.down_proj", # jamba
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
"model.layers.h.{bid}.mlp.c_proj", # exaone
),
MODEL_TENSOR.FFN_DOWN_EXP: (
@ -619,14 +637,12 @@ class TensorNameMap:
for tensor, keys in self.block_mappings_cfg.items():
if tensor not in MODEL_TENSORS[arch]:
continue
# TODO: make this configurable
n_experts = 160
for xid in range(n_experts):
tensor_name = TENSOR_NAMES[tensor].format(bid = bid, xid = xid)
self.mapping[tensor_name] = (tensor, tensor_name)
for key in keys:
key = key.format(bid = bid, xid = xid)
self.mapping[key] = (tensor, tensor_name)
tensor_name = TENSOR_NAMES[tensor].format(bid = bid)
self.mapping[tensor_name] = (tensor, tensor_name)
for key in keys:
key = key.format(bid = bid)
self.mapping[key] = (tensor, tensor_name)
def get_type_and_name(self, key: str, try_suffixes: Sequence[str] = ()) -> tuple[MODEL_TENSOR, str] | None:
result = self.mapping.get(key)

69
gguf-py/gguf/utility.py Normal file
View file

@ -0,0 +1,69 @@
from __future__ import annotations
from typing import Literal
def fill_templated_filename(filename: str, output_type: str | None) -> str:
# Given a file name fill in any type templates e.g. 'some-model-name.{ftype}.gguf'
ftype_lowercase: str = output_type.lower() if output_type is not None else ""
ftype_uppercase: str = output_type.upper() if output_type is not None else ""
return filename.format(ftype_lowercase,
outtype=ftype_lowercase, ftype=ftype_lowercase,
OUTTYPE=ftype_uppercase, FTYPE=ftype_uppercase)
def model_weight_count_rounded_notation(model_params_count: int, min_digits: int = 2) -> str:
if model_params_count > 1e12 :
# Trillions Of Parameters
scaled_model_params = model_params_count * 1e-12
scale_suffix = "T"
elif model_params_count > 1e9 :
# Billions Of Parameters
scaled_model_params = model_params_count * 1e-9
scale_suffix = "B"
elif model_params_count > 1e6 :
# Millions Of Parameters
scaled_model_params = model_params_count * 1e-6
scale_suffix = "M"
else:
# Thousands Of Parameters
scaled_model_params = model_params_count * 1e-3
scale_suffix = "K"
fix = max(min_digits - len(str(round(scaled_model_params)).lstrip('0')), 0)
return f"{scaled_model_params:.{fix}f}{scale_suffix}"
def size_label(total_params: int, shared_params: int, expert_params: int, expert_count: int) -> str:
if expert_count > 0:
pretty_size = model_weight_count_rounded_notation(abs(shared_params) + abs(expert_params), min_digits=2)
size_class = f"{expert_count}x{pretty_size}"
else:
size_class = model_weight_count_rounded_notation(abs(total_params), min_digits=2)
return size_class
def naming_convention(model_name: str | None, base_name: str | None, finetune_string: str | None, version_string: str | None, size_label: str | None, output_type: str | None, model_type: Literal['vocab', 'LoRA'] | None = None) -> str:
# Reference: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#gguf-naming-convention
if base_name is not None:
name = base_name.strip().replace(' ', '-').replace('/', '-')
elif model_name is not None:
name = model_name.strip().replace(' ', '-').replace('/', '-')
else:
name = "ggml-model"
parameters = f"-{size_label}" if size_label is not None else ""
finetune = f"-{finetune_string.strip().replace(' ', '-')}" if finetune_string is not None else ""
version = f"-{version_string.strip().replace(' ', '-')}" if version_string is not None else ""
encoding = f"-{output_type.strip().replace(' ', '-').upper()}" if output_type is not None else ""
kind = f"-{model_type.strip().replace(' ', '-')}" if model_type is not None else ""
return f"{name}{parameters}{finetune}{version}{encoding}{kind}"