replace use of rand() with mt19937 sampling

This commit is contained in:
Minsoo Cheong 2024-02-29 00:26:23 +09:00
parent 6afc1f60e1
commit 94f6256fd0

View file

@ -41,11 +41,13 @@ int main(int argc, char ** argv) {
// probability threshold for splitting a draft branch (only for n_seq_dft > 1) // probability threshold for splitting a draft branch (only for n_seq_dft > 1)
const float p_split = params.p_split; const float p_split = params.p_split;
std::mt19937 r_gen;
if (params.seed >= 0) { if (params.seed >= 0) {
srand(params.seed); r_gen = std::mt19937(params.seed);
} else { } else {
srand(time(NULL)); r_gen = std::mt19937(time(NULL));
} }
std::uniform_int_distribution<std::mt19937::result_type> u_dist(0, RAND_MAX);
#ifndef LOG_DISABLE_LOGS #ifndef LOG_DISABLE_LOGS
log_set_target(log_filename_generator("speculative", "log")); log_set_target(log_filename_generator("speculative", "log"));
@ -235,7 +237,7 @@ int main(int argc, char ** argv) {
continue; continue;
} }
float r = rand() / (float) RAND_MAX; float r = u_dist(r_gen) / (float) RAND_MAX;
llama_token_data_array dist_dft = drafts[s].dist[i_dft]; llama_token_data_array dist_dft = drafts[s].dist[i_dft];
// acquire the probability of the token from the draft model // acquire the probability of the token from the draft model
for (int i = 0; i < dist_tgt.size; i++) { for (int i = 0; i < dist_tgt.size; i++) {