add python bindings for functions to get and set the whole llama state
(rng, logits, embedding and kv_cache)
This commit is contained in:
parent
5f6b715071
commit
ed6b64fb98
1 changed files with 18 additions and 0 deletions
|
@ -97,6 +97,15 @@ lib.llama_reset_timings.restype = None
|
|||
lib.llama_print_system_info.argtypes = []
|
||||
lib.llama_print_system_info.restype = c_char_p
|
||||
|
||||
lib.llama_get_state_size.argtypes = [llama_context_p]
|
||||
lib.llama_get_state_size.restype = c_size_t
|
||||
|
||||
lib.llama_copy_state_data.argtypes = [llama_context_p, c_ubyte_p]
|
||||
lib.llama_copy_state_data.restype = c_size_t
|
||||
|
||||
lib.llama_set_state_data.argtypes = [llama_context_p, c_ubyte_p]
|
||||
lib.llama_set_state_data.restype = c_size_t
|
||||
|
||||
# Python functions
|
||||
def llama_context_default_params() -> llama_context_params:
|
||||
params = lib.llama_context_default_params()
|
||||
|
@ -171,3 +180,12 @@ def llama_reset_timings(ctx: llama_context_p):
|
|||
def llama_print_system_info() -> str:
|
||||
"""Print system informaiton"""
|
||||
return lib.llama_print_system_info().decode('utf-8')
|
||||
|
||||
def llama_get_state_size(ctx: llama_context_p) -> c_size_t:
|
||||
return lib.llama_get_state_size(ctx)
|
||||
|
||||
def llama_copy_state_data(ctx: llama_context_p, dst: c_ubyte_p) -> c_size_t:
|
||||
return lib.llama_copy_state_data(ctx, dst)
|
||||
|
||||
def llama_set_state_data(ctx: llama_context_p, src: c_ubyte_p) -> c_size_t:
|
||||
return lib.llama_set_state_data(ctx, src)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue