diff --git a/convert-pth-to-ggml.py b/convert-pth-to-ggml.py index d2557500a..75e182cc1 100644 --- a/convert-pth-to-ggml.py +++ b/convert-pth-to-ggml.py @@ -17,11 +17,16 @@ # and vocabulary. # +from collections import defaultdict import sys import json import struct import numpy as np -import torch +from tqdm import tqdm +import zipfile +import pickle +import concurrent.futures + from sentencepiece import SentencePieceProcessor if len(sys.argv) < 3: @@ -73,19 +78,22 @@ hparams.update({"vocab_size": tokenizer.vocab_size()}) n_parts = get_n_parts(hparams["dim"]) -print(hparams) -print('n_parts = ', n_parts) +print(f'Model params.json: {hparams}') +print(f'Parts to process: {n_parts}') -for p in range(n_parts): - print('Processing part ', p) - #fname_model = sys.argv[1] + "/consolidated.00.pth" - fname_model = sys.argv[1] + "/consolidated.0" + str(p) + ".pth" +def get_fname(p): + fname = "/consolidated.0" + str(p) + ".pth" + return fname + +def process_part(p): + fname = get_fname(p) + fname_model = sys.argv[1] + fname fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin" if (p > 0): fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin" + "." + str(p) - model = torch.load(fname_model, map_location="cpu") + print(f"Processing part {fname}") fout = open(fname_out, "wb") @@ -123,7 +131,58 @@ for p in range(n_parts): fout.write(struct.pack("i", len(text))) fout.write(text) - for k, v in model.items(): + + def load_model(fname): + class Tensor(): + def __init__(self, shape, dtype, loadinfo): + self.shape = shape + self.dtype = dtype + self.loadinfo = loadinfo + # print(shape, dtype) + + def numpy(self): + fname_model, base_name, storage_offset, k, shape, dtype = self.loadinfo + with zipfile.ZipFile(fname_model, 'r') as myzip: + with myzip.open(f'{base_name}/data/{k}') as myfile: + bytes_size = np.dtype(self.dtype).itemsize + myfile.seek(storage_offset * bytes_size, 1) + ret = np.empty(shape, dtype=dtype) + myfile.readinto(ret.data) + return ret + + def my_unpickle(datapkl, fname_model, base_name): + def my_rebuild_tensor(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None): + storage_type = storage[1] + obj_key = storage[2] + return Tensor(shape=size, dtype=storage_type, loadinfo=( + fname_model, base_name, storage_offset, + obj_key, size, storage_type + )) + + class MyUnpickler(pickle.Unpickler): + def find_class(self, *p): + if p == ('torch', 'HalfStorage'): return np.float16 + if p == ('torch', 'FloatStorage'): return np.float32 + if p == ('torch._utils', '_rebuild_tensor_v2'): return my_rebuild_tensor + if p == ('collections', 'OrderedDict'): return dict + raise ValueError(f'Unrecognized pickle {p}') + + def persistent_load(self, pid): + return pid + + return MyUnpickler(datapkl).load() + + with zipfile.ZipFile(fname, 'r') as myzip: + base_name = myzip.namelist()[0].split('/', 1)[0] + # print(myzip.namelist()) + with myzip.open(f'{base_name}/data.pkl') as myfile: + model = my_unpickle(myfile, fname, base_name) + return model + + model = load_model(fname_model) + + for k, v in (t := tqdm(model.items())): + t.set_description(f"Processing {k} with shape {tuple(v.shape)} and type {np.dtype(v.dtype)}") name = k shape = v.shape @@ -131,11 +190,11 @@ for p in range(n_parts): if name[-5:] == "freqs": continue - print("Processing variable: " + name + " with shape: ", shape, " and type: ", v.dtype) + # print("Processing variable: " + name + " with shape: ", shape, " and type: ", v.dtype) #data = tf.train.load_variable(dir_model, name).squeeze() data = v.numpy().squeeze() - n_dims = len(data.shape); + n_dims = len(data.shape) # for efficiency - transpose some matrices # "model/h.*/attn/c_attn/w" @@ -154,7 +213,7 @@ for p in range(n_parts): # default type is fp16 ftype_cur = 1 if ftype == 0 or n_dims == 1: - print(" Converting to float32") + # print(" Converting to float32") data = data.astype(np.float32) ftype_cur = 0 @@ -163,7 +222,7 @@ for p in range(n_parts): fout.write(struct.pack("iii", n_dims, len(sname), ftype_cur)) for i in range(n_dims): fout.write(struct.pack("i", dshape[n_dims - 1 - i])) - fout.write(sname); + fout.write(sname) # data data.tofile(fout) @@ -175,3 +234,10 @@ for p in range(n_parts): print("Done. Output file: " + fname_out + ", (part ", p, ")") print("") + +with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: + futures = {executor.submit(process_part, p) for p in range(n_parts)} + for f in (concurrent.futures.as_completed(futures)): + if f.exception() is not None: raise f.exception() + +print("All done.") \ No newline at end of file