Allow convert.py to convert to q8_0
Fix issue with bounded_parallel_map and greedy consuming iterator Display elapsed time during conversion
This commit is contained in:
parent
01f2224682
commit
f54daa0735
1 changed files with 91 additions and 26 deletions
117
convert.py
117
convert.py
|
@ -3,6 +3,7 @@
|
|||
import gguf
|
||||
import argparse
|
||||
import concurrent.futures
|
||||
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
|
||||
import copy
|
||||
import enum
|
||||
import faulthandler
|
||||
|
@ -17,6 +18,7 @@ import re
|
|||
import signal
|
||||
import struct
|
||||
import sys
|
||||
import time
|
||||
import zipfile
|
||||
import numpy as np
|
||||
|
||||
|
@ -50,7 +52,13 @@ DT_F32 = UnquantizedDataType('F32')
|
|||
DT_I32 = UnquantizedDataType('I32')
|
||||
DT_BF16 = UnquantizedDataType('BF16')
|
||||
|
||||
DataType = Union[UnquantizedDataType]
|
||||
@dataclass(frozen=True)
|
||||
class QuantizedDataType:
|
||||
name: str
|
||||
|
||||
DT_Q8_0 = QuantizedDataType('Q8_0')
|
||||
|
||||
DataType = Union[UnquantizedDataType, QuantizedDataType]
|
||||
|
||||
DATA_TYPE_TO_NUMPY: Dict[DataType, 'np.dtype[Any]'] = {
|
||||
DT_BF16: np.dtype(np.uint16),
|
||||
|
@ -73,8 +81,9 @@ SAFETENSORS_DATA_TYPES: Dict[str, DataType] = {
|
|||
# TODO: rename to LLAMAFileType
|
||||
# TODO: move to `gguf.py`
|
||||
class GGMLFileType(enum.IntEnum):
|
||||
AllF32 = 0
|
||||
MostlyF16 = 1 # except 1d tensors
|
||||
AllF32 = 0
|
||||
MostlyF16 = 1 # except 1d tensors
|
||||
MostlyQ8_0 = 7 # except 1d tensors
|
||||
|
||||
def type_for_tensor(self, name: str, tensor: 'LazyTensor') -> DataType:
|
||||
if len(tensor.shape) == 1:
|
||||
|
@ -84,6 +93,8 @@ class GGMLFileType(enum.IntEnum):
|
|||
return DT_F32
|
||||
elif self == GGMLFileType.MostlyF16:
|
||||
return DT_F16
|
||||
elif self == GGMLFileType.MostlyQ8_0:
|
||||
return DT_Q8_0
|
||||
else:
|
||||
raise ValueError(self)
|
||||
|
||||
|
@ -391,7 +402,10 @@ class UnquantizedTensor(Tensor):
|
|||
self.data_type = NUMPY_TYPE_TO_DATA_TYPE[ndarray.dtype]
|
||||
|
||||
def astype(self, data_type: DataType) -> Tensor:
|
||||
dtype = DATA_TYPE_TO_NUMPY[data_type]
|
||||
if data_type == DT_Q8_0:
|
||||
dtype = DATA_TYPE_TO_NUMPY[DT_F32]
|
||||
else:
|
||||
dtype = DATA_TYPE_TO_NUMPY[data_type]
|
||||
if self.data_type == DT_BF16:
|
||||
self.ndarray = bf16_to_fp32(self.ndarray)
|
||||
return UnquantizedTensor(self.ndarray.astype(dtype))
|
||||
|
@ -455,7 +469,7 @@ class LazyTensor:
|
|||
|
||||
def load(self) -> Tensor:
|
||||
ret = self._load()
|
||||
assert ret.data_type == self.data_type, (self.data_type, ret.data_type, self.description)
|
||||
assert ret.data_type == self.data_type or (self.data_type is DT_Q8_0 and ret.data_type is DT_F32), (self.data_type, ret.data_type, self.description)
|
||||
return ret
|
||||
|
||||
def astype(self, data_type: DataType) -> 'LazyTensor':
|
||||
|
@ -699,23 +713,32 @@ def lazy_load_file(path: Path) -> ModelPlus:
|
|||
In = TypeVar('In')
|
||||
Out = TypeVar('Out')
|
||||
|
||||
def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], concurrency: int) -> Iterable[Out]:
|
||||
def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], concurrency: int, max_workers: Optional[int] = None, factory: Callable = ThreadPoolExecutor) -> Iterable[Out]:
|
||||
'''Parallel map, but with backpressure. If the caller doesn't call `next`
|
||||
fast enough, this will stop calling `func` at some point rather than
|
||||
letting results pile up in memory. Specifically, there is a max of one
|
||||
output value buffered per thread.'''
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
iterable = iter(iterable)
|
||||
with factory(max_workers = max_workers) as executor:
|
||||
futures: List[concurrent.futures.Future[Out]] = []
|
||||
items_rev = list(iterable)[::-1]
|
||||
for i in range(min(concurrency, len(items_rev))):
|
||||
futures.append(executor.submit(func, items_rev.pop()))
|
||||
while futures:
|
||||
done = False
|
||||
for i in range(concurrency):
|
||||
try:
|
||||
nexti = next(iterable)
|
||||
except StopIteration:
|
||||
break
|
||||
futures.append(executor.submit(func, nexti))
|
||||
while not done or futures:
|
||||
result = futures.pop(0).result()
|
||||
if items_rev:
|
||||
futures.append(executor.submit(func, items_rev.pop()))
|
||||
while len(futures) < concurrency:
|
||||
try:
|
||||
nexti = next(iterable)
|
||||
except StopIteration:
|
||||
done = True
|
||||
break
|
||||
futures.append(executor.submit(func, nexti))
|
||||
yield result
|
||||
|
||||
|
||||
def check_vocab_size(params: Params, vocab: Vocab) -> None:
|
||||
if params.n_vocab != vocab.vocab_size:
|
||||
assert isinstance(vocab, BpeVocab) or isinstance(vocab, SentencePieceVocab)
|
||||
|
@ -732,6 +755,22 @@ def check_vocab_size(params: Params, vocab: Vocab) -> None:
|
|||
msg += f" Most likely you are missing added_tokens.json (should be in {vocab.fname_tokenizer.parent})."
|
||||
raise Exception(msg)
|
||||
|
||||
#### Mini Q8_0 quantization in Python
|
||||
QK8_0 = 32
|
||||
BLOCK_Q8_0 = np.dtype([('d', '<f2'), ('qs', 'i1', (QK8_0,))])
|
||||
def quantize_array_q8_0(arr):
|
||||
assert arr.size % QK8_0 == 0 and arr.size != 0, f'Bad array size {arr.size}'
|
||||
assert arr.dtype == np.float32, f'Bad array type {arr.dtype}'
|
||||
n_blocks = arr.size // QK8_0
|
||||
blocks = arr.reshape((n_blocks, QK8_0))
|
||||
return np.fromiter(map(quantize_block_q8_0, blocks), count = n_blocks, dtype = BLOCK_Q8_0)
|
||||
|
||||
def quantize_block_q8_0(blk, zero = np.float32(0), one = np.float32(1), onetwentyseven = np.float32(127), zero_chunk = (np.int8(0),) * QK8_0):
|
||||
d = abs(blk).max() / onetwentyseven
|
||||
if d == zero:
|
||||
return (np.float16(d), zero_chunk)
|
||||
return (np.float16(d), (blk * (one / d)).round())
|
||||
|
||||
|
||||
class OutputFile:
|
||||
def __init__(self, fname_out: Path) -> None:
|
||||
|
@ -777,9 +816,16 @@ class OutputFile:
|
|||
n_elements = 1
|
||||
for dim in tensor.shape:
|
||||
n_elements *= dim
|
||||
data_type = DATA_TYPE_TO_NUMPY[tensor.data_type]
|
||||
data_nbytes = n_elements * data_type.itemsize
|
||||
self.gguf.add_tensor_info(name, tensor.shape, data_type, data_nbytes)
|
||||
if tensor.data_type == DT_Q8_0:
|
||||
assert n_elements > 0 and n_elements % QK8_0 == 0, f'Cannot quantize as Q8_0, {n_elements} not a multiple of block size {QK8_0}'
|
||||
data_type= BLOCK_Q8_0
|
||||
raw_dtype = gguf.GGMLQuantizationType.Q8_0
|
||||
data_nbytes = n_elements + (n_elements // QK8_0) * 2
|
||||
else:
|
||||
data_type = DATA_TYPE_TO_NUMPY[tensor.data_type]
|
||||
data_nbytes = n_elements * data_type.itemsize
|
||||
raw_dtype = None
|
||||
self.gguf.add_tensor_info(name, tensor.shape, data_type, data_nbytes, raw_dtype = raw_dtype)
|
||||
|
||||
def write_meta(self) -> None:
|
||||
self.gguf.write_header_to_file()
|
||||
|
@ -805,7 +851,19 @@ class OutputFile:
|
|||
of.close()
|
||||
|
||||
@staticmethod
|
||||
def write_all(fname_out: Path, params: Params, model: LazyModel, vocab: Vocab) -> None:
|
||||
def do_item(item: Tuple[str, LazyTensor]) -> (DataType, NDArray):
|
||||
name, lazy_tensor = item
|
||||
tensor = lazy_tensor.load().to_ggml()
|
||||
return (lazy_tensor.data_type, tensor.ndarray)
|
||||
|
||||
@staticmethod
|
||||
def maybe_do_quant(item: Tuple[DataType, NDArray]) -> NDArray:
|
||||
if item[0] == DT_Q8_0:
|
||||
return quantize_array_q8_0(item[1])
|
||||
return item[1]
|
||||
|
||||
@staticmethod
|
||||
def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab) -> None:
|
||||
check_vocab_size(params, vocab)
|
||||
|
||||
of = OutputFile(fname_out)
|
||||
|
@ -821,16 +879,19 @@ class OutputFile:
|
|||
of.write_meta()
|
||||
of.write_tensor_info()
|
||||
|
||||
def do_item(item: Tuple[str, LazyTensor]) -> NDArray:
|
||||
name, lazy_tensor = item
|
||||
return lazy_tensor.load().to_ggml().ndarray
|
||||
|
||||
# tensor data
|
||||
ndarrays = bounded_parallel_map(do_item, model.items(), concurrency=8)
|
||||
ndarrays = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency = 8)
|
||||
if ftype == GGMLFileType.MostlyQ8_0:
|
||||
ndarrays = bounded_parallel_map(OutputFile.maybe_do_quant, ndarrays, concurrency = 8, max_workers = 8, factory = ProcessPoolExecutor)
|
||||
else:
|
||||
ndarrays = map(OutputFile.maybe_do_quant, ndarrays)
|
||||
|
||||
start = time.time()
|
||||
for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)):
|
||||
elapsed = time.time() - start
|
||||
size = ' x '.join(f"{dim:6d}" for dim in lazy_tensor.shape)
|
||||
padi = len(str(len(model)))
|
||||
print(f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type}")
|
||||
print(f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:6} | T+{int(elapsed):4}")
|
||||
of.gguf.write_tensor_data(ndarray)
|
||||
|
||||
of.close()
|
||||
|
@ -842,6 +903,8 @@ def pick_output_type(model: LazyModel, output_type_str: Optional[str]) -> GGMLFi
|
|||
return GGMLFileType.AllF32
|
||||
if output_type_str == "f16" or (output_type_str is None and wq_type in (DT_F16, DT_BF16)):
|
||||
return GGMLFileType.MostlyF16
|
||||
if output_type_str == "q8_0":
|
||||
return GGMLFileType.MostlyQ8_0
|
||||
|
||||
name_to_type = {name: lazy_tensor.data_type for (name, lazy_tensor) in model.items()}
|
||||
|
||||
|
@ -993,6 +1056,7 @@ def default_outfile(model_paths: List[Path], file_type: GGMLFileType) -> Path:
|
|||
namestr = {
|
||||
GGMLFileType.AllF32: "f32",
|
||||
GGMLFileType.MostlyF16: "f16",
|
||||
GGMLFileType.MostlyQ8_0:"q8_0",
|
||||
}[file_type]
|
||||
ret = model_paths[0].parent / f"ggml-model-{namestr}.gguf"
|
||||
if ret in model_paths:
|
||||
|
@ -1016,7 +1080,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
|
|||
parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model")
|
||||
parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file")
|
||||
parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab")
|
||||
parser.add_argument("--outtype", choices=["f32", "f16"], help="output format (default: based on input)")
|
||||
parser.add_argument("--outtype", choices=["f32", "f16", "q8_0"], help="output format (default: based on input)")
|
||||
parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file")
|
||||
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
|
||||
parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)")
|
||||
|
@ -1043,6 +1107,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
|
|||
params.ftype = {
|
||||
"f32": GGMLFileType.AllF32,
|
||||
"f16": GGMLFileType.MostlyF16,
|
||||
"q8_0": GGMLFileType.MostlyQ8_0,
|
||||
}[args.outtype]
|
||||
|
||||
print(f"params = {params}")
|
||||
|
@ -1074,7 +1139,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
|
|||
params.ftype = ftype
|
||||
print(f"Writing {outfile}, format {ftype}")
|
||||
|
||||
OutputFile.write_all(outfile, params, model, vocab)
|
||||
OutputFile.write_all(outfile, ftype, params, model, vocab)
|
||||
print(f"Wrote {outfile}")
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue