Fix array type signatures

This commit is contained in:
Andrei Betlen 2023-03-31 02:08:20 -04:00 committed by Don Mahurin
parent a3da39af79
commit 019650f416

View file

@ -111,7 +111,7 @@ _lib.llama_model_quantize.restype = c_int
# Returns 0 on success
def llama_eval(
ctx: llama_context_p,
tokens: llama_token_p,
tokens: ctypes.Array[llama_token],
n_tokens: c_int,
n_past: c_int,
n_threads: c_int,
@ -131,7 +131,7 @@ _lib.llama_eval.restype = c_int
def llama_tokenize(
ctx: llama_context_p,
text: bytes,
tokens: llama_token_p,
tokens: ctypes.Array[llama_token],
n_max_tokens: c_int,
add_bos: c_bool,
) -> c_int:
@ -163,7 +163,7 @@ _lib.llama_n_ctx.restype = c_int
# Can be mutated in order to change the probabilities of the next token
# Rows: n_tokens
# Cols: n_vocab
def llama_get_logits(ctx: llama_context_p):
def llama_get_logits(ctx: llama_context_p) -> ctypes.Array[c_float]:
return _lib.llama_get_logits(ctx)
@ -173,7 +173,7 @@ _lib.llama_get_logits.restype = POINTER(c_float)
# Get the embeddings for the input
# shape: [n_embd] (1-dimensional)
def llama_get_embeddings(ctx: llama_context_p):
def llama_get_embeddings(ctx: llama_context_p) -> ctypes.Array[c_float]:
return _lib.llama_get_embeddings(ctx)
@ -211,7 +211,7 @@ _lib.llama_token_eos.restype = llama_token
# TODO: improve the last_n_tokens interface ?
def llama_sample_top_p_top_k(
ctx: llama_context_p,
last_n_tokens_data: llama_token_p,
last_n_tokens_data: ctypes.Array[llama_token],
last_n_tokens_size: c_int,
top_k: c_int,
top_p: c_double,