improve performance

This commit is contained in:
isotr0py 2024-10-31 22:13:15 +08:00
parent 5ce2dbcf38
commit bcef54e10a

View file

@ -6,6 +6,7 @@ from __future__ import annotations
import logging
import os
import struct
from collections import OrderedDict
from typing import Any, Literal, NamedTuple, TypeVar, Union
@ -92,7 +93,7 @@ class GGUFReader:
offs = 0
# Check for GGUF magic
if self._get(offs, np.uint32, override_order = '<')[0] != GGUF_MAGIC:
if struct.unpack("<I", self.data.read(4))[0] != GGUF_MAGIC:
raise ValueError('GGUF magic invalid')
offs += 4
@ -170,9 +171,14 @@ class GGUFReader:
self.fields[field.name] = field
return 0 if skip_sum else sum(int(part.nbytes) for part in field.parts)
def _get_str(self, offset: int) -> tuple[npt.NDArray[np.uint64], npt.NDArray[np.uint8]]:
slen = self._get(offset, np.uint64)
return slen, self._get(offset + 8, np.uint8, slen[0])
def _get_str(self, offset: int, return_size=False) -> tuple[npt.NDArray[np.uint64], npt.NDArray[np.uint8]]:
self.data.seek(offset)
slen = struct.unpack('<Q', self.data.read(8))
sdata = struct.unpack('<'+'B'*slen[0], self.data.read(1*slen[0]))
output = (slen, sdata)
if return_size:
output += (8 + slen[0],)
return output
def _get_field_parts(
self, orig_offs: int, raw_type: int,
@ -183,8 +189,8 @@ class GGUFReader:
types.append(gtype)
# Handle strings.
if gtype == GGUFValueType.STRING:
sparts: list[npt.NDArray[Any]] = list(self._get_str(offs))
size = sum(int(part.nbytes) for part in sparts)
sparts: list[npt.NDArray[Any]] = list(self._get_str(offs, return_size=True))
size = sparts.pop(-1)
return size, sparts, [1], types
# Check if it's a simple scalar type.
nptype = self.gguf_scalar_to_np.get(gtype)
@ -216,23 +222,23 @@ class GGUFReader:
# Get Tensor Name
name_len, name_data = self._get_str(offs)
offs += int(name_len.nbytes + name_data.nbytes)
offs = self.data.tell()
# Get Tensor Dimensions Count
n_dims = self._get(offs, np.uint32)
offs += int(n_dims.nbytes)
offs = self.data.tell()
# Get Tensor Dimension Array
dims = self._get(offs, np.uint64, n_dims[0])
offs += int(dims.nbytes)
offs = self.data.tell()
# Get Tensor Encoding Scheme Type
raw_dtype = self._get(offs, np.uint32)
offs += int(raw_dtype.nbytes)
offs = self.data.tell()
# Get Tensor Offset
offset_tensor = self._get(offs, np.uint64)
offs += int(offset_tensor.nbytes)
offs = self.data.tell()
return ReaderField(
orig_offs,
@ -245,9 +251,9 @@ class GGUFReader:
for _ in range(count):
orig_offs = offs
kv_klen, kv_kdata = self._get_str(offs)
offs += int(kv_klen.nbytes + kv_kdata.nbytes)
offs = self.data.tell()
raw_kv_type = self._get(offs, np.uint32)
offs += int(raw_kv_type.nbytes)
offs = self.data.tell()
parts: list[npt.NDArray[Any]] = [kv_klen, kv_kdata, raw_kv_type]
idxs_offs = len(parts)
field_size, field_parts, field_idxs, field_types = self._get_field_parts(offs, raw_kv_type[0])
@ -266,7 +272,7 @@ class GGUFReader:
tensor_fields = []
for _ in range(count):
field = self._get_tensor_info_field(offs)
offs += sum(int(part.nbytes) for part in field.parts)
offs = self.data.tell()
tensor_fields.append(field)
return offs, tensor_fields