fixes based on review (@JohannesGaessler)
This commit is contained in:
parent
94f6256fd0
commit
e4896e71b5
1 changed files with 4 additions and 4 deletions
|
@ -48,7 +48,7 @@ int main(int argc, char ** argv) {
|
||||||
r_gen = std::mt19937(time(NULL));
|
r_gen = std::mt19937(time(NULL));
|
||||||
}
|
}
|
||||||
std::uniform_int_distribution<std::mt19937::result_type> u_dist(0, RAND_MAX);
|
std::uniform_int_distribution<std::mt19937::result_type> u_dist(0, RAND_MAX);
|
||||||
|
|
||||||
#ifndef LOG_DISABLE_LOGS
|
#ifndef LOG_DISABLE_LOGS
|
||||||
log_set_target(log_filename_generator("speculative", "log"));
|
log_set_target(log_filename_generator("speculative", "log"));
|
||||||
LOG_TEE("Log start\n");
|
LOG_TEE("Log start\n");
|
||||||
|
@ -218,7 +218,7 @@ int main(int argc, char ** argv) {
|
||||||
// stochastic verification
|
// stochastic verification
|
||||||
|
|
||||||
llama_token_data_array dist_tgt = llama_sampling_probability_distribution(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
|
llama_token_data_array dist_tgt = llama_sampling_probability_distribution(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
|
||||||
float p_tgt, p_dft;
|
float p_tgt = 0, p_dft = 0;
|
||||||
|
|
||||||
// GGML_ASSERT(dist_tgt.size() == dist_dft.size());
|
// GGML_ASSERT(dist_tgt.size() == dist_dft.size());
|
||||||
for (int s = 0; s < n_seq_dft; ++s) {
|
for (int s = 0; s < n_seq_dft; ++s) {
|
||||||
|
@ -239,7 +239,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
float r = u_dist(r_gen) / (float) RAND_MAX;
|
float r = u_dist(r_gen) / (float) RAND_MAX;
|
||||||
llama_token_data_array dist_dft = drafts[s].dist[i_dft];
|
llama_token_data_array dist_dft = drafts[s].dist[i_dft];
|
||||||
// acquire the probability of the token from the draft model
|
// acquire the token probabilities assigned by the draft and target models
|
||||||
for (int i = 0; i < dist_tgt.size; i++) {
|
for (int i = 0; i < dist_tgt.size; i++) {
|
||||||
if (dist_tgt.data[i].id == drafts[s].tokens[i_dft]) {
|
if (dist_tgt.data[i].id == drafts[s].tokens[i_dft]) {
|
||||||
p_tgt = dist_tgt.data[i].p;
|
p_tgt = dist_tgt.data[i].p;
|
||||||
|
@ -295,7 +295,7 @@ int main(int argc, char ** argv) {
|
||||||
for(int i = s; i < n_seq_dft; i++) {
|
for(int i = s; i < n_seq_dft; i++) {
|
||||||
if (drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) {
|
if (drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) {
|
||||||
// synchronize active status for sequences with the same drafted token
|
// synchronize active status for sequences with the same drafted token
|
||||||
drafts[i].active = drafts[i].active & accept;
|
drafts[i].active = drafts[i].active && accept;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue