gguf-py : decouple adding metadata from writing in GGUFWriter
This commit is contained in:
parent
da799b4189
commit
fe59f20d26
3 changed files with 138 additions and 117 deletions
|
@ -47,7 +47,7 @@ class Model:
|
||||||
_model_classes: dict[str, type[Model]] = {}
|
_model_classes: dict[str, type[Model]] = {}
|
||||||
|
|
||||||
dir_model: Path
|
dir_model: Path
|
||||||
ftype: int
|
ftype: gguf.LlamaFileType
|
||||||
is_big_endian: bool
|
is_big_endian: bool
|
||||||
endianess: gguf.GGUFEndian
|
endianess: gguf.GGUFEndian
|
||||||
use_temp_file: bool
|
use_temp_file: bool
|
||||||
|
@ -94,7 +94,7 @@ class Model:
|
||||||
ftype_lw: str = ftype_up.lower()
|
ftype_lw: str = ftype_up.lower()
|
||||||
# allow templating the file name with the output ftype, useful with the "auto" ftype
|
# allow templating the file name with the output ftype, useful with the "auto" ftype
|
||||||
self.fname_out = fname_out.parent / fname_out.name.format(ftype_lw, outtype=ftype_lw, ftype=ftype_lw, OUTTYPE=ftype_up, FTYPE=ftype_up)
|
self.fname_out = fname_out.parent / fname_out.name.format(ftype_lw, outtype=ftype_lw, ftype=ftype_lw, OUTTYPE=ftype_up, FTYPE=ftype_up)
|
||||||
self.gguf_writer = gguf.GGUFWriter(self.fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file)
|
self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __init_subclass__(cls):
|
def __init_subclass__(cls):
|
||||||
|
@ -324,13 +324,13 @@ class Model:
|
||||||
|
|
||||||
def write(self):
|
def write(self):
|
||||||
self.write_tensors()
|
self.write_tensors()
|
||||||
self.gguf_writer.write_header_to_file()
|
self.gguf_writer.write_header_to_file(self.fname_out)
|
||||||
self.gguf_writer.write_kv_data_to_file()
|
self.gguf_writer.write_kv_data_to_file()
|
||||||
self.gguf_writer.write_tensors_to_file(progress=True)
|
self.gguf_writer.write_tensors_to_file(progress=True)
|
||||||
self.gguf_writer.close()
|
self.gguf_writer.close()
|
||||||
|
|
||||||
def write_vocab(self):
|
def write_vocab(self):
|
||||||
self.gguf_writer.write_header_to_file()
|
self.gguf_writer.write_header_to_file(self.fname_out)
|
||||||
self.gguf_writer.write_kv_data_to_file()
|
self.gguf_writer.write_kv_data_to_file()
|
||||||
self.gguf_writer.close()
|
self.gguf_writer.close()
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ import os
|
||||||
import shutil
|
import shutil
|
||||||
import struct
|
import struct
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from dataclasses import dataclass
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from io import BufferedWriter
|
from io import BufferedWriter
|
||||||
from typing import IO, Any, Sequence, Mapping
|
from typing import IO, Any, Sequence, Mapping
|
||||||
|
@ -30,7 +31,22 @@ from .quants import quant_shape_from_byte_shape
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TensorInfo:
|
||||||
|
shape: Sequence[int]
|
||||||
|
dtype: GGMLQuantizationType
|
||||||
|
nbytes: int
|
||||||
|
tensor: np.ndarray[Any, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GGUFValue:
|
||||||
|
value: Any
|
||||||
|
type: GGUFValueType
|
||||||
|
|
||||||
|
|
||||||
class WriterState(Enum):
|
class WriterState(Enum):
|
||||||
|
NO_FILE = auto()
|
||||||
EMPTY = auto()
|
EMPTY = auto()
|
||||||
HEADER = auto()
|
HEADER = auto()
|
||||||
KV_DATA = auto()
|
KV_DATA = auto()
|
||||||
|
@ -38,9 +54,11 @@ class WriterState(Enum):
|
||||||
|
|
||||||
|
|
||||||
class GGUFWriter:
|
class GGUFWriter:
|
||||||
fout: BufferedWriter
|
fout: BufferedWriter | None
|
||||||
temp_file: tempfile.SpooledTemporaryFile[bytes] | None
|
temp_file: tempfile.SpooledTemporaryFile[bytes] | None
|
||||||
tensors: list[np.ndarray[Any, Any]]
|
tensors: dict[str, TensorInfo]
|
||||||
|
kv_data: dict[str, GGUFValue]
|
||||||
|
state: WriterState
|
||||||
_simple_value_packing = {
|
_simple_value_packing = {
|
||||||
GGUFValueType.UINT8: "B",
|
GGUFValueType.UINT8: "B",
|
||||||
GGUFValueType.INT8: "b",
|
GGUFValueType.INT8: "b",
|
||||||
|
@ -56,141 +74,129 @@ class GGUFWriter:
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, path: os.PathLike[str] | str, arch: str, use_temp_file: bool = True,
|
self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False,
|
||||||
endianess: GGUFEndian = GGUFEndian.LITTLE,
|
endianess: GGUFEndian = GGUFEndian.LITTLE,
|
||||||
):
|
):
|
||||||
self.fout = open(path, "wb")
|
self.fout = open(path, "wb") if path is not None else None
|
||||||
self.arch = arch
|
self.arch = arch
|
||||||
self.endianess = endianess
|
self.endianess = endianess
|
||||||
self.offset_tensor = 0
|
|
||||||
self.data_alignment = GGUF_DEFAULT_ALIGNMENT
|
self.data_alignment = GGUF_DEFAULT_ALIGNMENT
|
||||||
self.kv_data = bytearray()
|
|
||||||
self.kv_data_count = 0
|
|
||||||
self.ti_data = bytearray()
|
|
||||||
self.ti_data_count = 0
|
|
||||||
self.ti_names = set()
|
|
||||||
self.use_temp_file = use_temp_file
|
self.use_temp_file = use_temp_file
|
||||||
self.temp_file = None
|
self.temp_file = None
|
||||||
self.tensors = []
|
self.tensors = dict()
|
||||||
|
self.kv_data = dict()
|
||||||
logger.info("gguf: This GGUF file is for {0} Endian only".format(
|
logger.info("gguf: This GGUF file is for {0} Endian only".format(
|
||||||
"Big" if self.endianess == GGUFEndian.BIG else "Little",
|
"Big" if self.endianess == GGUFEndian.BIG else "Little",
|
||||||
))
|
))
|
||||||
self.state = WriterState.EMPTY
|
self.state = WriterState.NO_FILE if self.fout is None else WriterState.EMPTY
|
||||||
|
|
||||||
self.add_architecture()
|
self.add_architecture()
|
||||||
|
|
||||||
def write_header_to_file(self) -> None:
|
def write_header_to_file(self, path: os.PathLike[str] | str | None = None) -> None:
|
||||||
|
# NOTE: not checking for WriterState.NO_FILE,
|
||||||
|
# because writing can technically be started over from any state,
|
||||||
|
# as long as a new path is provided
|
||||||
|
if path is not None:
|
||||||
|
if self.fout is not None:
|
||||||
|
self.fout.close()
|
||||||
|
self.fout = open(path, "wb")
|
||||||
|
self.state = WriterState.EMPTY
|
||||||
if self.state is not WriterState.EMPTY:
|
if self.state is not WriterState.EMPTY:
|
||||||
raise ValueError(f'Expected output file to be empty, got {self.state}')
|
raise ValueError(f'Expected output file to be empty, got {self.state}')
|
||||||
|
|
||||||
self._write_packed("<I", GGUF_MAGIC, skip_pack_prefix = True)
|
self._write_packed("<I", GGUF_MAGIC, skip_pack_prefix = True)
|
||||||
self._write_packed("I", GGUF_VERSION)
|
self._write_packed("I", GGUF_VERSION)
|
||||||
self._write_packed("Q", self.ti_data_count)
|
self._write_packed("Q", len(self.tensors))
|
||||||
self._write_packed("Q", self.kv_data_count)
|
self._write_packed("Q", len(self.kv_data))
|
||||||
self.flush()
|
self.flush()
|
||||||
self.state = WriterState.HEADER
|
self.state = WriterState.HEADER
|
||||||
|
|
||||||
def write_kv_data_to_file(self) -> None:
|
def write_kv_data_to_file(self) -> None:
|
||||||
if self.state is not WriterState.HEADER:
|
if self.state is not WriterState.HEADER:
|
||||||
raise ValueError(f'Expected output file to contain the header, got {self.state}')
|
raise ValueError(f'Expected output file to contain the header, got {self.state}')
|
||||||
|
assert self.fout is not None
|
||||||
|
|
||||||
self.fout.write(self.kv_data)
|
kv_data = bytearray()
|
||||||
|
|
||||||
|
for key, val in self.kv_data.items():
|
||||||
|
kv_data += self._pack_val(key, GGUFValueType.STRING, add_vtype=False)
|
||||||
|
kv_data += self._pack_val(val.value, val.type, add_vtype=True)
|
||||||
|
|
||||||
|
self.fout.write(kv_data)
|
||||||
self.flush()
|
self.flush()
|
||||||
self.state = WriterState.KV_DATA
|
self.state = WriterState.KV_DATA
|
||||||
|
|
||||||
def write_ti_data_to_file(self) -> None:
|
def write_ti_data_to_file(self) -> None:
|
||||||
if self.state is not WriterState.KV_DATA:
|
if self.state is not WriterState.KV_DATA:
|
||||||
raise ValueError(f'Expected output file to contain KV data, got {self.state}')
|
raise ValueError(f'Expected output file to contain KV data, got {self.state}')
|
||||||
|
assert self.fout is not None
|
||||||
|
|
||||||
self.fout.write(self.ti_data)
|
ti_data = bytearray()
|
||||||
|
offset_tensor = 0
|
||||||
|
|
||||||
|
for name, ti in self.tensors.items():
|
||||||
|
ti_data += self._pack_val(name, GGUFValueType.STRING, add_vtype=False)
|
||||||
|
n_dims = len(ti.shape)
|
||||||
|
ti_data += self._pack("I", n_dims)
|
||||||
|
for i in range(n_dims):
|
||||||
|
ti_data += self._pack("Q", ti.shape[n_dims - 1 - i])
|
||||||
|
ti_data += self._pack("I", ti.dtype)
|
||||||
|
ti_data += self._pack("Q", offset_tensor)
|
||||||
|
offset_tensor += GGUFWriter.ggml_pad(ti.nbytes, self.data_alignment)
|
||||||
|
|
||||||
|
self.fout.write(ti_data)
|
||||||
self.flush()
|
self.flush()
|
||||||
self.state = WriterState.TI_DATA
|
self.state = WriterState.TI_DATA
|
||||||
|
|
||||||
def add_key(self, key: str) -> None:
|
def add_key_value(self, key: str, val: Any, vtype: GGUFValueType) -> None:
|
||||||
self.add_val(key, GGUFValueType.STRING, add_vtype=False)
|
if key in self.kv_data:
|
||||||
|
raise ValueError(f'Duplicated key name {key!r}')
|
||||||
|
|
||||||
|
self.kv_data[key] = GGUFValue(value=val, type=vtype)
|
||||||
|
|
||||||
def add_uint8(self, key: str, val: int) -> None:
|
def add_uint8(self, key: str, val: int) -> None:
|
||||||
self.add_key(key)
|
self.add_key_value(key,val, GGUFValueType.UINT8)
|
||||||
self.add_val(val, GGUFValueType.UINT8)
|
|
||||||
|
|
||||||
def add_int8(self, key: str, val: int) -> None:
|
def add_int8(self, key: str, val: int) -> None:
|
||||||
self.add_key(key)
|
self.add_key_value(key, val, GGUFValueType.INT8)
|
||||||
self.add_val(val, GGUFValueType.INT8)
|
|
||||||
|
|
||||||
def add_uint16(self, key: str, val: int) -> None:
|
def add_uint16(self, key: str, val: int) -> None:
|
||||||
self.add_key(key)
|
self.add_key_value(key, val, GGUFValueType.UINT16)
|
||||||
self.add_val(val, GGUFValueType.UINT16)
|
|
||||||
|
|
||||||
def add_int16(self, key: str, val: int) -> None:
|
def add_int16(self, key: str, val: int) -> None:
|
||||||
self.add_key(key)
|
self.add_key_value(key, val, GGUFValueType.INT16)
|
||||||
self.add_val(val, GGUFValueType.INT16)
|
|
||||||
|
|
||||||
def add_uint32(self, key: str, val: int) -> None:
|
def add_uint32(self, key: str, val: int) -> None:
|
||||||
self.add_key(key)
|
self.add_key_value(key, val, GGUFValueType.UINT32)
|
||||||
self.add_val(val, GGUFValueType.UINT32)
|
|
||||||
|
|
||||||
def add_int32(self, key: str, val: int) -> None:
|
def add_int32(self, key: str, val: int) -> None:
|
||||||
self.add_key(key)
|
self.add_key_value(key, val, GGUFValueType.INT32)
|
||||||
self.add_val(val, GGUFValueType.INT32)
|
|
||||||
|
|
||||||
def add_float32(self, key: str, val: float) -> None:
|
def add_float32(self, key: str, val: float) -> None:
|
||||||
self.add_key(key)
|
self.add_key_value(key, val, GGUFValueType.FLOAT32)
|
||||||
self.add_val(val, GGUFValueType.FLOAT32)
|
|
||||||
|
|
||||||
def add_uint64(self, key: str, val: int) -> None:
|
def add_uint64(self, key: str, val: int) -> None:
|
||||||
self.add_key(key)
|
self.add_key_value(key, val, GGUFValueType.UINT64)
|
||||||
self.add_val(val, GGUFValueType.UINT64)
|
|
||||||
|
|
||||||
def add_int64(self, key: str, val: int) -> None:
|
def add_int64(self, key: str, val: int) -> None:
|
||||||
self.add_key(key)
|
self.add_key_value(key, val, GGUFValueType.INT64)
|
||||||
self.add_val(val, GGUFValueType.INT64)
|
|
||||||
|
|
||||||
def add_float64(self, key: str, val: float) -> None:
|
def add_float64(self, key: str, val: float) -> None:
|
||||||
self.add_key(key)
|
self.add_key_value(key, val, GGUFValueType.FLOAT64)
|
||||||
self.add_val(val, GGUFValueType.FLOAT64)
|
|
||||||
|
|
||||||
def add_bool(self, key: str, val: bool) -> None:
|
def add_bool(self, key: str, val: bool) -> None:
|
||||||
self.add_key(key)
|
self.add_key_value(key, val, GGUFValueType.BOOL)
|
||||||
self.add_val(val, GGUFValueType.BOOL)
|
|
||||||
|
|
||||||
def add_string(self, key: str, val: str) -> None:
|
def add_string(self, key: str, val: str) -> None:
|
||||||
if not val:
|
if not val:
|
||||||
return
|
return
|
||||||
self.add_key(key)
|
self.add_key_value(key, val, GGUFValueType.STRING)
|
||||||
self.add_val(val, GGUFValueType.STRING)
|
|
||||||
|
|
||||||
def add_array(self, key: str, val: Sequence[Any]) -> None:
|
def add_array(self, key: str, val: Sequence[Any]) -> None:
|
||||||
if not isinstance(val, Sequence):
|
if not isinstance(val, Sequence):
|
||||||
raise ValueError("Value must be a sequence for array type")
|
raise ValueError("Value must be a sequence for array type")
|
||||||
|
|
||||||
self.add_key(key)
|
self.add_key_value(key, val, GGUFValueType.ARRAY)
|
||||||
self.add_val(val, GGUFValueType.ARRAY)
|
|
||||||
|
|
||||||
def add_val(self, val: Any, vtype: GGUFValueType | None = None, add_vtype: bool = True) -> None:
|
|
||||||
if vtype is None:
|
|
||||||
vtype = GGUFValueType.get_type(val)
|
|
||||||
|
|
||||||
if add_vtype:
|
|
||||||
self.kv_data += self._pack("I", vtype)
|
|
||||||
self.kv_data_count += 1
|
|
||||||
|
|
||||||
pack_fmt = self._simple_value_packing.get(vtype)
|
|
||||||
if pack_fmt is not None:
|
|
||||||
self.kv_data += self._pack(pack_fmt, val, skip_pack_prefix = vtype == GGUFValueType.BOOL)
|
|
||||||
elif vtype == GGUFValueType.STRING:
|
|
||||||
encoded_val = val.encode("utf-8") if isinstance(val, str) else val
|
|
||||||
self.kv_data += self._pack("Q", len(encoded_val))
|
|
||||||
self.kv_data += encoded_val
|
|
||||||
elif vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and val:
|
|
||||||
ltype = GGUFValueType.get_type(val[0])
|
|
||||||
if not all(GGUFValueType.get_type(i) is ltype for i in val[1:]):
|
|
||||||
raise ValueError("All items in a GGUF array should be of the same type")
|
|
||||||
self.kv_data += self._pack("I", ltype)
|
|
||||||
self.kv_data += self._pack("Q", len(val))
|
|
||||||
for item in val:
|
|
||||||
self.add_val(item, add_vtype=False)
|
|
||||||
else:
|
|
||||||
raise ValueError("Invalid GGUF metadata value type or value")
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def ggml_pad(x: int, n: int) -> int:
|
def ggml_pad(x: int, n: int) -> int:
|
||||||
|
@ -200,16 +206,12 @@ class GGUFWriter:
|
||||||
self, name: str, tensor_shape: Sequence[int], tensor_dtype: np.dtype,
|
self, name: str, tensor_shape: Sequence[int], tensor_dtype: np.dtype,
|
||||||
tensor_nbytes: int, raw_dtype: GGMLQuantizationType | None = None,
|
tensor_nbytes: int, raw_dtype: GGMLQuantizationType | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if self.state is not WriterState.EMPTY:
|
if self.state is not WriterState.EMPTY and self.state is not WriterState.NO_FILE:
|
||||||
raise ValueError(f'Expected output file to be empty, got {self.state}')
|
raise ValueError(f'Expected output file to be empty or absent, got {self.state}')
|
||||||
|
|
||||||
if name in self.ti_names:
|
if name in self.tensors:
|
||||||
raise ValueError(f'Duplicated tensor name {name}')
|
raise ValueError(f'Duplicated tensor name {name!r}')
|
||||||
self.ti_names.add(name)
|
|
||||||
|
|
||||||
encoded_name = name.encode("utf-8")
|
|
||||||
self.ti_data += self._pack("Q", len(encoded_name))
|
|
||||||
self.ti_data += encoded_name
|
|
||||||
if raw_dtype is None:
|
if raw_dtype is None:
|
||||||
if tensor_dtype == np.float16:
|
if tensor_dtype == np.float16:
|
||||||
dtype = GGMLQuantizationType.F16
|
dtype = GGMLQuantizationType.F16
|
||||||
|
@ -231,14 +233,8 @@ class GGUFWriter:
|
||||||
dtype = raw_dtype
|
dtype = raw_dtype
|
||||||
if tensor_dtype == np.uint8:
|
if tensor_dtype == np.uint8:
|
||||||
tensor_shape = quant_shape_from_byte_shape(tensor_shape, raw_dtype)
|
tensor_shape = quant_shape_from_byte_shape(tensor_shape, raw_dtype)
|
||||||
n_dims = len(tensor_shape)
|
|
||||||
self.ti_data += self._pack("I", n_dims)
|
self.tensors[name] = TensorInfo(shape=tensor_shape, dtype=dtype, nbytes=tensor_nbytes)
|
||||||
for i in range(n_dims):
|
|
||||||
self.ti_data += self._pack("Q", tensor_shape[n_dims - 1 - i])
|
|
||||||
self.ti_data += self._pack("I", dtype)
|
|
||||||
self.ti_data += self._pack("Q", self.offset_tensor)
|
|
||||||
self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment)
|
|
||||||
self.ti_data_count += 1
|
|
||||||
|
|
||||||
def add_tensor(
|
def add_tensor(
|
||||||
self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
|
self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
|
||||||
|
@ -255,7 +251,7 @@ class GGUFWriter:
|
||||||
self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype=raw_dtype)
|
self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype=raw_dtype)
|
||||||
|
|
||||||
if self.temp_file is None:
|
if self.temp_file is None:
|
||||||
self.tensors.append(tensor)
|
self.tensors[name].tensor = tensor
|
||||||
return
|
return
|
||||||
|
|
||||||
tensor.tofile(self.temp_file)
|
tensor.tofile(self.temp_file)
|
||||||
|
@ -269,6 +265,7 @@ class GGUFWriter:
|
||||||
def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None:
|
def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None:
|
||||||
if self.state is not WriterState.TI_DATA:
|
if self.state is not WriterState.TI_DATA:
|
||||||
raise ValueError(f'Expected output file to contain tensor info, got {self.state}')
|
raise ValueError(f'Expected output file to contain tensor info, got {self.state}')
|
||||||
|
assert self.fout is not None
|
||||||
|
|
||||||
if self.endianess == GGUFEndian.BIG:
|
if self.endianess == GGUFEndian.BIG:
|
||||||
tensor.byteswap(inplace=True)
|
tensor.byteswap(inplace=True)
|
||||||
|
@ -279,34 +276,30 @@ class GGUFWriter:
|
||||||
def write_tensors_to_file(self, *, progress: bool = False) -> None:
|
def write_tensors_to_file(self, *, progress: bool = False) -> None:
|
||||||
self.write_ti_data_to_file()
|
self.write_ti_data_to_file()
|
||||||
|
|
||||||
|
assert self.fout is not None
|
||||||
|
|
||||||
self.write_padding(self.fout, self.fout.tell())
|
self.write_padding(self.fout, self.fout.tell())
|
||||||
|
|
||||||
if self.temp_file is None:
|
if self.temp_file is None:
|
||||||
self.tensors.reverse() # to pop from the "beginning" in constant time
|
bar = None
|
||||||
|
|
||||||
if progress:
|
if progress:
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
total_bytes = sum(t.nbytes for t in self.tensors)
|
total_bytes = sum(t.nbytes for t in self.tensors.values())
|
||||||
|
|
||||||
bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
|
bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
|
||||||
|
|
||||||
while True:
|
# relying on the fact that Python dicts preserve insertion order (since 3.7)
|
||||||
try:
|
for ti in self.tensors.values():
|
||||||
tensor = self.tensors.pop()
|
assert ti.tensor is not None # can only iterate once over the tensors
|
||||||
except IndexError:
|
assert ti.tensor.nbytes == ti.nbytes
|
||||||
break
|
ti.tensor.tofile(self.fout)
|
||||||
tensor.tofile(self.fout)
|
if bar is not None:
|
||||||
bar.update(tensor.nbytes)
|
bar.update(ti.nbytes)
|
||||||
self.write_padding(self.fout, tensor.nbytes)
|
self.write_padding(self.fout, ti.nbytes)
|
||||||
return
|
ti.tensor = None
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
tensor = self.tensors.pop()
|
|
||||||
except IndexError:
|
|
||||||
break
|
|
||||||
tensor.tofile(self.fout)
|
|
||||||
self.write_padding(self.fout, tensor.nbytes)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
self.temp_file.seek(0)
|
self.temp_file.seek(0)
|
||||||
|
@ -316,10 +309,13 @@ class GGUFWriter:
|
||||||
self.temp_file.close()
|
self.temp_file.close()
|
||||||
|
|
||||||
def flush(self) -> None:
|
def flush(self) -> None:
|
||||||
|
assert self.fout is not None
|
||||||
self.fout.flush()
|
self.fout.flush()
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
|
if self.fout is not None:
|
||||||
self.fout.close()
|
self.fout.close()
|
||||||
|
self.fout = None
|
||||||
|
|
||||||
def add_architecture(self) -> None:
|
def add_architecture(self) -> None:
|
||||||
self.add_string(Keys.General.ARCHITECTURE, self.arch)
|
self.add_string(Keys.General.ARCHITECTURE, self.arch)
|
||||||
|
@ -449,7 +445,7 @@ class GGUFWriter:
|
||||||
def add_rope_scaling_factor(self, value: float) -> None:
|
def add_rope_scaling_factor(self, value: float) -> None:
|
||||||
self.add_float32(Keys.Rope.SCALING_FACTOR.format(arch=self.arch), value)
|
self.add_float32(Keys.Rope.SCALING_FACTOR.format(arch=self.arch), value)
|
||||||
|
|
||||||
def add_rope_scaling_attn_factors(self, value: Sequence[float]) -> None:
|
def add_rope_scaling_attn_factors(self, value: float) -> None:
|
||||||
self.add_float32(Keys.Rope.SCALING_ATTN_FACTOR.format(arch=self.arch), value)
|
self.add_float32(Keys.Rope.SCALING_ATTN_FACTOR.format(arch=self.arch), value)
|
||||||
|
|
||||||
def add_rope_scaling_orig_ctx_len(self, value: int) -> None:
|
def add_rope_scaling_orig_ctx_len(self, value: int) -> None:
|
||||||
|
@ -571,5 +567,32 @@ class GGUFWriter:
|
||||||
pack_prefix = '<' if self.endianess == GGUFEndian.LITTLE else '>'
|
pack_prefix = '<' if self.endianess == GGUFEndian.LITTLE else '>'
|
||||||
return struct.pack(f'{pack_prefix}{fmt}', value)
|
return struct.pack(f'{pack_prefix}{fmt}', value)
|
||||||
|
|
||||||
|
def _pack_val(self, val: Any, vtype: GGUFValueType, add_vtype: bool) -> bytes:
|
||||||
|
kv_data = bytearray()
|
||||||
|
|
||||||
|
if add_vtype:
|
||||||
|
kv_data += self._pack("I", vtype)
|
||||||
|
|
||||||
|
pack_fmt = self._simple_value_packing.get(vtype)
|
||||||
|
if pack_fmt is not None:
|
||||||
|
kv_data += self._pack(pack_fmt, val, skip_pack_prefix = vtype == GGUFValueType.BOOL)
|
||||||
|
elif vtype == GGUFValueType.STRING:
|
||||||
|
encoded_val = val.encode("utf-8") if isinstance(val, str) else val
|
||||||
|
kv_data += self._pack("Q", len(encoded_val))
|
||||||
|
kv_data += encoded_val
|
||||||
|
elif vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and val:
|
||||||
|
ltype = GGUFValueType.get_type(val[0])
|
||||||
|
if not all(GGUFValueType.get_type(i) is ltype for i in val[1:]):
|
||||||
|
raise ValueError("All items in a GGUF array should be of the same type")
|
||||||
|
kv_data += self._pack("I", ltype)
|
||||||
|
kv_data += self._pack("Q", len(val))
|
||||||
|
for item in val:
|
||||||
|
kv_data += self._pack_val(item, ltype, add_vtype=False)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid GGUF metadata value type or value")
|
||||||
|
|
||||||
|
return kv_data
|
||||||
|
|
||||||
def _write_packed(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> None:
|
def _write_packed(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> None:
|
||||||
|
assert self.fout is not None
|
||||||
self.fout.write(self._pack(fmt, value, skip_pack_prefix))
|
self.fout.write(self._pack(fmt, value, skip_pack_prefix))
|
||||||
|
|
|
@ -101,8 +101,7 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new
|
||||||
logger.debug(f'Copying {field.name}')
|
logger.debug(f'Copying {field.name}')
|
||||||
|
|
||||||
if val.value is not None:
|
if val.value is not None:
|
||||||
writer.add_key(field.name)
|
writer.add_key_value(field.name, val.value, val.type)
|
||||||
writer.add_val(val.value, val.type)
|
|
||||||
|
|
||||||
if gguf.Keys.Tokenizer.CHAT_TEMPLATE in new_metadata:
|
if gguf.Keys.Tokenizer.CHAT_TEMPLATE in new_metadata:
|
||||||
logger.debug('Adding chat template(s)')
|
logger.debug('Adding chat template(s)')
|
||||||
|
@ -111,8 +110,7 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new
|
||||||
|
|
||||||
for key, val in new_metadata.items():
|
for key, val in new_metadata.items():
|
||||||
logger.debug(f'Adding {key}: "{val.value}" {val.description}')
|
logger.debug(f'Adding {key}: "{val.value}" {val.description}')
|
||||||
writer.add_key(key)
|
writer.add_key_value(key, val.value, val.type)
|
||||||
writer.add_val(val.value, val.type)
|
|
||||||
|
|
||||||
total_bytes = 0
|
total_bytes = 0
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue