update conversion script to directly take adept artifacts rather than .saftensors file

This commit is contained in:
Phillip Kravtsov 2023-09-29 14:59:51 -07:00
parent ec0ce978ff
commit 3db04db2b8

View file

@ -1,28 +1,31 @@
from convert import lazy_load_safetensors_file
import sys
import torch import torch
from safetensors import safe_open import os
from pathlib import Path
from pprint import pprint from pprint import pprint
from sentencepiece import SentencePieceProcessor import sys
import argparse import argparse
from pathlib import Path
from sentencepiece import SentencePieceProcessor
if 'NO_LOCAL_GGUF' not in os.environ:
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf'))
import gguf import gguf
import json
import struct
def file_is_safetensors(path: Path) -> bool: def _flatten_dict(dct, tensors, prefix=None):
fp = open(path, 'rb') assert isinstance(dct, dict)
first8 = fp.read(8) for key in dct.keys():
fp.seek(0) new_prefix = prefix + '.' + key if prefix is not None else key
if first8[:2] == b'PK': if isinstance(dct[key], torch.Tensor):
# A zip file, i.e. PyTorch format tensors[new_prefix] = dct[key]
return False elif isinstance(dct[key], dict):
return struct.unpack('<Q', first8)[0] < 16 * 1024 * 1024 _flatten_dict(dct[key], tensors, new_prefix)
else:
raise ValueError(type(dct[key]))
return None
def get_tokenizer_info(dir_model: Path): def get_tokenizer_info(dir_model: Path):
tokenizer_path = dir_model / 'adept_vocab.model' tokenizer_path = dir_model / 'adept_vocab.model'
print('gguf: getting sentencepiece tokenizer from', tokenizer_path) print('gguf: getting sentencepiece tokenizer from', tokenizer_path)
tokenizer = SentencePieceProcessor(str(tokenizer_path)) tokenizer = SentencePieceProcessor(str(tokenizer_path))
print('gguf: adding tokens')
tokens: list[bytes] = [] tokens: list[bytes] = []
scores: list[float] = [] scores: list[float] = []
toktypes: list[int] = [] toktypes: list[int] = []
@ -54,41 +57,40 @@ def get_tokenizer_info(dir_model: Path):
pass pass
return tokens, scores, toktypes return tokens, scores, toktypes
def main():
def get_args():
parser = argparse.ArgumentParser(description="Convert a Persimmon model from Adept (e.g. Persimmon 8b chat) to a GGML compatible file") parser = argparse.ArgumentParser(description="Convert a Persimmon model from Adept (e.g. Persimmon 8b chat) to a GGML compatible file")
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input") parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.safetensors)") parser.add_argument("--ckpt-path", type=Path, help="path to persimmon checkpoint .pt file")
parser.add_argument("--model-dir", type=Path, help="directory containing model e.g. 8b_chat_model_release")
parser.add_argument("--adept-inference-dir", type=str, help="path to adept-inference code directory")
args = parser.parse_args() args = parser.parse_args()
return args sys.path.append(str(args.adept_inference_dir))
persimmon_model = torch.load(args.ckpt_path)
hparams = persimmon_model['args']
pprint(hparams)
tensors = {}
_flatten_dict(persimmon_model['model'], tensors, None)
def main() -> None:
args = get_args()
assert file_is_safetensors(args.model), 'Error: model file is not a SafeTensors file'
dir_model = args.model.parent
with open(dir_model / 'config.json', 'r') as f:
hparams = json.load(f)
arch = gguf.MODEL_ARCH.PERSIMMON arch = gguf.MODEL_ARCH.PERSIMMON
gguf_writer = gguf.GGUFWriter(args.outfile, gguf.MODEL_ARCH_NAMES[arch]) gguf_writer = gguf.GGUFWriter(args.outfile, gguf.MODEL_ARCH_NAMES[arch])
block_count = hparams['num_layers'] block_count = hparams.num_layers
head_count = hparams['num_attention_heads'] head_count = hparams.num_attention_heads
head_count_kv = head_count head_count_kv = head_count
ctx_length = hparams['seq_length'] ctx_length = hparams.seq_length
hidden_size = hparams['hidden_size'] hidden_size = hparams.hidden_size
gguf_writer.add_name('persimmon-8b-chat') gguf_writer.add_name('persimmon-8b-chat')
gguf_writer.add_context_length(ctx_length) gguf_writer.add_context_length(ctx_length)
gguf_writer.add_embedding_length(hidden_size) gguf_writer.add_embedding_length(hidden_size)
gguf_writer.add_block_count(block_count) gguf_writer.add_block_count(block_count)
gguf_writer.add_feed_forward_length(hparams['ffn_hidden_size']) gguf_writer.add_feed_forward_length(hparams.ffn_hidden_size)
gguf_writer.add_rope_dimension_count(hidden_size // head_count) gguf_writer.add_rope_dimension_count(hidden_size // head_count)
gguf_writer.add_head_count(head_count) gguf_writer.add_head_count(head_count)
gguf_writer.add_head_count_kv(head_count_kv) gguf_writer.add_head_count_kv(head_count_kv)
gguf_writer.add_rope_freq_base(hparams['rotary_emb_base']) gguf_writer.add_rope_freq_base(hparams.rotary_emb_base)
gguf_writer.add_layer_norm_eps(hparams['layernorm_epsilon']) gguf_writer.add_layer_norm_eps(hparams.layernorm_epsilon)
tokens, scores, toktypes = get_tokenizer_info(dir_model) tokens, scores, toktypes = get_tokenizer_info(args.model_dir)
gguf_writer.add_tokenizer_model('llama') gguf_writer.add_tokenizer_model('llama')
gguf_writer.add_token_list(tokens) gguf_writer.add_token_list(tokens)
gguf_writer.add_token_scores(scores) gguf_writer.add_token_scores(scores)
@ -98,10 +100,6 @@ def main() -> None:
tensor_map = gguf.get_tensor_name_map(arch, block_count) tensor_map = gguf.get_tensor_name_map(arch, block_count)
print(tensor_map) print(tensor_map)
tensors = {}
with safe_open(args.model, framework="pt") as f:
for k in f.keys():
tensors[k] = f.get_tensor(k)
for name in tensors.keys(): for name in tensors.keys():
data = tensors[name] data = tensors[name]
if name.endswith(".self_attention.rotary_emb.inv_freq"): if name.endswith(".self_attention.rotary_emb.inv_freq"):
@ -132,4 +130,4 @@ def main() -> None:
if __name__ == '__main__': if __name__ == '__main__':
main() main()