improve performance

This commit is contained in:
isotr0py 2024-10-31 22:13:15 +08:00
parent 5ce2dbcf38
commit bcef54e10a

View file

@ -6,6 +6,7 @@ from __future__ import annotations
import logging import logging
import os import os
import struct
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Literal, NamedTuple, TypeVar, Union from typing import Any, Literal, NamedTuple, TypeVar, Union
@ -92,7 +93,7 @@ class GGUFReader:
offs = 0 offs = 0
# Check for GGUF magic # Check for GGUF magic
if self._get(offs, np.uint32, override_order = '<')[0] != GGUF_MAGIC: if struct.unpack("<I", self.data.read(4))[0] != GGUF_MAGIC:
raise ValueError('GGUF magic invalid') raise ValueError('GGUF magic invalid')
offs += 4 offs += 4
@ -170,9 +171,14 @@ class GGUFReader:
self.fields[field.name] = field self.fields[field.name] = field
return 0 if skip_sum else sum(int(part.nbytes) for part in field.parts) return 0 if skip_sum else sum(int(part.nbytes) for part in field.parts)
def _get_str(self, offset: int) -> tuple[npt.NDArray[np.uint64], npt.NDArray[np.uint8]]: def _get_str(self, offset: int, return_size=False) -> tuple[npt.NDArray[np.uint64], npt.NDArray[np.uint8]]:
slen = self._get(offset, np.uint64) self.data.seek(offset)
return slen, self._get(offset + 8, np.uint8, slen[0]) slen = struct.unpack('<Q', self.data.read(8))
sdata = struct.unpack('<'+'B'*slen[0], self.data.read(1*slen[0]))
output = (slen, sdata)
if return_size:
output += (8 + slen[0],)
return output
def _get_field_parts( def _get_field_parts(
self, orig_offs: int, raw_type: int, self, orig_offs: int, raw_type: int,
@ -183,8 +189,8 @@ class GGUFReader:
types.append(gtype) types.append(gtype)
# Handle strings. # Handle strings.
if gtype == GGUFValueType.STRING: if gtype == GGUFValueType.STRING:
sparts: list[npt.NDArray[Any]] = list(self._get_str(offs)) sparts: list[npt.NDArray[Any]] = list(self._get_str(offs, return_size=True))
size = sum(int(part.nbytes) for part in sparts) size = sparts.pop(-1)
return size, sparts, [1], types return size, sparts, [1], types
# Check if it's a simple scalar type. # Check if it's a simple scalar type.
nptype = self.gguf_scalar_to_np.get(gtype) nptype = self.gguf_scalar_to_np.get(gtype)
@ -216,23 +222,23 @@ class GGUFReader:
# Get Tensor Name # Get Tensor Name
name_len, name_data = self._get_str(offs) name_len, name_data = self._get_str(offs)
offs += int(name_len.nbytes + name_data.nbytes) offs = self.data.tell()
# Get Tensor Dimensions Count # Get Tensor Dimensions Count
n_dims = self._get(offs, np.uint32) n_dims = self._get(offs, np.uint32)
offs += int(n_dims.nbytes) offs = self.data.tell()
# Get Tensor Dimension Array # Get Tensor Dimension Array
dims = self._get(offs, np.uint64, n_dims[0]) dims = self._get(offs, np.uint64, n_dims[0])
offs += int(dims.nbytes) offs = self.data.tell()
# Get Tensor Encoding Scheme Type # Get Tensor Encoding Scheme Type
raw_dtype = self._get(offs, np.uint32) raw_dtype = self._get(offs, np.uint32)
offs += int(raw_dtype.nbytes) offs = self.data.tell()
# Get Tensor Offset # Get Tensor Offset
offset_tensor = self._get(offs, np.uint64) offset_tensor = self._get(offs, np.uint64)
offs += int(offset_tensor.nbytes) offs = self.data.tell()
return ReaderField( return ReaderField(
orig_offs, orig_offs,
@ -245,9 +251,9 @@ class GGUFReader:
for _ in range(count): for _ in range(count):
orig_offs = offs orig_offs = offs
kv_klen, kv_kdata = self._get_str(offs) kv_klen, kv_kdata = self._get_str(offs)
offs += int(kv_klen.nbytes + kv_kdata.nbytes) offs = self.data.tell()
raw_kv_type = self._get(offs, np.uint32) raw_kv_type = self._get(offs, np.uint32)
offs += int(raw_kv_type.nbytes) offs = self.data.tell()
parts: list[npt.NDArray[Any]] = [kv_klen, kv_kdata, raw_kv_type] parts: list[npt.NDArray[Any]] = [kv_klen, kv_kdata, raw_kv_type]
idxs_offs = len(parts) idxs_offs = len(parts)
field_size, field_parts, field_idxs, field_types = self._get_field_parts(offs, raw_kv_type[0]) field_size, field_parts, field_idxs, field_types = self._get_field_parts(offs, raw_kv_type[0])
@ -266,7 +272,7 @@ class GGUFReader:
tensor_fields = [] tensor_fields = []
for _ in range(count): for _ in range(count):
field = self._get_tensor_info_field(offs) field = self._get_tensor_info_field(offs)
offs += sum(int(part.nbytes) for part in field.parts) offs = self.data.tell()
tensor_fields.append(field) tensor_fields.append(field)
return offs, tensor_fields return offs, tensor_fields