feat: embedding gets results
This commit is contained in:
parent
a40156a077
commit
b00d38b0b1
4 changed files with 42 additions and 7 deletions
|
@ -2170,6 +2170,29 @@ class NomicBertModel(BertModel):
|
|||
class JinaBertModel(BertModel):
|
||||
model_arch = gguf.MODEL_ARCH.JINA_BERT
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.intermediate_size = self.hparams["intermediate_size"]
|
||||
|
||||
def get_tensors(self):
|
||||
import string
|
||||
print(f'Intermediate SIZE: {self.intermediate_size}')
|
||||
|
||||
for name, data in super().get_tensors():
|
||||
if 'gated_layers' in name:
|
||||
print(f'name {name} => {data.shape}')
|
||||
d1 = data[:self.intermediate_size, :]
|
||||
name1 = name.replace('gated_layers', 'gated_layers_w')
|
||||
d2 = data[self.intermediate_size:, :]
|
||||
name2 = name.replace('gated_layers', 'gated_layers_v')
|
||||
print(f'd1 {d1.shape}, d2 {d2.shape}')
|
||||
yield name1, d1
|
||||
yield name2, d2
|
||||
continue
|
||||
|
||||
yield name, data
|
||||
|
||||
|
||||
@Model.register("GemmaForCausalLM")
|
||||
class GemmaModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.GEMMA
|
||||
|
|
|
@ -369,6 +369,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.LAYER_OUT_NORM,
|
||||
],
|
||||
|
|
|
@ -228,7 +228,7 @@ class TensorNameMap:
|
|||
"model.layers.{bid}.feed_forward.w3", # internlm2
|
||||
"encoder.layers.{bid}.mlp.fc11", # nomic-bert
|
||||
"model.layers.{bid}.mlp.c_fc", # starcoder2
|
||||
"encoder.layer.{bid}.mlp.gated_layers", # jina-bert
|
||||
"encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_UP_EXP: (
|
||||
|
@ -249,6 +249,7 @@ class TensorNameMap:
|
|||
"model.layers.layers.{bid}.mlp.gate_proj", # plamo
|
||||
"model.layers.{bid}.feed_forward.w1", # internlm2
|
||||
"encoder.layers.{bid}.mlp.fc12", # nomic-bert
|
||||
"encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_EXP: (
|
||||
|
|
22
llama.cpp
22
llama.cpp
|
@ -4870,7 +4870,7 @@ static bool llm_load_tensors(
|
|||
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // word_embeddings
|
||||
model.type_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type}); //token_type_embeddings
|
||||
model.tok_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); // LayerNorm
|
||||
model.tok_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}); //LayerNorm bias? Not sure needed
|
||||
model.tok_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}); //LayerNorm bias
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
ggml_context * ctx_layer = ctx_for_layer(i);
|
||||
|
@ -4893,8 +4893,8 @@ static bool llm_load_tensors(
|
|||
layer.attn_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}); //output_norm
|
||||
layer.attn_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd});
|
||||
|
||||
// TODO: HANDLE ALL THE MLP
|
||||
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, 2 * n_ff});
|
||||
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
|
||||
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
|
||||
|
||||
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
|
||||
layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd});
|
||||
|
@ -5851,7 +5851,7 @@ static struct ggml_tensor * llm_build_ffn(
|
|||
llm_ffn_gate_type type_gate,
|
||||
const llm_build_cb & cb,
|
||||
int il) {
|
||||
struct ggml_tensor * tmp = ggml_mul_mat(ctx, up, cur);
|
||||
struct ggml_tensor * tmp = up ? ggml_mul_mat(ctx, up, cur): cur;
|
||||
cb(tmp, "ffn_up", il);
|
||||
|
||||
if (up_b) {
|
||||
|
@ -7522,8 +7522,11 @@ struct llm_build_context {
|
|||
|
||||
struct ggml_tensor * cur;
|
||||
struct ggml_tensor * inpL;
|
||||
struct ggml_tensor * inp_pos = nullptr;
|
||||
|
||||
struct ggml_tensor * inp_pos = build_inp_pos();
|
||||
if (model.arch != LLM_ARCH_JINA_BERT) {
|
||||
inp_pos = build_inp_pos();
|
||||
}
|
||||
struct ggml_tensor * inp_mean = build_inp_mean();
|
||||
struct ggml_tensor * inp_cls = build_inp_cls();
|
||||
|
||||
|
@ -7644,13 +7647,20 @@ struct llm_build_context {
|
|||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// feed-forward network
|
||||
if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_JINA_BERT) {
|
||||
if (model.arch == LLM_ARCH_BERT) {
|
||||
cur = llm_build_ffn(ctx0, cur,
|
||||
model.layers[il].ffn_up, model.layers[il].ffn_up_b,
|
||||
NULL, NULL,
|
||||
model.layers[il].ffn_down, model.layers[il].ffn_down_b,
|
||||
NULL,
|
||||
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
|
||||
} else if (model.arch == LLM_ARCH_JINA_BERT) {
|
||||
cur = llm_build_ffn(ctx0, cur,
|
||||
model.layers[il].ffn_up, NULL,
|
||||
model.layers[il].ffn_gate, NULL,
|
||||
model.layers[il].ffn_down, model.layers[il].ffn_down_b,
|
||||
NULL,
|
||||
LLM_FFN_GELU, LLM_FFN_PAR, cb, il);
|
||||
} else {
|
||||
cur = llm_build_ffn(ctx0, cur,
|
||||
model.layers[il].ffn_up, NULL,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue