style cleanup with flake8

This commit is contained in:
Jared Van Bortel 2023-11-07 21:05:41 -05:00
parent ce865b3ce3
commit f364636b2e
5 changed files with 331 additions and 296 deletions

View file

@ -16,6 +16,7 @@ GGUF_DEFAULT_ALIGNMENT = 32
# metadata keys
#
class GeneralKeys(StrEnum):
ARCHITECTURE: str = "general.architecture"
QUANTIZATION_VERSION: str = "general.quantization_version"
@ -29,6 +30,7 @@ class GeneralKeys(StrEnum):
SOURCE_HF_REPO: str = "general.source.huggingface.repository"
FILE_TYPE: str = "general.file_type"
class AttentionKeys(StrEnum):
HEAD_COUNT: str = "{arch}.attention.head_count"
HEAD_COUNT_KV: str = "{arch}.attention.head_count_kv"
@ -37,6 +39,7 @@ class AttentionKeys(StrEnum):
LAYERNORM_EPS: str = "{arch}.attention.layer_norm_epsilon"
LAYERNORM_RMS_EPS: str = "{arch}.attention.layer_norm_rms_epsilon"
class RopeKeys(StrEnum):
DIMENSION_COUNT: str = "{arch}.rope.dimension_count"
FREQ_BASE: str = "{arch}.rope.freq_base"
@ -45,6 +48,7 @@ class RopeKeys(StrEnum):
SCALING_ORIG_CTX_LEN: str = "{arch}.rope.scaling.original_context_length"
SCALING_FINETUNED: str = "{arch}.rope.scaling.finetuned"
class TokenizerKeys(StrEnum):
MODEL: str = "tokenizer.ggml.model"
LIST: str = "tokenizer.ggml.tokens"
@ -59,6 +63,7 @@ class TokenizerKeys(StrEnum):
HF_JSON: str = "tokenizer.huggingface.json"
RWKV: str = "tokenizer.rwkv.world"
class LLMKeys(StrEnum):
CONTEXT_LENGTH: str = "{arch}.context_length"
EMBEDDING_LENGTH: str = "{arch}.embedding_length"
@ -67,6 +72,7 @@ class LLMKeys(StrEnum):
USE_PARALLEL_RESIDUAL: str = "{arch}.use_parallel_residual"
TENSOR_DATA_LAYOUT: str = "{arch}.tensor_data_layout"
class Keys(NamedTuple):
GENERAL: Type[GeneralKeys] = GeneralKeys
LLM: Type[LLMKeys] = LLMKeys
@ -74,6 +80,7 @@ class Keys(NamedTuple):
ROPE: Type[RopeKeys] = RopeKeys
TOKENIZER: Type[TokenizerKeys] = TokenizerKeys
KEY = Keys()
#
@ -321,13 +328,14 @@ MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
],
MODEL_ARCH.PERSIMMON: [
MODEL_TENSOR.ROPE_FREQS,
]
],
}
#
# types
#
class TokenType(IntEnum):
NORMAL = 1
UNKNOWN = 2
@ -336,11 +344,13 @@ class TokenType(IntEnum):
UNUSED = 5
BYTE = 6
class RopeScalingType(Enum):
NONE = 'none'
LINEAR = 'linear'
YARN = 'yarn'
class GGMLQuantizationType(IntEnum):
F32 = 0
F16 = 1
@ -357,6 +367,7 @@ class GGMLQuantizationType(IntEnum):
Q6_K = 14
Q8_K = 15
class GGUFEndian(IntEnum):
LITTLE = 0
BIG = 1
@ -379,7 +390,7 @@ class GGUFValueType(IntEnum):
@staticmethod
def get_type(val: Any) -> GGUFValueType:
if isinstance(val, str) or isinstance(val, bytes) or isinstance(val, bytearray):
if isinstance(val, (str, bytes, bytearray)):
return GGUFValueType.STRING
elif isinstance(val, list):
return GGUFValueType.ARRAY
@ -391,9 +402,10 @@ class GGUFValueType(IntEnum):
return GGUFValueType.INT32
# TODO: need help with 64-bit types in Python
else:
print("Unknown type: "+str(type(val)))
print("Unknown type:", type(val))
sys.exit()
# Note: Does not support GGML_QKK_64
QK_K = 256
# Items here are (block size, type size)

View file

@ -20,7 +20,7 @@ from gguf.constants import (
GGUF_MAGIC,
GGUF_VERSION,
GGMLQuantizationType,
GGUFValueType
GGUFValueType,
)
READER_SUPPORTED_VERSIONS = [2, GGUF_VERSION]
@ -76,14 +76,49 @@ class GGUFReader:
GGUFValueType.BOOL: np.bool_,
}
def __init__(self, path: os.PathLike[str] | str, mode: Literal['r' | 'r+' | 'c'] = 'r'):
self.data = np.memmap(path, mode = mode)
offs = 0
if self._get(offs, np.uint32, override_order = '<')[0] != GGUF_MAGIC:
raise ValueError('GGUF magic invalid')
offs += 4
temp_version = self._get(offs, np.uint32)
if temp_version[0] > 2000:
self.byte_order = 'S'
temp_version = temp_version.newbyteorder(self.byte_order)
version = temp_version[0]
if version not in READER_SUPPORTED_VERSIONS:
raise ValueError(f'Sorry, file appears to be version {version} which we cannot handle')
offs += self._push_field(ReaderField(offs, 'GGUF.version', [temp_version], [0], [GGUFValueType.UINT32]))
temp_counts = self._get(offs, np.uint64, 2)
offs += self._push_field(ReaderField(offs, 'GGUF.tensor_count', [temp_counts[:1]], [0], [GGUFValueType.UINT64]))
offs += self._push_field(ReaderField(offs, 'GGUF.kv_count', [temp_counts[1:]], [0], [GGUFValueType.UINT64]))
tensor_count, kv_count = temp_counts
offs = self._build_fields(offs, kv_count)
offs, tensors_fields = self._build_tensors_fields(offs, tensor_count)
new_align = self.fields.get('general.alignment')
if new_align is not None:
if new_align.types != [GGUFValueType.UINT64]:
raise ValueError('Bad type for general.alignment field')
self.alignment = new_align.parts[-1][0]
padding = offs % self.alignment
if padding != 0:
offs += self.alignment - padding
self._build_tensors(offs, tensors_fields)
_DT = TypeVar('_DT', bound = npt.DTypeLike)
def _get(self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I' | 'S' | '<'] = None) -> npt.NDArray[Any]:
def _get(
self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I' | 'S' | '<'] = None,
) -> npt.NDArray[Any]:
count = int(count)
itemsize = int(np.empty([], dtype = dtype).itemsize)
end_offs = offset + itemsize * count
return (self.data[offset:end_offs]
return (
self.data[offset:end_offs]
.view(dtype = dtype)[:count]
.newbyteorder(override_order or self.byte_order))
.newbyteorder(override_order or self.byte_order)
)
def _push_field(self, field: ReaderField, skip_sum: bool = False) -> int:
if field.name in self.fields:
@ -93,9 +128,11 @@ class GGUFReader:
def _get_str(self, offset: int) -> tuple[npt.NDArray[np.uint64], npt.NDArray[np.uint8]]:
slen = self._get(offset, np.uint64)
return (slen, self._get(offset + 8, np.uint8, slen[0]))
return slen, self._get(offset + 8, np.uint8, slen[0])
def _get_field_parts(self, orig_offs: int, raw_type: int) -> tuple[int, list[npt.NDArray[Any]], list[int], list[GGUFValueType]]:
def _get_field_parts(
self, orig_offs: int, raw_type: int,
) -> tuple[int, list[npt.NDArray[Any]], list[int], list[GGUFValueType]]:
offs = orig_offs
types: list[GGUFValueType] = []
gtype = GGUFValueType(raw_type)
@ -104,12 +141,12 @@ class GGUFReader:
if gtype == GGUFValueType.STRING:
sparts: list[npt.NDArray[Any]] = list(self._get_str(offs))
size = sum(int(part.nbytes) for part in sparts)
return (size, sparts, [1], types)
return size, sparts, [1], types
# Check if it's a simple scalar type.
nptype = self._simple_value_map.get(gtype)
if nptype is not None:
val = self._get(offs, nptype)
return (int(val.nbytes), [val], [0], types)
return int(val.nbytes), [val], [0], types
# Handle arrays.
if gtype == GGUFValueType.ARRAY:
raw_itype = self._get(offs, np.uint32)
@ -126,7 +163,7 @@ class GGUFReader:
aparts += curr_parts
data_idxs += (idx + idxs_offs for idx in curr_idxs)
offs += curr_size
return (offs - orig_offs, aparts, data_idxs, types)
return offs - orig_offs, aparts, data_idxs, types
# We can't deal with this one.
raise ValueError('Unknown/unhandled field type {gtype}')
@ -164,7 +201,7 @@ class GGUFReader:
orig_offs,
str(bytes(kv_kdata), encoding = 'utf-8'),
parts,
list(idx + idxs_offs for idx in field_idxs),
[idx + idxs_offs for idx in field_idxs],
field_types,
), skip_sum = True)
offs += field_size
@ -176,7 +213,7 @@ class GGUFReader:
field = self._get_tensor(offs)
offs += sum(int(part.nbytes) for part in field.parts)
tensor_fields.append(field)
return (offs, tensor_fields)
return offs, tensor_fields
def _build_tensors(self, start_offs: int, fields: list[ReaderField]) -> None:
tensors = []
@ -210,37 +247,6 @@ class GGUFReader:
self.tensors = tensors
def __init__(self, path: os.PathLike[str] | str, mode: Literal['r' | 'r+' | 'c'] = 'r') -> None:
self.data = np.memmap(path, mode = mode)
offs = 0
if self._get(offs, np.uint32, override_order = '<')[0] != GGUF_MAGIC:
raise ValueError('GGUF magic invalid')
offs += 4
temp_version = self._get(offs, np.uint32)
if temp_version[0] > 2000:
self.byte_order = 'S'
temp_version = temp_version.newbyteorder(self.byte_order)
version = temp_version[0]
if version not in READER_SUPPORTED_VERSIONS:
raise ValueError(f'Sorry, file appears to be version {version} which we cannot handle')
offs += self._push_field(ReaderField(offs, 'GGUF.version', [temp_version], [0], [GGUFValueType.UINT32]))
temp_counts = self._get(offs, np.uint64, 2)
offs += self._push_field(ReaderField(offs, 'GGUF.tensor_count', [temp_counts[:1]], [0], [GGUFValueType.UINT64]))
offs += self._push_field(ReaderField(offs, 'GGUF.kv_count', [temp_counts[1:]], [0], [GGUFValueType.UINT64]))
tensor_count, kv_count = temp_counts
offs = self._build_fields(offs, kv_count)
offs, tensors_fields = self._build_tensors_fields(offs, tensor_count)
new_align = self.fields.get('general.alignment')
if new_align is not None:
if new_align.types != [GGUFValueType.UINT64]:
raise ValueError('Bad type for general.alignment field')
self.alignment = new_align.parts[-1][0]
padding = offs % self.alignment
if padding != 0:
offs += self.alignment - padding
self._build_tensors(offs, tensors_fields)
# Example usage:
if __name__ == "__main__":
if len(sys.argv) < 2:
@ -250,7 +256,7 @@ if __name__ == "__main__":
reader = GGUFReader(sys.argv[1], 'r')
print(f'\n* Dumping {len(reader.fields)} key/value pair(s)')
for n, field in enumerate(reader.fields.values(), 1):
if len(field.types) == 0:
if not field.types:
pretty_type = 'N/A'
elif field.types[0] == GGUFValueType.ARRAY:
nest_count = len(field.types) - 1

View file

@ -19,7 +19,7 @@ from .constants import (
GGUFEndian,
GGUFValueType,
RopeScalingType,
TokenType
TokenType,
)
@ -29,6 +29,7 @@ class WriterState(Enum):
KV_DATA = auto()
TI_DATA = auto()
class GGUFWriter:
fout: BufferedWriter
temp_file: tempfile.SpooledTemporaryFile[bytes] | None
@ -47,16 +48,10 @@ class GGUFWriter:
GGUFValueType.BOOL: "?",
}
def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
pack_prefix = ''
if not skip_pack_prefix:
pack_prefix = '<' if self.endianess == GGUFEndian.LITTLE else '>'
return struct.pack(f'{pack_prefix}{fmt}', value)
def _write_packed(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> None:
self.fout.write(self._pack(fmt, value, skip_pack_prefix))
def __init__(self, path: os.PathLike[str] | str, arch: str, use_temp_file: bool = True, endianess: GGUFEndian = GGUFEndian.LITTLE) -> None:
def __init__(
self, path: os.PathLike[str] | str, arch: str, use_temp_file: bool = True,
endianess: GGUFEndian = GGUFEndian.LITTLE,
):
self.fout = open(path, "wb")
self.arch = arch
self.endianess = endianess
@ -69,8 +64,9 @@ class GGUFWriter:
self.use_temp_file = use_temp_file
self.temp_file = None
self.tensors = []
print("gguf: This GGUF file is for {0} Endian only"
.format("Big" if self.endianess == GGUFEndian.BIG else "Little"))
print("gguf: This GGUF file is for {0} Endian only".format(
"Big" if self.endianess == GGUFEndian.BIG else "Little",
))
self.state = WriterState.EMPTY
self.add_architecture()
@ -150,7 +146,7 @@ class GGUFWriter:
self.add_val(val, GGUFValueType.BOOL)
def add_string(self, key: str, val: str) -> None:
if len(val) == 0:
if not val:
return
self.add_key(key)
self.add_val(val, GGUFValueType.STRING)
@ -177,7 +173,7 @@ class GGUFWriter:
encoded_val = val.encode("utf8") if isinstance(val, str) else val
self.kv_data += self._pack("Q", len(encoded_val))
self.kv_data += encoded_val
elif vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and len(val) > 0:
elif vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and val:
ltype = GGUFValueType.get_type(val[0])
if not all(GGUFValueType.get_type(i) is ltype for i in val[1:]):
raise ValueError("All items in a GGUF array should be of the same type")
@ -192,7 +188,10 @@ class GGUFWriter:
def ggml_pad(x: int, n: int) -> int:
return ((x + n - 1) // n) * n
def add_tensor_info(self, name: str, tensor_shape: Sequence[int], tensor_dtype: np.dtype[np.float16] | np.dtype[np.float32], tensor_nbytes: int, raw_dtype: GGMLQuantizationType | None = None) -> None:
def add_tensor_info(
self, name: str, tensor_shape: Sequence[int], tensor_dtype: np.dtype[np.float16] | np.dtype[np.float32],
tensor_nbytes: int, raw_dtype: GGMLQuantizationType | None = None,
) -> None:
if self.state is not WriterState.EMPTY:
raise ValueError(f'Expected output file to be empty, got {self.state}')
@ -215,7 +214,10 @@ class GGUFWriter:
self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment)
self.ti_data_count += 1
def add_tensor(self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None, raw_dtype: GGMLQuantizationType | None = None) -> None:
def add_tensor(
self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
raw_dtype: GGMLQuantizationType | None = None,
) -> None:
if self.endianess == GGUFEndian.BIG:
tensor.byteswap(inplace=True)
if self.use_temp_file and self.temp_file is None:
@ -402,3 +404,12 @@ class GGUFWriter:
def add_pad_token_id(self, id: int) -> None:
self.add_uint32(KEY.TOKENIZER.PAD_ID, id)
def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
pack_prefix = ''
if not skip_pack_prefix:
pack_prefix = '<' if self.endianess == GGUFEndian.LITTLE else '>'
return struct.pack(f'{pack_prefix}{fmt}', value)
def _write_packed(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> None:
self.fout.write(self._pack(fmt, value, skip_pack_prefix))

View file

@ -127,7 +127,7 @@ class TensorNameMap:
"layers.{bid}.attention.wo", # llama-pth
"encoder.layer.{bid}.attention.output.dense", # bert
"transformer.h.{bid}.attn.out_proj", # gpt-j
"language_model.encoder.layers.{bid}.self_attention.dense" # persimmon
"language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
),
# Rotary embeddings
@ -193,7 +193,7 @@ class TensorNameMap:
MODEL_TENSOR.ROPE_FREQS: (
"language_model.encoder.layers.{bid}.self_attention.rotary_emb.inv_freq", # persimmon
)
),
}
mapping: dict[str, tuple[MODEL_TENSOR, str]]
@ -225,7 +225,7 @@ class TensorNameMap:
if key.endswith(suffix):
result = self.mapping.get(key[:-len(suffix)])
if result is not None:
return (result[0], result[1] + suffix)
return result[0], result[1] + suffix
return None
def get_name(self, key: str, try_suffixes: Sequence[str] = ()) -> str | None:
@ -252,5 +252,6 @@ class TensorNameMap:
def __repr__(self) -> str:
return repr(self.mapping)
def get_tensor_name_map(arch: MODEL_ARCH, n_blocks: int) -> TensorNameMap:
return TensorNameMap(arch, n_blocks)

View file

@ -28,6 +28,26 @@ class SpecialVocab:
self.special_token_types = ('bos', 'eos', 'unk', 'sep', 'pad')
self._load(Path(path))
def __repr__(self) -> str:
return f'<SpecialVocab with {len(self.merges)} merges and special tokens {self.special_token_ids or "unset"}>'
def add_to_gguf(self, gw: GGUFWriter, quiet: bool = False) -> None:
if self.merges:
if not quiet:
print(f'gguf: Adding {len(self.merges)} merge(s).')
gw.add_token_merges(self.merges)
for typ, tokid in self.special_token_ids.items():
handler: Callable[[int], None] | None = getattr(gw, f'add_{typ}_token_id', None)
if handler is None:
print(
f'gguf: WARNING: No handler for special token type {typ} with id {tokid} - skipping',
file = sys.stderr,
)
continue
if not quiet:
print(f'gguf: Setting special token type {typ} to {tokid}')
handler(tokid)
def _load(self, path: Path) -> None:
if not self._try_load_from_tokenizer_json(path):
self._try_load_from_config_json(path)
@ -38,9 +58,10 @@ class SpecialVocab:
if self.n_vocab is None or tid < self.n_vocab:
self.special_token_ids[typ] = tid
return
print(f'gguf: WARNING: Special token type {typ}, id {tid} out of range, must be under {self.n_vocab} - skipping',
file = sys.stderr)
print(
f'gguf: WARNING: Special token type {typ}, id {tid} out of range, must be under {self.n_vocab} - skipping',
file = sys.stderr,
)
def _try_load_from_tokenizer_json(self, path: Path) -> bool:
tokenizer_file = path / 'tokenizer.json'
@ -50,7 +71,7 @@ class SpecialVocab:
tokenizer = json.load(f)
if self.load_merges:
merges = tokenizer.get('model', {}).get('merges')
if isinstance(merges, list) and len(merges) > 0 and isinstance(merges[0], str):
if isinstance(merges, list) and merges and isinstance(merges[0], str):
self.merges = merges
tokenizer_config_file = path / 'tokenizer_config.json'
added_tokens = tokenizer.get('added_tokens')
@ -70,9 +91,10 @@ class SpecialVocab:
else:
continue
# We only need the first match here.
maybe_token_id = next((
atok.get('id') for atok in added_tokens
if atok.get('content') == tc_content), None)
maybe_token_id = next(
(atok.get('id') for atok in added_tokens if atok.get('content') == tc_content),
None,
)
self._set_special_token(typ, maybe_token_id)
return True
@ -85,20 +107,3 @@ class SpecialVocab:
for typ in self.special_token_types:
self._set_special_token(typ, config.get(f'{typ}_token_id'))
return True
def add_to_gguf(self, gw: GGUFWriter, quiet: bool = False) -> None:
if len(self.merges) > 0:
if not quiet:
print(f'gguf: Adding {len(self.merges)} merge(s).')
gw.add_token_merges(self.merges)
for typ, tokid in self.special_token_ids.items():
handler: Callable[[int], None] | None = getattr(gw, f'add_{typ}_token_id', None)
if handler is None:
print(f'gguf: WARNING: No handler for special token type {typ} with id {tokid} - skipping', file = sys.stderr)
continue
if not quiet:
print(f'gguf: Setting special token type {typ} to {tokid}')
handler(tokid)
def __repr__(self) -> str:
return f'<SpecialVocab with {len(self.merges)} merges and special tokens {self.special_token_ids or "unset"}>'