Refactor types

This commit is contained in:
KerfuffleV2 2023-08-26 09:23:06 -06:00
parent 8ee186c0aa
commit 5f23d41faa

View file

@ -25,7 +25,7 @@ import numpy as np
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Literal, Optional, Sequence, Tuple, TypeVar, Union) from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable, List, Literal, Optional, Sequence, Set, Tuple, TypeVar, Union)
from sentencepiece import SentencePieceProcessor # type: ignore from sentencepiece import SentencePieceProcessor # type: ignore
if TYPE_CHECKING: if TYPE_CHECKING:
@ -45,31 +45,64 @@ DEFAULT_CONCURRENCY = 8
# #
@dataclass(frozen=True) @dataclass(frozen=True)
class UnquantizedDataType: class DataType:
name: str name: str
dtype: 'np.dtype[Any]'
valid_conversions: List[str]
DT_F16 = UnquantizedDataType('F16') def elements_to_bytes(self, n_elements: int) -> int:
DT_F32 = UnquantizedDataType('F32') return n_elements * self.dtype.itemsize
DT_I32 = UnquantizedDataType('I32')
DT_BF16 = UnquantizedDataType('BF16')
@dataclass(frozen=True) @dataclass(frozen=True)
class QuantizedDataType: class UnquantizedDataType(DataType):
name: str pass
DT_Q8_0 = QuantizedDataType('Q8_0') DT_F16 = UnquantizedDataType('F16', dtype = np.dtype(np.float16), valid_conversions = ['F32', 'Q8_0'])
DT_F32 = UnquantizedDataType('F32', dtype = np.dtype(np.float32), valid_conversions = ['F16', 'Q8_0'])
DT_I32 = UnquantizedDataType('I32', dtype = np.dtype(np.int16), valid_conversions = [])
DT_BF16 = UnquantizedDataType('BF16', dtype = np.dtype(np.uint16), valid_conversions = ['F32', 'F16', 'Q8_0'])
DataType = Union[UnquantizedDataType, QuantizedDataType] @dataclass(frozen=True)
class QuantizedDataType(DataType):
block_size: int
quantized_dtype: 'np.dtype[Any]'
ggml_type: gguf.GGMLQuantizationType
DATA_TYPE_TO_NUMPY: Dict[DataType, 'np.dtype[Any]'] = { def quantize(self, arr: NDArray) -> NDArray:
DT_BF16: np.dtype(np.uint16), raise NotImplementedError(f'Quantization for {self.name} not implemented')
DT_F16: np.dtype(np.float16),
DT_F32: np.dtype(np.float32),
DT_I32: np.dtype(np.int32),
}
NUMPY_TYPE_TO_DATA_TYPE: Dict['np.dtype[Any]', DataType] = \ def elements_to_bytes(self, n_elements: int) -> int:
{dtype: data_type for (data_type, dtype) in DATA_TYPE_TO_NUMPY.items()} assert n_elements % self.block_size == 0, f'Invalid number of elements {n_elements} for {self.name} with block size {self.block_size}'
return self.quantized_dtype.itemsize * (n_elements // self.block_size)
@dataclass(frozen=True)
class Q8_0QuantizedDataType(QuantizedDataType):
# Mini Q8_0 quantization in Python!
def quantize(self, arr: NDArray) -> NDArray:
assert arr.size % self.block_size == 0 and arr.size != 0, f'Bad array size {arr.size}'
assert arr.dtype == np.float32, f'Bad array type {arr.dtype}'
n_blocks = arr.size // self.block_size
blocks = arr.reshape((n_blocks, self.block_size))
# Much faster implementation of block quantization contributed by @Cebtenzzre
def quantize_blocks_q8_0(blocks: NDArray) -> Iterable[Tuple[Any, Any]]:
d = abs(blocks).max(axis = 1) / np.float32(127)
with np.errstate(divide = 'ignore'):
qs = (blocks / d[:, None]).round()
qs[d == 0] = 0
yield from zip(d, qs)
return np.fromiter(quantize_blocks_q8_0(blocks), count = n_blocks, dtype = self.quantized_dtype)
DT_Q8_0 = Q8_0QuantizedDataType('Q8_0',
dtype = np.dtype(np.float32), valid_conversions = [],
ggml_type = gguf.GGMLQuantizationType.Q8_0, block_size = 32,
quantized_dtype = np.dtype([('d', '<f2'), ('qs', 'i1', (32,))]))
# Quantized types skipped here because they may also map to np.float32
NUMPY_TYPE_TO_DATA_TYPE: Dict['np.dtype[Any]', DataType] = {}
for dt in (DT_BF16, DT_F16, DT_F32, DT_I32):
if dt.dtype in NUMPY_TYPE_TO_DATA_TYPE:
raise ValueError(f'Invalid duplicate data type {dt}')
NUMPY_TYPE_TO_DATA_TYPE[dt.dtype] = dt
SAFETENSORS_DATA_TYPES: Dict[str, DataType] = { SAFETENSORS_DATA_TYPES: Dict[str, DataType] = {
'BF16': DT_BF16, 'BF16': DT_BF16,
@ -87,18 +120,17 @@ class GGMLFileType(enum.IntEnum):
MostlyQ8_0 = 7 # except 1d tensors MostlyQ8_0 = 7 # except 1d tensors
def type_for_tensor(self, name: str, tensor: 'LazyTensor') -> DataType: def type_for_tensor(self, name: str, tensor: 'LazyTensor') -> DataType:
if len(tensor.shape) == 1: dt = GGML_FILE_TYPE_TO_DATA_TYPE.get(self)
# 1D tensors are always F32. if dt is None:
return DT_F32
elif self == GGMLFileType.AllF32:
return DT_F32
elif self == GGMLFileType.MostlyF16:
return DT_F16
elif self == GGMLFileType.MostlyQ8_0:
return DT_Q8_0
else:
raise ValueError(self) raise ValueError(self)
# 1D tensors are always F32.
return dt if len(tensor.shape) > 1 else DT_F32
GGML_FILE_TYPE_TO_DATA_TYPE: Dict[GGMLFileType, DataType] = {
GGMLFileType.AllF32 : DT_F32,
GGMLFileType.MostlyF16 : DT_F16,
GGMLFileType.MostlyQ8_0: DT_Q8_0,
}
# #
# hparams loading # hparams loading
@ -403,10 +435,7 @@ class UnquantizedTensor(Tensor):
self.data_type = NUMPY_TYPE_TO_DATA_TYPE[ndarray.dtype] self.data_type = NUMPY_TYPE_TO_DATA_TYPE[ndarray.dtype]
def astype(self, data_type: DataType) -> Tensor: def astype(self, data_type: DataType) -> Tensor:
if data_type == DT_Q8_0: dtype = data_type.dtype
dtype = DATA_TYPE_TO_NUMPY[DT_F32]
else:
dtype = DATA_TYPE_TO_NUMPY[data_type]
if self.data_type == DT_BF16: if self.data_type == DT_BF16:
self.ndarray = bf16_to_fp32(self.ndarray) self.ndarray = bf16_to_fp32(self.ndarray)
return UnquantizedTensor(self.ndarray.astype(dtype)) return UnquantizedTensor(self.ndarray.astype(dtype))
@ -445,22 +474,6 @@ def load_unquantized(lazy_tensor: 'LazyTensor', expected_dtype: Any = None, conv
GGMLCompatibleTensor = Union[UnquantizedTensor] GGMLCompatibleTensor = Union[UnquantizedTensor]
class DeferredPermutedTensor(Tensor):
def __init__(self, base: Tensor, n_head: int, n_head_kv: int) -> None:
self.base = base
self.n_head = n_head
self.data_type = self.base.data_type
def astype(self, data_type: DataType) -> Tensor:
return self.base.astype(data_type).permute(self.n_head, self.n_head_kv)
def to_ggml(self) -> GGMLCompatibleTensor:
return self.base.to_ggml().permute(self.n_head, self.n_head_kv)
def permute(self, n_head: int, n_head_kv: int) -> Tensor:
raise Exception("shouldn't permute twice")
@dataclass @dataclass
class LazyTensor: class LazyTensor:
_load: Callable[[], Tensor] _load: Callable[[], Tensor]
@ -470,7 +483,9 @@ class LazyTensor:
def load(self) -> Tensor: def load(self) -> Tensor:
ret = self._load() ret = self._load()
assert ret.data_type == self.data_type or (self.data_type is DT_Q8_0 and ret.data_type is DT_F32), (self.data_type, ret.data_type, self.description) # Should be okay if it maps to the same numpy type?
assert ret.data_type == self.data_type or (self.data_type.dtype == ret.data_type.dtype), \
(self.data_type, ret.data_type, self.description)
return ret return ret
def astype(self, data_type: DataType) -> 'LazyTensor': def astype(self, data_type: DataType) -> 'LazyTensor':
@ -481,8 +496,8 @@ class LazyTensor:
return LazyTensor(load, self.shape, data_type, f'convert({data_type}) {self.description}') return LazyTensor(load, self.shape, data_type, f'convert({data_type}) {self.description}')
def validate_conversion_to(self, data_type: DataType) -> None: def validate_conversion_to(self, data_type: DataType) -> None:
if data_type == self.data_type: if data_type != self.data_type and data_type.name not in self.data_type.valid_conversions:
return raise ValueError(f'Cannot validate conversion from {self.data_type} to {data_type}.')
LazyModel = Dict[str, LazyTensor] LazyModel = Dict[str, LazyTensor]
@ -608,9 +623,7 @@ class LazyUnpickler(pickle.Unpickler):
info = self.zip_file.getinfo(filename) info = self.zip_file.getinfo(filename)
def load(offset: int, elm_count: int) -> NDArray: def load(offset: int, elm_count: int) -> NDArray:
dtype = DATA_TYPE_TO_NUMPY.get(data_type) dtype = data_type.dtype
if dtype is None:
raise Exception("tensor stored in unsupported format")
fp = self.zip_file.open(info) fp = self.zip_file.open(info)
fp.seek(offset * dtype.itemsize) fp.seek(offset * dtype.itemsize)
size = elm_count * dtype.itemsize size = elm_count * dtype.itemsize
@ -674,7 +687,7 @@ def lazy_load_safetensors_file(fp: IO[bytes], path: Path) -> ModelPlus:
def convert(info: Dict[str, Any]) -> LazyTensor: def convert(info: Dict[str, Any]) -> LazyTensor:
data_type = SAFETENSORS_DATA_TYPES[info['dtype']] data_type = SAFETENSORS_DATA_TYPES[info['dtype']]
numpy_dtype = DATA_TYPE_TO_NUMPY[data_type] numpy_dtype = data_type.dtype
shape: List[int] = info['shape'] shape: List[int] = info['shape']
begin, end = info['data_offsets'] begin, end = info['data_offsets']
assert 0 <= begin <= end <= len(byte_buf) assert 0 <= begin <= end <= len(byte_buf)
@ -719,6 +732,9 @@ def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], conc
fast enough, this will stop calling `func` at some point rather than fast enough, this will stop calling `func` at some point rather than
letting results pile up in memory. Specifically, there is a max of one letting results pile up in memory. Specifically, there is a max of one
output value buffered per thread.''' output value buffered per thread.'''
if concurrency < 2:
yield from map(func, iterable)
# Not reached.
iterable = iter(iterable) iterable = iter(iterable)
with factory(max_workers = max_workers) as executor: with factory(max_workers = max_workers) as executor:
futures: List[concurrent.futures.Future[Out]] = [] futures: List[concurrent.futures.Future[Out]] = []
@ -756,24 +772,6 @@ def check_vocab_size(params: Params, vocab: Vocab) -> None:
msg += f" Most likely you are missing added_tokens.json (should be in {vocab.fname_tokenizer.parent})." msg += f" Most likely you are missing added_tokens.json (should be in {vocab.fname_tokenizer.parent})."
raise Exception(msg) raise Exception(msg)
#### Mini Q8_0 quantization in Python
QK8_0 = 32
BLOCK_Q8_0 = np.dtype([('d', '<f2'), ('qs', 'i1', (QK8_0,))])
def quantize_array_q8_0(arr):
assert arr.size % QK8_0 == 0 and arr.size != 0, f'Bad array size {arr.size}'
assert arr.dtype == np.float32, f'Bad array type {arr.dtype}'
n_blocks = arr.size // QK8_0
blocks = arr.reshape((n_blocks, QK8_0))
return np.fromiter(quantize_blocks_q8_0(blocks), count = n_blocks, dtype = BLOCK_Q8_0)
# Much faster implementation of block quantization contributed by @Cebtenzzre
def quantize_blocks_q8_0(blocks):
d = abs(blocks).max(axis = 1) / np.float32(127)
with np.errstate(divide = 'ignore'):
qs = (blocks / d[:, None]).round()
qs[d == 0] = 0
yield from zip(np.float16(d), qs)
class OutputFile: class OutputFile:
def __init__(self, fname_out: Path) -> None: def __init__(self, fname_out: Path) -> None:
@ -816,18 +814,10 @@ class OutputFile:
self.gguf.add_token_types(toktypes) self.gguf.add_token_types(toktypes)
def add_tensor_info(self, name: str, tensor: LazyTensor) -> None: def add_tensor_info(self, name: str, tensor: LazyTensor) -> None:
n_elements = 1 n_elements = int(np.prod(tensor.shape))
for dim in tensor.shape: raw_dtype = getattr(tensor.data_type, 'ggml_type', None)
n_elements *= dim data_type = getattr(tensor.data_type, 'quantized_type', None) or tensor.data_type.dtype
if tensor.data_type == DT_Q8_0: data_nbytes = tensor.data_type.elements_to_bytes(n_elements)
assert n_elements > 0 and n_elements % QK8_0 == 0, f'Cannot quantize as Q8_0, {n_elements} not a multiple of block size {QK8_0}'
data_type= BLOCK_Q8_0
raw_dtype = gguf.GGMLQuantizationType.Q8_0
data_nbytes = n_elements + (n_elements // QK8_0) * 2
else:
data_type = DATA_TYPE_TO_NUMPY[tensor.data_type]
data_nbytes = n_elements * data_type.itemsize
raw_dtype = None
self.gguf.add_tensor_info(name, tensor.shape, data_type, data_nbytes, raw_dtype = raw_dtype) self.gguf.add_tensor_info(name, tensor.shape, data_type, data_nbytes, raw_dtype = raw_dtype)
def write_meta(self) -> None: def write_meta(self) -> None:
@ -854,16 +844,17 @@ class OutputFile:
of.close() of.close()
@staticmethod @staticmethod
def do_item(item: Tuple[str, LazyTensor]) -> (DataType, NDArray): def do_item(item: Tuple[str, LazyTensor]) -> Tuple[DataType, NDArray]:
name, lazy_tensor = item name, lazy_tensor = item
tensor = lazy_tensor.load().to_ggml() tensor = lazy_tensor.load().to_ggml()
return (lazy_tensor.data_type, tensor.ndarray) return (lazy_tensor.data_type, tensor.ndarray)
@staticmethod @staticmethod
def maybe_do_quantize(item: Tuple[DataType, NDArray]) -> NDArray: def maybe_do_quantize(item: Tuple[DataType, NDArray]) -> NDArray:
if item[0] == DT_Q8_0: dt, arr = item
return quantize_array_q8_0(item[1]) if not isinstance(dt, QuantizedDataType):
return item[1] return arr
return dt.quantize(arr)
@staticmethod @staticmethod
def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, concurrency: int = DEFAULT_CONCURRENCY) -> None: def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, concurrency: int = DEFAULT_CONCURRENCY) -> None:
@ -883,11 +874,11 @@ class OutputFile:
of.write_tensor_info() of.write_tensor_info()
# tensor data # tensor data
ndarrays = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency = concurrency) ndarrays_inner = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency = concurrency)
if ftype == GGMLFileType.MostlyQ8_0: if ftype == GGMLFileType.MostlyQ8_0:
ndarrays = bounded_parallel_map(OutputFile.maybe_do_quantize, ndarrays, concurrency = concurrency, max_workers = concurrency, factory = ProcessPoolExecutor) ndarrays = bounded_parallel_map(OutputFile.maybe_do_quantize, ndarrays_inner, concurrency = concurrency, max_workers = concurrency, factory = ProcessPoolExecutor)
else: else:
ndarrays = map(OutputFile.maybe_do_quantize, ndarrays) ndarrays = map(OutputFile.maybe_do_quantize, ndarrays_inner)
start = time.time() start = time.time()
for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)): for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)):
@ -954,7 +945,7 @@ def convert_model_names(model: LazyModel, params: Params) -> LazyModel:
print(f"skipping tensor {name_new}") print(f"skipping tensor {name_new}")
continue continue
else: else:
print(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type} | {lazy_tensor.shape}") print(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type.name:6s} | {lazy_tensor.shape}")
out[name_new] = lazy_tensor out[name_new] = lazy_tensor
return out return out