Fix llama_cpp and Llama type signatures. Closes #221

This commit is contained in:
Andrei Betlen 2023-05-19 11:59:33 -04:00 committed by Don Mahurin
parent 601b19203f
commit fda33ddbd5

View file

@ -206,7 +206,7 @@ _lib.llama_free.restype = None
# nthread - how many threads to use. If <=0, will use std::thread::hardware_concurrency(), else the number given # nthread - how many threads to use. If <=0, will use std::thread::hardware_concurrency(), else the number given
def llama_model_quantize( def llama_model_quantize(
fname_inp: bytes, fname_out: bytes, ftype: c_int, nthread: c_int fname_inp: bytes, fname_out: bytes, ftype: c_int, nthread: c_int
) -> c_int: ) -> int:
return _lib.llama_model_quantize(fname_inp, fname_out, ftype, nthread) return _lib.llama_model_quantize(fname_inp, fname_out, ftype, nthread)
@ -225,7 +225,7 @@ def llama_apply_lora_from_file(
path_lora: c_char_p, path_lora: c_char_p,
path_base_model: c_char_p, path_base_model: c_char_p,
n_threads: c_int, n_threads: c_int,
) -> c_int: ) -> int:
return _lib.llama_apply_lora_from_file(ctx, path_lora, path_base_model, n_threads) return _lib.llama_apply_lora_from_file(ctx, path_lora, path_base_model, n_threads)
@ -234,7 +234,7 @@ _lib.llama_apply_lora_from_file.restype = c_int
# Returns the number of tokens in the KV cache # Returns the number of tokens in the KV cache
def llama_get_kv_cache_token_count(ctx: llama_context_p) -> c_int: def llama_get_kv_cache_token_count(ctx: llama_context_p) -> int:
return _lib.llama_get_kv_cache_token_count(ctx) return _lib.llama_get_kv_cache_token_count(ctx)
@ -253,7 +253,7 @@ _lib.llama_set_rng_seed.restype = None
# Returns the maximum size in bytes of the state (rng, logits, embedding # Returns the maximum size in bytes of the state (rng, logits, embedding
# and kv_cache) - will often be smaller after compacting tokens # and kv_cache) - will often be smaller after compacting tokens
def llama_get_state_size(ctx: llama_context_p) -> c_size_t: def llama_get_state_size(ctx: llama_context_p) -> int:
return _lib.llama_get_state_size(ctx) return _lib.llama_get_state_size(ctx)
@ -293,7 +293,7 @@ def llama_load_session_file(
tokens_out, # type: Array[llama_token] tokens_out, # type: Array[llama_token]
n_token_capacity: c_size_t, n_token_capacity: c_size_t,
n_token_count_out, # type: _Pointer[c_size_t] n_token_count_out, # type: _Pointer[c_size_t]
) -> c_size_t: ) -> int:
return _lib.llama_load_session_file( return _lib.llama_load_session_file(
ctx, path_session, tokens_out, n_token_capacity, n_token_count_out ctx, path_session, tokens_out, n_token_capacity, n_token_count_out
) )
@ -314,7 +314,7 @@ def llama_save_session_file(
path_session: bytes, path_session: bytes,
tokens, # type: Array[llama_token] tokens, # type: Array[llama_token]
n_token_count: c_size_t, n_token_count: c_size_t,
) -> c_size_t: ) -> int:
return _lib.llama_save_session_file(ctx, path_session, tokens, n_token_count) return _lib.llama_save_session_file(ctx, path_session, tokens, n_token_count)
@ -337,7 +337,7 @@ def llama_eval(
n_tokens: c_int, n_tokens: c_int,
n_past: c_int, n_past: c_int,
n_threads: c_int, n_threads: c_int,
) -> c_int: ) -> int:
return _lib.llama_eval(ctx, tokens, n_tokens, n_past, n_threads) return _lib.llama_eval(ctx, tokens, n_tokens, n_past, n_threads)
@ -364,7 +364,7 @@ _lib.llama_tokenize.argtypes = [llama_context_p, c_char_p, llama_token_p, c_int,
_lib.llama_tokenize.restype = c_int _lib.llama_tokenize.restype = c_int
def llama_n_vocab(ctx: llama_context_p) -> c_int: def llama_n_vocab(ctx: llama_context_p) -> int:
return _lib.llama_n_vocab(ctx) return _lib.llama_n_vocab(ctx)
@ -372,7 +372,7 @@ _lib.llama_n_vocab.argtypes = [llama_context_p]
_lib.llama_n_vocab.restype = c_int _lib.llama_n_vocab.restype = c_int
def llama_n_ctx(ctx: llama_context_p) -> c_int: def llama_n_ctx(ctx: llama_context_p) -> int:
return _lib.llama_n_ctx(ctx) return _lib.llama_n_ctx(ctx)
@ -380,7 +380,7 @@ _lib.llama_n_ctx.argtypes = [llama_context_p]
_lib.llama_n_ctx.restype = c_int _lib.llama_n_ctx.restype = c_int
def llama_n_embd(ctx: llama_context_p) -> c_int: def llama_n_embd(ctx: llama_context_p) -> int:
return _lib.llama_n_embd(ctx) return _lib.llama_n_embd(ctx)
@ -426,7 +426,7 @@ _lib.llama_token_to_str.restype = c_char_p
# Special tokens # Special tokens
def llama_token_bos() -> llama_token: def llama_token_bos() -> int:
return _lib.llama_token_bos() return _lib.llama_token_bos()
@ -434,7 +434,7 @@ _lib.llama_token_bos.argtypes = []
_lib.llama_token_bos.restype = llama_token _lib.llama_token_bos.restype = llama_token
def llama_token_eos() -> llama_token: def llama_token_eos() -> int:
return _lib.llama_token_eos() return _lib.llama_token_eos()
@ -442,7 +442,7 @@ _lib.llama_token_eos.argtypes = []
_lib.llama_token_eos.restype = llama_token _lib.llama_token_eos.restype = llama_token
def llama_token_nl() -> llama_token: def llama_token_nl() -> int:
return _lib.llama_token_nl() return _lib.llama_token_nl()
@ -625,7 +625,7 @@ def llama_sample_token_mirostat(
eta: c_float, eta: c_float,
m: c_int, m: c_int,
mu, # type: _Pointer[c_float] mu, # type: _Pointer[c_float]
) -> llama_token: ) -> int:
return _lib.llama_sample_token_mirostat(ctx, candidates, tau, eta, m, mu) return _lib.llama_sample_token_mirostat(ctx, candidates, tau, eta, m, mu)
@ -651,7 +651,7 @@ def llama_sample_token_mirostat_v2(
tau: c_float, tau: c_float,
eta: c_float, eta: c_float,
mu, # type: _Pointer[c_float] mu, # type: _Pointer[c_float]
) -> llama_token: ) -> int:
return _lib.llama_sample_token_mirostat_v2(ctx, candidates, tau, eta, mu) return _lib.llama_sample_token_mirostat_v2(ctx, candidates, tau, eta, mu)
@ -669,7 +669,7 @@ _lib.llama_sample_token_mirostat_v2.restype = llama_token
def llama_sample_token_greedy( def llama_sample_token_greedy(
ctx: llama_context_p, ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array] candidates, # type: _Pointer[llama_token_data_array]
) -> llama_token: ) -> int:
return _lib.llama_sample_token_greedy(ctx, candidates) return _lib.llama_sample_token_greedy(ctx, candidates)
@ -684,7 +684,7 @@ _lib.llama_sample_token_greedy.restype = llama_token
def llama_sample_token( def llama_sample_token(
ctx: llama_context_p, ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array] candidates, # type: _Pointer[llama_token_data_array]
) -> llama_token: ) -> int:
return _lib.llama_sample_token(ctx, candidates) return _lib.llama_sample_token(ctx, candidates)