Merge branch 'master' into refactor-convert-py

This commit is contained in:
brian khuu 2024-07-16 23:58:56 +10:00
commit 5da16bb1d7
9 changed files with 142 additions and 138 deletions

View file

@ -153,9 +153,16 @@ class Model:
tensor_names_from_parts.update(model_part.keys()) tensor_names_from_parts.update(model_part.keys())
for name in model_part.keys(): for name in model_part.keys():
data = model_part.get_tensor(name) if self.is_safetensors else model_part[name] if self.is_safetensors:
if self.lazy: if self.lazy:
data = LazyTorchTensor.from_eager(data) data = model_part.get_slice(name)
data = LazyTorchTensor.from_safetensors_slice(data)
else:
data = model_part.get_tensor(name)
else:
data = model_part[name]
if self.lazy:
data = LazyTorchTensor.from_eager(data)
yield name, data yield name, data
# only verify tensor name presence; it doesn't matter if they are not in the right files # only verify tensor name presence; it doesn't matter if they are not in the right files
@ -3444,19 +3451,46 @@ class LazyTorchTensor(gguf.LazyBase):
torch.float32: np.float32, torch.float32: np.float32,
} }
# used for safetensors slices
# ref: https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/src/lib.rs#L1046
# TODO: uncomment U64, U32, and U16, ref: https://github.com/pytorch/pytorch/issues/58734
_dtype_str_map: dict[str, torch.dtype] = {
"F64": torch.float64,
"F32": torch.float32,
"BF16": torch.bfloat16,
"F16": torch.float16,
# "U64": torch.uint64,
"I64": torch.int64,
# "U32": torch.uint32,
"I32": torch.int32,
# "U16": torch.uint16,
"I16": torch.int16,
"U8": torch.uint8,
"I8": torch.int8,
"BOOL": torch.bool,
"F8_E4M3": torch.float8_e4m3fn,
"F8_E5M2": torch.float8_e5m2,
}
def numpy(self) -> gguf.LazyNumpyTensor: def numpy(self) -> gguf.LazyNumpyTensor:
dtype = self._dtype_map[self.dtype] dtype = self._dtype_map[self.dtype]
return gguf.LazyNumpyTensor( return gguf.LazyNumpyTensor(
meta=gguf.LazyNumpyTensor.meta_with_dtype_and_shape(dtype, self.shape), meta=gguf.LazyNumpyTensor.meta_with_dtype_and_shape(dtype, self.shape),
lazy=self._lazy,
args=(self,), args=(self,),
func=(lambda s: s[0].numpy()) func=(lambda s: s.numpy())
) )
@classmethod @classmethod
def meta_with_dtype_and_shape(cls, dtype: torch.dtype, shape: torch.Size) -> Tensor: def meta_with_dtype_and_shape(cls, dtype: torch.dtype, shape: tuple[int, ...]) -> Tensor:
return torch.empty(size=shape, dtype=dtype, device="meta") return torch.empty(size=shape, dtype=dtype, device="meta")
@classmethod
def from_safetensors_slice(cls, st_slice: Any) -> Tensor:
dtype = cls._dtype_str_map[st_slice.get_dtype()]
shape: tuple[int, ...] = tuple(st_slice.get_shape())
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:])
return cast(torch.Tensor, lazy)
@classmethod @classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None): def __torch_function__(cls, func, types, args=(), kwargs=None):
del types # unused del types # unused
@ -3467,7 +3501,7 @@ class LazyTorchTensor(gguf.LazyBase):
if func is torch.Tensor.numpy: if func is torch.Tensor.numpy:
return args[0].numpy() return args[0].numpy()
return LazyTorchTensor._wrap_fn(func)(*args, **kwargs) return cls._wrap_fn(func)(*args, **kwargs)
def parse_args() -> argparse.Namespace: def parse_args() -> argparse.Namespace:

View file

@ -190,6 +190,9 @@ static bool export_lora_params_parse(int argc, char ** argv, struct export_lora_
if (params->n_threads <= 0) { if (params->n_threads <= 0) {
params->n_threads = std::thread::hardware_concurrency(); params->n_threads = std::thread::hardware_concurrency();
} }
} else if (arg == "-h" || arg == "--help") {
export_lora_print_usage(argc, argv, &default_params);
exit(0);
} else { } else {
fprintf(stderr, "error: unknown argument: '%s'\n", arg.c_str()); fprintf(stderr, "error: unknown argument: '%s'\n", arg.c_str());
export_lora_print_usage(argc, argv, &default_params); export_lora_print_usage(argc, argv, &default_params);

View file

@ -201,6 +201,6 @@ Verification results for test.gguf.manifest - Success
These micro c libraries dependencies was installed via the [clib c package manager](https://github.com/clibs) These micro c libraries dependencies was installed via the [clib c package manager](https://github.com/clibs)
- https://github.com/mofosyne/xxHash (From: https://github.com/Cyan4973/xxHash) - https://github.com/Cyan4973/xxHash
- https://github.com/clibs/sha1/ - https://github.com/clibs/sha1/
- https://github.com/jb55/sha256.c - https://github.com/jb55/sha256.c

View file

@ -1,7 +1,7 @@
{ {
"name": "xxhash", "name": "xxhash",
"version": "0.8.2", "version": "0.8.2",
"repo": "mofosyne/xxhash", "repo": "Cyan4973/xxhash",
"description": "Extremely fast non-cryptographic hash algorithm", "description": "Extremely fast non-cryptographic hash algorithm",
"keywords": ["xxhash", "hashing"], "keywords": ["xxhash", "hashing"],
"license": "BSD-2-Clause", "license": "BSD-2-Clause",

View file

@ -16,44 +16,44 @@ struct quant_option {
}; };
static const std::vector<struct quant_option> QUANT_OPTIONS = { static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "Q4_0", LLAMA_FTYPE_MOSTLY_Q4_0, " 4.34G, +0.4685 ppl @ Llama-3-8B", }, { "Q4_0", LLAMA_FTYPE_MOSTLY_Q4_0, " 4.34G, +0.4685 ppl @ Llama-3-8B", },
{ "Q4_1", LLAMA_FTYPE_MOSTLY_Q4_1, " 4.78G, +0.4511 ppl @ Llama-3-8B", }, { "Q4_1", LLAMA_FTYPE_MOSTLY_Q4_1, " 4.78G, +0.4511 ppl @ Llama-3-8B", },
{ "Q5_0", LLAMA_FTYPE_MOSTLY_Q5_0, " 5.21G, +0.1316 ppl @ Llama-3-8B", }, { "Q5_0", LLAMA_FTYPE_MOSTLY_Q5_0, " 5.21G, +0.1316 ppl @ Llama-3-8B", },
{ "Q5_1", LLAMA_FTYPE_MOSTLY_Q5_1, " 5.65G, +0.1062 ppl @ Llama-3-8B", }, { "Q5_1", LLAMA_FTYPE_MOSTLY_Q5_1, " 5.65G, +0.1062 ppl @ Llama-3-8B", },
{ "IQ2_XXS",LLAMA_FTYPE_MOSTLY_IQ2_XXS," 2.06 bpw quantization", }, { "IQ2_XXS", LLAMA_FTYPE_MOSTLY_IQ2_XXS, " 2.06 bpw quantization", },
{ "IQ2_XS", LLAMA_FTYPE_MOSTLY_IQ2_XS, " 2.31 bpw quantization", }, { "IQ2_XS", LLAMA_FTYPE_MOSTLY_IQ2_XS, " 2.31 bpw quantization", },
{ "IQ2_S", LLAMA_FTYPE_MOSTLY_IQ2_S, " 2.5 bpw quantization", }, { "IQ2_S", LLAMA_FTYPE_MOSTLY_IQ2_S, " 2.5 bpw quantization", },
{ "IQ2_M", LLAMA_FTYPE_MOSTLY_IQ2_M, " 2.7 bpw quantization", }, { "IQ2_M", LLAMA_FTYPE_MOSTLY_IQ2_M, " 2.7 bpw quantization", },
{ "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S, " 1.56 bpw quantization", }, { "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S, " 1.56 bpw quantization", },
{ "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", }, { "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", },
{ "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.96G, +3.5199 ppl @ Llama-3-8B", }, { "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.96G, +3.5199 ppl @ Llama-3-8B", },
{ "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.96G, +3.1836 ppl @ Llama-3-8B", }, { "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.96G, +3.1836 ppl @ Llama-3-8B", },
{ "IQ3_XXS",LLAMA_FTYPE_MOSTLY_IQ3_XXS," 3.06 bpw quantization", }, { "IQ3_XXS", LLAMA_FTYPE_MOSTLY_IQ3_XXS, " 3.06 bpw quantization", },
{ "IQ3_S", LLAMA_FTYPE_MOSTLY_IQ3_S, " 3.44 bpw quantization", }, { "IQ3_S", LLAMA_FTYPE_MOSTLY_IQ3_S, " 3.44 bpw quantization", },
{ "IQ3_M", LLAMA_FTYPE_MOSTLY_IQ3_M, " 3.66 bpw quantization mix", }, { "IQ3_M", LLAMA_FTYPE_MOSTLY_IQ3_M, " 3.66 bpw quantization mix", },
{ "Q3_K", LLAMA_FTYPE_MOSTLY_Q3_K_M, "alias for Q3_K_M" }, { "Q3_K", LLAMA_FTYPE_MOSTLY_Q3_K_M, "alias for Q3_K_M" },
{ "IQ3_XS", LLAMA_FTYPE_MOSTLY_IQ3_XS, " 3.3 bpw quantization", }, { "IQ3_XS", LLAMA_FTYPE_MOSTLY_IQ3_XS, " 3.3 bpw quantization", },
{ "Q3_K_S", LLAMA_FTYPE_MOSTLY_Q3_K_S, " 3.41G, +1.6321 ppl @ Llama-3-8B", }, { "Q3_K_S", LLAMA_FTYPE_MOSTLY_Q3_K_S, " 3.41G, +1.6321 ppl @ Llama-3-8B", },
{ "Q3_K_M", LLAMA_FTYPE_MOSTLY_Q3_K_M, " 3.74G, +0.6569 ppl @ Llama-3-8B", }, { "Q3_K_M", LLAMA_FTYPE_MOSTLY_Q3_K_M, " 3.74G, +0.6569 ppl @ Llama-3-8B", },
{ "Q3_K_L", LLAMA_FTYPE_MOSTLY_Q3_K_L, " 4.03G, +0.5562 ppl @ Llama-3-8B", }, { "Q3_K_L", LLAMA_FTYPE_MOSTLY_Q3_K_L, " 4.03G, +0.5562 ppl @ Llama-3-8B", },
{ "IQ4_NL", LLAMA_FTYPE_MOSTLY_IQ4_NL, " 4.50 bpw non-linear quantization", }, { "IQ4_NL", LLAMA_FTYPE_MOSTLY_IQ4_NL, " 4.50 bpw non-linear quantization", },
{ "IQ4_XS", LLAMA_FTYPE_MOSTLY_IQ4_XS, " 4.25 bpw non-linear quantization", }, { "IQ4_XS", LLAMA_FTYPE_MOSTLY_IQ4_XS, " 4.25 bpw non-linear quantization", },
{ "Q4_K", LLAMA_FTYPE_MOSTLY_Q4_K_M, "alias for Q4_K_M", }, { "Q4_K", LLAMA_FTYPE_MOSTLY_Q4_K_M, "alias for Q4_K_M", },
{ "Q4_K_S", LLAMA_FTYPE_MOSTLY_Q4_K_S, " 4.37G, +0.2689 ppl @ Llama-3-8B", }, { "Q4_K_S", LLAMA_FTYPE_MOSTLY_Q4_K_S, " 4.37G, +0.2689 ppl @ Llama-3-8B", },
{ "Q4_K_M", LLAMA_FTYPE_MOSTLY_Q4_K_M, " 4.58G, +0.1754 ppl @ Llama-3-8B", }, { "Q4_K_M", LLAMA_FTYPE_MOSTLY_Q4_K_M, " 4.58G, +0.1754 ppl @ Llama-3-8B", },
{ "Q5_K", LLAMA_FTYPE_MOSTLY_Q5_K_M, "alias for Q5_K_M", }, { "Q5_K", LLAMA_FTYPE_MOSTLY_Q5_K_M, "alias for Q5_K_M", },
{ "Q5_K_S", LLAMA_FTYPE_MOSTLY_Q5_K_S, " 5.21G, +0.1049 ppl @ Llama-3-8B", }, { "Q5_K_S", LLAMA_FTYPE_MOSTLY_Q5_K_S, " 5.21G, +0.1049 ppl @ Llama-3-8B", },
{ "Q5_K_M", LLAMA_FTYPE_MOSTLY_Q5_K_M, " 5.33G, +0.0569 ppl @ Llama-3-8B", }, { "Q5_K_M", LLAMA_FTYPE_MOSTLY_Q5_K_M, " 5.33G, +0.0569 ppl @ Llama-3-8B", },
{ "Q6_K", LLAMA_FTYPE_MOSTLY_Q6_K, " 6.14G, +0.0217 ppl @ Llama-3-8B", }, { "Q6_K", LLAMA_FTYPE_MOSTLY_Q6_K, " 6.14G, +0.0217 ppl @ Llama-3-8B", },
{ "Q8_0", LLAMA_FTYPE_MOSTLY_Q8_0, " 7.96G, +0.0026 ppl @ Llama-3-8B", }, { "Q8_0", LLAMA_FTYPE_MOSTLY_Q8_0, " 7.96G, +0.0026 ppl @ Llama-3-8B", },
{ "Q4_0_4_4", LLAMA_FTYPE_MOSTLY_Q4_0_4_4, " 4.34G, +0.4685 ppl @ Llama-3-8B", }, { "Q4_0_4_4", LLAMA_FTYPE_MOSTLY_Q4_0_4_4, " 4.34G, +0.4685 ppl @ Llama-3-8B", },
{ "Q4_0_4_8", LLAMA_FTYPE_MOSTLY_Q4_0_4_8, " 4.34G, +0.4685 ppl @ Llama-3-8B", }, { "Q4_0_4_8", LLAMA_FTYPE_MOSTLY_Q4_0_4_8, " 4.34G, +0.4685 ppl @ Llama-3-8B", },
{ "Q4_0_8_8", LLAMA_FTYPE_MOSTLY_Q4_0_8_8, " 4.34G, +0.4685 ppl @ Llama-3-8B", }, { "Q4_0_8_8", LLAMA_FTYPE_MOSTLY_Q4_0_8_8, " 4.34G, +0.4685 ppl @ Llama-3-8B", },
{ "F16", LLAMA_FTYPE_MOSTLY_F16, "14.00G, +0.0020 ppl @ Mistral-7B", }, { "F16", LLAMA_FTYPE_MOSTLY_F16, "14.00G, +0.0020 ppl @ Mistral-7B", },
{ "BF16", LLAMA_FTYPE_MOSTLY_BF16, "14.00G, -0.0050 ppl @ Mistral-7B", }, { "BF16", LLAMA_FTYPE_MOSTLY_BF16, "14.00G, -0.0050 ppl @ Mistral-7B", },
{ "F32", LLAMA_FTYPE_ALL_F32, "26.00G @ 7B", }, { "F32", LLAMA_FTYPE_ALL_F32, "26.00G @ 7B", },
// Note: Ensure COPY comes after F32 to avoid ftype 0 from matching. // Note: Ensure COPY comes after F32 to avoid ftype 0 from matching.
{ "COPY", LLAMA_FTYPE_ALL_F32, "only copy tensors, no quantizing", }, { "COPY", LLAMA_FTYPE_ALL_F32, "only copy tensors, no quantizing", },
}; };
static const char * const LLM_KV_QUANTIZE_IMATRIX_FILE = "quantize.imatrix.file"; static const char * const LLM_KV_QUANTIZE_IMATRIX_FILE = "quantize.imatrix.file";

View file

@ -3,7 +3,6 @@ from abc import ABC, ABCMeta, abstractmethod
import logging import logging
from typing import Any, Callable from typing import Any, Callable
from collections import deque
import numpy as np import numpy as np
from numpy.typing import DTypeLike from numpy.typing import DTypeLike
@ -74,20 +73,18 @@ class LazyBase(ABC, metaclass=LazyMeta):
_tensor_type: type _tensor_type: type
_meta: Any _meta: Any
_data: Any | None _data: Any | None
_lazy: deque[LazyBase] # shared within a graph, to avoid deep recursion when making eager
_args: tuple _args: tuple
_func: Callable[[tuple], Any] | None _kwargs: dict[str, Any]
_func: Callable[[Any], Any] | None
def __init__(self, *, meta: Any, data: Any | None = None, lazy: deque[LazyBase] | None = None, args: tuple = (), func: Callable[[tuple], Any] | None = None): def __init__(self, *, meta: Any, data: Any | None = None, args: tuple = (), kwargs: dict[str, Any] | None = None, func: Callable[[Any], Any] | None = None):
super().__init__() super().__init__()
self._meta = meta self._meta = meta
self._data = data self._data = data
self._lazy = lazy if lazy is not None else deque()
self._args = args self._args = args
self._kwargs = kwargs if kwargs is not None else {}
self._func = func self._func = func
assert self._func is not None or self._data is not None assert self._func is not None or self._data is not None
if self._data is None:
self._lazy.append(self)
def __init_subclass__(cls) -> None: def __init_subclass__(cls) -> None:
if "_tensor_type" not in cls.__dict__: if "_tensor_type" not in cls.__dict__:
@ -117,6 +114,7 @@ class LazyBase(ABC, metaclass=LazyMeta):
args = ((use_self,) if use_self is not None else ()) + args args = ((use_self,) if use_self is not None else ()) + args
meta_args = LazyBase._recurse_apply(args, lambda t: t._meta) meta_args = LazyBase._recurse_apply(args, lambda t: t._meta)
# TODO: maybe handle tensors in kwargs too
if isinstance(meta_noop, bool) and not meta_noop: if isinstance(meta_noop, bool) and not meta_noop:
try: try:
@ -140,23 +138,7 @@ class LazyBase(ABC, metaclass=LazyMeta):
res = cls.meta_with_dtype_and_shape(meta_noop, res.shape) res = cls.meta_with_dtype_and_shape(meta_noop, res.shape)
if isinstance(res, cls._tensor_type): if isinstance(res, cls._tensor_type):
class CollectSharedLazy: return cls(meta=cls.eager_to_meta(res), args=args, kwargs=kwargs, func=fn)
# emulating a static variable
shared_lazy: None | deque[LazyBase] = None
@staticmethod
def collect_replace(t: LazyBase):
if CollectSharedLazy.shared_lazy is None:
CollectSharedLazy.shared_lazy = t._lazy
else:
CollectSharedLazy.shared_lazy.extend(t._lazy)
t._lazy = CollectSharedLazy.shared_lazy
LazyBase._recurse_apply(args, CollectSharedLazy.collect_replace)
shared_lazy = CollectSharedLazy.shared_lazy
return cls(meta=cls.eager_to_meta(res), lazy=shared_lazy, args=args, func=lambda a: fn(*a, **kwargs))
else: else:
del res # not needed del res # not needed
# non-tensor return likely relies on the contents of the args # non-tensor return likely relies on the contents of the args
@ -168,26 +150,18 @@ class LazyBase(ABC, metaclass=LazyMeta):
@classmethod @classmethod
def to_eager(cls, t: Any) -> Any: def to_eager(cls, t: Any) -> Any:
def simple_to_eager(_t: LazyBase) -> Any: def simple_to_eager(_t: LazyBase) -> Any:
def already_eager_to_eager(_t: LazyBase) -> Any: if _t._data is not None:
assert _t._data is not None
return _t._data return _t._data
while _t._data is None: # NOTE: there's a recursion limit in Python (usually 1000)
lt = _t._lazy.popleft()
if lt._data is not None: assert _t._func is not None
# Lazy tensor did not belong in the lazy queue. _t._args = cls._recurse_apply(_t._args, simple_to_eager)
# Weirdly only happens with Bloom models... _t._data = _t._func(*_t._args, **_t._kwargs)
# likely because tensors aren't unique in the queue. # sanity check
# The final output is still the same as in eager mode, assert _t._data is not None
# so it's safe to ignore this. assert _t._data.dtype == _t._meta.dtype
continue assert _t._data.shape == _t._meta.shape
assert lt._func is not None
lt._args = cls._recurse_apply(lt._args, already_eager_to_eager)
lt._data = lt._func(lt._args)
# sanity check
assert lt._data is not None
assert lt._data.dtype == lt._meta.dtype
assert lt._data.shape == lt._meta.shape
return _t._data return _t._data
@ -206,7 +180,7 @@ class LazyBase(ABC, metaclass=LazyMeta):
@classmethod @classmethod
def from_eager(cls, t: Any) -> Any: def from_eager(cls, t: Any) -> Any:
if type(t) is cls: if type(t) is cls:
# already eager # already lazy
return t return t
elif isinstance(t, cls._tensor_type): elif isinstance(t, cls._tensor_type):
return cls(meta=cls.eager_to_meta(t), data=t) return cls(meta=cls.eager_to_meta(t), data=t)
@ -228,8 +202,7 @@ class LazyNumpyTensor(LazyBase):
def astype(self, dtype, *args, **kwargs): def astype(self, dtype, *args, **kwargs):
meta = type(self).meta_with_dtype_and_shape(dtype, self._meta.shape) meta = type(self).meta_with_dtype_and_shape(dtype, self._meta.shape)
full_args = (self, dtype,) + args full_args = (self, dtype,) + args
# very important to pass the shared _lazy deque, or else there's an infinite loop somewhere. return type(self)(meta=meta, args=full_args, kwargs=kwargs, func=(lambda a, *args, **kwargs: a.astype(*args, **kwargs)))
return type(self)(meta=meta, args=full_args, lazy=self._lazy, func=(lambda a: a[0].astype(*a[1:], **kwargs)))
def tofile(self, *args, **kwargs): def tofile(self, *args, **kwargs):
eager = LazyNumpyTensor.to_eager(self) eager = LazyNumpyTensor.to_eager(self)

View file

@ -602,14 +602,12 @@ class TensorNameMap:
for tensor, keys in self.block_mappings_cfg.items(): for tensor, keys in self.block_mappings_cfg.items():
if tensor not in MODEL_TENSORS[arch]: if tensor not in MODEL_TENSORS[arch]:
continue continue
# TODO: make this configurable
n_experts = 160 tensor_name = TENSOR_NAMES[tensor].format(bid = bid)
for xid in range(n_experts): self.mapping[tensor_name] = (tensor, tensor_name)
tensor_name = TENSOR_NAMES[tensor].format(bid = bid, xid = xid) for key in keys:
self.mapping[tensor_name] = (tensor, tensor_name) key = key.format(bid = bid)
for key in keys: self.mapping[key] = (tensor, tensor_name)
key = key.format(bid = bid, xid = xid)
self.mapping[key] = (tensor, tensor_name)
def get_type_and_name(self, key: str, try_suffixes: Sequence[str] = ()) -> tuple[MODEL_TENSOR, str] | None: def get_type_and_name(self, key: str, try_suffixes: Sequence[str] = ()) -> tuple[MODEL_TENSOR, str] | None:
result = self.mapping.get(key) result = self.mapping.get(key)

View file

@ -133,7 +133,7 @@ extern "C" {
LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16 // LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16
// LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed // LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed
// LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed // LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed
LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors

View file

@ -4510,40 +4510,36 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
} }
switch (ftype) { switch (ftype) {
case LLAMA_FTYPE_ALL_F32: return "all F32"; case LLAMA_FTYPE_ALL_F32: return "all F32";
case LLAMA_FTYPE_MOSTLY_F16: return "F16"; case LLAMA_FTYPE_MOSTLY_F16: return "F16";
case LLAMA_FTYPE_MOSTLY_BF16: return "BF16"; case LLAMA_FTYPE_MOSTLY_BF16: return "BF16";
case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0"; case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0";
case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1"; case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1";
case LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16: case LLAMA_FTYPE_MOSTLY_Q5_0: return "Q5_0";
return "Q4_1, some F16"; case LLAMA_FTYPE_MOSTLY_Q5_1: return "Q5_1";
case LLAMA_FTYPE_MOSTLY_Q5_0: return "Q5_0"; case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0";
case LLAMA_FTYPE_MOSTLY_Q5_1: return "Q5_1"; case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K - Medium";
case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0"; case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small";
case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "Q3_K - Small";
// K-quants case LLAMA_FTYPE_MOSTLY_Q3_K_M: return "Q3_K - Medium";
case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K - Medium"; case LLAMA_FTYPE_MOSTLY_Q3_K_L: return "Q3_K - Large";
case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small"; case LLAMA_FTYPE_MOSTLY_Q4_K_S: return "Q4_K - Small";
case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "Q3_K - Small"; case LLAMA_FTYPE_MOSTLY_Q4_K_M: return "Q4_K - Medium";
case LLAMA_FTYPE_MOSTLY_Q3_K_M: return "Q3_K - Medium"; case LLAMA_FTYPE_MOSTLY_Q5_K_S: return "Q5_K - Small";
case LLAMA_FTYPE_MOSTLY_Q3_K_L: return "Q3_K - Large"; case LLAMA_FTYPE_MOSTLY_Q5_K_M: return "Q5_K - Medium";
case LLAMA_FTYPE_MOSTLY_Q4_K_S: return "Q4_K - Small"; case LLAMA_FTYPE_MOSTLY_Q6_K: return "Q6_K";
case LLAMA_FTYPE_MOSTLY_Q4_K_M: return "Q4_K - Medium"; case LLAMA_FTYPE_MOSTLY_IQ2_XXS: return "IQ2_XXS - 2.0625 bpw";
case LLAMA_FTYPE_MOSTLY_Q5_K_S: return "Q5_K - Small"; case LLAMA_FTYPE_MOSTLY_IQ2_XS: return "IQ2_XS - 2.3125 bpw";
case LLAMA_FTYPE_MOSTLY_Q5_K_M: return "Q5_K - Medium"; case LLAMA_FTYPE_MOSTLY_IQ2_S: return "IQ2_S - 2.5 bpw";
case LLAMA_FTYPE_MOSTLY_Q6_K: return "Q6_K"; case LLAMA_FTYPE_MOSTLY_IQ2_M: return "IQ2_M - 2.7 bpw";
case LLAMA_FTYPE_MOSTLY_IQ2_XXS:return "IQ2_XXS - 2.0625 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_XS: return "IQ3_XS - 3.3 bpw";
case LLAMA_FTYPE_MOSTLY_IQ2_XS: return "IQ2_XS - 2.3125 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_XXS: return "IQ3_XXS - 3.0625 bpw";
case LLAMA_FTYPE_MOSTLY_IQ2_S: return "IQ2_S - 2.5 bpw"; case LLAMA_FTYPE_MOSTLY_IQ1_S: return "IQ1_S - 1.5625 bpw";
case LLAMA_FTYPE_MOSTLY_IQ2_M: return "IQ2_M - 2.7 bpw"; case LLAMA_FTYPE_MOSTLY_IQ1_M: return "IQ1_M - 1.75 bpw";
case LLAMA_FTYPE_MOSTLY_IQ3_XS: return "IQ3_XS - 3.3 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw";
case LLAMA_FTYPE_MOSTLY_IQ3_XXS:return "IQ3_XXS - 3.0625 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw";
case LLAMA_FTYPE_MOSTLY_IQ1_S :return "IQ1_S - 1.5625 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw";
case LLAMA_FTYPE_MOSTLY_IQ1_M :return "IQ1_M - 1.75 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw";
case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw";
case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw";
case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw";
case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw";
case LLAMA_FTYPE_MOSTLY_Q4_0_4_4: return "Q4_0_4_4"; case LLAMA_FTYPE_MOSTLY_Q4_0_4_4: return "Q4_0_4_4";
case LLAMA_FTYPE_MOSTLY_Q4_0_4_8: return "Q4_0_4_8"; case LLAMA_FTYPE_MOSTLY_Q4_0_4_8: return "Q4_0_4_8";
case LLAMA_FTYPE_MOSTLY_Q4_0_8_8: return "Q4_0_8_8"; case LLAMA_FTYPE_MOSTLY_Q4_0_8_8: return "Q4_0_8_8";
@ -18069,10 +18065,10 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
// if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q4_K; // if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q4_K;
//} //}
bool convert_incompatible_tensor = false; bool convert_incompatible_tensor = false;
if (new_type == GGML_TYPE_Q2_K || new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K || if (new_type == GGML_TYPE_Q2_K || new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K ||
new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K || new_type == GGML_TYPE_IQ4_XS || new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K || new_type == GGML_TYPE_IQ4_XS ||
new_type == GGML_TYPE_IQ2_XS || new_type == GGML_TYPE_IQ2_XXS || new_type == GGML_TYPE_IQ2_S || new_type == GGML_TYPE_IQ2_XS || new_type == GGML_TYPE_IQ2_XXS || new_type == GGML_TYPE_IQ2_S ||
new_type == GGML_TYPE_IQ3_XXS || new_type == GGML_TYPE_IQ1_S || new_type == GGML_TYPE_IQ3_S || new_type == GGML_TYPE_IQ3_XXS || new_type == GGML_TYPE_IQ1_S || new_type == GGML_TYPE_IQ3_S ||
new_type == GGML_TYPE_IQ1_M) { new_type == GGML_TYPE_IQ1_M) {
int nx = tensor->ne[0]; int nx = tensor->ne[0];
int ny = tensor->ne[1]; int ny = tensor->ne[1];