From b2b36e9e95249bbeaf2d833377777c5e32c39576 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 4 Sep 2024 22:16:30 +0300 Subject: [PATCH] example : fix build + fix speculative ggml-ci --- common/sampling.cpp | 12 +++++- common/sampling.h | 5 ++- examples/batched.swift/Sources/main.swift | 4 +- .../llama/src/main/cpp/llama-android.cpp | 2 +- .../llama.cpp.swift/LibLlama.swift | 6 ++- examples/speculative/speculative.cpp | 38 ++++++++++++------- 6 files changed, 45 insertions(+), 22 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 18d3be845..edc6cd05b 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -120,7 +120,7 @@ struct gpt_sampler * gpt_sampler_cp(gpt_sampler * gsmpl) { /* .bias = */ llama_constraint_cp(gsmpl->bias), /* .pnlt = */ llama_constraint_cp(gsmpl->pnlt), /* .grmr = */ llama_constraint_cp(gsmpl->grmr), - /* .smpl = */ llama_sampler_cp(gsmpl->smpl) + /* .smpl = */ llama_sampler_cp (gsmpl->smpl) }; } @@ -158,7 +158,7 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_token_da return llama_sampler_sample(gsmpl->smpl, cur_p); } -llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx) { +llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) { auto & bias = gsmpl->bias; auto & pnlt = gsmpl->pnlt; auto & grmr = gsmpl->grmr; @@ -173,10 +173,18 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context llama_constraint_apply(bias, cur_p); llama_constraint_apply(pnlt, cur_p); + if (grammar_first) { + llama_constraint_apply(grmr, cur_p); + } + llama_sampler_apply(smpl, cur_p); const llama_token id = llama_sampler_sample(smpl, cur_p); + if (grammar_first) { + return id; + } + // check if it the sampled token fits the grammar { llama_token_data single_token_data = { id, 1.0f, 0.0f }; diff --git a/common/sampling.h b/common/sampling.h index 9bdeadf78..87673efa3 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -92,7 +92,10 @@ void gpt_print_timings(struct llama_context * ctx, struct gpt_sampler * gsmpl); // - check if the token fits the grammar (if any) // - if not: resample by first applying the grammar constraints and then sampling again (slower path) // -llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx); +// if grammar_first is true, the grammar is applied before the constraints (slower) +// useful in cases where all the resulting candidates must fit the grammar +// +llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false); // helpers diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift index 6b9f3e0d5..6ff62ae06 100644 --- a/examples/batched.swift/Sources/main.swift +++ b/examples/batched.swift/Sources/main.swift @@ -141,9 +141,7 @@ while n_cur <= n_len { llama_sampler_set_logits(smpl, logits) - let new_token_id = llama_sampler_sample_dist(smpl, nil) - - // const llama_token new_token_id = llama_sampler_sample_greedy(smpl, nil, false); + let new_token_id = llama_sampler_sample(smpl, nil) // is it an end of stream? -> mark the stream as finished if llama_token_is_eog(model, new_token_id) || n_cur == n_len { diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index 666e89764..1a4908501 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -399,7 +399,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop( llama_sampler_set_logits(sampling, logits); // sample the most likely token - const auto new_token_id = llama_sampler_sample_greedy(sampling, nullptr, false); + const auto new_token_id = llama_sampler_sample(sampling, nullptr); const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value); if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { diff --git a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift index 930336b27..bd6513d34 100644 --- a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift +++ b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift @@ -43,7 +43,9 @@ actor LlamaContext { self.tokens_list = [] self.batch = llama_batch_init(512, 0, 1) self.temporary_invalid_cchars = [] - self.sampling = llama_sampler_init(context, llama_sampler_default_params()) + var sparams = llama_sampler_default_params() + sparams.type = LLAMA_SAMPLER_TYPE_GREEDY + self.sampling = llama_sampler_init(context, sparams) } deinit { @@ -151,7 +153,7 @@ actor LlamaContext { llama_sampler_set_logits(sampling, logits); - new_token_id = llama_sampler_sample_greedy(sampling, nil, false) + new_token_id = llama_sampler_sample(sampling, nil) if llama_token_is_eog(model, new_token_id) || n_cur == n_len { print("\n") diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 5cd14c49d..d51c76849 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -228,20 +228,13 @@ int main(int argc, char ** argv) { bool accept = false; if (params.sparams.temp > 0) { // stochastic verification - const float * logits = llama_get_logits_ith(ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]); - - gpt_sampler_set_logits(smpl, logits); + gpt_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft], true); auto & dist_tgt = *gpt_sampler_get_candidates(smpl); - gpt_sampler_apply_grammar(smpl, &dist_tgt); - llama_constraint_apply(softmax, &dist_tgt); - float p_tgt = 0.0f; float p_dft = 0.0f; - // GGML_ASSERT(dist_tgt.size() == dist_dft.size()); - while (active_seqs.size() > 0) { // randomly select a sequence to verify from active sequences std::uniform_int_distribution u_int_dist(0, active_seqs.size() - 1); @@ -259,9 +252,13 @@ int main(int argc, char ** argv) { } continue; } + LOG("verifying sequence #%d at pos #%d from %d active sequence(s)\n", s, i_dft, (int) active_seqs.size()); float r = u_dist(rng); llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), true }; + + //GGML_ASSERT(dist_tgt.size <= dist_dft.size); + // acquire the token probabilities assigned by the draft and target models for (size_t i = 0; i < dist_tgt.size; i++) { if (dist_tgt.data[i].id == drafts[s].tokens[i_dft]) { @@ -291,7 +288,6 @@ int main(int argc, char ** argv) { // calculate residual probability GGML_ASSERT(dist_tgt.sorted); GGML_ASSERT(dist_dft.sorted); - float sum_probs = 0.0f; // sort dist by id std::sort(dist_tgt.data, dist_tgt.data + dist_tgt.size, [](const llama_token_data &a, const llama_token_data &b) { @@ -301,10 +297,18 @@ int main(int argc, char ** argv) { return a.id < b.id; }); + float sum_probs = 0.0f; + for (size_t i = 0; i < dist_tgt.size; i++) { - dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p - dist_dft.data[i].p); + if (i < dist_dft.size) { + dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p - dist_dft.data[i].p); + } else { + dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p); + } + sum_probs += dist_tgt.data[i].p; } + for (size_t i = 0; i < dist_tgt.size; i++) { dist_tgt.data[i].p /= sum_probs; } @@ -334,7 +338,16 @@ int main(int argc, char ** argv) { // all drafted tokens were rejected // sample from the target model LOG("all drafted tokens were rejected, sampling from residual distribution\n"); - token_id = gpt_sampler_sample(smpl, &dist_tgt); + std::vector probs(dist_tgt.size); + for (size_t i = 0; i < dist_tgt.size; ++i) { + probs[i] = dist_tgt.data[i].p; + } + + std::discrete_distribution<> dist(probs.begin(), probs.end()); + + const int idx = dist(rng); + + token_id = dist_tgt.data[idx].id; gpt_sampler_accept(smpl, token_id, true); token_str = llama_token_to_piece(ctx_tgt, token_id); } @@ -467,7 +480,7 @@ int main(int argc, char ** argv) { continue; } - gpt_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft); + gpt_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft, true); const auto * cur_p = gpt_sampler_get_candidates(drafts[s].smpl); @@ -512,7 +525,6 @@ int main(int argc, char ** argv) { } drafts[n_seq_cur].smpl = gpt_sampler_cp(drafts[s].smpl); - sa.push_back(n_seq_cur); n_seq_cur++;