gguf : track writer state

This commit is contained in:
Cebtenzzre 2023-10-01 19:31:08 -04:00 committed by cebtenzzre
parent 3fcdc9330a
commit d97afcfc02

View file

@ -7,7 +7,7 @@ import shutil
import struct import struct
import sys import sys
import tempfile import tempfile
from enum import IntEnum, auto from enum import Enum, IntEnum, auto
from io import BufferedWriter from io import BufferedWriter
from pathlib import Path from pathlib import Path
from typing import IO, Any, BinaryIO, Callable, Sequence from typing import IO, Any, BinaryIO, Callable, Sequence
@ -636,6 +636,13 @@ class GGUFValueType(IntEnum):
sys.exit() sys.exit()
class WriterState(Enum):
EMPTY = auto()
HEADER = auto()
KV_DATA = auto()
TI_DATA = auto()
class GGUFWriter: class GGUFWriter:
fout: BufferedWriter fout: BufferedWriter
temp_file: tempfile.SpooledTemporaryFile[bytes] | None temp_file: tempfile.SpooledTemporaryFile[bytes] | None
@ -676,24 +683,37 @@ class GGUFWriter:
self.tensors = [] self.tensors = []
endianess_str = "Big Endian" if self.endianess == GGUFEndian.BIG else "Little Endian" endianess_str = "Big Endian" if self.endianess == GGUFEndian.BIG else "Little Endian"
print(f"This gguf file is for {endianess_str} only") print(f"This gguf file is for {endianess_str} only")
self.state = WriterState.EMPTY
self.add_architecture() self.add_architecture()
def write_header_to_file(self): def write_header_to_file(self):
if self.state is not WriterState.EMPTY:
raise ValueError(f'Expected output file to be empty, got {self.state}')
self.fout.write(struct.pack("<I", GGUF_MAGIC)) self.fout.write(struct.pack("<I", GGUF_MAGIC))
self.fout.write(struct.pack(f"{self.pack_prefix}I", GGUF_VERSION)) self.fout.write(struct.pack(f"{self.pack_prefix}I", GGUF_VERSION))
self.fout.write(struct.pack(f"{self.pack_prefix}Q", self.ti_data_count)) self.fout.write(struct.pack(f"{self.pack_prefix}Q", self.ti_data_count))
self.fout.write(struct.pack(f"{self.pack_prefix}Q", self.kv_data_count)) self.fout.write(struct.pack(f"{self.pack_prefix}Q", self.kv_data_count))
self.flush() self.flush()
# print("tensors " + str(self.ti_data_count) + " kv " + str(self.kv_data_count)) #print("tensors " + str(self.ti_data_count) + " kv " + str(self.kv_data_count))
self.state = WriterState.HEADER
def write_kv_data_to_file(self): def write_kv_data_to_file(self):
if self.state is not WriterState.HEADER:
raise ValueError(f'Expected output file to contain the header, got {self.state}')
self.fout.write(self.kv_data) self.fout.write(self.kv_data)
self.flush() self.flush()
self.state = WriterState.KV_DATA
def write_ti_data_to_file(self): def write_ti_data_to_file(self):
if self.state is not WriterState.KV_DATA:
raise ValueError(f'Expected output file to contain KV data, got {self.state}')
self.fout.write(self.ti_data) self.fout.write(self.ti_data)
self.flush() self.flush()
self.state = WriterState.TI_DATA
def add_key(self, key: str): def add_key(self, key: str):
self.add_val(key, GGUFValueType.STRING, add_vtype=False) self.add_val(key, GGUFValueType.STRING, add_vtype=False)
@ -828,6 +848,9 @@ class GGUFWriter:
fp.write(bytes([0] * pad)) fp.write(bytes([0] * pad))
def write_tensor_data(self, tensor: np.ndarray[Any, Any]): def write_tensor_data(self, tensor: np.ndarray[Any, Any]):
if self.state is not WriterState.TI_DATA:
raise ValueError(f'Expected output file to contain tensor info, got {self.state}')
if self.endianess==GGUFEndian.BIG: if self.endianess==GGUFEndian.BIG:
tensor.byteswap(inplace=True) tensor.byteswap(inplace=True)
self.write_padding(self.fout, self.fout.tell()) self.write_padding(self.fout, self.fout.tell())