change function name to llama_sampling_prepare

This commit is contained in:
Minsoo Cheong 2024-03-24 11:16:29 +09:00
parent 0a243da7d4
commit bb38278e6a
4 changed files with 6 additions and 6 deletions

View file

@ -174,7 +174,7 @@ static llama_token llama_sampling_sample_impl(
const float mirostat_eta = params.mirostat_eta; const float mirostat_eta = params.mirostat_eta;
std::vector<float> original_logits; std::vector<float> original_logits;
auto cur_p = llama_sampling_configure_token_candidates(ctx_sampling, ctx_main, ctx_cfg, idx, !is_resampling, &original_logits); auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, !is_resampling, &original_logits);
if (!is_resampling) { if (!is_resampling) {
GGML_ASSERT(!original_logits.empty()); GGML_ASSERT(!original_logits.empty());
} }
@ -245,7 +245,7 @@ static llama_token llama_sampling_sample_impl(
return id; return id;
} }
static llama_token_data_array llama_sampling_configure_token_candidates_impl( static llama_token_data_array llama_sampling_prepare_impl(
struct llama_sampling_context * ctx_sampling, struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main, struct llama_context * ctx_main,
struct llama_context * ctx_cfg, struct llama_context * ctx_cfg,
@ -329,14 +329,14 @@ llama_token llama_sampling_sample(
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false); return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false);
} }
llama_token_data_array llama_sampling_configure_token_candidates( llama_token_data_array llama_sampling_prepare(
struct llama_sampling_context * ctx_sampling, struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main, struct llama_context * ctx_main,
struct llama_context * ctx_cfg, struct llama_context * ctx_cfg,
const int idx, const int idx,
bool apply_grammar, bool apply_grammar,
std::vector<float> * original_logits) { std::vector<float> * original_logits) {
return llama_sampling_configure_token_candidates_impl(ctx_sampling,ctx_main, ctx_cfg, idx, apply_grammar, original_logits); return llama_sampling_prepare_impl(ctx_sampling,ctx_main, ctx_cfg, idx, apply_grammar, original_logits);
} }
void llama_sampling_accept( void llama_sampling_accept(

View file

@ -132,7 +132,7 @@ llama_token llama_sampling_sample(
int idx = 0); int idx = 0);
// Prepares and adjusts the set of token candidates for sampling based on penalties, biases, and sampling parameters. // Prepares and adjusts the set of token candidates for sampling based on penalties, biases, and sampling parameters.
llama_token_data_array llama_sampling_configure_token_candidates( llama_token_data_array llama_sampling_prepare(
struct llama_sampling_context * ctx_sampling, struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main, struct llama_context * ctx_main,
struct llama_context * ctx_cfg, struct llama_context * ctx_cfg,

View file

@ -219,7 +219,7 @@ int main(int argc, char ** argv) {
if (params.sparams.temp > 0) { if (params.sparams.temp > 0) {
// stochastic verification // stochastic verification
llama_token_data_array dist_tgt = llama_sampling_configure_token_candidates(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft], true, NULL); llama_token_data_array dist_tgt = llama_sampling_prepare(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft], true, NULL);
llama_sample_softmax(ctx_tgt, &dist_tgt); llama_sample_softmax(ctx_tgt, &dist_tgt);
float p_tgt = 0, p_dft = 0; float p_tgt = 0, p_dft = 0;

BIN
retrieval Executable file

Binary file not shown.