Compare commits

...
Sign in to create a new pull request.

19 commits

Author SHA1 Message Date
Brian
c32d39cefb
Merge branch 'master' into compilade/convert-hf-refactor 2024-05-06 19:33:38 +10:00
Francis Couture-Harpin
215a0d38c8 convert-hf : fix Refact conversion 2024-05-04 23:55:42 -04:00
Francis Couture-Harpin
f2099c50ab convert-hf : align the message logged for converted tensors 2024-05-04 09:20:01 -04:00
Francis Couture-Harpin
98f2d0e0d7 convert-hf : more consistent formatting of cmdline args 2024-05-03 23:05:41 -04:00
Francis Couture-Harpin
3e5e0dced5 Merge branch 'master' into compilade/convert-hf-refactor 2024-05-03 16:20:54 -04:00
Francis Couture-Harpin
6a54973d82 Merge branch 'master' into compilade/convert-hf-refactor 2024-05-02 20:02:46 -04:00
Francis Couture-Harpin
13f4cf70db convert-hf : use a plain class for Model, and forbid direct instantiation
There are no abstract methods used anyway,
so using ABC isn't really necessary.
2024-05-02 15:52:19 -04:00
Francis Couture-Harpin
ce067af118 convert-hf : use an ABC for Model again
It seems Protocol can't be used as a statically type-checked ABC,
because its subclasses also can't be instantiated. (why did it seem to work?)

At least there's still a way to throw an error when forgetting to define
the `model_arch` property of any registered Model subclasses.
2024-05-02 15:10:52 -04:00
Francis Couture-Harpin
644c2696d0 convert-hf : sort model part names
`os.listdir` is said to list files in arbitrary order.
Sorting the file names should let "model-00009-of-00042.safetensors"
be loaded before "model-00010-of-00042.safetensors".
2024-05-01 19:17:44 -04:00
Francis Couture-Harpin
639b374b1a convert-hf : convert norms to f32 by default 2024-05-01 19:03:58 -04:00
Francis Couture-Harpin
21068b6bdf convert-hf : display tensor shape 2024-05-01 16:59:21 -04:00
Francis Couture-Harpin
dcd8dfa1b5 convert : use a string for the SentencePiece tokenizer path 2024-05-01 13:07:10 -04:00
Francis Couture-Harpin
3870164f47 convert-hf : allow unusual model part names
For example, loading `model-00001-of-00001.safetensors` now works.

* convert-hf : fix stacking MoE expert tensors

`torch.stack` and `torch.cat` don't do the same thing.

* convert-hf : fix Mamba conversion

Tested to work even with a SentencePiece-based tokenizer.
2024-05-01 12:30:20 -04:00
Francis Couture-Harpin
56f60f5d69 convert-hf : flake8 linter doesn't like semicolons 2024-05-01 11:38:47 -04:00
Francis Couture-Harpin
cde9ea65e8 convert-hf : simplify MoE weights stacking 2024-04-30 18:12:01 -04:00
Francis Couture-Harpin
698f0b3479 convert-hf : remove unused n_dims in extra_*_tensors 2024-04-30 15:02:34 -04:00
Francis Couture-Harpin
c33775bcc7 convert : upgrade to sentencepiece v0.2.0 2024-04-30 15:01:23 -04:00
Francis Couture-Harpin
0d720acb91 Merge branch 'master' into compilade/convert-hf-refactor 2024-04-30 14:08:05 -04:00
Francis Couture-Harpin
47e02eb7bc convert-hf : begin refactoring write_tensor 2024-04-30 14:07:28 -04:00
12 changed files with 649 additions and 1278 deletions

View file

@ -86,6 +86,7 @@ let
# TODO(Green-Sky): find a better way to opt-into the heavy ml python runtime
llama-python-extra = python3.withPackages (
ps: [
ps.einops
ps.numpy
ps.sentencepiece
ps.tiktoken

File diff suppressed because it is too large Load diff

View file

@ -284,6 +284,7 @@ class Params:
n_experts = None
n_experts_used = None
f_rope_freq_base = None
n_ff = None
# hack to determine LLaMA v1 vs v2 vs CodeLlama
if config.get("moe"):
@ -308,6 +309,8 @@ class Params:
n_experts_used = config["moe"]["num_experts_per_tok"]
f_rope_freq_base = 1e6
assert n_ff is not None
return Params(
n_vocab = model["tok_embeddings.weight"].shape[0],
n_embd = config["dim"],
@ -462,7 +465,8 @@ class SentencePieceVocab(Vocab):
# not found in alternate location either
raise FileNotFoundError('Cannot find tokenizer.model')
self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
self.sentencepiece_tokenizer = SentencePieceProcessor()
self.sentencepiece_tokenizer.LoadFromFile(str(fname_tokenizer))
vocab_size = self.sentencepiece_tokenizer.vocab_size()
new_tokens = {id: piece for piece, id in added_tokens.items() if id >= vocab_size}
@ -482,23 +486,23 @@ class SentencePieceVocab(Vocab):
def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
tokenizer = self.sentencepiece_tokenizer
for i in range(tokenizer.vocab_size()):
piece = tokenizer.id_to_piece(i)
piece = tokenizer.IdToPiece(i)
text = piece.encode("utf-8")
score: float = tokenizer.get_score(i)
score: float = tokenizer.GetScore(i)
toktype = gguf.TokenType.NORMAL
if tokenizer.is_unknown(i):
if tokenizer.IsUnknown(i):
toktype = gguf.TokenType.UNKNOWN
if tokenizer.is_control(i):
if tokenizer.IsControl(i):
toktype = gguf.TokenType.CONTROL
# NOTE: I think added_tokens are user defined.
# ref: https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto
# if tokenizer.is_user_defined(i): toktype = gguf.TokenType.USER_DEFINED
if tokenizer.is_unused(i):
if tokenizer.IsUnused(i):
toktype = gguf.TokenType.UNUSED
if tokenizer.is_byte(i):
if tokenizer.IsByte(i):
toktype = gguf.TokenType.BYTE
yield text, score, toktype
@ -906,7 +910,7 @@ class LazyUnpickler(pickle.Unpickler):
def rebuild_from_type_v2(func, new_type, args, state):
return func(*args)
CLASSES = {
CLASSES: dict[tuple[str, str], type[LazyTensor] | LazyStorageKind] = {
# getattr used here as a workaround for mypy not being smart enough to determine
# the staticmethods have a __func__ attribute.
('torch._tensor', '_rebuild_from_type_v2'): getattr(rebuild_from_type_v2, '__func__'),

View file

@ -911,7 +911,7 @@ async def oai_chat_completions(user_prompt,
while event_received:
event_received = False
async for line_in_bytes in response.content:
line = line_in_bytes.decode('utf8')
line = line_in_bytes.decode('utf-8')
line = line.rstrip('\n').rstrip('\r')
if line == '':
continue

View file

@ -859,7 +859,7 @@ class GGUFValueType(IntEnum):
# Note: Does not support GGML_QKK_64
QK_K = 256
# Items here are (block size, type size)
GGML_QUANT_SIZES = {
GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = {
GGMLQuantizationType.F32: (1, 4),
GGMLQuantizationType.F16: (1, 2),
GGMLQuantizationType.Q4_0: (32, 2 + 16),

View file

@ -65,7 +65,7 @@ class ReaderTensor(NamedTuple):
class GGUFReader:
# I - same as host, S - swapped
byte_order: Literal['I' | 'S'] = 'I'
byte_order: Literal['I'] | Literal['S'] = 'I'
alignment: int = GGUF_DEFAULT_ALIGNMENT
# Note: Internal helper, API may change.
@ -83,7 +83,7 @@ class GGUFReader:
GGUFValueType.BOOL: np.bool_,
}
def __init__(self, path: os.PathLike[str] | str, mode: Literal['r' | 'r+' | 'c'] = 'r'):
def __init__(self, path: os.PathLike[str] | str, mode: Literal['r'] | Literal['r+'] | Literal['c'] = 'r'):
self.data = np.memmap(path, mode = mode)
offs = 0
if self._get(offs, np.uint32, override_order = '<')[0] != GGUF_MAGIC:
@ -128,7 +128,7 @@ class GGUFReader:
return self.tensors[idx]
def _get(
self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I' | 'S' | '<'] = None,
self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I'] | Literal['S'] | Literal['<'] = None,
) -> npt.NDArray[Any]:
count = int(count)
itemsize = int(np.empty([], dtype = dtype).itemsize)
@ -250,7 +250,7 @@ class GGUFReader:
raise ValueError(f'Found duplicated tensor with name {tensor_name}')
tensor_names.add(tensor_name)
ggml_type = GGMLQuantizationType(raw_dtype[0])
n_elems = np.prod(dims)
n_elems = int(np.prod(dims))
block_size, type_size = GGML_QUANT_SIZES[ggml_type]
n_bytes = n_elems * type_size // block_size
data_offs = int(start_offs + offset_tensor[0])

View file

@ -176,7 +176,7 @@ class GGUFWriter:
if pack_fmt is not None:
self.kv_data += self._pack(pack_fmt, val, skip_pack_prefix = vtype == GGUFValueType.BOOL)
elif vtype == GGUFValueType.STRING:
encoded_val = val.encode("utf8") if isinstance(val, str) else val
encoded_val = val.encode("utf-8") if isinstance(val, str) else val
self.kv_data += self._pack("Q", len(encoded_val))
self.kv_data += encoded_val
elif vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and val:
@ -205,7 +205,7 @@ class GGUFWriter:
raise ValueError(f'Duplicated tensor name {name}')
self.ti_names.add(name)
encoded_name = name.encode("utf8")
encoded_name = name.encode("utf-8")
self.ti_data += self._pack("Q", len(encoded_name))
self.ti_data += encoded_name
n_dims = len(tensor_shape)
@ -479,7 +479,7 @@ class GGUFWriter:
self.add_bool(Keys.Tokenizer.ADD_PREFIX, value)
def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None:
if isinstance(value, list):
if not isinstance(value, str):
template_default = None
template_names = set()

View file

@ -4,7 +4,7 @@ import logging
import json
import os
from pathlib import Path
from typing import Any, Callable
from typing import Any, Callable, Sequence, Mapping, Iterable
from .gguf_writer import GGUFWriter
@ -15,11 +15,11 @@ class SpecialVocab:
merges: list[str]
add_special_token: dict[str, bool]
special_token_ids: dict[str, int]
chat_template: str | None
chat_template: str | Sequence[Mapping[str, str]] | None
def __init__(
self, path: str | os.PathLike[str], load_merges: bool = False,
special_token_types: tuple[str, ...] | None = None,
special_token_types: Iterable[str] | None = None,
n_vocab: int | None = None,
):
self.special_token_ids = {}

View file

@ -47,7 +47,7 @@ def dump_metadata(reader: GGUFReader, args: argparse.Namespace) -> None:
if len(field.types) == 1:
curr_type = field.types[0]
if curr_type == GGUFValueType.STRING:
log_message += ' = {0}'.format(repr(str(bytes(field.parts[-1]), encoding='utf8')[:60]))
log_message += ' = {0}'.format(repr(str(bytes(field.parts[-1]), encoding='utf-8')[:60]))
elif field.types[0] in reader.gguf_scalar_to_np:
log_message += ' = {0}'.format(field.parts[-1][0])
print(log_message) # noqa: NP100

View file

@ -7,7 +7,7 @@ import json
from pathlib import Path
import numpy as np
from typing import Any, Mapping, Sequence
from typing import Any, Sequence
# Necessary to load the local gguf package
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
@ -34,7 +34,7 @@ def get_byteorder(reader: gguf.GGUFReader) -> gguf.GGUFEndian:
return host_endian
def decode_field(field: gguf.ReaderField) -> Any:
def decode_field(field: gguf.ReaderField | None) -> Any:
if field and field.types:
main_type = field.types[0]
@ -42,11 +42,11 @@ def decode_field(field: gguf.ReaderField) -> Any:
sub_type = field.types[-1]
if sub_type == gguf.GGUFValueType.STRING:
return [str(bytes(field.parts[idx]), encoding='utf8') for idx in field.data]
return [str(bytes(field.parts[idx]), encoding='utf-8') for idx in field.data]
else:
return [pv for idx in field.data for pv in field.parts[idx].tolist()]
if main_type == gguf.GGUFValueType.STRING:
return str(bytes(field.parts[-1]), encoding='utf8')
return str(bytes(field.parts[-1]), encoding='utf-8')
else:
return field.parts[-1][0]
@ -59,7 +59,7 @@ def get_field_data(reader: gguf.GGUFReader, key: str) -> Any:
return decode_field(field)
def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new_metadata: Mapping[str, str], remove_metadata: Sequence[str]) -> None:
def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new_metadata: dict[str, str], remove_metadata: Sequence[str]) -> None:
for field in reader.fields.values():
# Suppress virtual fields and fields written by GGUFWriter
if field.name == gguf.Keys.General.ARCHITECTURE or field.name.startswith('GGUF.'):
@ -101,7 +101,7 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new
for tensor in reader.tensors:
# Dimensions are written in reverse order, so flip them first
shape = np.flipud(tensor.shape)
shape = np.flipud(tensor.shape).tolist()
writer.add_tensor_info(tensor.name, shape, tensor.data.dtype, tensor.data.nbytes, tensor.tensor_type)
writer.write_header_to_file()

3
pyrightconfig.json Normal file
View file

@ -0,0 +1,3 @@
{
"extraPaths": ["gguf-py"],
}

View file

@ -1,5 +1,5 @@
numpy~=1.24.4
sentencepiece~=0.1.98
sentencepiece~=0.2.0
transformers>=4.40.1,<5.0.0
gguf>=0.1.0
protobuf>=4.21.0,<5.0.0