apply black to ggml-to-pth
This commit is contained in:
parent
9cbc404ba6
commit
dfa2d707e9
1 changed files with 41 additions and 11 deletions
|
@ -72,7 +72,12 @@ def dequantize_weights(fin, n_rows, n_cols):
|
|||
|
||||
def read_variables(fin):
|
||||
model = {}
|
||||
pbar = tqdm(total=os.path.getsize(fin.name), unit="B", unit_scale=True, desc="Reading variables")
|
||||
pbar = tqdm(
|
||||
total=os.path.getsize(fin.name),
|
||||
unit="B",
|
||||
unit_scale=True,
|
||||
desc="Reading variables",
|
||||
)
|
||||
while True:
|
||||
start_pos = fin.tell()
|
||||
try:
|
||||
|
@ -98,7 +103,9 @@ def read_variables(fin):
|
|||
data_size = np.prod(shape)
|
||||
data = np.fromfile(fin, dtype=dtype, count=data_size).reshape(shape)
|
||||
|
||||
model[name] = torch.tensor(data, dtype=torch.float32 if dtype == np.float32 else torch.float16)
|
||||
model[name] = torch.tensor(
|
||||
data, dtype=torch.float32 if dtype == np.float32 else torch.float16
|
||||
)
|
||||
|
||||
pbar.update(fin.tell() - start_pos)
|
||||
|
||||
|
@ -112,11 +119,17 @@ def convert_to_hf_format(model, hparams):
|
|||
dim = hparams["dim"]
|
||||
dims_per_head = dim // n_heads
|
||||
base = 10000.0
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
|
||||
inv_freq = 1.0 / (
|
||||
base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)
|
||||
)
|
||||
|
||||
# permute for sliced rotary
|
||||
def permute(w):
|
||||
return w.view(n_heads, dim // n_heads // 2, 2, dim).transpose(1, 2).reshape(dim, dim)
|
||||
return (
|
||||
w.view(n_heads, dim // n_heads // 2, 2, dim)
|
||||
.transpose(1, 2)
|
||||
.reshape(dim, dim)
|
||||
)
|
||||
|
||||
state_dict = {}
|
||||
for layer_i in range(n_layers):
|
||||
|
@ -164,16 +177,22 @@ def convert_to_hf_format(model, hparams):
|
|||
|
||||
|
||||
def chat(model, hparams, llama_dir):
|
||||
from transformers import (GenerationConfig, LlamaForCausalLM,
|
||||
LlamaTokenizer, StoppingCriteria,
|
||||
StoppingCriteriaList)
|
||||
from transformers import (
|
||||
GenerationConfig,
|
||||
LlamaForCausalLM,
|
||||
LlamaTokenizer,
|
||||
StoppingCriteria,
|
||||
StoppingCriteriaList,
|
||||
)
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
|
||||
class StoppingCriteriaSub(StoppingCriteria):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, stops=[]):
|
||||
def __call__(
|
||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, stops=[]
|
||||
):
|
||||
print(tokenizer.decode(input_ids[0]), end="", flush=True)
|
||||
if input_ids[0][-1] == 13:
|
||||
return True
|
||||
|
@ -237,7 +256,11 @@ AI: Hello! How can I assist you today?
|
|||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--input_dir", "-i", type=str, required=True, help="The input directory containing the ggml files."
|
||||
"--input_dir",
|
||||
"-i",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input directory containing the ggml files.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefix",
|
||||
|
@ -252,14 +275,21 @@ def main():
|
|||
help="Whether to save the model in the huggingface format. (default: False)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chat", "-c", action="store_true", help="Whether to open a chat with the model. (default: False)"
|
||||
"--chat",
|
||||
"-c",
|
||||
action="store_true",
|
||||
help="Whether to open a chat with the model. (default: False)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
llama_dir = os.path.abspath(f"{args.input_dir}/../")
|
||||
|
||||
ggml_files = sorted(
|
||||
[f"{args.input_dir}/{f}" for f in os.listdir(args.input_dir) if f.startswith(args.prefix)]
|
||||
[
|
||||
f"{args.input_dir}/{f}"
|
||||
for f in os.listdir(args.input_dir)
|
||||
if f.startswith(args.prefix)
|
||||
]
|
||||
)
|
||||
|
||||
fin = open(ggml_files[0], "rb")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue