apply black to ggml-to-pth

This commit is contained in:
Geeks-sid 2023-03-29 19:15:06 -04:00
parent 9cbc404ba6
commit dfa2d707e9

View file

@ -72,7 +72,12 @@ def dequantize_weights(fin, n_rows, n_cols):
def read_variables(fin): def read_variables(fin):
model = {} model = {}
pbar = tqdm(total=os.path.getsize(fin.name), unit="B", unit_scale=True, desc="Reading variables") pbar = tqdm(
total=os.path.getsize(fin.name),
unit="B",
unit_scale=True,
desc="Reading variables",
)
while True: while True:
start_pos = fin.tell() start_pos = fin.tell()
try: try:
@ -98,7 +103,9 @@ def read_variables(fin):
data_size = np.prod(shape) data_size = np.prod(shape)
data = np.fromfile(fin, dtype=dtype, count=data_size).reshape(shape) data = np.fromfile(fin, dtype=dtype, count=data_size).reshape(shape)
model[name] = torch.tensor(data, dtype=torch.float32 if dtype == np.float32 else torch.float16) model[name] = torch.tensor(
data, dtype=torch.float32 if dtype == np.float32 else torch.float16
)
pbar.update(fin.tell() - start_pos) pbar.update(fin.tell() - start_pos)
@ -112,11 +119,17 @@ def convert_to_hf_format(model, hparams):
dim = hparams["dim"] dim = hparams["dim"]
dims_per_head = dim // n_heads dims_per_head = dim // n_heads
base = 10000.0 base = 10000.0
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) inv_freq = 1.0 / (
base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)
)
# permute for sliced rotary # permute for sliced rotary
def permute(w): def permute(w):
return w.view(n_heads, dim // n_heads // 2, 2, dim).transpose(1, 2).reshape(dim, dim) return (
w.view(n_heads, dim // n_heads // 2, 2, dim)
.transpose(1, 2)
.reshape(dim, dim)
)
state_dict = {} state_dict = {}
for layer_i in range(n_layers): for layer_i in range(n_layers):
@ -164,16 +177,22 @@ def convert_to_hf_format(model, hparams):
def chat(model, hparams, llama_dir): def chat(model, hparams, llama_dir):
from transformers import (GenerationConfig, LlamaForCausalLM, from transformers import (
LlamaTokenizer, StoppingCriteria, GenerationConfig,
StoppingCriteriaList) LlamaForCausalLM,
LlamaTokenizer,
StoppingCriteria,
StoppingCriteriaList,
)
from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.configuration_llama import LlamaConfig
class StoppingCriteriaSub(StoppingCriteria): class StoppingCriteriaSub(StoppingCriteria):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, stops=[]): def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, stops=[]
):
print(tokenizer.decode(input_ids[0]), end="", flush=True) print(tokenizer.decode(input_ids[0]), end="", flush=True)
if input_ids[0][-1] == 13: if input_ids[0][-1] == 13:
return True return True
@ -237,7 +256,11 @@ AI: Hello! How can I assist you today?
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"--input_dir", "-i", type=str, required=True, help="The input directory containing the ggml files." "--input_dir",
"-i",
type=str,
required=True,
help="The input directory containing the ggml files.",
) )
parser.add_argument( parser.add_argument(
"--prefix", "--prefix",
@ -252,14 +275,21 @@ def main():
help="Whether to save the model in the huggingface format. (default: False)", help="Whether to save the model in the huggingface format. (default: False)",
) )
parser.add_argument( parser.add_argument(
"--chat", "-c", action="store_true", help="Whether to open a chat with the model. (default: False)" "--chat",
"-c",
action="store_true",
help="Whether to open a chat with the model. (default: False)",
) )
args = parser.parse_args() args = parser.parse_args()
llama_dir = os.path.abspath(f"{args.input_dir}/../") llama_dir = os.path.abspath(f"{args.input_dir}/../")
ggml_files = sorted( ggml_files = sorted(
[f"{args.input_dir}/{f}" for f in os.listdir(args.input_dir) if f.startswith(args.prefix)] [
f"{args.input_dir}/{f}"
for f in os.listdir(args.input_dir)
if f.startswith(args.prefix)
]
) )
fin = open(ggml_files[0], "rb") fin = open(ggml_files[0], "rb")