gguf-dump: refactor GGUFReader for clarity
This commit is contained in:
parent
0fee5b6d56
commit
c0e6537508
1 changed files with 23 additions and 4 deletions
|
@ -89,9 +89,13 @@ class GGUFReader:
|
||||||
def __init__(self, path: os.PathLike[str] | str, mode: Literal['r'] | Literal['r+'] | Literal['c'] = 'r'):
|
def __init__(self, path: os.PathLike[str] | str, mode: Literal['r'] | Literal['r+'] | Literal['c'] = 'r'):
|
||||||
self.data = np.memmap(path, mode = mode)
|
self.data = np.memmap(path, mode = mode)
|
||||||
offs = 0
|
offs = 0
|
||||||
|
|
||||||
|
# Check for GGUF magic
|
||||||
if self._get(offs, np.uint32, override_order = '<')[0] != GGUF_MAGIC:
|
if self._get(offs, np.uint32, override_order = '<')[0] != GGUF_MAGIC:
|
||||||
raise ValueError('GGUF magic invalid')
|
raise ValueError('GGUF magic invalid')
|
||||||
offs += 4
|
offs += 4
|
||||||
|
|
||||||
|
# Check GGUF version
|
||||||
temp_version = self._get(offs, np.uint32)
|
temp_version = self._get(offs, np.uint32)
|
||||||
if temp_version[0] & 65535 == 0:
|
if temp_version[0] & 65535 == 0:
|
||||||
# If we get 0 here that means it's (probably) a GGUF file created for
|
# If we get 0 here that means it's (probably) a GGUF file created for
|
||||||
|
@ -104,12 +108,16 @@ class GGUFReader:
|
||||||
self.fields: OrderedDict[str, ReaderField] = OrderedDict()
|
self.fields: OrderedDict[str, ReaderField] = OrderedDict()
|
||||||
self.tensors: list[ReaderTensor] = []
|
self.tensors: list[ReaderTensor] = []
|
||||||
offs += self._push_field(ReaderField(offs, 'GGUF.version', [temp_version], [0], [GGUFValueType.UINT32]))
|
offs += self._push_field(ReaderField(offs, 'GGUF.version', [temp_version], [0], [GGUFValueType.UINT32]))
|
||||||
|
|
||||||
|
# Check tensor count and kv count
|
||||||
temp_counts = self._get(offs, np.uint64, 2)
|
temp_counts = self._get(offs, np.uint64, 2)
|
||||||
offs += self._push_field(ReaderField(offs, 'GGUF.tensor_count', [temp_counts[:1]], [0], [GGUFValueType.UINT64]))
|
offs += self._push_field(ReaderField(offs, 'GGUF.tensor_count', [temp_counts[:1]], [0], [GGUFValueType.UINT64]))
|
||||||
offs += self._push_field(ReaderField(offs, 'GGUF.kv_count', [temp_counts[1:]], [0], [GGUFValueType.UINT64]))
|
offs += self._push_field(ReaderField(offs, 'GGUF.kv_count', [temp_counts[1:]], [0], [GGUFValueType.UINT64]))
|
||||||
tensor_count, kv_count = temp_counts
|
tensor_count, kv_count = temp_counts
|
||||||
offs = self._build_fields(offs, kv_count)
|
offs = self._build_fields(offs, kv_count)
|
||||||
offs, tensors_fields = self._build_tensors_fields(offs, tensor_count)
|
|
||||||
|
# Build Tensor Info Fields
|
||||||
|
offs, tensors_fields = self._build_tensors_info_fields(offs, tensor_count)
|
||||||
new_align = self.fields.get('general.alignment')
|
new_align = self.fields.get('general.alignment')
|
||||||
if new_align is not None:
|
if new_align is not None:
|
||||||
if new_align.types != [GGUFValueType.UINT32]:
|
if new_align.types != [GGUFValueType.UINT32]:
|
||||||
|
@ -195,18 +203,29 @@ class GGUFReader:
|
||||||
# We can't deal with this one.
|
# We can't deal with this one.
|
||||||
raise ValueError('Unknown/unhandled field type {gtype}')
|
raise ValueError('Unknown/unhandled field type {gtype}')
|
||||||
|
|
||||||
def _get_tensor(self, orig_offs: int) -> ReaderField:
|
def _get_tensor_info_field(self, orig_offs: int) -> ReaderField:
|
||||||
offs = orig_offs
|
offs = orig_offs
|
||||||
|
|
||||||
|
# Tensor Info Name
|
||||||
name_len, name_data = self._get_str(offs)
|
name_len, name_data = self._get_str(offs)
|
||||||
offs += int(name_len.nbytes + name_data.nbytes)
|
offs += int(name_len.nbytes + name_data.nbytes)
|
||||||
|
|
||||||
|
# Tensor Info Dimensions Count
|
||||||
n_dims = self._get(offs, np.uint32)
|
n_dims = self._get(offs, np.uint32)
|
||||||
offs += int(n_dims.nbytes)
|
offs += int(n_dims.nbytes)
|
||||||
|
|
||||||
|
# Tensor Info Dimension Array
|
||||||
dims = self._get(offs, np.uint64, n_dims[0])
|
dims = self._get(offs, np.uint64, n_dims[0])
|
||||||
offs += int(dims.nbytes)
|
offs += int(dims.nbytes)
|
||||||
|
|
||||||
|
# Tensor Info Tensor Type
|
||||||
raw_dtype = self._get(offs, np.uint32)
|
raw_dtype = self._get(offs, np.uint32)
|
||||||
offs += int(raw_dtype.nbytes)
|
offs += int(raw_dtype.nbytes)
|
||||||
|
|
||||||
|
# Tensor Info Offset
|
||||||
offset_tensor = self._get(offs, np.uint64)
|
offset_tensor = self._get(offs, np.uint64)
|
||||||
offs += int(offset_tensor.nbytes)
|
offs += int(offset_tensor.nbytes)
|
||||||
|
|
||||||
return ReaderField(
|
return ReaderField(
|
||||||
orig_offs,
|
orig_offs,
|
||||||
str(bytes(name_data), encoding = 'utf-8'),
|
str(bytes(name_data), encoding = 'utf-8'),
|
||||||
|
@ -235,10 +254,10 @@ class GGUFReader:
|
||||||
offs += field_size
|
offs += field_size
|
||||||
return offs
|
return offs
|
||||||
|
|
||||||
def _build_tensors_fields(self, offs: int, count: int) -> tuple[int, list[ReaderField]]:
|
def _build_tensors_info_fields(self, offs: int, count: int) -> tuple[int, list[ReaderField]]:
|
||||||
tensor_fields = []
|
tensor_fields = []
|
||||||
for _ in range(count):
|
for _ in range(count):
|
||||||
field = self._get_tensor(offs)
|
field = self._get_tensor_info_field(offs)
|
||||||
offs += sum(int(part.nbytes) for part in field.parts)
|
offs += sum(int(part.nbytes) for part in field.parts)
|
||||||
tensor_fields.append(field)
|
tensor_fields.append(field)
|
||||||
return offs, tensor_fields
|
return offs, tensor_fields
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue