remove malloc code by utilizing vectors
This commit is contained in:
parent
45465b21d1
commit
67ad517e11
1 changed files with 8 additions and 20 deletions
|
@ -19,7 +19,7 @@ struct seq_draft {
|
|||
std::vector<int> i_batch_tgt;
|
||||
|
||||
std::vector<llama_token> tokens;
|
||||
std::vector<llama_token_data_array> dist;
|
||||
std::vector<std::vector<llama_token_data>> dists;
|
||||
|
||||
struct llama_sampling_context * ctx_sampling;
|
||||
};
|
||||
|
@ -243,7 +243,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
LOG("verifying sequence #%d at pos #%d from %d active sequence(s)\n", s, i_dft, (int) active_seqs.size());
|
||||
float r = u_dist(rng);
|
||||
llama_token_data_array dist_dft = drafts[s].dist[i_dft];
|
||||
llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), true };
|
||||
// acquire the token probabilities assigned by the draft and target models
|
||||
for (size_t i = 0; i < dist_tgt.size; i++) {
|
||||
if (dist_tgt.data[i].id == drafts[s].tokens[i_dft]) {
|
||||
|
@ -393,25 +393,15 @@ int main(int argc, char ** argv) {
|
|||
llama_kv_cache_seq_keep(ctx_tgt, 0);
|
||||
}
|
||||
|
||||
std::set<llama_token_data *> freed_addrs;
|
||||
for (int s = 0; s < n_seq_dft; ++s) {
|
||||
drafts[s].active = false;
|
||||
drafts[s].tokens.clear();
|
||||
drafts[s].i_batch_tgt.clear();
|
||||
// free dist and clear
|
||||
for (size_t i = 0; i < drafts[s].dist.size(); i++) {
|
||||
if (freed_addrs.find(drafts[s].dist[i].data) != freed_addrs.end()) {
|
||||
continue;
|
||||
}
|
||||
free(drafts[s].dist[i].data);
|
||||
freed_addrs.insert(drafts[s].dist[i].data);
|
||||
}
|
||||
drafts[s].dist.clear();
|
||||
drafts[s].dists.clear();
|
||||
}
|
||||
freed_addrs.clear();
|
||||
// note: will be erased after the speculation phase
|
||||
drafts[0].tokens.push_back(token_id);
|
||||
drafts[0].dist.push_back(llama_token_data_array{});
|
||||
drafts[0].dists.push_back(std::vector<llama_token_data>());
|
||||
drafts[0].i_batch_tgt.push_back(0);
|
||||
|
||||
llama_batch_clear(batch_dft);
|
||||
|
@ -493,7 +483,7 @@ int main(int argc, char ** argv) {
|
|||
drafts[n_seq_cur].skip = true;
|
||||
|
||||
drafts[n_seq_cur].tokens = drafts[s].tokens;
|
||||
drafts[n_seq_cur].dist = drafts[s].dist;
|
||||
drafts[n_seq_cur].dists = drafts[s].dists;
|
||||
drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft;
|
||||
drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt;
|
||||
|
||||
|
@ -516,10 +506,8 @@ int main(int argc, char ** argv) {
|
|||
llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id, true);
|
||||
|
||||
drafts[s].tokens.push_back(id);
|
||||
// save cur_p into drafts[s].dist
|
||||
llama_token_data *data = (llama_token_data *)malloc(sizeof(llama_token_data) * cur_p.size());
|
||||
memcpy(data, cur_p.data(), sizeof(llama_token_data) * cur_p.size());
|
||||
drafts[s].dist.push_back(llama_token_data_array{data, cur_p.size(), true});
|
||||
// save cur_p.data into drafts[s].dists
|
||||
drafts[s].dists.push_back(cur_p);
|
||||
|
||||
// add unique drafted tokens to the target batch
|
||||
drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);
|
||||
|
@ -571,7 +559,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
drafts[s].tokens.erase(drafts[s].tokens.begin());
|
||||
drafts[s].dist.erase(drafts[s].dist.begin());
|
||||
drafts[s].dists.erase(drafts[s].dists.begin());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue