support splits in convert.py

This commit is contained in:
Christian Azinn 2024-04-27 00:08:55 -04:00
parent 928e0b7013
commit 874c3411c2

View file

@ -44,9 +44,16 @@ ARCH = gguf.MODEL_ARCH.LLAMA
DEFAULT_CONCURRENCY = 8
DEFAULT_SPLIT_TENSORS = 128
ADDED_TOKENS_FILE = 'added_tokens.json'
FAST_TOKENIZER_FILE = 'tokenizer.json'
LLM_KV_SPLIT_NO = "split.no"
LLM_KV_SPLIT_COUNT = "split.count"
LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count"
SHARD_NAME_FORMAT = "{:s}-{:05d}-of-{:05d}.gguf"
#
# data types
#
@ -1235,6 +1242,49 @@ class OutputFile:
of.close()
@staticmethod
def write_split(
fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: BaseVocab, svocab: gguf.SpecialVocab,
total_tensors: int, concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE,
pad_vocab: bool = False, tensors_per_shard: int = DEFAULT_SPLIT_TENSORS, small_first_shard: bool = True,
) -> None:
check_vocab_size(params, vocab, pad_vocab=pad_vocab)
model_list = list(model.items())
total_shards = math.ceil(total_tensors / tensors_per_shard) + small_first_shard
shard_files = [fname_out.with_name(SHARD_NAME_FORMAT.format(fname_out.stem, i + 1, total_shards)) for i in range(total_shards)]
for i, shard in enumerate(shard_files):
of = OutputFile(shard, endianess=endianess)
if i == 0:
of.add_meta_arch(params)
if isinstance(vocab, Vocab):
of.add_meta_vocab(vocab)
of.add_meta_special_vocab(svocab)
else: # NoVocab
of.gguf.add_tokenizer_model(vocab.tokenizer_model)
of.gguf.add_uint16(LLM_KV_SPLIT_NO, i)
of.gguf.add_uint16(LLM_KV_SPLIT_COUNT, total_shards)
of.gguf.add_int32(LLM_KV_SPLIT_TENSORS_COUNT, total_tensors)
# have the option to write a first shard with only the metadata
if small_first_shard and i == 0:
of.write_meta()
of.close()
continue
stop = min((i + 1 - small_first_shard) * tensors_per_shard, total_tensors)
shard_models = model_list[(i - small_first_shard) * tensors_per_shard:stop]
for name, lazy_tensor in shard_models:
of.add_tensor_info(name, lazy_tensor)
of.write_meta()
of.write_tensor_info()
of.write_tensor_data(ftype, dict(shard_models), concurrency)
of.close()
def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileType:
wq_type = model[gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0) + ".weight"].data_type
@ -1473,6 +1523,9 @@ def main(args_in: list[str] | None = None) -> None:
parser.add_argument("--big-endian", action="store_true", help="model is executed on big endian machine")
parser.add_argument("--pad-vocab", action="store_true", help="add pad tokens when model vocab expects more than tokenizer metadata provides")
parser.add_argument("--skip-unknown", action="store_true", help="skip unknown tensor names instead of failing")
parser.add_argument("--split", action="store_true", help="split the converted model into multiple files")
parser.add_argument("--split-max-tensors", type=int, help=f"maximum number of tensors per file when splitting (default: {DEFAULT_SPLIT_TENSORS})", default=DEFAULT_SPLIT_TENSORS)
parser.add_argument("--large-first-shard", action="store_true", help="include tensors in the first shard when splitting (default is to only include metadata)")
args = parser.parse_args(args_in)
if args.no_vocab and args.vocab_only:
@ -1544,6 +1597,23 @@ def main(args_in: list[str] | None = None) -> None:
outfile = args.outfile or default_outfile(model_plus.paths, ftype)
params.ftype = ftype
if args.split:
total_tensors = len(model)
if total_tensors < args.split_max_tensors:
print("Model has fewer tensors than the split threshold, not splitting")
print(f"Writing {outfile}, format {ftype}")
OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab,
concurrency=args.concurrency, endianess=endianess, pad_vocab=args.pad_vocab)
else:
print(f"Writing {outfile} as shards, format {ftype}")
OutputFile.write_split(outfile, ftype, params, model, vocab, special_vocab, total_tensors,
concurrency=args.concurrency, endianess=endianess, pad_vocab=args.pad_vocab,
tensors_per_shard=args.split_max_tensors, small_first_shard=not args.large_first_shard)
print(f"Wrote {outfile}")
else:
print(f"Writing {outfile}, format {ftype}")
OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab,