From e1b1e12a414aa2289b6fbee588d2b70393dd812d Mon Sep 17 00:00:00 2001 From: qunash Date: Tue, 14 Mar 2023 22:52:22 +0300 Subject: [PATCH] modularize --- convert-pth-to-ggml.py | 196 +++++++++++++++++++++-------------------- 1 file changed, 100 insertions(+), 96 deletions(-) diff --git a/convert-pth-to-ggml.py b/convert-pth-to-ggml.py index 05f3cbd7c..8815f3a6b 100644 --- a/convert-pth-to-ggml.py +++ b/convert-pth-to-ggml.py @@ -16,7 +16,6 @@ # At the start of the ggml file we write the model parameters # and vocabulary. # - import argparse import sys import json @@ -25,129 +24,134 @@ import numpy as np import torch from sentencepiece import SentencePieceProcessor -parser = argparse.ArgumentParser(description='Convert a LLaMA model checkpoint to a ggml compatible file') -parser.add_argument('dir_model', help='directory containing the model checkpoint') -parser.add_argument('ftype', type=int, choices=[0, 1], default=1, help='file type (0: float32, 1: float16)') -args = parser.parse_args() +def parse_args(): -# output in the same directory as the model -dir_model = args.dir_model -ftype = args.ftype - -fname_hparams = f"{dir_model}/params.json" -fname_tokenizer = f"{dir_model}/../tokenizer.model" + parser = argparse.ArgumentParser(description='Convert a LLaMA model checkpoint to a ggml compatible file') + parser.add_argument('dir_model', help='directory containing the model checkpoint') + parser.add_argument('ftype', type=int, choices=[0, 1], default=1, help='file type (0: float32, 1: float16)') + return parser.parse_args() def get_n_parts(dim): + mappings = {4096: 1, 5120: 2, 6656: 4, 8192: 8} n_parts = mappings.get(dim) if n_parts is None: print(f"Invalid dim: {dim}") sys.exit(1) + + print(f"n_parts = {n_parts}\n") return n_parts -# map from ftype to string -ftype_str = ["f32", "f16"] +def load_hparams_and_tokenizer(dir_model): + + fname_hparams = f"{dir_model}/params.json" + fname_tokenizer = f"{dir_model}/../tokenizer.model" -with open(fname_hparams, "r") as f: - hparams = json.load(f) + with open(fname_hparams, "r") as f: + hparams = json.load(f) + print(hparams) -tokenizer = SentencePieceProcessor(fname_tokenizer) + tokenizer = SentencePieceProcessor(fname_tokenizer) + hparams.update({"vocab_size": tokenizer.vocab_size()}) -hparams.update({"vocab_size": tokenizer.vocab_size()}) + return hparams, tokenizer -n_parts = get_n_parts(hparams["dim"]) +def write_header(fout, hparams, ftype): + + keys = ["vocab_size", "dim", "multiple_of", "n_heads", "n_layers"] + values = [ + 0x67676d6c, # magic: ggml in hex + *[hparams[key] for key in keys], + hparams["dim"] // hparams["n_heads"], # rot (obsolete) + ftype + ] + fout.write(struct.pack("i" * len(values), *values)) -print(hparams) -print(f"n_parts = {n_parts}\n") +def write_tokens(fout, tokenizer): -for p in range(n_parts): - print(f"Processing part {p}\n") + for i in range(32000): + if tokenizer.is_unknown(i): + text = " \u2047 ".encode("utf-8") + elif tokenizer.is_control(i): + text = b"" + elif tokenizer.is_byte(i): + piece = tokenizer.id_to_piece(i) + if len(piece) != 6: + print(f"Invalid token: {piece}") + sys.exit(1) + byte_value = int(piece[3:-1], 16) + text = struct.pack("B", byte_value) + else: + text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8") + fout.write(struct.pack("i", len(text))) + fout.write(text) - #fname_model = args[1] + "/consolidated.00.pth" - fname_model = f"{dir_model}/consolidated.0{p}.pth" - fname_out = f"{dir_model}/ggml-model-{ftype_str[ftype]}.bin{'' if p == 0 else '.' + str(p)}" +def process_and_write_variables(fout, model, ftype): - model = torch.load(fname_model, map_location="cpu") + for name, data in model.items(): + + if name.endswith("freqs"): + continue + + shape = data.shape + + print(f"Processing variable: {name} with shape: {shape} and type: {data.dtype}\n") + + data = np.squeeze(data) + n_dims = len(shape) - with open(fname_out, "wb") as fout: + # for efficiency - transpose some matrices + # "model/h.*/attn/c_attn/w" + # "model/h.*/attn/c_proj/w" + # "model/h.*/mlp/c_fc/w" + # "model/h.*/mlp/c_proj/w" + #if name.endswith(("/attn/c_attn/w", "/attn/c_proj/w", "/mlp/c_fc/w", "/mlp/c_proj/w")): + # print("Transposing") + # data = data.transpose() - fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex - fout.write(struct.pack("i", hparams["vocab_size"])) - fout.write(struct.pack("i", hparams["dim"])) - fout.write(struct.pack("i", hparams["multiple_of"])) - fout.write(struct.pack("i", hparams["n_heads"])) - fout.write(struct.pack("i", hparams["n_layers"])) - fout.write(struct.pack("i", hparams["dim"] // hparams["n_heads"])) # rot (obsolete) - fout.write(struct.pack("i", ftype)) + # default type is fp16 + ftype_cur = 1 + if ftype == 0 or n_dims == 1: + print(" Converting to float32") + data = data.astype(np.float32) + ftype_cur = 0 - # Is this correct?? - for i in range(32000): - if tokenizer.is_unknown(i): - # "" token (translated as ??) - text = " \u2047 ".encode("utf-8") - fout.write(struct.pack("i", len(text))) - fout.write(text) - elif tokenizer.is_control(i): - # ""/"" tokens - fout.write(struct.pack("i", 0)) - elif tokenizer.is_byte(i): - # "" tokens (which may be invalid UTF-8) - piece = tokenizer.id_to_piece(i) - if len(piece) != 6: - print(f"Invalid token: {piece}") - sys.exit(1) - byte_value = int(piece[3:-1], 16) - fout.write(struct.pack("i", 1)) - fout.write(struct.pack("B", byte_value)) - else: - # normal token. Uses U+2581 (LOWER ONE EIGHTH BLOCK) to represent spaces. - text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8") - fout.write(struct.pack("i", len(text))) - fout.write(text) + # header + sname = name.encode('utf-8') + fout.write(struct.pack("iii", len(data.shape), len(sname), ftype_cur)) + for dim in reversed(data.shape): + fout.write(struct.pack("i", dim)) + fout.write(sname) + + # data + data.tofile(fout) - for name, data in model.items(): - shape = data.shape +def main(): - # skip layers.X.attention.inner_attention.rope.freqs - if name.endswith("freqs"): - continue + args = parse_args() + dir_model = args.dir_model + ftype = args.ftype + ftype_str = ["f32", "f16"] - print(f"Processing variable: {name} with shape: {shape} and type: {data.dtype}\n") + hparams, tokenizer = load_hparams_and_tokenizer(dir_model) + n_parts = get_n_parts(hparams["dim"]) - #data = tf.train.load_variable(dir_model, name).squeeze() - data = np.squeeze(data) - n_dims = len(data.shape) + for p in range(n_parts): - # for efficiency - transpose some matrices - # "model/h.*/attn/c_attn/w" - # "model/h.*/attn/c_proj/w" - # "model/h.*/mlp/c_fc/w" - # "model/h.*/mlp/c_proj/w" - #if name[-14:] == "/attn/c_attn/w" or \ - # name[-14:] == "/attn/c_proj/w" or \ - # name[-11:] == "/mlp/c_fc/w" or \ - # name[-13:] == "/mlp/c_proj/w": - # print(" Transposing") - # data = data.transpose() + print(f"Processing part {p}\n") + + fname_model = f"{dir_model}/consolidated.0{p}.pth" + fname_out = f"{dir_model}/ggml-model-{ftype_str[ftype]}.bin{'' if p == 0 else '.' + str(p)}" - # default type is fp16 - ftype_cur = 1 - if ftype == 0 or n_dims == 1: - print(" Converting to float32") - data = data.astype(np.float32) - ftype_cur = 0 + model = torch.load(fname_model, map_location="cpu") - # header - sname = name.encode('utf-8') - fout.write(struct.pack("iii", n_dims, len(sname), ftype_cur)) - for dim in reversed(data.shape): - fout.write(struct.pack("i", dim)) - - fout.write(sname) - - # data - data.tofile(fout) + with open(fname_out, "wb") as fout: + write_header(fout, hparams, ftype) + write_tokens(fout, tokenizer) + process_and_write_variables(fout, model, ftype) del model + print(f"Done. Output file: {fname_out}, (part {p})\n") - print(f"Done. Output file: {fname_out}, (part {p})\n") +if __name__ == "__main__": + main()