Update sampling api

This commit is contained in:
Andrei Betlen 2023-05-01 14:47:55 -04:00 committed by Don Mahurin
parent 78531e5d05
commit c26e9bf1c1

View file

@ -495,7 +495,9 @@ _lib.llama_sample_softmax.restype = None
# @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
def llama_sample_top_k(ctx: llama_context_p, candidates, k: c_int, min_keep: c_int):
def llama_sample_top_k(
ctx: llama_context_p, candidates, k: c_int, min_keep: c_size_t = c_size_t(1)
):
return _lib.llama_sample_top_k(ctx, candidates, k, min_keep)
@ -503,13 +505,15 @@ _lib.llama_sample_top_k.argtypes = [
llama_context_p,
llama_token_data_array_p,
c_int,
c_int,
c_size_t,
]
_lib.llama_sample_top_k.restype = None
# @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
def llama_sample_top_p(ctx: llama_context_p, candidates, p: c_float, min_keep: c_int):
def llama_sample_top_p(
ctx: llama_context_p, candidates, p: c_float, min_keep: c_size_t = c_size_t(1)
):
return _lib.llama_sample_top_p(ctx, candidates, p, min_keep)
@ -517,14 +521,14 @@ _lib.llama_sample_top_p.argtypes = [
llama_context_p,
llama_token_data_array_p,
c_float,
c_int,
c_size_t,
]
_lib.llama_sample_top_p.restype = None
# @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
def llama_sample_tail_free(
ctx: llama_context_p, candidates, z: c_float, min_keep: c_int
ctx: llama_context_p, candidates, z: c_float, min_keep: c_size_t = c_size_t(1)
):
return _lib.llama_sample_tail_free(ctx, candidates, z, min_keep)
@ -533,13 +537,15 @@ _lib.llama_sample_tail_free.argtypes = [
llama_context_p,
llama_token_data_array_p,
c_float,
c_int,
c_size_t,
]
_lib.llama_sample_tail_free.restype = None
# @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
def llama_sample_typical(ctx: llama_context_p, candidates, p: c_float, min_keep: c_int):
def llama_sample_typical(
ctx: llama_context_p, candidates, p: c_float, min_keep: c_size_t = c_size_t(1)
):
return _lib.llama_sample_typical(ctx, candidates, p, min_keep)
@ -547,7 +553,7 @@ _lib.llama_sample_typical.argtypes = [
llama_context_p,
llama_token_data_array_p,
c_float,
c_int,
c_size_t,
]
_lib.llama_sample_typical.restype = None