From ff2dd7d30dcd7b2dc172fcf448f6a67b42504247 Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Sun, 9 Jun 2024 10:29:47 -0400 Subject: [PATCH] try to refactor kv data (still fails) --- gguf-py/gguf/gguf_writer.py | 63 ++++++++++++++++++------------------- 1 file changed, 31 insertions(+), 32 deletions(-) diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 84190837d..0a8749d7f 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -36,7 +36,7 @@ logger = logging.getLogger(__name__) SHARD_NAME_FORMAT = "{:s}-{:05d}-of-{:05d}.gguf" -NUM_SHARD_KV_DATA = 6 +NUM_SHARD_KV_DATA = 3 METADATA_ONLY_INDICATOR = -1 KVTempData: TypeAlias = dict[str, tuple[Any, GGUFValueType | None]] # {key: (value, type)} @@ -92,11 +92,11 @@ class SplitStyle(Enum): class GGUFWriter: - fout: list[BufferedWriter | None] + fout: list[BufferedWriter | None] | None path: os.PathLike[str] | str | None temp_file: tempfile.SpooledTemporaryFile[bytes] | None tensors: list[dict[str, TensorInfo]] - kv_data: dict[str, GGUFValue] + kv_data: list[dict[str, GGUFValue]] state: WriterState _simple_value_packing = { GGUFValueType.UINT8: "B", @@ -125,7 +125,7 @@ class GGUFWriter: self.use_temp_file = use_temp_file self.temp_file = None self.tensors = [] - self.kv_data = dict() + self.kv_data = [dict()] logger.info("gguf: This GGUF file is for {0} Endian only".format( "Big" if self.endianess == GGUFEndian.BIG else "Little", )) @@ -188,6 +188,20 @@ class GGUFWriter: logger.info("Dry run, not writing files") exit() + def add_shard_kv_data(self) -> None: + if self.split_arguments.split_style == SplitStyle.NONE: + return + + total_tensors = sum(len(t) for t in self.tensors) + for i in range(len(self.fout)): + try: # TODO better way to do this + self.kv_data[i] + except IndexError: + self.kv_data.append(dict()) + self.kv_data[i][Keys.Split.LLM_KV_SPLIT_NO] = GGUFValue(i, GGUFValueType.UINT16) + self.kv_data[i][Keys.Split.LLM_KV_SPLIT_COUNT] = GGUFValue(len(self.fout), GGUFValueType.UINT16) + self.kv_data[i][Keys.Split.LLM_KV_SPLIT_TENSORS_COUNT] = GGUFValue(total_tensors, GGUFValueType.INT32) + def write_header_to_file(self, path: os.PathLike[str] | str | None = None) -> None: self.verify_arguments() self.open_output_file(path) @@ -197,50 +211,35 @@ class GGUFWriter: raise ValueError(f'Expected output file to be empty, got {self.state}') assert len(self.fout) == len(self.tensors) + assert len(self.kv_data) == 1 + + self.add_shard_kv_data() for i in range(len(self.fout)): fout = self.fout[i] + #print(f"writing header: GGUF_VERSION={GGUF_VERSION}, GGUF_MAGIC={GGUF_MAGIC}, n_tensors={len(self.tensors[i])}, n_kv_data={len(self.kv_data[i])}") self._write_packed(fout, " bytearray: - total_tensors = sum(len(t) for t in self.tensors) - kv_data += self._pack_val(Keys.Split.LLM_KV_SPLIT_NO, GGUFValueType.STRING, add_vtype=False) - kv_data += self._pack_val(shard_no, GGUFValueType.UINT16, add_vtype=True) - kv_data += self._pack_val(Keys.Split.LLM_KV_SPLIT_COUNT, GGUFValueType.STRING, add_vtype=False) - kv_data += self._pack_val(len(self.fout), GGUFValueType.UINT16, add_vtype=True) - kv_data += self._pack_val(Keys.Split.LLM_KV_SPLIT_TENSORS_COUNT, GGUFValueType.STRING, add_vtype=False) - kv_data += self._pack_val(total_tensors, GGUFValueType.INT32, add_vtype=True) - return kv_data - def write_kv_data_to_file(self) -> None: if self.state is not WriterState.HEADER: raise ValueError(f'Expected output file to contain the header, got {self.state}') assert self.fout is not None - kv_data = bytearray() + for fout, kv_data in zip(self.fout, self.kv_data): + kv_bytes = bytearray() - for key, val in self.kv_data.items(): - kv_data += self._pack_val(key, GGUFValueType.STRING, add_vtype=False) - kv_data += self._pack_val(val.value, val.type, add_vtype=True) + for key, val in kv_data.items(): + kv_bytes += self._pack_val(key, GGUFValueType.STRING, add_vtype=False) + kv_bytes += self._pack_val(val.value, val.type, add_vtype=True) - if len(self.fout) > 1: - kv_data = self.add_shard_kv_data(kv_data, 0) + fout.write(kv_bytes) - # only the first shard needs kv data - self.fout[0].write(kv_data) - self.fout[0].flush() - - for i in range(1, len(self.fout)): - self.fout[i].write(self.add_shard_kv_data(bytearray(), i)) - self.fout[i].flush() + self.flush() self.state = WriterState.KV_DATA def write_ti_data_to_file(self) -> None: @@ -271,7 +270,7 @@ class GGUFWriter: if key in self.kv_data: raise ValueError(f'Duplicated key name {key!r}') - self.kv_data[key] = GGUFValue(value=val, type=vtype) + self.kv_data[0][key] = GGUFValue(value=val, type=vtype) def add_uint8(self, key: str, val: int) -> None: self.add_key_value(key,val, GGUFValueType.UINT8)