apply black to ggml-to-pth
This commit is contained in:
parent
9cbc404ba6
commit
dfa2d707e9
1 changed files with 41 additions and 11 deletions
|
@ -72,7 +72,12 @@ def dequantize_weights(fin, n_rows, n_cols):
|
||||||
|
|
||||||
def read_variables(fin):
|
def read_variables(fin):
|
||||||
model = {}
|
model = {}
|
||||||
pbar = tqdm(total=os.path.getsize(fin.name), unit="B", unit_scale=True, desc="Reading variables")
|
pbar = tqdm(
|
||||||
|
total=os.path.getsize(fin.name),
|
||||||
|
unit="B",
|
||||||
|
unit_scale=True,
|
||||||
|
desc="Reading variables",
|
||||||
|
)
|
||||||
while True:
|
while True:
|
||||||
start_pos = fin.tell()
|
start_pos = fin.tell()
|
||||||
try:
|
try:
|
||||||
|
@ -98,7 +103,9 @@ def read_variables(fin):
|
||||||
data_size = np.prod(shape)
|
data_size = np.prod(shape)
|
||||||
data = np.fromfile(fin, dtype=dtype, count=data_size).reshape(shape)
|
data = np.fromfile(fin, dtype=dtype, count=data_size).reshape(shape)
|
||||||
|
|
||||||
model[name] = torch.tensor(data, dtype=torch.float32 if dtype == np.float32 else torch.float16)
|
model[name] = torch.tensor(
|
||||||
|
data, dtype=torch.float32 if dtype == np.float32 else torch.float16
|
||||||
|
)
|
||||||
|
|
||||||
pbar.update(fin.tell() - start_pos)
|
pbar.update(fin.tell() - start_pos)
|
||||||
|
|
||||||
|
@ -112,11 +119,17 @@ def convert_to_hf_format(model, hparams):
|
||||||
dim = hparams["dim"]
|
dim = hparams["dim"]
|
||||||
dims_per_head = dim // n_heads
|
dims_per_head = dim // n_heads
|
||||||
base = 10000.0
|
base = 10000.0
|
||||||
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
|
inv_freq = 1.0 / (
|
||||||
|
base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)
|
||||||
|
)
|
||||||
|
|
||||||
# permute for sliced rotary
|
# permute for sliced rotary
|
||||||
def permute(w):
|
def permute(w):
|
||||||
return w.view(n_heads, dim // n_heads // 2, 2, dim).transpose(1, 2).reshape(dim, dim)
|
return (
|
||||||
|
w.view(n_heads, dim // n_heads // 2, 2, dim)
|
||||||
|
.transpose(1, 2)
|
||||||
|
.reshape(dim, dim)
|
||||||
|
)
|
||||||
|
|
||||||
state_dict = {}
|
state_dict = {}
|
||||||
for layer_i in range(n_layers):
|
for layer_i in range(n_layers):
|
||||||
|
@ -164,16 +177,22 @@ def convert_to_hf_format(model, hparams):
|
||||||
|
|
||||||
|
|
||||||
def chat(model, hparams, llama_dir):
|
def chat(model, hparams, llama_dir):
|
||||||
from transformers import (GenerationConfig, LlamaForCausalLM,
|
from transformers import (
|
||||||
LlamaTokenizer, StoppingCriteria,
|
GenerationConfig,
|
||||||
StoppingCriteriaList)
|
LlamaForCausalLM,
|
||||||
|
LlamaTokenizer,
|
||||||
|
StoppingCriteria,
|
||||||
|
StoppingCriteriaList,
|
||||||
|
)
|
||||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||||
|
|
||||||
class StoppingCriteriaSub(StoppingCriteria):
|
class StoppingCriteriaSub(StoppingCriteria):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, stops=[]):
|
def __call__(
|
||||||
|
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, stops=[]
|
||||||
|
):
|
||||||
print(tokenizer.decode(input_ids[0]), end="", flush=True)
|
print(tokenizer.decode(input_ids[0]), end="", flush=True)
|
||||||
if input_ids[0][-1] == 13:
|
if input_ids[0][-1] == 13:
|
||||||
return True
|
return True
|
||||||
|
@ -237,7 +256,11 @@ AI: Hello! How can I assist you today?
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--input_dir", "-i", type=str, required=True, help="The input directory containing the ggml files."
|
"--input_dir",
|
||||||
|
"-i",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The input directory containing the ggml files.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--prefix",
|
"--prefix",
|
||||||
|
@ -252,14 +275,21 @@ def main():
|
||||||
help="Whether to save the model in the huggingface format. (default: False)",
|
help="Whether to save the model in the huggingface format. (default: False)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--chat", "-c", action="store_true", help="Whether to open a chat with the model. (default: False)"
|
"--chat",
|
||||||
|
"-c",
|
||||||
|
action="store_true",
|
||||||
|
help="Whether to open a chat with the model. (default: False)",
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
llama_dir = os.path.abspath(f"{args.input_dir}/../")
|
llama_dir = os.path.abspath(f"{args.input_dir}/../")
|
||||||
|
|
||||||
ggml_files = sorted(
|
ggml_files = sorted(
|
||||||
[f"{args.input_dir}/{f}" for f in os.listdir(args.input_dir) if f.startswith(args.prefix)]
|
[
|
||||||
|
f"{args.input_dir}/{f}"
|
||||||
|
for f in os.listdir(args.input_dir)
|
||||||
|
if f.startswith(args.prefix)
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
fin = open(ggml_files[0], "rb")
|
fin = open(ggml_files[0], "rb")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue