diff --git a/examples/duo/duo.cpp b/examples/duo/duo.cpp index e9dc3bd11..fa9cecc46 100644 --- a/examples/duo/duo.cpp +++ b/examples/duo/duo.cpp @@ -18,21 +18,30 @@ static void dbg_color(const std::string & s, const std::string & fg) static void dbg_accepted(const std::string & accepted) { - static const std::string kGreen = "\033[32m"; - dbg_color(accepted, kGreen); + dbg_color(accepted, /* green */ "\033[32m"); } -static void dbg_not_matched(const std::string & accepted) +static void dbg_default(const std::string & accepted) { dbg_color(accepted, ""); } static void dbg_rejected(const std::string & rejected) { - static const std::string kRed = "\033[31m"; - dbg_color(rejected, kRed); + dbg_color(rejected, /* red */ "\033[31m"); } +template +static std::string to_string(llama_context * ctx, Iterator from, Iterator to) +{ + std::string res = ""; + for (auto it = from; it != to; ++it) + { + res += llama_token_to_piece(ctx, *it); + } + return res; +} + using llama_tokens = std::vector; struct speculation_context @@ -93,97 +102,97 @@ static int speculation( int32_t active = 1; - llama_batch batch = llama_batch_init(512, 0, 1); - - for (size_t i = 0; i < input.size(); i++) - { - llama_batch_add(batch, input[i], i, { 0 }, false); - } - - batch.logits[batch.n_tokens - 1] = true; - - if (llama_decode(ctx[active], batch) != 0) { - LOG_TEE("%s: llama_decode() failed\n", __func__); - return 1; - } - - int logit_idx = batch.n_tokens - 1; - std::vector local_spec = input; - size_t match_len; - - while (true) { - auto next_tokens = greedy_tokens(model[active], ctx[active], logit_idx, logit_idx + 1); - if (next_tokens.size() != 1) { - fprintf(stderr, "invalid next tokens\n"); - return 1; - } - - local_spec.push_back(next_tokens[0]); - - { - std::lock_guard _lock(spec_ctx->mtx); - if (spec_ctx->done) - { - break; - } - auto& spec = spec_ctx->candidate; - bool match = true; - match_len = local_spec.size() - 1; - for (size_t i = 0; i < std::min(spec.size(), local_spec.size()); i++) - { - if (spec[i] != local_spec[i]) - { - match = false; - match_len = i; - // here we need to clear both contexts - llama_kv_cache_seq_rm(ctx[0], 0, i, -1); - llama_kv_cache_seq_rm(ctx[1], 0, i, -1); - break; - } - } - if (match) { - spec = local_spec; - } else { - local_spec = spec; - } - active = spec_ctx->active_id; - } - - llama_batch_clear(batch); - // TODO theoretically this can be empty? - for (size_t i = match_len; i < local_spec.size(); i++) { - llama_batch_add(batch, local_spec[i], i, { 0 }, true); - } - - logit_idx = batch.n_tokens - 1; - - if (llama_decode(ctx[active], batch)) { - fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); - return 1; - } - } - - llama_batch_free(batch); - return 0; -} - -static int target(llama_model * model, llama_context * ctx, const llama_tokens& input, size_t n_predict) -{ - // TODO: batch size llama_batch batch = llama_batch_init(512, 0, 1); - // evaluate the initial prompt - for (size_t i = 0; i < input.size(); i++) { + for (size_t i = 0; i < input.size(); i++) + { llama_batch_add(batch, input[i], i, { 0 }, false); } batch.logits[batch.n_tokens - 1] = true; - if (llama_decode(ctx, batch) != 0) { + if (llama_decode(ctx[active], batch) != 0) { LOG_TEE("%s: llama_decode() failed\n", __func__); return 1; } + int logit_idx = batch.n_tokens - 1; + std::vector local_spec = input; + size_t match_len; + + // TODO: here we need to not generate too many and wait + while (true) { + auto next_tokens = greedy_tokens(model[active], ctx[active], logit_idx, logit_idx + 1); + if (next_tokens.size() != 1) { + fprintf(stderr, "invalid next tokens\n"); + return 1; + } + + local_spec.push_back(next_tokens[0]); + + { + std::lock_guard _lock(spec_ctx->mtx); + if (spec_ctx->done) + { + break; + } + auto& spec = spec_ctx->candidate; + bool match = true; + match_len = local_spec.size() - 1; + for (size_t i = 0; i < std::min(spec.size(), local_spec.size()); i++) + { + if (spec[i] != local_spec[i]) + { + match = false; + match_len = i; + // here we need to clear both contexts + llama_kv_cache_seq_rm(ctx[0], 0, i, -1); + llama_kv_cache_seq_rm(ctx[1], 0, i, -1); + break; + } + } + if (match) { + spec = local_spec; + } else { + local_spec = spec; + } + active = spec_ctx->active_id; + } + + llama_batch_clear(batch); + // TODO theoretically this can be empty? + for (size_t i = match_len; i < local_spec.size(); i++) { + llama_batch_add(batch, local_spec[i], i, { 0 }, true); + } + + logit_idx = batch.n_tokens - 1; + + if (llama_decode(ctx[active], batch)) { + fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); + return 1; + } + } + + llama_batch_free(batch); + return 0; +} + +static int target(llama_model * model, llama_context * ctx, const llama_tokens& input, size_t n_predict) +{ + dbg_default(to_string(ctx, input.begin(), input.end())); + // TODO: batch size + llama_batch batch = llama_batch_init(512, 0, 1); + for (size_t i = 0; i < input.size(); i++) + { + llama_batch_add(batch, input[i], i, { 0 }, false); + } + batch.logits[batch.n_tokens - 1] = true; + + if (llama_decode(ctx, batch) != 0) { + fprintf(stderr, "llama_decode() failed\n"); + return 1; + } + // how many tokens are currently accepted // TODO: rename to n_accepted size_t n_cur = input.size(); @@ -195,7 +204,7 @@ static int target(llama_model * model, llama_context * ctx, const llama_tokens& int logits_from = batch.n_tokens - 1; int logits_to = batch.n_tokens; - llama_tokens input_seq, next_tokens, output; + llama_tokens input_seq, next_tokens; input_seq.push_back(input.back()); while (n_decode <= n_predict) @@ -241,7 +250,6 @@ static int target(llama_model * model, llama_context * ctx, const llama_tokens& break; } } - output.insert(output.end(), next_tokens.begin(), next_tokens.end()); { std::lock_guard _lock(spec_ctx.mtx); @@ -259,31 +267,11 @@ static int target(llama_model * model, llama_context * ctx, const llama_tokens& } } - std::string accepted = ""; - for (size_t i = next_tokens_pos; i < next_tokens_pos + n_match; i++) - { - accepted += llama_token_to_piece(ctx, spec[i]); - } - dbg_accepted(accepted); - if (n_match != next_tokens.size()) - { - std::string rejected = ""; - for (size_t i = next_tokens_pos + n_match; i < spec.size(); i++) - { - rejected += llama_token_to_piece(ctx, spec[i]); - } - dbg_rejected(rejected); - std::string not_matched = ""; - for (size_t i = n_match; i < next_tokens.size(); i++) - { - not_matched += llama_token_to_piece(ctx, next_tokens[i]); - } - dbg_not_matched(not_matched); - } - - // remove non-matched tokens + dbg_accepted(to_string(ctx, spec.begin() + next_tokens_pos, spec.begin() + next_tokens_pos + n_match)); if (n_match != next_tokens.size()) { + dbg_rejected(to_string(ctx, spec.begin() + next_tokens_pos + n_match, spec.end())); + dbg_default(to_string(ctx, next_tokens.begin() + n_match, next_tokens.end())); spec.erase(spec.begin() + next_tokens_pos, spec.end()); for (const auto tok: next_tokens) { @@ -337,7 +325,6 @@ int main(int argc, char ** argv) { params.seed = time(NULL); } - // parse 2 speculation rpc instances std::string draft_rpcs = params.rpc_servers_draft; size_t i = draft_rpcs.find(','); if (i == std::string::npos || draft_rpcs.find(',', i + 1) != std::string::npos) @@ -360,7 +347,7 @@ int main(int argc, char ** argv) { // prepare draft model and contexts. No need for two model instances? std::vector draft_models = {nullptr, nullptr}; - std::vector draft_ctx = {nullptr, nullptr}; + std::vector draft_ctx = {nullptr, nullptr}; params.model = params.model_draft; params.n_gpu_layers = params.n_gpu_layers_draft;