llama : assert pooling tensor

This commit is contained in:
Georgi Gerganov 2024-03-04 19:24:03 +02:00
parent 79e4eede23
commit fc9af156ff
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -6246,6 +6246,7 @@ struct llm_build_context {
// final output
cur = inpL;
cb(cur, "result_embd", -1);
// pooling layer
switch (pooling_type) {
@ -6256,17 +6257,18 @@ struct llm_build_context {
case LLAMA_POOLING_TYPE_MEAN:
{
cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_mean);
cb(cur, "result_embd_pooled", -1);
} break;
case LLAMA_POOLING_TYPE_CLS:
{
cur = ggml_get_rows(ctx0, cur, inp_cls);
cb(cur, "result_embd_pooled", -1);
} break;
case LLAMA_POOLING_TYPE_UNSPECIFIED:
{
GGML_ASSERT(false && "Invalid pooling type");
} break;
}
cb(cur, "result_embd", -1);
ggml_build_forward_expand(gf, cur);
@ -8281,7 +8283,7 @@ static int llama_decode_internal(
// token or sequence embeddings
embd = gf->nodes[gf->n_nodes - 1];
GGML_ASSERT(strcmp(embd->name, "result_embd") == 0);
GGML_ASSERT(strcmp(embd->name, "result_embd") == 0 || strcmp(embd->name, "result_embd_pooled") == 0);
} else {
if (strcmp(res->name, "result_output") == 0) {
// the token embeddings could be the second to last tensor, or the third to last tensor
@ -8413,6 +8415,8 @@ static int llama_decode_internal(
case LLAMA_POOLING_TYPE_CLS:
case LLAMA_POOLING_TYPE_MEAN:
{
GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0);
// extract sequence embeddings
auto & embd_seq_out = lctx.embd_seq;
embd_seq_out.clear();