try to refactor kv data (still fails)
This commit is contained in:
parent
97dd416903
commit
ff2dd7d30d
1 changed files with 31 additions and 32 deletions
|
@ -36,7 +36,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
SHARD_NAME_FORMAT = "{:s}-{:05d}-of-{:05d}.gguf"
|
||||
NUM_SHARD_KV_DATA = 6
|
||||
NUM_SHARD_KV_DATA = 3
|
||||
METADATA_ONLY_INDICATOR = -1
|
||||
|
||||
KVTempData: TypeAlias = dict[str, tuple[Any, GGUFValueType | None]] # {key: (value, type)}
|
||||
|
@ -92,11 +92,11 @@ class SplitStyle(Enum):
|
|||
|
||||
|
||||
class GGUFWriter:
|
||||
fout: list[BufferedWriter | None]
|
||||
fout: list[BufferedWriter | None] | None
|
||||
path: os.PathLike[str] | str | None
|
||||
temp_file: tempfile.SpooledTemporaryFile[bytes] | None
|
||||
tensors: list[dict[str, TensorInfo]]
|
||||
kv_data: dict[str, GGUFValue]
|
||||
kv_data: list[dict[str, GGUFValue]]
|
||||
state: WriterState
|
||||
_simple_value_packing = {
|
||||
GGUFValueType.UINT8: "B",
|
||||
|
@ -125,7 +125,7 @@ class GGUFWriter:
|
|||
self.use_temp_file = use_temp_file
|
||||
self.temp_file = None
|
||||
self.tensors = []
|
||||
self.kv_data = dict()
|
||||
self.kv_data = [dict()]
|
||||
logger.info("gguf: This GGUF file is for {0} Endian only".format(
|
||||
"Big" if self.endianess == GGUFEndian.BIG else "Little",
|
||||
))
|
||||
|
@ -188,6 +188,20 @@ class GGUFWriter:
|
|||
logger.info("Dry run, not writing files")
|
||||
exit()
|
||||
|
||||
def add_shard_kv_data(self) -> None:
|
||||
if self.split_arguments.split_style == SplitStyle.NONE:
|
||||
return
|
||||
|
||||
total_tensors = sum(len(t) for t in self.tensors)
|
||||
for i in range(len(self.fout)):
|
||||
try: # TODO better way to do this
|
||||
self.kv_data[i]
|
||||
except IndexError:
|
||||
self.kv_data.append(dict())
|
||||
self.kv_data[i][Keys.Split.LLM_KV_SPLIT_NO] = GGUFValue(i, GGUFValueType.UINT16)
|
||||
self.kv_data[i][Keys.Split.LLM_KV_SPLIT_COUNT] = GGUFValue(len(self.fout), GGUFValueType.UINT16)
|
||||
self.kv_data[i][Keys.Split.LLM_KV_SPLIT_TENSORS_COUNT] = GGUFValue(total_tensors, GGUFValueType.INT32)
|
||||
|
||||
def write_header_to_file(self, path: os.PathLike[str] | str | None = None) -> None:
|
||||
self.verify_arguments()
|
||||
self.open_output_file(path)
|
||||
|
@ -197,50 +211,35 @@ class GGUFWriter:
|
|||
raise ValueError(f'Expected output file to be empty, got {self.state}')
|
||||
|
||||
assert len(self.fout) == len(self.tensors)
|
||||
assert len(self.kv_data) == 1
|
||||
|
||||
self.add_shard_kv_data()
|
||||
|
||||
for i in range(len(self.fout)):
|
||||
fout = self.fout[i]
|
||||
#print(f"writing header: GGUF_VERSION={GGUF_VERSION}, GGUF_MAGIC={GGUF_MAGIC}, n_tensors={len(self.tensors[i])}, n_kv_data={len(self.kv_data[i])}")
|
||||
self._write_packed(fout, "<I", GGUF_MAGIC, skip_pack_prefix = True)
|
||||
self._write_packed(fout, "I", GGUF_VERSION)
|
||||
self._write_packed(fout, "Q", len(self.tensors[i]))
|
||||
kv_data_len = len(self.kv_data) if i == 0 else 0
|
||||
if self.split_arguments.split_style != SplitStyle.NONE or self.split_arguments.small_first_shard:
|
||||
kv_data_len += NUM_SHARD_KV_DATA
|
||||
self._write_packed(fout, "Q", kv_data_len)
|
||||
self._write_packed(fout, "Q", len(self.kv_data[i]))
|
||||
self.fout[i].flush()
|
||||
self.state = WriterState.HEADER
|
||||
|
||||
def add_shard_kv_data(self, kv_data: bytearray, shard_no: int) -> bytearray:
|
||||
total_tensors = sum(len(t) for t in self.tensors)
|
||||
kv_data += self._pack_val(Keys.Split.LLM_KV_SPLIT_NO, GGUFValueType.STRING, add_vtype=False)
|
||||
kv_data += self._pack_val(shard_no, GGUFValueType.UINT16, add_vtype=True)
|
||||
kv_data += self._pack_val(Keys.Split.LLM_KV_SPLIT_COUNT, GGUFValueType.STRING, add_vtype=False)
|
||||
kv_data += self._pack_val(len(self.fout), GGUFValueType.UINT16, add_vtype=True)
|
||||
kv_data += self._pack_val(Keys.Split.LLM_KV_SPLIT_TENSORS_COUNT, GGUFValueType.STRING, add_vtype=False)
|
||||
kv_data += self._pack_val(total_tensors, GGUFValueType.INT32, add_vtype=True)
|
||||
return kv_data
|
||||
|
||||
def write_kv_data_to_file(self) -> None:
|
||||
if self.state is not WriterState.HEADER:
|
||||
raise ValueError(f'Expected output file to contain the header, got {self.state}')
|
||||
assert self.fout is not None
|
||||
|
||||
kv_data = bytearray()
|
||||
for fout, kv_data in zip(self.fout, self.kv_data):
|
||||
kv_bytes = bytearray()
|
||||
|
||||
for key, val in self.kv_data.items():
|
||||
kv_data += self._pack_val(key, GGUFValueType.STRING, add_vtype=False)
|
||||
kv_data += self._pack_val(val.value, val.type, add_vtype=True)
|
||||
for key, val in kv_data.items():
|
||||
kv_bytes += self._pack_val(key, GGUFValueType.STRING, add_vtype=False)
|
||||
kv_bytes += self._pack_val(val.value, val.type, add_vtype=True)
|
||||
|
||||
if len(self.fout) > 1:
|
||||
kv_data = self.add_shard_kv_data(kv_data, 0)
|
||||
fout.write(kv_bytes)
|
||||
|
||||
# only the first shard needs kv data
|
||||
self.fout[0].write(kv_data)
|
||||
self.fout[0].flush()
|
||||
|
||||
for i in range(1, len(self.fout)):
|
||||
self.fout[i].write(self.add_shard_kv_data(bytearray(), i))
|
||||
self.fout[i].flush()
|
||||
self.flush()
|
||||
self.state = WriterState.KV_DATA
|
||||
|
||||
def write_ti_data_to_file(self) -> None:
|
||||
|
@ -271,7 +270,7 @@ class GGUFWriter:
|
|||
if key in self.kv_data:
|
||||
raise ValueError(f'Duplicated key name {key!r}')
|
||||
|
||||
self.kv_data[key] = GGUFValue(value=val, type=vtype)
|
||||
self.kv_data[0][key] = GGUFValue(value=val, type=vtype)
|
||||
|
||||
def add_uint8(self, key: str, val: int) -> None:
|
||||
self.add_key_value(key,val, GGUFValueType.UINT8)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue