convert-hf : save memory with lazy evaluation (#7075)

* convert-hf : begin refactoring write_tensor

* convert : upgrade to sentencepiece v0.2.0

* convert-hf : remove unused n_dims in extra_*_tensors

* convert-hf : simplify MoE weights stacking

* convert-hf : flake8 linter doesn't like semicolons

* convert-hf : allow unusual model part names

For example, loading `model-00001-of-00001.safetensors` now works.

* convert-hf : fix stacking MoE expert tensors

`torch.stack` and `torch.cat` don't do the same thing.

* convert-hf : fix Mamba conversion

Tested to work even with a SentencePiece-based tokenizer.

* convert : use a string for the SentencePiece tokenizer path

* convert-hf : display tensor shape

* convert-hf : convert norms to f32 by default

* convert-hf : sort model part names

`os.listdir` is said to list files in arbitrary order.
Sorting the file names should let "model-00009-of-00042.safetensors"
be loaded before "model-00010-of-00042.safetensors".

* convert-hf : use an ABC for Model again

It seems Protocol can't be used as a statically type-checked ABC,
because its subclasses also can't be instantiated. (why did it seem to work?)

At least there's still a way to throw an error when forgetting to define
the `model_arch` property of any registered Model subclasses.

* convert-hf : use a plain class for Model, and forbid direct instantiation

There are no abstract methods used anyway,
so using ABC isn't really necessary.

* convert-hf : more consistent formatting of cmdline args

* convert-hf : align the message logged for converted tensors

* convert-hf : fix Refact conversion

* convert-hf : save memory with lazy evaluation

* convert-hf : flake8 doesn't like lowercase L as a variable name

* convert-hf : remove einops requirement for InternLM2

* convert-hf : faster model parts loading

Instead of pre-loading them all into a dict, iterate on the tensors
in the model parts progressively as needed in Model.write_tensors

Conversion for some architectures relies on checking for the presence
of specific tensor names, so for multi-part models, the weight map is read
from the relevant json file to quickly get these names up-front.

* convert-hf : minor changes for consistency

* gguf-py : add tqdm as a dependency

It's small, and used for a progress bar
in GGUFWriter.write_tensors_to_file
This commit is contained in:
compilade 2024-05-08 18:16:38 -04:00 committed by GitHub
parent bc4bba364f
commit f98eb31c51
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 873 additions and 1285 deletions

View file

@ -7,7 +7,7 @@ import struct
import tempfile
from enum import Enum, auto
from io import BufferedWriter
from typing import IO, Any, Sequence, Mapping
from typing import IO, Any, Callable, Sequence, Mapping
from string import ascii_letters, digits
import numpy as np
@ -28,6 +28,47 @@ from .constants import (
logger = logging.getLogger(__name__)
class LazyTensor:
data: Callable[[], np.ndarray[Any, Any]]
# to avoid too deep recursion
functions: list[Callable[[np.ndarray[Any, Any]], np.ndarray[Any, Any]]]
dtype: np.dtype[Any]
shape: tuple[int, ...]
def __init__(self, data: Callable[[], np.ndarray[Any, Any]], *, dtype: type, shape: tuple[int, ...]):
self.data = data
self.functions = []
self.dtype = np.dtype(dtype)
self.shape = shape
def astype(self, dtype: type, **kwargs) -> LazyTensor:
self.functions.append(lambda n: n.astype(dtype, **kwargs))
self.dtype = np.dtype(dtype)
return self
@property
def nbytes(self) -> int:
size = 1
for n in self.shape:
size *= n
return size * self.dtype.itemsize
def tofile(self, *args, **kwargs) -> None:
data = self.data()
for f in self.functions:
data = f(data)
assert data.shape == self.shape
assert data.dtype == self.dtype
assert data.nbytes == self.nbytes
self.functions = []
self.data = lambda: data
data.tofile(*args, **kwargs)
def byteswap(self, *args, **kwargs) -> LazyTensor:
self.functions.append(lambda n: n.byteswap(*args, **kwargs))
return self
class WriterState(Enum):
EMPTY = auto()
HEADER = auto()
@ -38,7 +79,7 @@ class WriterState(Enum):
class GGUFWriter:
fout: BufferedWriter
temp_file: tempfile.SpooledTemporaryFile[bytes] | None
tensors: list[np.ndarray[Any, Any]]
tensors: list[np.ndarray[Any, Any] | LazyTensor]
_simple_value_packing = {
GGUFValueType.UINT8: "B",
GGUFValueType.INT8: "b",
@ -176,7 +217,7 @@ class GGUFWriter:
if pack_fmt is not None:
self.kv_data += self._pack(pack_fmt, val, skip_pack_prefix = vtype == GGUFValueType.BOOL)
elif vtype == GGUFValueType.STRING:
encoded_val = val.encode("utf8") if isinstance(val, str) else val
encoded_val = val.encode("utf-8") if isinstance(val, str) else val
self.kv_data += self._pack("Q", len(encoded_val))
self.kv_data += encoded_val
elif vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and val:
@ -205,7 +246,7 @@ class GGUFWriter:
raise ValueError(f'Duplicated tensor name {name}')
self.ti_names.add(name)
encoded_name = name.encode("utf8")
encoded_name = name.encode("utf-8")
self.ti_data += self._pack("Q", len(encoded_name))
self.ti_data += encoded_name
n_dims = len(tensor_shape)
@ -237,7 +278,7 @@ class GGUFWriter:
self.ti_data_count += 1
def add_tensor(
self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
self, name: str, tensor: np.ndarray[Any, Any] | LazyTensor, raw_shape: Sequence[int] | None = None,
raw_dtype: GGMLQuantizationType | None = None,
) -> None:
if self.endianess == GGUFEndian.BIG:
@ -262,7 +303,7 @@ class GGUFWriter:
if pad != 0:
fp.write(bytes([0] * pad))
def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None:
def write_tensor_data(self, tensor: np.ndarray[Any, Any] | LazyTensor) -> None:
if self.state is not WriterState.TI_DATA:
raise ValueError(f'Expected output file to contain tensor info, got {self.state}')
@ -272,15 +313,33 @@ class GGUFWriter:
tensor.tofile(self.fout)
self.write_padding(self.fout, tensor.nbytes)
def write_tensors_to_file(self) -> None:
def write_tensors_to_file(self, *, progress: bool = False) -> None:
self.write_ti_data_to_file()
self.write_padding(self.fout, self.fout.tell())
if self.temp_file is None:
self.tensors.reverse() # to pop from the "beginning" in constant time
if progress:
from tqdm import tqdm
total_bytes = sum(t.nbytes for t in self.tensors)
bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
while True:
try:
tensor = self.tensors.pop()
except IndexError:
break
tensor.tofile(self.fout)
bar.update(tensor.nbytes)
self.write_padding(self.fout, tensor.nbytes)
return
while True:
try:
tensor = self.tensors.pop(0)
tensor = self.tensors.pop()
except IndexError:
break
tensor.tofile(self.fout)
@ -479,7 +538,7 @@ class GGUFWriter:
self.add_bool(Keys.Tokenizer.ADD_PREFIX, value)
def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None:
if isinstance(value, list):
if not isinstance(value, str):
template_default = None
template_names = set()