From a7a6d88793deaf73629adffefc9e820dda5c52ef Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 31 Mar 2023 03:20:15 -0400 Subject: [PATCH] Fix ctypes typing issue for Arrays --- examples/llama_cpp.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/llama_cpp.py b/examples/llama_cpp.py index 1e8054e5d..2a43ca328 100644 --- a/examples/llama_cpp.py +++ b/examples/llama_cpp.py @@ -9,8 +9,8 @@ from ctypes import ( c_bool, POINTER, Structure, + Array ) - import pathlib from itertools import chain @@ -111,7 +111,7 @@ _lib.llama_model_quantize.restype = c_int # Returns 0 on success def llama_eval( ctx: llama_context_p, - tokens: ctypes.Array[llama_token], + tokens, # type: Array[llama_token] n_tokens: c_int, n_past: c_int, n_threads: c_int, @@ -131,7 +131,7 @@ _lib.llama_eval.restype = c_int def llama_tokenize( ctx: llama_context_p, text: bytes, - tokens: ctypes.Array[llama_token], + tokens, # type: Array[llama_token] n_max_tokens: c_int, add_bos: c_bool, ) -> c_int: @@ -163,7 +163,7 @@ _lib.llama_n_ctx.restype = c_int # Can be mutated in order to change the probabilities of the next token # Rows: n_tokens # Cols: n_vocab -def llama_get_logits(ctx: llama_context_p) -> ctypes.Array[c_float]: +def llama_get_logits(ctx: llama_context_p): return _lib.llama_get_logits(ctx) @@ -173,7 +173,7 @@ _lib.llama_get_logits.restype = POINTER(c_float) # Get the embeddings for the input # shape: [n_embd] (1-dimensional) -def llama_get_embeddings(ctx: llama_context_p) -> ctypes.Array[c_float]: +def llama_get_embeddings(ctx: llama_context_p): return _lib.llama_get_embeddings(ctx) @@ -211,7 +211,7 @@ _lib.llama_token_eos.restype = llama_token # TODO: improve the last_n_tokens interface ? def llama_sample_top_p_top_k( ctx: llama_context_p, - last_n_tokens_data: ctypes.Array[llama_token], + last_n_tokens_data, # type: Array[llama_token] last_n_tokens_size: c_int, top_k: c_int, top_p: c_double,