diff --git a/llama.cpp b/llama.cpp index 1d4ea55d6..f58eb7c7b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1464,9 +1464,11 @@ static llama_state g_state; // available llama models enum e_model { MODEL_UNKNOWN, + MODEL_17M, MODEL_22M, MODEL_33M, MODEL_109M, + MODEL_335M, MODEL_0_5B, MODEL_1B, MODEL_2B, @@ -3040,14 +3042,18 @@ static void llm_load_hparams( ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type); - switch (hparams.n_embd) { - case 384: // MiniLM - switch (hparams.n_layer) { - case 6: model.type = e_model::MODEL_22M; break; - case 12: model.type = e_model::MODEL_33M; break; + switch (hparams.n_layer) { + case 3: + model.type = e_model::MODEL_17M; break; // bge-micro + case 6: + model.type = e_model::MODEL_22M; break; // MiniLM-L6 + case 12: + switch (hparams.n_embd) { + case 384: model.type = e_model::MODEL_33M; break; // MiniLM-L12, bge-small + case 768: model.type = e_model::MODEL_109M; break; // bge-base } break; - case 768: // BERT-Base - model.type = e_model::MODEL_109M; break; + case 24: + model.type = e_model::MODEL_335M; break; // bge-large } } break; case LLM_ARCH_BLOOM: @@ -3851,8 +3857,8 @@ static bool llm_load_tensors( model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); model.type_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type}); model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, hparams.n_ctx_train}); - model.tok_norm = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); - model.tok_norm_b = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}); + model.tok_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); + model.tok_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}); for (int i = 0; i < n_layer; ++i) { ggml_context * ctx_layer = ctx_for_layer(i); @@ -7481,21 +7487,15 @@ static int llama_decode_internal( ggml_cgraph * gf = llama_build_graph(lctx, batch); - // the output is always the last tensor in the graph - struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; - struct ggml_tensor * embeddings = nullptr; - if (strcmp(res->name, "result_embed") == 0) { - embeddings = res; - res = nullptr; - } else { - // the embeddings could be the second to last tensor, or the third to last tensor - GGML_ASSERT(strcmp(res->name, "result_output") == 0); - embeddings = gf->nodes[gf->n_nodes - 2]; - if (strcmp(embeddings->name, "result_norm") != 0) { - embeddings = gf->nodes[gf->n_nodes - 3]; - GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0); - } + // get logits and embeddings + struct ggml_tensor * res = ggml_graph_get_tensor(gf, "result_output"); + struct ggml_tensor * embeddings = ggml_graph_get_tensor(gf, "result_norm"); + + // if logits are none we must be doing embeddings + if (res == nullptr) { + embeddings = ggml_graph_get_tensor(gf, "result_embed"); } + GGML_ASSERT(res || embeddings); // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);