gguf-py: gguf_writer: Use BytesIO to build metadata
This commit is contained in:
parent
532dd74e38
commit
446ee3c79f
2 changed files with 18 additions and 18 deletions
|
@ -5,7 +5,7 @@ import shutil
|
|||
import struct
|
||||
import tempfile
|
||||
from enum import Enum, auto
|
||||
from io import BufferedWriter
|
||||
from io import BufferedWriter, BytesIO
|
||||
from typing import IO, Any, Sequence
|
||||
|
||||
import numpy as np
|
||||
|
@ -57,9 +57,9 @@ class GGUFWriter:
|
|||
self.endianess = endianess
|
||||
self.offset_tensor = 0
|
||||
self.data_alignment = GGUF_DEFAULT_ALIGNMENT
|
||||
self.kv_data = b""
|
||||
self.kv_data = BytesIO()
|
||||
self.kv_data_count = 0
|
||||
self.ti_data = b""
|
||||
self.ti_data = BytesIO()
|
||||
self.ti_data_count = 0
|
||||
self.use_temp_file = use_temp_file
|
||||
self.temp_file = None
|
||||
|
@ -86,7 +86,7 @@ class GGUFWriter:
|
|||
if self.state is not WriterState.HEADER:
|
||||
raise ValueError(f'Expected output file to contain the header, got {self.state}')
|
||||
|
||||
self.fout.write(self.kv_data)
|
||||
self.fout.write(self.kv_data.getbuffer())
|
||||
self.flush()
|
||||
self.state = WriterState.KV_DATA
|
||||
|
||||
|
@ -94,7 +94,7 @@ class GGUFWriter:
|
|||
if self.state is not WriterState.KV_DATA:
|
||||
raise ValueError(f'Expected output file to contain KV data, got {self.state}')
|
||||
|
||||
self.fout.write(self.ti_data)
|
||||
self.fout.write(self.ti_data.getbuffer())
|
||||
self.flush()
|
||||
self.state = WriterState.TI_DATA
|
||||
|
||||
|
@ -163,22 +163,22 @@ class GGUFWriter:
|
|||
vtype = GGUFValueType.get_type(val)
|
||||
|
||||
if add_vtype:
|
||||
self.kv_data += self._pack("I", vtype)
|
||||
self.kv_data.write(self._pack("I", vtype))
|
||||
self.kv_data_count += 1
|
||||
|
||||
pack_fmt = self._simple_value_packing.get(vtype)
|
||||
if pack_fmt is not None:
|
||||
self.kv_data += self._pack(pack_fmt, val, skip_pack_prefix = vtype == GGUFValueType.BOOL)
|
||||
self.kv_data.write(self._pack(pack_fmt, val, skip_pack_prefix = vtype == GGUFValueType.BOOL))
|
||||
elif vtype == GGUFValueType.STRING:
|
||||
encoded_val = val.encode("utf8") if isinstance(val, str) else val
|
||||
self.kv_data += self._pack("Q", len(encoded_val))
|
||||
self.kv_data += encoded_val
|
||||
self.kv_data.write(self._pack("Q", len(encoded_val)))
|
||||
self.kv_data.write(encoded_val)
|
||||
elif vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and val:
|
||||
ltype = GGUFValueType.get_type(val[0])
|
||||
if not all(GGUFValueType.get_type(i) is ltype for i in val[1:]):
|
||||
raise ValueError("All items in a GGUF array should be of the same type")
|
||||
self.kv_data += self._pack("I", ltype)
|
||||
self.kv_data += self._pack("Q", len(val))
|
||||
self.kv_data.write(self._pack("I", ltype))
|
||||
self.kv_data.write(self._pack("Q", len(val)))
|
||||
for item in val:
|
||||
self.add_val(item, add_vtype=False)
|
||||
else:
|
||||
|
@ -199,18 +199,18 @@ class GGUFWriter:
|
|||
raise ValueError("Only F32 and F16 tensors are supported for now")
|
||||
|
||||
encoded_name = name.encode("utf8")
|
||||
self.ti_data += self._pack("Q", len(encoded_name))
|
||||
self.ti_data += encoded_name
|
||||
self.ti_data.write(self._pack("Q", len(encoded_name)))
|
||||
self.ti_data.write(encoded_name)
|
||||
n_dims = len(tensor_shape)
|
||||
self.ti_data += self._pack("I", n_dims)
|
||||
self.ti_data.write(self._pack("I", n_dims))
|
||||
for i in range(n_dims):
|
||||
self.ti_data += self._pack("Q", tensor_shape[n_dims - 1 - i])
|
||||
self.ti_data.write(self._pack("Q", tensor_shape[n_dims - 1 - i]))
|
||||
if raw_dtype is None:
|
||||
dtype = GGMLQuantizationType.F32 if tensor_dtype == np.float32 else GGMLQuantizationType.F16
|
||||
else:
|
||||
dtype = raw_dtype
|
||||
self.ti_data += self._pack("I", dtype)
|
||||
self.ti_data += self._pack("Q", self.offset_tensor)
|
||||
self.ti_data.write(self._pack("I", dtype))
|
||||
self.ti_data.write(self._pack("Q", self.offset_tensor))
|
||||
self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment)
|
||||
self.ti_data_count += 1
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "gguf"
|
||||
version = "0.5.1"
|
||||
version = "0.5.2"
|
||||
description = "Read and write ML models in GGUF for GGML"
|
||||
authors = ["GGML <ggml@ggml.ai>"]
|
||||
packages = [
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue