fixup! sampling: separate rng per sampling context
This commit is contained in:
parent
123eaf054f
commit
760db9ee35
6 changed files with 23 additions and 24 deletions
|
@ -243,6 +243,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
|||
return true;
|
||||
}
|
||||
params.seed = std::stoul(argv[i]);
|
||||
sparams.seed = std::stoul(argv[i]);
|
||||
return true;
|
||||
}
|
||||
if (arg == "-t" || arg == "--threads") {
|
||||
|
|
|
@ -35,7 +35,7 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
|
|||
|
||||
result->prev.resize(params.n_prev);
|
||||
|
||||
llama_sampling_set_rng_seed(result, LLAMA_DEFAULT_SEED);
|
||||
llama_sampling_set_rng_seed(result, params.seed);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
|
|
@ -40,6 +40,7 @@ typedef struct llama_sampling_params {
|
|||
float mirostat_tau = 5.00f; // target entropy
|
||||
float mirostat_eta = 0.10f; // learning rate
|
||||
bool penalize_nl = false; // consider newlines as a repeatable token
|
||||
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
|
||||
|
||||
std::vector<llama_sampler_type> samplers_sequence = {
|
||||
llama_sampler_type::TOP_K,
|
||||
|
|
|
@ -107,7 +107,6 @@ int main(int argc, char ** argv){
|
|||
bool has_eos = false;
|
||||
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
|
||||
llama_sampling_set_rng_seed(ctx_sampling, params.seed);
|
||||
|
||||
std::vector<llama_token> draft;
|
||||
|
||||
|
|
|
@ -520,7 +520,6 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
|
||||
llama_sampling_set_rng_seed(ctx_sampling, params.seed);
|
||||
|
||||
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
|
||||
// predict
|
||||
|
|
|
@ -854,7 +854,7 @@ struct server_context {
|
|||
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
|
||||
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
|
||||
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
|
||||
slot.params.seed = json_value(data, "seed", default_params.seed);
|
||||
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
|
||||
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
||||
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
|
||||
|
||||
|
@ -1028,7 +1028,6 @@ struct server_context {
|
|||
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
}
|
||||
llama_sampling_set_rng_seed(slot.ctx_sampling, slot.params.seed);
|
||||
}
|
||||
|
||||
slot.command = SLOT_COMMAND_LOAD_PROMPT;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue