This commit is contained in:
Amir Eslampanah 2023-05-03 10:51:59 +02:00 committed by GitHub
commit df69e4e43e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 272 additions and 0 deletions

181
convert-hf-to-ggml-v2.py Normal file
View file

@ -0,0 +1,181 @@
import io
import os
import sys
import struct
import json
import torch
import numpy as np
import tempfile
import argparse
from tqdm import tqdm
from transformers import AutoTokenizer, AutoConfig
conv_map = {
'word_embeddings': 'tok_embeddings',
'word_embeddings_layernorm': 'norm',
'input_layernorm': 'attention_norm',
'self_attention.query_key_value': 'attention.query_key_value',
'self_attention.dense': 'attention.wo',
'post_attention_layernorm': 'ffn_norm',
'mlp.dense_h_to_4h': 'feed_forward.w1',
'mlp.dense_4h_to_h': 'feed_forward.w2',
'ln_f': 'output_norm',
'lm_head': 'output',
}
parser = argparse.ArgumentParser(description='Convert a model from HF format to GGML format.')
parser.add_argument('model_name', type=str, help='directory of the model to convert. Example: "bigscience/bloomz-560m"')
parser.add_argument('dir_output', type=str, help='directory where the output file will be written')
parser.add_argument('--use-f32', action='store_true', help='if present, use float32 instead of float16')
parser.add_argument('--debug', action='store_true', help='if present, dump the progress as it happens')
args = parser.parse_args()
model_name = args.model_name
dir_out = args.dir_output
os.makedirs(dir_out, exist_ok=True)
ftype_str = ["f32", "f16"]
ftype = 0 if args.use_f32 else 1
debug_flag = args.debug
tokenizer = AutoTokenizer.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name)
hparams = config.to_dict()
print("Loading model: ", model_name)
# Save the model to disk
model_dir = f"{model_name}_tmp"
os.makedirs(model_dir, exist_ok=True)
config.save_pretrained(model_dir)
fname_out = dir_out + f"/ggml-model-{model_name.split('/')[-1]}-{ftype_str[ftype]}.bin"
fout = open(fname_out, "wb")
hparams["multiple_of"] = 1
fout.write(struct.pack("i", 0x67676d6c))
fout.write(struct.pack("i", hparams["vocab_size"]))
fout.write(struct.pack("i", hparams["hidden_size"]))
fout.write(struct.pack("i", hparams["multiple_of"]))
fout.write(struct.pack("i", hparams["n_head"]))
fout.write(struct.pack("i", hparams["n_layer"]))
fout.write(struct.pack("i", ftype))
dot_token = tokenizer.encode(".")[0]
for i in range(hparams["vocab_size"]):
text = tokenizer.decode([i]).encode('utf-8')
fout.write(struct.pack("i", len(text)))
fout.write(text)
# Create temporary files for chunks
temp_files = {}
# Define the chunk size
chunk_size = 1000 * 1000 * 1024
# Load the PyTorch model weights from the saved files
# Find the files in the model directory
model_files = sorted([f for f in os.listdir(model_name) if f.startswith("pytorch_model") and f.endswith(".bin")])
added_head = False
state_dict = {}
for model_file in tqdm(model_files, desc="Processing model files in: " + model_name):
file_path = os.path.join(model_name, model_file)
model_part = torch.load(file_path, map_location=torch.device('cpu'))
state_dict.update(model_part)
# Add the missing lm_head.weight tensor
lm_head_weight_key = 'lm_head.weight'
word_embeddings_weight_key = 'word_embeddings.weight'
if lm_head_weight_key not in state_dict and not added_head:
# Use the word_embeddings.weight tensor for the lm_head.weight
word_embeddings_weight = state_dict[word_embeddings_weight_key]
# Add the tensor to the state_dict
state_dict[lm_head_weight_key] = word_embeddings_weight
added_head = True
for name in tqdm(state_dict.keys(), desc="Processing nodes"):
src = name
nn = name.split(".")
# Handle layer indices
if nn[0].isdigit():
layer_idx = nn[0]
nn = nn[1:]
else:
layer_idx = None
if debug_flag:
if nn[0].isdigit():
print("For Layer: " + layer_idx)
if nn[0] == "h":
nn[0] = "layers"
mapped = conv_map[".".join(nn[2:-1])]
if layer_idx is not None:
name = f"{layer_idx}.{nn[0]}.{layer_idx}.{mapped}.{nn[-1]}"
else:
name = ".".join(nn[:2] + [mapped] + nn[-1:])
else:
mapped = conv_map[".".join(nn[:-1])]
if layer_idx is not None:
name = f"{layer_idx}.{mapped}.{nn[-1]}"
else:
name = ".".join([mapped] + nn[-1:])
if "query_key_value" in src:
q, k, v = state_dict[src].reshape(config.n_head, 3, -1).unbind(1)
state_dict[src] = torch.cat([q, k, v], dim=0).reshape_as(state_dict[src])
if debug_flag:
print(src, ' -> ', name)
tensor = state_dict[src].cpu()
# If the tensor dtype is bfloat16, convert it to float32
if tensor.dtype == torch.bfloat16:
tensor = tensor.to(torch.float32)
data = tensor.squeeze().numpy()
data = data.astype(np.float32)
n_dims = len(data.shape)
if debug_flag:
print(name, n_dims, data.shape)
# Check if the current data type is float16
if data.dtype == np.float16:
ftype_cur = 1
else:
ftype_cur = 0
# If the specified ftype is float16 and the current ftype is not, convert data to float16
if ftype == 1 and ftype_cur == 0 and n_dims > 1:
if debug_flag:
print("Converting to float16")
data = data.astype(np.float16)
ftype_cur = 1
str = name.encode('utf-8')
fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
for i in range(n_dims):
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
fout.write(str)
# Write data to file in chunks
data_buffer = data.tobytes()
data_len = len(data_buffer)
for offset in range(0, data_len, chunk_size):
chunk = data_buffer[offset: offset + chunk_size]
fout.write(chunk)
# Free some memory as we don't need the previous layer's state
state_dict = {}
fout.close()
print("Done. Output file: " + fname_out)
print("")

91
tokenconvert.py Executable file
View file

@ -0,0 +1,91 @@
import json
import os
import tempfile
import sys
import sentencepiece as spm
from tokenizers import Tokenizer, models, pre_tokenizers, decoders, processors
if len(sys.argv) < 3:
print("Usage: python tokenconvert.py tokenizer_type token-dir [dir-output]")
print(" tokenizer_type: The type of tokenizer (check the model information), eg: BPE, WordPiece, SentencePiece.")
print(" token-dir: Directory of the model containing the tokenizer.json. Example: 'bigscience/bloomz-560m'")
print(" dir-output: directory where the output file will be written, eg: ./tokenizer.model , by default writes to the same directory.")
sys.exit(1)
tokenizer_type = sys.argv[1]
token_dir = sys.argv[2]
if len(sys.argv) < 4:
dir_out = token_dir
else:
dir_out = sys.argv[3]
def load_tokenizer_from_json(json_path, special_tokens_map_path, tokenizer_config_path, tokenizer_type):
with open(json_path, "r", encoding="utf-8") as f:
json_data = json.load(f)
with open(special_tokens_map_path, "r", encoding="utf-8") as f:
special_tokens_map = json.load(f)
with open(tokenizer_config_path, "r", encoding="utf-8") as f:
tokenizer_config = json.load(f)
model_data = json_data["model"]
vocab = model_data["vocab"]
merges = model_data["merges"]
with tempfile.NamedTemporaryFile(mode="w", delete=False) as vocab_file:
json.dump(vocab, vocab_file)
if tokenizer_type == "BPE":
with tempfile.NamedTemporaryFile(mode="w", delete=False) as merges_file:
merges_file.write("\n".join(merges))
tokenizer = Tokenizer(models.BPE.from_file(vocab_file.name, merges_file.name))
os.unlink(merges_file.name)
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
tokenizer.decoder = decoders.ByteLevel()
elif tokenizer_type == "WordPiece":
tokenizer = Tokenizer(models.WordPiece.from_file(vocab_file.name))
tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()
tokenizer.decoder = decoders.WordPiece()
elif tokenizer_type == "SentencePiece":
sp_model = spm.SentencePieceProcessor()
sp_model.Load(vocab_file.name)
tokenizer = Tokenizer(models.Model.from_sentencepiece(sp_model))
tokenizer.pre_tokenizer = pre_tokenizers.Sequence([pre_tokenizers.Metaspace(), pre_tokenizers.Split()])
tokenizer.decoder = decoders.Sequence([decoders.Split(), decoders.Metaspace()])
else:
raise ValueError(f"Unsupported tokenizer type: {tokenizer_type}")
os.unlink(vocab_file.name)
bos_token_id = tokenizer.token_to_id(special_tokens_map["bos_token"])
eos_token_id = tokenizer.token_to_id(special_tokens_map["eos_token"])
tokenizer.post_processor = processors.TemplateProcessing(
single=f"{special_tokens_map['bos_token']} $A {special_tokens_map['eos_token']}",
pair=f"{special_tokens_map['bos_token']} $A {special_tokens_map['eos_token']} {special_tokens_map['bos_token']} $B {special_tokens_map['eos_token']}",
special_tokens=[
(special_tokens_map["bos_token"], bos_token_id),
(special_tokens_map["eos_token"], eos_token_id),
],
)
return tokenizer
if __name__ == "__main__":
input_json_path = os.path.join(token_dir, "tokenizer.json")
special_tokens_map_path = os.path.join(token_dir, "special_tokens_map.json")
tokenizer_config_path = os.path.join(token_dir, "tokenizer_config.json")
output_model_path = os.path.join(dir_out, "tokenizer.model")
tokenizer = load_tokenizer_from_json(input_json_path, special_tokens_map_path, tokenizer_config_path, tokenizer_type)
print(f"Saving.. tokenizer.model to {output_model_path}")
tokenizer.save(output_model_path)
print(f"Saved tokenizer.model to {output_model_path}")