added support for sbert/distilbert model
This commit is contained in:
parent
8da46278e1
commit
19d82c2244
3 changed files with 153 additions and 57 deletions
|
@ -14,15 +14,19 @@ from typing import TYPE_CHECKING, Any, ContextManager, Iterator, cast
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from collections import OrderedDict
|
||||
|
||||
#TYPE_CHECKING = False
|
||||
if TYPE_CHECKING:
|
||||
from torch import Tensor
|
||||
|
||||
if 'NO_LOCAL_GGUF' not in os.environ:
|
||||
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
|
||||
#if 'NO_LOCAL_GGUF' not in os.environ:
|
||||
sys.path.insert(0, str(Path(__file__).parent / 'gguf-py'))
|
||||
import gguf
|
||||
|
||||
|
||||
#archs = [n.name for n in gguf.MODEL_ARCH] + ['DISTILBERT']
|
||||
#gguf.MODEL_ARCH = IntEnum('MODEL_ARCH',archs)
|
||||
#gguf.MODEL_ARCH_NAMES[gguf.MODEL_ARCH.DISTILBERT] = "distilbert"
|
||||
###### MODEL DEFINITIONS ######
|
||||
|
||||
class SentencePieceTokenTypes(IntEnum):
|
||||
|
@ -84,6 +88,9 @@ class Model:
|
|||
def write_tensors(self):
|
||||
block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
|
||||
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
|
||||
print ("-------------------")
|
||||
for tt,_ in self.get_tensors():
|
||||
print(tt)
|
||||
for name, data_torch in self.get_tensors():
|
||||
# we don't need these
|
||||
if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
|
||||
|
@ -166,6 +173,8 @@ class Model:
|
|||
return RefactModel
|
||||
if model_architecture == "PersimmonForCausalLM":
|
||||
return PersimmonModel
|
||||
if model_architecture == "DistilBertModel":
|
||||
return DistilBertModel
|
||||
if model_architecture in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
|
||||
return StableLMModel
|
||||
return Model
|
||||
|
@ -203,6 +212,8 @@ class Model:
|
|||
return gguf.MODEL_ARCH.PERSIMMON
|
||||
if arch in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
|
||||
return gguf.MODEL_ARCH.STABLELM
|
||||
if arch == "DistilBertModel":
|
||||
return gguf.MODEL_ARCH.DISTILBERT
|
||||
|
||||
raise NotImplementedError(f'Architecture "{arch}" not supported!')
|
||||
|
||||
|
@ -242,6 +253,7 @@ class Model:
|
|||
special_vocab = gguf.SpecialVocab(dir_model, load_merges=True)
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
|
||||
def _set_vocab_sentencepiece(self):
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
|
||||
|
@ -691,6 +703,57 @@ class StarCoderModel(Model):
|
|||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
|
||||
class DistilBertModel(Model):
|
||||
|
||||
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
|
||||
for part_name in self.part_names:
|
||||
print(f"gguf: loading model part '{part_name}'")
|
||||
ctx: ContextManager[Any]
|
||||
if self.is_safetensors:
|
||||
raise NotImplementedError(f'Safetensors not supported for Distilbert')
|
||||
else:
|
||||
# load modules as in https://github.com/UKPLab/sentence-transformers/blob/master/sentence_transformers/SentenceTransformer.py
|
||||
from sentence_transformers.util import import_from_string
|
||||
modules_json_path = os.path.join(self.dir_model, 'modules.json')
|
||||
with open(modules_json_path) as fIn:
|
||||
modules_config = json.load(fIn)
|
||||
|
||||
modules = OrderedDict()
|
||||
for module_config in modules_config:
|
||||
module_class = import_from_string(module_config['type'])
|
||||
module = module_class.load(os.path.join(self.dir_model, module_config['path']))
|
||||
modules[module_config['name']] = module
|
||||
|
||||
sd = torch.nn.Sequential(modules).state_dict()
|
||||
# deep copy
|
||||
keys = [k for k in sd.keys()]
|
||||
for k in keys:
|
||||
key_chunks = k.split(".")[1:]
|
||||
newkey = ".".join([c for c in key_chunks if c != "auto_model"])
|
||||
sd[newkey] = sd.pop(k)
|
||||
|
||||
ctx = contextlib.nullcontext(sd)
|
||||
|
||||
with ctx as model_part:
|
||||
for name in model_part.keys():
|
||||
data = model_part.get_tensor(name) if self.is_safetensors else model_part[name]
|
||||
yield name, data
|
||||
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["n_layers"]
|
||||
|
||||
self.gguf_writer.add_name("DistilBert")
|
||||
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
|
||||
self.gguf_writer.add_embedding_length(self.hparams["dim"])
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams["hidden_dim"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_head_count(self.hparams["n_heads"])
|
||||
self.gguf_writer.add_head_count_kv(1)
|
||||
#self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
|
||||
class RefactModel(Model):
|
||||
def set_gguf_parameters(self):
|
||||
hidden_dim = self.hparams["n_embd"]
|
||||
|
|
|
@ -91,7 +91,7 @@ class MODEL_ARCH(IntEnum):
|
|||
BERT = auto()
|
||||
BLOOM = auto()
|
||||
STABLELM = auto()
|
||||
|
||||
DISTILBERT = auto()
|
||||
|
||||
class MODEL_TENSOR(IntEnum):
|
||||
TOKEN_EMBD = auto()
|
||||
|
@ -115,6 +115,7 @@ class MODEL_TENSOR(IntEnum):
|
|||
FFN_NORM = auto()
|
||||
ATTN_Q_NORM = auto()
|
||||
ATTN_K_NORM = auto()
|
||||
OUTPUT_LAYER_NORM = auto()
|
||||
|
||||
|
||||
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
|
@ -131,6 +132,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
|||
MODEL_ARCH.BERT: "bert",
|
||||
MODEL_ARCH.BLOOM: "bloom",
|
||||
MODEL_ARCH.STABLELM: "stablelm",
|
||||
MODEL_ARCH.DISTILBERT: "distilbert"
|
||||
}
|
||||
|
||||
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
|
@ -155,6 +157,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
|||
MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
|
||||
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
|
||||
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
|
||||
MODEL_TENSOR.OUTPUT_LAYER_NORM: "blk.{bid}.output_norm"
|
||||
}
|
||||
|
||||
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
|
@ -185,6 +188,21 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.DISTILBERT: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.POS_EMBD,
|
||||
MODEL_TENSOR.TOKEN_EMBD_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.OUTPUT_LAYER_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
],
|
||||
|
||||
MODEL_ARCH.FALCON: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
|
|
|
@ -4,7 +4,6 @@ from typing import Sequence
|
|||
|
||||
from .constants import MODEL_ARCH, MODEL_TENSOR, MODEL_TENSORS, TENSOR_NAMES
|
||||
|
||||
|
||||
class TensorNameMap:
|
||||
mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {
|
||||
# Token embeddings
|
||||
|
@ -27,6 +26,7 @@ class TensorNameMap:
|
|||
# Normalization of token embeddings
|
||||
MODEL_TENSOR.TOKEN_EMBD_NORM: (
|
||||
"word_embeddings_layernorm", # bloom
|
||||
"embeddings.LayerNorm", # distilbert
|
||||
),
|
||||
|
||||
# Position embeddings
|
||||
|
@ -41,6 +41,7 @@ class TensorNameMap:
|
|||
"lm_head", # gpt2 mpt falcon llama-hf baichuan
|
||||
"output", # llama-pth bloom
|
||||
"word_embeddings_for_head", # persimmon
|
||||
"linear.weight", # sbert / distilbery
|
||||
),
|
||||
|
||||
# Output norm
|
||||
|
@ -49,7 +50,7 @@ class TensorNameMap:
|
|||
"transformer.ln_f", # gpt2 gpt-j falcon
|
||||
"model.norm", # llama-hf baichuan
|
||||
"norm", # llama-pth
|
||||
"embeddings.LayerNorm", # bert
|
||||
# "embeddings.LayerNorm", # bert
|
||||
"transformer.norm_f", # mpt
|
||||
"ln_f", # refact bloom
|
||||
"language_model.encoder.final_layernorm", # persimmon
|
||||
|
@ -98,6 +99,7 @@ class TensorNameMap:
|
|||
"layers.{bid}.attention.wq", # llama-pth
|
||||
"encoder.layer.{bid}.attention.self.query", # bert
|
||||
"transformer.h.{bid}.attn.q_proj", # gpt-j
|
||||
"transformer.layer.{bid}.attention.q_lin", #distilbert
|
||||
),
|
||||
|
||||
# Attention key
|
||||
|
@ -106,6 +108,7 @@ class TensorNameMap:
|
|||
"layers.{bid}.attention.wk", # llama-pth
|
||||
"encoder.layer.{bid}.attention.self.key", # bert
|
||||
"transformer.h.{bid}.attn.k_proj", # gpt-j
|
||||
"transformer.layer.{bid}.attention.k_lin", #distilbert
|
||||
),
|
||||
|
||||
# Attention value
|
||||
|
@ -114,6 +117,7 @@ class TensorNameMap:
|
|||
"layers.{bid}.attention.wv", # llama-pth
|
||||
"encoder.layer.{bid}.attention.self.value", # bert
|
||||
"transformer.h.{bid}.attn.v_proj", # gpt-j
|
||||
"transformer.layer.{bid}.attention.v_lin", # distilbert
|
||||
),
|
||||
|
||||
# Attention output
|
||||
|
@ -128,6 +132,7 @@ class TensorNameMap:
|
|||
"encoder.layer.{bid}.attention.output.dense", # bert
|
||||
"transformer.h.{bid}.attn.out_proj", # gpt-j
|
||||
"language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
|
||||
"transformer.layer.{bid}.attention.out_lin", # distilbert
|
||||
),
|
||||
|
||||
# Rotary embeddings
|
||||
|
@ -147,6 +152,7 @@ class TensorNameMap:
|
|||
"encoder.layer.{bid}.output.LayerNorm", # bert
|
||||
"language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon
|
||||
"model.layers.{bid}.ln2", # yi
|
||||
"transformer.layer.{bid}.sa_layer_norm", #distilbert
|
||||
),
|
||||
|
||||
# Feed-forward up
|
||||
|
@ -161,6 +167,7 @@ class TensorNameMap:
|
|||
"encoder.layer.{bid}.intermediate.dense", # bert
|
||||
"transformer.h.{bid}.mlp.fc_in", # gpt-j
|
||||
"language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon
|
||||
"transformer.layer.{bid}.ffn.lin1" # distil
|
||||
),
|
||||
|
||||
# Feed-forward gate
|
||||
|
@ -181,6 +188,7 @@ class TensorNameMap:
|
|||
"encoder.layer.{bid}.output.dense", # bert
|
||||
"transformer.h.{bid}.mlp.fc_out", # gpt-j
|
||||
"language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon
|
||||
"transformer.layer.{bid}.ffn.lin2", # distilbert
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_Q_NORM: (
|
||||
|
@ -194,6 +202,12 @@ class TensorNameMap:
|
|||
MODEL_TENSOR.ROPE_FREQS: (
|
||||
"language_model.encoder.layers.{bid}.self_attention.rotary_emb.inv_freq", # persimmon
|
||||
),
|
||||
|
||||
# Output norm
|
||||
MODEL_TENSOR.OUTPUT_LAYER_NORM: (
|
||||
"transformer.layer.{bid}.output_layer_norm", # distilbert
|
||||
),
|
||||
|
||||
}
|
||||
|
||||
mapping: dict[str, tuple[MODEL_TENSOR, str]]
|
||||
|
@ -216,8 +230,8 @@ class TensorNameMap:
|
|||
for key in keys:
|
||||
key = key.format(bid = bid)
|
||||
self.mapping[key] = (tensor, tensor_name)
|
||||
|
||||
def get_type_and_name(self, key: str, try_suffixes: Sequence[str] = ()) -> tuple[MODEL_TENSOR, str] | None:
|
||||
print("getting ", key)
|
||||
result = self.mapping.get(key)
|
||||
if result is not None:
|
||||
return result
|
||||
|
@ -241,6 +255,7 @@ class TensorNameMap:
|
|||
return result[0]
|
||||
|
||||
def __getitem__(self, key: str) -> str:
|
||||
#print ("getting", key)
|
||||
try:
|
||||
return self.mapping[key][1]
|
||||
except KeyError:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue