From 779d7969c01aa6dd30fe510b68a04482fafa3619 Mon Sep 17 00:00:00 2001 From: Branden Butler Date: Wed, 10 May 2023 14:32:05 -0500 Subject: [PATCH] Add llama_get_num_logits() function to Llama.cpp API --- llama.cpp | 4 ++++ llama.h | 2 ++ 2 files changed, 6 insertions(+) diff --git a/llama.cpp b/llama.cpp index 4bba93a11..efff9d9ad 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2779,6 +2779,10 @@ float * llama_get_logits(struct llama_context * ctx) { return ctx->logits.data(); } +size_t llama_get_num_logits(struct llama_context * ctx) { + return ctx->logits.size(); +} + float * llama_get_embeddings(struct llama_context * ctx) { return ctx->embedding.data(); } diff --git a/llama.h b/llama.h index 58c6e0699..25a7d6ef5 100644 --- a/llama.h +++ b/llama.h @@ -178,6 +178,8 @@ extern "C" { // Cols: n_vocab LLAMA_API float * llama_get_logits(struct llama_context * ctx); + LLAMA_API size_t llama_get_num_logits(struct llama_context * ctx); + // Get the embeddings for the input // shape: [n_embd] (1-dimensional) LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);