sampling : one sequence per sampling context
ggml-ci
This commit is contained in:
parent
370359e5ba
commit
5261aee8d8
4 changed files with 28 additions and 86 deletions
|
@ -1,50 +1,14 @@
|
|||
#include "sampling.h"
|
||||
|
||||
llama_sampling_context::~llama_sampling_context() {
|
||||
for (auto & it : sequence_contexts) {
|
||||
if (it.second.grammar != NULL) {
|
||||
llama_grammar_free(it.second.grammar);
|
||||
it.second.grammar = NULL;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
llama_sampling_context llama_sampling_context_init(
|
||||
const struct gpt_params & params,
|
||||
llama_grammar * grammar) {
|
||||
llama_sampling_context result;
|
||||
llama_sampling_context result;
|
||||
|
||||
result.params = params.sampling_params;
|
||||
result.grammar = grammar;
|
||||
return result;
|
||||
}
|
||||
result.params = params.sampling_params;
|
||||
result.grammar = grammar;
|
||||
|
||||
// Note: Creates the context if it doesn't exist, so this always return something.
|
||||
llama_sampler_sequence_context & llama_sampling_get_sequence_context(
|
||||
llama_sampling_context & ctx_sampling,
|
||||
const llama_seq_id seq) {
|
||||
const auto it = ctx_sampling.sequence_contexts.find(seq);
|
||||
if (it != ctx_sampling.sequence_contexts.end()) {
|
||||
return it->second;
|
||||
}
|
||||
llama_sampler_sequence_context new_ctx = {
|
||||
2.0f * ctx_sampling.params.mirostat_tau,
|
||||
ctx_sampling.grammar != NULL ? llama_grammar_copy(ctx_sampling.grammar) : NULL,
|
||||
};
|
||||
return ctx_sampling.sequence_contexts.insert({seq, new_ctx}).first->second;
|
||||
}
|
||||
|
||||
bool llama_sampling_context_reset(
|
||||
llama_sampling_context & ctx_sampling,
|
||||
const llama_seq_id seq) {
|
||||
const auto it = ctx_sampling.sequence_contexts.find(seq);
|
||||
if (it == ctx_sampling.sequence_contexts.end()) return false;
|
||||
if (it->second.grammar != NULL) {
|
||||
llama_grammar_free(it->second.grammar);
|
||||
it->second.grammar = NULL;
|
||||
}
|
||||
ctx_sampling.sequence_contexts.erase(it);
|
||||
return true;
|
||||
return result;
|
||||
}
|
||||
|
||||
llama_token llama_sampling_sample(
|
||||
|
@ -53,8 +17,7 @@ llama_token llama_sampling_sample(
|
|||
struct llama_sampling_context & ctx_sampling,
|
||||
const std::vector<llama_token> & last_tokens,
|
||||
std::vector<llama_token_data> & candidates,
|
||||
const int idx,
|
||||
llama_seq_id seq) {
|
||||
const int idx) {
|
||||
const int n_ctx = llama_n_ctx(ctx);
|
||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||
|
||||
|
@ -115,10 +78,8 @@ llama_token llama_sampling_sample(
|
|||
}
|
||||
}
|
||||
|
||||
llama_sampler_sequence_context & ctx_seq = llama_sampling_get_sequence_context(ctx_sampling, seq);
|
||||
|
||||
if (ctx_seq.grammar != NULL) {
|
||||
llama_sample_grammar(ctx, &cur_p, ctx_seq.grammar);
|
||||
if (ctx_sampling.grammar != NULL) {
|
||||
llama_sample_grammar(ctx, &cur_p, ctx_sampling.grammar);
|
||||
}
|
||||
|
||||
if (temp <= 0) {
|
||||
|
@ -128,10 +89,10 @@ llama_token llama_sampling_sample(
|
|||
if (mirostat == 1) {
|
||||
const int mirostat_m = 100;
|
||||
llama_sample_temp(ctx, &cur_p, temp);
|
||||
id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_seq.mirostat_mu);
|
||||
id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling.mirostat_mu);
|
||||
} else if (mirostat == 2) {
|
||||
llama_sample_temp(ctx, &cur_p, temp);
|
||||
id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &ctx_seq.mirostat_mu);
|
||||
id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling.mirostat_mu);
|
||||
} else {
|
||||
// Temperature sampling
|
||||
size_t min_keep = std::max(1, params.n_probs);
|
||||
|
@ -158,8 +119,8 @@ llama_token llama_sampling_sample(
|
|||
}
|
||||
}
|
||||
|
||||
if (ctx_seq.grammar != NULL) {
|
||||
llama_grammar_accept_token(ctx, ctx_seq.grammar, id);
|
||||
if (ctx_sampling.grammar != NULL) {
|
||||
llama_grammar_accept_token(ctx, ctx_sampling.grammar, id);
|
||||
}
|
||||
|
||||
return id;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue