Add support for BERT embedding models (#5423)
* BERT model graph construction (build_bert) * WordPiece tokenizer (llm_tokenize_wpm) * Add flag for non-causal attention models * Allow for models that only output embeddings * Support conversion of BERT models to GGUF * Based on prior work by @xyzhang626 and @skeskinen --------- Co-authored-by: Jared Van Bortel <jared@nomic.ai> Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
97a336507e
commit
2891c8aa9a
8 changed files with 616 additions and 52 deletions
|
@ -209,6 +209,8 @@ class Model:
|
|||
return InternLM2Model
|
||||
if model_architecture == "MiniCPMForCausalLM":
|
||||
return MiniCPMModel
|
||||
if model_architecture == "BertModel":
|
||||
return BertModel
|
||||
return Model
|
||||
|
||||
def _is_model_safetensors(self) -> bool:
|
||||
|
@ -264,6 +266,8 @@ class Model:
|
|||
return gguf.MODEL_ARCH.INTERNLM2
|
||||
if arch == "MiniCPMForCausalLM":
|
||||
return gguf.MODEL_ARCH.MINICPM
|
||||
if arch == "BertModel":
|
||||
return gguf.MODEL_ARCH.BERT
|
||||
|
||||
raise NotImplementedError(f'Architecture "{arch}" not supported!')
|
||||
|
||||
|
@ -1629,6 +1633,96 @@ in chat mode so that the conversation can end normally.")
|
|||
self.post_write_tensors(tensor_map, name, data_torch)
|
||||
|
||||
|
||||
class BertModel(Model):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.block_count = self.hparams["num_hidden_layers"]
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
# TODO(cebtenzzre): merge with parent class
|
||||
self.gguf_writer.add_name(self.dir_model.name)
|
||||
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
|
||||
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
|
||||
self.gguf_writer.add_causal_attention(False)
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
def set_vocab(self):
|
||||
path = self.dir_model
|
||||
added_tokens_path = self.dir_model if self.dir_model.exists() else None
|
||||
|
||||
# use huggingface vocab to get all tokens
|
||||
vocab = HfVocab(path, added_tokens_path)
|
||||
tokens, scores, toktypes = zip(*vocab.all_tokens())
|
||||
assert len(tokens) == vocab.vocab_size
|
||||
|
||||
# we need this to validate the size of the token_type embeddings
|
||||
# though currently we are passing all zeros to the token_type embeddings
|
||||
n_token_types = len(set(toktypes))
|
||||
self.gguf_writer.add_token_type_count(n_token_types)
|
||||
|
||||
# convert to phantom space vocab
|
||||
def phantom(tok, typ):
|
||||
if tok.startswith(b"[") and tok.endswith(b"]"):
|
||||
return tok
|
||||
if tok.startswith(b"##"):
|
||||
return tok[2:]
|
||||
return b"\xe2\x96\x81" + tok
|
||||
tokens = [phantom(t, y) for t, y in zip(tokens, toktypes)]
|
||||
|
||||
# set up bos and eos tokens (cls and sep)
|
||||
self.gguf_writer.add_bos_token_id(vocab.tokenizer.cls_token_id)
|
||||
self.gguf_writer.add_eos_token_id(vocab.tokenizer.sep_token_id)
|
||||
|
||||
# add vocab to gguf
|
||||
self.gguf_writer.add_tokenizer_model("bert")
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_scores(scores)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
|
||||
# handle special tokens
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def write_tensors(self):
|
||||
tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
|
||||
tensors = dict(self.get_tensors())
|
||||
for name, data_torch in tensors.items():
|
||||
# we are only using BERT for embeddings so we don't need the pooling layer
|
||||
if name in ("embeddings.position_ids", "pooler.dense.weight", "pooler.dense.bias"):
|
||||
continue # we don't need these
|
||||
|
||||
# map tensor names
|
||||
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
||||
if new_name is None:
|
||||
print(f"Can not map tensor {name!r}")
|
||||
sys.exit()
|
||||
|
||||
data = data_torch.squeeze().numpy()
|
||||
n_dims = len(data.shape)
|
||||
new_dtype: type[np.floating[Any]]
|
||||
|
||||
if (
|
||||
self.ftype == 1 and name.endswith(".weight") and n_dims == 2
|
||||
and name != "embeddings.token_type_embeddings.weight" # not used with get_rows, must be F32
|
||||
):
|
||||
# if f16 desired, convert any float32 2-dim weight tensors to float16
|
||||
new_dtype = np.float16
|
||||
else:
|
||||
# if f32 desired, convert any float16 to float32
|
||||
new_dtype = np.float32
|
||||
|
||||
print(f"{new_name}, n_dims = {n_dims}, {data_torch.dtype} --> {new_dtype}")
|
||||
|
||||
if data.dtype != new_dtype:
|
||||
data = data.astype(new_dtype)
|
||||
|
||||
self.gguf_writer.add_tensor(new_name, data)
|
||||
|
||||
|
||||
###### CONVERSION LOGIC ######
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue