added support for sbert/distilbert model

This commit is contained in:
Andrew 2023-11-17 13:49:27 -05:00
parent 8da46278e1
commit 19d82c2244
3 changed files with 153 additions and 57 deletions

View file

@ -14,15 +14,19 @@ from typing import TYPE_CHECKING, Any, ContextManager, Iterator, cast
import numpy as np import numpy as np
import torch import torch
from collections import OrderedDict
#TYPE_CHECKING = False
if TYPE_CHECKING: if TYPE_CHECKING:
from torch import Tensor from torch import Tensor
if 'NO_LOCAL_GGUF' not in os.environ: #if 'NO_LOCAL_GGUF' not in os.environ:
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) sys.path.insert(0, str(Path(__file__).parent / 'gguf-py'))
import gguf import gguf
#archs = [n.name for n in gguf.MODEL_ARCH] + ['DISTILBERT']
#gguf.MODEL_ARCH = IntEnum('MODEL_ARCH',archs)
#gguf.MODEL_ARCH_NAMES[gguf.MODEL_ARCH.DISTILBERT] = "distilbert"
###### MODEL DEFINITIONS ###### ###### MODEL DEFINITIONS ######
class SentencePieceTokenTypes(IntEnum): class SentencePieceTokenTypes(IntEnum):
@ -84,6 +88,9 @@ class Model:
def write_tensors(self): def write_tensors(self):
block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer"))) block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
print ("-------------------")
for tt,_ in self.get_tensors():
print(tt)
for name, data_torch in self.get_tensors(): for name, data_torch in self.get_tensors():
# we don't need these # we don't need these
if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")): if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
@ -166,6 +173,8 @@ class Model:
return RefactModel return RefactModel
if model_architecture == "PersimmonForCausalLM": if model_architecture == "PersimmonForCausalLM":
return PersimmonModel return PersimmonModel
if model_architecture == "DistilBertModel":
return DistilBertModel
if model_architecture in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"): if model_architecture in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
return StableLMModel return StableLMModel
return Model return Model
@ -203,6 +212,8 @@ class Model:
return gguf.MODEL_ARCH.PERSIMMON return gguf.MODEL_ARCH.PERSIMMON
if arch in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"): if arch in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
return gguf.MODEL_ARCH.STABLELM return gguf.MODEL_ARCH.STABLELM
if arch == "DistilBertModel":
return gguf.MODEL_ARCH.DISTILBERT
raise NotImplementedError(f'Architecture "{arch}" not supported!') raise NotImplementedError(f'Architecture "{arch}" not supported!')
@ -242,7 +253,8 @@ class Model:
special_vocab = gguf.SpecialVocab(dir_model, load_merges=True) special_vocab = gguf.SpecialVocab(dir_model, load_merges=True)
special_vocab.add_to_gguf(self.gguf_writer) special_vocab.add_to_gguf(self.gguf_writer)
def _set_vocab_sentencepiece(self):
def _set_vocab_sentencepiece(self):
from sentencepiece import SentencePieceProcessor from sentencepiece import SentencePieceProcessor
tokenizer_path = self.dir_model / 'tokenizer.model' tokenizer_path = self.dir_model / 'tokenizer.model'
@ -691,6 +703,57 @@ class StarCoderModel(Model):
self.gguf_writer.add_file_type(self.ftype) self.gguf_writer.add_file_type(self.ftype)
class DistilBertModel(Model):
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
for part_name in self.part_names:
print(f"gguf: loading model part '{part_name}'")
ctx: ContextManager[Any]
if self.is_safetensors:
raise NotImplementedError(f'Safetensors not supported for Distilbert')
else:
# load modules as in https://github.com/UKPLab/sentence-transformers/blob/master/sentence_transformers/SentenceTransformer.py
from sentence_transformers.util import import_from_string
modules_json_path = os.path.join(self.dir_model, 'modules.json')
with open(modules_json_path) as fIn:
modules_config = json.load(fIn)
modules = OrderedDict()
for module_config in modules_config:
module_class = import_from_string(module_config['type'])
module = module_class.load(os.path.join(self.dir_model, module_config['path']))
modules[module_config['name']] = module
sd = torch.nn.Sequential(modules).state_dict()
# deep copy
keys = [k for k in sd.keys()]
for k in keys:
key_chunks = k.split(".")[1:]
newkey = ".".join([c for c in key_chunks if c != "auto_model"])
sd[newkey] = sd.pop(k)
ctx = contextlib.nullcontext(sd)
with ctx as model_part:
for name in model_part.keys():
data = model_part.get_tensor(name) if self.is_safetensors else model_part[name]
yield name, data
def set_gguf_parameters(self):
block_count = self.hparams["n_layers"]
self.gguf_writer.add_name("DistilBert")
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
self.gguf_writer.add_embedding_length(self.hparams["dim"])
self.gguf_writer.add_feed_forward_length(self.hparams["hidden_dim"])
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_head_count(self.hparams["n_heads"])
self.gguf_writer.add_head_count_kv(1)
#self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
self.gguf_writer.add_file_type(self.ftype)
class RefactModel(Model): class RefactModel(Model):
def set_gguf_parameters(self): def set_gguf_parameters(self):
hidden_dim = self.hparams["n_embd"] hidden_dim = self.hparams["n_embd"]

View file

@ -91,7 +91,7 @@ class MODEL_ARCH(IntEnum):
BERT = auto() BERT = auto()
BLOOM = auto() BLOOM = auto()
STABLELM = auto() STABLELM = auto()
DISTILBERT = auto()
class MODEL_TENSOR(IntEnum): class MODEL_TENSOR(IntEnum):
TOKEN_EMBD = auto() TOKEN_EMBD = auto()
@ -115,6 +115,7 @@ class MODEL_TENSOR(IntEnum):
FFN_NORM = auto() FFN_NORM = auto()
ATTN_Q_NORM = auto() ATTN_Q_NORM = auto()
ATTN_K_NORM = auto() ATTN_K_NORM = auto()
OUTPUT_LAYER_NORM = auto()
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
@ -131,6 +132,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.BERT: "bert", MODEL_ARCH.BERT: "bert",
MODEL_ARCH.BLOOM: "bloom", MODEL_ARCH.BLOOM: "bloom",
MODEL_ARCH.STABLELM: "stablelm", MODEL_ARCH.STABLELM: "stablelm",
MODEL_ARCH.DISTILBERT: "distilbert"
} }
TENSOR_NAMES: dict[MODEL_TENSOR, str] = { TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@ -155,6 +157,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate", MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up", MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
MODEL_TENSOR.OUTPUT_LAYER_NORM: "blk.{bid}.output_norm"
} }
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
@ -185,6 +188,21 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP, MODEL_TENSOR.FFN_UP,
], ],
MODEL_ARCH.DISTILBERT: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.POS_EMBD,
MODEL_TENSOR.TOKEN_EMBD_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.OUTPUT_LAYER_NORM,
MODEL_TENSOR.OUTPUT,
],
MODEL_ARCH.FALCON: [ MODEL_ARCH.FALCON: [
MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT_NORM,

View file

@ -4,7 +4,6 @@ from typing import Sequence
from .constants import MODEL_ARCH, MODEL_TENSOR, MODEL_TENSORS, TENSOR_NAMES from .constants import MODEL_ARCH, MODEL_TENSOR, MODEL_TENSORS, TENSOR_NAMES
class TensorNameMap: class TensorNameMap:
mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = { mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {
# Token embeddings # Token embeddings
@ -27,6 +26,7 @@ class TensorNameMap:
# Normalization of token embeddings # Normalization of token embeddings
MODEL_TENSOR.TOKEN_EMBD_NORM: ( MODEL_TENSOR.TOKEN_EMBD_NORM: (
"word_embeddings_layernorm", # bloom "word_embeddings_layernorm", # bloom
"embeddings.LayerNorm", # distilbert
), ),
# Position embeddings # Position embeddings
@ -41,6 +41,7 @@ class TensorNameMap:
"lm_head", # gpt2 mpt falcon llama-hf baichuan "lm_head", # gpt2 mpt falcon llama-hf baichuan
"output", # llama-pth bloom "output", # llama-pth bloom
"word_embeddings_for_head", # persimmon "word_embeddings_for_head", # persimmon
"linear.weight", # sbert / distilbery
), ),
# Output norm # Output norm
@ -49,7 +50,7 @@ class TensorNameMap:
"transformer.ln_f", # gpt2 gpt-j falcon "transformer.ln_f", # gpt2 gpt-j falcon
"model.norm", # llama-hf baichuan "model.norm", # llama-hf baichuan
"norm", # llama-pth "norm", # llama-pth
"embeddings.LayerNorm", # bert # "embeddings.LayerNorm", # bert
"transformer.norm_f", # mpt "transformer.norm_f", # mpt
"ln_f", # refact bloom "ln_f", # refact bloom
"language_model.encoder.final_layernorm", # persimmon "language_model.encoder.final_layernorm", # persimmon
@ -98,6 +99,7 @@ class TensorNameMap:
"layers.{bid}.attention.wq", # llama-pth "layers.{bid}.attention.wq", # llama-pth
"encoder.layer.{bid}.attention.self.query", # bert "encoder.layer.{bid}.attention.self.query", # bert
"transformer.h.{bid}.attn.q_proj", # gpt-j "transformer.h.{bid}.attn.q_proj", # gpt-j
"transformer.layer.{bid}.attention.q_lin", #distilbert
), ),
# Attention key # Attention key
@ -106,6 +108,7 @@ class TensorNameMap:
"layers.{bid}.attention.wk", # llama-pth "layers.{bid}.attention.wk", # llama-pth
"encoder.layer.{bid}.attention.self.key", # bert "encoder.layer.{bid}.attention.self.key", # bert
"transformer.h.{bid}.attn.k_proj", # gpt-j "transformer.h.{bid}.attn.k_proj", # gpt-j
"transformer.layer.{bid}.attention.k_lin", #distilbert
), ),
# Attention value # Attention value
@ -114,6 +117,7 @@ class TensorNameMap:
"layers.{bid}.attention.wv", # llama-pth "layers.{bid}.attention.wv", # llama-pth
"encoder.layer.{bid}.attention.self.value", # bert "encoder.layer.{bid}.attention.self.value", # bert
"transformer.h.{bid}.attn.v_proj", # gpt-j "transformer.h.{bid}.attn.v_proj", # gpt-j
"transformer.layer.{bid}.attention.v_lin", # distilbert
), ),
# Attention output # Attention output
@ -128,6 +132,7 @@ class TensorNameMap:
"encoder.layer.{bid}.attention.output.dense", # bert "encoder.layer.{bid}.attention.output.dense", # bert
"transformer.h.{bid}.attn.out_proj", # gpt-j "transformer.h.{bid}.attn.out_proj", # gpt-j
"language_model.encoder.layers.{bid}.self_attention.dense", # persimmon "language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
"transformer.layer.{bid}.attention.out_lin", # distilbert
), ),
# Rotary embeddings # Rotary embeddings
@ -147,6 +152,7 @@ class TensorNameMap:
"encoder.layer.{bid}.output.LayerNorm", # bert "encoder.layer.{bid}.output.LayerNorm", # bert
"language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon "language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon
"model.layers.{bid}.ln2", # yi "model.layers.{bid}.ln2", # yi
"transformer.layer.{bid}.sa_layer_norm", #distilbert
), ),
# Feed-forward up # Feed-forward up
@ -161,6 +167,7 @@ class TensorNameMap:
"encoder.layer.{bid}.intermediate.dense", # bert "encoder.layer.{bid}.intermediate.dense", # bert
"transformer.h.{bid}.mlp.fc_in", # gpt-j "transformer.h.{bid}.mlp.fc_in", # gpt-j
"language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon "language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon
"transformer.layer.{bid}.ffn.lin1" # distil
), ),
# Feed-forward gate # Feed-forward gate
@ -181,6 +188,7 @@ class TensorNameMap:
"encoder.layer.{bid}.output.dense", # bert "encoder.layer.{bid}.output.dense", # bert
"transformer.h.{bid}.mlp.fc_out", # gpt-j "transformer.h.{bid}.mlp.fc_out", # gpt-j
"language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon "language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon
"transformer.layer.{bid}.ffn.lin2", # distilbert
), ),
MODEL_TENSOR.ATTN_Q_NORM: ( MODEL_TENSOR.ATTN_Q_NORM: (
@ -194,6 +202,12 @@ class TensorNameMap:
MODEL_TENSOR.ROPE_FREQS: ( MODEL_TENSOR.ROPE_FREQS: (
"language_model.encoder.layers.{bid}.self_attention.rotary_emb.inv_freq", # persimmon "language_model.encoder.layers.{bid}.self_attention.rotary_emb.inv_freq", # persimmon
), ),
# Output norm
MODEL_TENSOR.OUTPUT_LAYER_NORM: (
"transformer.layer.{bid}.output_layer_norm", # distilbert
),
} }
mapping: dict[str, tuple[MODEL_TENSOR, str]] mapping: dict[str, tuple[MODEL_TENSOR, str]]
@ -216,8 +230,8 @@ class TensorNameMap:
for key in keys: for key in keys:
key = key.format(bid = bid) key = key.format(bid = bid)
self.mapping[key] = (tensor, tensor_name) self.mapping[key] = (tensor, tensor_name)
def get_type_and_name(self, key: str, try_suffixes: Sequence[str] = ()) -> tuple[MODEL_TENSOR, str] | None: def get_type_and_name(self, key: str, try_suffixes: Sequence[str] = ()) -> tuple[MODEL_TENSOR, str] | None:
print("getting ", key)
result = self.mapping.get(key) result = self.mapping.get(key)
if result is not None: if result is not None:
return result return result
@ -241,6 +255,7 @@ class TensorNameMap:
return result[0] return result[0]
def __getitem__(self, key: str) -> str: def __getitem__(self, key: str) -> str:
#print ("getting", key)
try: try:
return self.mapping[key][1] return self.mapping[key][1]
except KeyError: except KeyError: