Various type annotation fixes.
This commit is contained in:
parent
8047aa192f
commit
d7688dc937
3 changed files with 174 additions and 171 deletions
|
@ -1,8 +1,8 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
from enum import Enum, IntEnum, auto
|
from enum import Enum, IntEnum, auto, StrEnum
|
||||||
from typing import Any, NamedTuple
|
from typing import Any, NamedTuple, Type
|
||||||
|
|
||||||
#
|
#
|
||||||
# constants
|
# constants
|
||||||
|
@ -16,63 +16,63 @@ GGUF_DEFAULT_ALIGNMENT = 32
|
||||||
# metadata keys
|
# metadata keys
|
||||||
#
|
#
|
||||||
|
|
||||||
class GeneralKeys(NamedTuple):
|
class GeneralKeys(StrEnum):
|
||||||
ARCHITECTURE = "general.architecture"
|
ARCHITECTURE : str = "general.architecture"
|
||||||
QUANTIZATION_VERSION = "general.quantization_version"
|
QUANTIZATION_VERSION: str = "general.quantization_version"
|
||||||
ALIGNMENT = "general.alignment"
|
ALIGNMENT : str = "general.alignment"
|
||||||
NAME = "general.name"
|
NAME : str = "general.name"
|
||||||
AUTHOR = "general.author"
|
AUTHOR : str = "general.author"
|
||||||
URL = "general.url"
|
URL : str = "general.url"
|
||||||
DESCRIPTION = "general.description"
|
DESCRIPTION : str = "general.description"
|
||||||
LICENSE = "general.license"
|
LICENSE : str = "general.license"
|
||||||
SOURCE_URL = "general.source.url"
|
SOURCE_URL : str = "general.source.url"
|
||||||
SOURCE_HF_REPO = "general.source.huggingface.repository"
|
SOURCE_HF_REPO : str = "general.source.huggingface.repository"
|
||||||
FILE_TYPE = "general.file_type"
|
FILE_TYPE : str = "general.file_type"
|
||||||
|
|
||||||
class AttentionKeys(NamedTuple):
|
class AttentionKeys(StrEnum):
|
||||||
HEAD_COUNT = "{arch}.attention.head_count"
|
HEAD_COUNT : str = "{arch}.attention.head_count"
|
||||||
HEAD_COUNT_KV = "{arch}.attention.head_count_kv"
|
HEAD_COUNT_KV : str = "{arch}.attention.head_count_kv"
|
||||||
MAX_ALIBI_BIAS = "{arch}.attention.max_alibi_bias"
|
MAX_ALIBI_BIAS : str = "{arch}.attention.max_alibi_bias"
|
||||||
CLAMP_KQV = "{arch}.attention.clamp_kqv"
|
CLAMP_KQV : str = "{arch}.attention.clamp_kqv"
|
||||||
LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
|
LAYERNORM_EPS : str = "{arch}.attention.layer_norm_epsilon"
|
||||||
LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
|
LAYERNORM_RMS_EPS: str = "{arch}.attention.layer_norm_rms_epsilon"
|
||||||
|
|
||||||
class RopeKeys(NamedTuple):
|
class RopeKeys(StrEnum):
|
||||||
DIMENSION_COUNT = "{arch}.rope.dimension_count"
|
DIMENSION_COUNT : str = "{arch}.rope.dimension_count"
|
||||||
FREQ_BASE = "{arch}.rope.freq_base"
|
FREQ_BASE : str = "{arch}.rope.freq_base"
|
||||||
SCALING_TYPE = "{arch}.rope.scaling.type"
|
SCALING_TYPE : str = "{arch}.rope.scaling.type"
|
||||||
SCALING_FACTOR = "{arch}.rope.scaling.factor"
|
SCALING_FACTOR : str = "{arch}.rope.scaling.factor"
|
||||||
SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
|
SCALING_ORIG_CTX_LEN: str = "{arch}.rope.scaling.original_context_length"
|
||||||
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
|
SCALING_FINETUNED : str = "{arch}.rope.scaling.finetuned"
|
||||||
|
|
||||||
class TokenizerKeys(NamedTuple):
|
class TokenizerKeys(StrEnum):
|
||||||
MODEL = "tokenizer.ggml.model"
|
MODEL : str = "tokenizer.ggml.model"
|
||||||
LIST = "tokenizer.ggml.tokens"
|
LIST : str = "tokenizer.ggml.tokens"
|
||||||
TOKEN_TYPE = "tokenizer.ggml.token_type"
|
TOKEN_TYPE: str = "tokenizer.ggml.token_type"
|
||||||
SCORES = "tokenizer.ggml.scores"
|
SCORES : str = "tokenizer.ggml.scores"
|
||||||
MERGES = "tokenizer.ggml.merges"
|
MERGES : str = "tokenizer.ggml.merges"
|
||||||
BOS_ID = "tokenizer.ggml.bos_token_id"
|
BOS_ID : str = "tokenizer.ggml.bos_token_id"
|
||||||
EOS_ID = "tokenizer.ggml.eos_token_id"
|
EOS_ID : str = "tokenizer.ggml.eos_token_id"
|
||||||
UNK_ID = "tokenizer.ggml.unknown_token_id"
|
UNK_ID : str = "tokenizer.ggml.unknown_token_id"
|
||||||
SEP_ID = "tokenizer.ggml.seperator_token_id"
|
SEP_ID : str = "tokenizer.ggml.seperator_token_id"
|
||||||
PAD_ID = "tokenizer.ggml.padding_token_id"
|
PAD_ID : str = "tokenizer.ggml.padding_token_id"
|
||||||
HF_JSON = "tokenizer.huggingface.json"
|
HF_JSON : str = "tokenizer.huggingface.json"
|
||||||
RWKV = "tokenizer.rwkv.world"
|
RWKV : str = "tokenizer.rwkv.world"
|
||||||
|
|
||||||
class LLMKeys(NamedTuple):
|
class LLMKeys(StrEnum):
|
||||||
CONTEXT_LENGTH = "{arch}.context_length"
|
CONTEXT_LENGTH : str = "{arch}.context_length"
|
||||||
EMBEDDING_LENGTH = "{arch}.embedding_length"
|
EMBEDDING_LENGTH : str = "{arch}.embedding_length"
|
||||||
BLOCK_COUNT = "{arch}.block_count"
|
BLOCK_COUNT : str = "{arch}.block_count"
|
||||||
FEED_FORWARD_LENGTH = "{arch}.feed_forward_length"
|
FEED_FORWARD_LENGTH : str = "{arch}.feed_forward_length"
|
||||||
USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
|
USE_PARALLEL_RESIDUAL: str = "{arch}.use_parallel_residual"
|
||||||
TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout"
|
TENSOR_DATA_LAYOUT : str = "{arch}.tensor_data_layout"
|
||||||
|
|
||||||
class Keys(NamedTuple):
|
class Keys(NamedTuple):
|
||||||
GENERAL = GeneralKeys()
|
GENERAL : Type[GeneralKeys ] = GeneralKeys
|
||||||
LLM = LLMKeys()
|
LLM : Type[LLMKeys ] = LLMKeys
|
||||||
ATTENTION = AttentionKeys()
|
ATTENTION: Type[AttentionKeys] = AttentionKeys
|
||||||
ROPE = RopeKeys()
|
ROPE : Type[RopeKeys ] = RopeKeys
|
||||||
TOKENIZER = TokenizerKeys()
|
TOKENIZER: Type[TokenizerKeys] = TokenizerKeys
|
||||||
|
|
||||||
KEY = Keys()
|
KEY = Keys()
|
||||||
|
|
||||||
|
@ -418,52 +418,52 @@ GGML_QUANT_SIZES = {
|
||||||
# Aliases for backward compatibility.
|
# Aliases for backward compatibility.
|
||||||
|
|
||||||
# general
|
# general
|
||||||
KEY_GENERAL_ARCHITECTURE = KEY.GENERAL.ARCHITECTURE
|
KEY_GENERAL_ARCHITECTURE : str = KEY.GENERAL.ARCHITECTURE
|
||||||
KEY_GENERAL_QUANTIZATION_VERSION = KEY.GENERAL.QUANTIZATION_VERSION
|
KEY_GENERAL_QUANTIZATION_VERSION: str = KEY.GENERAL.QUANTIZATION_VERSION
|
||||||
KEY_GENERAL_ALIGNMENT = KEY.GENERAL.ALIGNMENT
|
KEY_GENERAL_ALIGNMENT : str = KEY.GENERAL.ALIGNMENT
|
||||||
KEY_GENERAL_NAME = KEY.GENERAL.NAME
|
KEY_GENERAL_NAME : str = KEY.GENERAL.NAME
|
||||||
KEY_GENERAL_AUTHOR = KEY.GENERAL.AUTHOR
|
KEY_GENERAL_AUTHOR : str = KEY.GENERAL.AUTHOR
|
||||||
KEY_GENERAL_URL = KEY.GENERAL.URL
|
KEY_GENERAL_URL : str = KEY.GENERAL.URL
|
||||||
KEY_GENERAL_DESCRIPTION = KEY.GENERAL.DESCRIPTION
|
KEY_GENERAL_DESCRIPTION : str = KEY.GENERAL.DESCRIPTION
|
||||||
KEY_GENERAL_LICENSE = KEY.GENERAL.LICENSE
|
KEY_GENERAL_LICENSE : str = KEY.GENERAL.LICENSE
|
||||||
KEY_GENERAL_SOURCE_URL = KEY.GENERAL.SOURCE_URL
|
KEY_GENERAL_SOURCE_URL : str = KEY.GENERAL.SOURCE_URL
|
||||||
KEY_GENERAL_SOURCE_HF_REPO = KEY.GENERAL.SOURCE_HF_REPO
|
KEY_GENERAL_SOURCE_HF_REPO : str = KEY.GENERAL.SOURCE_HF_REPO
|
||||||
KEY_GENERAL_FILE_TYPE = KEY.GENERAL.FILE_TYPE
|
KEY_GENERAL_FILE_TYPE : str = KEY.GENERAL.FILE_TYPE
|
||||||
|
|
||||||
# LLM
|
# LLM
|
||||||
KEY_CONTEXT_LENGTH = KEY.LLM.CONTEXT_LENGTH
|
KEY_CONTEXT_LENGTH : str = KEY.LLM.CONTEXT_LENGTH
|
||||||
KEY_EMBEDDING_LENGTH = KEY.LLM.EMBEDDING_LENGTH
|
KEY_EMBEDDING_LENGTH : str = KEY.LLM.EMBEDDING_LENGTH
|
||||||
KEY_BLOCK_COUNT = KEY.LLM.BLOCK_COUNT
|
KEY_BLOCK_COUNT : str = KEY.LLM.BLOCK_COUNT
|
||||||
KEY_FEED_FORWARD_LENGTH = KEY.LLM.FEED_FORWARD_LENGTH
|
KEY_FEED_FORWARD_LENGTH : str = KEY.LLM.FEED_FORWARD_LENGTH
|
||||||
KEY_USE_PARALLEL_RESIDUAL = KEY.LLM.USE_PARALLEL_RESIDUAL
|
KEY_USE_PARALLEL_RESIDUAL: str = KEY.LLM.USE_PARALLEL_RESIDUAL
|
||||||
KEY_TENSOR_DATA_LAYOUT = KEY.LLM.TENSOR_DATA_LAYOUT
|
KEY_TENSOR_DATA_LAYOUT : str = KEY.LLM.TENSOR_DATA_LAYOUT
|
||||||
|
|
||||||
# attention
|
# attention
|
||||||
KEY_ATTENTION_HEAD_COUNT = KEY.ATTENTION.HEAD_COUNT
|
KEY_ATTENTION_HEAD_COUNT : str = KEY.ATTENTION.HEAD_COUNT
|
||||||
KEY_ATTENTION_HEAD_COUNT_KV = KEY.ATTENTION.HEAD_COUNT_KV
|
KEY_ATTENTION_HEAD_COUNT_KV : str = KEY.ATTENTION.HEAD_COUNT_KV
|
||||||
KEY_ATTENTION_MAX_ALIBI_BIAS = KEY.ATTENTION.MAX_ALIBI_BIAS
|
KEY_ATTENTION_MAX_ALIBI_BIAS : str = KEY.ATTENTION.MAX_ALIBI_BIAS
|
||||||
KEY_ATTENTION_CLAMP_KQV = KEY.ATTENTION.CLAMP_KQV
|
KEY_ATTENTION_CLAMP_KQV : str = KEY.ATTENTION.CLAMP_KQV
|
||||||
KEY_ATTENTION_LAYERNORM_EPS = KEY.ATTENTION.LAYERNORM_EPS
|
KEY_ATTENTION_LAYERNORM_EPS : str = KEY.ATTENTION.LAYERNORM_EPS
|
||||||
KEY_ATTENTION_LAYERNORM_RMS_EPS = KEY.ATTENTION.LAYERNORM_RMS_EPS
|
KEY_ATTENTION_LAYERNORM_RMS_EPS: str = KEY.ATTENTION.LAYERNORM_RMS_EPS
|
||||||
|
|
||||||
# RoPE
|
# RoPE
|
||||||
KEY_ROPE_DIMENSION_COUNT = KEY.ROPE.DIMENSION_COUNT
|
KEY_ROPE_DIMENSION_COUNT : str = KEY.ROPE.DIMENSION_COUNT
|
||||||
KEY_ROPE_FREQ_BASE = KEY.ROPE.FREQ_BASE
|
KEY_ROPE_FREQ_BASE : str = KEY.ROPE.FREQ_BASE
|
||||||
KEY_ROPE_SCALING_TYPE = KEY.ROPE.SCALING_TYPE
|
KEY_ROPE_SCALING_TYPE : str = KEY.ROPE.SCALING_TYPE
|
||||||
KEY_ROPE_SCALING_FACTOR = KEY.ROPE.SCALING_FACTOR
|
KEY_ROPE_SCALING_FACTOR : str = KEY.ROPE.SCALING_FACTOR
|
||||||
KEY_ROPE_SCALING_ORIG_CTX_LEN = KEY.ROPE.SCALING_ORIG_CTX_LEN
|
KEY_ROPE_SCALING_ORIG_CTX_LEN: str = KEY.ROPE.SCALING_ORIG_CTX_LEN
|
||||||
KEY_ROPE_SCALING_FINETUNED = KEY.ROPE.SCALING_FINETUNED
|
KEY_ROPE_SCALING_FINETUNED : str = KEY.ROPE.SCALING_FINETUNED
|
||||||
|
|
||||||
# tokenization
|
# tokenization
|
||||||
KEY_TOKENIZER_MODEL = KEY.TOKENIZER.MODEL
|
KEY_TOKENIZER_MODEL : str = KEY.TOKENIZER.MODEL
|
||||||
KEY_TOKENIZER_LIST = KEY.TOKENIZER.LIST
|
KEY_TOKENIZER_LIST : str = KEY.TOKENIZER.LIST
|
||||||
KEY_TOKENIZER_TOKEN_TYPE = KEY.TOKENIZER.TOKEN_TYPE
|
KEY_TOKENIZER_TOKEN_TYPE: str = KEY.TOKENIZER.TOKEN_TYPE
|
||||||
KEY_TOKENIZER_SCORES = KEY.TOKENIZER.SCORES
|
KEY_TOKENIZER_SCORES : str = KEY.TOKENIZER.SCORES
|
||||||
KEY_TOKENIZER_MERGES = KEY.TOKENIZER.MERGES
|
KEY_TOKENIZER_MERGES : str = KEY.TOKENIZER.MERGES
|
||||||
KEY_TOKENIZER_BOS_ID = KEY.TOKENIZER.BOS_ID
|
KEY_TOKENIZER_BOS_ID : str = KEY.TOKENIZER.BOS_ID
|
||||||
KEY_TOKENIZER_EOS_ID = KEY.TOKENIZER.EOS_ID
|
KEY_TOKENIZER_EOS_ID : str = KEY.TOKENIZER.EOS_ID
|
||||||
KEY_TOKENIZER_UNK_ID = KEY.TOKENIZER.UNK_ID
|
KEY_TOKENIZER_UNK_ID : str = KEY.TOKENIZER.UNK_ID
|
||||||
KEY_TOKENIZER_SEP_ID = KEY.TOKENIZER.SEP_ID
|
KEY_TOKENIZER_SEP_ID : str = KEY.TOKENIZER.SEP_ID
|
||||||
KEY_TOKENIZER_PAD_ID = KEY.TOKENIZER.PAD_ID
|
KEY_TOKENIZER_PAD_ID : str = KEY.TOKENIZER.PAD_ID
|
||||||
KEY_TOKENIZER_HF_JSON = KEY.TOKENIZER.HF_JSON
|
KEY_TOKENIZER_HF_JSON : str = KEY.TOKENIZER.HF_JSON
|
||||||
KEY_TOKENIZER_RWKV = KEY.TOKENIZER.RWKV
|
KEY_TOKENIZER_RWKV : str = KEY.TOKENIZER.RWKV
|
||||||
|
|
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import TypeVar, NamedTuple
|
from typing import Any, TypeVar, NamedTuple, Dict, Type, Literal
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
|
@ -35,14 +35,14 @@ class ReaderField(NamedTuple):
|
||||||
|
|
||||||
# Data parts. Some types have multiple components, such as strings
|
# Data parts. Some types have multiple components, such as strings
|
||||||
# that consist of a length followed by the string data.
|
# that consist of a length followed by the string data.
|
||||||
parts: [npt.NDArray] = []
|
parts: list[npt.NDArray[Any]] = []
|
||||||
|
|
||||||
# Indexes into parts that we can call the actual data. For example
|
# Indexes into parts that we can call the actual data. For example
|
||||||
# an array of strings will be populated with indexes to the actual
|
# an array of strings will be populated with indexes to the actual
|
||||||
# string data.
|
# string data.
|
||||||
data: [int] = [-1]
|
data: list[int] = [-1]
|
||||||
|
|
||||||
types: [GGUFValueType] = []
|
types: list[GGUFValueType] = []
|
||||||
|
|
||||||
|
|
||||||
class ReaderTensor(NamedTuple):
|
class ReaderTensor(NamedTuple):
|
||||||
|
@ -52,17 +52,17 @@ class ReaderTensor(NamedTuple):
|
||||||
n_elements: int
|
n_elements: int
|
||||||
n_bytes: int
|
n_bytes: int
|
||||||
data_offset: int
|
data_offset: int
|
||||||
data: npt.NDArray
|
data: npt.NDArray[Any]
|
||||||
field: ReaderField
|
field: ReaderField
|
||||||
|
|
||||||
|
|
||||||
class GGUFReader:
|
class GGUFReader:
|
||||||
byte_order: str = 'I'
|
byte_order: Literal['I' | 'S' | '<'] = 'I'
|
||||||
fields: 'OrderedDict[str, ReaderField]' = {}
|
fields: 'OrderedDict[str, ReaderField]' = OrderedDict()
|
||||||
tensors: [ReaderTensor] = []
|
tensors: list[ReaderTensor] = []
|
||||||
alignment: int = GGUF_DEFAULT_ALIGNMENT
|
alignment: int = GGUF_DEFAULT_ALIGNMENT
|
||||||
|
|
||||||
_simple_value_map = {
|
_simple_value_map: Dict[GGUFValueType, Type[Any]] = {
|
||||||
GGUFValueType.UINT8: np.uint8,
|
GGUFValueType.UINT8: np.uint8,
|
||||||
GGUFValueType.INT8: np.int8,
|
GGUFValueType.INT8: np.int8,
|
||||||
GGUFValueType.UINT16: np.uint16,
|
GGUFValueType.UINT16: np.uint16,
|
||||||
|
@ -76,10 +76,12 @@ class GGUFReader:
|
||||||
GGUFValueType.BOOL: np.bool_,
|
GGUFValueType.BOOL: np.bool_,
|
||||||
}
|
}
|
||||||
|
|
||||||
_DT = TypeVar('T', bound = npt.DTypeLike)
|
_DT = TypeVar('_DT', bound = npt.DTypeLike)
|
||||||
def _get(self, offset: int, dtype: _DT, count: int = 1, override_order: None | str = None) -> 'npt.NDArray[_DT]':
|
def _get(self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I' | 'S' | '<'] = None) -> npt.NDArray[Any]:
|
||||||
end_offs = np.uint64(offset + np.uint64(dtype().nbytes * count))
|
count = int(count)
|
||||||
return (self.data[np.uint64(offset):end_offs]
|
itemsize = int(np.empty([], dtype = dtype).itemsize)
|
||||||
|
end_offs = offset + itemsize * count
|
||||||
|
return (self.data[offset:end_offs]
|
||||||
.view(dtype = dtype)[:count]
|
.view(dtype = dtype)[:count]
|
||||||
.newbyteorder(override_order or self.byte_order))
|
.newbyteorder(override_order or self.byte_order))
|
||||||
|
|
||||||
|
@ -87,80 +89,80 @@ class GGUFReader:
|
||||||
if field.name in self.fields:
|
if field.name in self.fields:
|
||||||
raise KeyError(f'Duplicate {field.name} already in list at offset {field.offset}')
|
raise KeyError(f'Duplicate {field.name} already in list at offset {field.offset}')
|
||||||
self.fields[field.name] = field
|
self.fields[field.name] = field
|
||||||
return 0 if skip_sum else sum(part.nbytes for part in field.parts)
|
return 0 if skip_sum else sum(int(part.nbytes) for part in field.parts)
|
||||||
|
|
||||||
def _get_str(self, offset: int) -> (npt.NDArray[np.uint64], npt.NDArray[np.uint8]):
|
def _get_str(self, offset: int) -> tuple[npt.NDArray[np.uint64], npt.NDArray[np.uint8]]:
|
||||||
slen = self._get(offset, np.uint64)
|
slen = self._get(offset, np.uint64)
|
||||||
return (slen, self._get(offset + 8, np.uint8, slen[0]))
|
return (slen, self._get(offset + 8, np.uint8, slen[0]))
|
||||||
|
|
||||||
def _get_field_parts(self, orig_offs: int, raw_type: int) -> (int, [np.NDArray], [int], [GGUFValueType]):
|
def _get_field_parts(self, orig_offs: int, raw_type: int) -> tuple[int, list[npt.NDArray[Any]], list[int], list[GGUFValueType]]:
|
||||||
offs = orig_offs
|
offs = orig_offs
|
||||||
types = []
|
types: list[GGUFValueType] = []
|
||||||
gtype = GGUFValueType(raw_type)
|
gtype = GGUFValueType(raw_type)
|
||||||
types.append(gtype)
|
types.append(gtype)
|
||||||
# Handle strings.
|
# Handle strings.
|
||||||
if gtype == GGUFValueType.STRING:
|
if gtype == GGUFValueType.STRING:
|
||||||
parts = list(self._get_str(offs))
|
sparts: list[npt.NDArray[Any]] = list(self._get_str(offs))
|
||||||
size = sum(part.nbytes for part in parts)
|
size = sum(int(part.nbytes) for part in sparts)
|
||||||
return (size, parts, [1], types)
|
return (size, sparts, [1], types)
|
||||||
# Check if it's a simple scalar type.
|
# Check if it's a simple scalar type.
|
||||||
nptype = self._simple_value_map.get(gtype)
|
nptype = self._simple_value_map.get(gtype)
|
||||||
if nptype is not None:
|
if nptype is not None:
|
||||||
val = self._get(offs, nptype)
|
val = self._get(offs, nptype)
|
||||||
return (val.nbytes, [val], [0], types)
|
return (int(val.nbytes), [val], [0], types)
|
||||||
# Handle arrays.
|
# Handle arrays.
|
||||||
if gtype == GGUFValueType.ARRAY:
|
if gtype == GGUFValueType.ARRAY:
|
||||||
raw_itype = self._get(offs, np.uint32)
|
raw_itype = self._get(offs, np.uint32)
|
||||||
offs += raw_itype.nbytes
|
offs += int(raw_itype.nbytes)
|
||||||
alen = self._get(offs, np.uint64)
|
alen = self._get(offs, np.uint64)
|
||||||
offs += alen.nbytes
|
offs += int(alen.nbytes)
|
||||||
parts = [raw_itype, alen]
|
aparts: list[npt.NDArray[Any]] = [raw_itype, alen]
|
||||||
data_idxs = []
|
data_idxs: list[int] = []
|
||||||
for idx in range(alen[0]):
|
for idx in range(alen[0]):
|
||||||
curr_size, curr_parts, curr_idxs, curr_types = self._get_field_parts(offs, raw_itype[0])
|
curr_size, curr_parts, curr_idxs, curr_types = self._get_field_parts(offs, raw_itype[0])
|
||||||
if idx == 0:
|
if idx == 0:
|
||||||
types += curr_types
|
types += curr_types
|
||||||
idxs_offs = len(parts)
|
idxs_offs = len(aparts)
|
||||||
parts += curr_parts
|
aparts += curr_parts
|
||||||
data_idxs += (idx + idxs_offs for idx in curr_idxs)
|
data_idxs += (idx + idxs_offs for idx in curr_idxs)
|
||||||
offs += curr_size
|
offs += curr_size
|
||||||
return (offs - orig_offs, parts, data_idxs, types)
|
return (offs - orig_offs, aparts, data_idxs, types)
|
||||||
# We can't deal with this one.
|
# We can't deal with this one.
|
||||||
raise ValueError('Unknown/unhandled field type {gtype}')
|
raise ValueError('Unknown/unhandled field type {gtype}')
|
||||||
|
|
||||||
def _get_tensor(self, orig_offs: int) -> ReaderField:
|
def _get_tensor(self, orig_offs: int) -> ReaderField:
|
||||||
offs = np.uint64(orig_offs)
|
offs = orig_offs
|
||||||
name_len, name_data = self._get_str(offs)
|
name_len, name_data = self._get_str(offs)
|
||||||
offs += name_len.nbytes + name_data.nbytes
|
offs += int(name_len.nbytes + name_data.nbytes)
|
||||||
n_dims = self._get(offs, np.uint32)
|
n_dims = self._get(offs, np.uint32)
|
||||||
offs += n_dims.nbytes
|
offs += int(n_dims.nbytes)
|
||||||
dims = self._get(offs, np.uint64, n_dims[0])
|
dims = self._get(offs, np.uint64, n_dims[0])
|
||||||
offs += dims.nbytes
|
offs += int(dims.nbytes)
|
||||||
raw_dtype = self._get(offs, np.uint32)
|
raw_dtype = self._get(offs, np.uint32)
|
||||||
offs += raw_dtype.nbytes
|
offs += int(raw_dtype.nbytes)
|
||||||
offset_tensor = self._get(offs, np.uint64)
|
offset_tensor = self._get(offs, np.uint64)
|
||||||
offs += offset_tensor.nbytes
|
offs += int(offset_tensor.nbytes)
|
||||||
return ReaderField(
|
return ReaderField(
|
||||||
orig_offs,
|
orig_offs,
|
||||||
str(name_data, encoding = 'utf-8'),
|
str(bytes(name_data), encoding = 'utf-8'),
|
||||||
[name_len, name_data, n_dims, dims, raw_dtype, offset_tensor],
|
[name_len, name_data, n_dims, dims, raw_dtype, offset_tensor],
|
||||||
[1, 3, 4, 5],
|
[1, 3, 4, 5],
|
||||||
)
|
)
|
||||||
|
|
||||||
def _build_fields(self, offs, count) -> int:
|
def _build_fields(self, offs: int, count: int) -> int:
|
||||||
for _ in range(count):
|
for _ in range(count):
|
||||||
orig_offs = offs
|
orig_offs = offs
|
||||||
kv_klen, kv_kdata = self._get_str(offs)
|
kv_klen, kv_kdata = self._get_str(offs)
|
||||||
offs += kv_klen.nbytes + kv_kdata.nbytes
|
offs += int(kv_klen.nbytes + kv_kdata.nbytes)
|
||||||
raw_kv_type = self._get(offs, np.uint32)
|
raw_kv_type = self._get(offs, np.uint32)
|
||||||
offs += raw_kv_type.nbytes
|
offs += int(raw_kv_type.nbytes)
|
||||||
parts = [kv_klen, kv_kdata, raw_kv_type]
|
parts: list[npt.NDArray[Any]] = [kv_klen, kv_kdata, raw_kv_type]
|
||||||
idxs_offs = len(parts)
|
idxs_offs = len(parts)
|
||||||
field_size, field_parts, field_idxs, field_types = self._get_field_parts(offs, raw_kv_type[0])
|
field_size, field_parts, field_idxs, field_types = self._get_field_parts(offs, raw_kv_type[0])
|
||||||
parts += field_parts
|
parts += field_parts
|
||||||
self._push_field(ReaderField(
|
self._push_field(ReaderField(
|
||||||
orig_offs,
|
orig_offs,
|
||||||
str(kv_kdata, encoding = 'utf-8'),
|
str(bytes(kv_kdata), encoding = 'utf-8'),
|
||||||
parts,
|
parts,
|
||||||
list(idx + idxs_offs for idx in field_idxs),
|
list(idx + idxs_offs for idx in field_idxs),
|
||||||
field_types,
|
field_types,
|
||||||
|
@ -168,23 +170,24 @@ class GGUFReader:
|
||||||
offs += field_size
|
offs += field_size
|
||||||
return offs
|
return offs
|
||||||
|
|
||||||
def _build_tensors_fields(self, offs, count) -> (int, [ReaderField]):
|
def _build_tensors_fields(self, offs: int, count: int) -> tuple[int, list[ReaderField]]:
|
||||||
tensor_fields = []
|
tensor_fields = []
|
||||||
for _ in range(count):
|
for _ in range(count):
|
||||||
field = self._get_tensor(offs)
|
field = self._get_tensor(offs)
|
||||||
offs += sum(part.nbytes for part in field.parts)
|
offs += sum(int(part.nbytes) for part in field.parts)
|
||||||
tensor_fields.append(field)
|
tensor_fields.append(field)
|
||||||
return (offs, tensor_fields)
|
return (offs, tensor_fields)
|
||||||
|
|
||||||
def _build_tensors(self, start_offs: int, fields: [ReaderField]) -> None:
|
def _build_tensors(self, start_offs: int, fields: list[ReaderField]) -> None:
|
||||||
tensors = []
|
tensors = []
|
||||||
for field in fields:
|
for field in fields:
|
||||||
_name_len, name_data, _n_dims, dims, raw_dtype, offset_tensor = field.parts
|
_name_len, name_data, _n_dims, dims, raw_dtype, offset_tensor = field.parts
|
||||||
ggml_type = GGMLQuantizationType(raw_dtype[0])
|
ggml_type = GGMLQuantizationType(raw_dtype[0])
|
||||||
n_elems = np.prod(dims)
|
n_elems = np.prod(dims)
|
||||||
block_size, type_size = GGML_QUANT_SIZES[ggml_type]
|
block_size, type_size = GGML_QUANT_SIZES[ggml_type]
|
||||||
n_bytes = np.uint64(np.uint64(n_elems) * np.uint64(type_size)) // np.uint64(block_size)
|
n_bytes = n_elems * type_size // block_size
|
||||||
data_offs = start_offs + offset_tensor[0]
|
data_offs = int(start_offs + offset_tensor[0])
|
||||||
|
item_type: npt.DTypeLike
|
||||||
if ggml_type == GGMLQuantizationType.F32:
|
if ggml_type == GGMLQuantizationType.F32:
|
||||||
item_count = n_elems
|
item_count = n_elems
|
||||||
item_type = np.float32
|
item_type = np.float32
|
||||||
|
@ -195,7 +198,7 @@ class GGUFReader:
|
||||||
item_count = n_bytes
|
item_count = n_bytes
|
||||||
item_type = np.uint8
|
item_type = np.uint8
|
||||||
tensors.append(ReaderTensor(
|
tensors.append(ReaderTensor(
|
||||||
name = str(name_data, encoding = 'utf-8'),
|
name = str(bytes(name_data), encoding = 'utf-8'),
|
||||||
tensor_type = ggml_type,
|
tensor_type = ggml_type,
|
||||||
shape = dims,
|
shape = dims,
|
||||||
n_elements = n_elems,
|
n_elements = n_elems,
|
||||||
|
@ -207,31 +210,31 @@ class GGUFReader:
|
||||||
self.tensors = tensors
|
self.tensors = tensors
|
||||||
|
|
||||||
|
|
||||||
def __init__(self, path: os.PathLike[str] | str, mode: str = 'r') -> None:
|
def __init__(self, path: os.PathLike[str] | str, mode: Literal['r' | 'r+' | 'c'] = 'r') -> None:
|
||||||
self.data = np.memmap(path, mode = mode)
|
self.data = np.memmap(path, mode = mode)
|
||||||
offs = 0
|
offs = 0
|
||||||
if self._get(offs, np.uint32, override_order = '<')[0] != GGUF_MAGIC:
|
if self._get(offs, np.uint32, override_order = '<')[0] != GGUF_MAGIC:
|
||||||
raise ValueError('GGUF magic invalid')
|
raise ValueError('GGUF magic invalid')
|
||||||
offs += 4
|
offs += 4
|
||||||
temp = self._get(offs, np.uint32)
|
temp_version = self._get(offs, np.uint32)
|
||||||
if temp[0] > 2000:
|
if temp_version[0] > 2000:
|
||||||
self.byte_order = 'S'
|
self.byte_order = 'S'
|
||||||
temp = temp.newbyteorder(self.byte_order)
|
temp_version = temp_version.newbyteorder(self.byte_order)
|
||||||
version = temp[0]
|
version = temp_version[0]
|
||||||
if version not in READER_SUPPORTED_VERSIONS:
|
if version not in READER_SUPPORTED_VERSIONS:
|
||||||
raise ValueError(f'Sorry, file appears to be version {version} which we cannot handle')
|
raise ValueError(f'Sorry, file appears to be version {version} which we cannot handle')
|
||||||
offs += self._push_field(ReaderField(offs, 'GGUF.version', [temp], [0], [GGUFValueType.UINT32]))
|
offs += self._push_field(ReaderField(offs, 'GGUF.version', [temp_version], [0], [GGUFValueType.UINT32]))
|
||||||
temp = self._get(offs, np.uint64, 2)
|
temp_counts = self._get(offs, np.uint64, 2)
|
||||||
offs += self._push_field(ReaderField(offs, 'GGUF.tensor_count', [temp[:1]], [0], [GGUFValueType.UINT64]))
|
offs += self._push_field(ReaderField(offs, 'GGUF.tensor_count', [temp_counts[:1]], [0], [GGUFValueType.UINT64]))
|
||||||
offs += self._push_field(ReaderField(offs, 'GGUF.kv_count', [temp[1:]], [0], [GGUFValueType.UINT64]))
|
offs += self._push_field(ReaderField(offs, 'GGUF.kv_count', [temp_counts[1:]], [0], [GGUFValueType.UINT64]))
|
||||||
tensor_count, kv_count = temp
|
tensor_count, kv_count = temp_counts
|
||||||
offs = self._build_fields(offs, kv_count)
|
offs = self._build_fields(offs, kv_count)
|
||||||
offs, tensors_fields = self._build_tensors_fields(offs, tensor_count)
|
offs, tensors_fields = self._build_tensors_fields(offs, tensor_count)
|
||||||
new_align = self.fields.get('general.alignment')
|
new_align = self.fields.get('general.alignment')
|
||||||
if new_align is not None:
|
if new_align is not None:
|
||||||
if new_align.types != [GGUFValueType.UINT64]:
|
if new_align.types != [GGUFValueType.UINT64]:
|
||||||
raise ValueError('Bad type for general.alignment field')
|
raise ValueError('Bad type for general.alignment field')
|
||||||
self.alignment = new_align.parts[-1]
|
self.alignment = new_align.parts[-1][0]
|
||||||
padding = offs % self.alignment
|
padding = offs % self.alignment
|
||||||
if padding != 0:
|
if padding != 0:
|
||||||
offs += self.alignment - padding
|
offs += self.alignment - padding
|
||||||
|
@ -258,7 +261,7 @@ if __name__ == "__main__":
|
||||||
if len(field.types) == 1:
|
if len(field.types) == 1:
|
||||||
curr_type = field.types[0]
|
curr_type = field.types[0]
|
||||||
if curr_type == GGUFValueType.STRING:
|
if curr_type == GGUFValueType.STRING:
|
||||||
print(' = {0}'.format(repr(str(field.parts[-1], encoding='utf8')[:60])), end = '')
|
print(' = {0}'.format(repr(str(bytes(field.parts[-1]), encoding='utf8')[:60])), end = '')
|
||||||
elif field.types[0] in reader._simple_value_map:
|
elif field.types[0] in reader._simple_value_map:
|
||||||
print(' = {0}'.format(field.parts[-1][0]), end = '')
|
print(' = {0}'.format(field.parts[-1][0]), end = '')
|
||||||
print()
|
print()
|
||||||
|
|
|
@ -231,7 +231,7 @@ class GGUFWriter:
|
||||||
tensor.tofile(self.temp_file)
|
tensor.tofile(self.temp_file)
|
||||||
self.write_padding(self.temp_file, tensor.nbytes)
|
self.write_padding(self.temp_file, tensor.nbytes)
|
||||||
|
|
||||||
def write_padding(self, fp: IO[bytes], n: int, align: int | None = None):
|
def write_padding(self, fp: IO[bytes], n: int, align: int | None = None) -> None:
|
||||||
pad = GGUFWriter.ggml_pad(n, align if align is not None else self.data_alignment) - n
|
pad = GGUFWriter.ggml_pad(n, align if align is not None else self.data_alignment) - n
|
||||||
if pad != 0:
|
if pad != 0:
|
||||||
fp.write(bytes([0] * pad))
|
fp.write(bytes([0] * pad))
|
||||||
|
@ -280,7 +280,7 @@ class GGUFWriter:
|
||||||
self.add_string(KEY.GENERAL.AUTHOR, author)
|
self.add_string(KEY.GENERAL.AUTHOR, author)
|
||||||
|
|
||||||
def add_tensor_data_layout(self, layout: str) -> None:
|
def add_tensor_data_layout(self, layout: str) -> None:
|
||||||
self.add_string(KEY.TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
|
self.add_string(KEY.LLM.TENSOR_DATA_LAYOUT.value.format(arch=self.arch), layout)
|
||||||
|
|
||||||
def add_url(self, url: str) -> None:
|
def add_url(self, url: str) -> None:
|
||||||
self.add_string(KEY.GENERAL.URL, url)
|
self.add_string(KEY.GENERAL.URL, url)
|
||||||
|
@ -310,66 +310,66 @@ class GGUFWriter:
|
||||||
|
|
||||||
def add_context_length(self, length: int) -> None:
|
def add_context_length(self, length: int) -> None:
|
||||||
self.add_uint32(
|
self.add_uint32(
|
||||||
KEY.LLM.CONTEXT_LENGTH.format(arch=self.arch), length)
|
KEY.LLM.CONTEXT_LENGTH.value.format(arch=self.arch), length)
|
||||||
|
|
||||||
def add_embedding_length(self, length: int) -> None:
|
def add_embedding_length(self, length: int) -> None:
|
||||||
self.add_uint32(
|
self.add_uint32(
|
||||||
KEY.LLM.EMBEDDING_LENGTH.format(arch=self.arch), length)
|
KEY.LLM.EMBEDDING_LENGTH.value.format(arch=self.arch), length)
|
||||||
|
|
||||||
def add_block_count(self, length: int) -> None:
|
def add_block_count(self, length: int) -> None:
|
||||||
self.add_uint32(
|
self.add_uint32(
|
||||||
KEY.LLM.BLOCK_COUNT.format(arch=self.arch), length)
|
KEY.LLM.BLOCK_COUNT.value.format(arch=self.arch), length)
|
||||||
|
|
||||||
def add_feed_forward_length(self, length: int) -> None:
|
def add_feed_forward_length(self, length: int) -> None:
|
||||||
self.add_uint32(
|
self.add_uint32(
|
||||||
KEY.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length)
|
KEY.LLM.FEED_FORWARD_LENGTH.value.format(arch=self.arch), length)
|
||||||
|
|
||||||
def add_parallel_residual(self, use: bool) -> None:
|
def add_parallel_residual(self, use: bool) -> None:
|
||||||
self.add_bool(
|
self.add_bool(
|
||||||
KEY.LLM.USE_PARALLEL_RESIDUAL.format(arch=self.arch), use)
|
KEY.LLM.USE_PARALLEL_RESIDUAL.value.format(arch=self.arch), use)
|
||||||
|
|
||||||
def add_head_count(self, count: int) -> None:
|
def add_head_count(self, count: int) -> None:
|
||||||
self.add_uint32(
|
self.add_uint32(
|
||||||
KEY.ATTENTION.HEAD_COUNT.format(arch=self.arch), count)
|
KEY.ATTENTION.HEAD_COUNT.value.format(arch=self.arch), count)
|
||||||
|
|
||||||
def add_head_count_kv(self, count: int) -> None:
|
def add_head_count_kv(self, count: int) -> None:
|
||||||
self.add_uint32(
|
self.add_uint32(
|
||||||
KEY.ATTENTION.HEAD_COUNT_KV.format(arch=self.arch), count)
|
KEY.ATTENTION.HEAD_COUNT_KV.value.format(arch=self.arch), count)
|
||||||
|
|
||||||
def add_max_alibi_bias(self, bias: float) -> None:
|
def add_max_alibi_bias(self, bias: float) -> None:
|
||||||
self.add_float32(
|
self.add_float32(
|
||||||
KEY.ATTENTION.MAX_ALIBI_BIAS.format(arch=self.arch), bias)
|
KEY.ATTENTION.MAX_ALIBI_BIAS.value.format(arch=self.arch), bias)
|
||||||
|
|
||||||
def add_clamp_kqv(self, value: float) -> None:
|
def add_clamp_kqv(self, value: float) -> None:
|
||||||
self.add_float32(
|
self.add_float32(
|
||||||
KEY.ATTENTION.CLAMP_KQV.format(arch=self.arch), value)
|
KEY.ATTENTION.CLAMP_KQV.value.format(arch=self.arch), value)
|
||||||
|
|
||||||
def add_layer_norm_eps(self, value: float) -> None:
|
def add_layer_norm_eps(self, value: float) -> None:
|
||||||
self.add_float32(
|
self.add_float32(
|
||||||
KEY.ATTENTION.LAYERNORM_EPS.format(arch=self.arch), value)
|
KEY.ATTENTION.LAYERNORM_EPS.value.format(arch=self.arch), value)
|
||||||
|
|
||||||
def add_layer_norm_rms_eps(self, value: float) -> None:
|
def add_layer_norm_rms_eps(self, value: float) -> None:
|
||||||
self.add_float32(
|
self.add_float32(
|
||||||
KEY.ATTENTION.LAYERNORM_RMS_EPS.format(arch=self.arch), value)
|
KEY.ATTENTION.LAYERNORM_RMS_EPS.value.format(arch=self.arch), value)
|
||||||
|
|
||||||
def add_rope_dimension_count(self, count: int) -> None:
|
def add_rope_dimension_count(self, count: int) -> None:
|
||||||
self.add_uint32(
|
self.add_uint32(
|
||||||
KEY.ROPE.DIMENSION_COUNT.format(arch=self.arch), count)
|
KEY.ROPE.DIMENSION_COUNT.value.format(arch=self.arch), count)
|
||||||
|
|
||||||
def add_rope_freq_base(self, value: float) -> None:
|
def add_rope_freq_base(self, value: float) -> None:
|
||||||
self.add_float32(KEY.ROPE.FREQ_BASE.format(arch=self.arch), value)
|
self.add_float32(KEY.ROPE.FREQ_BASE.value.format(arch=self.arch), value)
|
||||||
|
|
||||||
def add_rope_scaling_type(self, value: RopeScalingType) -> None:
|
def add_rope_scaling_type(self, value: RopeScalingType) -> None:
|
||||||
self.add_string(KEY.ROPE.SCALING_TYPE.format(arch=self.arch), value.value)
|
self.add_string(KEY.ROPE.SCALING_TYPE.value.format(arch=self.arch), value.value)
|
||||||
|
|
||||||
def add_rope_scaling_factor(self, value: float) -> None:
|
def add_rope_scaling_factor(self, value: float) -> None:
|
||||||
self.add_float32(KEY.ROPE.SCALING_FACTOR.format(arch=self.arch), value)
|
self.add_float32(KEY.ROPE.SCALING_FACTOR.value.format(arch=self.arch), value)
|
||||||
|
|
||||||
def add_rope_scaling_orig_ctx_len(self, value: int) -> None:
|
def add_rope_scaling_orig_ctx_len(self, value: int) -> None:
|
||||||
self.add_uint32(KEY.ROPE.SCALING_ORIG_CTX_LEN.format(arch=self.arch), value)
|
self.add_uint32(KEY.ROPE.SCALING_ORIG_CTX_LEN.value.format(arch=self.arch), value)
|
||||||
|
|
||||||
def add_rope_scaling_finetuned(self, value: bool) -> None:
|
def add_rope_scaling_finetuned(self, value: bool) -> None:
|
||||||
self.add_bool(KEY.ROPE.SCALING_FINETUNED.format(arch=self.arch), value)
|
self.add_bool(KEY.ROPE.SCALING_FINETUNED.value.format(arch=self.arch), value)
|
||||||
|
|
||||||
def add_tokenizer_model(self, model: str) -> None:
|
def add_tokenizer_model(self, model: str) -> None:
|
||||||
self.add_string(KEY.TOKENIZER.MODEL, model)
|
self.add_string(KEY.TOKENIZER.MODEL, model)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue