gguf-py: gguf_writer: Use BytesIO to build metadata

This commit is contained in:
KerfuffleV2 2023-11-12 15:43:09 -07:00
parent 532dd74e38
commit 446ee3c79f
2 changed files with 18 additions and 18 deletions

View file

@ -5,7 +5,7 @@ import shutil
import struct import struct
import tempfile import tempfile
from enum import Enum, auto from enum import Enum, auto
from io import BufferedWriter from io import BufferedWriter, BytesIO
from typing import IO, Any, Sequence from typing import IO, Any, Sequence
import numpy as np import numpy as np
@ -57,9 +57,9 @@ class GGUFWriter:
self.endianess = endianess self.endianess = endianess
self.offset_tensor = 0 self.offset_tensor = 0
self.data_alignment = GGUF_DEFAULT_ALIGNMENT self.data_alignment = GGUF_DEFAULT_ALIGNMENT
self.kv_data = b"" self.kv_data = BytesIO()
self.kv_data_count = 0 self.kv_data_count = 0
self.ti_data = b"" self.ti_data = BytesIO()
self.ti_data_count = 0 self.ti_data_count = 0
self.use_temp_file = use_temp_file self.use_temp_file = use_temp_file
self.temp_file = None self.temp_file = None
@ -86,7 +86,7 @@ class GGUFWriter:
if self.state is not WriterState.HEADER: if self.state is not WriterState.HEADER:
raise ValueError(f'Expected output file to contain the header, got {self.state}') raise ValueError(f'Expected output file to contain the header, got {self.state}')
self.fout.write(self.kv_data) self.fout.write(self.kv_data.getbuffer())
self.flush() self.flush()
self.state = WriterState.KV_DATA self.state = WriterState.KV_DATA
@ -94,7 +94,7 @@ class GGUFWriter:
if self.state is not WriterState.KV_DATA: if self.state is not WriterState.KV_DATA:
raise ValueError(f'Expected output file to contain KV data, got {self.state}') raise ValueError(f'Expected output file to contain KV data, got {self.state}')
self.fout.write(self.ti_data) self.fout.write(self.ti_data.getbuffer())
self.flush() self.flush()
self.state = WriterState.TI_DATA self.state = WriterState.TI_DATA
@ -163,22 +163,22 @@ class GGUFWriter:
vtype = GGUFValueType.get_type(val) vtype = GGUFValueType.get_type(val)
if add_vtype: if add_vtype:
self.kv_data += self._pack("I", vtype) self.kv_data.write(self._pack("I", vtype))
self.kv_data_count += 1 self.kv_data_count += 1
pack_fmt = self._simple_value_packing.get(vtype) pack_fmt = self._simple_value_packing.get(vtype)
if pack_fmt is not None: if pack_fmt is not None:
self.kv_data += self._pack(pack_fmt, val, skip_pack_prefix = vtype == GGUFValueType.BOOL) self.kv_data.write(self._pack(pack_fmt, val, skip_pack_prefix = vtype == GGUFValueType.BOOL))
elif vtype == GGUFValueType.STRING: elif vtype == GGUFValueType.STRING:
encoded_val = val.encode("utf8") if isinstance(val, str) else val encoded_val = val.encode("utf8") if isinstance(val, str) else val
self.kv_data += self._pack("Q", len(encoded_val)) self.kv_data.write(self._pack("Q", len(encoded_val)))
self.kv_data += encoded_val self.kv_data.write(encoded_val)
elif vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and val: elif vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and val:
ltype = GGUFValueType.get_type(val[0]) ltype = GGUFValueType.get_type(val[0])
if not all(GGUFValueType.get_type(i) is ltype for i in val[1:]): if not all(GGUFValueType.get_type(i) is ltype for i in val[1:]):
raise ValueError("All items in a GGUF array should be of the same type") raise ValueError("All items in a GGUF array should be of the same type")
self.kv_data += self._pack("I", ltype) self.kv_data.write(self._pack("I", ltype))
self.kv_data += self._pack("Q", len(val)) self.kv_data.write(self._pack("Q", len(val)))
for item in val: for item in val:
self.add_val(item, add_vtype=False) self.add_val(item, add_vtype=False)
else: else:
@ -199,18 +199,18 @@ class GGUFWriter:
raise ValueError("Only F32 and F16 tensors are supported for now") raise ValueError("Only F32 and F16 tensors are supported for now")
encoded_name = name.encode("utf8") encoded_name = name.encode("utf8")
self.ti_data += self._pack("Q", len(encoded_name)) self.ti_data.write(self._pack("Q", len(encoded_name)))
self.ti_data += encoded_name self.ti_data.write(encoded_name)
n_dims = len(tensor_shape) n_dims = len(tensor_shape)
self.ti_data += self._pack("I", n_dims) self.ti_data.write(self._pack("I", n_dims))
for i in range(n_dims): for i in range(n_dims):
self.ti_data += self._pack("Q", tensor_shape[n_dims - 1 - i]) self.ti_data.write(self._pack("Q", tensor_shape[n_dims - 1 - i]))
if raw_dtype is None: if raw_dtype is None:
dtype = GGMLQuantizationType.F32 if tensor_dtype == np.float32 else GGMLQuantizationType.F16 dtype = GGMLQuantizationType.F32 if tensor_dtype == np.float32 else GGMLQuantizationType.F16
else: else:
dtype = raw_dtype dtype = raw_dtype
self.ti_data += self._pack("I", dtype) self.ti_data.write(self._pack("I", dtype))
self.ti_data += self._pack("Q", self.offset_tensor) self.ti_data.write(self._pack("Q", self.offset_tensor))
self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment) self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment)
self.ti_data_count += 1 self.ti_data_count += 1

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "gguf" name = "gguf"
version = "0.5.1" version = "0.5.2"
description = "Read and write ML models in GGUF for GGML" description = "Read and write ML models in GGUF for GGML"
authors = ["GGML <ggml@ggml.ai>"] authors = ["GGML <ggml@ggml.ai>"]
packages = [ packages = [