diff --git a/examples/llama_cpp.py b/examples/llama_cpp.py index 601ffc6c2..4e4596ea7 100644 --- a/examples/llama_cpp.py +++ b/examples/llama_cpp.py @@ -495,7 +495,9 @@ _lib.llama_sample_softmax.restype = None # @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 -def llama_sample_top_k(ctx: llama_context_p, candidates, k: c_int, min_keep: c_int): +def llama_sample_top_k( + ctx: llama_context_p, candidates, k: c_int, min_keep: c_size_t = c_size_t(1) +): return _lib.llama_sample_top_k(ctx, candidates, k, min_keep) @@ -503,13 +505,15 @@ _lib.llama_sample_top_k.argtypes = [ llama_context_p, llama_token_data_array_p, c_int, - c_int, + c_size_t, ] _lib.llama_sample_top_k.restype = None # @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 -def llama_sample_top_p(ctx: llama_context_p, candidates, p: c_float, min_keep: c_int): +def llama_sample_top_p( + ctx: llama_context_p, candidates, p: c_float, min_keep: c_size_t = c_size_t(1) +): return _lib.llama_sample_top_p(ctx, candidates, p, min_keep) @@ -517,14 +521,14 @@ _lib.llama_sample_top_p.argtypes = [ llama_context_p, llama_token_data_array_p, c_float, - c_int, + c_size_t, ] _lib.llama_sample_top_p.restype = None # @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. def llama_sample_tail_free( - ctx: llama_context_p, candidates, z: c_float, min_keep: c_int + ctx: llama_context_p, candidates, z: c_float, min_keep: c_size_t = c_size_t(1) ): return _lib.llama_sample_tail_free(ctx, candidates, z, min_keep) @@ -533,13 +537,15 @@ _lib.llama_sample_tail_free.argtypes = [ llama_context_p, llama_token_data_array_p, c_float, - c_int, + c_size_t, ] _lib.llama_sample_tail_free.restype = None # @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. -def llama_sample_typical(ctx: llama_context_p, candidates, p: c_float, min_keep: c_int): +def llama_sample_typical( + ctx: llama_context_p, candidates, p: c_float, min_keep: c_size_t = c_size_t(1) +): return _lib.llama_sample_typical(ctx, candidates, p, min_keep) @@ -547,7 +553,7 @@ _lib.llama_sample_typical.argtypes = [ llama_context_p, llama_token_data_array_p, c_float, - c_int, + c_size_t, ] _lib.llama_sample_typical.restype = None