Make gguf_writer member of Model, rework tokenizer export

This commit is contained in:
Galunid 2023-10-29 00:33:05 +02:00
parent 22201248a0
commit 0ff237105d
2 changed files with 84 additions and 83 deletions

View file

@ -1,5 +1,4 @@
#!/usr/bin/env python3
# HF stablelm --> gguf conversion
from __future__ import annotations
@ -41,37 +40,30 @@ print("gguf: loading model " + dir_model.name)
hparams = model.Model.load_hparams(dir_model)
model_class = model.Model.from_model_architecture(hparams["architectures"][0])
model_instance = model_class(dir_model, ftype)
gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[model_instance.model_arch])
model_instance = model_class(dir_model, ftype, fname_out)
print("gguf: get model metadata")
model_instance.set_gguf_parameters(gguf_writer)
model_instance.set_gguf_parameters()
# TOKENIZATION
print("gguf: get tokenizer metadata")
gguf_writer.add_tokenizer_model("gpt2")
print("gguf: get gpt2 tokenizer vocab")
tokens, toktypes = model.Model.load_vocab_gpt2(model_instance.dir_model, model_instance.hparams)
gguf_writer.add_token_list(tokens)
gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(dir_model, load_merges = True)
special_vocab.add_to_gguf(gguf_writer)
model_instance.set_vocab()
# write model
print("gguf: write header")
gguf_writer.write_header_to_file()
model_instance.gguf_writer.write_header_to_file()
print("gguf: write metadata")
gguf_writer.write_kv_data_to_file()
model_instance.gguf_writer.write_kv_data_to_file()
if not args.vocab_only:
print("gguf: write tensors")
model_instance.write_tensors(gguf_writer)
gguf_writer.write_tensors_to_file()
model_instance.write_tensors()
model_instance.gguf_writer.write_tensors_to_file()
gguf_writer.close()
model_instance.gguf_writer.close()
print(f"gguf: model successfully exported to '{fname_out}'")
print("")

143
model.py
View file

@ -11,14 +11,16 @@ from pathlib import Path
class Model:
def __init__(self, dir_model: Path, ftype: int):
def __init__(self, dir_model: Path, ftype: int, fname_out: Path):
self.dir_model = dir_model
self.ftype = ftype
self.fname_out = fname_out
self.is_safetensors = self._is_model_safetensors()
self.num_parts = Model.count_model_parts(self.dir_model, ".safetensors" if self.is_safetensors else ".bin")
self.part_names = self._get_part_names()
self.hparams = Model.load_hparams(self.dir_model)
self.model_arch = self._get_model_architecture()
self.gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch])
def _is_model_safetensors(self) -> bool:
return Model.count_model_parts(self.dir_model, ".safetensors") > 0
@ -43,6 +45,41 @@ class Model:
return gguf.MODEL_ARCH.MPT
raise NotImplementedError(f'Architecture "{arch}" not supported!')
def set_vocab(self):
dir_model = self.dir_model
hparams = self.hparams
tokens: list[bytearray] = []
toktypes: list[int] = []
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(dir_model)
vocab_size = hparams.get("vocab_size", len(tokenizer.vocab))
assert max(tokenizer.vocab.values()) < vocab_size
reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()}
added_vocab = tokenizer.get_added_vocab()
for i in range(vocab_size):
if i not in reverse_vocab:
tokens.append(f"[PAD{i}]")
toktypes.append(gguf.TokenType.USER_DEFINED)
elif reverse_vocab[i] in added_vocab:
tokens.append(reverse_vocab[i])
if tokenizer.added_tokens_decoder[i].special:
toktypes.append(gguf.TokenType.CONTROL)
else:
toktypes.append(gguf.TokenType.USER_DEFINED)
else:
tokens.append(reverse_vocab[i])
toktypes.append(gguf.TokenType.NORMAL)
self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(dir_model, load_merges = True)
special_vocab.add_to_gguf(self.gguf_writer)
def get_tensors(self):
for part_name in self.part_names:
print("gguf: loading model part '" + part_name + "'")
@ -57,20 +94,20 @@ class Model:
data = model_part.get_tensor(name) if self.is_safetensors else model_part[name]
yield name, data
def set_gguf_parameters(self, gguf_writer: gguf.GGUFWriter):
gguf_writer.add_name(self.dir_model.name)
gguf_writer.add_block_count(self.hparams.get("n_layers", self.hparams.get("num_hidden_layers")))
def set_gguf_parameters(self):
self.gguf_writer.add_name(self.dir_model.name)
self.gguf_writer.add_block_count(self.hparams.get("n_layers", self.hparams.get("num_hidden_layers")))
if "max_position_embeddings" in self.hparams:
gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
if "hidden_size" in self.hparams:
gguf_writer.add_embedding_length(self.hparams["hidden_size"])
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
if "intermediate_size" in self.hparams:
gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
if "num_attention_head" in self.hparams:
gguf_writer.add_head_count(self.hparams["num_attention_heads"])
gguf_writer.add_parallel_residual(self.hparams["use_parallel_residual"] if "use_parallel_residual" in self.hparams else True)
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
self.gguf_writer.add_parallel_residual(self.hparams["use_parallel_residual"] if "use_parallel_residual" in self.hparams else True)
def write_tensors(self, gguf_writer: gguf.GGUFWriter):
def write_tensors(self):
block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers"))
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
for name, data in self.get_tensors():
@ -109,7 +146,7 @@ class Model:
print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
gguf_writer.add_tensor(new_name, data)
self.gguf_writer.add_tensor(new_name, data)
@staticmethod
def count_model_parts(dir_model: Path, prefix: str) -> int:
@ -126,34 +163,6 @@ class Model:
hparams = json.load(f)
return hparams
@staticmethod
def load_vocab_gpt2(dir_model: Path, hparams):
tokens: list[bytearray] = []
toktypes: list[int] = []
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(dir_model)
vocab_size = hparams.get("vocab_size", len(tokenizer.vocab))
assert max(tokenizer.vocab.values()) < vocab_size
reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()}
added_vocab = tokenizer.get_added_vocab()
for i in range(vocab_size):
if i not in reverse_vocab:
tokens.append(f"[PAD{i}]")
toktypes.append(gguf.TokenType.USER_DEFINED)
elif reverse_vocab[i] in added_vocab:
tokens.append(reverse_vocab[i])
if tokenizer.added_tokens_decoder[i].special:
toktypes.append(gguf.TokenType.CONTROL)
else:
toktypes.append(gguf.TokenType.USER_DEFINED)
else:
tokens.append(reverse_vocab[i])
toktypes.append(gguf.TokenType.NORMAL)
return tokens, toktypes
@staticmethod
def from_model_architecture(model_architecture):
@ -168,30 +177,30 @@ class Model:
return Model
class StableLMModel(Model):
def set_gguf_parameters(self, gguf_writer):
super().set_gguf_parameters(gguf_writer)
gguf_writer.add_rope_dimension_count(int(self.hparams["rope_pct"]*(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])))
gguf_writer.add_layer_norm_eps(1e-5)
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_rope_dimension_count(int(self.hparams["rope_pct"]*(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])))
self.gguf_writer.add_layer_norm_eps(1e-5)
class GPTNeoXModel(Model):
pass
class BloomModel(Model):
def set_gguf_parameters(self, gguf_writer: gguf.GGUFWriter):
gguf_writer.add_name("Bloom")
def set_gguf_parameters(self):
self.gguf_writer.add_name("Bloom")
n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed))
gguf_writer.add_embedding_length(n_embed)
gguf_writer.add_feed_forward_length(4 * n_embed)
gguf_writer.add_block_count(self.hparams["n_layer"])
gguf_writer.add_head_count(n_head)
gguf_writer.add_head_count_kv(n_head)
gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
gguf_writer.add_file_type(self.ftype)
self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed))
self.gguf_writer.add_embedding_length(n_embed)
self.gguf_writer.add_feed_forward_length(4 * n_embed)
self.gguf_writer.add_block_count(self.hparams["n_layer"])
self.gguf_writer.add_head_count(n_head)
self.gguf_writer.add_head_count_kv(n_head)
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
self.gguf_writer.add_file_type(self.ftype)
def write_tensors(self, gguf_writer):
def write_tensors(self):
block_count = self.hparams["n_layer"]
tensors = dict(self.get_tensors())
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
@ -258,24 +267,24 @@ class BloomModel(Model):
print(name, "=>", new_name + ", shape = " + str(data.shape) + ", " + str(old_dtype) + " --> " + str(data.dtype))
gguf_writer.add_tensor(new_name, data)
self.gguf_writer.add_tensor(new_name, data)
if not has_lm_head and name == "word_embeddings.weight":
gguf_writer.add_tensor("output.weight", data)
self.gguf_writer.add_tensor("output.weight", data)
print(name, "=>", "output.weight" + ", shape = " + str(data.shape) + ", " + str(old_dtype) + " --> " + str(data.dtype)) # noqa
class MPTModel(Model):
def set_gguf_parameters(self, gguf_writer):
def set_gguf_parameters(self):
block_count = self.hparams["n_layers"]
gguf_writer.add_name(self.dir_model.name)
gguf_writer.add_context_length(self.hparams["max_seq_len"])
gguf_writer.add_embedding_length(self.hparams["d_model"])
gguf_writer.add_block_count(block_count)
gguf_writer.add_feed_forward_length(4 * self.hparams["d_model"])
gguf_writer.add_head_count(self.hparams["n_heads"])
self.gguf_writer.add_name(self.dir_model.name)
self.gguf_writer.add_context_length(self.hparams["max_seq_len"])
self.gguf_writer.add_embedding_length(self.hparams["d_model"])
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_feed_forward_length(4 * self.hparams["d_model"])
self.gguf_writer.add_head_count(self.hparams["n_heads"])
if kv_n_heads := self.hparams["attn_config"].get("kv_n_heads"):
gguf_writer.add_head_count_kv(kv_n_heads)
gguf_writer.add_layer_norm_eps(1e-05)
self.gguf_writer.add_head_count_kv(kv_n_heads)
self.gguf_writer.add_layer_norm_eps(1e-05)
if self.hparams["attn_config"]["clip_qkv"] is not None:
gguf_writer.add_clamp_kqv(self.hparams["attn_config"]["clip_qkv"])
gguf_writer.add_max_alibi_bias(self.hparams["attn_config"]["alibi_bias_max"])
self.gguf_writer.add_clamp_kqv(self.hparams["attn_config"]["clip_qkv"])
self.gguf_writer.add_max_alibi_bias(self.hparams["attn_config"]["alibi_bias_max"])