diff --git a/examples/duo/duo.cpp b/examples/duo/duo.cpp index fa9cecc46..1250db018 100644 --- a/examples/duo/duo.cpp +++ b/examples/duo/duo.cpp @@ -10,9 +10,9 @@ static void dbg_color(const std::string & s, const std::string & fg) { static const std::string kReset = "\033[0m"; - static const std::string bold[] = { "", "\033[1m" }; + static const std::string kBold[] = { "", "\033[1m" }; static size_t index = 0; - std::cout << bold[index] << fg << s << kReset << std::flush; + std::cout << kBold[index] << fg << s << kReset << std::flush; index = 1 - index; } @@ -98,7 +98,7 @@ static int speculation( std::vector model, speculation_context * spec_ctx, std::vector ctx, - std::vector input /* copy here */) { + llama_tokens input /* copy here */) { int32_t active = 1; @@ -117,18 +117,18 @@ static int speculation( } int logit_idx = batch.n_tokens - 1; - std::vector local_spec = input; + llama_tokens local = input; size_t match_len; // TODO: here we need to not generate too many and wait while (true) { auto next_tokens = greedy_tokens(model[active], ctx[active], logit_idx, logit_idx + 1); if (next_tokens.size() != 1) { - fprintf(stderr, "invalid next tokens\n"); - return 1; + fprintf(stderr, "invalid next tokens\n"); + return 1; } - local_spec.push_back(next_tokens[0]); + local.push_back(next_tokens[0]); { std::lock_guard _lock(spec_ctx->mtx); @@ -136,12 +136,12 @@ static int speculation( { break; } - auto& spec = spec_ctx->candidate; + auto& shared = spec_ctx->candidate; bool match = true; - match_len = local_spec.size() - 1; - for (size_t i = 0; i < std::min(spec.size(), local_spec.size()); i++) + match_len = local.size() - 1; + for (size_t i = 0; i < std::min(shared.size(), local.size()); i++) { - if (spec[i] != local_spec[i]) + if (shared[i] != local[i]) { match = false; match_len = i; @@ -151,23 +151,28 @@ static int speculation( break; } } - if (match) { - spec = local_spec; - } else { - local_spec = spec; + if (match && shared.size() < local.size()) + { + shared = local; + } + else + { + local = shared; } active = spec_ctx->active_id; } llama_batch_clear(batch); // TODO theoretically this can be empty? - for (size_t i = match_len; i < local_spec.size(); i++) { - llama_batch_add(batch, local_spec[i], i, { 0 }, true); + for (size_t i = match_len; i < local.size(); i++) + { + llama_batch_add(batch, local[i], i, { 0 }, true); } logit_idx = batch.n_tokens - 1; - if (llama_decode(ctx[active], batch)) { + if (llama_decode(ctx[active], batch) != 0) + { fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); return 1; } @@ -177,7 +182,11 @@ static int speculation( return 0; } -static int target(llama_model * model, llama_context * ctx, const llama_tokens& input, size_t n_predict) +static int target( + llama_model * model, + llama_context * ctx, + const llama_tokens& input, + size_t n_predict) { dbg_default(to_string(ctx, input.begin(), input.end())); // TODO: batch size @@ -300,8 +309,8 @@ static int target(llama_model * model, llama_context * ctx, const llama_tokens& const auto t_main_end = ggml_time_us(); - LOG_TEE("%s: decoded %zu tokens in %.2f s, speed: %.2f t/s\n", - __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); + fprintf(stderr, "decoded %zu tokens in %.2f s, speed: %.2f t/s\n", + n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); llama_print_timings(ctx); fprintf(stderr, "\n");