Merge branch 'master' into gguf-dump-grouping-fix

This commit is contained in:
Brian 2024-06-24 20:58:06 +10:00 committed by GitHub
commit 64f2a194db
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 1123 additions and 642 deletions

View file

@ -65,7 +65,8 @@ class Model:
# subclasses should define this! # subclasses should define this!
model_arch: gguf.MODEL_ARCH model_arch: gguf.MODEL_ARCH
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool, model_name: str | None): def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool,
model_name: str | None, split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False):
if type(self) is Model: if type(self) is Model:
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated") raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
self.dir_model = dir_model self.dir_model = dir_model
@ -96,7 +97,8 @@ class Model:
ftype_lw: str = ftype_up.lower() ftype_lw: str = ftype_up.lower()
# allow templating the file name with the output ftype, useful with the "auto" ftype # allow templating the file name with the output ftype, useful with the "auto" ftype
self.fname_out = fname_out.parent / fname_out.name.format(ftype_lw, outtype=ftype_lw, ftype=ftype_lw, OUTTYPE=ftype_up, FTYPE=ftype_up) self.fname_out = fname_out.parent / fname_out.name.format(ftype_lw, outtype=ftype_lw, ftype=ftype_lw, OUTTYPE=ftype_up, FTYPE=ftype_up)
self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file) self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file,
split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard)
@classmethod @classmethod
def __init_subclass__(cls): def __init_subclass__(cls):
@ -332,6 +334,8 @@ class Model:
self.gguf_writer.close() self.gguf_writer.close()
def write_vocab(self): def write_vocab(self):
if len(self.gguf_writer.tensors) != 1:
raise ValueError('Splitting the vocabulary is not supported')
self.gguf_writer.write_header_to_file(self.fname_out) self.gguf_writer.write_header_to_file(self.fname_out)
self.gguf_writer.write_kv_data_to_file() self.gguf_writer.write_kv_data_to_file()
self.gguf_writer.close() self.gguf_writer.close()
@ -2974,10 +2978,44 @@ def parse_args() -> argparse.Namespace:
"--verbose", action="store_true", "--verbose", action="store_true",
help="increase output verbosity", help="increase output verbosity",
) )
parser.add_argument(
"--split-max-tensors", type=int, default=0,
help="max tensors in each split",
)
parser.add_argument(
"--split-max-size", type=str, default="0",
help="max size per split N(M|G)",
)
parser.add_argument(
"--dry-run", action="store_true",
help="only print out a split plan and exit, without writing any new files",
)
parser.add_argument(
"--no-tensor-first-split", action="store_true",
help="do not add tensors to the first split (disabled by default)"
)
return parser.parse_args() return parser.parse_args()
def split_str_to_n_bytes(split_str: str) -> int:
if split_str.endswith("K"):
n = int(split_str[:-1]) * 1000
elif split_str.endswith("M"):
n = int(split_str[:-1]) * 1000 * 1000
elif split_str.endswith("G"):
n = int(split_str[:-1]) * 1000 * 1000 * 1000
elif split_str.isnumeric():
n = int(split_str)
else:
raise ValueError(f"Invalid split size: {split_str}, must be a number, optionally followed by K, M, or G")
if n < 0:
raise ValueError(f"Invalid split size: {split_str}, must be positive")
return n
def main() -> None: def main() -> None:
args = parse_args() args = parse_args()
@ -3010,6 +3048,10 @@ def main() -> None:
"auto": gguf.LlamaFileType.GUESSED, "auto": gguf.LlamaFileType.GUESSED,
} }
if args.use_temp_file and (args.split_max_tensors > 0 or args.split_max_size != "0"):
logger.error("Error: Cannot use temp file when splitting")
sys.exit(1)
if args.outfile is not None: if args.outfile is not None:
fname_out = args.outfile fname_out = args.outfile
else: else:
@ -3027,7 +3069,10 @@ def main() -> None:
logger.error(f"Model {hparams['architectures'][0]} is not supported") logger.error(f"Model {hparams['architectures'][0]} is not supported")
sys.exit(1) sys.exit(1)
model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file, args.no_lazy, args.model_name) model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file,
args.no_lazy, args.model_name, split_max_tensors=args.split_max_tensors,
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
small_first_shard=args.no_tensor_first_split)
logger.info("Set model parameters") logger.info("Set model parameters")
model_instance.set_gguf_parameters() model_instance.set_gguf_parameters()
@ -3038,13 +3083,13 @@ def main() -> None:
model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION) model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
if args.vocab_only: if args.vocab_only:
logger.info(f"Exporting model vocab to '{model_instance.fname_out}'") logger.info("Exporting model vocab...")
model_instance.write_vocab() model_instance.write_vocab()
logger.info("Model vocab successfully exported.")
else: else:
logger.info(f"Exporting model to '{model_instance.fname_out}'") logger.info("Exporting model...")
model_instance.write() model_instance.write()
logger.info("Model successfully exported.")
logger.info(f"Model successfully exported to '{model_instance.fname_out}'")
if __name__ == '__main__': if __name__ == '__main__':

View file

@ -643,7 +643,7 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> {
static constexpr int qi = QI3_S; static constexpr int qi = QI3_S;
}; };
static int get_mmq_x_max_host(const int cc) { static constexpr int get_mmq_x_max_host(int cc) {
#ifdef CUDA_USE_TENSOR_CORES #ifdef CUDA_USE_TENSOR_CORES
return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_MAX_BATCH_SIZE : 64; return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_MAX_BATCH_SIZE : 64;
#else #else
@ -652,7 +652,7 @@ static int get_mmq_x_max_host(const int cc) {
} }
// Round rows to this value for --split-mode row: // Round rows to this value for --split-mode row:
static int get_mmq_y_host(const int cc) { static constexpr int get_mmq_y_host(int cc) {
return cc >= CC_VOLTA ? 128 : 64; return cc >= CC_VOLTA ? 128 : 64;
} }

View file

@ -20,6 +20,20 @@ struct mma_int_A_I16K4 {
GGML_CUDA_ASSUME(ret < K); GGML_CUDA_ASSUME(ret < K);
return ret; return ret;
} }
__device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
#if defined(INT8_MMA_AVAILABLE)
const int * xs = xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
: "+r"(x[0]), "+r"(x[1])
: "l"(xs));
#else
#pragma unroll
for (int l = 0; l < ne; ++l) {
x[l] = xs0[get_i(l)*stride + get_k(l)];
}
#endif // defined(INT8_MMA_AVAILABLE)
}
}; };
struct mma_int_A_I16K8 { struct mma_int_A_I16K8 {
@ -42,6 +56,20 @@ struct mma_int_A_I16K8 {
GGML_CUDA_ASSUME(ret < K); GGML_CUDA_ASSUME(ret < K);
return ret; return ret;
} }
__device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
#if defined(INT8_MMA_AVAILABLE)
const int * xs = xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
: "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
: "l"(xs));
#else
#pragma unroll
for (int l = 0; l < ne; ++l) {
x[l] = xs0[get_i(l)*stride + get_k(l)];
}
#endif // defined(INT8_MMA_AVAILABLE)
}
}; };
struct mma_int_B_J8K4 { struct mma_int_B_J8K4 {
@ -64,6 +92,20 @@ struct mma_int_B_J8K4 {
GGML_CUDA_ASSUME(ret < K); GGML_CUDA_ASSUME(ret < K);
return ret; return ret;
} }
__device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
const int * xs = xs0 + (threadIdx.x%J)*stride;
asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];"
: "+r"(x[0])
: "l"(xs));
#else
#pragma unroll
for (int l = 0; l < ne; ++l) {
x[l] = xs0[get_j(l)*stride + get_k(l)];
}
#endif // defined(INT8_MMA_AVAILABLE)
}
}; };
struct mma_int_B_J8K8 { struct mma_int_B_J8K8 {
@ -86,6 +128,20 @@ struct mma_int_B_J8K8 {
GGML_CUDA_ASSUME(ret < K); GGML_CUDA_ASSUME(ret < K);
return ret; return ret;
} }
__device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
const int * xs = xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K;
asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
: "+r"(x[0]), "+r"(x[1])
: "l"(xs));
#else
#pragma unroll
for (int l = 0; l < ne; ++l) {
x[l] = xs0[get_j(l)*stride + get_k(l)];
}
#endif // defined(INT8_MMA_AVAILABLE)
}
}; };
struct mma_int_C_I16J8 { struct mma_int_C_I16J8 {

File diff suppressed because it is too large Load diff

View file

@ -75,6 +75,11 @@ class Keys:
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned" SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier" SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier"
class Split:
LLM_KV_SPLIT_NO = "split.no"
LLM_KV_SPLIT_COUNT = "split.count"
LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count"
class SSM: class SSM:
CONV_KERNEL = "{arch}.ssm.conv_kernel" CONV_KERNEL = "{arch}.ssm.conv_kernel"
INNER_SIZE = "{arch}.ssm.inner_size" INNER_SIZE = "{arch}.ssm.inner_size"

View file

@ -7,6 +7,7 @@ import struct
import tempfile import tempfile
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from pathlib import Path
from io import BufferedWriter from io import BufferedWriter
from typing import IO, Any, Sequence, Mapping from typing import IO, Any, Sequence, Mapping
from string import ascii_letters, digits from string import ascii_letters, digits
@ -31,6 +32,9 @@ from .quants import quant_shape_from_byte_shape
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SHARD_NAME_FORMAT = "{:s}-{:05d}-of-{:05d}.gguf"
@dataclass @dataclass
class TensorInfo: class TensorInfo:
shape: Sequence[int] shape: Sequence[int]
@ -55,11 +59,11 @@ class WriterState(Enum):
class GGUFWriter: class GGUFWriter:
fout: BufferedWriter | None fout: list[BufferedWriter] | None
path: os.PathLike[str] | str | None path: Path | None
temp_file: tempfile.SpooledTemporaryFile[bytes] | None temp_file: tempfile.SpooledTemporaryFile[bytes] | None
tensors: dict[str, TensorInfo] tensors: list[dict[str, TensorInfo]]
kv_data: dict[str, GGUFValue] kv_data: list[dict[str, GGUFValue]]
state: WriterState state: WriterState
_simple_value_packing = { _simple_value_packing = {
GGUFValueType.UINT8: "B", GGUFValueType.UINT8: "B",
@ -76,26 +80,38 @@ class GGUFWriter:
} }
def __init__( def __init__(
self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False, self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False, endianess: GGUFEndian = GGUFEndian.LITTLE,
endianess: GGUFEndian = GGUFEndian.LITTLE, split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False
): ):
self.fout = None self.fout = None
self.path = path self.path = Path(path) if path else None
self.arch = arch self.arch = arch
self.endianess = endianess self.endianess = endianess
self.data_alignment = GGUF_DEFAULT_ALIGNMENT self.data_alignment = GGUF_DEFAULT_ALIGNMENT
self.use_temp_file = use_temp_file self.use_temp_file = use_temp_file
self.temp_file = None self.temp_file = None
self.tensors = dict() self.tensors = [{}]
self.kv_data = dict() self.kv_data = [{}]
self.split_max_tensors = split_max_tensors
self.split_max_size = split_max_size
self.dry_run = dry_run
self.small_first_shard = small_first_shard
logger.info("gguf: This GGUF file is for {0} Endian only".format( logger.info("gguf: This GGUF file is for {0} Endian only".format(
"Big" if self.endianess == GGUFEndian.BIG else "Little", "Big" if self.endianess == GGUFEndian.BIG else "Little",
)) ))
self.state = WriterState.NO_FILE self.state = WriterState.NO_FILE
if self.small_first_shard:
self.tensors.append({})
self.add_architecture() self.add_architecture()
def open_output_file(self, path: os.PathLike[str] | str | None = None) -> None: def format_shard_names(self, path: Path) -> list[Path]:
if len(self.tensors) == 1:
return [path]
return [path.with_name(SHARD_NAME_FORMAT.format(path.stem, i + 1, len(self.tensors))) for i in range(len(self.tensors))]
def open_output_file(self, path: Path | None = None) -> None:
if self.state is WriterState.EMPTY and self.fout is not None and (path is None or path == self.path): if self.state is WriterState.EMPTY and self.fout is not None and (path is None or path == self.path):
# allow calling this multiple times as long as the path is the same # allow calling this multiple times as long as the path is the same
return return
@ -106,22 +122,58 @@ class GGUFWriter:
self.path = path self.path = path
if self.path is not None: if self.path is not None:
if self.fout is not None: filenames = self.print_plan()
self.fout.close() self.fout = [open(filename, "wb") for filename in filenames]
self.fout = open(self.path, "wb")
self.state = WriterState.EMPTY self.state = WriterState.EMPTY
def write_header_to_file(self, path: os.PathLike[str] | str | None = None) -> None: def print_plan(self) -> list[Path]:
logger.info("Writing the following files:")
assert self.path is not None
filenames = self.format_shard_names(self.path)
assert len(filenames) == len(self.tensors)
for name, tensors in zip(filenames, self.tensors):
logger.info(f"{name}: n_tensors = {len(tensors)}, total_size = {GGUFWriter.format_n_bytes_to_str(sum(ti.nbytes for ti in tensors.values()))}")
if self.dry_run:
logger.info("Dry run, not writing files")
exit()
return filenames
def add_shard_kv_data(self) -> None:
if len(self.tensors) == 1:
return
total_tensors = sum(len(t) for t in self.tensors)
assert self.fout is not None
total_splits = len(self.fout)
self.kv_data.extend({} for _ in range(len(self.kv_data), total_splits))
for i, kv_data in enumerate(self.kv_data):
kv_data[Keys.Split.LLM_KV_SPLIT_NO] = GGUFValue(i, GGUFValueType.UINT16)
kv_data[Keys.Split.LLM_KV_SPLIT_COUNT] = GGUFValue(total_splits, GGUFValueType.UINT16)
kv_data[Keys.Split.LLM_KV_SPLIT_TENSORS_COUNT] = GGUFValue(total_tensors, GGUFValueType.INT32)
def write_header_to_file(self, path: Path | None = None) -> None:
if len(self.tensors) == 1 and (self.split_max_tensors != 0 or self.split_max_size != 0):
logger.warning("Model fails split requirements, not splitting")
self.open_output_file(path) self.open_output_file(path)
if self.state is not WriterState.EMPTY: if self.state is not WriterState.EMPTY:
raise ValueError(f'Expected output file to be empty, got {self.state}') raise ValueError(f'Expected output file to be empty, got {self.state}')
self._write_packed("<I", GGUF_MAGIC, skip_pack_prefix = True) assert self.fout is not None
self._write_packed("I", GGUF_VERSION) assert len(self.fout) == len(self.tensors)
self._write_packed("Q", len(self.tensors)) assert len(self.kv_data) == 1
self._write_packed("Q", len(self.kv_data))
self.flush() self.add_shard_kv_data()
for fout, tensors, kv_data in zip(self.fout, self.tensors, self.kv_data):
fout.write(self._pack("<I", GGUF_MAGIC, skip_pack_prefix = True))
fout.write(self._pack("I", GGUF_VERSION))
fout.write(self._pack("Q", len(tensors)))
fout.write(self._pack("Q", len(kv_data)))
fout.flush()
self.state = WriterState.HEADER self.state = WriterState.HEADER
def write_kv_data_to_file(self) -> None: def write_kv_data_to_file(self) -> None:
@ -129,13 +181,15 @@ class GGUFWriter:
raise ValueError(f'Expected output file to contain the header, got {self.state}') raise ValueError(f'Expected output file to contain the header, got {self.state}')
assert self.fout is not None assert self.fout is not None
kv_data = bytearray() for fout, kv_data in zip(self.fout, self.kv_data):
kv_bytes = bytearray()
for key, val in self.kv_data.items(): for key, val in kv_data.items():
kv_data += self._pack_val(key, GGUFValueType.STRING, add_vtype=False) kv_bytes += self._pack_val(key, GGUFValueType.STRING, add_vtype=False)
kv_data += self._pack_val(val.value, val.type, add_vtype=True) kv_bytes += self._pack_val(val.value, val.type, add_vtype=True)
fout.write(kv_bytes)
self.fout.write(kv_data)
self.flush() self.flush()
self.state = WriterState.KV_DATA self.state = WriterState.KV_DATA
@ -144,28 +198,29 @@ class GGUFWriter:
raise ValueError(f'Expected output file to contain KV data, got {self.state}') raise ValueError(f'Expected output file to contain KV data, got {self.state}')
assert self.fout is not None assert self.fout is not None
ti_data = bytearray() for fout, tensors in zip(self.fout, self.tensors):
offset_tensor = 0 ti_data = bytearray()
offset_tensor = 0
for name, ti in self.tensors.items(): for name, ti in tensors.items():
ti_data += self._pack_val(name, GGUFValueType.STRING, add_vtype=False) ti_data += self._pack_val(name, GGUFValueType.STRING, add_vtype=False)
n_dims = len(ti.shape) n_dims = len(ti.shape)
ti_data += self._pack("I", n_dims) ti_data += self._pack("I", n_dims)
for i in range(n_dims): for j in range(n_dims):
ti_data += self._pack("Q", ti.shape[n_dims - 1 - i]) ti_data += self._pack("Q", ti.shape[n_dims - 1 - j])
ti_data += self._pack("I", ti.dtype) ti_data += self._pack("I", ti.dtype)
ti_data += self._pack("Q", offset_tensor) ti_data += self._pack("Q", offset_tensor)
offset_tensor += GGUFWriter.ggml_pad(ti.nbytes, self.data_alignment) offset_tensor += GGUFWriter.ggml_pad(ti.nbytes, self.data_alignment)
self.fout.write(ti_data) fout.write(ti_data)
self.flush() fout.flush()
self.state = WriterState.TI_DATA self.state = WriterState.TI_DATA
def add_key_value(self, key: str, val: Any, vtype: GGUFValueType) -> None: def add_key_value(self, key: str, val: Any, vtype: GGUFValueType) -> None:
if key in self.kv_data: if any(key in kv_data for kv_data in self.kv_data):
raise ValueError(f'Duplicated key name {key!r}') raise ValueError(f'Duplicated key name {key!r}')
self.kv_data[key] = GGUFValue(value=val, type=vtype) self.kv_data[0][key] = GGUFValue(value=val, type=vtype)
def add_uint8(self, key: str, val: int) -> None: def add_uint8(self, key: str, val: int) -> None:
self.add_key_value(key,val, GGUFValueType.UINT8) self.add_key_value(key,val, GGUFValueType.UINT8)
@ -206,9 +261,6 @@ class GGUFWriter:
self.add_key_value(key, val, GGUFValueType.STRING) self.add_key_value(key, val, GGUFValueType.STRING)
def add_array(self, key: str, val: Sequence[Any]) -> None: def add_array(self, key: str, val: Sequence[Any]) -> None:
if not isinstance(val, Sequence):
raise ValueError("Value must be a sequence for array type")
self.add_key_value(key, val, GGUFValueType.ARRAY) self.add_key_value(key, val, GGUFValueType.ARRAY)
@staticmethod @staticmethod
@ -222,7 +274,7 @@ class GGUFWriter:
if self.state is not WriterState.NO_FILE: if self.state is not WriterState.NO_FILE:
raise ValueError(f'Expected output file to be not yet opened, got {self.state}') raise ValueError(f'Expected output file to be not yet opened, got {self.state}')
if name in self.tensors: if any(name in tensors for tensors in self.tensors):
raise ValueError(f'Duplicated tensor name {name!r}') raise ValueError(f'Duplicated tensor name {name!r}')
if raw_dtype is None: if raw_dtype is None:
@ -247,7 +299,18 @@ class GGUFWriter:
if tensor_dtype == np.uint8: if tensor_dtype == np.uint8:
tensor_shape = quant_shape_from_byte_shape(tensor_shape, raw_dtype) tensor_shape = quant_shape_from_byte_shape(tensor_shape, raw_dtype)
self.tensors[name] = TensorInfo(shape=tensor_shape, dtype=dtype, nbytes=tensor_nbytes) # make sure there is at least one tensor before splitting
if len(self.tensors[-1]) > 0:
if ( # split when over tensor limit
self.split_max_tensors != 0
and len(self.tensors[-1]) >= self.split_max_tensors
) or ( # split when over size limit
self.split_max_size != 0
and sum(ti.nbytes for ti in self.tensors[-1].values()) + tensor_nbytes > self.split_max_size
):
self.tensors.append({})
self.tensors[-1][name] = TensorInfo(shape=tensor_shape, dtype=dtype, nbytes=tensor_nbytes)
def add_tensor( def add_tensor(
self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None, self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
@ -264,7 +327,7 @@ class GGUFWriter:
self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype=raw_dtype) self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype=raw_dtype)
if self.temp_file is None: if self.temp_file is None:
self.tensors[name].tensor = tensor self.tensors[-1][name].tensor = tensor
return return
tensor.tofile(self.temp_file) tensor.tofile(self.temp_file)
@ -282,9 +345,24 @@ class GGUFWriter:
if self.endianess == GGUFEndian.BIG: if self.endianess == GGUFEndian.BIG:
tensor.byteswap(inplace=True) tensor.byteswap(inplace=True)
self.write_padding(self.fout, self.fout.tell())
tensor.tofile(self.fout) file_id = -1
self.write_padding(self.fout, tensor.nbytes) for i, tensors in enumerate(self.tensors):
if len(tensors) > 0:
file_id = i
break
fout = self.fout[file_id]
# pop the first tensor info
# TODO: cleaner way to get the first key
first_tensor_name = [name for name, _ in zip(self.tensors[file_id].keys(), range(1))][0]
ti = self.tensors[file_id].pop(first_tensor_name)
assert ti.nbytes == tensor.nbytes
self.write_padding(fout, fout.tell())
tensor.tofile(fout)
self.write_padding(fout, tensor.nbytes)
self.state = WriterState.WEIGHTS self.state = WriterState.WEIGHTS
@ -293,31 +371,43 @@ class GGUFWriter:
assert self.fout is not None assert self.fout is not None
self.write_padding(self.fout, self.fout.tell()) for fout in self.fout:
self.write_padding(fout, fout.tell())
if self.temp_file is None: if self.temp_file is None:
shard_bar = None
bar = None bar = None
if progress: if progress:
from tqdm import tqdm from tqdm import tqdm
total_bytes = sum(t.nbytes for t in self.tensors.values()) total_bytes = sum(ti.nbytes for t in self.tensors for ti in t.values())
if len(self.fout) > 1:
shard_bar = tqdm(desc=f"Shard (0/{len(self.fout)})", total=None, unit="byte", unit_scale=True)
bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True) bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
# relying on the fact that Python dicts preserve insertion order (since 3.7) for i, (fout, tensors) in enumerate(zip(self.fout, self.tensors)):
for ti in self.tensors.values(): if shard_bar is not None:
assert ti.tensor is not None # can only iterate once over the tensors shard_bar.set_description(f"Shard ({i + 1}/{len(self.fout)})")
assert ti.tensor.nbytes == ti.nbytes total = sum(ti.nbytes for ti in tensors.values())
ti.tensor.tofile(self.fout) shard_bar.reset(total=(total if total > 0 else None))
if bar is not None:
bar.update(ti.nbytes) # relying on the fact that Python dicts preserve insertion order (since 3.7)
self.write_padding(self.fout, ti.nbytes) for ti in tensors.values():
ti.tensor = None assert ti.tensor is not None # can only iterate once over the tensors
assert ti.tensor.nbytes == ti.nbytes
ti.tensor.tofile(fout)
if shard_bar is not None:
shard_bar.update(ti.nbytes)
if bar is not None:
bar.update(ti.nbytes)
self.write_padding(fout, ti.nbytes)
ti.tensor = None
else: else:
self.temp_file.seek(0) self.temp_file.seek(0)
shutil.copyfileobj(self.temp_file, self.fout) shutil.copyfileobj(self.temp_file, self.fout[0 if not self.small_first_shard else 1])
self.flush() self.flush()
self.temp_file.close() self.temp_file.close()
@ -325,11 +415,13 @@ class GGUFWriter:
def flush(self) -> None: def flush(self) -> None:
assert self.fout is not None assert self.fout is not None
self.fout.flush() for fout in self.fout:
fout.flush()
def close(self) -> None: def close(self) -> None:
if self.fout is not None: if self.fout is not None:
self.fout.close() for fout in self.fout:
fout.close()
self.fout = None self.fout = None
def add_architecture(self) -> None: def add_architecture(self) -> None:
@ -626,6 +718,13 @@ class GGUFWriter:
return kv_data return kv_data
def _write_packed(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> None: @staticmethod
assert self.fout is not None def format_n_bytes_to_str(num: int) -> str:
self.fout.write(self._pack(fmt, value, skip_pack_prefix)) if num == 0:
return "negligible - metadata only"
fnum = float(num)
for unit in ("", "K", "M", "G"):
if abs(fnum) < 1000.0:
return f"{fnum:3.1f}{unit}"
fnum /= 1000.0
return f"{fnum:.1f}T - over 1TB, split recommended"