convert-hf : begin refactoring write_tensor
This commit is contained in:
parent
b8a7a5a90f
commit
47e02eb7bc
10 changed files with 386 additions and 852 deletions
|
@ -86,6 +86,7 @@ let
|
||||||
# TODO(Green-Sky): find a better way to opt-into the heavy ml python runtime
|
# TODO(Green-Sky): find a better way to opt-into the heavy ml python runtime
|
||||||
llama-python-extra = python3.withPackages (
|
llama-python-extra = python3.withPackages (
|
||||||
ps: [
|
ps: [
|
||||||
|
ps.einops
|
||||||
ps.numpy
|
ps.numpy
|
||||||
ps.sentencepiece
|
ps.sentencepiece
|
||||||
ps.tiktoken
|
ps.tiktoken
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -882,7 +882,7 @@ async def oai_chat_completions(user_prompt,
|
||||||
while event_received:
|
while event_received:
|
||||||
event_received = False
|
event_received = False
|
||||||
async for line_in_bytes in response.content:
|
async for line_in_bytes in response.content:
|
||||||
line = line_in_bytes.decode('utf8')
|
line = line_in_bytes.decode('utf-8')
|
||||||
line = line.rstrip('\n').rstrip('\r')
|
line = line.rstrip('\n').rstrip('\r')
|
||||||
if line == '':
|
if line == '':
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -861,7 +861,7 @@ class GGUFValueType(IntEnum):
|
||||||
# Note: Does not support GGML_QKK_64
|
# Note: Does not support GGML_QKK_64
|
||||||
QK_K = 256
|
QK_K = 256
|
||||||
# Items here are (block size, type size)
|
# Items here are (block size, type size)
|
||||||
GGML_QUANT_SIZES = {
|
GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = {
|
||||||
GGMLQuantizationType.F32: (1, 4),
|
GGMLQuantizationType.F32: (1, 4),
|
||||||
GGMLQuantizationType.F16: (1, 2),
|
GGMLQuantizationType.F16: (1, 2),
|
||||||
GGMLQuantizationType.Q4_0: (32, 2 + 16),
|
GGMLQuantizationType.Q4_0: (32, 2 + 16),
|
||||||
|
|
|
@ -63,7 +63,7 @@ class ReaderTensor(NamedTuple):
|
||||||
|
|
||||||
class GGUFReader:
|
class GGUFReader:
|
||||||
# I - same as host, S - swapped
|
# I - same as host, S - swapped
|
||||||
byte_order: Literal['I' | 'S'] = 'I'
|
byte_order: Literal['I'] | Literal['S'] = 'I'
|
||||||
alignment: int = GGUF_DEFAULT_ALIGNMENT
|
alignment: int = GGUF_DEFAULT_ALIGNMENT
|
||||||
|
|
||||||
# Note: Internal helper, API may change.
|
# Note: Internal helper, API may change.
|
||||||
|
@ -81,7 +81,7 @@ class GGUFReader:
|
||||||
GGUFValueType.BOOL: np.bool_,
|
GGUFValueType.BOOL: np.bool_,
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, path: os.PathLike[str] | str, mode: Literal['r' | 'r+' | 'c'] = 'r'):
|
def __init__(self, path: os.PathLike[str] | str, mode: Literal['r'] | Literal['r+'] | Literal['c'] = 'r'):
|
||||||
self.data = np.memmap(path, mode = mode)
|
self.data = np.memmap(path, mode = mode)
|
||||||
offs = 0
|
offs = 0
|
||||||
if self._get(offs, np.uint32, override_order = '<')[0] != GGUF_MAGIC:
|
if self._get(offs, np.uint32, override_order = '<')[0] != GGUF_MAGIC:
|
||||||
|
@ -126,7 +126,7 @@ class GGUFReader:
|
||||||
return self.tensors[idx]
|
return self.tensors[idx]
|
||||||
|
|
||||||
def _get(
|
def _get(
|
||||||
self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I' | 'S' | '<'] = None,
|
self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I'] | Literal['S'] | Literal['<'] = None,
|
||||||
) -> npt.NDArray[Any]:
|
) -> npt.NDArray[Any]:
|
||||||
count = int(count)
|
count = int(count)
|
||||||
itemsize = int(np.empty([], dtype = dtype).itemsize)
|
itemsize = int(np.empty([], dtype = dtype).itemsize)
|
||||||
|
@ -248,7 +248,7 @@ class GGUFReader:
|
||||||
raise ValueError(f'Found duplicated tensor with name {tensor_name}')
|
raise ValueError(f'Found duplicated tensor with name {tensor_name}')
|
||||||
tensor_names.add(tensor_name)
|
tensor_names.add(tensor_name)
|
||||||
ggml_type = GGMLQuantizationType(raw_dtype[0])
|
ggml_type = GGMLQuantizationType(raw_dtype[0])
|
||||||
n_elems = np.prod(dims)
|
n_elems = int(np.prod(dims))
|
||||||
block_size, type_size = GGML_QUANT_SIZES[ggml_type]
|
block_size, type_size = GGML_QUANT_SIZES[ggml_type]
|
||||||
n_bytes = n_elems * type_size // block_size
|
n_bytes = n_elems * type_size // block_size
|
||||||
data_offs = int(start_offs + offset_tensor[0])
|
data_offs = int(start_offs + offset_tensor[0])
|
||||||
|
|
|
@ -173,7 +173,7 @@ class GGUFWriter:
|
||||||
if pack_fmt is not None:
|
if pack_fmt is not None:
|
||||||
self.kv_data += self._pack(pack_fmt, val, skip_pack_prefix = vtype == GGUFValueType.BOOL)
|
self.kv_data += self._pack(pack_fmt, val, skip_pack_prefix = vtype == GGUFValueType.BOOL)
|
||||||
elif vtype == GGUFValueType.STRING:
|
elif vtype == GGUFValueType.STRING:
|
||||||
encoded_val = val.encode("utf8") if isinstance(val, str) else val
|
encoded_val = val.encode("utf-8") if isinstance(val, str) else val
|
||||||
self.kv_data += self._pack("Q", len(encoded_val))
|
self.kv_data += self._pack("Q", len(encoded_val))
|
||||||
self.kv_data += encoded_val
|
self.kv_data += encoded_val
|
||||||
elif vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and val:
|
elif vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and val:
|
||||||
|
@ -202,7 +202,7 @@ class GGUFWriter:
|
||||||
raise ValueError(f'Duplicated tensor name {name}')
|
raise ValueError(f'Duplicated tensor name {name}')
|
||||||
self.ti_names.add(name)
|
self.ti_names.add(name)
|
||||||
|
|
||||||
encoded_name = name.encode("utf8")
|
encoded_name = name.encode("utf-8")
|
||||||
self.ti_data += self._pack("Q", len(encoded_name))
|
self.ti_data += self._pack("Q", len(encoded_name))
|
||||||
self.ti_data += encoded_name
|
self.ti_data += encoded_name
|
||||||
n_dims = len(tensor_shape)
|
n_dims = len(tensor_shape)
|
||||||
|
@ -476,7 +476,7 @@ class GGUFWriter:
|
||||||
self.add_bool(Keys.Tokenizer.ADD_PREFIX, value)
|
self.add_bool(Keys.Tokenizer.ADD_PREFIX, value)
|
||||||
|
|
||||||
def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None:
|
def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None:
|
||||||
if isinstance(value, list):
|
if not isinstance(value, str):
|
||||||
template_default = None
|
template_default = None
|
||||||
template_names = set()
|
template_names = set()
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@ import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable
|
from typing import Any, Callable, Sequence, Mapping, Iterable
|
||||||
|
|
||||||
from .gguf_writer import GGUFWriter
|
from .gguf_writer import GGUFWriter
|
||||||
|
|
||||||
|
@ -13,11 +13,11 @@ class SpecialVocab:
|
||||||
merges: list[str]
|
merges: list[str]
|
||||||
add_special_token: dict[str, bool]
|
add_special_token: dict[str, bool]
|
||||||
special_token_ids: dict[str, int]
|
special_token_ids: dict[str, int]
|
||||||
chat_template: str | None
|
chat_template: str | Sequence[Mapping[str, str]] | None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, path: str | os.PathLike[str], load_merges: bool = False,
|
self, path: str | os.PathLike[str], load_merges: bool = False,
|
||||||
special_token_types: tuple[str, ...] | None = None,
|
special_token_types: Iterable[str] | None = None,
|
||||||
n_vocab: int | None = None,
|
n_vocab: int | None = None,
|
||||||
):
|
):
|
||||||
self.special_token_ids = {}
|
self.special_token_ids = {}
|
||||||
|
|
|
@ -43,7 +43,7 @@ def dump_metadata(reader: GGUFReader, args: argparse.Namespace) -> None:
|
||||||
if len(field.types) == 1:
|
if len(field.types) == 1:
|
||||||
curr_type = field.types[0]
|
curr_type = field.types[0]
|
||||||
if curr_type == GGUFValueType.STRING:
|
if curr_type == GGUFValueType.STRING:
|
||||||
print(' = {0}'.format(repr(str(bytes(field.parts[-1]), encoding='utf8')[:60])), end = '')
|
print(' = {0}'.format(repr(str(bytes(field.parts[-1]), encoding='utf-8')[:60])), end = '')
|
||||||
elif field.types[0] in reader.gguf_scalar_to_np:
|
elif field.types[0] in reader.gguf_scalar_to_np:
|
||||||
print(' = {0}'.format(field.parts[-1][0]), end = '')
|
print(' = {0}'.format(field.parts[-1][0]), end = '')
|
||||||
print()
|
print()
|
||||||
|
|
|
@ -34,7 +34,7 @@ def get_byteorder(reader: gguf.GGUFReader) -> gguf.GGUFEndian:
|
||||||
return host_endian
|
return host_endian
|
||||||
|
|
||||||
|
|
||||||
def decode_field(field: gguf.ReaderField) -> Any:
|
def decode_field(field: gguf.ReaderField | None) -> Any:
|
||||||
if field and field.types:
|
if field and field.types:
|
||||||
main_type = field.types[0]
|
main_type = field.types[0]
|
||||||
|
|
||||||
|
@ -42,11 +42,11 @@ def decode_field(field: gguf.ReaderField) -> Any:
|
||||||
sub_type = field.types[-1]
|
sub_type = field.types[-1]
|
||||||
|
|
||||||
if sub_type == gguf.GGUFValueType.STRING:
|
if sub_type == gguf.GGUFValueType.STRING:
|
||||||
return [str(bytes(field.parts[idx]), encoding='utf8') for idx in field.data]
|
return [str(bytes(field.parts[idx]), encoding='utf-8') for idx in field.data]
|
||||||
else:
|
else:
|
||||||
return [pv for idx in field.data for pv in field.parts[idx].tolist()]
|
return [pv for idx in field.data for pv in field.parts[idx].tolist()]
|
||||||
if main_type == gguf.GGUFValueType.STRING:
|
if main_type == gguf.GGUFValueType.STRING:
|
||||||
return str(bytes(field.parts[-1]), encoding='utf8')
|
return str(bytes(field.parts[-1]), encoding='utf-8')
|
||||||
else:
|
else:
|
||||||
return field.parts[-1][0]
|
return field.parts[-1][0]
|
||||||
|
|
||||||
|
@ -59,7 +59,7 @@ def get_field_data(reader: gguf.GGUFReader, key: str) -> Any:
|
||||||
return decode_field(field)
|
return decode_field(field)
|
||||||
|
|
||||||
|
|
||||||
def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new_metadata: Mapping[str, str], remove_metadata: Sequence[str]) -> None:
|
def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new_metadata: dict[str, str], remove_metadata: Sequence[str]) -> None:
|
||||||
for field in reader.fields.values():
|
for field in reader.fields.values():
|
||||||
# Suppress virtual fields and fields written by GGUFWriter
|
# Suppress virtual fields and fields written by GGUFWriter
|
||||||
if field.name == gguf.Keys.General.ARCHITECTURE or field.name.startswith('GGUF.'):
|
if field.name == gguf.Keys.General.ARCHITECTURE or field.name.startswith('GGUF.'):
|
||||||
|
@ -101,7 +101,7 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new
|
||||||
|
|
||||||
for tensor in reader.tensors:
|
for tensor in reader.tensors:
|
||||||
# Dimensions are written in reverse order, so flip them first
|
# Dimensions are written in reverse order, so flip them first
|
||||||
shape = np.flipud(tensor.shape)
|
shape = np.flipud(tensor.shape).tolist()
|
||||||
writer.add_tensor_info(tensor.name, shape, tensor.data.dtype, tensor.data.nbytes, tensor.tensor_type)
|
writer.add_tensor_info(tensor.name, shape, tensor.data.dtype, tensor.data.nbytes, tensor.tensor_type)
|
||||||
|
|
||||||
writer.write_header_to_file()
|
writer.write_header_to_file()
|
||||||
|
|
3
pyrightconfig.json
Normal file
3
pyrightconfig.json
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
{
|
||||||
|
"extraPaths": ["gguf-py"],
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue