Update sampling api
This commit is contained in:
parent
78531e5d05
commit
c26e9bf1c1
1 changed files with 14 additions and 8 deletions
|
@ -495,7 +495,9 @@ _lib.llama_sample_softmax.restype = None
|
|||
|
||||
|
||||
# @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||
def llama_sample_top_k(ctx: llama_context_p, candidates, k: c_int, min_keep: c_int):
|
||||
def llama_sample_top_k(
|
||||
ctx: llama_context_p, candidates, k: c_int, min_keep: c_size_t = c_size_t(1)
|
||||
):
|
||||
return _lib.llama_sample_top_k(ctx, candidates, k, min_keep)
|
||||
|
||||
|
||||
|
@ -503,13 +505,15 @@ _lib.llama_sample_top_k.argtypes = [
|
|||
llama_context_p,
|
||||
llama_token_data_array_p,
|
||||
c_int,
|
||||
c_int,
|
||||
c_size_t,
|
||||
]
|
||||
_lib.llama_sample_top_k.restype = None
|
||||
|
||||
|
||||
# @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||
def llama_sample_top_p(ctx: llama_context_p, candidates, p: c_float, min_keep: c_int):
|
||||
def llama_sample_top_p(
|
||||
ctx: llama_context_p, candidates, p: c_float, min_keep: c_size_t = c_size_t(1)
|
||||
):
|
||||
return _lib.llama_sample_top_p(ctx, candidates, p, min_keep)
|
||||
|
||||
|
||||
|
@ -517,14 +521,14 @@ _lib.llama_sample_top_p.argtypes = [
|
|||
llama_context_p,
|
||||
llama_token_data_array_p,
|
||||
c_float,
|
||||
c_int,
|
||||
c_size_t,
|
||||
]
|
||||
_lib.llama_sample_top_p.restype = None
|
||||
|
||||
|
||||
# @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
|
||||
def llama_sample_tail_free(
|
||||
ctx: llama_context_p, candidates, z: c_float, min_keep: c_int
|
||||
ctx: llama_context_p, candidates, z: c_float, min_keep: c_size_t = c_size_t(1)
|
||||
):
|
||||
return _lib.llama_sample_tail_free(ctx, candidates, z, min_keep)
|
||||
|
||||
|
@ -533,13 +537,15 @@ _lib.llama_sample_tail_free.argtypes = [
|
|||
llama_context_p,
|
||||
llama_token_data_array_p,
|
||||
c_float,
|
||||
c_int,
|
||||
c_size_t,
|
||||
]
|
||||
_lib.llama_sample_tail_free.restype = None
|
||||
|
||||
|
||||
# @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
|
||||
def llama_sample_typical(ctx: llama_context_p, candidates, p: c_float, min_keep: c_int):
|
||||
def llama_sample_typical(
|
||||
ctx: llama_context_p, candidates, p: c_float, min_keep: c_size_t = c_size_t(1)
|
||||
):
|
||||
return _lib.llama_sample_typical(ctx, candidates, p, min_keep)
|
||||
|
||||
|
||||
|
@ -547,7 +553,7 @@ _lib.llama_sample_typical.argtypes = [
|
|||
llama_context_p,
|
||||
llama_token_data_array_p,
|
||||
c_float,
|
||||
c_int,
|
||||
c_size_t,
|
||||
]
|
||||
_lib.llama_sample_typical.restype = None
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue