Various type annotation fixes.

This commit is contained in:
KerfuffleV2 2023-11-07 17:30:11 -07:00
parent 8047aa192f
commit d7688dc937
3 changed files with 174 additions and 171 deletions

View file

@ -1,8 +1,8 @@
from __future__ import annotations from __future__ import annotations
import sys import sys
from enum import Enum, IntEnum, auto from enum import Enum, IntEnum, auto, StrEnum
from typing import Any, NamedTuple from typing import Any, NamedTuple, Type
# #
# constants # constants
@ -16,63 +16,63 @@ GGUF_DEFAULT_ALIGNMENT = 32
# metadata keys # metadata keys
# #
class GeneralKeys(NamedTuple): class GeneralKeys(StrEnum):
ARCHITECTURE = "general.architecture" ARCHITECTURE : str = "general.architecture"
QUANTIZATION_VERSION = "general.quantization_version" QUANTIZATION_VERSION: str = "general.quantization_version"
ALIGNMENT = "general.alignment" ALIGNMENT : str = "general.alignment"
NAME = "general.name" NAME : str = "general.name"
AUTHOR = "general.author" AUTHOR : str = "general.author"
URL = "general.url" URL : str = "general.url"
DESCRIPTION = "general.description" DESCRIPTION : str = "general.description"
LICENSE = "general.license" LICENSE : str = "general.license"
SOURCE_URL = "general.source.url" SOURCE_URL : str = "general.source.url"
SOURCE_HF_REPO = "general.source.huggingface.repository" SOURCE_HF_REPO : str = "general.source.huggingface.repository"
FILE_TYPE = "general.file_type" FILE_TYPE : str = "general.file_type"
class AttentionKeys(NamedTuple): class AttentionKeys(StrEnum):
HEAD_COUNT = "{arch}.attention.head_count" HEAD_COUNT : str = "{arch}.attention.head_count"
HEAD_COUNT_KV = "{arch}.attention.head_count_kv" HEAD_COUNT_KV : str = "{arch}.attention.head_count_kv"
MAX_ALIBI_BIAS = "{arch}.attention.max_alibi_bias" MAX_ALIBI_BIAS : str = "{arch}.attention.max_alibi_bias"
CLAMP_KQV = "{arch}.attention.clamp_kqv" CLAMP_KQV : str = "{arch}.attention.clamp_kqv"
LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon" LAYERNORM_EPS : str = "{arch}.attention.layer_norm_epsilon"
LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon" LAYERNORM_RMS_EPS: str = "{arch}.attention.layer_norm_rms_epsilon"
class RopeKeys(NamedTuple): class RopeKeys(StrEnum):
DIMENSION_COUNT = "{arch}.rope.dimension_count" DIMENSION_COUNT : str = "{arch}.rope.dimension_count"
FREQ_BASE = "{arch}.rope.freq_base" FREQ_BASE : str = "{arch}.rope.freq_base"
SCALING_TYPE = "{arch}.rope.scaling.type" SCALING_TYPE : str = "{arch}.rope.scaling.type"
SCALING_FACTOR = "{arch}.rope.scaling.factor" SCALING_FACTOR : str = "{arch}.rope.scaling.factor"
SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length" SCALING_ORIG_CTX_LEN: str = "{arch}.rope.scaling.original_context_length"
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned" SCALING_FINETUNED : str = "{arch}.rope.scaling.finetuned"
class TokenizerKeys(NamedTuple): class TokenizerKeys(StrEnum):
MODEL = "tokenizer.ggml.model" MODEL : str = "tokenizer.ggml.model"
LIST = "tokenizer.ggml.tokens" LIST : str = "tokenizer.ggml.tokens"
TOKEN_TYPE = "tokenizer.ggml.token_type" TOKEN_TYPE: str = "tokenizer.ggml.token_type"
SCORES = "tokenizer.ggml.scores" SCORES : str = "tokenizer.ggml.scores"
MERGES = "tokenizer.ggml.merges" MERGES : str = "tokenizer.ggml.merges"
BOS_ID = "tokenizer.ggml.bos_token_id" BOS_ID : str = "tokenizer.ggml.bos_token_id"
EOS_ID = "tokenizer.ggml.eos_token_id" EOS_ID : str = "tokenizer.ggml.eos_token_id"
UNK_ID = "tokenizer.ggml.unknown_token_id" UNK_ID : str = "tokenizer.ggml.unknown_token_id"
SEP_ID = "tokenizer.ggml.seperator_token_id" SEP_ID : str = "tokenizer.ggml.seperator_token_id"
PAD_ID = "tokenizer.ggml.padding_token_id" PAD_ID : str = "tokenizer.ggml.padding_token_id"
HF_JSON = "tokenizer.huggingface.json" HF_JSON : str = "tokenizer.huggingface.json"
RWKV = "tokenizer.rwkv.world" RWKV : str = "tokenizer.rwkv.world"
class LLMKeys(NamedTuple): class LLMKeys(StrEnum):
CONTEXT_LENGTH = "{arch}.context_length" CONTEXT_LENGTH : str = "{arch}.context_length"
EMBEDDING_LENGTH = "{arch}.embedding_length" EMBEDDING_LENGTH : str = "{arch}.embedding_length"
BLOCK_COUNT = "{arch}.block_count" BLOCK_COUNT : str = "{arch}.block_count"
FEED_FORWARD_LENGTH = "{arch}.feed_forward_length" FEED_FORWARD_LENGTH : str = "{arch}.feed_forward_length"
USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual" USE_PARALLEL_RESIDUAL: str = "{arch}.use_parallel_residual"
TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout" TENSOR_DATA_LAYOUT : str = "{arch}.tensor_data_layout"
class Keys(NamedTuple): class Keys(NamedTuple):
GENERAL = GeneralKeys() GENERAL : Type[GeneralKeys ] = GeneralKeys
LLM = LLMKeys() LLM : Type[LLMKeys ] = LLMKeys
ATTENTION = AttentionKeys() ATTENTION: Type[AttentionKeys] = AttentionKeys
ROPE = RopeKeys() ROPE : Type[RopeKeys ] = RopeKeys
TOKENIZER = TokenizerKeys() TOKENIZER: Type[TokenizerKeys] = TokenizerKeys
KEY = Keys() KEY = Keys()
@ -418,52 +418,52 @@ GGML_QUANT_SIZES = {
# Aliases for backward compatibility. # Aliases for backward compatibility.
# general # general
KEY_GENERAL_ARCHITECTURE = KEY.GENERAL.ARCHITECTURE KEY_GENERAL_ARCHITECTURE : str = KEY.GENERAL.ARCHITECTURE
KEY_GENERAL_QUANTIZATION_VERSION = KEY.GENERAL.QUANTIZATION_VERSION KEY_GENERAL_QUANTIZATION_VERSION: str = KEY.GENERAL.QUANTIZATION_VERSION
KEY_GENERAL_ALIGNMENT = KEY.GENERAL.ALIGNMENT KEY_GENERAL_ALIGNMENT : str = KEY.GENERAL.ALIGNMENT
KEY_GENERAL_NAME = KEY.GENERAL.NAME KEY_GENERAL_NAME : str = KEY.GENERAL.NAME
KEY_GENERAL_AUTHOR = KEY.GENERAL.AUTHOR KEY_GENERAL_AUTHOR : str = KEY.GENERAL.AUTHOR
KEY_GENERAL_URL = KEY.GENERAL.URL KEY_GENERAL_URL : str = KEY.GENERAL.URL
KEY_GENERAL_DESCRIPTION = KEY.GENERAL.DESCRIPTION KEY_GENERAL_DESCRIPTION : str = KEY.GENERAL.DESCRIPTION
KEY_GENERAL_LICENSE = KEY.GENERAL.LICENSE KEY_GENERAL_LICENSE : str = KEY.GENERAL.LICENSE
KEY_GENERAL_SOURCE_URL = KEY.GENERAL.SOURCE_URL KEY_GENERAL_SOURCE_URL : str = KEY.GENERAL.SOURCE_URL
KEY_GENERAL_SOURCE_HF_REPO = KEY.GENERAL.SOURCE_HF_REPO KEY_GENERAL_SOURCE_HF_REPO : str = KEY.GENERAL.SOURCE_HF_REPO
KEY_GENERAL_FILE_TYPE = KEY.GENERAL.FILE_TYPE KEY_GENERAL_FILE_TYPE : str = KEY.GENERAL.FILE_TYPE
# LLM # LLM
KEY_CONTEXT_LENGTH = KEY.LLM.CONTEXT_LENGTH KEY_CONTEXT_LENGTH : str = KEY.LLM.CONTEXT_LENGTH
KEY_EMBEDDING_LENGTH = KEY.LLM.EMBEDDING_LENGTH KEY_EMBEDDING_LENGTH : str = KEY.LLM.EMBEDDING_LENGTH
KEY_BLOCK_COUNT = KEY.LLM.BLOCK_COUNT KEY_BLOCK_COUNT : str = KEY.LLM.BLOCK_COUNT
KEY_FEED_FORWARD_LENGTH = KEY.LLM.FEED_FORWARD_LENGTH KEY_FEED_FORWARD_LENGTH : str = KEY.LLM.FEED_FORWARD_LENGTH
KEY_USE_PARALLEL_RESIDUAL = KEY.LLM.USE_PARALLEL_RESIDUAL KEY_USE_PARALLEL_RESIDUAL: str = KEY.LLM.USE_PARALLEL_RESIDUAL
KEY_TENSOR_DATA_LAYOUT = KEY.LLM.TENSOR_DATA_LAYOUT KEY_TENSOR_DATA_LAYOUT : str = KEY.LLM.TENSOR_DATA_LAYOUT
# attention # attention
KEY_ATTENTION_HEAD_COUNT = KEY.ATTENTION.HEAD_COUNT KEY_ATTENTION_HEAD_COUNT : str = KEY.ATTENTION.HEAD_COUNT
KEY_ATTENTION_HEAD_COUNT_KV = KEY.ATTENTION.HEAD_COUNT_KV KEY_ATTENTION_HEAD_COUNT_KV : str = KEY.ATTENTION.HEAD_COUNT_KV
KEY_ATTENTION_MAX_ALIBI_BIAS = KEY.ATTENTION.MAX_ALIBI_BIAS KEY_ATTENTION_MAX_ALIBI_BIAS : str = KEY.ATTENTION.MAX_ALIBI_BIAS
KEY_ATTENTION_CLAMP_KQV = KEY.ATTENTION.CLAMP_KQV KEY_ATTENTION_CLAMP_KQV : str = KEY.ATTENTION.CLAMP_KQV
KEY_ATTENTION_LAYERNORM_EPS = KEY.ATTENTION.LAYERNORM_EPS KEY_ATTENTION_LAYERNORM_EPS : str = KEY.ATTENTION.LAYERNORM_EPS
KEY_ATTENTION_LAYERNORM_RMS_EPS = KEY.ATTENTION.LAYERNORM_RMS_EPS KEY_ATTENTION_LAYERNORM_RMS_EPS: str = KEY.ATTENTION.LAYERNORM_RMS_EPS
# RoPE # RoPE
KEY_ROPE_DIMENSION_COUNT = KEY.ROPE.DIMENSION_COUNT KEY_ROPE_DIMENSION_COUNT : str = KEY.ROPE.DIMENSION_COUNT
KEY_ROPE_FREQ_BASE = KEY.ROPE.FREQ_BASE KEY_ROPE_FREQ_BASE : str = KEY.ROPE.FREQ_BASE
KEY_ROPE_SCALING_TYPE = KEY.ROPE.SCALING_TYPE KEY_ROPE_SCALING_TYPE : str = KEY.ROPE.SCALING_TYPE
KEY_ROPE_SCALING_FACTOR = KEY.ROPE.SCALING_FACTOR KEY_ROPE_SCALING_FACTOR : str = KEY.ROPE.SCALING_FACTOR
KEY_ROPE_SCALING_ORIG_CTX_LEN = KEY.ROPE.SCALING_ORIG_CTX_LEN KEY_ROPE_SCALING_ORIG_CTX_LEN: str = KEY.ROPE.SCALING_ORIG_CTX_LEN
KEY_ROPE_SCALING_FINETUNED = KEY.ROPE.SCALING_FINETUNED KEY_ROPE_SCALING_FINETUNED : str = KEY.ROPE.SCALING_FINETUNED
# tokenization # tokenization
KEY_TOKENIZER_MODEL = KEY.TOKENIZER.MODEL KEY_TOKENIZER_MODEL : str = KEY.TOKENIZER.MODEL
KEY_TOKENIZER_LIST = KEY.TOKENIZER.LIST KEY_TOKENIZER_LIST : str = KEY.TOKENIZER.LIST
KEY_TOKENIZER_TOKEN_TYPE = KEY.TOKENIZER.TOKEN_TYPE KEY_TOKENIZER_TOKEN_TYPE: str = KEY.TOKENIZER.TOKEN_TYPE
KEY_TOKENIZER_SCORES = KEY.TOKENIZER.SCORES KEY_TOKENIZER_SCORES : str = KEY.TOKENIZER.SCORES
KEY_TOKENIZER_MERGES = KEY.TOKENIZER.MERGES KEY_TOKENIZER_MERGES : str = KEY.TOKENIZER.MERGES
KEY_TOKENIZER_BOS_ID = KEY.TOKENIZER.BOS_ID KEY_TOKENIZER_BOS_ID : str = KEY.TOKENIZER.BOS_ID
KEY_TOKENIZER_EOS_ID = KEY.TOKENIZER.EOS_ID KEY_TOKENIZER_EOS_ID : str = KEY.TOKENIZER.EOS_ID
KEY_TOKENIZER_UNK_ID = KEY.TOKENIZER.UNK_ID KEY_TOKENIZER_UNK_ID : str = KEY.TOKENIZER.UNK_ID
KEY_TOKENIZER_SEP_ID = KEY.TOKENIZER.SEP_ID KEY_TOKENIZER_SEP_ID : str = KEY.TOKENIZER.SEP_ID
KEY_TOKENIZER_PAD_ID = KEY.TOKENIZER.PAD_ID KEY_TOKENIZER_PAD_ID : str = KEY.TOKENIZER.PAD_ID
KEY_TOKENIZER_HF_JSON = KEY.TOKENIZER.HF_JSON KEY_TOKENIZER_HF_JSON : str = KEY.TOKENIZER.HF_JSON
KEY_TOKENIZER_RWKV = KEY.TOKENIZER.RWKV KEY_TOKENIZER_RWKV : str = KEY.TOKENIZER.RWKV

View file

@ -2,7 +2,7 @@ from __future__ import annotations
import os import os
from collections import OrderedDict from collections import OrderedDict
from typing import TypeVar, NamedTuple from typing import Any, TypeVar, NamedTuple, Dict, Type, Literal
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
@ -35,14 +35,14 @@ class ReaderField(NamedTuple):
# Data parts. Some types have multiple components, such as strings # Data parts. Some types have multiple components, such as strings
# that consist of a length followed by the string data. # that consist of a length followed by the string data.
parts: [npt.NDArray] = [] parts: list[npt.NDArray[Any]] = []
# Indexes into parts that we can call the actual data. For example # Indexes into parts that we can call the actual data. For example
# an array of strings will be populated with indexes to the actual # an array of strings will be populated with indexes to the actual
# string data. # string data.
data: [int] = [-1] data: list[int] = [-1]
types: [GGUFValueType] = [] types: list[GGUFValueType] = []
class ReaderTensor(NamedTuple): class ReaderTensor(NamedTuple):
@ -52,17 +52,17 @@ class ReaderTensor(NamedTuple):
n_elements: int n_elements: int
n_bytes: int n_bytes: int
data_offset: int data_offset: int
data: npt.NDArray data: npt.NDArray[Any]
field: ReaderField field: ReaderField
class GGUFReader: class GGUFReader:
byte_order: str = 'I' byte_order: Literal['I' | 'S' | '<'] = 'I'
fields: 'OrderedDict[str, ReaderField]' = {} fields: 'OrderedDict[str, ReaderField]' = OrderedDict()
tensors: [ReaderTensor] = [] tensors: list[ReaderTensor] = []
alignment: int = GGUF_DEFAULT_ALIGNMENT alignment: int = GGUF_DEFAULT_ALIGNMENT
_simple_value_map = { _simple_value_map: Dict[GGUFValueType, Type[Any]] = {
GGUFValueType.UINT8: np.uint8, GGUFValueType.UINT8: np.uint8,
GGUFValueType.INT8: np.int8, GGUFValueType.INT8: np.int8,
GGUFValueType.UINT16: np.uint16, GGUFValueType.UINT16: np.uint16,
@ -76,10 +76,12 @@ class GGUFReader:
GGUFValueType.BOOL: np.bool_, GGUFValueType.BOOL: np.bool_,
} }
_DT = TypeVar('T', bound = npt.DTypeLike) _DT = TypeVar('_DT', bound = npt.DTypeLike)
def _get(self, offset: int, dtype: _DT, count: int = 1, override_order: None | str = None) -> 'npt.NDArray[_DT]': def _get(self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I' | 'S' | '<'] = None) -> npt.NDArray[Any]:
end_offs = np.uint64(offset + np.uint64(dtype().nbytes * count)) count = int(count)
return (self.data[np.uint64(offset):end_offs] itemsize = int(np.empty([], dtype = dtype).itemsize)
end_offs = offset + itemsize * count
return (self.data[offset:end_offs]
.view(dtype = dtype)[:count] .view(dtype = dtype)[:count]
.newbyteorder(override_order or self.byte_order)) .newbyteorder(override_order or self.byte_order))
@ -87,80 +89,80 @@ class GGUFReader:
if field.name in self.fields: if field.name in self.fields:
raise KeyError(f'Duplicate {field.name} already in list at offset {field.offset}') raise KeyError(f'Duplicate {field.name} already in list at offset {field.offset}')
self.fields[field.name] = field self.fields[field.name] = field
return 0 if skip_sum else sum(part.nbytes for part in field.parts) return 0 if skip_sum else sum(int(part.nbytes) for part in field.parts)
def _get_str(self, offset: int) -> (npt.NDArray[np.uint64], npt.NDArray[np.uint8]): def _get_str(self, offset: int) -> tuple[npt.NDArray[np.uint64], npt.NDArray[np.uint8]]:
slen = self._get(offset, np.uint64) slen = self._get(offset, np.uint64)
return (slen, self._get(offset + 8, np.uint8, slen[0])) return (slen, self._get(offset + 8, np.uint8, slen[0]))
def _get_field_parts(self, orig_offs: int, raw_type: int) -> (int, [np.NDArray], [int], [GGUFValueType]): def _get_field_parts(self, orig_offs: int, raw_type: int) -> tuple[int, list[npt.NDArray[Any]], list[int], list[GGUFValueType]]:
offs = orig_offs offs = orig_offs
types = [] types: list[GGUFValueType] = []
gtype = GGUFValueType(raw_type) gtype = GGUFValueType(raw_type)
types.append(gtype) types.append(gtype)
# Handle strings. # Handle strings.
if gtype == GGUFValueType.STRING: if gtype == GGUFValueType.STRING:
parts = list(self._get_str(offs)) sparts: list[npt.NDArray[Any]] = list(self._get_str(offs))
size = sum(part.nbytes for part in parts) size = sum(int(part.nbytes) for part in sparts)
return (size, parts, [1], types) return (size, sparts, [1], types)
# Check if it's a simple scalar type. # Check if it's a simple scalar type.
nptype = self._simple_value_map.get(gtype) nptype = self._simple_value_map.get(gtype)
if nptype is not None: if nptype is not None:
val = self._get(offs, nptype) val = self._get(offs, nptype)
return (val.nbytes, [val], [0], types) return (int(val.nbytes), [val], [0], types)
# Handle arrays. # Handle arrays.
if gtype == GGUFValueType.ARRAY: if gtype == GGUFValueType.ARRAY:
raw_itype = self._get(offs, np.uint32) raw_itype = self._get(offs, np.uint32)
offs += raw_itype.nbytes offs += int(raw_itype.nbytes)
alen = self._get(offs, np.uint64) alen = self._get(offs, np.uint64)
offs += alen.nbytes offs += int(alen.nbytes)
parts = [raw_itype, alen] aparts: list[npt.NDArray[Any]] = [raw_itype, alen]
data_idxs = [] data_idxs: list[int] = []
for idx in range(alen[0]): for idx in range(alen[0]):
curr_size, curr_parts, curr_idxs, curr_types = self._get_field_parts(offs, raw_itype[0]) curr_size, curr_parts, curr_idxs, curr_types = self._get_field_parts(offs, raw_itype[0])
if idx == 0: if idx == 0:
types += curr_types types += curr_types
idxs_offs = len(parts) idxs_offs = len(aparts)
parts += curr_parts aparts += curr_parts
data_idxs += (idx + idxs_offs for idx in curr_idxs) data_idxs += (idx + idxs_offs for idx in curr_idxs)
offs += curr_size offs += curr_size
return (offs - orig_offs, parts, data_idxs, types) return (offs - orig_offs, aparts, data_idxs, types)
# We can't deal with this one. # We can't deal with this one.
raise ValueError('Unknown/unhandled field type {gtype}') raise ValueError('Unknown/unhandled field type {gtype}')
def _get_tensor(self, orig_offs: int) -> ReaderField: def _get_tensor(self, orig_offs: int) -> ReaderField:
offs = np.uint64(orig_offs) offs = orig_offs
name_len, name_data = self._get_str(offs) name_len, name_data = self._get_str(offs)
offs += name_len.nbytes + name_data.nbytes offs += int(name_len.nbytes + name_data.nbytes)
n_dims = self._get(offs, np.uint32) n_dims = self._get(offs, np.uint32)
offs += n_dims.nbytes offs += int(n_dims.nbytes)
dims = self._get(offs, np.uint64, n_dims[0]) dims = self._get(offs, np.uint64, n_dims[0])
offs += dims.nbytes offs += int(dims.nbytes)
raw_dtype = self._get(offs, np.uint32) raw_dtype = self._get(offs, np.uint32)
offs += raw_dtype.nbytes offs += int(raw_dtype.nbytes)
offset_tensor = self._get(offs, np.uint64) offset_tensor = self._get(offs, np.uint64)
offs += offset_tensor.nbytes offs += int(offset_tensor.nbytes)
return ReaderField( return ReaderField(
orig_offs, orig_offs,
str(name_data, encoding = 'utf-8'), str(bytes(name_data), encoding = 'utf-8'),
[name_len, name_data, n_dims, dims, raw_dtype, offset_tensor], [name_len, name_data, n_dims, dims, raw_dtype, offset_tensor],
[1, 3, 4, 5], [1, 3, 4, 5],
) )
def _build_fields(self, offs, count) -> int: def _build_fields(self, offs: int, count: int) -> int:
for _ in range(count): for _ in range(count):
orig_offs = offs orig_offs = offs
kv_klen, kv_kdata = self._get_str(offs) kv_klen, kv_kdata = self._get_str(offs)
offs += kv_klen.nbytes + kv_kdata.nbytes offs += int(kv_klen.nbytes + kv_kdata.nbytes)
raw_kv_type = self._get(offs, np.uint32) raw_kv_type = self._get(offs, np.uint32)
offs += raw_kv_type.nbytes offs += int(raw_kv_type.nbytes)
parts = [kv_klen, kv_kdata, raw_kv_type] parts: list[npt.NDArray[Any]] = [kv_klen, kv_kdata, raw_kv_type]
idxs_offs = len(parts) idxs_offs = len(parts)
field_size, field_parts, field_idxs, field_types = self._get_field_parts(offs, raw_kv_type[0]) field_size, field_parts, field_idxs, field_types = self._get_field_parts(offs, raw_kv_type[0])
parts += field_parts parts += field_parts
self._push_field(ReaderField( self._push_field(ReaderField(
orig_offs, orig_offs,
str(kv_kdata, encoding = 'utf-8'), str(bytes(kv_kdata), encoding = 'utf-8'),
parts, parts,
list(idx + idxs_offs for idx in field_idxs), list(idx + idxs_offs for idx in field_idxs),
field_types, field_types,
@ -168,23 +170,24 @@ class GGUFReader:
offs += field_size offs += field_size
return offs return offs
def _build_tensors_fields(self, offs, count) -> (int, [ReaderField]): def _build_tensors_fields(self, offs: int, count: int) -> tuple[int, list[ReaderField]]:
tensor_fields = [] tensor_fields = []
for _ in range(count): for _ in range(count):
field = self._get_tensor(offs) field = self._get_tensor(offs)
offs += sum(part.nbytes for part in field.parts) offs += sum(int(part.nbytes) for part in field.parts)
tensor_fields.append(field) tensor_fields.append(field)
return (offs, tensor_fields) return (offs, tensor_fields)
def _build_tensors(self, start_offs: int, fields: [ReaderField]) -> None: def _build_tensors(self, start_offs: int, fields: list[ReaderField]) -> None:
tensors = [] tensors = []
for field in fields: for field in fields:
_name_len, name_data, _n_dims, dims, raw_dtype, offset_tensor = field.parts _name_len, name_data, _n_dims, dims, raw_dtype, offset_tensor = field.parts
ggml_type = GGMLQuantizationType(raw_dtype[0]) ggml_type = GGMLQuantizationType(raw_dtype[0])
n_elems = np.prod(dims) n_elems = np.prod(dims)
block_size, type_size = GGML_QUANT_SIZES[ggml_type] block_size, type_size = GGML_QUANT_SIZES[ggml_type]
n_bytes = np.uint64(np.uint64(n_elems) * np.uint64(type_size)) // np.uint64(block_size) n_bytes = n_elems * type_size // block_size
data_offs = start_offs + offset_tensor[0] data_offs = int(start_offs + offset_tensor[0])
item_type: npt.DTypeLike
if ggml_type == GGMLQuantizationType.F32: if ggml_type == GGMLQuantizationType.F32:
item_count = n_elems item_count = n_elems
item_type = np.float32 item_type = np.float32
@ -195,7 +198,7 @@ class GGUFReader:
item_count = n_bytes item_count = n_bytes
item_type = np.uint8 item_type = np.uint8
tensors.append(ReaderTensor( tensors.append(ReaderTensor(
name = str(name_data, encoding = 'utf-8'), name = str(bytes(name_data), encoding = 'utf-8'),
tensor_type = ggml_type, tensor_type = ggml_type,
shape = dims, shape = dims,
n_elements = n_elems, n_elements = n_elems,
@ -207,31 +210,31 @@ class GGUFReader:
self.tensors = tensors self.tensors = tensors
def __init__(self, path: os.PathLike[str] | str, mode: str = 'r') -> None: def __init__(self, path: os.PathLike[str] | str, mode: Literal['r' | 'r+' | 'c'] = 'r') -> None:
self.data = np.memmap(path, mode = mode) self.data = np.memmap(path, mode = mode)
offs = 0 offs = 0
if self._get(offs, np.uint32, override_order = '<')[0] != GGUF_MAGIC: if self._get(offs, np.uint32, override_order = '<')[0] != GGUF_MAGIC:
raise ValueError('GGUF magic invalid') raise ValueError('GGUF magic invalid')
offs += 4 offs += 4
temp = self._get(offs, np.uint32) temp_version = self._get(offs, np.uint32)
if temp[0] > 2000: if temp_version[0] > 2000:
self.byte_order = 'S' self.byte_order = 'S'
temp = temp.newbyteorder(self.byte_order) temp_version = temp_version.newbyteorder(self.byte_order)
version = temp[0] version = temp_version[0]
if version not in READER_SUPPORTED_VERSIONS: if version not in READER_SUPPORTED_VERSIONS:
raise ValueError(f'Sorry, file appears to be version {version} which we cannot handle') raise ValueError(f'Sorry, file appears to be version {version} which we cannot handle')
offs += self._push_field(ReaderField(offs, 'GGUF.version', [temp], [0], [GGUFValueType.UINT32])) offs += self._push_field(ReaderField(offs, 'GGUF.version', [temp_version], [0], [GGUFValueType.UINT32]))
temp = self._get(offs, np.uint64, 2) temp_counts = self._get(offs, np.uint64, 2)
offs += self._push_field(ReaderField(offs, 'GGUF.tensor_count', [temp[:1]], [0], [GGUFValueType.UINT64])) offs += self._push_field(ReaderField(offs, 'GGUF.tensor_count', [temp_counts[:1]], [0], [GGUFValueType.UINT64]))
offs += self._push_field(ReaderField(offs, 'GGUF.kv_count', [temp[1:]], [0], [GGUFValueType.UINT64])) offs += self._push_field(ReaderField(offs, 'GGUF.kv_count', [temp_counts[1:]], [0], [GGUFValueType.UINT64]))
tensor_count, kv_count = temp tensor_count, kv_count = temp_counts
offs = self._build_fields(offs, kv_count) offs = self._build_fields(offs, kv_count)
offs, tensors_fields = self._build_tensors_fields(offs, tensor_count) offs, tensors_fields = self._build_tensors_fields(offs, tensor_count)
new_align = self.fields.get('general.alignment') new_align = self.fields.get('general.alignment')
if new_align is not None: if new_align is not None:
if new_align.types != [GGUFValueType.UINT64]: if new_align.types != [GGUFValueType.UINT64]:
raise ValueError('Bad type for general.alignment field') raise ValueError('Bad type for general.alignment field')
self.alignment = new_align.parts[-1] self.alignment = new_align.parts[-1][0]
padding = offs % self.alignment padding = offs % self.alignment
if padding != 0: if padding != 0:
offs += self.alignment - padding offs += self.alignment - padding
@ -258,7 +261,7 @@ if __name__ == "__main__":
if len(field.types) == 1: if len(field.types) == 1:
curr_type = field.types[0] curr_type = field.types[0]
if curr_type == GGUFValueType.STRING: if curr_type == GGUFValueType.STRING:
print(' = {0}'.format(repr(str(field.parts[-1], encoding='utf8')[:60])), end = '') print(' = {0}'.format(repr(str(bytes(field.parts[-1]), encoding='utf8')[:60])), end = '')
elif field.types[0] in reader._simple_value_map: elif field.types[0] in reader._simple_value_map:
print(' = {0}'.format(field.parts[-1][0]), end = '') print(' = {0}'.format(field.parts[-1][0]), end = '')
print() print()

View file

@ -231,7 +231,7 @@ class GGUFWriter:
tensor.tofile(self.temp_file) tensor.tofile(self.temp_file)
self.write_padding(self.temp_file, tensor.nbytes) self.write_padding(self.temp_file, tensor.nbytes)
def write_padding(self, fp: IO[bytes], n: int, align: int | None = None): def write_padding(self, fp: IO[bytes], n: int, align: int | None = None) -> None:
pad = GGUFWriter.ggml_pad(n, align if align is not None else self.data_alignment) - n pad = GGUFWriter.ggml_pad(n, align if align is not None else self.data_alignment) - n
if pad != 0: if pad != 0:
fp.write(bytes([0] * pad)) fp.write(bytes([0] * pad))
@ -280,7 +280,7 @@ class GGUFWriter:
self.add_string(KEY.GENERAL.AUTHOR, author) self.add_string(KEY.GENERAL.AUTHOR, author)
def add_tensor_data_layout(self, layout: str) -> None: def add_tensor_data_layout(self, layout: str) -> None:
self.add_string(KEY.TENSOR_DATA_LAYOUT.format(arch=self.arch), layout) self.add_string(KEY.LLM.TENSOR_DATA_LAYOUT.value.format(arch=self.arch), layout)
def add_url(self, url: str) -> None: def add_url(self, url: str) -> None:
self.add_string(KEY.GENERAL.URL, url) self.add_string(KEY.GENERAL.URL, url)
@ -310,66 +310,66 @@ class GGUFWriter:
def add_context_length(self, length: int) -> None: def add_context_length(self, length: int) -> None:
self.add_uint32( self.add_uint32(
KEY.LLM.CONTEXT_LENGTH.format(arch=self.arch), length) KEY.LLM.CONTEXT_LENGTH.value.format(arch=self.arch), length)
def add_embedding_length(self, length: int) -> None: def add_embedding_length(self, length: int) -> None:
self.add_uint32( self.add_uint32(
KEY.LLM.EMBEDDING_LENGTH.format(arch=self.arch), length) KEY.LLM.EMBEDDING_LENGTH.value.format(arch=self.arch), length)
def add_block_count(self, length: int) -> None: def add_block_count(self, length: int) -> None:
self.add_uint32( self.add_uint32(
KEY.LLM.BLOCK_COUNT.format(arch=self.arch), length) KEY.LLM.BLOCK_COUNT.value.format(arch=self.arch), length)
def add_feed_forward_length(self, length: int) -> None: def add_feed_forward_length(self, length: int) -> None:
self.add_uint32( self.add_uint32(
KEY.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length) KEY.LLM.FEED_FORWARD_LENGTH.value.format(arch=self.arch), length)
def add_parallel_residual(self, use: bool) -> None: def add_parallel_residual(self, use: bool) -> None:
self.add_bool( self.add_bool(
KEY.LLM.USE_PARALLEL_RESIDUAL.format(arch=self.arch), use) KEY.LLM.USE_PARALLEL_RESIDUAL.value.format(arch=self.arch), use)
def add_head_count(self, count: int) -> None: def add_head_count(self, count: int) -> None:
self.add_uint32( self.add_uint32(
KEY.ATTENTION.HEAD_COUNT.format(arch=self.arch), count) KEY.ATTENTION.HEAD_COUNT.value.format(arch=self.arch), count)
def add_head_count_kv(self, count: int) -> None: def add_head_count_kv(self, count: int) -> None:
self.add_uint32( self.add_uint32(
KEY.ATTENTION.HEAD_COUNT_KV.format(arch=self.arch), count) KEY.ATTENTION.HEAD_COUNT_KV.value.format(arch=self.arch), count)
def add_max_alibi_bias(self, bias: float) -> None: def add_max_alibi_bias(self, bias: float) -> None:
self.add_float32( self.add_float32(
KEY.ATTENTION.MAX_ALIBI_BIAS.format(arch=self.arch), bias) KEY.ATTENTION.MAX_ALIBI_BIAS.value.format(arch=self.arch), bias)
def add_clamp_kqv(self, value: float) -> None: def add_clamp_kqv(self, value: float) -> None:
self.add_float32( self.add_float32(
KEY.ATTENTION.CLAMP_KQV.format(arch=self.arch), value) KEY.ATTENTION.CLAMP_KQV.value.format(arch=self.arch), value)
def add_layer_norm_eps(self, value: float) -> None: def add_layer_norm_eps(self, value: float) -> None:
self.add_float32( self.add_float32(
KEY.ATTENTION.LAYERNORM_EPS.format(arch=self.arch), value) KEY.ATTENTION.LAYERNORM_EPS.value.format(arch=self.arch), value)
def add_layer_norm_rms_eps(self, value: float) -> None: def add_layer_norm_rms_eps(self, value: float) -> None:
self.add_float32( self.add_float32(
KEY.ATTENTION.LAYERNORM_RMS_EPS.format(arch=self.arch), value) KEY.ATTENTION.LAYERNORM_RMS_EPS.value.format(arch=self.arch), value)
def add_rope_dimension_count(self, count: int) -> None: def add_rope_dimension_count(self, count: int) -> None:
self.add_uint32( self.add_uint32(
KEY.ROPE.DIMENSION_COUNT.format(arch=self.arch), count) KEY.ROPE.DIMENSION_COUNT.value.format(arch=self.arch), count)
def add_rope_freq_base(self, value: float) -> None: def add_rope_freq_base(self, value: float) -> None:
self.add_float32(KEY.ROPE.FREQ_BASE.format(arch=self.arch), value) self.add_float32(KEY.ROPE.FREQ_BASE.value.format(arch=self.arch), value)
def add_rope_scaling_type(self, value: RopeScalingType) -> None: def add_rope_scaling_type(self, value: RopeScalingType) -> None:
self.add_string(KEY.ROPE.SCALING_TYPE.format(arch=self.arch), value.value) self.add_string(KEY.ROPE.SCALING_TYPE.value.format(arch=self.arch), value.value)
def add_rope_scaling_factor(self, value: float) -> None: def add_rope_scaling_factor(self, value: float) -> None:
self.add_float32(KEY.ROPE.SCALING_FACTOR.format(arch=self.arch), value) self.add_float32(KEY.ROPE.SCALING_FACTOR.value.format(arch=self.arch), value)
def add_rope_scaling_orig_ctx_len(self, value: int) -> None: def add_rope_scaling_orig_ctx_len(self, value: int) -> None:
self.add_uint32(KEY.ROPE.SCALING_ORIG_CTX_LEN.format(arch=self.arch), value) self.add_uint32(KEY.ROPE.SCALING_ORIG_CTX_LEN.value.format(arch=self.arch), value)
def add_rope_scaling_finetuned(self, value: bool) -> None: def add_rope_scaling_finetuned(self, value: bool) -> None:
self.add_bool(KEY.ROPE.SCALING_FINETUNED.format(arch=self.arch), value) self.add_bool(KEY.ROPE.SCALING_FINETUNED.value.format(arch=self.arch), value)
def add_tokenizer_model(self, model: str) -> None: def add_tokenizer_model(self, model: str) -> None:
self.add_string(KEY.TOKENIZER.MODEL, model) self.add_string(KEY.TOKENIZER.MODEL, model)