Add llama_get_num_logits() function to Llama.cpp API
This commit is contained in:
parent
cf348a60e0
commit
779d7969c0
2 changed files with 6 additions and 0 deletions
|
@ -2779,6 +2779,10 @@ float * llama_get_logits(struct llama_context * ctx) {
|
||||||
return ctx->logits.data();
|
return ctx->logits.data();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t llama_get_num_logits(struct llama_context * ctx) {
|
||||||
|
return ctx->logits.size();
|
||||||
|
}
|
||||||
|
|
||||||
float * llama_get_embeddings(struct llama_context * ctx) {
|
float * llama_get_embeddings(struct llama_context * ctx) {
|
||||||
return ctx->embedding.data();
|
return ctx->embedding.data();
|
||||||
}
|
}
|
||||||
|
|
2
llama.h
2
llama.h
|
@ -178,6 +178,8 @@ extern "C" {
|
||||||
// Cols: n_vocab
|
// Cols: n_vocab
|
||||||
LLAMA_API float * llama_get_logits(struct llama_context * ctx);
|
LLAMA_API float * llama_get_logits(struct llama_context * ctx);
|
||||||
|
|
||||||
|
LLAMA_API size_t llama_get_num_logits(struct llama_context * ctx);
|
||||||
|
|
||||||
// Get the embeddings for the input
|
// Get the embeddings for the input
|
||||||
// shape: [n_embd] (1-dimensional)
|
// shape: [n_embd] (1-dimensional)
|
||||||
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
|
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue