From 66982abcb1cab7e88d849e5c9ce946beb079ef37 Mon Sep 17 00:00:00 2001 From: Oleksandr Kuvshynov <661042+okuvshynov@users.noreply.github.com> Date: Fri, 24 May 2024 12:22:59 -0400 Subject: [PATCH] fixes --- examples/duo/README.md | 2 +- examples/duo/duo.cpp | 33 ++++++++++++++++++++++++++++++--- 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/examples/duo/README.md b/examples/duo/README.md index 56b1dad3b..dfcbb1023 100644 --- a/examples/duo/README.md +++ b/examples/duo/README.md @@ -4,4 +4,4 @@ Minimal example. What's not implemented, but can be implemented separately in pi * tree-based speculation * correct sampling * support more than 2 instances -* \ No newline at end of file +* just one instance speculates diff --git a/examples/duo/duo.cpp b/examples/duo/duo.cpp index ab66bcd0f..75abf2467 100644 --- a/examples/duo/duo.cpp +++ b/examples/duo/duo.cpp @@ -1,6 +1,7 @@ #include "common.h" #include "llama.h" +#include #include #include #include @@ -59,6 +60,7 @@ static void split_done_cb(int split) if (split == 1 || split == 2) { std::lock_guard guard(spec_ctx.mtx); + fprintf(stderr, "split_done = %d\n", split); spec_ctx.active_id = split - 1; } } @@ -121,7 +123,24 @@ static int speculation( size_t match_len; // TODO: here we need to not generate too many and wait - while (true) { + while (true) + { + // silliest thing ever + bool wait = false; + { + std::lock_guard g(spec_ctx->mtx); + if (spec_ctx->active_id != 0) + { + wait = true; + } + } + if (wait) + { + std::this_thread::sleep_for(std::chrono::milliseconds{10}); + continue; + } + + auto next_tokens = greedy_tokens(model[active], ctx[active], logit_idx, logit_idx + 1); if (next_tokens.size() != 1) { fprintf(stderr, "invalid next tokens\n"); @@ -139,6 +158,7 @@ static int speculation( auto& shared = spec_ctx->candidate; bool match = true; match_len = local.size() - 1; + fprintf(stderr, "spec #%d: %zu | %zu\n", active, shared.size(), local.size()); for (size_t i = 0; i < std::min(shared.size(), local.size()); i++) { if (shared[i] != local[i]) @@ -159,7 +179,11 @@ static int speculation( { local = shared; } - active = spec_ctx->active_id; + if (active != spec_ctx->active_id) + { + active = spec_ctx->active_id; + fprintf(stderr, "updating active_id = %d\n", active); + } } llama_batch_clear(batch); @@ -294,6 +318,8 @@ static int target( break; } + fprintf(stderr, "tgt: input_seq.size() = %zu\n", input_seq.size()); + llama_batch_clear(batch); for (size_t i = 0; i < input_seq.size(); i++) { @@ -365,7 +391,8 @@ int main(int argc, char ** argv) { params.n_threads = params.n_threads_draft; } params.n_threads_batch = params.n_threads_batch_draft; - + + params.cb_split_done = nullptr; params.rpc_servers = draft_rpcs.substr(0, i); std::tie(draft_models[0], draft_ctx[0]) = llama_init_from_gpt_params(params); params.rpc_servers = draft_rpcs.substr(i + 1);