added support for sbert/distilbert model
This commit is contained in:
parent
8da46278e1
commit
19d82c2244
3 changed files with 153 additions and 57 deletions
|
@ -14,15 +14,19 @@ from typing import TYPE_CHECKING, Any, ContextManager, Iterator, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
#TYPE_CHECKING = False
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
if 'NO_LOCAL_GGUF' not in os.environ:
|
#if 'NO_LOCAL_GGUF' not in os.environ:
|
||||||
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
|
sys.path.insert(0, str(Path(__file__).parent / 'gguf-py'))
|
||||||
import gguf
|
import gguf
|
||||||
|
|
||||||
|
#archs = [n.name for n in gguf.MODEL_ARCH] + ['DISTILBERT']
|
||||||
|
#gguf.MODEL_ARCH = IntEnum('MODEL_ARCH',archs)
|
||||||
|
#gguf.MODEL_ARCH_NAMES[gguf.MODEL_ARCH.DISTILBERT] = "distilbert"
|
||||||
###### MODEL DEFINITIONS ######
|
###### MODEL DEFINITIONS ######
|
||||||
|
|
||||||
class SentencePieceTokenTypes(IntEnum):
|
class SentencePieceTokenTypes(IntEnum):
|
||||||
|
@ -84,6 +88,9 @@ class Model:
|
||||||
def write_tensors(self):
|
def write_tensors(self):
|
||||||
block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
|
block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
|
||||||
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
|
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
|
||||||
|
print ("-------------------")
|
||||||
|
for tt,_ in self.get_tensors():
|
||||||
|
print(tt)
|
||||||
for name, data_torch in self.get_tensors():
|
for name, data_torch in self.get_tensors():
|
||||||
# we don't need these
|
# we don't need these
|
||||||
if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
|
if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
|
||||||
|
@ -166,6 +173,8 @@ class Model:
|
||||||
return RefactModel
|
return RefactModel
|
||||||
if model_architecture == "PersimmonForCausalLM":
|
if model_architecture == "PersimmonForCausalLM":
|
||||||
return PersimmonModel
|
return PersimmonModel
|
||||||
|
if model_architecture == "DistilBertModel":
|
||||||
|
return DistilBertModel
|
||||||
if model_architecture in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
|
if model_architecture in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
|
||||||
return StableLMModel
|
return StableLMModel
|
||||||
return Model
|
return Model
|
||||||
|
@ -203,6 +212,8 @@ class Model:
|
||||||
return gguf.MODEL_ARCH.PERSIMMON
|
return gguf.MODEL_ARCH.PERSIMMON
|
||||||
if arch in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
|
if arch in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
|
||||||
return gguf.MODEL_ARCH.STABLELM
|
return gguf.MODEL_ARCH.STABLELM
|
||||||
|
if arch == "DistilBertModel":
|
||||||
|
return gguf.MODEL_ARCH.DISTILBERT
|
||||||
|
|
||||||
raise NotImplementedError(f'Architecture "{arch}" not supported!')
|
raise NotImplementedError(f'Architecture "{arch}" not supported!')
|
||||||
|
|
||||||
|
@ -242,7 +253,8 @@ class Model:
|
||||||
special_vocab = gguf.SpecialVocab(dir_model, load_merges=True)
|
special_vocab = gguf.SpecialVocab(dir_model, load_merges=True)
|
||||||
special_vocab.add_to_gguf(self.gguf_writer)
|
special_vocab.add_to_gguf(self.gguf_writer)
|
||||||
|
|
||||||
def _set_vocab_sentencepiece(self):
|
|
||||||
|
def _set_vocab_sentencepiece(self):
|
||||||
from sentencepiece import SentencePieceProcessor
|
from sentencepiece import SentencePieceProcessor
|
||||||
|
|
||||||
tokenizer_path = self.dir_model / 'tokenizer.model'
|
tokenizer_path = self.dir_model / 'tokenizer.model'
|
||||||
|
@ -691,6 +703,57 @@ class StarCoderModel(Model):
|
||||||
self.gguf_writer.add_file_type(self.ftype)
|
self.gguf_writer.add_file_type(self.ftype)
|
||||||
|
|
||||||
|
|
||||||
|
class DistilBertModel(Model):
|
||||||
|
|
||||||
|
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
|
||||||
|
for part_name in self.part_names:
|
||||||
|
print(f"gguf: loading model part '{part_name}'")
|
||||||
|
ctx: ContextManager[Any]
|
||||||
|
if self.is_safetensors:
|
||||||
|
raise NotImplementedError(f'Safetensors not supported for Distilbert')
|
||||||
|
else:
|
||||||
|
# load modules as in https://github.com/UKPLab/sentence-transformers/blob/master/sentence_transformers/SentenceTransformer.py
|
||||||
|
from sentence_transformers.util import import_from_string
|
||||||
|
modules_json_path = os.path.join(self.dir_model, 'modules.json')
|
||||||
|
with open(modules_json_path) as fIn:
|
||||||
|
modules_config = json.load(fIn)
|
||||||
|
|
||||||
|
modules = OrderedDict()
|
||||||
|
for module_config in modules_config:
|
||||||
|
module_class = import_from_string(module_config['type'])
|
||||||
|
module = module_class.load(os.path.join(self.dir_model, module_config['path']))
|
||||||
|
modules[module_config['name']] = module
|
||||||
|
|
||||||
|
sd = torch.nn.Sequential(modules).state_dict()
|
||||||
|
# deep copy
|
||||||
|
keys = [k for k in sd.keys()]
|
||||||
|
for k in keys:
|
||||||
|
key_chunks = k.split(".")[1:]
|
||||||
|
newkey = ".".join([c for c in key_chunks if c != "auto_model"])
|
||||||
|
sd[newkey] = sd.pop(k)
|
||||||
|
|
||||||
|
ctx = contextlib.nullcontext(sd)
|
||||||
|
|
||||||
|
with ctx as model_part:
|
||||||
|
for name in model_part.keys():
|
||||||
|
data = model_part.get_tensor(name) if self.is_safetensors else model_part[name]
|
||||||
|
yield name, data
|
||||||
|
|
||||||
|
|
||||||
|
def set_gguf_parameters(self):
|
||||||
|
block_count = self.hparams["n_layers"]
|
||||||
|
|
||||||
|
self.gguf_writer.add_name("DistilBert")
|
||||||
|
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
|
||||||
|
self.gguf_writer.add_embedding_length(self.hparams["dim"])
|
||||||
|
self.gguf_writer.add_feed_forward_length(self.hparams["hidden_dim"])
|
||||||
|
self.gguf_writer.add_block_count(block_count)
|
||||||
|
self.gguf_writer.add_head_count(self.hparams["n_heads"])
|
||||||
|
self.gguf_writer.add_head_count_kv(1)
|
||||||
|
#self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
|
||||||
|
self.gguf_writer.add_file_type(self.ftype)
|
||||||
|
|
||||||
|
|
||||||
class RefactModel(Model):
|
class RefactModel(Model):
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
hidden_dim = self.hparams["n_embd"]
|
hidden_dim = self.hparams["n_embd"]
|
||||||
|
|
|
@ -91,7 +91,7 @@ class MODEL_ARCH(IntEnum):
|
||||||
BERT = auto()
|
BERT = auto()
|
||||||
BLOOM = auto()
|
BLOOM = auto()
|
||||||
STABLELM = auto()
|
STABLELM = auto()
|
||||||
|
DISTILBERT = auto()
|
||||||
|
|
||||||
class MODEL_TENSOR(IntEnum):
|
class MODEL_TENSOR(IntEnum):
|
||||||
TOKEN_EMBD = auto()
|
TOKEN_EMBD = auto()
|
||||||
|
@ -115,6 +115,7 @@ class MODEL_TENSOR(IntEnum):
|
||||||
FFN_NORM = auto()
|
FFN_NORM = auto()
|
||||||
ATTN_Q_NORM = auto()
|
ATTN_Q_NORM = auto()
|
||||||
ATTN_K_NORM = auto()
|
ATTN_K_NORM = auto()
|
||||||
|
OUTPUT_LAYER_NORM = auto()
|
||||||
|
|
||||||
|
|
||||||
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||||
|
@ -131,6 +132,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||||
MODEL_ARCH.BERT: "bert",
|
MODEL_ARCH.BERT: "bert",
|
||||||
MODEL_ARCH.BLOOM: "bloom",
|
MODEL_ARCH.BLOOM: "bloom",
|
||||||
MODEL_ARCH.STABLELM: "stablelm",
|
MODEL_ARCH.STABLELM: "stablelm",
|
||||||
|
MODEL_ARCH.DISTILBERT: "distilbert"
|
||||||
}
|
}
|
||||||
|
|
||||||
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||||
|
@ -155,6 +157,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||||
MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
|
MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
|
||||||
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
|
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
|
||||||
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
|
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
|
||||||
|
MODEL_TENSOR.OUTPUT_LAYER_NORM: "blk.{bid}.output_norm"
|
||||||
}
|
}
|
||||||
|
|
||||||
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||||
|
@ -185,6 +188,21 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||||
MODEL_TENSOR.FFN_DOWN,
|
MODEL_TENSOR.FFN_DOWN,
|
||||||
MODEL_TENSOR.FFN_UP,
|
MODEL_TENSOR.FFN_UP,
|
||||||
],
|
],
|
||||||
|
MODEL_ARCH.DISTILBERT: [
|
||||||
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
|
MODEL_TENSOR.POS_EMBD,
|
||||||
|
MODEL_TENSOR.TOKEN_EMBD_NORM,
|
||||||
|
MODEL_TENSOR.ATTN_Q,
|
||||||
|
MODEL_TENSOR.ATTN_K,
|
||||||
|
MODEL_TENSOR.ATTN_V,
|
||||||
|
MODEL_TENSOR.ATTN_OUT,
|
||||||
|
MODEL_TENSOR.FFN_NORM,
|
||||||
|
MODEL_TENSOR.FFN_DOWN,
|
||||||
|
MODEL_TENSOR.FFN_UP,
|
||||||
|
MODEL_TENSOR.OUTPUT_LAYER_NORM,
|
||||||
|
MODEL_TENSOR.OUTPUT,
|
||||||
|
],
|
||||||
|
|
||||||
MODEL_ARCH.FALCON: [
|
MODEL_ARCH.FALCON: [
|
||||||
MODEL_TENSOR.TOKEN_EMBD,
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
MODEL_TENSOR.OUTPUT_NORM,
|
MODEL_TENSOR.OUTPUT_NORM,
|
||||||
|
|
|
@ -4,7 +4,6 @@ from typing import Sequence
|
||||||
|
|
||||||
from .constants import MODEL_ARCH, MODEL_TENSOR, MODEL_TENSORS, TENSOR_NAMES
|
from .constants import MODEL_ARCH, MODEL_TENSOR, MODEL_TENSORS, TENSOR_NAMES
|
||||||
|
|
||||||
|
|
||||||
class TensorNameMap:
|
class TensorNameMap:
|
||||||
mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {
|
mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {
|
||||||
# Token embeddings
|
# Token embeddings
|
||||||
|
@ -27,6 +26,7 @@ class TensorNameMap:
|
||||||
# Normalization of token embeddings
|
# Normalization of token embeddings
|
||||||
MODEL_TENSOR.TOKEN_EMBD_NORM: (
|
MODEL_TENSOR.TOKEN_EMBD_NORM: (
|
||||||
"word_embeddings_layernorm", # bloom
|
"word_embeddings_layernorm", # bloom
|
||||||
|
"embeddings.LayerNorm", # distilbert
|
||||||
),
|
),
|
||||||
|
|
||||||
# Position embeddings
|
# Position embeddings
|
||||||
|
@ -41,6 +41,7 @@ class TensorNameMap:
|
||||||
"lm_head", # gpt2 mpt falcon llama-hf baichuan
|
"lm_head", # gpt2 mpt falcon llama-hf baichuan
|
||||||
"output", # llama-pth bloom
|
"output", # llama-pth bloom
|
||||||
"word_embeddings_for_head", # persimmon
|
"word_embeddings_for_head", # persimmon
|
||||||
|
"linear.weight", # sbert / distilbery
|
||||||
),
|
),
|
||||||
|
|
||||||
# Output norm
|
# Output norm
|
||||||
|
@ -49,7 +50,7 @@ class TensorNameMap:
|
||||||
"transformer.ln_f", # gpt2 gpt-j falcon
|
"transformer.ln_f", # gpt2 gpt-j falcon
|
||||||
"model.norm", # llama-hf baichuan
|
"model.norm", # llama-hf baichuan
|
||||||
"norm", # llama-pth
|
"norm", # llama-pth
|
||||||
"embeddings.LayerNorm", # bert
|
# "embeddings.LayerNorm", # bert
|
||||||
"transformer.norm_f", # mpt
|
"transformer.norm_f", # mpt
|
||||||
"ln_f", # refact bloom
|
"ln_f", # refact bloom
|
||||||
"language_model.encoder.final_layernorm", # persimmon
|
"language_model.encoder.final_layernorm", # persimmon
|
||||||
|
@ -98,6 +99,7 @@ class TensorNameMap:
|
||||||
"layers.{bid}.attention.wq", # llama-pth
|
"layers.{bid}.attention.wq", # llama-pth
|
||||||
"encoder.layer.{bid}.attention.self.query", # bert
|
"encoder.layer.{bid}.attention.self.query", # bert
|
||||||
"transformer.h.{bid}.attn.q_proj", # gpt-j
|
"transformer.h.{bid}.attn.q_proj", # gpt-j
|
||||||
|
"transformer.layer.{bid}.attention.q_lin", #distilbert
|
||||||
),
|
),
|
||||||
|
|
||||||
# Attention key
|
# Attention key
|
||||||
|
@ -106,6 +108,7 @@ class TensorNameMap:
|
||||||
"layers.{bid}.attention.wk", # llama-pth
|
"layers.{bid}.attention.wk", # llama-pth
|
||||||
"encoder.layer.{bid}.attention.self.key", # bert
|
"encoder.layer.{bid}.attention.self.key", # bert
|
||||||
"transformer.h.{bid}.attn.k_proj", # gpt-j
|
"transformer.h.{bid}.attn.k_proj", # gpt-j
|
||||||
|
"transformer.layer.{bid}.attention.k_lin", #distilbert
|
||||||
),
|
),
|
||||||
|
|
||||||
# Attention value
|
# Attention value
|
||||||
|
@ -114,6 +117,7 @@ class TensorNameMap:
|
||||||
"layers.{bid}.attention.wv", # llama-pth
|
"layers.{bid}.attention.wv", # llama-pth
|
||||||
"encoder.layer.{bid}.attention.self.value", # bert
|
"encoder.layer.{bid}.attention.self.value", # bert
|
||||||
"transformer.h.{bid}.attn.v_proj", # gpt-j
|
"transformer.h.{bid}.attn.v_proj", # gpt-j
|
||||||
|
"transformer.layer.{bid}.attention.v_lin", # distilbert
|
||||||
),
|
),
|
||||||
|
|
||||||
# Attention output
|
# Attention output
|
||||||
|
@ -128,6 +132,7 @@ class TensorNameMap:
|
||||||
"encoder.layer.{bid}.attention.output.dense", # bert
|
"encoder.layer.{bid}.attention.output.dense", # bert
|
||||||
"transformer.h.{bid}.attn.out_proj", # gpt-j
|
"transformer.h.{bid}.attn.out_proj", # gpt-j
|
||||||
"language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
|
"language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
|
||||||
|
"transformer.layer.{bid}.attention.out_lin", # distilbert
|
||||||
),
|
),
|
||||||
|
|
||||||
# Rotary embeddings
|
# Rotary embeddings
|
||||||
|
@ -147,6 +152,7 @@ class TensorNameMap:
|
||||||
"encoder.layer.{bid}.output.LayerNorm", # bert
|
"encoder.layer.{bid}.output.LayerNorm", # bert
|
||||||
"language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon
|
"language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon
|
||||||
"model.layers.{bid}.ln2", # yi
|
"model.layers.{bid}.ln2", # yi
|
||||||
|
"transformer.layer.{bid}.sa_layer_norm", #distilbert
|
||||||
),
|
),
|
||||||
|
|
||||||
# Feed-forward up
|
# Feed-forward up
|
||||||
|
@ -161,6 +167,7 @@ class TensorNameMap:
|
||||||
"encoder.layer.{bid}.intermediate.dense", # bert
|
"encoder.layer.{bid}.intermediate.dense", # bert
|
||||||
"transformer.h.{bid}.mlp.fc_in", # gpt-j
|
"transformer.h.{bid}.mlp.fc_in", # gpt-j
|
||||||
"language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon
|
"language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon
|
||||||
|
"transformer.layer.{bid}.ffn.lin1" # distil
|
||||||
),
|
),
|
||||||
|
|
||||||
# Feed-forward gate
|
# Feed-forward gate
|
||||||
|
@ -181,6 +188,7 @@ class TensorNameMap:
|
||||||
"encoder.layer.{bid}.output.dense", # bert
|
"encoder.layer.{bid}.output.dense", # bert
|
||||||
"transformer.h.{bid}.mlp.fc_out", # gpt-j
|
"transformer.h.{bid}.mlp.fc_out", # gpt-j
|
||||||
"language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon
|
"language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon
|
||||||
|
"transformer.layer.{bid}.ffn.lin2", # distilbert
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.ATTN_Q_NORM: (
|
MODEL_TENSOR.ATTN_Q_NORM: (
|
||||||
|
@ -194,6 +202,12 @@ class TensorNameMap:
|
||||||
MODEL_TENSOR.ROPE_FREQS: (
|
MODEL_TENSOR.ROPE_FREQS: (
|
||||||
"language_model.encoder.layers.{bid}.self_attention.rotary_emb.inv_freq", # persimmon
|
"language_model.encoder.layers.{bid}.self_attention.rotary_emb.inv_freq", # persimmon
|
||||||
),
|
),
|
||||||
|
|
||||||
|
# Output norm
|
||||||
|
MODEL_TENSOR.OUTPUT_LAYER_NORM: (
|
||||||
|
"transformer.layer.{bid}.output_layer_norm", # distilbert
|
||||||
|
),
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
mapping: dict[str, tuple[MODEL_TENSOR, str]]
|
mapping: dict[str, tuple[MODEL_TENSOR, str]]
|
||||||
|
@ -216,8 +230,8 @@ class TensorNameMap:
|
||||||
for key in keys:
|
for key in keys:
|
||||||
key = key.format(bid = bid)
|
key = key.format(bid = bid)
|
||||||
self.mapping[key] = (tensor, tensor_name)
|
self.mapping[key] = (tensor, tensor_name)
|
||||||
|
|
||||||
def get_type_and_name(self, key: str, try_suffixes: Sequence[str] = ()) -> tuple[MODEL_TENSOR, str] | None:
|
def get_type_and_name(self, key: str, try_suffixes: Sequence[str] = ()) -> tuple[MODEL_TENSOR, str] | None:
|
||||||
|
print("getting ", key)
|
||||||
result = self.mapping.get(key)
|
result = self.mapping.get(key)
|
||||||
if result is not None:
|
if result is not None:
|
||||||
return result
|
return result
|
||||||
|
@ -241,6 +255,7 @@ class TensorNameMap:
|
||||||
return result[0]
|
return result[0]
|
||||||
|
|
||||||
def __getitem__(self, key: str) -> str:
|
def __getitem__(self, key: str) -> str:
|
||||||
|
#print ("getting", key)
|
||||||
try:
|
try:
|
||||||
return self.mapping[key][1]
|
return self.mapping[key][1]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue