llama : assert pooling tensor
This commit is contained in:
parent
79e4eede23
commit
fc9af156ff
1 changed files with 6 additions and 2 deletions
|
@ -6246,6 +6246,7 @@ struct llm_build_context {
|
||||||
|
|
||||||
// final output
|
// final output
|
||||||
cur = inpL;
|
cur = inpL;
|
||||||
|
cb(cur, "result_embd", -1);
|
||||||
|
|
||||||
// pooling layer
|
// pooling layer
|
||||||
switch (pooling_type) {
|
switch (pooling_type) {
|
||||||
|
@ -6256,17 +6257,18 @@ struct llm_build_context {
|
||||||
case LLAMA_POOLING_TYPE_MEAN:
|
case LLAMA_POOLING_TYPE_MEAN:
|
||||||
{
|
{
|
||||||
cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_mean);
|
cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_mean);
|
||||||
|
cb(cur, "result_embd_pooled", -1);
|
||||||
} break;
|
} break;
|
||||||
case LLAMA_POOLING_TYPE_CLS:
|
case LLAMA_POOLING_TYPE_CLS:
|
||||||
{
|
{
|
||||||
cur = ggml_get_rows(ctx0, cur, inp_cls);
|
cur = ggml_get_rows(ctx0, cur, inp_cls);
|
||||||
|
cb(cur, "result_embd_pooled", -1);
|
||||||
} break;
|
} break;
|
||||||
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(false && "Invalid pooling type");
|
GGML_ASSERT(false && "Invalid pooling type");
|
||||||
} break;
|
} break;
|
||||||
}
|
}
|
||||||
cb(cur, "result_embd", -1);
|
|
||||||
|
|
||||||
ggml_build_forward_expand(gf, cur);
|
ggml_build_forward_expand(gf, cur);
|
||||||
|
|
||||||
|
@ -8281,7 +8283,7 @@ static int llama_decode_internal(
|
||||||
// token or sequence embeddings
|
// token or sequence embeddings
|
||||||
embd = gf->nodes[gf->n_nodes - 1];
|
embd = gf->nodes[gf->n_nodes - 1];
|
||||||
|
|
||||||
GGML_ASSERT(strcmp(embd->name, "result_embd") == 0);
|
GGML_ASSERT(strcmp(embd->name, "result_embd") == 0 || strcmp(embd->name, "result_embd_pooled") == 0);
|
||||||
} else {
|
} else {
|
||||||
if (strcmp(res->name, "result_output") == 0) {
|
if (strcmp(res->name, "result_output") == 0) {
|
||||||
// the token embeddings could be the second to last tensor, or the third to last tensor
|
// the token embeddings could be the second to last tensor, or the third to last tensor
|
||||||
|
@ -8413,6 +8415,8 @@ static int llama_decode_internal(
|
||||||
case LLAMA_POOLING_TYPE_CLS:
|
case LLAMA_POOLING_TYPE_CLS:
|
||||||
case LLAMA_POOLING_TYPE_MEAN:
|
case LLAMA_POOLING_TYPE_MEAN:
|
||||||
{
|
{
|
||||||
|
GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0);
|
||||||
|
|
||||||
// extract sequence embeddings
|
// extract sequence embeddings
|
||||||
auto & embd_seq_out = lctx.embd_seq;
|
auto & embd_seq_out = lctx.embd_seq;
|
||||||
embd_seq_out.clear();
|
embd_seq_out.clear();
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue