add python bindings for functions to get and set the whole llama state

(rng, logits, embedding and kv_cache)
This commit is contained in:
xaedes 2023-04-14 03:16:50 +02:00
parent 5f6b715071
commit ed6b64fb98
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -97,6 +97,15 @@ lib.llama_reset_timings.restype = None
lib.llama_print_system_info.argtypes = [] lib.llama_print_system_info.argtypes = []
lib.llama_print_system_info.restype = c_char_p lib.llama_print_system_info.restype = c_char_p
lib.llama_get_state_size.argtypes = [llama_context_p]
lib.llama_get_state_size.restype = c_size_t
lib.llama_copy_state_data.argtypes = [llama_context_p, c_ubyte_p]
lib.llama_copy_state_data.restype = c_size_t
lib.llama_set_state_data.argtypes = [llama_context_p, c_ubyte_p]
lib.llama_set_state_data.restype = c_size_t
# Python functions # Python functions
def llama_context_default_params() -> llama_context_params: def llama_context_default_params() -> llama_context_params:
params = lib.llama_context_default_params() params = lib.llama_context_default_params()
@ -171,3 +180,12 @@ def llama_reset_timings(ctx: llama_context_p):
def llama_print_system_info() -> str: def llama_print_system_info() -> str:
"""Print system informaiton""" """Print system informaiton"""
return lib.llama_print_system_info().decode('utf-8') return lib.llama_print_system_info().decode('utf-8')
def llama_get_state_size(ctx: llama_context_p) -> c_size_t:
return lib.llama_get_state_size(ctx)
def llama_copy_state_data(ctx: llama_context_p, dst: c_ubyte_p) -> c_size_t:
return lib.llama_copy_state_data(ctx, dst)
def llama_set_state_data(ctx: llama_context_p, src: c_ubyte_p) -> c_size_t:
return lib.llama_set_state_data(ctx, src)