improvement(tools): optimize with argparse

This commit is contained in:
tpoisonooo 2023-03-17 16:53:53 +08:00
parent 904d2a8d6a
commit 3c7cb413fb

View file

@ -22,19 +22,27 @@ import json
import struct import struct
import numpy as np import numpy as np
import torch import torch
import argparse
import os
from sentencepiece import SentencePieceProcessor from sentencepiece import SentencePieceProcessor
if len(sys.argv) < 3:
print("Usage: convert-ckpt-to-ggml.py dir-model ftype\n")
print(" ftype == 0 -> float32")
print(" ftype == 1 -> float16")
sys.exit(1)
# output in the same directory as the model def parse_args():
dir_model = sys.argv[1] parser = argparse.ArgumentParser(
description='Convert ckpt models to ggml models.')
parser.add_argument('dir_model',
type=str,
help='Directory path of the checkpoint model')
parser.add_argument('ftype',
type=str,
choices=['f32', 'f16'],
help='Data type of the converted tensor, f32 or f16')
parser.add_argument('out_dir',
type=str,
help='Directory path for storing ggml model')
return parser.parse_args()
fname_hparams = sys.argv[1] + "/params.json"
fname_tokenizer = sys.argv[1] + "/../tokenizer.model"
def get_n_parts(dim): def get_n_parts(dim):
if dim == 4096: if dim == 4096:
@ -49,41 +57,43 @@ def get_n_parts(dim):
print("Invalid dim: " + str(dim)) print("Invalid dim: " + str(dim))
sys.exit(1) sys.exit(1)
# possible data types
# ftype == 0 -> float32
# ftype == 1 -> float16
#
# map from ftype to string
ftype_str = ["f32", "f16"]
ftype = 1 def main():
if len(sys.argv) > 2: args = parse_args()
ftype = int(sys.argv[2]) dir_model = args.dir_model
if ftype < 0 or ftype > 1: out_dir = args.out_dir
print("Invalid ftype: " + str(ftype))
sys.exit(1)
fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin"
with open(fname_hparams, "r") as f: if not os.path.exists(out_dir):
os.mkdir(out_dir)
ftype = args.ftype
ftype_int = {'f32': 0, 'f16': 1}
fname_hparams = os.path.join(dir_model, 'params.json')
fname_tokenizer = os.path.join(dir_model, '..', 'tokenizer.model')
with open(fname_hparams, "r") as f:
hparams = json.load(f) hparams = json.load(f)
tokenizer = SentencePieceProcessor(fname_tokenizer) tokenizer = SentencePieceProcessor(fname_tokenizer)
hparams.update({"vocab_size": tokenizer.vocab_size()}) hparams.update({"vocab_size": tokenizer.vocab_size()})
n_parts = get_n_parts(hparams["dim"]) n_parts = get_n_parts(hparams["dim"])
print(hparams) print(hparams)
print('n_parts = ', n_parts) print('n_parts = ', n_parts)
for p in range(n_parts): for p in range(n_parts):
print('Processing part ', p) print('Processing part ', p)
#fname_model = sys.argv[1] + "/consolidated.00.pth" #fname_model = sys.argv[1] + "/consolidated.00.pth"
fname_model = sys.argv[1] + "/consolidated.0" + str(p) + ".pth" fname_model = os.path.join(dir_model, "consolidated.0{}.pth".format(p))
fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin" if p > 0:
if (p > 0): fname_out = os.path.join(out_dir,
fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin" + "." + str(p) "ggml-model-{}.bin.{}".format(ftype, p))
else:
fname_out = os.path.join(out_dir,
"ggml-model-{}.bin".format(ftype))
model = torch.load(fname_model, map_location="cpu") model = torch.load(fname_model, map_location="cpu")
@ -95,8 +105,9 @@ for p in range(n_parts):
fout.write(struct.pack("i", hparams["multiple_of"])) fout.write(struct.pack("i", hparams["multiple_of"]))
fout.write(struct.pack("i", hparams["n_heads"])) fout.write(struct.pack("i", hparams["n_heads"]))
fout.write(struct.pack("i", hparams["n_layers"])) fout.write(struct.pack("i", hparams["n_layers"]))
fout.write(struct.pack("i", hparams["dim"] // hparams["n_heads"])) # rot (obsolete) fout.write(struct.pack("i", hparams["dim"] //
fout.write(struct.pack("i", ftype)) hparams["n_heads"])) # rot (obsolete)
fout.write(struct.pack("i", ftype_int[ftype]))
# Is this correct?? # Is this correct??
for i in range(tokenizer.vocab_size()): for i in range(tokenizer.vocab_size()):
@ -119,7 +130,8 @@ for p in range(n_parts):
fout.write(struct.pack("B", byte_value)) fout.write(struct.pack("B", byte_value))
else: else:
# normal token. Uses U+2581 (LOWER ONE EIGHTH BLOCK) to represent spaces. # normal token. Uses U+2581 (LOWER ONE EIGHTH BLOCK) to represent spaces.
text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8") text = tokenizer.id_to_piece(i).replace("\u2581",
" ").encode("utf-8")
fout.write(struct.pack("i", len(text))) fout.write(struct.pack("i", len(text)))
fout.write(text) fout.write(text)
@ -131,11 +143,12 @@ for p in range(n_parts):
if name[-5:] == "freqs": if name[-5:] == "freqs":
continue continue
print("Processing variable: " + name + " with shape: ", shape, " and type: ", v.dtype) print("Processing variable: " + name + " with shape: ", shape,
" and type: ", v.dtype)
#data = tf.train.load_variable(dir_model, name).squeeze() #data = tf.train.load_variable(dir_model, name).squeeze()
data = v.numpy().squeeze() data = v.numpy().squeeze()
n_dims = len(data.shape); n_dims = len(data.shape)
# for efficiency - transpose some matrices # for efficiency - transpose some matrices
# "model/h.*/attn/c_attn/w" # "model/h.*/attn/c_attn/w"
@ -153,7 +166,7 @@ for p in range(n_parts):
# default type is fp16 # default type is fp16
ftype_cur = 1 ftype_cur = 1
if ftype == 0 or n_dims == 1: if ftype == 'f32' or n_dims == 1:
print(" Converting to float32") print(" Converting to float32")
data = data.astype(np.float32) data = data.astype(np.float32)
ftype_cur = 0 ftype_cur = 0
@ -163,7 +176,7 @@ for p in range(n_parts):
fout.write(struct.pack("iii", n_dims, len(sname), ftype_cur)) fout.write(struct.pack("iii", n_dims, len(sname), ftype_cur))
for i in range(n_dims): for i in range(n_dims):
fout.write(struct.pack("i", dshape[n_dims - 1 - i])) fout.write(struct.pack("i", dshape[n_dims - 1 - i]))
fout.write(sname); fout.write(sname)
# data # data
data.tofile(fout) data.tofile(fout)
@ -175,3 +188,7 @@ for p in range(n_parts):
print("Done. Output file: " + fname_out + ", (part ", p, ")") print("Done. Output file: " + fname_out + ", (part ", p, ")")
print("") print("")
if __name__ == '__main__':
main()