convert-hf : begin refactoring write_tensor
This commit is contained in:
parent
b8a7a5a90f
commit
47e02eb7bc
10 changed files with 386 additions and 852 deletions
|
@ -861,7 +861,7 @@ class GGUFValueType(IntEnum):
|
|||
# Note: Does not support GGML_QKK_64
|
||||
QK_K = 256
|
||||
# Items here are (block size, type size)
|
||||
GGML_QUANT_SIZES = {
|
||||
GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = {
|
||||
GGMLQuantizationType.F32: (1, 4),
|
||||
GGMLQuantizationType.F16: (1, 2),
|
||||
GGMLQuantizationType.Q4_0: (32, 2 + 16),
|
||||
|
|
|
@ -63,7 +63,7 @@ class ReaderTensor(NamedTuple):
|
|||
|
||||
class GGUFReader:
|
||||
# I - same as host, S - swapped
|
||||
byte_order: Literal['I' | 'S'] = 'I'
|
||||
byte_order: Literal['I'] | Literal['S'] = 'I'
|
||||
alignment: int = GGUF_DEFAULT_ALIGNMENT
|
||||
|
||||
# Note: Internal helper, API may change.
|
||||
|
@ -81,7 +81,7 @@ class GGUFReader:
|
|||
GGUFValueType.BOOL: np.bool_,
|
||||
}
|
||||
|
||||
def __init__(self, path: os.PathLike[str] | str, mode: Literal['r' | 'r+' | 'c'] = 'r'):
|
||||
def __init__(self, path: os.PathLike[str] | str, mode: Literal['r'] | Literal['r+'] | Literal['c'] = 'r'):
|
||||
self.data = np.memmap(path, mode = mode)
|
||||
offs = 0
|
||||
if self._get(offs, np.uint32, override_order = '<')[0] != GGUF_MAGIC:
|
||||
|
@ -126,7 +126,7 @@ class GGUFReader:
|
|||
return self.tensors[idx]
|
||||
|
||||
def _get(
|
||||
self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I' | 'S' | '<'] = None,
|
||||
self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I'] | Literal['S'] | Literal['<'] = None,
|
||||
) -> npt.NDArray[Any]:
|
||||
count = int(count)
|
||||
itemsize = int(np.empty([], dtype = dtype).itemsize)
|
||||
|
@ -248,7 +248,7 @@ class GGUFReader:
|
|||
raise ValueError(f'Found duplicated tensor with name {tensor_name}')
|
||||
tensor_names.add(tensor_name)
|
||||
ggml_type = GGMLQuantizationType(raw_dtype[0])
|
||||
n_elems = np.prod(dims)
|
||||
n_elems = int(np.prod(dims))
|
||||
block_size, type_size = GGML_QUANT_SIZES[ggml_type]
|
||||
n_bytes = n_elems * type_size // block_size
|
||||
data_offs = int(start_offs + offset_tensor[0])
|
||||
|
|
|
@ -173,7 +173,7 @@ class GGUFWriter:
|
|||
if pack_fmt is not None:
|
||||
self.kv_data += self._pack(pack_fmt, val, skip_pack_prefix = vtype == GGUFValueType.BOOL)
|
||||
elif vtype == GGUFValueType.STRING:
|
||||
encoded_val = val.encode("utf8") if isinstance(val, str) else val
|
||||
encoded_val = val.encode("utf-8") if isinstance(val, str) else val
|
||||
self.kv_data += self._pack("Q", len(encoded_val))
|
||||
self.kv_data += encoded_val
|
||||
elif vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and val:
|
||||
|
@ -202,7 +202,7 @@ class GGUFWriter:
|
|||
raise ValueError(f'Duplicated tensor name {name}')
|
||||
self.ti_names.add(name)
|
||||
|
||||
encoded_name = name.encode("utf8")
|
||||
encoded_name = name.encode("utf-8")
|
||||
self.ti_data += self._pack("Q", len(encoded_name))
|
||||
self.ti_data += encoded_name
|
||||
n_dims = len(tensor_shape)
|
||||
|
@ -476,7 +476,7 @@ class GGUFWriter:
|
|||
self.add_bool(Keys.Tokenizer.ADD_PREFIX, value)
|
||||
|
||||
def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None:
|
||||
if isinstance(value, list):
|
||||
if not isinstance(value, str):
|
||||
template_default = None
|
||||
template_names = set()
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ import json
|
|||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable
|
||||
from typing import Any, Callable, Sequence, Mapping, Iterable
|
||||
|
||||
from .gguf_writer import GGUFWriter
|
||||
|
||||
|
@ -13,11 +13,11 @@ class SpecialVocab:
|
|||
merges: list[str]
|
||||
add_special_token: dict[str, bool]
|
||||
special_token_ids: dict[str, int]
|
||||
chat_template: str | None
|
||||
chat_template: str | Sequence[Mapping[str, str]] | None
|
||||
|
||||
def __init__(
|
||||
self, path: str | os.PathLike[str], load_merges: bool = False,
|
||||
special_token_types: tuple[str, ...] | None = None,
|
||||
special_token_types: Iterable[str] | None = None,
|
||||
n_vocab: int | None = None,
|
||||
):
|
||||
self.special_token_ids = {}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue