gguf : track writer state
This commit is contained in:
parent
3fcdc9330a
commit
d97afcfc02
1 changed files with 25 additions and 2 deletions
|
@ -7,7 +7,7 @@ import shutil
|
||||||
import struct
|
import struct
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
from enum import IntEnum, auto
|
from enum import Enum, IntEnum, auto
|
||||||
from io import BufferedWriter
|
from io import BufferedWriter
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import IO, Any, BinaryIO, Callable, Sequence
|
from typing import IO, Any, BinaryIO, Callable, Sequence
|
||||||
|
@ -636,6 +636,13 @@ class GGUFValueType(IntEnum):
|
||||||
sys.exit()
|
sys.exit()
|
||||||
|
|
||||||
|
|
||||||
|
class WriterState(Enum):
|
||||||
|
EMPTY = auto()
|
||||||
|
HEADER = auto()
|
||||||
|
KV_DATA = auto()
|
||||||
|
TI_DATA = auto()
|
||||||
|
|
||||||
|
|
||||||
class GGUFWriter:
|
class GGUFWriter:
|
||||||
fout: BufferedWriter
|
fout: BufferedWriter
|
||||||
temp_file: tempfile.SpooledTemporaryFile[bytes] | None
|
temp_file: tempfile.SpooledTemporaryFile[bytes] | None
|
||||||
|
@ -676,24 +683,37 @@ class GGUFWriter:
|
||||||
self.tensors = []
|
self.tensors = []
|
||||||
endianess_str = "Big Endian" if self.endianess == GGUFEndian.BIG else "Little Endian"
|
endianess_str = "Big Endian" if self.endianess == GGUFEndian.BIG else "Little Endian"
|
||||||
print(f"This gguf file is for {endianess_str} only")
|
print(f"This gguf file is for {endianess_str} only")
|
||||||
|
self.state = WriterState.EMPTY
|
||||||
|
|
||||||
self.add_architecture()
|
self.add_architecture()
|
||||||
|
|
||||||
def write_header_to_file(self):
|
def write_header_to_file(self):
|
||||||
|
if self.state is not WriterState.EMPTY:
|
||||||
|
raise ValueError(f'Expected output file to be empty, got {self.state}')
|
||||||
|
|
||||||
self.fout.write(struct.pack("<I", GGUF_MAGIC))
|
self.fout.write(struct.pack("<I", GGUF_MAGIC))
|
||||||
self.fout.write(struct.pack(f"{self.pack_prefix}I", GGUF_VERSION))
|
self.fout.write(struct.pack(f"{self.pack_prefix}I", GGUF_VERSION))
|
||||||
self.fout.write(struct.pack(f"{self.pack_prefix}Q", self.ti_data_count))
|
self.fout.write(struct.pack(f"{self.pack_prefix}Q", self.ti_data_count))
|
||||||
self.fout.write(struct.pack(f"{self.pack_prefix}Q", self.kv_data_count))
|
self.fout.write(struct.pack(f"{self.pack_prefix}Q", self.kv_data_count))
|
||||||
self.flush()
|
self.flush()
|
||||||
# print("tensors " + str(self.ti_data_count) + " kv " + str(self.kv_data_count))
|
#print("tensors " + str(self.ti_data_count) + " kv " + str(self.kv_data_count))
|
||||||
|
self.state = WriterState.HEADER
|
||||||
|
|
||||||
def write_kv_data_to_file(self):
|
def write_kv_data_to_file(self):
|
||||||
|
if self.state is not WriterState.HEADER:
|
||||||
|
raise ValueError(f'Expected output file to contain the header, got {self.state}')
|
||||||
|
|
||||||
self.fout.write(self.kv_data)
|
self.fout.write(self.kv_data)
|
||||||
self.flush()
|
self.flush()
|
||||||
|
self.state = WriterState.KV_DATA
|
||||||
|
|
||||||
def write_ti_data_to_file(self):
|
def write_ti_data_to_file(self):
|
||||||
|
if self.state is not WriterState.KV_DATA:
|
||||||
|
raise ValueError(f'Expected output file to contain KV data, got {self.state}')
|
||||||
|
|
||||||
self.fout.write(self.ti_data)
|
self.fout.write(self.ti_data)
|
||||||
self.flush()
|
self.flush()
|
||||||
|
self.state = WriterState.TI_DATA
|
||||||
|
|
||||||
def add_key(self, key: str):
|
def add_key(self, key: str):
|
||||||
self.add_val(key, GGUFValueType.STRING, add_vtype=False)
|
self.add_val(key, GGUFValueType.STRING, add_vtype=False)
|
||||||
|
@ -828,6 +848,9 @@ class GGUFWriter:
|
||||||
fp.write(bytes([0] * pad))
|
fp.write(bytes([0] * pad))
|
||||||
|
|
||||||
def write_tensor_data(self, tensor: np.ndarray[Any, Any]):
|
def write_tensor_data(self, tensor: np.ndarray[Any, Any]):
|
||||||
|
if self.state is not WriterState.TI_DATA:
|
||||||
|
raise ValueError(f'Expected output file to contain tensor info, got {self.state}')
|
||||||
|
|
||||||
if self.endianess==GGUFEndian.BIG:
|
if self.endianess==GGUFEndian.BIG:
|
||||||
tensor.byteswap(inplace=True)
|
tensor.byteswap(inplace=True)
|
||||||
self.write_padding(self.fout, self.fout.tell())
|
self.write_padding(self.fout, self.fout.tell())
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue