try to refactor kv data (still fails)

This commit is contained in:
Christian Zhou-Zheng 2024-06-09 10:29:47 -04:00
parent 97dd416903
commit ff2dd7d30d

View file

@ -36,7 +36,7 @@ logger = logging.getLogger(__name__)
SHARD_NAME_FORMAT = "{:s}-{:05d}-of-{:05d}.gguf"
NUM_SHARD_KV_DATA = 6
NUM_SHARD_KV_DATA = 3
METADATA_ONLY_INDICATOR = -1
KVTempData: TypeAlias = dict[str, tuple[Any, GGUFValueType | None]] # {key: (value, type)}
@ -92,11 +92,11 @@ class SplitStyle(Enum):
class GGUFWriter:
fout: list[BufferedWriter | None]
fout: list[BufferedWriter | None] | None
path: os.PathLike[str] | str | None
temp_file: tempfile.SpooledTemporaryFile[bytes] | None
tensors: list[dict[str, TensorInfo]]
kv_data: dict[str, GGUFValue]
kv_data: list[dict[str, GGUFValue]]
state: WriterState
_simple_value_packing = {
GGUFValueType.UINT8: "B",
@ -125,7 +125,7 @@ class GGUFWriter:
self.use_temp_file = use_temp_file
self.temp_file = None
self.tensors = []
self.kv_data = dict()
self.kv_data = [dict()]
logger.info("gguf: This GGUF file is for {0} Endian only".format(
"Big" if self.endianess == GGUFEndian.BIG else "Little",
))
@ -188,6 +188,20 @@ class GGUFWriter:
logger.info("Dry run, not writing files")
exit()
def add_shard_kv_data(self) -> None:
if self.split_arguments.split_style == SplitStyle.NONE:
return
total_tensors = sum(len(t) for t in self.tensors)
for i in range(len(self.fout)):
try: # TODO better way to do this
self.kv_data[i]
except IndexError:
self.kv_data.append(dict())
self.kv_data[i][Keys.Split.LLM_KV_SPLIT_NO] = GGUFValue(i, GGUFValueType.UINT16)
self.kv_data[i][Keys.Split.LLM_KV_SPLIT_COUNT] = GGUFValue(len(self.fout), GGUFValueType.UINT16)
self.kv_data[i][Keys.Split.LLM_KV_SPLIT_TENSORS_COUNT] = GGUFValue(total_tensors, GGUFValueType.INT32)
def write_header_to_file(self, path: os.PathLike[str] | str | None = None) -> None:
self.verify_arguments()
self.open_output_file(path)
@ -197,50 +211,35 @@ class GGUFWriter:
raise ValueError(f'Expected output file to be empty, got {self.state}')
assert len(self.fout) == len(self.tensors)
assert len(self.kv_data) == 1
self.add_shard_kv_data()
for i in range(len(self.fout)):
fout = self.fout[i]
#print(f"writing header: GGUF_VERSION={GGUF_VERSION}, GGUF_MAGIC={GGUF_MAGIC}, n_tensors={len(self.tensors[i])}, n_kv_data={len(self.kv_data[i])}")
self._write_packed(fout, "<I", GGUF_MAGIC, skip_pack_prefix = True)
self._write_packed(fout, "I", GGUF_VERSION)
self._write_packed(fout, "Q", len(self.tensors[i]))
kv_data_len = len(self.kv_data) if i == 0 else 0
if self.split_arguments.split_style != SplitStyle.NONE or self.split_arguments.small_first_shard:
kv_data_len += NUM_SHARD_KV_DATA
self._write_packed(fout, "Q", kv_data_len)
self._write_packed(fout, "Q", len(self.kv_data[i]))
self.fout[i].flush()
self.state = WriterState.HEADER
def add_shard_kv_data(self, kv_data: bytearray, shard_no: int) -> bytearray:
total_tensors = sum(len(t) for t in self.tensors)
kv_data += self._pack_val(Keys.Split.LLM_KV_SPLIT_NO, GGUFValueType.STRING, add_vtype=False)
kv_data += self._pack_val(shard_no, GGUFValueType.UINT16, add_vtype=True)
kv_data += self._pack_val(Keys.Split.LLM_KV_SPLIT_COUNT, GGUFValueType.STRING, add_vtype=False)
kv_data += self._pack_val(len(self.fout), GGUFValueType.UINT16, add_vtype=True)
kv_data += self._pack_val(Keys.Split.LLM_KV_SPLIT_TENSORS_COUNT, GGUFValueType.STRING, add_vtype=False)
kv_data += self._pack_val(total_tensors, GGUFValueType.INT32, add_vtype=True)
return kv_data
def write_kv_data_to_file(self) -> None:
if self.state is not WriterState.HEADER:
raise ValueError(f'Expected output file to contain the header, got {self.state}')
assert self.fout is not None
kv_data = bytearray()
for fout, kv_data in zip(self.fout, self.kv_data):
kv_bytes = bytearray()
for key, val in self.kv_data.items():
kv_data += self._pack_val(key, GGUFValueType.STRING, add_vtype=False)
kv_data += self._pack_val(val.value, val.type, add_vtype=True)
for key, val in kv_data.items():
kv_bytes += self._pack_val(key, GGUFValueType.STRING, add_vtype=False)
kv_bytes += self._pack_val(val.value, val.type, add_vtype=True)
if len(self.fout) > 1:
kv_data = self.add_shard_kv_data(kv_data, 0)
fout.write(kv_bytes)
# only the first shard needs kv data
self.fout[0].write(kv_data)
self.fout[0].flush()
for i in range(1, len(self.fout)):
self.fout[i].write(self.add_shard_kv_data(bytearray(), i))
self.fout[i].flush()
self.flush()
self.state = WriterState.KV_DATA
def write_ti_data_to_file(self) -> None:
@ -271,7 +270,7 @@ class GGUFWriter:
if key in self.kv_data:
raise ValueError(f'Duplicated key name {key!r}')
self.kv_data[key] = GGUFValue(value=val, type=vtype)
self.kv_data[0][key] = GGUFValue(value=val, type=vtype)
def add_uint8(self, key: str, val: int) -> None:
self.add_key_value(key,val, GGUFValueType.UINT8)