Uniform args parsing and vocab only mode for convert examples

This commit is contained in:
KerfuffleV2 2023-08-29 10:47:47 -06:00
parent 2ea133895c
commit 58fa4dc870
4 changed files with 205 additions and 184 deletions

View file

@ -33,11 +33,10 @@ def bytes_to_unicode():
bs.append(b) bs.append(b)
cs.append(2**8+n) cs.append(2**8+n)
n += 1 n += 1
cs = [chr(n) for n in cs] return dict(zip(bs, (chr(n) for n in cs)))
return dict(zip(bs, cs))
def count_model_parts(dir_model: str) -> int: def count_model_parts(dir_model: Path) -> int:
num_parts = 0 num_parts = 0
for filename in os.listdir(dir_model): for filename in os.listdir(dir_model):
if filename.startswith("pytorch_model-"): if filename.startswith("pytorch_model-"):

View file

@ -8,6 +8,7 @@ import struct
import json import json
import numpy as np import numpy as np
import torch import torch
import argparse
from typing import Any, List from typing import Any, List
from pathlib import Path from pathlib import Path
@ -37,7 +38,7 @@ def bytes_to_unicode():
return dict(zip(bs, (chr(n) for n in cs))) return dict(zip(bs, (chr(n) for n in cs)))
def count_model_parts(dir_model: str) -> int: def count_model_parts(dir_model: Path) -> int:
num_parts = 0 num_parts = 0
for filename in os.listdir(dir_model): for filename in os.listdir(dir_model):
if filename.startswith("pytorch_model-"): if filename.startswith("pytorch_model-"):
@ -48,17 +49,22 @@ def count_model_parts(dir_model: str) -> int:
return num_parts return num_parts
if len(sys.argv) < 3: def parse_args() -> argparse.Namespace:
print(f"Usage: python {sys.argv[0]} dir-model ftype\n") parser = argparse.ArgumentParser(description="Convert a GPT-NeoX model to a GGML compatible file")
print(" ftype == 0 -> float32") parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab")
print(" ftype == 1 -> float16") parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.bin)")
parser.add_argument("ftype", type=int, choices=[0, 1], help="output format - use 0 for float32, 1 for float16", default = 1)
return parser.parse_args()
args = parse_args()
dir_model = args.model
ftype = args.ftype
if not dir_model.is_dir():
print(f'Error: {args.model} is not a directory', file = sys.stderr)
sys.exit(1) sys.exit(1)
# output in the same directory as the model
dir_model = sys.argv[1]
last_dir = os.path.basename(os.path.normpath(dir_model))
# possible tensor data types # possible tensor data types
# ftype == 0 -> float32 # ftype == 0 -> float32
# ftype == 1 -> float16 # ftype == 1 -> float16
@ -66,19 +72,15 @@ last_dir = os.path.basename(os.path.normpath(dir_model))
# map from ftype to string # map from ftype to string
ftype_str = ["f32", "f16"] ftype_str = ["f32", "f16"]
ftype = 1 if args.outfile is not None:
if len(sys.argv) > 2: fname_out = args.outfile
ftype = int(sys.argv[2]) else:
if ftype < 0 or ftype > 1: # output in the same directory as the model by default
print("Invalid ftype: " + str(ftype)) fname_out = dir_model / f'ggml-model-{ftype_str[ftype]}.gguf'
sys.exit(1) print("gguf: loading model "+dir_model.name)
fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".gguf" with open(dir_model / "config.json", "r", encoding="utf-8") as f:
print("gguf: loading model "+last_dir)
with open(dir_model + "/config.json", "r", encoding="utf-8") as f:
hparams = json.load(f) hparams = json.load(f)
if hparams["architectures"][0] != "GPTNeoXForCausalLM": if hparams["architectures"][0] != "GPTNeoXForCausalLM":
@ -96,7 +98,7 @@ print("gguf: get model metadata")
block_count = hparams["num_hidden_layers"] block_count = hparams["num_hidden_layers"]
gguf_writer.add_name(last_dir) gguf_writer.add_name(dir_model.name)
gguf_writer.add_context_length(hparams["max_position_embeddings"]) gguf_writer.add_context_length(hparams["max_position_embeddings"])
gguf_writer.add_embedding_length(hparams["hidden_size"]) gguf_writer.add_embedding_length(hparams["hidden_size"])
gguf_writer.add_block_count(block_count) gguf_writer.add_block_count(block_count)
@ -112,45 +114,49 @@ print("gguf: get tokenizer metadata")
tokens: List[bytearray] = [] tokens: List[bytearray] = []
if Path(dir_model + "/tokenizer.json").is_file(): tokenizer_json_file = dir_model / 'tokenizer.json'
# gpt2 tokenizer if not tokenizer_json_file.is_file():
gguf_writer.add_tokenizer_model("gpt2") print(f'Error: Missing {tokenizer_json_file}', file = sys.stderr)
sys.exit(1)
with open(dir_model + "/tokenizer.json", "r", encoding="utf-8") as f: # gpt2 tokenizer
tokenizer_json = json.load(f) gguf_writer.add_tokenizer_model("gpt2")
print("gguf: get gpt2 tokenizer vocab") with open(tokenizer_json_file, "r", encoding="utf-8") as f:
tokenizer_json = json.load(f)
vocab_size = len(tokenizer_json["model"]["vocab"]) print("gguf: get gpt2 tokenizer vocab")
# ref: https://github.com/cmp-nct/ggllm.cpp/blob/master/falcon_convert.py vocab_size = len(tokenizer_json["model"]["vocab"])
tokenizer = AutoTokenizer.from_pretrained(dir_model)
reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()} # ref: https://github.com/cmp-nct/ggllm.cpp/blob/master/falcon_convert.py
byte_encoder = bytes_to_unicode() tokenizer = AutoTokenizer.from_pretrained(dir_model)
byte_decoder = {v: k for k, v in byte_encoder.items()}
for i in range(vocab_size): reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()}
if i in reverse_vocab: byte_encoder = bytes_to_unicode()
try: byte_decoder = {v: k for k, v in byte_encoder.items()}
text = bytearray([byte_decoder[c] for c in reverse_vocab[i]])
except KeyError:
text = bytearray()
for c in reverse_vocab[i]:
if ord(c) < 256: # single byte character
text.append(byte_decoder[ord(c)])
else: # multibyte special token character
text.extend(c.encode('utf-8'))
else:
print(f"Key {i} not in tokenizer vocabulary. Padding with an arbitrary token.")
pad_token = f"[PAD{i}]".encode("utf8")
text = bytearray(pad_token)
tokens.append(text) for i in range(vocab_size):
if i in reverse_vocab:
try:
text = bytearray([byte_decoder[c] for c in reverse_vocab[i]])
except KeyError:
text = bytearray()
for c in reverse_vocab[i]:
if ord(c) < 256: # single byte character
text.append(byte_decoder[ord(c)])
else: # multibyte special token character
text.extend(c.encode('utf-8'))
else:
print(f"Key {i} not in tokenizer vocabulary. Padding with an arbitrary token.")
pad_token = f"[PAD{i}]".encode("utf8")
text = bytearray(pad_token)
gguf_writer.add_token_list(tokens) tokens.append(text)
special_vocab = gguf.SpecialVocab(Path(dir_model), load_merges = True) gguf_writer.add_token_list(tokens)
special_vocab = gguf.SpecialVocab(dir_model, load_merges = True)
special_vocab.add_to_gguf(gguf_writer) special_vocab.add_to_gguf(gguf_writer)
# TENSORS # TENSORS
@ -168,6 +174,8 @@ else:
) )
for part_name in part_names: for part_name in part_names:
if args.vocab_only:
break
print("gguf: loading model part '" + part_name + "'") print("gguf: loading model part '" + part_name + "'")
model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu") model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu")
@ -216,10 +224,11 @@ print("gguf: write header")
gguf_writer.write_header_to_file() gguf_writer.write_header_to_file()
print("gguf: write metadata") print("gguf: write metadata")
gguf_writer.write_kv_data_to_file() gguf_writer.write_kv_data_to_file()
print("gguf: write tensors") if not args.vocab_only:
gguf_writer.write_tensors_to_file() print("gguf: write tensors")
gguf_writer.write_tensors_to_file()
gguf_writer.close() gguf_writer.close()
print("gguf: model successfully exported to '" + fname_out + "'") print(f"gguf: model successfully exported to '{fname_out}'")
print("") print("")

View file

@ -10,6 +10,7 @@ import struct
import json import json
import numpy as np import numpy as np
import torch import torch
import argparse
from typing import Any, List, TypeAlias from typing import Any, List, TypeAlias
from pathlib import Path from pathlib import Path
@ -20,7 +21,7 @@ from sentencepiece import SentencePieceProcessor
NDArray: 'TypeAlias' = 'np.ndarray[Any, Any]' NDArray: 'TypeAlias' = 'np.ndarray[Any, Any]'
def count_model_parts(dir_model: str) -> int: def count_model_parts(dir_model: Path) -> int:
num_parts = 0 num_parts = 0
for filename in os.listdir(dir_model): for filename in os.listdir(dir_model):
if filename.startswith("consolidated."): if filename.startswith("consolidated."):
@ -31,19 +32,22 @@ def count_model_parts(dir_model: str) -> int:
return num_parts return num_parts
if len(sys.argv) < 3: def parse_args() -> argparse.Namespace:
print(f"Usage: python {sys.argv[0]} dir-model ftype\n") parser = argparse.ArgumentParser(description="Convert a PyTorch 7B LLaMA model to a GGML compatible file")
print(" ftype == 0 -> float32") parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab")
print(" ftype == 1 -> float16") parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.bin)")
parser.add_argument("ftype", type=int, choices=[0, 1], help="output format - use 0 for float32, 1 for float16", default = 1)
return parser.parse_args()
args = parse_args()
dir_model = args.model
ftype = args.ftype
if not dir_model.is_dir():
print(f'Error: {args.model} is not a directory', file = sys.stderr)
sys.exit(1) sys.exit(1)
# output in the same directory as the model
dir_model = sys.argv[1]
last_dir = os.path.basename(os.path.normpath(dir_model))
# possible tensor data types # possible tensor data types
# ftype == 0 -> float32 # ftype == 0 -> float32
# ftype == 1 -> float16 # ftype == 1 -> float16
@ -51,19 +55,15 @@ last_dir = os.path.basename(os.path.normpath(dir_model))
# map from ftype to string # map from ftype to string
ftype_str = ["f32", "f16"] ftype_str = ["f32", "f16"]
ftype = 1 if args.outfile is not None:
if len(sys.argv) > 2: fname_out = args.outfile
ftype = int(sys.argv[2]) else:
if ftype < 0 or ftype > 1: # output in the same directory as the model by default
print("Invalid ftype: " + str(ftype)) fname_out = dir_model / f'ggml-model-{ftype_str[ftype]}.gguf'
sys.exit(1) print("gguf: loading model "+dir_model.name)
fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".gguf" with open(dir_model / "config.json", "r", encoding="utf-8") as f:
print("gguf: loading model "+last_dir)
with open(dir_model + "/config.json", "r", encoding="utf-8") as f:
hparams = json.load(f) hparams = json.load(f)
if hparams["architectures"][0] != "LlamaForCausalLM": if hparams["architectures"][0] != "LlamaForCausalLM":
@ -107,7 +107,7 @@ else:
sys.exit() sys.exit()
gguf_writer.add_name(last_dir) gguf_writer.add_name(dir_model.name)
gguf_writer.add_source_hf_repo(hf_repo) gguf_writer.add_source_hf_repo(hf_repo)
gguf_writer.add_tensor_data_layout("Meta AI original pth") gguf_writer.add_tensor_data_layout("Meta AI original pth")
gguf_writer.add_context_length(ctx_length) gguf_writer.add_context_length(ctx_length)
@ -133,54 +133,59 @@ tokens: List[bytes] = []
scores: List[float] = [] scores: List[float] = []
toktypes: List[int] = [] toktypes: List[int] = []
if Path(dir_model + "/tokenizer.model").is_file(): tokenizer_model_file = dir_model / 'tokenizer.model'
# vocab type sentencepiece if not tokenizer_model_file.is_file():
print("gguf: get sentencepiece tokenizer vocab and scores") print(f'Error: Missing {tokenizer_model_file}', file = sys.stderr)
sys.exit(1)
tokenizer = SentencePieceProcessor(dir_model + "/tokenizer.model") # vocab type sentencepiece
print("gguf: get sentencepiece tokenizer vocab and scores")
for i in range(tokenizer.vocab_size()): tokenizer = SentencePieceProcessor(str(tokenizer_model_file))
text: bytes
score: float
piece = tokenizer.id_to_piece(i) for i in range(tokenizer.vocab_size()):
text = piece.encode("utf-8") text: bytes
score = tokenizer.get_score(i) score: float
toktype = 1 # defualt to normal token type piece = tokenizer.id_to_piece(i)
if tokenizer.is_unknown(i): text = piece.encode("utf-8")
toktype = 2 score = tokenizer.get_score(i)
if tokenizer.is_control(i):
toktype = 3
# toktype = 4 is user-defined = tokens from added_tokens.json toktype = 1 # defualt to normal token type
if tokenizer.is_unknown(i):
toktype = 2
if tokenizer.is_control(i):
toktype = 3
if tokenizer.is_unused(i): # toktype = 4 is user-defined = tokens from added_tokens.json
toktype = 5
if tokenizer.is_byte(i):
toktype = 6
tokens.append(text) if tokenizer.is_unused(i):
scores.append(score) toktype = 5
toktypes.append(toktype) if tokenizer.is_byte(i):
toktype = 6
if Path(dir_model + "/added_tokens.json").is_file(): tokens.append(text)
with open(dir_model + "/added_tokens.json", "r", encoding="utf-8") as f: scores.append(score)
addtokens_json = json.load(f) toktypes.append(toktype)
print("gguf: get added tokens") added_tokens_file = dir_model / 'added_tokens.json'
if added_tokens_file.is_file():
with open(added_tokens_file, "r", encoding="utf-8") as f:
addtokens_json = json.load(f)
for key in addtokens_json: print("gguf: get added tokens")
tokens.append( key.encode("utf-8") )
scores.append(-1000.0)
toktypes.append(4) # user-defined token type
gguf_writer.add_tokenizer_model("llama") for key in addtokens_json:
gguf_writer.add_token_list(tokens) tokens.append( key.encode("utf-8") )
gguf_writer.add_token_scores(scores) scores.append(-1000.0)
gguf_writer.add_token_types(toktypes) toktypes.append(4) # user-defined token type
special_vocab = gguf.SpecialVocab(Path(dir_model)) gguf_writer.add_tokenizer_model("llama")
gguf_writer.add_token_list(tokens)
gguf_writer.add_token_scores(scores)
gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(dir_model)
special_vocab.add_to_gguf(gguf_writer) special_vocab.add_to_gguf(gguf_writer)
# TENSORS # TENSORS
@ -193,6 +198,8 @@ print("gguf: get tensor metadata")
part_names = (f"consolidated.{n:02}.pth" for n in range(0, num_parts)) part_names = (f"consolidated.{n:02}.pth" for n in range(0, num_parts))
for part_name in part_names: for part_name in part_names:
if args.vocab_only:
break
print("gguf: loading model part '" + part_name + "'") print("gguf: loading model part '" + part_name + "'")
model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu") model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu")
@ -241,11 +248,11 @@ print("gguf: write header")
gguf_writer.write_header_to_file() gguf_writer.write_header_to_file()
print("gguf: write metadata") print("gguf: write metadata")
gguf_writer.write_kv_data_to_file() gguf_writer.write_kv_data_to_file()
print("gguf: write tensors") if not args.vocab_only:
gguf_writer.write_tensors_to_file() print("gguf: write tensors")
gguf_writer.write_tensors_to_file()
gguf_writer.close() gguf_writer.close()
print(f"gguf: model successfully exported to '{fname_out}'")
print("gguf: model successfully exported to '" + fname_out + "'")
print("") print("")

View file

@ -8,6 +8,7 @@ import struct
import json import json
import numpy as np import numpy as np
import torch import torch
import argparse
from typing import Any, List, Optional, TypeAlias from typing import Any, List, Optional, TypeAlias
from pathlib import Path from pathlib import Path
@ -43,40 +44,38 @@ def count_model_parts(dir_model: str) -> int:
return num_parts return num_parts
if len(sys.argv) < 3: def parse_args() -> argparse.Namespace:
print(f"Usage: python {sys.argv[0]} dir-model ftype\n") parser = argparse.ArgumentParser(description="Convert a HuggingFace LLaMA model to a GGML compatible file")
print(" ftype == 0 -> float32") parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab")
print(" ftype == 1 -> float16") parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.bin)")
parser.add_argument("ftype", type=int, choices=[0, 1], help="output format - use 0 for float32, 1 for float16", default = 1)
return parser.parse_args()
args = parse_args()
dir_model = args.model
ftype = args.ftype
if not dir_model.is_dir():
print(f'Error: {args.model} is not a directory', file = sys.stderr)
sys.exit(1) sys.exit(1)
# output in the same directory as the model
dir_model = sys.argv[1]
last_dir = os.path.basename(os.path.normpath(dir_model))
# possible tensor data types # possible tensor data types
# ftype == 0 -> float32 # ftype == 0 -> float32
# ftype == 1 -> float16 # ftype == 1 -> float16
# map from ftype to string # map from ftype to string
ftype_str = ["f32", "f16"] ftype_str = ["f32", "f16"]
ftype = 1 if args.outfile is not None:
if len(sys.argv) > 2: fname_out = args.outfile
ftype = int(sys.argv[2]) else:
if ftype < 0 or ftype > 1: # output in the same directory as the model by default
print("Invalid ftype: " + str(ftype)) fname_out = dir_model / f'ggml-model-{ftype_str[ftype]}.gguf'
sys.exit(1) print("gguf: loading model "+dir_model.name)
fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".gguf" with open(dir_model / "config.json", "r", encoding="utf-8") as f:
print("gguf: loading model "+last_dir)
with open(dir_model + "/config.json", "r", encoding="utf-8") as f:
hparams = json.load(f) hparams = json.load(f)
if hparams["architectures"][0] != "LlamaForCausalLM": if hparams["architectures"][0] != "LlamaForCausalLM":
@ -115,7 +114,7 @@ else:
sys.exit() sys.exit()
gguf_writer.add_name(last_dir) gguf_writer.add_name(dir_model.name)
gguf_writer.add_source_hf_repo(hf_repo) gguf_writer.add_source_hf_repo(hf_repo)
gguf_writer.add_tensor_data_layout("Meta AI original pth") gguf_writer.add_tensor_data_layout("Meta AI original pth")
gguf_writer.add_context_length(ctx_length) gguf_writer.add_context_length(ctx_length)
@ -141,55 +140,60 @@ tokens: List[bytes] = []
scores: List[float] = [] scores: List[float] = []
toktypes: List[int] = [] toktypes: List[int] = []
if Path(dir_model + "/tokenizer.model").is_file(): tokenizer_model_file = dir_model / 'tokenizer.model'
# vocab type sentencepiece if not tokenizer_model_file.is_file():
print("gguf: get sentencepiece tokenizer vocab, scores and token types") print(f'Error: Missing {tokenizer_model_file}', file = sys.stderr)
sys.exit(1)
tokenizer = SentencePieceProcessor(dir_model + "/tokenizer.model") # vocab type sentencepiece
print("gguf: get sentencepiece tokenizer vocab, scores and token types")
for i in range(tokenizer.vocab_size()): tokenizer = SentencePieceProcessor(str(tokenizer_model_file))
text: bytes
score: float
piece = tokenizer.id_to_piece(i) for i in range(tokenizer.vocab_size()):
text = piece.encode("utf-8") text: bytes
score = tokenizer.get_score(i) score: float
toktype = 1 # defualt to normal token type piece = tokenizer.id_to_piece(i)
if tokenizer.is_unknown(i): text = piece.encode("utf-8")
toktype = 2 score = tokenizer.get_score(i)
if tokenizer.is_control(i):
toktype = 3
# toktype = 4 is user-defined = tokens from added_tokens.json toktype = 1 # defualt to normal token type
if tokenizer.is_unknown(i):
toktype = 2
if tokenizer.is_control(i):
toktype = 3
if tokenizer.is_unused(i): # toktype = 4 is user-defined = tokens from added_tokens.json
toktype = 5
if tokenizer.is_byte(i):
toktype = 6
tokens.append(text) if tokenizer.is_unused(i):
scores.append(score) toktype = 5
toktypes.append(toktype) if tokenizer.is_byte(i):
toktype = 6
if Path(dir_model + "/added_tokens.json").is_file(): tokens.append(text)
with open(dir_model + "/added_tokens.json", "r", encoding="utf-8") as f: scores.append(score)
addtokens_json = json.load(f) toktypes.append(toktype)
print("gguf: get added tokens") added_tokens_file = dir_model / 'added_tokens.json'
if added_tokens_file.is_file():
with open(added_tokens_file, "r", encoding="utf-8") as f:
addtokens_json = json.load(f)
for key in addtokens_json: print("gguf: get added tokens")
tokens.append( key.encode("utf-8") )
scores.append(-1000.0) for key in addtokens_json:
toktypes.append(4) # user-defined token type tokens.append( key.encode("utf-8") )
scores.append(-1000.0)
toktypes.append(4) # user-defined token type
gguf_writer.add_tokenizer_model("llama") gguf_writer.add_tokenizer_model("llama")
gguf_writer.add_token_list(tokens) gguf_writer.add_token_list(tokens)
gguf_writer.add_token_scores(scores) gguf_writer.add_token_scores(scores)
gguf_writer.add_token_types(toktypes) gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(Path(dir_model)) special_vocab = gguf.SpecialVocab(dir_model)
special_vocab.add_to_gguf(gguf_writer) special_vocab.add_to_gguf(gguf_writer)
# TENSORS # TENSORS
@ -207,6 +211,8 @@ else:
) )
for part_name in part_names: for part_name in part_names:
if args.vocab_only:
break
print("gguf: loading model part '" + part_name + "'") print("gguf: loading model part '" + part_name + "'")
model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu") model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu")
@ -261,11 +267,11 @@ print("gguf: write header")
gguf_writer.write_header_to_file() gguf_writer.write_header_to_file()
print("gguf: write metadata") print("gguf: write metadata")
gguf_writer.write_kv_data_to_file() gguf_writer.write_kv_data_to_file()
print("gguf: write tensors") if not args.vocab_only:
gguf_writer.write_tensors_to_file() print("gguf: write tensors")
gguf_writer.write_tensors_to_file()
gguf_writer.close() gguf_writer.close()
print(f"gguf: model successfully exported to '{fname_out}'")
print("gguf: model successfully exported to '" + fname_out + "'")
print("") print("")