From 67ad517e110567ecc52404fc7d800cb17b4b5412 Mon Sep 17 00:00:00 2001 From: Minsoo Cheong Date: Mon, 4 Mar 2024 14:55:35 +0900 Subject: [PATCH] remove malloc code by utilizing vectors --- examples/speculative/speculative.cpp | 28 ++++++++-------------------- 1 file changed, 8 insertions(+), 20 deletions(-) diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 0546e68fb..85bc0a762 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -19,7 +19,7 @@ struct seq_draft { std::vector i_batch_tgt; std::vector tokens; - std::vector dist; + std::vector> dists; struct llama_sampling_context * ctx_sampling; }; @@ -243,7 +243,7 @@ int main(int argc, char ** argv) { } LOG("verifying sequence #%d at pos #%d from %d active sequence(s)\n", s, i_dft, (int) active_seqs.size()); float r = u_dist(rng); - llama_token_data_array dist_dft = drafts[s].dist[i_dft]; + llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), true }; // acquire the token probabilities assigned by the draft and target models for (size_t i = 0; i < dist_tgt.size; i++) { if (dist_tgt.data[i].id == drafts[s].tokens[i_dft]) { @@ -393,25 +393,15 @@ int main(int argc, char ** argv) { llama_kv_cache_seq_keep(ctx_tgt, 0); } - std::set freed_addrs; for (int s = 0; s < n_seq_dft; ++s) { drafts[s].active = false; drafts[s].tokens.clear(); drafts[s].i_batch_tgt.clear(); - // free dist and clear - for (size_t i = 0; i < drafts[s].dist.size(); i++) { - if (freed_addrs.find(drafts[s].dist[i].data) != freed_addrs.end()) { - continue; - } - free(drafts[s].dist[i].data); - freed_addrs.insert(drafts[s].dist[i].data); - } - drafts[s].dist.clear(); + drafts[s].dists.clear(); } - freed_addrs.clear(); // note: will be erased after the speculation phase drafts[0].tokens.push_back(token_id); - drafts[0].dist.push_back(llama_token_data_array{}); + drafts[0].dists.push_back(std::vector()); drafts[0].i_batch_tgt.push_back(0); llama_batch_clear(batch_dft); @@ -493,7 +483,7 @@ int main(int argc, char ** argv) { drafts[n_seq_cur].skip = true; drafts[n_seq_cur].tokens = drafts[s].tokens; - drafts[n_seq_cur].dist = drafts[s].dist; + drafts[n_seq_cur].dists = drafts[s].dists; drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft; drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt; @@ -516,10 +506,8 @@ int main(int argc, char ** argv) { llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id, true); drafts[s].tokens.push_back(id); - // save cur_p into drafts[s].dist - llama_token_data *data = (llama_token_data *)malloc(sizeof(llama_token_data) * cur_p.size()); - memcpy(data, cur_p.data(), sizeof(llama_token_data) * cur_p.size()); - drafts[s].dist.push_back(llama_token_data_array{data, cur_p.size(), true}); + // save cur_p.data into drafts[s].dists + drafts[s].dists.push_back(cur_p); // add unique drafted tokens to the target batch drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens); @@ -571,7 +559,7 @@ int main(int argc, char ** argv) { } drafts[s].tokens.erase(drafts[s].tokens.begin()); - drafts[s].dist.erase(drafts[s].dist.begin()); + drafts[s].dists.erase(drafts[s].dists.begin()); } }