fix up model sizing and result acquisition

This commit is contained in:
Douglas Hanley 2024-02-08 22:43:26 -06:00
parent 96d37f8d55
commit 68758083d6

View file

@ -1464,9 +1464,11 @@ static llama_state g_state;
// available llama models
enum e_model {
MODEL_UNKNOWN,
MODEL_17M,
MODEL_22M,
MODEL_33M,
MODEL_109M,
MODEL_335M,
MODEL_0_5B,
MODEL_1B,
MODEL_2B,
@ -3040,14 +3042,18 @@ static void llm_load_hparams(
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type);
switch (hparams.n_embd) {
case 384: // MiniLM
switch (hparams.n_layer) {
case 6: model.type = e_model::MODEL_22M; break;
case 12: model.type = e_model::MODEL_33M; break;
switch (hparams.n_layer) {
case 3:
model.type = e_model::MODEL_17M; break; // bge-micro
case 6:
model.type = e_model::MODEL_22M; break; // MiniLM-L6
case 12:
switch (hparams.n_embd) {
case 384: model.type = e_model::MODEL_33M; break; // MiniLM-L12, bge-small
case 768: model.type = e_model::MODEL_109M; break; // bge-base
} break;
case 768: // BERT-Base
model.type = e_model::MODEL_109M; break;
case 24:
model.type = e_model::MODEL_335M; break; // bge-large
}
} break;
case LLM_ARCH_BLOOM:
@ -3851,8 +3857,8 @@ static bool llm_load_tensors(
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
model.type_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type});
model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, hparams.n_ctx_train});
model.tok_norm = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd});
model.tok_norm_b = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd});
model.tok_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd});
model.tok_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd});
for (int i = 0; i < n_layer; ++i) {
ggml_context * ctx_layer = ctx_for_layer(i);
@ -7481,21 +7487,15 @@ static int llama_decode_internal(
ggml_cgraph * gf = llama_build_graph(lctx, batch);
// the output is always the last tensor in the graph
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
struct ggml_tensor * embeddings = nullptr;
if (strcmp(res->name, "result_embed") == 0) {
embeddings = res;
res = nullptr;
} else {
// the embeddings could be the second to last tensor, or the third to last tensor
GGML_ASSERT(strcmp(res->name, "result_output") == 0);
embeddings = gf->nodes[gf->n_nodes - 2];
if (strcmp(embeddings->name, "result_norm") != 0) {
embeddings = gf->nodes[gf->n_nodes - 3];
GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0);
}
// get logits and embeddings
struct ggml_tensor * res = ggml_graph_get_tensor(gf, "result_output");
struct ggml_tensor * embeddings = ggml_graph_get_tensor(gf, "result_norm");
// if logits are none we must be doing embeddings
if (res == nullptr) {
embeddings = ggml_graph_get_tensor(gf, "result_embed");
}
GGML_ASSERT(res || embeddings);
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);