diff --git a/examples/llama_cpp.py b/examples/llama_cpp.py index bafc40112..1e8054e5d 100644 --- a/examples/llama_cpp.py +++ b/examples/llama_cpp.py @@ -111,7 +111,7 @@ _lib.llama_model_quantize.restype = c_int # Returns 0 on success def llama_eval( ctx: llama_context_p, - tokens: llama_token_p, + tokens: ctypes.Array[llama_token], n_tokens: c_int, n_past: c_int, n_threads: c_int, @@ -131,7 +131,7 @@ _lib.llama_eval.restype = c_int def llama_tokenize( ctx: llama_context_p, text: bytes, - tokens: llama_token_p, + tokens: ctypes.Array[llama_token], n_max_tokens: c_int, add_bos: c_bool, ) -> c_int: @@ -163,7 +163,7 @@ _lib.llama_n_ctx.restype = c_int # Can be mutated in order to change the probabilities of the next token # Rows: n_tokens # Cols: n_vocab -def llama_get_logits(ctx: llama_context_p): +def llama_get_logits(ctx: llama_context_p) -> ctypes.Array[c_float]: return _lib.llama_get_logits(ctx) @@ -173,7 +173,7 @@ _lib.llama_get_logits.restype = POINTER(c_float) # Get the embeddings for the input # shape: [n_embd] (1-dimensional) -def llama_get_embeddings(ctx: llama_context_p): +def llama_get_embeddings(ctx: llama_context_p) -> ctypes.Array[c_float]: return _lib.llama_get_embeddings(ctx) @@ -211,7 +211,7 @@ _lib.llama_token_eos.restype = llama_token # TODO: improve the last_n_tokens interface ? def llama_sample_top_p_top_k( ctx: llama_context_p, - last_n_tokens_data: llama_token_p, + last_n_tokens_data: ctypes.Array[llama_token], last_n_tokens_size: c_int, top_k: c_int, top_p: c_double,