diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 57893c2eb..124133a35 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100 #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 @@ -46,6 +47,7 @@ int main(int argc, char ** argv) { } std::default_random_engine rng(params.seed); std::uniform_real_distribution<> u_dist; + std::uniform_int_distribution<> u_int_dist; #ifndef LOG_DISABLE_LOGS log_set_target(log_filename_generator("speculative", "log")); @@ -188,12 +190,15 @@ int main(int argc, char ** argv) { drafts[0].i_batch_tgt[0] = 0; while (true) { + std::set active_seqs = {}; + // print current draft sequences for (int s = 0; s < n_seq_dft; ++s) { if (!drafts[s].active) { continue; } + active_seqs.insert(s); const auto & tokens = drafts[s].tokens; LOG("draft %d: %s\n", s, LOG_TOKENS_TOSTR_PRETTY(ctx_dft, tokens).c_str()); @@ -219,12 +224,13 @@ int main(int argc, char ** argv) { float p_tgt = 0, p_dft = 0; // GGML_ASSERT(dist_tgt.size() == dist_dft.size()); - for (int s = 0; s < n_seq_dft; ++s) { - if (!drafts[s].active) { - continue; - } + + while (active_seqs.size() > 0) { + // randomly select a sequence to verify from active sequences + int s = *std::next(active_seqs.begin(), u_int_dist(rng) % active_seqs.size()); if (i_dft >= (int) drafts[s].tokens.size()) { drafts[s].active = false; + active_seqs.erase(s); continue; } if (accept) { @@ -232,9 +238,10 @@ int main(int argc, char ** argv) { if (drafts[s].tokens[i_dft] != drafts[s_keep].tokens[i_dft]) { drafts[s].active = false; } + active_seqs.erase(s); continue; } - + LOG("verifying sequence #%d at pos #%d from %d active sequence(s)\n", s, i_dft, (int) active_seqs.size()); float r = u_dist(rng); llama_token_data_array dist_dft = drafts[s].dist[i_dft]; // acquire the token probabilities assigned by the draft and target models @@ -290,13 +297,19 @@ int main(int argc, char ** argv) { }); } - for(int i = s; i < n_seq_dft; i++) { + active_seqs.erase(s); + for(int i = 0; i < n_seq_dft; i++) { + if (i == s) { + continue; + } if (drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) { // synchronize active status for sequences with the same drafted token drafts[i].active = drafts[i].active && accept; + if (!drafts[i].active) { + active_seqs.erase(s); + } } } - } if (!accept) { @@ -380,16 +393,22 @@ int main(int argc, char ** argv) { llama_kv_cache_seq_keep(ctx_tgt, 0); } + std::set freed_addrs; for (int s = 0; s < n_seq_dft; ++s) { drafts[s].active = false; drafts[s].tokens.clear(); drafts[s].i_batch_tgt.clear(); // free dist and clear for (int i = 0; i < drafts[s].dist.size(); i++) { + if (freed_addrs.find(drafts[s].dist[i].data) != freed_addrs.end()) { + continue; + } free(drafts[s].dist[i].data); + freed_addrs.insert(drafts[s].dist[i].data); } drafts[s].dist.clear(); } + freed_addrs.clear(); // note: will be erased after the speculation phase drafts[0].tokens.push_back(token_id); drafts[0].dist.push_back(llama_token_data_array{});