diff --git a/examples/llama_cpp.py b/examples/llama_cpp.py index b5f83baa2..1862605b4 100644 --- a/examples/llama_cpp.py +++ b/examples/llama_cpp.py @@ -42,6 +42,7 @@ llama_token_data_p = POINTER(llama_token_data) llama_progress_callback = ctypes.CFUNCTYPE(None, c_double, c_void_p) + class llama_context_params(Structure): _fields_ = [ ("n_ctx", c_int), # text context @@ -163,6 +164,14 @@ _lib.llama_n_ctx.argtypes = [llama_context_p] _lib.llama_n_ctx.restype = c_int +def llama_n_embd(ctx: llama_context_p) -> c_int: + return _lib.llama_n_ctx(ctx) + + +_lib.llama_n_embd.argtypes = [llama_context_p] +_lib.llama_n_embd.restype = c_int + + # Token logits obtained from the last call to llama_eval() # The logits for the last token are stored in the last row # Can be mutated in order to change the probabilities of the next token