Fix ctypes typing issue for Arrays

This commit is contained in:
Andrei Betlen 2023-03-31 03:20:15 -04:00 committed by Don Mahurin
parent 019650f416
commit a7a6d88793

View file

@ -9,8 +9,8 @@ from ctypes import (
c_bool, c_bool,
POINTER, POINTER,
Structure, Structure,
Array
) )
import pathlib import pathlib
from itertools import chain from itertools import chain
@ -111,7 +111,7 @@ _lib.llama_model_quantize.restype = c_int
# Returns 0 on success # Returns 0 on success
def llama_eval( def llama_eval(
ctx: llama_context_p, ctx: llama_context_p,
tokens: ctypes.Array[llama_token], tokens, # type: Array[llama_token]
n_tokens: c_int, n_tokens: c_int,
n_past: c_int, n_past: c_int,
n_threads: c_int, n_threads: c_int,
@ -131,7 +131,7 @@ _lib.llama_eval.restype = c_int
def llama_tokenize( def llama_tokenize(
ctx: llama_context_p, ctx: llama_context_p,
text: bytes, text: bytes,
tokens: ctypes.Array[llama_token], tokens, # type: Array[llama_token]
n_max_tokens: c_int, n_max_tokens: c_int,
add_bos: c_bool, add_bos: c_bool,
) -> c_int: ) -> c_int:
@ -163,7 +163,7 @@ _lib.llama_n_ctx.restype = c_int
# Can be mutated in order to change the probabilities of the next token # Can be mutated in order to change the probabilities of the next token
# Rows: n_tokens # Rows: n_tokens
# Cols: n_vocab # Cols: n_vocab
def llama_get_logits(ctx: llama_context_p) -> ctypes.Array[c_float]: def llama_get_logits(ctx: llama_context_p):
return _lib.llama_get_logits(ctx) return _lib.llama_get_logits(ctx)
@ -173,7 +173,7 @@ _lib.llama_get_logits.restype = POINTER(c_float)
# Get the embeddings for the input # Get the embeddings for the input
# shape: [n_embd] (1-dimensional) # shape: [n_embd] (1-dimensional)
def llama_get_embeddings(ctx: llama_context_p) -> ctypes.Array[c_float]: def llama_get_embeddings(ctx: llama_context_p):
return _lib.llama_get_embeddings(ctx) return _lib.llama_get_embeddings(ctx)
@ -211,7 +211,7 @@ _lib.llama_token_eos.restype = llama_token
# TODO: improve the last_n_tokens interface ? # TODO: improve the last_n_tokens interface ?
def llama_sample_top_p_top_k( def llama_sample_top_p_top_k(
ctx: llama_context_p, ctx: llama_context_p,
last_n_tokens_data: ctypes.Array[llama_token], last_n_tokens_data, # type: Array[llama_token]
last_n_tokens_size: c_int, last_n_tokens_size: c_int,
top_k: c_int, top_k: c_int,
top_p: c_double, top_p: c_double,