llama : assert pooling tensor
This commit is contained in:
parent
79e4eede23
commit
fc9af156ff
1 changed files with 6 additions and 2 deletions
|
@ -6246,6 +6246,7 @@ struct llm_build_context {
|
|||
|
||||
// final output
|
||||
cur = inpL;
|
||||
cb(cur, "result_embd", -1);
|
||||
|
||||
// pooling layer
|
||||
switch (pooling_type) {
|
||||
|
@ -6256,17 +6257,18 @@ struct llm_build_context {
|
|||
case LLAMA_POOLING_TYPE_MEAN:
|
||||
{
|
||||
cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_mean);
|
||||
cb(cur, "result_embd_pooled", -1);
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_CLS:
|
||||
{
|
||||
cur = ggml_get_rows(ctx0, cur, inp_cls);
|
||||
cb(cur, "result_embd_pooled", -1);
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
||||
{
|
||||
GGML_ASSERT(false && "Invalid pooling type");
|
||||
} break;
|
||||
}
|
||||
cb(cur, "result_embd", -1);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
|
@ -8281,7 +8283,7 @@ static int llama_decode_internal(
|
|||
// token or sequence embeddings
|
||||
embd = gf->nodes[gf->n_nodes - 1];
|
||||
|
||||
GGML_ASSERT(strcmp(embd->name, "result_embd") == 0);
|
||||
GGML_ASSERT(strcmp(embd->name, "result_embd") == 0 || strcmp(embd->name, "result_embd_pooled") == 0);
|
||||
} else {
|
||||
if (strcmp(res->name, "result_output") == 0) {
|
||||
// the token embeddings could be the second to last tensor, or the third to last tensor
|
||||
|
@ -8413,6 +8415,8 @@ static int llama_decode_internal(
|
|||
case LLAMA_POOLING_TYPE_CLS:
|
||||
case LLAMA_POOLING_TYPE_MEAN:
|
||||
{
|
||||
GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0);
|
||||
|
||||
// extract sequence embeddings
|
||||
auto & embd_seq_out = lctx.embd_seq;
|
||||
embd_seq_out.clear();
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue