tentative push of convert-hf-to-gguf support
This commit is contained in:
parent
c33bdf397d
commit
b7c612088f
3 changed files with 74 additions and 24 deletions
|
@ -21,7 +21,9 @@ if TYPE_CHECKING:
|
||||||
|
|
||||||
if 'NO_LOCAL_GGUF' not in os.environ:
|
if 'NO_LOCAL_GGUF' not in os.environ:
|
||||||
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
|
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
|
||||||
import gguf
|
import importlib
|
||||||
|
gguf = importlib.import_module("gguf-py.gguf")
|
||||||
|
# import gguf
|
||||||
|
|
||||||
from convert import LlamaHfVocab, permute
|
from convert import LlamaHfVocab, permute
|
||||||
|
|
||||||
|
@ -43,18 +45,18 @@ AnyModel = TypeVar("AnyModel", bound="type[Model]")
|
||||||
class Model(ABC):
|
class Model(ABC):
|
||||||
_model_classes: dict[str, type[Model]] = {}
|
_model_classes: dict[str, type[Model]] = {}
|
||||||
|
|
||||||
def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool, use_temp_file: bool):
|
def __init__(self, dir_model: Path, ftype: int, fname_out: Path, args: argparse.Namespace):
|
||||||
self.dir_model = dir_model
|
self.dir_model = dir_model
|
||||||
self.ftype = ftype
|
self.ftype = ftype
|
||||||
self.fname_out = fname_out
|
self.fname_out = fname_out
|
||||||
self.is_big_endian = is_big_endian
|
self.is_big_endian = args.bigendian
|
||||||
self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
|
self.endianess = gguf.GGUFEndian.BIG if args.bigendian else gguf.GGUFEndian.LITTLE
|
||||||
self.use_temp_file = use_temp_file
|
self.use_temp_file = args.use_temp_file
|
||||||
self.is_safetensors = self._is_model_safetensors()
|
self.is_safetensors = self._is_model_safetensors()
|
||||||
self.num_parts = Model.count_model_parts(self.dir_model, ".safetensors" if self.is_safetensors else ".bin")
|
self.num_parts = Model.count_model_parts(self.dir_model, ".safetensors" if self.is_safetensors else ".bin")
|
||||||
self.part_names = self._get_part_names()
|
self.part_names = self._get_part_names()
|
||||||
self.hparams = Model.load_hparams(self.dir_model)
|
self.hparams = Model.load_hparams(self.dir_model)
|
||||||
self.gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file)
|
self.gguf_writer = gguf.GGUFManager(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], args, endianess=self.endianess, use_temp_file=self.use_temp_file)
|
||||||
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"])
|
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -174,14 +176,11 @@ class Model(ABC):
|
||||||
|
|
||||||
def write(self):
|
def write(self):
|
||||||
self.write_tensors()
|
self.write_tensors()
|
||||||
self.gguf_writer.write_header_to_file()
|
self.gguf_writer.write_to_file()
|
||||||
self.gguf_writer.write_kv_data_to_file()
|
|
||||||
self.gguf_writer.write_tensors_to_file()
|
|
||||||
self.gguf_writer.close()
|
self.gguf_writer.close()
|
||||||
|
|
||||||
def write_vocab(self):
|
def write_vocab(self):
|
||||||
self.gguf_writer.write_header_to_file()
|
self.gguf_writer.write_to_file(meta_only=True)
|
||||||
self.gguf_writer.write_kv_data_to_file()
|
|
||||||
self.gguf_writer.close()
|
self.gguf_writer.close()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -1711,7 +1710,7 @@ class MiniCPMModel(Model):
|
||||||
|
|
||||||
self.gguf_writer.add_tensor(new_name, data)
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
|
# TODO what the hell is this?
|
||||||
@Model.register("QWenLMHeadModel")
|
@Model.register("QWenLMHeadModel")
|
||||||
class QwenModel(Model):
|
class QwenModel(Model):
|
||||||
model_arch = gguf.MODEL_ARCH.QWEN
|
model_arch = gguf.MODEL_ARCH.QWEN
|
||||||
|
@ -2843,6 +2842,11 @@ def parse_args() -> argparse.Namespace:
|
||||||
help="directory containing model file",
|
help="directory containing model file",
|
||||||
)
|
)
|
||||||
parser.add_argument("--use-temp-file", action="store_true", help="use the tempfile library while processing (helpful when running out of memory, process killed)")
|
parser.add_argument("--use-temp-file", action="store_true", help="use the tempfile library while processing (helpful when running out of memory, process killed)")
|
||||||
|
parser.add_argument("--split", action="store_true", help="split the converted model into multiple files")
|
||||||
|
parser.add_argument("--split-max-tensors", type=int, help="max tensors in each split")
|
||||||
|
parser.add_argument("--split-max-size", type=str, help="max size per split N(M|G)")
|
||||||
|
parser.add_argument("--dry-run", action="store_true", help="only print out a split plan and exit, without writing any new files")
|
||||||
|
parser.add_argument("--large-first-shard", action="store_true", help="include tensors in the first shard when splitting (default: metadata only)")
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
@ -2869,6 +2873,15 @@ def main() -> None:
|
||||||
print(f'Error: {args.model} is not a directory', file=sys.stderr)
|
print(f'Error: {args.model} is not a directory', file=sys.stderr)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
if args.split and not (args.split_max_tensors or args.split_max_size):
|
||||||
|
raise ValueError("Need to specify one of --split-max-tensors or --split-max-size when splitting")
|
||||||
|
|
||||||
|
if args.split_max_tensors and args.split_max_size:
|
||||||
|
raise ValueError("Can't specify both --split-max-tensors and --split-max-size")
|
||||||
|
|
||||||
|
if args.split_max_size:
|
||||||
|
args.split_max_size = gguf.SplitStrategy.split_str_to_n_bytes(args.split_max_size)
|
||||||
|
|
||||||
ftype_map = {
|
ftype_map = {
|
||||||
"f32": gguf.GGMLQuantizationType.F32,
|
"f32": gguf.GGMLQuantizationType.F32,
|
||||||
"f16": gguf.GGMLQuantizationType.F16,
|
"f16": gguf.GGMLQuantizationType.F16,
|
||||||
|
@ -2886,7 +2899,7 @@ def main() -> None:
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
model_class = Model.from_model_architecture(hparams["architectures"][0])
|
model_class = Model.from_model_architecture(hparams["architectures"][0])
|
||||||
model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file)
|
model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args)
|
||||||
|
|
||||||
print("Set model parameters")
|
print("Set model parameters")
|
||||||
model_instance.set_gguf_parameters()
|
model_instance.set_gguf_parameters()
|
||||||
|
|
|
@ -1465,7 +1465,7 @@ def main(args_in: list[str] | None = None) -> None:
|
||||||
parser.add_argument("--skip-unknown", action="store_true", help="skip unknown tensor names instead of failing")
|
parser.add_argument("--skip-unknown", action="store_true", help="skip unknown tensor names instead of failing")
|
||||||
parser.add_argument("--split", action="store_true", help="split the converted model into multiple files")
|
parser.add_argument("--split", action="store_true", help="split the converted model into multiple files")
|
||||||
parser.add_argument("--split-max-tensors", type=int, help="max tensors in each split")
|
parser.add_argument("--split-max-tensors", type=int, help="max tensors in each split")
|
||||||
parser.add_argument("--split-max-size", type=str, help="max size per split N(M|G)+")
|
parser.add_argument("--split-max-size", type=str, help="max size per split N(M|G)")
|
||||||
parser.add_argument("--dry-run", action="store_true", help="only print out a split plan and exit, without writing any new files")
|
parser.add_argument("--dry-run", action="store_true", help="only print out a split plan and exit, without writing any new files")
|
||||||
parser.add_argument("--large-first-shard", action="store_true", help="include tensors in the first shard when splitting (default: metadata only)")
|
parser.add_argument("--large-first-shard", action="store_true", help="include tensors in the first shard when splitting (default: metadata only)")
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,7 @@ import os
|
||||||
import shutil
|
import shutil
|
||||||
import struct
|
import struct
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import time
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from typing import TYPE_CHECKING, Any, Sequence, Mapping
|
from typing import TYPE_CHECKING, Any, Sequence, Mapping
|
||||||
from string import ascii_letters, digits
|
from string import ascii_letters, digits
|
||||||
|
@ -174,7 +175,9 @@ class GGUFManager:
|
||||||
# have to consolidate because we need to know kv data count and tensor count before we can write the header
|
# have to consolidate because we need to know kv data count and tensor count before we can write the header
|
||||||
# and we need to write tensor info before we can write metadata
|
# and we need to write tensor info before we can write metadata
|
||||||
# these all kinda show up around the same places anyway so it's not a huge deal?
|
# these all kinda show up around the same places anyway so it's not a huge deal?
|
||||||
def write_to_file(self, meta_only: bool = False, ftype: int = 0, concurrency: int = 8, write_tensor_data: function = None) -> None:
|
def write_to_file(self, meta_only: bool = False, ftype: int = 0, concurrency: int = 8,
|
||||||
|
write_tensor_data: function = None
|
||||||
|
) -> None:
|
||||||
|
|
||||||
# here is the first place you can assume you have all tensors written and you can establish the size of the file - so logic goes here
|
# here is the first place you can assume you have all tensors written and you can establish the size of the file - so logic goes here
|
||||||
self.total_tensors = len(self.tensors)
|
self.total_tensors = len(self.tensors)
|
||||||
|
@ -218,19 +221,37 @@ class GGUFManager:
|
||||||
|
|
||||||
if self.args.dry_run:
|
if self.args.dry_run:
|
||||||
print("\nDry run, not writing files")
|
print("\nDry run, not writing files")
|
||||||
|
# instantiating GGUFWriters creates files
|
||||||
|
for name, _, _ in self.split_strategy:
|
||||||
|
os.remove(name)
|
||||||
return
|
return
|
||||||
|
|
||||||
# run add_tensor_info, write data, then write_tensor_data - taken from convert.py
|
# run add_tensor_info, write data, then write_tensor_data - taken from convert.py
|
||||||
running_total = self.total_tensors
|
running_total = self.total_tensors
|
||||||
|
start = time.time()
|
||||||
for i, (_, tensors, writer) in enumerate(self.split_strategy):
|
for i, (_, tensors, writer) in enumerate(self.split_strategy):
|
||||||
|
|
||||||
if tensors:
|
if tensors:
|
||||||
for name, tensor in tensors:
|
for j, (name, tensor) in enumerate(tensors):
|
||||||
n_elements = int(np.prod(tensor.shape))
|
n_elements = int(np.prod(tensor.shape))
|
||||||
raw_dtype = getattr(tensor.data_type, 'ggml_type', None)
|
# logic from convert.py
|
||||||
data_type = getattr(tensor.data_type, 'quantized_type', None) or tensor.data_type.dtype
|
if getattr(tensor, 'data_type', None):
|
||||||
data_nbytes = tensor.data_type.elements_to_bytes(n_elements)
|
raw_dtype = getattr(tensor.data_type, 'ggml_type', None)
|
||||||
writer.add_tensor_info(name, tensor.shape, data_type, data_nbytes, raw_dtype=raw_dtype)
|
data_type = getattr(tensor.data_type, 'quantized_type', None) or tensor.data_type.dtype
|
||||||
|
data_nbytes = tensor.data_type.elements_to_bytes(n_elements)
|
||||||
|
writer.add_tensor_info(name, tensor.shape, data_type, data_nbytes, raw_dtype=raw_dtype)
|
||||||
|
# logic from convert-hf-to-gguf.py
|
||||||
|
else:
|
||||||
|
# stolen from write_tensor_data because that doesn't get called with this logic
|
||||||
|
elapsed = time.time() - start
|
||||||
|
size = ' x '.join(f"{dim:6d}" for dim in tensor.shape)
|
||||||
|
padi = len(str(self.total_tensors))
|
||||||
|
dtype = str(tensor.dtype)
|
||||||
|
print(
|
||||||
|
f"[{j + 1:{padi}d}/{len(tensors)}] Writing tensor {name:38s} | size {size:16} | type {dtype:8} | T+{int(elapsed):4}"
|
||||||
|
)
|
||||||
|
writer.add_tensor(name, tensor)
|
||||||
|
|
||||||
|
|
||||||
writer.write_header_to_file()
|
writer.write_header_to_file()
|
||||||
writer.write_kv_data_to_file()
|
writer.write_kv_data_to_file()
|
||||||
|
@ -240,8 +261,9 @@ class GGUFManager:
|
||||||
print(f"\nWriting to shard {i + 1}/{self.total_shards} with {len(tensors)}/{running_total} remaining tensors (of {self.total_tensors} total)")
|
print(f"\nWriting to shard {i + 1}/{self.total_shards} with {len(tensors)}/{running_total} remaining tensors (of {self.total_tensors} total)")
|
||||||
running_total -= len(tensors)
|
running_total -= len(tensors)
|
||||||
|
|
||||||
# convert.py's write_tensor_data is dependent on so many objects in convert.py itself that it's easier to pass the function as a parameter and call it here
|
if write_tensor_data:
|
||||||
write_tensor_data(ftype, dict(tensors), concurrency, writer)
|
# convert.py's write_tensor_data is dependent on so many objects in convert.py itself that it's easier to pass the function as a parameter and call it here
|
||||||
|
write_tensor_data(ftype, dict(tensors), concurrency, writer)
|
||||||
|
|
||||||
def add_uint8(self, key: str, val: int) -> None:
|
def add_uint8(self, key: str, val: int) -> None:
|
||||||
self.kv_data[key] = (val, GGUFValueType.UINT8)
|
self.kv_data[key] = (val, GGUFValueType.UINT8)
|
||||||
|
@ -295,8 +317,23 @@ class GGUFManager:
|
||||||
self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
|
self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
|
||||||
raw_dtype: GGMLQuantizationType | None = None,
|
raw_dtype: GGMLQuantizationType | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
# TODO WRITE
|
if self.endianess == GGUFEndian.BIG:
|
||||||
pass
|
tensor.byteswap(inplace=True)
|
||||||
|
|
||||||
|
# TODO reimplement temp file
|
||||||
|
#if self.use_temp_file and self.temp_file is None:
|
||||||
|
# fp = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256 * 1024 * 1024)
|
||||||
|
# fp.seek(0)
|
||||||
|
# self.temp_file = fp
|
||||||
|
|
||||||
|
self.add_tensor_info(name, tensor)
|
||||||
|
|
||||||
|
#if self.temp_file is None:
|
||||||
|
# self.tensors.append(tensor)
|
||||||
|
# return
|
||||||
|
|
||||||
|
#tensor.tofile(self.temp_file)
|
||||||
|
#self.write_padding(self.temp_file, tensor.nbytes)
|
||||||
|
|
||||||
def write_tensors_to_file(self) -> None:
|
def write_tensors_to_file(self) -> None:
|
||||||
# TODO WRITE
|
# TODO WRITE
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue