Compare commits

...
Sign in to create a new pull request.

26 commits

Author SHA1 Message Date
Francis Couture-Harpin
bffdaf4010 Merge branch 'master' into compilade/lazy-convert-hf 2024-05-08 10:56:03 -04:00
Francis Couture-Harpin
94e667a9d8 gguf-py : add tqdm as a dependency
It's small, and used for a progress bar
in GGUFWriter.write_tensors_to_file
2024-05-06 09:08:09 -04:00
Francis Couture-Harpin
68c5ac628f Merge branch 'master' into compilade/lazy-convert-hf 2024-05-06 08:07:08 -04:00
Francis Couture-Harpin
62303e7f77 convert-hf : minor changes for consistency 2024-05-05 16:49:53 -04:00
Francis Couture-Harpin
bc78bf4cdb convert-hf : faster model parts loading
Instead of pre-loading them all into a dict, iterate on the tensors
in the model parts progressively as needed in Model.write_tensors

Conversion for some architectures relies on checking for the presence
of specific tensor names, so for multi-part models, the weight map is read
from the relevant json file to quickly get these names up-front.
2024-05-05 12:58:32 -04:00
Francis Couture-Harpin
98db4347e8 convert-hf : remove einops requirement for InternLM2 2024-05-04 23:57:22 -04:00
Francis Couture-Harpin
0c3833286e convert-hf : flake8 doesn't like lowercase L as a variable name 2024-05-04 23:56:43 -04:00
Francis Couture-Harpin
f09674fbbd convert-hf : save memory with lazy evaluation 2024-05-04 23:56:43 -04:00
Francis Couture-Harpin
215a0d38c8 convert-hf : fix Refact conversion 2024-05-04 23:55:42 -04:00
Francis Couture-Harpin
f2099c50ab convert-hf : align the message logged for converted tensors 2024-05-04 09:20:01 -04:00
Francis Couture-Harpin
98f2d0e0d7 convert-hf : more consistent formatting of cmdline args 2024-05-03 23:05:41 -04:00
Francis Couture-Harpin
3e5e0dced5 Merge branch 'master' into compilade/convert-hf-refactor 2024-05-03 16:20:54 -04:00
Francis Couture-Harpin
6a54973d82 Merge branch 'master' into compilade/convert-hf-refactor 2024-05-02 20:02:46 -04:00
Francis Couture-Harpin
13f4cf70db convert-hf : use a plain class for Model, and forbid direct instantiation
There are no abstract methods used anyway,
so using ABC isn't really necessary.
2024-05-02 15:52:19 -04:00
Francis Couture-Harpin
ce067af118 convert-hf : use an ABC for Model again
It seems Protocol can't be used as a statically type-checked ABC,
because its subclasses also can't be instantiated. (why did it seem to work?)

At least there's still a way to throw an error when forgetting to define
the `model_arch` property of any registered Model subclasses.
2024-05-02 15:10:52 -04:00
Francis Couture-Harpin
644c2696d0 convert-hf : sort model part names
`os.listdir` is said to list files in arbitrary order.
Sorting the file names should let "model-00009-of-00042.safetensors"
be loaded before "model-00010-of-00042.safetensors".
2024-05-01 19:17:44 -04:00
Francis Couture-Harpin
639b374b1a convert-hf : convert norms to f32 by default 2024-05-01 19:03:58 -04:00
Francis Couture-Harpin
21068b6bdf convert-hf : display tensor shape 2024-05-01 16:59:21 -04:00
Francis Couture-Harpin
dcd8dfa1b5 convert : use a string for the SentencePiece tokenizer path 2024-05-01 13:07:10 -04:00
Francis Couture-Harpin
3870164f47 convert-hf : allow unusual model part names
For example, loading `model-00001-of-00001.safetensors` now works.

* convert-hf : fix stacking MoE expert tensors

`torch.stack` and `torch.cat` don't do the same thing.

* convert-hf : fix Mamba conversion

Tested to work even with a SentencePiece-based tokenizer.
2024-05-01 12:30:20 -04:00
Francis Couture-Harpin
56f60f5d69 convert-hf : flake8 linter doesn't like semicolons 2024-05-01 11:38:47 -04:00
Francis Couture-Harpin
cde9ea65e8 convert-hf : simplify MoE weights stacking 2024-04-30 18:12:01 -04:00
Francis Couture-Harpin
698f0b3479 convert-hf : remove unused n_dims in extra_*_tensors 2024-04-30 15:02:34 -04:00
Francis Couture-Harpin
c33775bcc7 convert : upgrade to sentencepiece v0.2.0 2024-04-30 15:01:23 -04:00
Francis Couture-Harpin
0d720acb91 Merge branch 'master' into compilade/convert-hf-refactor 2024-04-30 14:08:05 -04:00
Francis Couture-Harpin
47e02eb7bc convert-hf : begin refactoring write_tensor 2024-04-30 14:07:28 -04:00
14 changed files with 873 additions and 1285 deletions

File diff suppressed because it is too large Load diff

View file

@ -284,6 +284,7 @@ class Params:
n_experts = None
n_experts_used = None
f_rope_freq_base = None
n_ff = None
# hack to determine LLaMA v1 vs v2 vs CodeLlama
if config.get("moe"):
@ -308,6 +309,8 @@ class Params:
n_experts_used = config["moe"]["num_experts_per_tok"]
f_rope_freq_base = 1e6
assert n_ff is not None
return Params(
n_vocab = model["tok_embeddings.weight"].shape[0],
n_embd = config["dim"],
@ -462,7 +465,8 @@ class SentencePieceVocab(Vocab):
# not found in alternate location either
raise FileNotFoundError('Cannot find tokenizer.model')
self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
self.sentencepiece_tokenizer = SentencePieceProcessor()
self.sentencepiece_tokenizer.LoadFromFile(str(fname_tokenizer))
vocab_size = self.sentencepiece_tokenizer.vocab_size()
new_tokens = {id: piece for piece, id in added_tokens.items() if id >= vocab_size}
@ -482,23 +486,23 @@ class SentencePieceVocab(Vocab):
def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
tokenizer = self.sentencepiece_tokenizer
for i in range(tokenizer.vocab_size()):
piece = tokenizer.id_to_piece(i)
piece = tokenizer.IdToPiece(i)
text = piece.encode("utf-8")
score: float = tokenizer.get_score(i)
score: float = tokenizer.GetScore(i)
toktype = gguf.TokenType.NORMAL
if tokenizer.is_unknown(i):
if tokenizer.IsUnknown(i):
toktype = gguf.TokenType.UNKNOWN
if tokenizer.is_control(i):
if tokenizer.IsControl(i):
toktype = gguf.TokenType.CONTROL
# NOTE: I think added_tokens are user defined.
# ref: https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto
# if tokenizer.is_user_defined(i): toktype = gguf.TokenType.USER_DEFINED
if tokenizer.is_unused(i):
if tokenizer.IsUnused(i):
toktype = gguf.TokenType.UNUSED
if tokenizer.is_byte(i):
if tokenizer.IsByte(i):
toktype = gguf.TokenType.BYTE
yield text, score, toktype
@ -906,7 +910,7 @@ class LazyUnpickler(pickle.Unpickler):
def rebuild_from_type_v2(func, new_type, args, state):
return func(*args)
CLASSES = {
CLASSES: dict[tuple[str, str], type[LazyTensor] | LazyStorageKind] = {
# getattr used here as a workaround for mypy not being smart enough to determine
# the staticmethods have a __func__ attribute.
('torch._tensor', '_rebuild_from_type_v2'): getattr(rebuild_from_type_v2, '__func__'),

View file

@ -939,7 +939,7 @@ async def oai_chat_completions(user_prompt,
while event_received:
event_received = False
async for line_in_bytes in response.content:
line = line_in_bytes.decode('utf8')
line = line_in_bytes.decode('utf-8')
line = line.rstrip('\n').rstrip('\r')
if line == '':
continue

View file

@ -860,7 +860,7 @@ class GGUFValueType(IntEnum):
# Note: Does not support GGML_QKK_64
QK_K = 256
# Items here are (block size, type size)
GGML_QUANT_SIZES = {
GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = {
GGMLQuantizationType.F32: (1, 4),
GGMLQuantizationType.F16: (1, 2),
GGMLQuantizationType.Q4_0: (32, 2 + 16),

View file

@ -65,7 +65,7 @@ class ReaderTensor(NamedTuple):
class GGUFReader:
# I - same as host, S - swapped
byte_order: Literal['I' | 'S'] = 'I'
byte_order: Literal['I'] | Literal['S'] = 'I'
alignment: int = GGUF_DEFAULT_ALIGNMENT
# Note: Internal helper, API may change.
@ -83,7 +83,7 @@ class GGUFReader:
GGUFValueType.BOOL: np.bool_,
}
def __init__(self, path: os.PathLike[str] | str, mode: Literal['r' | 'r+' | 'c'] = 'r'):
def __init__(self, path: os.PathLike[str] | str, mode: Literal['r'] | Literal['r+'] | Literal['c'] = 'r'):
self.data = np.memmap(path, mode = mode)
offs = 0
if self._get(offs, np.uint32, override_order = '<')[0] != GGUF_MAGIC:
@ -128,7 +128,7 @@ class GGUFReader:
return self.tensors[idx]
def _get(
self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I' | 'S' | '<'] = None,
self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I'] | Literal['S'] | Literal['<'] = None,
) -> npt.NDArray[Any]:
count = int(count)
itemsize = int(np.empty([], dtype = dtype).itemsize)
@ -250,7 +250,7 @@ class GGUFReader:
raise ValueError(f'Found duplicated tensor with name {tensor_name}')
tensor_names.add(tensor_name)
ggml_type = GGMLQuantizationType(raw_dtype[0])
n_elems = np.prod(dims)
n_elems = int(np.prod(dims))
block_size, type_size = GGML_QUANT_SIZES[ggml_type]
n_bytes = n_elems * type_size // block_size
data_offs = int(start_offs + offset_tensor[0])

View file

@ -7,7 +7,7 @@ import struct
import tempfile
from enum import Enum, auto
from io import BufferedWriter
from typing import IO, Any, Sequence, Mapping
from typing import IO, Any, Callable, Sequence, Mapping
from string import ascii_letters, digits
import numpy as np
@ -28,6 +28,47 @@ from .constants import (
logger = logging.getLogger(__name__)
class LazyTensor:
data: Callable[[], np.ndarray[Any, Any]]
# to avoid too deep recursion
functions: list[Callable[[np.ndarray[Any, Any]], np.ndarray[Any, Any]]]
dtype: np.dtype[Any]
shape: tuple[int, ...]
def __init__(self, data: Callable[[], np.ndarray[Any, Any]], *, dtype: type, shape: tuple[int, ...]):
self.data = data
self.functions = []
self.dtype = np.dtype(dtype)
self.shape = shape
def astype(self, dtype: type, **kwargs) -> LazyTensor:
self.functions.append(lambda n: n.astype(dtype, **kwargs))
self.dtype = np.dtype(dtype)
return self
@property
def nbytes(self) -> int:
size = 1
for n in self.shape:
size *= n
return size * self.dtype.itemsize
def tofile(self, *args, **kwargs) -> None:
data = self.data()
for f in self.functions:
data = f(data)
assert data.shape == self.shape
assert data.dtype == self.dtype
assert data.nbytes == self.nbytes
self.functions = []
self.data = lambda: data
data.tofile(*args, **kwargs)
def byteswap(self, *args, **kwargs) -> LazyTensor:
self.functions.append(lambda n: n.byteswap(*args, **kwargs))
return self
class WriterState(Enum):
EMPTY = auto()
HEADER = auto()
@ -38,7 +79,7 @@ class WriterState(Enum):
class GGUFWriter:
fout: BufferedWriter
temp_file: tempfile.SpooledTemporaryFile[bytes] | None
tensors: list[np.ndarray[Any, Any]]
tensors: list[np.ndarray[Any, Any] | LazyTensor]
_simple_value_packing = {
GGUFValueType.UINT8: "B",
GGUFValueType.INT8: "b",
@ -176,7 +217,7 @@ class GGUFWriter:
if pack_fmt is not None:
self.kv_data += self._pack(pack_fmt, val, skip_pack_prefix = vtype == GGUFValueType.BOOL)
elif vtype == GGUFValueType.STRING:
encoded_val = val.encode("utf8") if isinstance(val, str) else val
encoded_val = val.encode("utf-8") if isinstance(val, str) else val
self.kv_data += self._pack("Q", len(encoded_val))
self.kv_data += encoded_val
elif vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and val:
@ -205,7 +246,7 @@ class GGUFWriter:
raise ValueError(f'Duplicated tensor name {name}')
self.ti_names.add(name)
encoded_name = name.encode("utf8")
encoded_name = name.encode("utf-8")
self.ti_data += self._pack("Q", len(encoded_name))
self.ti_data += encoded_name
n_dims = len(tensor_shape)
@ -237,7 +278,7 @@ class GGUFWriter:
self.ti_data_count += 1
def add_tensor(
self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
self, name: str, tensor: np.ndarray[Any, Any] | LazyTensor, raw_shape: Sequence[int] | None = None,
raw_dtype: GGMLQuantizationType | None = None,
) -> None:
if self.endianess == GGUFEndian.BIG:
@ -262,7 +303,7 @@ class GGUFWriter:
if pad != 0:
fp.write(bytes([0] * pad))
def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None:
def write_tensor_data(self, tensor: np.ndarray[Any, Any] | LazyTensor) -> None:
if self.state is not WriterState.TI_DATA:
raise ValueError(f'Expected output file to contain tensor info, got {self.state}')
@ -272,15 +313,33 @@ class GGUFWriter:
tensor.tofile(self.fout)
self.write_padding(self.fout, tensor.nbytes)
def write_tensors_to_file(self) -> None:
def write_tensors_to_file(self, *, progress: bool = False) -> None:
self.write_ti_data_to_file()
self.write_padding(self.fout, self.fout.tell())
if self.temp_file is None:
self.tensors.reverse() # to pop from the "beginning" in constant time
if progress:
from tqdm import tqdm
total_bytes = sum(t.nbytes for t in self.tensors)
bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
while True:
try:
tensor = self.tensors.pop()
except IndexError:
break
tensor.tofile(self.fout)
bar.update(tensor.nbytes)
self.write_padding(self.fout, tensor.nbytes)
return
while True:
try:
tensor = self.tensors.pop(0)
tensor = self.tensors.pop()
except IndexError:
break
tensor.tofile(self.fout)
@ -479,7 +538,7 @@ class GGUFWriter:
self.add_bool(Keys.Tokenizer.ADD_PREFIX, value)
def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None:
if isinstance(value, list):
if not isinstance(value, str):
template_default = None
template_names = set()

View file

@ -4,7 +4,7 @@ import logging
import json
import os
from pathlib import Path
from typing import Any, Callable
from typing import Any, Callable, Sequence, Mapping, Iterable
from .gguf_writer import GGUFWriter
@ -15,11 +15,11 @@ class SpecialVocab:
merges: list[str]
add_special_token: dict[str, bool]
special_token_ids: dict[str, int]
chat_template: str | None
chat_template: str | Sequence[Mapping[str, str]] | None
def __init__(
self, path: str | os.PathLike[str], load_merges: bool = False,
special_token_types: tuple[str, ...] | None = None,
special_token_types: Iterable[str] | None = None,
n_vocab: int | None = None,
):
self.special_token_ids = {}

View file

@ -21,6 +21,7 @@ classifiers = [
[tool.poetry.dependencies]
python = ">=3.8"
numpy = ">=1.17"
tqdm = ">=4.27"
[tool.poetry.dev-dependencies]
pytest = "^5.2"

View file

@ -47,7 +47,7 @@ def dump_metadata(reader: GGUFReader, args: argparse.Namespace) -> None:
if len(field.types) == 1:
curr_type = field.types[0]
if curr_type == GGUFValueType.STRING:
log_message += ' = {0}'.format(repr(str(bytes(field.parts[-1]), encoding='utf8')[:60]))
log_message += ' = {0}'.format(repr(str(bytes(field.parts[-1]), encoding='utf-8')[:60]))
elif field.types[0] in reader.gguf_scalar_to_np:
log_message += ' = {0}'.format(field.parts[-1][0])
print(log_message) # noqa: NP100

View file

@ -7,7 +7,7 @@ import json
from pathlib import Path
import numpy as np
from typing import Any, Mapping, Sequence
from typing import Any, Sequence
# Necessary to load the local gguf package
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
@ -34,7 +34,7 @@ def get_byteorder(reader: gguf.GGUFReader) -> gguf.GGUFEndian:
return host_endian
def decode_field(field: gguf.ReaderField) -> Any:
def decode_field(field: gguf.ReaderField | None) -> Any:
if field and field.types:
main_type = field.types[0]
@ -42,11 +42,11 @@ def decode_field(field: gguf.ReaderField) -> Any:
sub_type = field.types[-1]
if sub_type == gguf.GGUFValueType.STRING:
return [str(bytes(field.parts[idx]), encoding='utf8') for idx in field.data]
return [str(bytes(field.parts[idx]), encoding='utf-8') for idx in field.data]
else:
return [pv for idx in field.data for pv in field.parts[idx].tolist()]
if main_type == gguf.GGUFValueType.STRING:
return str(bytes(field.parts[-1]), encoding='utf8')
return str(bytes(field.parts[-1]), encoding='utf-8')
else:
return field.parts[-1][0]
@ -59,7 +59,7 @@ def get_field_data(reader: gguf.GGUFReader, key: str) -> Any:
return decode_field(field)
def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new_metadata: Mapping[str, str], remove_metadata: Sequence[str]) -> None:
def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new_metadata: dict[str, str], remove_metadata: Sequence[str]) -> None:
for field in reader.fields.values():
# Suppress virtual fields and fields written by GGUFWriter
if field.name == gguf.Keys.General.ARCHITECTURE or field.name.startswith('GGUF.'):
@ -101,7 +101,7 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new
for tensor in reader.tensors:
# Dimensions are written in reverse order, so flip them first
shape = np.flipud(tensor.shape)
shape = np.flipud(tensor.shape).tolist()
writer.add_tensor_info(tensor.name, shape, tensor.data.dtype, tensor.data.nbytes, tensor.tensor_type)
writer.write_header_to_file()

3
pyrightconfig.json Normal file
View file

@ -0,0 +1,3 @@
{
"extraPaths": ["gguf-py"],
}

View file

@ -1,3 +1,2 @@
-r ./requirements-convert.txt
torch~=2.1.1
einops~=0.7.0

View file

@ -1,3 +1,2 @@
-r ./requirements-convert.txt
torch~=2.1.1
einops~=0.7.0

View file

@ -1,5 +1,5 @@
numpy~=1.24.4
sentencepiece~=0.1.98
sentencepiece~=0.2.0
transformers>=4.40.1,<5.0.0
gguf>=0.1.0
protobuf>=4.21.0,<5.0.0