Make loading weights 10-100x faster
This is a breaking change that's going to give you three benefits: 1. Your inference commands should load 100x faster 2. You may be able to safely load models 2x larger 3. You can run many concurrent inference processes This was accomplished by changing the file format so we can mmap() weights directly into memory without having to read() or copy them thereby ensuring the kernel can make its file cache pages directly accessible to our inference processes; and secondly, that the file cache pages are much less likely to get evicted (which would force loads to hit disk) because they're no longer competing with memory pages that were needlessly created by gigabytes of standard i/o. The new file format supports single-file models like LLaMA 7b, and it also supports multi-file models like LLaMA 13B. Our Python tool now merges the foo.1, foo.2, etc. files back into a single file so that the C++ code which maps it doesn't need to reshape data every time. That's made llama.cpp so much simpler. Much of its load code has now been deleted. Furthermore, this change ensures that tensors are aligned properly on a 32-byte boundary. That opens the door to seeing if we can get additional performance gains on some microprocessors, by using ops that require memory alignment. Lastly note that both POSIX and the Windows platform are supported Fixes #91
This commit is contained in:
parent
a017390358
commit
78ca9838ee
7 changed files with 336 additions and 375 deletions
|
@ -24,8 +24,57 @@ import torch
|
|||
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
|
||||
def parse_args():
|
||||
QK = 32
|
||||
|
||||
GGML_TYPE_Q4_0 = 0
|
||||
GGML_TYPE_Q4_1 = 1
|
||||
GGML_TYPE_I8 = 2
|
||||
GGML_TYPE_I16 = 3
|
||||
GGML_TYPE_I32 = 4
|
||||
GGML_TYPE_F16 = 5
|
||||
GGML_TYPE_F32 = 6
|
||||
|
||||
WTYPES = {
|
||||
0: GGML_TYPE_F32,
|
||||
1: GGML_TYPE_F16,
|
||||
2: GGML_TYPE_Q4_0,
|
||||
3: GGML_TYPE_Q4_1,
|
||||
}
|
||||
|
||||
GGML_BLCK_SIZE = {
|
||||
GGML_TYPE_Q4_0: QK,
|
||||
GGML_TYPE_Q4_1: QK,
|
||||
GGML_TYPE_I8: 1,
|
||||
GGML_TYPE_I16: 1,
|
||||
GGML_TYPE_I32: 1,
|
||||
GGML_TYPE_F16: 1,
|
||||
GGML_TYPE_F32: 1,
|
||||
}
|
||||
|
||||
GGML_TYPE_SIZE = {
|
||||
GGML_TYPE_Q4_0: 4 + QK/2,
|
||||
GGML_TYPE_Q4_1: 4*2 + QK/2,
|
||||
GGML_TYPE_I8: 1,
|
||||
GGML_TYPE_I16: 2,
|
||||
GGML_TYPE_I32: 4,
|
||||
GGML_TYPE_F16: 2,
|
||||
GGML_TYPE_F32: 4,
|
||||
}
|
||||
|
||||
def ggml_nelements(shape):
|
||||
r = 1
|
||||
for i in shape:
|
||||
r *= i
|
||||
return r
|
||||
|
||||
def ggml_nbytes(shape, ftype):
|
||||
x = ggml_nelements(shape)
|
||||
t = WTYPES[ftype]
|
||||
x *= GGML_TYPE_SIZE[t]
|
||||
x //= GGML_BLCK_SIZE[t]
|
||||
return x
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Convert a LLaMA model checkpoint to a ggml compatible file')
|
||||
parser.add_argument('dir_model', help='directory containing the model checkpoint')
|
||||
parser.add_argument('ftype', help='file type (0: float32, 1: float16)', type=int, choices=[0, 1], default=1)
|
||||
|
@ -33,7 +82,6 @@ def parse_args():
|
|||
return parser.parse_args()
|
||||
|
||||
def get_n_parts(dim):
|
||||
|
||||
mappings = {4096: 1, 5120: 2, 6656: 4, 8192: 8}
|
||||
n_parts = mappings.get(dim)
|
||||
if n_parts is None:
|
||||
|
@ -44,30 +92,24 @@ def get_n_parts(dim):
|
|||
return n_parts
|
||||
|
||||
def load_hparams_and_tokenizer(dir_model):
|
||||
|
||||
# `dir_model` is something like `models/7B` or `models/7B/`.
|
||||
# "tokenizer.model" is expected under model's parent dir.
|
||||
# When `dir_model` is a symlink, f"{dir_model}/../tokenizer.model" would not be found.
|
||||
# Let's use the model's parent dir directly.
|
||||
model_parent_dir = os.path.dirname(os.path.normpath(dir_model))
|
||||
|
||||
fname_hparams = f"{dir_model}/params.json"
|
||||
fname_tokenizer = f"{model_parent_dir}/tokenizer.model"
|
||||
|
||||
with open(fname_hparams, "r") as f:
|
||||
hparams = json.load(f)
|
||||
print(hparams)
|
||||
|
||||
tokenizer = SentencePieceProcessor(fname_tokenizer)
|
||||
hparams.update({"vocab_size": tokenizer.vocab_size()})
|
||||
|
||||
return hparams, tokenizer
|
||||
|
||||
def write_header(fout, hparams, ftype):
|
||||
|
||||
keys = ["vocab_size", "dim", "multiple_of", "n_heads", "n_layers"]
|
||||
values = [
|
||||
0x67676d66, # magic: ggmf in hex
|
||||
0x67676a74, # magic: ggjt in hex
|
||||
1, # file version
|
||||
*[hparams[key] for key in keys],
|
||||
hparams["dim"] // hparams["n_heads"], # rot (obsolete)
|
||||
|
@ -76,7 +118,6 @@ def write_header(fout, hparams, ftype):
|
|||
fout.write(struct.pack("i" * len(values), *values))
|
||||
|
||||
def write_tokens(fout, tokenizer):
|
||||
|
||||
for i in range(tokenizer.vocab_size()):
|
||||
if tokenizer.is_unknown(i):
|
||||
text = " \u2047 ".encode("utf-8")
|
||||
|
@ -95,85 +136,141 @@ def write_tokens(fout, tokenizer):
|
|||
fout.write(text)
|
||||
fout.write(struct.pack("f", tokenizer.get_score(i)))
|
||||
|
||||
def process_and_write_variables(fout, model, ftype):
|
||||
|
||||
def process_and_write_variables(fout, model, ftype, part_id, n_parts):
|
||||
for name, datao in model.items():
|
||||
|
||||
if name.endswith("freqs"):
|
||||
continue
|
||||
|
||||
shape = datao.shape
|
||||
|
||||
print(f"Processing variable: {name} with shape: {shape} and type: {datao.dtype}")
|
||||
|
||||
# remove dimensions with a single element
|
||||
data = datao.numpy().squeeze()
|
||||
n_dims = len(shape)
|
||||
partshape = data.shape
|
||||
n_dims = len(data.shape)
|
||||
assert n_dims in (1, 2)
|
||||
|
||||
# default type is fp16
|
||||
print(f"Processing variable: {name} with shape: {partshape} and type: {datao.dtype}")
|
||||
|
||||
# coerce single-dimensional tensors from float16 to float32
|
||||
ftype_cur = 1
|
||||
if ftype == 0 or n_dims == 1:
|
||||
print(" Converting to float32")
|
||||
data = data.astype(np.float32)
|
||||
ftype_cur = 0
|
||||
blck_size = GGML_BLCK_SIZE[WTYPES[ftype_cur]]
|
||||
type_size = GGML_TYPE_SIZE[WTYPES[ftype_cur]]
|
||||
|
||||
# header
|
||||
# determine dimension along which multipart tensor is sharded
|
||||
#
|
||||
# split_dim 0 regex:
|
||||
# - output.*
|
||||
# - layers.*.attention.wq.weight
|
||||
# - layers.*.attention.wk.weight
|
||||
# - layers.*.attention.wv.weight
|
||||
# - layers.*.feed_forward.w1.weight
|
||||
# - layers.*.feed_forward.w3.weight
|
||||
#
|
||||
# split_dim 1 regex:
|
||||
# - tok_embeddings.*
|
||||
# - layers.*.attention.wo.weight
|
||||
# - layers.*.feed_forward.w2.weight
|
||||
#
|
||||
if n_dims > 1:
|
||||
split_dim = 1
|
||||
if "tok_embeddings" in name:
|
||||
split_dim = 1
|
||||
elif "layers" in name:
|
||||
if "attention.wo.weight" in name:
|
||||
split_dim = 1
|
||||
elif "feed_forward.w2.weight" in name:
|
||||
split_dim = 1
|
||||
else:
|
||||
split_dim = 0
|
||||
elif "output" in name:
|
||||
split_dim = 0
|
||||
|
||||
# output tensor header
|
||||
fullshape = list(partshape)
|
||||
if n_dims > 1:
|
||||
fullshape[split_dim] *= n_parts
|
||||
sname = name.encode('utf-8')
|
||||
fout.write(struct.pack("iii", len(data.shape), len(sname), ftype_cur))
|
||||
for dim in reversed(data.shape):
|
||||
fout.write(struct.pack("iii", n_dims, len(sname), ftype_cur))
|
||||
for dim in reversed(fullshape):
|
||||
fout.write(struct.pack("i", dim))
|
||||
fout.write(sname)
|
||||
|
||||
# data output to file
|
||||
data.tofile(fout)
|
||||
# ensure tensor data is aligned
|
||||
tensor_data_offset = fout.tell()
|
||||
while tensor_data_offset % QK != 0:
|
||||
fout.write(struct.pack("B", 0))
|
||||
tensor_data_offset += 1
|
||||
|
||||
# output unified mappable tensor data
|
||||
if n_dims == 1 or n_parts == 1:
|
||||
# copy tensor which we thankfully received in one piece
|
||||
if part_id == 0:
|
||||
data.tofile(fout)
|
||||
elif split_dim == 0:
|
||||
# reassemble multifile tensor containing some of the rows
|
||||
rows_per_chunk = partshape[0]
|
||||
current_row = part_id * rows_per_chunk
|
||||
bytes_per_row = fullshape[1] // blck_size * type_size
|
||||
offset = current_row * bytes_per_row
|
||||
fout.seek(tensor_data_offset + offset)
|
||||
data.tofile(fout)
|
||||
elif split_dim == 1:
|
||||
# reassemble multifile tensor containing some of the cols
|
||||
cols_per_chunk = partshape[1]
|
||||
current_col = part_id * cols_per_chunk
|
||||
bytes_per_row = fullshape[1] // blck_size * type_size
|
||||
offset_current_col = current_col // blck_size * type_size
|
||||
for row in range(partshape[0]):
|
||||
offset_row = row * bytes_per_row
|
||||
offset = offset_row + offset_current_col
|
||||
fout.seek(tensor_data_offset + offset)
|
||||
data[row].tofile(fout)
|
||||
|
||||
# advance file position to next tensor
|
||||
fout.seek(tensor_data_offset + ggml_nbytes(fullshape, ftype_cur))
|
||||
|
||||
def main():
|
||||
|
||||
args = parse_args()
|
||||
dir_model = args.dir_model
|
||||
ftype = args.ftype
|
||||
ftype_str = ["f32", "f16"]
|
||||
|
||||
hparams, tokenizer = load_hparams_and_tokenizer(dir_model)
|
||||
|
||||
print(args)
|
||||
|
||||
# if only writing vocab to file
|
||||
if args.vocab_only:
|
||||
|
||||
fname_model = f"{dir_model}/consolidated.00.pth"
|
||||
fname_out = f"{dir_model}/ggml-vocab.bin"
|
||||
|
||||
print(f"Extracting only the vocab from '{fname_model}'\n")
|
||||
|
||||
|
||||
model = torch.load(fname_model, map_location="cpu")
|
||||
with open(fname_out, "wb") as fout:
|
||||
write_header(fout, hparams, ftype)
|
||||
write_tokens(fout, tokenizer)
|
||||
|
||||
|
||||
del model
|
||||
print(f"Done. Output file: {fname_out}\n")
|
||||
|
||||
return
|
||||
|
||||
n_parts = get_n_parts(hparams["dim"])
|
||||
fname_out = f"{dir_model}/ggml-model-{ftype_str[ftype]}.bin"
|
||||
|
||||
for p in range(n_parts):
|
||||
# we output a single file for ggml
|
||||
with open(fname_out, "wb") as fout:
|
||||
write_header(fout, hparams, ftype)
|
||||
write_tokens(fout, tokenizer)
|
||||
offset_of_tensors = fout.tell()
|
||||
# the tensors we load could be split across multiple files
|
||||
for part_id in range(n_parts):
|
||||
fout.seek(offset_of_tensors)
|
||||
print(f"Processing part {part_id+1} of {n_parts}\n")
|
||||
fname_model = f"{dir_model}/consolidated.0{part_id}.pth"
|
||||
model = torch.load(fname_model, map_location="cpu")
|
||||
process_and_write_variables(fout, model, ftype, part_id, n_parts)
|
||||
del model
|
||||
|
||||
print(f"Processing part {p+1} of {n_parts}\n")
|
||||
|
||||
fname_model = f"{dir_model}/consolidated.0{p}.pth"
|
||||
fname_out = f"{dir_model}/ggml-model-{ftype_str[ftype]}.bin{'' if p == 0 else '.' + str(p)}"
|
||||
|
||||
model = torch.load(fname_model, map_location="cpu")
|
||||
|
||||
with open(fname_out, "wb") as fout:
|
||||
write_header(fout, hparams, ftype)
|
||||
write_tokens(fout, tokenizer)
|
||||
process_and_write_variables(fout, model, ftype)
|
||||
|
||||
del model
|
||||
|
||||
print(f"Done. Output file: {fname_out}, (part {p})\n")
|
||||
print(f"Done. Output file: {fname_out}\n")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue