streaming conversion without pytorch
This commit is contained in:
parent
2f700a2738
commit
289637a6a3
1 changed files with 79 additions and 13 deletions
|
@ -17,11 +17,16 @@
|
||||||
# and vocabulary.
|
# and vocabulary.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
import sys
|
import sys
|
||||||
import json
|
import json
|
||||||
import struct
|
import struct
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
from tqdm import tqdm
|
||||||
|
import zipfile
|
||||||
|
import pickle
|
||||||
|
import concurrent.futures
|
||||||
|
|
||||||
from sentencepiece import SentencePieceProcessor
|
from sentencepiece import SentencePieceProcessor
|
||||||
|
|
||||||
if len(sys.argv) < 3:
|
if len(sys.argv) < 3:
|
||||||
|
@ -73,19 +78,22 @@ hparams.update({"vocab_size": tokenizer.vocab_size()})
|
||||||
|
|
||||||
n_parts = get_n_parts(hparams["dim"])
|
n_parts = get_n_parts(hparams["dim"])
|
||||||
|
|
||||||
print(hparams)
|
print(f'Model params.json: {hparams}')
|
||||||
print('n_parts = ', n_parts)
|
print(f'Parts to process: {n_parts}')
|
||||||
|
|
||||||
for p in range(n_parts):
|
|
||||||
print('Processing part ', p)
|
|
||||||
|
|
||||||
#fname_model = sys.argv[1] + "/consolidated.00.pth"
|
def get_fname(p):
|
||||||
fname_model = sys.argv[1] + "/consolidated.0" + str(p) + ".pth"
|
fname = "/consolidated.0" + str(p) + ".pth"
|
||||||
|
return fname
|
||||||
|
|
||||||
|
def process_part(p):
|
||||||
|
fname = get_fname(p)
|
||||||
|
fname_model = sys.argv[1] + fname
|
||||||
fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin"
|
fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin"
|
||||||
if (p > 0):
|
if (p > 0):
|
||||||
fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin" + "." + str(p)
|
fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin" + "." + str(p)
|
||||||
|
|
||||||
model = torch.load(fname_model, map_location="cpu")
|
print(f"Processing part {fname}")
|
||||||
|
|
||||||
fout = open(fname_out, "wb")
|
fout = open(fname_out, "wb")
|
||||||
|
|
||||||
|
@ -123,7 +131,58 @@ for p in range(n_parts):
|
||||||
fout.write(struct.pack("i", len(text)))
|
fout.write(struct.pack("i", len(text)))
|
||||||
fout.write(text)
|
fout.write(text)
|
||||||
|
|
||||||
for k, v in model.items():
|
|
||||||
|
def load_model(fname):
|
||||||
|
class Tensor():
|
||||||
|
def __init__(self, shape, dtype, loadinfo):
|
||||||
|
self.shape = shape
|
||||||
|
self.dtype = dtype
|
||||||
|
self.loadinfo = loadinfo
|
||||||
|
# print(shape, dtype)
|
||||||
|
|
||||||
|
def numpy(self):
|
||||||
|
fname_model, base_name, storage_offset, k, shape, dtype = self.loadinfo
|
||||||
|
with zipfile.ZipFile(fname_model, 'r') as myzip:
|
||||||
|
with myzip.open(f'{base_name}/data/{k}') as myfile:
|
||||||
|
bytes_size = np.dtype(self.dtype).itemsize
|
||||||
|
myfile.seek(storage_offset * bytes_size, 1)
|
||||||
|
ret = np.empty(shape, dtype=dtype)
|
||||||
|
myfile.readinto(ret.data)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def my_unpickle(datapkl, fname_model, base_name):
|
||||||
|
def my_rebuild_tensor(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None):
|
||||||
|
storage_type = storage[1]
|
||||||
|
obj_key = storage[2]
|
||||||
|
return Tensor(shape=size, dtype=storage_type, loadinfo=(
|
||||||
|
fname_model, base_name, storage_offset,
|
||||||
|
obj_key, size, storage_type
|
||||||
|
))
|
||||||
|
|
||||||
|
class MyUnpickler(pickle.Unpickler):
|
||||||
|
def find_class(self, *p):
|
||||||
|
if p == ('torch', 'HalfStorage'): return np.float16
|
||||||
|
if p == ('torch', 'FloatStorage'): return np.float32
|
||||||
|
if p == ('torch._utils', '_rebuild_tensor_v2'): return my_rebuild_tensor
|
||||||
|
if p == ('collections', 'OrderedDict'): return dict
|
||||||
|
raise ValueError(f'Unrecognized pickle {p}')
|
||||||
|
|
||||||
|
def persistent_load(self, pid):
|
||||||
|
return pid
|
||||||
|
|
||||||
|
return MyUnpickler(datapkl).load()
|
||||||
|
|
||||||
|
with zipfile.ZipFile(fname, 'r') as myzip:
|
||||||
|
base_name = myzip.namelist()[0].split('/', 1)[0]
|
||||||
|
# print(myzip.namelist())
|
||||||
|
with myzip.open(f'{base_name}/data.pkl') as myfile:
|
||||||
|
model = my_unpickle(myfile, fname, base_name)
|
||||||
|
return model
|
||||||
|
|
||||||
|
model = load_model(fname_model)
|
||||||
|
|
||||||
|
for k, v in (t := tqdm(model.items())):
|
||||||
|
t.set_description(f"Processing {k} with shape {tuple(v.shape)} and type {np.dtype(v.dtype)}")
|
||||||
name = k
|
name = k
|
||||||
shape = v.shape
|
shape = v.shape
|
||||||
|
|
||||||
|
@ -131,11 +190,11 @@ for p in range(n_parts):
|
||||||
if name[-5:] == "freqs":
|
if name[-5:] == "freqs":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
print("Processing variable: " + name + " with shape: ", shape, " and type: ", v.dtype)
|
# print("Processing variable: " + name + " with shape: ", shape, " and type: ", v.dtype)
|
||||||
|
|
||||||
#data = tf.train.load_variable(dir_model, name).squeeze()
|
#data = tf.train.load_variable(dir_model, name).squeeze()
|
||||||
data = v.numpy().squeeze()
|
data = v.numpy().squeeze()
|
||||||
n_dims = len(data.shape);
|
n_dims = len(data.shape)
|
||||||
|
|
||||||
# for efficiency - transpose some matrices
|
# for efficiency - transpose some matrices
|
||||||
# "model/h.*/attn/c_attn/w"
|
# "model/h.*/attn/c_attn/w"
|
||||||
|
@ -154,7 +213,7 @@ for p in range(n_parts):
|
||||||
# default type is fp16
|
# default type is fp16
|
||||||
ftype_cur = 1
|
ftype_cur = 1
|
||||||
if ftype == 0 or n_dims == 1:
|
if ftype == 0 or n_dims == 1:
|
||||||
print(" Converting to float32")
|
# print(" Converting to float32")
|
||||||
data = data.astype(np.float32)
|
data = data.astype(np.float32)
|
||||||
ftype_cur = 0
|
ftype_cur = 0
|
||||||
|
|
||||||
|
@ -163,7 +222,7 @@ for p in range(n_parts):
|
||||||
fout.write(struct.pack("iii", n_dims, len(sname), ftype_cur))
|
fout.write(struct.pack("iii", n_dims, len(sname), ftype_cur))
|
||||||
for i in range(n_dims):
|
for i in range(n_dims):
|
||||||
fout.write(struct.pack("i", dshape[n_dims - 1 - i]))
|
fout.write(struct.pack("i", dshape[n_dims - 1 - i]))
|
||||||
fout.write(sname);
|
fout.write(sname)
|
||||||
|
|
||||||
# data
|
# data
|
||||||
data.tofile(fout)
|
data.tofile(fout)
|
||||||
|
@ -175,3 +234,10 @@ for p in range(n_parts):
|
||||||
|
|
||||||
print("Done. Output file: " + fname_out + ", (part ", p, ")")
|
print("Done. Output file: " + fname_out + ", (part ", p, ")")
|
||||||
print("")
|
print("")
|
||||||
|
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
|
||||||
|
futures = {executor.submit(process_part, p) for p in range(n_parts)}
|
||||||
|
for f in (concurrent.futures.as_completed(futures)):
|
||||||
|
if f.exception() is not None: raise f.exception()
|
||||||
|
|
||||||
|
print("All done.")
|
Loading…
Add table
Add a link
Reference in a new issue