llama : assert pooling tensor

This commit is contained in:
Georgi Gerganov 2024-03-04 19:24:03 +02:00
parent 79e4eede23
commit fc9af156ff
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -6246,6 +6246,7 @@ struct llm_build_context {
// final output // final output
cur = inpL; cur = inpL;
cb(cur, "result_embd", -1);
// pooling layer // pooling layer
switch (pooling_type) { switch (pooling_type) {
@ -6256,17 +6257,18 @@ struct llm_build_context {
case LLAMA_POOLING_TYPE_MEAN: case LLAMA_POOLING_TYPE_MEAN:
{ {
cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_mean); cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_mean);
cb(cur, "result_embd_pooled", -1);
} break; } break;
case LLAMA_POOLING_TYPE_CLS: case LLAMA_POOLING_TYPE_CLS:
{ {
cur = ggml_get_rows(ctx0, cur, inp_cls); cur = ggml_get_rows(ctx0, cur, inp_cls);
cb(cur, "result_embd_pooled", -1);
} break; } break;
case LLAMA_POOLING_TYPE_UNSPECIFIED: case LLAMA_POOLING_TYPE_UNSPECIFIED:
{ {
GGML_ASSERT(false && "Invalid pooling type"); GGML_ASSERT(false && "Invalid pooling type");
} break; } break;
} }
cb(cur, "result_embd", -1);
ggml_build_forward_expand(gf, cur); ggml_build_forward_expand(gf, cur);
@ -8281,7 +8283,7 @@ static int llama_decode_internal(
// token or sequence embeddings // token or sequence embeddings
embd = gf->nodes[gf->n_nodes - 1]; embd = gf->nodes[gf->n_nodes - 1];
GGML_ASSERT(strcmp(embd->name, "result_embd") == 0); GGML_ASSERT(strcmp(embd->name, "result_embd") == 0 || strcmp(embd->name, "result_embd_pooled") == 0);
} else { } else {
if (strcmp(res->name, "result_output") == 0) { if (strcmp(res->name, "result_output") == 0) {
// the token embeddings could be the second to last tensor, or the third to last tensor // the token embeddings could be the second to last tensor, or the third to last tensor
@ -8413,6 +8415,8 @@ static int llama_decode_internal(
case LLAMA_POOLING_TYPE_CLS: case LLAMA_POOLING_TYPE_CLS:
case LLAMA_POOLING_TYPE_MEAN: case LLAMA_POOLING_TYPE_MEAN:
{ {
GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0);
// extract sequence embeddings // extract sequence embeddings
auto & embd_seq_out = lctx.embd_seq; auto & embd_seq_out = lctx.embd_seq;
embd_seq_out.clear(); embd_seq_out.clear();