streaming conversion without pytorch

This commit is contained in:
Dmitry Wolf 2023-03-15 21:25:01 +03:00
parent 2f700a2738
commit 289637a6a3

View file

@ -17,11 +17,16 @@
# and vocabulary.
#
from collections import defaultdict
import sys
import json
import struct
import numpy as np
import torch
from tqdm import tqdm
import zipfile
import pickle
import concurrent.futures
from sentencepiece import SentencePieceProcessor
if len(sys.argv) < 3:
@ -73,19 +78,22 @@ hparams.update({"vocab_size": tokenizer.vocab_size()})
n_parts = get_n_parts(hparams["dim"])
print(hparams)
print('n_parts = ', n_parts)
print(f'Model params.json: {hparams}')
print(f'Parts to process: {n_parts}')
for p in range(n_parts):
print('Processing part ', p)
#fname_model = sys.argv[1] + "/consolidated.00.pth"
fname_model = sys.argv[1] + "/consolidated.0" + str(p) + ".pth"
def get_fname(p):
fname = "/consolidated.0" + str(p) + ".pth"
return fname
def process_part(p):
fname = get_fname(p)
fname_model = sys.argv[1] + fname
fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin"
if (p > 0):
fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin" + "." + str(p)
model = torch.load(fname_model, map_location="cpu")
print(f"Processing part {fname}")
fout = open(fname_out, "wb")
@ -123,7 +131,58 @@ for p in range(n_parts):
fout.write(struct.pack("i", len(text)))
fout.write(text)
for k, v in model.items():
def load_model(fname):
class Tensor():
def __init__(self, shape, dtype, loadinfo):
self.shape = shape
self.dtype = dtype
self.loadinfo = loadinfo
# print(shape, dtype)
def numpy(self):
fname_model, base_name, storage_offset, k, shape, dtype = self.loadinfo
with zipfile.ZipFile(fname_model, 'r') as myzip:
with myzip.open(f'{base_name}/data/{k}') as myfile:
bytes_size = np.dtype(self.dtype).itemsize
myfile.seek(storage_offset * bytes_size, 1)
ret = np.empty(shape, dtype=dtype)
myfile.readinto(ret.data)
return ret
def my_unpickle(datapkl, fname_model, base_name):
def my_rebuild_tensor(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None):
storage_type = storage[1]
obj_key = storage[2]
return Tensor(shape=size, dtype=storage_type, loadinfo=(
fname_model, base_name, storage_offset,
obj_key, size, storage_type
))
class MyUnpickler(pickle.Unpickler):
def find_class(self, *p):
if p == ('torch', 'HalfStorage'): return np.float16
if p == ('torch', 'FloatStorage'): return np.float32
if p == ('torch._utils', '_rebuild_tensor_v2'): return my_rebuild_tensor
if p == ('collections', 'OrderedDict'): return dict
raise ValueError(f'Unrecognized pickle {p}')
def persistent_load(self, pid):
return pid
return MyUnpickler(datapkl).load()
with zipfile.ZipFile(fname, 'r') as myzip:
base_name = myzip.namelist()[0].split('/', 1)[0]
# print(myzip.namelist())
with myzip.open(f'{base_name}/data.pkl') as myfile:
model = my_unpickle(myfile, fname, base_name)
return model
model = load_model(fname_model)
for k, v in (t := tqdm(model.items())):
t.set_description(f"Processing {k} with shape {tuple(v.shape)} and type {np.dtype(v.dtype)}")
name = k
shape = v.shape
@ -131,11 +190,11 @@ for p in range(n_parts):
if name[-5:] == "freqs":
continue
print("Processing variable: " + name + " with shape: ", shape, " and type: ", v.dtype)
# print("Processing variable: " + name + " with shape: ", shape, " and type: ", v.dtype)
#data = tf.train.load_variable(dir_model, name).squeeze()
data = v.numpy().squeeze()
n_dims = len(data.shape);
n_dims = len(data.shape)
# for efficiency - transpose some matrices
# "model/h.*/attn/c_attn/w"
@ -154,7 +213,7 @@ for p in range(n_parts):
# default type is fp16
ftype_cur = 1
if ftype == 0 or n_dims == 1:
print(" Converting to float32")
# print(" Converting to float32")
data = data.astype(np.float32)
ftype_cur = 0
@ -163,7 +222,7 @@ for p in range(n_parts):
fout.write(struct.pack("iii", n_dims, len(sname), ftype_cur))
for i in range(n_dims):
fout.write(struct.pack("i", dshape[n_dims - 1 - i]))
fout.write(sname);
fout.write(sname)
# data
data.tofile(fout)
@ -175,3 +234,10 @@ for p in range(n_parts):
print("Done. Output file: " + fname_out + ", (part ", p, ")")
print("")
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
futures = {executor.submit(process_part, p) for p in range(n_parts)}
for f in (concurrent.futures.as_completed(futures)):
if f.exception() is not None: raise f.exception()
print("All done.")