further simplify GGUFManager
This commit is contained in:
parent
3e9430df33
commit
f6fd3ea4e9
3 changed files with 54 additions and 51 deletions
|
@ -81,14 +81,7 @@ models = [
|
||||||
{"name": "jina-v2-en", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-en", }, # WPM!
|
{"name": "jina-v2-en", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-en", }, # WPM!
|
||||||
{"name": "jina-v2-es", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-es", },
|
{"name": "jina-v2-es", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-es", },
|
||||||
{"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-de", },
|
{"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-de", },
|
||||||
<<<<<<< Updated upstream
|
|
||||||
{"name": "smaug-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/abacusai/Smaug-Llama-3-70B-Instruct", },
|
{"name": "smaug-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/abacusai/Smaug-Llama-3-70B-Instruct", },
|
||||||
=======
|
|
||||||
{"name": "bloom", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/bigscience/bloom-7b1", },
|
|
||||||
{"name": "gptbigcode", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/bigcode/gpt_bigcode-santacoder", },
|
|
||||||
{"name": "phi2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/microsoft/phi-2", },
|
|
||||||
{"name": "codeshell", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/WisdomShell/CodeShell-7B-Chat", },
|
|
||||||
>>>>>>> Stashed changes
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -60,7 +60,7 @@ class Model:
|
||||||
tensor_map: gguf.TensorNameMap
|
tensor_map: gguf.TensorNameMap
|
||||||
tensor_names: set[str] | None
|
tensor_names: set[str] | None
|
||||||
fname_out: Path
|
fname_out: Path
|
||||||
gguf_writer: gguf.GGUFManager
|
gguf_writer: gguf.GGUFWriter
|
||||||
|
|
||||||
# subclasses should define this!
|
# subclasses should define this!
|
||||||
model_arch: gguf.MODEL_ARCH
|
model_arch: gguf.MODEL_ARCH
|
||||||
|
@ -329,11 +329,16 @@ class Model:
|
||||||
|
|
||||||
def write(self):
|
def write(self):
|
||||||
self.write_tensors()
|
self.write_tensors()
|
||||||
self.gguf_writer.write_to_file()
|
self.gguf_writer.write_header_to_file()
|
||||||
|
self.gguf_writer.write_kv_data_to_file()
|
||||||
|
self.gguf_writer.write_ti_data_to_file()
|
||||||
self.gguf_writer.close()
|
self.gguf_writer.close()
|
||||||
|
|
||||||
def write_vocab(self):
|
def write_vocab(self):
|
||||||
self.gguf_writer.write_to_file(meta_only=True)
|
if self.gguf_writer.split_arguments.split:
|
||||||
|
raise ValueError('Splitting the vocabulary is not supported')
|
||||||
|
self.gguf_writer.write_header_to_file()
|
||||||
|
self.gguf_writer.write_kv_data_to_file()
|
||||||
self.gguf_writer.close()
|
self.gguf_writer.close()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -1563,7 +1568,6 @@ class MiniCPMModel(Model):
|
||||||
|
|
||||||
return [(self.map_tensor_name(name), data_torch)]
|
return [(self.map_tensor_name(name), data_torch)]
|
||||||
|
|
||||||
# TODO what the hell is this?
|
|
||||||
@Model.register("QWenLMHeadModel")
|
@Model.register("QWenLMHeadModel")
|
||||||
class QwenModel(Model):
|
class QwenModel(Model):
|
||||||
model_arch = gguf.MODEL_ARCH.QWEN
|
model_arch = gguf.MODEL_ARCH.QWEN
|
||||||
|
|
|
@ -2,8 +2,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from typing import TYPE_CHECKING, Any, Sequence, Mapping
|
from typing import TYPE_CHECKING, Any, Sequence
|
||||||
from string import ascii_letters, digits
|
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
@ -18,7 +17,7 @@ from .constants import (
|
||||||
GGUFEndian,
|
GGUFEndian,
|
||||||
GGUFValueType
|
GGUFValueType
|
||||||
)
|
)
|
||||||
from .gguf_writer import GGUFWriter
|
from .gguf_writer import GGUFWriter, WriterState
|
||||||
|
|
||||||
|
|
||||||
SHARD_NAME_FORMAT = "{:s}-{:05d}-of-{:05d}.gguf"
|
SHARD_NAME_FORMAT = "{:s}-{:05d}-of-{:05d}.gguf"
|
||||||
|
@ -74,7 +73,7 @@ class SplitStrategy(deque):
|
||||||
self.append((shard, model[start:stop], GGUFWriter(shard, arch, use_temp_file=use_temp_file, endianess=endianess)))
|
self.append((shard, model[start:stop], GGUFWriter(shard, arch, use_temp_file=use_temp_file, endianess=endianess)))
|
||||||
|
|
||||||
elif split_arguments.split_style == SplitStyle.SIZE:
|
elif split_arguments.split_style == SplitStyle.SIZE:
|
||||||
shards = deque()
|
shards = []
|
||||||
|
|
||||||
# we have to determine the shards first to determine how many shards there will be in total - two passes
|
# we have to determine the shards first to determine how many shards there will be in total - two passes
|
||||||
for i, shard in enumerate(model):
|
for i, shard in enumerate(model):
|
||||||
|
@ -135,7 +134,6 @@ class SplitStrategy(deque):
|
||||||
num /= 1024.0
|
num /= 1024.0
|
||||||
return f"{num:.1f}T - over 1TB, --split recommended"
|
return f"{num:.1f}T - over 1TB, --split recommended"
|
||||||
|
|
||||||
# TODO fall back to normal GGUFWriter in convert-hf-to-gguf.py if no --split
|
|
||||||
class GGUFManager(GGUFWriter):
|
class GGUFManager(GGUFWriter):
|
||||||
kv_data: KVTempData
|
kv_data: KVTempData
|
||||||
tensors: list[TensorTempData]
|
tensors: list[TensorTempData]
|
||||||
|
@ -145,27 +143,25 @@ class GGUFManager(GGUFWriter):
|
||||||
def __init__(self, path: os.PathLike[str] | str, arch: str, split_arguments: SplitArguments,
|
def __init__(self, path: os.PathLike[str] | str, arch: str, split_arguments: SplitArguments,
|
||||||
use_temp_file: bool = True, endianess: GGUFEndian = GGUFEndian.LITTLE
|
use_temp_file: bool = True, endianess: GGUFEndian = GGUFEndian.LITTLE
|
||||||
) -> None:
|
) -> None:
|
||||||
# TODO be able to use superclass constructor
|
# we intentionally don't call superclass constructor
|
||||||
# super().__init__(path, arch, use_temp_file=use_temp_file, endianess=endianess)
|
|
||||||
self.arch = arch
|
self.arch = arch
|
||||||
self.path = path
|
self.path = path
|
||||||
self.endianess = endianess
|
self.endianess = endianess
|
||||||
self.offset_tensor = 0
|
|
||||||
self.kv_data = {}
|
self.kv_data = {}
|
||||||
self.tensors = []
|
self.tensors = []
|
||||||
# TODO how many of these do you need
|
|
||||||
self.split_strategy = None
|
self.split_strategy = None
|
||||||
self.total_shards = None
|
self.total_shards = 0
|
||||||
self.total_tensors = None
|
self.total_tensors = 0
|
||||||
self.use_temp_file = use_temp_file
|
self.use_temp_file = use_temp_file
|
||||||
self.split_arguments = split_arguments
|
self.split_arguments = split_arguments
|
||||||
self.recent_key = None
|
self.recent_key = None
|
||||||
|
self.state = WriterState.EMPTY
|
||||||
self.add_architecture()
|
self.add_architecture()
|
||||||
|
|
||||||
# TODO split back into write_header_to_file, write_kv_data_to_file, write_ti_data_to_file
|
def write_header_to_file(self) -> None:
|
||||||
def write_to_file(self, meta_only: bool = False) -> None:
|
if self.state is not WriterState.EMPTY:
|
||||||
|
raise ValueError(f'Expected GGUFManager state to be EMPTY, got {self.state}')
|
||||||
|
|
||||||
# here is the first place you can assume you have all tensors written and you can establish the size of the file - so logic goes here
|
|
||||||
self.total_tensors = len(self.tensors)
|
self.total_tensors = len(self.tensors)
|
||||||
total_size = sum(SplitStrategy.get_tensor_size(tensor[1]) for tensor in self.tensors)
|
total_size = sum(SplitStrategy.get_tensor_size(tensor[1]) for tensor in self.tensors)
|
||||||
|
|
||||||
|
@ -182,26 +178,6 @@ class GGUFManager(GGUFWriter):
|
||||||
del self.tensors
|
del self.tensors
|
||||||
self.total_shards = len(self.split_strategy)
|
self.total_shards = len(self.split_strategy)
|
||||||
|
|
||||||
# only the first shard needs all the KV data
|
|
||||||
for key, (value, etype) in self.kv_data.items():
|
|
||||||
self.split_strategy[0][2].add_key(key)
|
|
||||||
self.split_strategy[0][2].add_val(value, etype)
|
|
||||||
|
|
||||||
if self.split_arguments.split_style != SplitStyle.NONE:
|
|
||||||
for i, (_, _, writer) in enumerate(self.split_strategy):
|
|
||||||
writer.add_uint16(LLM_KV_SPLIT_NO, i)
|
|
||||||
writer.add_uint16(LLM_KV_SPLIT_COUNT, self.total_shards)
|
|
||||||
writer.add_int32(LLM_KV_SPLIT_TENSORS_COUNT, self.total_tensors)
|
|
||||||
|
|
||||||
# metadata/vocab only can write and return here
|
|
||||||
if meta_only:
|
|
||||||
for i, (_, _, writer) in enumerate(self.split_strategy):
|
|
||||||
writer.write_header_to_file()
|
|
||||||
writer.write_kv_data_to_file()
|
|
||||||
return
|
|
||||||
|
|
||||||
# tensor writing code starts here
|
|
||||||
|
|
||||||
print("\nWriting the following files:")
|
print("\nWriting the following files:")
|
||||||
for (shard_path, shard_tensors, _) in self.split_strategy:
|
for (shard_path, shard_tensors, _) in self.split_strategy:
|
||||||
size = SplitStrategy.format_n_bytes_to_str(sum(SplitStrategy.get_tensor_size(t[1]) for t in shard_tensors)) if shard_tensors else "negligible - metadata only"
|
size = SplitStrategy.format_n_bytes_to_str(sum(SplitStrategy.get_tensor_size(t[1]) for t in shard_tensors)) if shard_tensors else "negligible - metadata only"
|
||||||
|
@ -214,10 +190,38 @@ class GGUFManager(GGUFWriter):
|
||||||
os.remove(name)
|
os.remove(name)
|
||||||
return
|
return
|
||||||
|
|
||||||
# run add_tensor_info, write data, then write_tensor_data - taken from convert.py
|
self.state = WriterState.HEADER
|
||||||
|
|
||||||
|
def write_kv_data_to_file(self) -> None:
|
||||||
|
if self.split_arguments.dry_run:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.state is not WriterState.HEADER:
|
||||||
|
raise ValueError(f'Expected GGUFManager state to be HEADER, got {self.state}')
|
||||||
|
|
||||||
|
# only the first shard needs all the KV data
|
||||||
|
for key, (value, etype) in self.kv_data.items():
|
||||||
|
self.split_strategy[0][2].add_key(key)
|
||||||
|
self.split_strategy[0][2].add_val(value, etype)
|
||||||
|
|
||||||
|
# the other shards need shard data
|
||||||
|
if self.split_arguments.split_style != SplitStyle.NONE:
|
||||||
|
for i, (_, _, writer) in enumerate(self.split_strategy):
|
||||||
|
writer.add_uint16(LLM_KV_SPLIT_NO, i)
|
||||||
|
writer.add_uint16(LLM_KV_SPLIT_COUNT, self.total_shards)
|
||||||
|
writer.add_int32(LLM_KV_SPLIT_TENSORS_COUNT, self.total_tensors)
|
||||||
|
|
||||||
|
self.state = WriterState.KV_DATA
|
||||||
|
|
||||||
|
def write_ti_data_to_file(self) -> None:
|
||||||
|
if self.split_arguments.dry_run:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.state is not WriterState.KV_DATA:
|
||||||
|
raise ValueError(f'Expected GGUFManager state to be KV_DATA, got {self.state}')
|
||||||
|
|
||||||
running_total = self.total_tensors
|
running_total = self.total_tensors
|
||||||
ct = 0
|
for ct in range(self.total_shards):
|
||||||
while True:
|
|
||||||
try:
|
try:
|
||||||
(_, tensors, writer) = self.split_strategy.popleft()
|
(_, tensors, writer) = self.split_strategy.popleft()
|
||||||
tensors = deque(tensors) if tensors else None
|
tensors = deque(tensors) if tensors else None
|
||||||
|
@ -234,15 +238,17 @@ class GGUFManager(GGUFWriter):
|
||||||
break
|
break
|
||||||
writer.add_tensor(name, tensor, raw_dtype=dtype)
|
writer.add_tensor(name, tensor, raw_dtype=dtype)
|
||||||
|
|
||||||
print(f"Writing to shard {ct + 1}/{self.total_shards} with {shard_num_tensors}/{running_total} remaining tensors (of {self.total_tensors} total)")
|
print(f"Writing to shard {ct}/{self.total_shards} with {shard_num_tensors}/{running_total} remaining tensors (of {self.total_tensors} total)")
|
||||||
running_total -= shard_num_tensors
|
running_total -= shard_num_tensors
|
||||||
|
|
||||||
|
# need to write everything down here
|
||||||
writer.write_header_to_file()
|
writer.write_header_to_file()
|
||||||
writer.write_kv_data_to_file()
|
writer.write_kv_data_to_file()
|
||||||
writer.write_tensors_to_file(progress=True)
|
writer.write_tensors_to_file(progress=True)
|
||||||
ct = ct + 1
|
|
||||||
del tensors
|
del tensors
|
||||||
|
|
||||||
|
self.state = WriterState.TI_DATA
|
||||||
|
|
||||||
# override add_key, add_val to handle kv data separately
|
# override add_key, add_val to handle kv data separately
|
||||||
def add_key(self, key: str) -> None:
|
def add_key(self, key: str) -> None:
|
||||||
self.recent_key = key
|
self.recent_key = key
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue