randomly select next sequence to verify + fix bug in memory freeing

This commit is contained in:
Minsoo Cheong 2024-02-29 15:47:41 +09:00
parent 6b35c8b3cf
commit 2ad3f7c28c

View file

@ -5,6 +5,7 @@
#include <cstdio> #include <cstdio>
#include <string> #include <string>
#include <vector> #include <vector>
#include <set>
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100 #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
@ -46,6 +47,7 @@ int main(int argc, char ** argv) {
} }
std::default_random_engine rng(params.seed); std::default_random_engine rng(params.seed);
std::uniform_real_distribution<> u_dist; std::uniform_real_distribution<> u_dist;
std::uniform_int_distribution<> u_int_dist;
#ifndef LOG_DISABLE_LOGS #ifndef LOG_DISABLE_LOGS
log_set_target(log_filename_generator("speculative", "log")); log_set_target(log_filename_generator("speculative", "log"));
@ -188,12 +190,15 @@ int main(int argc, char ** argv) {
drafts[0].i_batch_tgt[0] = 0; drafts[0].i_batch_tgt[0] = 0;
while (true) { while (true) {
std::set<int> active_seqs = {};
// print current draft sequences // print current draft sequences
for (int s = 0; s < n_seq_dft; ++s) { for (int s = 0; s < n_seq_dft; ++s) {
if (!drafts[s].active) { if (!drafts[s].active) {
continue; continue;
} }
active_seqs.insert(s);
const auto & tokens = drafts[s].tokens; const auto & tokens = drafts[s].tokens;
LOG("draft %d: %s\n", s, LOG_TOKENS_TOSTR_PRETTY(ctx_dft, tokens).c_str()); LOG("draft %d: %s\n", s, LOG_TOKENS_TOSTR_PRETTY(ctx_dft, tokens).c_str());
@ -219,12 +224,13 @@ int main(int argc, char ** argv) {
float p_tgt = 0, p_dft = 0; float p_tgt = 0, p_dft = 0;
// GGML_ASSERT(dist_tgt.size() == dist_dft.size()); // GGML_ASSERT(dist_tgt.size() == dist_dft.size());
for (int s = 0; s < n_seq_dft; ++s) {
if (!drafts[s].active) { while (active_seqs.size() > 0) {
continue; // randomly select a sequence to verify from active sequences
} int s = *std::next(active_seqs.begin(), u_int_dist(rng) % active_seqs.size());
if (i_dft >= (int) drafts[s].tokens.size()) { if (i_dft >= (int) drafts[s].tokens.size()) {
drafts[s].active = false; drafts[s].active = false;
active_seqs.erase(s);
continue; continue;
} }
if (accept) { if (accept) {
@ -232,9 +238,10 @@ int main(int argc, char ** argv) {
if (drafts[s].tokens[i_dft] != drafts[s_keep].tokens[i_dft]) { if (drafts[s].tokens[i_dft] != drafts[s_keep].tokens[i_dft]) {
drafts[s].active = false; drafts[s].active = false;
} }
active_seqs.erase(s);
continue; continue;
} }
LOG("verifying sequence #%d at pos #%d from %d active sequence(s)\n", s, i_dft, (int) active_seqs.size());
float r = u_dist(rng); float r = u_dist(rng);
llama_token_data_array dist_dft = drafts[s].dist[i_dft]; llama_token_data_array dist_dft = drafts[s].dist[i_dft];
// acquire the token probabilities assigned by the draft and target models // acquire the token probabilities assigned by the draft and target models
@ -290,13 +297,19 @@ int main(int argc, char ** argv) {
}); });
} }
for(int i = s; i < n_seq_dft; i++) { active_seqs.erase(s);
for(int i = 0; i < n_seq_dft; i++) {
if (i == s) {
continue;
}
if (drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) { if (drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) {
// synchronize active status for sequences with the same drafted token // synchronize active status for sequences with the same drafted token
drafts[i].active = drafts[i].active && accept; drafts[i].active = drafts[i].active && accept;
if (!drafts[i].active) {
active_seqs.erase(s);
}
} }
} }
} }
if (!accept) { if (!accept) {
@ -380,16 +393,22 @@ int main(int argc, char ** argv) {
llama_kv_cache_seq_keep(ctx_tgt, 0); llama_kv_cache_seq_keep(ctx_tgt, 0);
} }
std::set<llama_token_data *> freed_addrs;
for (int s = 0; s < n_seq_dft; ++s) { for (int s = 0; s < n_seq_dft; ++s) {
drafts[s].active = false; drafts[s].active = false;
drafts[s].tokens.clear(); drafts[s].tokens.clear();
drafts[s].i_batch_tgt.clear(); drafts[s].i_batch_tgt.clear();
// free dist and clear // free dist and clear
for (int i = 0; i < drafts[s].dist.size(); i++) { for (int i = 0; i < drafts[s].dist.size(); i++) {
if (freed_addrs.find(drafts[s].dist[i].data) != freed_addrs.end()) {
continue;
}
free(drafts[s].dist[i].data); free(drafts[s].dist[i].data);
freed_addrs.insert(drafts[s].dist[i].data);
} }
drafts[s].dist.clear(); drafts[s].dist.clear();
} }
freed_addrs.clear();
// note: will be erased after the speculation phase // note: will be erased after the speculation phase
drafts[0].tokens.push_back(token_id); drafts[0].tokens.push_back(token_id);
drafts[0].dist.push_back(llama_token_data_array{}); drafts[0].dist.push_back(llama_token_data_array{});