fixup! sampling: separate rng per sampling context
This commit is contained in:
parent
123eaf054f
commit
760db9ee35
6 changed files with 23 additions and 24 deletions
|
@ -243,6 +243,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
params.seed = std::stoul(argv[i]);
|
params.seed = std::stoul(argv[i]);
|
||||||
|
sparams.seed = std::stoul(argv[i]);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (arg == "-t" || arg == "--threads") {
|
if (arg == "-t" || arg == "--threads") {
|
||||||
|
|
|
@ -35,7 +35,7 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
|
||||||
|
|
||||||
result->prev.resize(params.n_prev);
|
result->prev.resize(params.n_prev);
|
||||||
|
|
||||||
llama_sampling_set_rng_seed(result, LLAMA_DEFAULT_SEED);
|
llama_sampling_set_rng_seed(result, params.seed);
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,25 +21,26 @@ enum class llama_sampler_type : char {
|
||||||
|
|
||||||
// sampling parameters
|
// sampling parameters
|
||||||
typedef struct llama_sampling_params {
|
typedef struct llama_sampling_params {
|
||||||
int32_t n_prev = 64; // number of previous tokens to remember
|
int32_t n_prev = 64; // number of previous tokens to remember
|
||||||
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
||||||
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
|
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
|
||||||
int32_t top_k = 40; // <= 0 to use vocab size
|
int32_t top_k = 40; // <= 0 to use vocab size
|
||||||
float top_p = 0.95f; // 1.0 = disabled
|
float top_p = 0.95f; // 1.0 = disabled
|
||||||
float min_p = 0.05f; // 0.0 = disabled
|
float min_p = 0.05f; // 0.0 = disabled
|
||||||
float tfs_z = 1.00f; // 1.0 = disabled
|
float tfs_z = 1.00f; // 1.0 = disabled
|
||||||
float typical_p = 1.00f; // 1.0 = disabled
|
float typical_p = 1.00f; // 1.0 = disabled
|
||||||
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
|
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
|
||||||
float dynatemp_range = 0.00f; // 0.0 = disabled
|
float dynatemp_range = 0.00f; // 0.0 = disabled
|
||||||
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
|
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
|
||||||
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
||||||
float penalty_repeat = 1.00f; // 1.0 = disabled
|
float penalty_repeat = 1.00f; // 1.0 = disabled
|
||||||
float penalty_freq = 0.00f; // 0.0 = disabled
|
float penalty_freq = 0.00f; // 0.0 = disabled
|
||||||
float penalty_present = 0.00f; // 0.0 = disabled
|
float penalty_present = 0.00f; // 0.0 = disabled
|
||||||
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
||||||
float mirostat_tau = 5.00f; // target entropy
|
float mirostat_tau = 5.00f; // target entropy
|
||||||
float mirostat_eta = 0.10f; // learning rate
|
float mirostat_eta = 0.10f; // learning rate
|
||||||
bool penalize_nl = false; // consider newlines as a repeatable token
|
bool penalize_nl = false; // consider newlines as a repeatable token
|
||||||
|
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
|
||||||
|
|
||||||
std::vector<llama_sampler_type> samplers_sequence = {
|
std::vector<llama_sampler_type> samplers_sequence = {
|
||||||
llama_sampler_type::TOP_K,
|
llama_sampler_type::TOP_K,
|
||||||
|
|
|
@ -107,7 +107,6 @@ int main(int argc, char ** argv){
|
||||||
bool has_eos = false;
|
bool has_eos = false;
|
||||||
|
|
||||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
|
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
|
||||||
llama_sampling_set_rng_seed(ctx_sampling, params.seed);
|
|
||||||
|
|
||||||
std::vector<llama_token> draft;
|
std::vector<llama_token> draft;
|
||||||
|
|
||||||
|
|
|
@ -520,7 +520,6 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
|
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
|
||||||
llama_sampling_set_rng_seed(ctx_sampling, params.seed);
|
|
||||||
|
|
||||||
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
|
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
|
||||||
// predict
|
// predict
|
||||||
|
|
|
@ -854,7 +854,7 @@ struct server_context {
|
||||||
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
|
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
|
||||||
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
|
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
|
||||||
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
|
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
|
||||||
slot.params.seed = json_value(data, "seed", default_params.seed);
|
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
|
||||||
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
||||||
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
|
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
|
||||||
|
|
||||||
|
@ -1028,7 +1028,6 @@ struct server_context {
|
||||||
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
|
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
llama_sampling_set_rng_seed(slot.ctx_sampling, slot.params.seed);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.command = SLOT_COMMAND_LOAD_PROMPT;
|
slot.command = SLOT_COMMAND_LOAD_PROMPT;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue