gguf-py: gguf_writer: Use BytesIO to build metadata

This commit is contained in:
KerfuffleV2 2023-11-12 15:43:09 -07:00
parent 532dd74e38
commit 446ee3c79f
2 changed files with 18 additions and 18 deletions

View file

@ -5,7 +5,7 @@ import shutil
import struct
import tempfile
from enum import Enum, auto
from io import BufferedWriter
from io import BufferedWriter, BytesIO
from typing import IO, Any, Sequence
import numpy as np
@ -57,9 +57,9 @@ class GGUFWriter:
self.endianess = endianess
self.offset_tensor = 0
self.data_alignment = GGUF_DEFAULT_ALIGNMENT
self.kv_data = b""
self.kv_data = BytesIO()
self.kv_data_count = 0
self.ti_data = b""
self.ti_data = BytesIO()
self.ti_data_count = 0
self.use_temp_file = use_temp_file
self.temp_file = None
@ -86,7 +86,7 @@ class GGUFWriter:
if self.state is not WriterState.HEADER:
raise ValueError(f'Expected output file to contain the header, got {self.state}')
self.fout.write(self.kv_data)
self.fout.write(self.kv_data.getbuffer())
self.flush()
self.state = WriterState.KV_DATA
@ -94,7 +94,7 @@ class GGUFWriter:
if self.state is not WriterState.KV_DATA:
raise ValueError(f'Expected output file to contain KV data, got {self.state}')
self.fout.write(self.ti_data)
self.fout.write(self.ti_data.getbuffer())
self.flush()
self.state = WriterState.TI_DATA
@ -163,22 +163,22 @@ class GGUFWriter:
vtype = GGUFValueType.get_type(val)
if add_vtype:
self.kv_data += self._pack("I", vtype)
self.kv_data.write(self._pack("I", vtype))
self.kv_data_count += 1
pack_fmt = self._simple_value_packing.get(vtype)
if pack_fmt is not None:
self.kv_data += self._pack(pack_fmt, val, skip_pack_prefix = vtype == GGUFValueType.BOOL)
self.kv_data.write(self._pack(pack_fmt, val, skip_pack_prefix = vtype == GGUFValueType.BOOL))
elif vtype == GGUFValueType.STRING:
encoded_val = val.encode("utf8") if isinstance(val, str) else val
self.kv_data += self._pack("Q", len(encoded_val))
self.kv_data += encoded_val
self.kv_data.write(self._pack("Q", len(encoded_val)))
self.kv_data.write(encoded_val)
elif vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and val:
ltype = GGUFValueType.get_type(val[0])
if not all(GGUFValueType.get_type(i) is ltype for i in val[1:]):
raise ValueError("All items in a GGUF array should be of the same type")
self.kv_data += self._pack("I", ltype)
self.kv_data += self._pack("Q", len(val))
self.kv_data.write(self._pack("I", ltype))
self.kv_data.write(self._pack("Q", len(val)))
for item in val:
self.add_val(item, add_vtype=False)
else:
@ -199,18 +199,18 @@ class GGUFWriter:
raise ValueError("Only F32 and F16 tensors are supported for now")
encoded_name = name.encode("utf8")
self.ti_data += self._pack("Q", len(encoded_name))
self.ti_data += encoded_name
self.ti_data.write(self._pack("Q", len(encoded_name)))
self.ti_data.write(encoded_name)
n_dims = len(tensor_shape)
self.ti_data += self._pack("I", n_dims)
self.ti_data.write(self._pack("I", n_dims))
for i in range(n_dims):
self.ti_data += self._pack("Q", tensor_shape[n_dims - 1 - i])
self.ti_data.write(self._pack("Q", tensor_shape[n_dims - 1 - i]))
if raw_dtype is None:
dtype = GGMLQuantizationType.F32 if tensor_dtype == np.float32 else GGMLQuantizationType.F16
else:
dtype = raw_dtype
self.ti_data += self._pack("I", dtype)
self.ti_data += self._pack("Q", self.offset_tensor)
self.ti_data.write(self._pack("I", dtype))
self.ti_data.write(self._pack("Q", self.offset_tensor))
self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment)
self.ti_data_count += 1

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "gguf"
version = "0.5.1"
version = "0.5.2"
description = "Read and write ML models in GGUF for GGML"
authors = ["GGML <ggml@ggml.ai>"]
packages = [