feat: create tensors for Jina architecture

This commit is contained in:
Joan Martinez 2024-04-12 12:47:48 +02:00
parent 86a5d96fc6
commit 747d17a62c
4 changed files with 61 additions and 38 deletions

View file

@ -77,6 +77,7 @@ class Model(ABC):
for part_name in self.part_names:
print(f"gguf: loading model part '{part_name}'")
ctx: ContextManager[Any]
if self.is_safetensors:
from safetensors import safe_open
ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
@ -91,6 +92,7 @@ class Model(ABC):
def set_gguf_parameters(self):
self.gguf_writer.add_name(self.dir_model.name)
print(f'self.block_count {self.block_count}')
self.gguf_writer.add_block_count(self.block_count)
if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx"], optional=True)) is not None:
@ -136,6 +138,7 @@ class Model(ABC):
def write_tensors(self):
block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
print(f'Block_count {block_count} with tensor_map {tensor_map}')
for name, data_torch in self.get_tensors():
# we don't need these
if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
@ -2096,6 +2099,7 @@ class BertModel(Model):
# map tensor names
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
if new_name is None:
print(f"Can not map tensor {name!r}")
sys.exit()
@ -2166,34 +2170,6 @@ class NomicBertModel(BertModel):
class JinaBertModel(BertModel):
model_arch = gguf.MODEL_ARCH.JINA_BERT
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
print(f'hparams {self.hparams}')
assert self.hparams["position_embedding_type"] == "alibi"
# def __init__(self, *args, **kwargs):
# super().__init__(*args, **kwargs)
#
# assert self.hparams["position_embedding_type"] == "alibi"
#
# # GeGLU activation
# assert self.hparams["feed_forward_type"] == "geglu"
#
# def get_tensors(self):
# assert self.vocab_size is not None
# for name, data in super().get_tensors():
# print(f'get_tensors: {name} {data.shape}')
# # Nomic Embed's token embeddings tensor is padded, but llama.cpp wants tensor sizes to match exactly.
# if name == 'embeddings.word_embeddings.weight' and data.shape[1] != self.vocab_size:
# rounded_vocab_size = (self.vocab_size + 63) // 64 * 64
# assert data.shape == (rounded_vocab_size, self.hparams["hidden_size"])
# data = data[:self.vocab_size, :]
# yield name, data
@Model.register("GemmaForCausalLM")
class GemmaModel(Model):
model_arch = gguf.MODEL_ARCH.GEMMA
@ -2461,9 +2437,7 @@ def main() -> None:
print(f"Loading model: {dir_model.name}")
hparams = Model.load_hparams(dir_model)
with torch.inference_mode():
print(hparams["architectures"])
model_class = Model.from_model_architecture(hparams["architectures"][0])
model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian)

View file

@ -363,14 +363,14 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.TOKEN_EMBD_NORM,
MODEL_TENSOR.TOKEN_TYPES,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ATTN_OUT_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.LAYER_OUT_NORM,
],
MODEL_ARCH.MPT: [

View file

@ -217,9 +217,6 @@ class TensorNameMap:
"model.layers.{bid}.mlp.up_proj", # llama-hf refact
"layers.{bid}.feed_forward.w3", # llama-pth
"encoder.layer.{bid}.intermediate.dense", # bert
"encoder.layer.{bid}.mlp.gated_layers", # jina-bert
"encoder.layer.{bid}.mlp.layernorm", # jina-bert
"encoder.layer.{bid}.mlp.wo", # jina-bert
"transformer.h.{bid}.mlp.fc_in", # gpt-j
"language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon
"model.layers.{bid}.mlp.dense_h_to_4h", # persimmon
@ -251,6 +248,7 @@ class TensorNameMap:
"model.layers.layers.{bid}.mlp.gate_proj", # plamo
"model.layers.{bid}.feed_forward.w1", # internlm2
"encoder.layers.{bid}.mlp.fc12", # nomic-bert
"encoder.layer.{bid}.mlp.gated_layers", # jina-bert
),
MODEL_TENSOR.FFN_GATE_EXP: (
@ -278,6 +276,7 @@ class TensorNameMap:
"model.layers.{bid}.feed_forward.w2", # internlm2
"encoder.layers.{bid}.mlp.fc2", # nomic-bert
"model.layers.{bid}.mlp.c_proj", # starcoder2
"encoder.layer.{bid}.mlp.wo", # jina-bert
),
MODEL_TENSOR.FFN_DOWN_EXP: (
@ -307,6 +306,7 @@ class TensorNameMap:
"encoder.layer.{bid}.output.LayerNorm", # bert
"encoder.layers.{bid}.norm2", # nomic-bert
"transformer.decoder_layer.{bid}.rms_norm_3", # Grok
"encoder.layer.{bid}.mlp.layernorm", # jina-bert
),
MODEL_TENSOR.SSM_IN: (

View file

@ -680,6 +680,7 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
@ -1921,6 +1922,16 @@ struct llama_layer {
// mamba bias
struct ggml_tensor * ssm_conv1d_b;
struct ggml_tensor * ssm_dt_b;
//glu mlp (jina-bert)
struct ggml_tensor * mlp_gated_layer_w;
struct ggml_tensor * mlp_wo_w;
struct ggml_tensor * mlp_wo_b;
struct ggml_tensor * mlp_norm_w;
struct ggml_tensor * mlp_norm_b;
};
struct llama_kv_cell {
@ -4813,7 +4824,6 @@ static bool llm_load_tensors(
}
} break;
case LLM_ARCH_BERT:
case LLM_ARCH_JINA_BERT:
case LLM_ARCH_NOMIC_BERT:
{
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
@ -4831,7 +4841,7 @@ static bool llm_load_tensors(
auto & layer = model.layers[i];
if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_JINA_BERT) {
if (model.arch == LLM_ARCH_BERT) {
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd});
@ -4852,7 +4862,7 @@ static bool llm_load_tensors(
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_JINA_BERT) {
if (model.arch == LLM_ARCH_BERT) {
layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd});
layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff});
@ -4865,6 +4875,44 @@ static bool llm_load_tensors(
layer.layer_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd});
}
} break;
case LLM_ARCH_JINA_BERT:
{
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // word_embeddings
model.type_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type}); //token_type_embeddings
model.tok_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); // LayerNorm
model.tok_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}); //LayerNorm bias? Not sure needed
for (int i = 0; i < n_layer; ++i) {
ggml_context * ctx_layer = ctx_for_layer(i);
ggml_context * ctx_split = ctx_for_layer_split(i);
auto & layer = model.layers[i]; // JinaBertLayer
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd});
layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa});
layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa});
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); //output_dens
layer.bo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); //output_dens
layer.attn_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}); //output_norm
layer.attn_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd});
// TODO: HANDLE ALL THE MLP
layer.mlp_gated_layer_w = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, 2 * n_ff});
layer.mlp_wo_w = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
layer.mlp_wo_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd});
layer.mlp_norm_w = ml.create_tensor(ctx_split, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd});
layer.mlp_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd});
}
} break;
case LLM_ARCH_BLOOM:
{
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
@ -9713,6 +9761,7 @@ static struct ggml_cgraph * llama_build_graph(
result = llm.build_refact();
} break;
case LLM_ARCH_BERT:
case LLM_ARCH_JINA_BERT:
case LLM_ARCH_NOMIC_BERT:
{
result = llm.build_bert();