convert : fix convert for refact models
This commit is contained in:
parent
0faf92e74c
commit
03e940cdec
2 changed files with 16 additions and 0 deletions
|
@ -1013,6 +1013,18 @@ class StarCoderModel(Model):
|
|||
class RefactModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.REFACT
|
||||
|
||||
def set_vocab(self):
|
||||
super().set_vocab()
|
||||
|
||||
# TODO: how to determine special FIM tokens automatically?
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False,
|
||||
special_token_types = ['prefix', 'suffix', 'middle', 'fsep', 'eot'])
|
||||
special_vocab._set_special_token("prefix", 1)
|
||||
special_vocab._set_special_token("suffix", 3)
|
||||
special_vocab._set_special_token("middle", 2)
|
||||
special_vocab._set_special_token("fsep", 4) # is this correct?
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
hidden_dim = self.hparams["n_embd"]
|
||||
inner_dim = 4 * hidden_dim
|
||||
|
|
|
@ -137,6 +137,7 @@ class TensorNameMap:
|
|||
"layers.{bid}.attention.wk", # llama-pth
|
||||
"encoder.layer.{bid}.attention.self.key", # bert
|
||||
"transformer.h.{bid}.attn.k_proj", # gpt-j
|
||||
"transformer.h.{bid}.attn.k", # refact
|
||||
"model.layers.layers.{bid}.self_attn.k_proj", # plamo
|
||||
"model.layers.{bid}.attention.wk", # internlm2
|
||||
"transformer.decoder_layer.{bid}.multi_head_attention.key" # Grok
|
||||
|
@ -148,6 +149,7 @@ class TensorNameMap:
|
|||
"layers.{bid}.attention.wv", # llama-pth
|
||||
"encoder.layer.{bid}.attention.self.value", # bert
|
||||
"transformer.h.{bid}.attn.v_proj", # gpt-j
|
||||
"transformer.h.{bid}.attn.v", # refact
|
||||
"model.layers.layers.{bid}.self_attn.v_proj", # plamo
|
||||
"model.layers.{bid}.attention.wv", # internlm2
|
||||
"transformer.decoder_layer.{bid}.multi_head_attention.value" # Grok
|
||||
|
@ -229,6 +231,7 @@ class TensorNameMap:
|
|||
"layers.{bid}.feed_forward.w3", # llama-pth
|
||||
"encoder.layer.{bid}.intermediate.dense", # bert
|
||||
"transformer.h.{bid}.mlp.fc_in", # gpt-j
|
||||
"transformer.h.{bid}.mlp.linear_3", # refact
|
||||
"language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon
|
||||
"model.layers.{bid}.mlp.dense_h_to_4h", # persimmon
|
||||
"transformer.h.{bid}.mlp.w1", # qwen
|
||||
|
@ -266,6 +269,7 @@ class TensorNameMap:
|
|||
"model.layers.layers.{bid}.mlp.gate_proj", # plamo
|
||||
"model.layers.{bid}.feed_forward.w1", # internlm2
|
||||
"encoder.layers.{bid}.mlp.fc12", # nomic-bert
|
||||
"transformer.h.{bid}.mlp.linear_1", # refact
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_EXP: (
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue