diff --git a/convert-ggml-to-pth.py b/convert-ggml-to-pth.py index 20158c9ca..8cc677851 100644 --- a/convert-ggml-to-pth.py +++ b/convert-ggml-to-pth.py @@ -72,7 +72,12 @@ def dequantize_weights(fin, n_rows, n_cols): def read_variables(fin): model = {} - pbar = tqdm(total=os.path.getsize(fin.name), unit="B", unit_scale=True, desc="Reading variables") + pbar = tqdm( + total=os.path.getsize(fin.name), + unit="B", + unit_scale=True, + desc="Reading variables", + ) while True: start_pos = fin.tell() try: @@ -98,7 +103,9 @@ def read_variables(fin): data_size = np.prod(shape) data = np.fromfile(fin, dtype=dtype, count=data_size).reshape(shape) - model[name] = torch.tensor(data, dtype=torch.float32 if dtype == np.float32 else torch.float16) + model[name] = torch.tensor( + data, dtype=torch.float32 if dtype == np.float32 else torch.float16 + ) pbar.update(fin.tell() - start_pos) @@ -112,11 +119,17 @@ def convert_to_hf_format(model, hparams): dim = hparams["dim"] dims_per_head = dim // n_heads base = 10000.0 - inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head) + ) # permute for sliced rotary def permute(w): - return w.view(n_heads, dim // n_heads // 2, 2, dim).transpose(1, 2).reshape(dim, dim) + return ( + w.view(n_heads, dim // n_heads // 2, 2, dim) + .transpose(1, 2) + .reshape(dim, dim) + ) state_dict = {} for layer_i in range(n_layers): @@ -164,16 +177,22 @@ def convert_to_hf_format(model, hparams): def chat(model, hparams, llama_dir): - from transformers import (GenerationConfig, LlamaForCausalLM, - LlamaTokenizer, StoppingCriteria, - StoppingCriteriaList) + from transformers import ( + GenerationConfig, + LlamaForCausalLM, + LlamaTokenizer, + StoppingCriteria, + StoppingCriteriaList, + ) from transformers.models.llama.configuration_llama import LlamaConfig class StoppingCriteriaSub(StoppingCriteria): def __init__(self): super().__init__() - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, stops=[]): + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, stops=[] + ): print(tokenizer.decode(input_ids[0]), end="", flush=True) if input_ids[0][-1] == 13: return True @@ -237,7 +256,11 @@ AI: Hello! How can I assist you today? def main(): parser = argparse.ArgumentParser() parser.add_argument( - "--input_dir", "-i", type=str, required=True, help="The input directory containing the ggml files." + "--input_dir", + "-i", + type=str, + required=True, + help="The input directory containing the ggml files.", ) parser.add_argument( "--prefix", @@ -252,14 +275,21 @@ def main(): help="Whether to save the model in the huggingface format. (default: False)", ) parser.add_argument( - "--chat", "-c", action="store_true", help="Whether to open a chat with the model. (default: False)" + "--chat", + "-c", + action="store_true", + help="Whether to open a chat with the model. (default: False)", ) args = parser.parse_args() llama_dir = os.path.abspath(f"{args.input_dir}/../") ggml_files = sorted( - [f"{args.input_dir}/{f}" for f in os.listdir(args.input_dir) if f.startswith(args.prefix)] + [ + f"{args.input_dir}/{f}" + for f in os.listdir(args.input_dir) + if f.startswith(args.prefix) + ] ) fin = open(ggml_files[0], "rb")