replace use of rand() with mt19937 sampling
This commit is contained in:
parent
6afc1f60e1
commit
94f6256fd0
1 changed files with 6 additions and 4 deletions
|
@ -41,11 +41,13 @@ int main(int argc, char ** argv) {
|
|||
// probability threshold for splitting a draft branch (only for n_seq_dft > 1)
|
||||
const float p_split = params.p_split;
|
||||
|
||||
std::mt19937 r_gen;
|
||||
if (params.seed >= 0) {
|
||||
srand(params.seed);
|
||||
r_gen = std::mt19937(params.seed);
|
||||
} else {
|
||||
srand(time(NULL));
|
||||
r_gen = std::mt19937(time(NULL));
|
||||
}
|
||||
std::uniform_int_distribution<std::mt19937::result_type> u_dist(0, RAND_MAX);
|
||||
|
||||
#ifndef LOG_DISABLE_LOGS
|
||||
log_set_target(log_filename_generator("speculative", "log"));
|
||||
|
@ -235,7 +237,7 @@ int main(int argc, char ** argv) {
|
|||
continue;
|
||||
}
|
||||
|
||||
float r = rand() / (float) RAND_MAX;
|
||||
float r = u_dist(r_gen) / (float) RAND_MAX;
|
||||
llama_token_data_array dist_dft = drafts[s].dist[i_dft];
|
||||
// acquire the probability of the token from the draft model
|
||||
for (int i = 0; i < dist_tgt.size; i++) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue