fixup! sampling: separate rng per sampling context
This commit is contained in:
parent
123eaf054f
commit
760db9ee35
6 changed files with 23 additions and 24 deletions
|
@ -243,6 +243,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
|||
return true;
|
||||
}
|
||||
params.seed = std::stoul(argv[i]);
|
||||
sparams.seed = std::stoul(argv[i]);
|
||||
return true;
|
||||
}
|
||||
if (arg == "-t" || arg == "--threads") {
|
||||
|
|
|
@ -35,7 +35,7 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
|
|||
|
||||
result->prev.resize(params.n_prev);
|
||||
|
||||
llama_sampling_set_rng_seed(result, LLAMA_DEFAULT_SEED);
|
||||
llama_sampling_set_rng_seed(result, params.seed);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
|
|
@ -21,25 +21,26 @@ enum class llama_sampler_type : char {
|
|||
|
||||
// sampling parameters
|
||||
typedef struct llama_sampling_params {
|
||||
int32_t n_prev = 64; // number of previous tokens to remember
|
||||
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
||||
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
|
||||
int32_t top_k = 40; // <= 0 to use vocab size
|
||||
float top_p = 0.95f; // 1.0 = disabled
|
||||
float min_p = 0.05f; // 0.0 = disabled
|
||||
float tfs_z = 1.00f; // 1.0 = disabled
|
||||
float typical_p = 1.00f; // 1.0 = disabled
|
||||
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
|
||||
float dynatemp_range = 0.00f; // 0.0 = disabled
|
||||
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
|
||||
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
||||
float penalty_repeat = 1.00f; // 1.0 = disabled
|
||||
float penalty_freq = 0.00f; // 0.0 = disabled
|
||||
float penalty_present = 0.00f; // 0.0 = disabled
|
||||
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
||||
float mirostat_tau = 5.00f; // target entropy
|
||||
float mirostat_eta = 0.10f; // learning rate
|
||||
bool penalize_nl = false; // consider newlines as a repeatable token
|
||||
int32_t n_prev = 64; // number of previous tokens to remember
|
||||
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
||||
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
|
||||
int32_t top_k = 40; // <= 0 to use vocab size
|
||||
float top_p = 0.95f; // 1.0 = disabled
|
||||
float min_p = 0.05f; // 0.0 = disabled
|
||||
float tfs_z = 1.00f; // 1.0 = disabled
|
||||
float typical_p = 1.00f; // 1.0 = disabled
|
||||
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
|
||||
float dynatemp_range = 0.00f; // 0.0 = disabled
|
||||
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
|
||||
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
||||
float penalty_repeat = 1.00f; // 1.0 = disabled
|
||||
float penalty_freq = 0.00f; // 0.0 = disabled
|
||||
float penalty_present = 0.00f; // 0.0 = disabled
|
||||
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
||||
float mirostat_tau = 5.00f; // target entropy
|
||||
float mirostat_eta = 0.10f; // learning rate
|
||||
bool penalize_nl = false; // consider newlines as a repeatable token
|
||||
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
|
||||
|
||||
std::vector<llama_sampler_type> samplers_sequence = {
|
||||
llama_sampler_type::TOP_K,
|
||||
|
|
|
@ -107,7 +107,6 @@ int main(int argc, char ** argv){
|
|||
bool has_eos = false;
|
||||
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
|
||||
llama_sampling_set_rng_seed(ctx_sampling, params.seed);
|
||||
|
||||
std::vector<llama_token> draft;
|
||||
|
||||
|
|
|
@ -520,7 +520,6 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
|
||||
llama_sampling_set_rng_seed(ctx_sampling, params.seed);
|
||||
|
||||
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
|
||||
// predict
|
||||
|
|
|
@ -854,7 +854,7 @@ struct server_context {
|
|||
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
|
||||
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
|
||||
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
|
||||
slot.params.seed = json_value(data, "seed", default_params.seed);
|
||||
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
|
||||
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
||||
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
|
||||
|
||||
|
@ -1028,7 +1028,6 @@ struct server_context {
|
|||
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
}
|
||||
llama_sampling_set_rng_seed(slot.ctx_sampling, slot.params.seed);
|
||||
}
|
||||
|
||||
slot.command = SLOT_COMMAND_LOAD_PROMPT;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue