Merge branch 'master' into compilade/bitnet-ternary

This commit is contained in:
Francis Couture-Harpin 2024-08-11 15:52:29 -04:00
commit d911cd1f13
138 changed files with 7065 additions and 1937 deletions

View file

@ -15,7 +15,6 @@ def writer_example() -> None:
# Example usage with a file
gguf_writer = GGUFWriter("example.gguf", "llama")
gguf_writer.add_architecture()
gguf_writer.add_block_count(12)
gguf_writer.add_uint32("answer", 42) # Write a 32-bit integer
gguf_writer.add_float32("answer_in_float", 42.0) # Write a 32-bit float

View file

@ -161,6 +161,7 @@ class Keys:
SUFFIX_ID = "tokenizer.ggml.suffix_token_id"
MIDDLE_ID = "tokenizer.ggml.middle_token_id"
EOT_ID = "tokenizer.ggml.eot_token_id"
EOM_ID = "tokenizer.ggml.eom_token_id"
class Adapter:
TYPE = "adapter.type"
@ -216,6 +217,7 @@ class MODEL_ARCH(IntEnum):
CHATGLM = auto()
BITNET = auto()
T5 = auto()
T5ENCODER = auto()
JAIS = auto()
@ -343,6 +345,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.CHATGLM: "chatglm",
MODEL_ARCH.BITNET: "bitnet",
MODEL_ARCH.T5: "t5",
MODEL_ARCH.T5ENCODER: "t5encoder",
MODEL_ARCH.JAIS: "jais",
}
@ -1035,6 +1038,21 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.ENC_FFN_UP,
MODEL_TENSOR.ENC_OUTPUT_NORM,
],
MODEL_ARCH.T5ENCODER: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ENC_ATTN_NORM,
MODEL_TENSOR.ENC_ATTN_Q,
MODEL_TENSOR.ENC_ATTN_K,
MODEL_TENSOR.ENC_ATTN_V,
MODEL_TENSOR.ENC_ATTN_OUT,
MODEL_TENSOR.ENC_ATTN_REL_B,
MODEL_TENSOR.ENC_FFN_NORM,
MODEL_TENSOR.ENC_FFN_GATE,
MODEL_TENSOR.ENC_FFN_DOWN,
MODEL_TENSOR.ENC_FFN_UP,
MODEL_TENSOR.ENC_OUTPUT_NORM,
],
MODEL_ARCH.JAIS: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
@ -1162,7 +1180,7 @@ class LlamaFileType(IntEnum):
MOSTLY_F16 = 1 # except 1d tensors
MOSTLY_Q4_0 = 2 # except 1d tensors
MOSTLY_Q4_1 = 3 # except 1d tensors
MOSTLY_Q4_1_SOME_F16 = 4 # tok_embeddings.weight and output.weight are F16
# MOSTLY_Q4_1_SOME_F16 = 4 # tok_embeddings.weight and output.weight are F16
# MOSTLY_Q4_2 = 5 # support has been removed
# MOSTLY_Q4_3 = 6 # support has been removed
MOSTLY_Q8_0 = 7 # except 1d tensors
@ -1342,3 +1360,4 @@ KEY_TOKENIZER_PRIFIX_ID = Keys.Tokenizer.PREFIX_ID
KEY_TOKENIZER_SUFFIX_ID = Keys.Tokenizer.SUFFIX_ID
KEY_TOKENIZER_MIDDLE_ID = Keys.Tokenizer.MIDDLE_ID
KEY_TOKENIZER_EOT_ID = Keys.Tokenizer.EOT_ID
KEY_TOKENIZER_EOM_ID = Keys.Tokenizer.EOM_ID

View file

@ -312,6 +312,8 @@ class GGUFWriter:
self.add_key_value(key, val, GGUFValueType.STRING)
def add_array(self, key: str, val: Sequence[Any]) -> None:
if len(val) == 0:
return
self.add_key_value(key, val, GGUFValueType.ARRAY)
@staticmethod
@ -826,6 +828,9 @@ class GGUFWriter:
def add_eot_token_id(self, id: int) -> None:
self.add_uint32(Keys.Tokenizer.EOT_ID, id)
def add_eom_token_id(self, id: int) -> None:
self.add_uint32(Keys.Tokenizer.EOM_ID, id)
def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
pack_prefix = ''
if not skip_pack_prefix:
@ -845,7 +850,14 @@ class GGUFWriter:
encoded_val = val.encode("utf-8") if isinstance(val, str) else val
kv_data += self._pack("Q", len(encoded_val))
kv_data += encoded_val
elif vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and val:
elif vtype == GGUFValueType.ARRAY:
if not isinstance(val, Sequence):
raise ValueError("Invalid GGUF metadata array, expecting sequence")
if len(val) == 0:
raise ValueError("Invalid GGUF metadata array. Empty array")
if isinstance(val, bytes):
ltype = GGUFValueType.UINT8
else:

View file

@ -191,6 +191,8 @@ class LazyBase(ABC, metaclass=LazyMeta):
class LazyNumpyTensor(LazyBase):
_tensor_type = np.ndarray
shape: tuple[int, ...] # Makes the type checker happy in quants.py
@classmethod
def meta_with_dtype_and_shape(cls, dtype: DTypeLike, shape: tuple[int, ...]) -> np.ndarray[Any, Any]:
# The initial idea was to use np.nan as the fill value,

View file

@ -174,7 +174,7 @@ class Metadata:
org_component, model_full_name_component = None, model_id
# Check if we erroneously matched against './' or '../' etc...
if org_component is not None and org_component[0] == '.':
if org_component is not None and len(org_component) > 0 and org_component[0] == '.':
org_component = None
name_parts: list[str] = model_full_name_component.split('-')
@ -284,20 +284,67 @@ class Metadata:
########################
if model_card is not None:
if "model_name" in model_card and metadata.name is None:
# Not part of huggingface model card standard but notice some model creator using it
# such as TheBloke in 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF'
metadata.name = model_card.get("model_name")
def use_model_card_metadata(metadata_key: str, model_card_key: str):
if model_card_key in model_card and getattr(metadata, metadata_key, None) is None:
setattr(metadata, metadata_key, model_card.get(model_card_key))
if "model_creator" in model_card and metadata.author is None:
# Not part of huggingface model card standard but notice some model creator using it
# such as TheBloke in 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF'
metadata.author = model_card.get("model_creator")
def use_array_model_card_metadata(metadata_key: str, model_card_key: str):
# Note: Will append rather than replace if already exist
tags_value = model_card.get(model_card_key, None)
if tags_value is None:
return
if "model_type" in model_card and metadata.basename is None:
# Not part of huggingface model card standard but notice some model creator using it
# such as TheBloke in 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF'
metadata.basename = model_card.get("model_type")
current_value = getattr(metadata, metadata_key, None)
if current_value is None:
current_value = []
if isinstance(tags_value, str):
current_value.append(tags_value)
elif isinstance(tags_value, list):
current_value.extend(tags_value)
setattr(metadata, metadata_key, current_value)
# LLAMA.cpp's direct internal convention
# (Definitely not part of hugging face formal/informal standard)
#########################################
use_model_card_metadata("name", "name")
use_model_card_metadata("author", "author")
use_model_card_metadata("version", "version")
use_model_card_metadata("organization", "organization")
use_model_card_metadata("description", "description")
use_model_card_metadata("finetune", "finetune")
use_model_card_metadata("basename", "basename")
use_model_card_metadata("size_label", "size_label")
use_model_card_metadata("source_url", "url")
use_model_card_metadata("source_doi", "doi")
use_model_card_metadata("source_uuid", "uuid")
use_model_card_metadata("source_repo_url", "repo_url")
# LLAMA.cpp's huggingface style convention
# (Definitely not part of hugging face formal/informal standard... but with model_ appended to match their style)
###########################################
use_model_card_metadata("name", "model_name")
use_model_card_metadata("author", "model_author")
use_model_card_metadata("version", "model_version")
use_model_card_metadata("organization", "model_organization")
use_model_card_metadata("description", "model_description")
use_model_card_metadata("finetune", "model_finetune")
use_model_card_metadata("basename", "model_basename")
use_model_card_metadata("size_label", "model_size_label")
use_model_card_metadata("source_url", "model_url")
use_model_card_metadata("source_doi", "model_doi")
use_model_card_metadata("source_uuid", "model_uuid")
use_model_card_metadata("source_repo_url", "model_repo_url")
# Hugging Face Direct Convention
#################################
# Not part of huggingface model card standard but notice some model creator using it
# such as TheBloke in 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF'
use_model_card_metadata("name", "model_name")
use_model_card_metadata("author", "model_creator")
use_model_card_metadata("basename", "model_type")
if "base_model" in model_card:
# This represents the parent models that this is based on
@ -329,58 +376,18 @@ class Metadata:
base_model["repo_url"] = f"https://huggingface.co/{org_component}/{model_full_name_component}"
metadata.base_models.append(base_model)
if "license" in model_card and metadata.license is None:
metadata.license = model_card.get("license")
use_model_card_metadata("license", "license")
use_model_card_metadata("license_name", "license_name")
use_model_card_metadata("license_link", "license_link")
if "license_name" in model_card and metadata.license_name is None:
metadata.license_name = model_card.get("license_name")
use_array_model_card_metadata("tags", "tags")
use_array_model_card_metadata("tags", "pipeline_tag")
if "license_link" in model_card and metadata.license_link is None:
metadata.license_link = model_card.get("license_link")
use_array_model_card_metadata("languages", "languages")
use_array_model_card_metadata("languages", "language")
tags_value = model_card.get("tags", None)
if tags_value is not None:
if metadata.tags is None:
metadata.tags = []
if isinstance(tags_value, str):
metadata.tags.append(tags_value)
elif isinstance(tags_value, list):
metadata.tags.extend(tags_value)
pipeline_tags_value = model_card.get("pipeline_tag", None)
if pipeline_tags_value is not None:
if metadata.tags is None:
metadata.tags = []
if isinstance(pipeline_tags_value, str):
metadata.tags.append(pipeline_tags_value)
elif isinstance(pipeline_tags_value, list):
metadata.tags.extend(pipeline_tags_value)
language_value = model_card.get("languages", model_card.get("language", None))
if language_value is not None:
if metadata.languages is None:
metadata.languages = []
if isinstance(language_value, str):
metadata.languages.append(language_value)
elif isinstance(language_value, list):
metadata.languages.extend(language_value)
dataset_value = model_card.get("datasets", model_card.get("dataset", None))
if dataset_value is not None:
if metadata.datasets is None:
metadata.datasets = []
if isinstance(dataset_value, str):
metadata.datasets.append(dataset_value)
elif isinstance(dataset_value, list):
metadata.datasets.extend(dataset_value)
use_array_model_card_metadata("datasets", "datasets")
use_array_model_card_metadata("datasets", "dataset")
# Hugging Face Parameter Heuristics
####################################

File diff suppressed because it is too large Load diff

237
gguf-py/tests/test_quants.py Executable file
View file

@ -0,0 +1,237 @@
#!/usr/bin/env python3
# Test gguf.quants so that it exactly matches the C implementation of the (de)quantization
# NOTE: this is kind of a mess, but at least it worked for initially testing the Python implementations.
from __future__ import annotations
import argparse
from math import prod
import os
import sys
from pathlib import Path
import ctypes
import logging
import numpy as np
# Necessary to load the local gguf package
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
sys.path.insert(0, str(Path(__file__).parent.parent))
import gguf
from gguf.constants import GGMLQuantizationType
logger = logging.getLogger("test-quants")
c_float_p = ctypes.POINTER(ctypes.c_float)
class ggml_init_params(ctypes.Structure):
_fields_ = [
("mem_size", ctypes.c_size_t),
("mem_buffer", ctypes.c_void_p),
("no_alloc", ctypes.c_bool),
]
class GGMLQuants:
libggml: ctypes.CDLL
def __init__(self, libggml: Path):
self.libggml = ctypes.CDLL(str(libggml))
self.libggml.ggml_quantize_chunk.restype = ctypes.c_size_t
# enum ggml_type type,
# const float * src,
# void * dst,
# int64_t start,
# int64_t nrows,
# int64_t n_per_row,
# const float * imatrix) {
self.libggml.ggml_quantize_chunk.argtypes = (
ctypes.c_int,
ctypes.POINTER(ctypes.c_float),
ctypes.c_void_p,
ctypes.c_int64,
ctypes.c_int64,
ctypes.c_int64,
ctypes.POINTER(ctypes.c_float),
)
self.libggml.ggml_quantize_requires_imatrix.restype = ctypes.c_bool
self.libggml.ggml_quantize_requires_imatrix.argtypes = (ctypes.c_int,)
for t in (
"q4_0", "q4_1", "q5_0", "q5_1", "q8_0",
"q2_K", "q3_K", "q4_K", "q5_K", "q6_K",
"iq2_xxs", "iq2_xs", "iq2_s", "iq3_xxs", "iq3_s", "iq1_s", "iq1_m",
"iq4_nl", "iq4_xs",
):
dequant_func: ctypes._NamedFuncPointer = getattr(self.libggml, "dequantize_row_" + t)
dequant_func.restype = None
dequant_func.argtypes = (ctypes.c_void_p, ctypes.POINTER(ctypes.c_float), ctypes.c_int64)
self.libggml.ggml_fp16_to_fp32_row.restype = None
self.libggml.ggml_fp16_to_fp32_row.argtypes = (ctypes.POINTER(ctypes.c_uint16), ctypes.POINTER(ctypes.c_float), ctypes.c_int64)
self.libggml.ggml_bf16_to_fp32_row.restype = None
self.libggml.ggml_bf16_to_fp32_row.argtypes = (ctypes.POINTER(ctypes.c_uint16), ctypes.POINTER(ctypes.c_float), ctypes.c_int64)
self.libggml.ggml_init.argtypes = (ggml_init_params,)
self.libggml.ggml_init(ggml_init_params(1 * 1024 * 1024, 0, False))
def dequantize(self, tensor: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray:
result = np.zeros(gguf.quant_shape_from_byte_shape(tensor.shape, qtype), dtype=np.float32, order="C")
if qtype == GGMLQuantizationType.F32:
# no-op
result = tensor.view(np.float32)
elif qtype == GGMLQuantizationType.F16:
self.libggml.ggml_fp16_to_fp32_row(tensor.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)), result.ctypes.data_as(c_float_p), result.size)
elif qtype == GGMLQuantizationType.BF16:
self.libggml.ggml_bf16_to_fp32_row(tensor.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)), result.ctypes.data_as(c_float_p), result.size)
else:
lw_qname = qtype.name.lower()
if lw_qname[-1] == "k":
lw_qname = lw_qname[:-1] + "K"
dequant_func: ctypes._NamedFuncPointer = getattr(self.libggml, "dequantize_row_" + lw_qname)
dequant_func(tensor.ctypes.data_as(ctypes.c_void_p), result.ctypes.data_as(c_float_p), result.size)
return result
def quantize(self, data: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray:
result = np.zeros(gguf.quant_shape_to_byte_shape(data.shape, qtype), dtype=np.uint8, order="C")
if self.libggml.ggml_quantize_requires_imatrix(qtype.value):
# TODO: is a column-wise sum of squares appropriate?
qw = np.sum((data * data).reshape((-1, data.shape[-1])), axis=0).ctypes.data_as(c_float_p)
else:
qw = ctypes.cast(0, c_float_p)
result_size = self.libggml.ggml_quantize_chunk(qtype.value, data.ctypes.data_as(c_float_p), result.ctypes.data_as(ctypes.c_void_p), 0, prod(data.shape[:-1]), data.shape[-1], qw)
assert result.size == result_size
return result
def compare_tensors(t1: np.ndarray, t2: np.ndarray, qtype: GGMLQuantizationType) -> bool:
same = np.array_equal(t1, t2)
if same:
return True
else:
block_size, type_size = gguf.GGML_QUANT_SIZES[qtype]
if t1.dtype == np.float32:
t1 = t1.reshape((-1, block_size))
t2 = t2.reshape((-1, block_size))
else:
t1 = t1.reshape((-1, type_size))
t2 = t2.reshape((-1, type_size))
x = t1.view(np.uint8) ^ t2.view(np.uint8)
diff_bits = np.count_nonzero(np.unpackbits(x, axis=-1), axis=-1)
num_bad_blocks = np.count_nonzero(diff_bits, axis=0)
if num_bad_blocks == 0 and t1.shape == t2.shape:
logger.debug("Bits are equal, but arrays don't match, likely contains NANs")
return True
logger.debug(f"{num_bad_blocks} bad blocks ({100 * num_bad_blocks / x.shape[0]:.6f}%)")
bad_block_id = np.argmax(diff_bits, axis=0)
logger.debug(f"Worst block id: {bad_block_id}")
logger.debug(f"Sample bad block ({diff_bits[bad_block_id]} differing bits):\n{t1[bad_block_id]}\nReference:\n{t2[bad_block_id]}")
sum_diff_bits = np.sum(diff_bits)
logger.debug(f"{sum_diff_bits} bits differ ({100 * sum_diff_bits/(x.size * 8):.6f}%)")
return False
def do_test(libggml_path: Path, quick: bool = False):
ggml_quants = GGMLQuants(libggml_path)
np.set_printoptions(precision=None, threshold=(4 * 256) + 1, formatter={"int": lambda n: "0x%02X" % n})
r = np.random.randn(8, 1024, 1024).astype(np.float32, copy=False)
for qtype in (GGMLQuantizationType.F16, *gguf.quants._type_traits.keys()):
has_dequantize = False
has_quantize = False
try:
gguf.dequantize(np.zeros((gguf.GGML_QUANT_SIZES[qtype][1]), dtype=np.uint8), qtype)
has_dequantize = True
except (NotImplementedError, AssertionError) as e:
if isinstance(e, AssertionError):
logger.error(f"Error with {qtype.name}: {e}")
raise e
try:
gguf.quantize(np.zeros((gguf.GGML_QUANT_SIZES[qtype][0]), dtype=np.float32), qtype)
has_quantize = True
except (NotImplementedError, AssertionError) as e:
if isinstance(e, AssertionError):
logger.error(f"Error with {qtype.name}: {e}")
raise e
if not has_dequantize and not has_quantize:
continue
logger.info(f"Testing {qtype.name}")
rc = r.copy(order="C")
pyq = None
ggq = None
if has_quantize:
logger.debug(f"Quantizing to {qtype.name} with Python")
pyq = gguf.quants.quantize(rc, qtype)
logger.debug(f"Quantizing to {qtype.name} with C")
ggq = ggml_quants.quantize(rc, qtype)
if qtype == GGMLQuantizationType.F16:
pyq = pyq.view(np.uint8)
quant_equal = compare_tensors(pyq, ggq, qtype)
if not quant_equal:
logger.error(f"Quantization to {qtype.name} does not match ❌")
else:
logger.info(f"Quantization to {qtype.name} matches exactly ✅")
if has_dequantize:
if ggq is None and not quick:
logger.debug(f"Quantizing to {qtype.name} with C")
ggq = ggml_quants.quantize(rc, qtype)
if ggq is not None:
logger.debug(f"Dequantizing from {qtype.name} with Python")
pydq = gguf.quants.dequantize(ggq, qtype)
logger.debug(f"Dequantizing from {qtype.name} with C")
ggdq = ggml_quants.dequantize(ggq, qtype)
dequant_equal = compare_tensors(pydq, ggdq, qtype)
if not dequant_equal:
logger.error(f"Dequantization from {qtype.name} does not match ❌")
else:
logger.info(f"Dequantization from {qtype.name} matches exactly ✅")
rq_shape = gguf.quants.quant_shape_to_byte_shape((8, 1024, 1024 // 2), qtype)
rq = np.random.random(rq_shape).astype(np.float16).view(np.uint8)
logger.debug(f"Dequantizing random f16 data as {qtype.name} with Python")
pydq = gguf.quants.dequantize(rq, qtype)
logger.debug(f"Dequantizing random f16 data as {qtype.name} with C")
ggdq = ggml_quants.dequantize(rq, qtype)
dequant_equal = compare_tensors(pydq, ggdq, qtype)
if not dequant_equal:
logger.error(f"Dequantization from random f16 data as {qtype.name} does not match ❌")
else:
logger.info(f"Dequantization from random f16 data as {qtype.name} matches exactly ✅")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Test Python (de)quantization against the reference C implementation")
parser.add_argument("--libggml", type=Path, default=Path(__file__).parent.parent.parent / "build" / "ggml" / "src" / "libggml.so", help="The path to libggml.so")
parser.add_argument("--quick", action="store_true", help="Don't quantize with C when it's not strictly necessary")
args = parser.parse_args()
logging.basicConfig(level=logging.DEBUG)
do_test(args.libggml, args.quick)