fix style

This commit is contained in:
Minsoo Cheong 2024-02-27 15:29:14 +09:00
parent fb18827b4e
commit 34b942a429

View file

@ -208,12 +208,11 @@ int main(int argc, char ** argv) {
bool accept = false; bool accept = false;
if (params.sparams.temp > 0) { if (params.sparams.temp > 0) {
// stochastic verification // stochastic verification
llama_token_data_array dist_tgt = llama_sampling_probability_distribution(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
float p_tgt, p_dft;
// GGML_ASSERT(dist_tgt.size() == dist_dft.size());
llama_token_data_array dist_tgt = llama_sampling_probability_distribution(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
float p_tgt, p_dft;
// GGML_ASSERT(dist_tgt.size() == dist_dft.size());
for (int s = 0; s < n_seq_dft; ++s) { for (int s = 0; s < n_seq_dft; ++s) {
if (!drafts[s].active) { if (!drafts[s].active) {
continue; continue;
@ -234,7 +233,6 @@ int main(int argc, char ** argv) {
llama_token_data_array dist_dft = drafts[s].dist[i_dft]; llama_token_data_array dist_dft = drafts[s].dist[i_dft];
// acquire the probability of the token from the draft model // acquire the probability of the token from the draft model
for (int i = 0; i < dist_tgt.size; i++) { for (int i = 0; i < dist_tgt.size; i++) {
if (dist_tgt.data[i].id == drafts[s].tokens[i_dft]) { if (dist_tgt.data[i].id == drafts[s].tokens[i_dft]) {
p_tgt = dist_tgt.data[i].p; p_tgt = dist_tgt.data[i].p;
} }
@ -258,7 +256,7 @@ int main(int argc, char ** argv) {
} else { } else {
LOG("draft token %d of sequence %d (%d, '%s') rejected\n", i_dft, s, drafts[s].tokens[i_dft], llama_token_to_piece(ctx_tgt, drafts[s].tokens[i_dft]).c_str()); LOG("draft token %d of sequence %d (%d, '%s') rejected\n", i_dft, s, drafts[s].tokens[i_dft], llama_token_to_piece(ctx_tgt, drafts[s].tokens[i_dft]).c_str());
drafts[s].active = false; drafts[s].active = false;
// calculate residual probability // calculate residual probability
GGML_ASSERT(dist_tgt.sorted); GGML_ASSERT(dist_tgt.sorted);
GGML_ASSERT(dist_dft.sorted); GGML_ASSERT(dist_dft.sorted);