diff --git a/common/common.cpp b/common/common.cpp index 5dae42814..680f06990 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1068,6 +1068,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.rpc_servers = argv[i]; return true; } + if (arg == "--rpcd") { + if (++i >= argc) { + invalid_param = true; + return true; + } + params.rpc_servers_draft = argv[i]; + return true; + } if (arg == "--no-mmap") { params.use_mmap = false; return true; diff --git a/common/common.h b/common/common.h index 2d5772d19..62b7b05e3 100644 --- a/common/common.h +++ b/common/common.h @@ -83,6 +83,7 @@ struct gpt_params { int32_t yarn_orig_ctx = 0; // YaRN original context length float defrag_thold = -1.0f; // KV cache defragmentation threshold std::string rpc_servers = ""; // comma separated list of RPC servers + std::string rpc_servers_draft = ""; // comma separated list of RPC servers used for draft model ggml_backend_sched_eval_callback cb_eval = nullptr; void * cb_eval_user_data = nullptr; diff --git a/examples/duo/README.md b/examples/duo/README.md index a34868788..56b1dad3b 100644 --- a/examples/duo/README.md +++ b/examples/duo/README.md @@ -1 +1,7 @@ -## duo \ No newline at end of file +## duo + +Minimal example. What's not implemented, but can be implemented separately in pieces: +* tree-based speculation +* correct sampling +* support more than 2 instances +* \ No newline at end of file diff --git a/examples/duo/duo.cpp b/examples/duo/duo.cpp index d94a2f34a..e9dc3bd11 100644 --- a/examples/duo/duo.cpp +++ b/examples/duo/duo.cpp @@ -7,54 +7,169 @@ #include #include +static void dbg_color(const std::string & s, const std::string & fg) +{ + static const std::string kReset = "\033[0m"; + static const std::string bold[] = { "", "\033[1m" }; + static size_t index = 0; + std::cout << bold[index] << fg << s << kReset << std::flush; + index = 1 - index; +} + +static void dbg_accepted(const std::string & accepted) +{ + static const std::string kGreen = "\033[32m"; + dbg_color(accepted, kGreen); +} + +static void dbg_not_matched(const std::string & accepted) +{ + dbg_color(accepted, ""); +} + +static void dbg_rejected(const std::string & rejected) +{ + static const std::string kRed = "\033[31m"; + dbg_color(rejected, kRed); +} + using llama_tokens = std::vector; struct speculation_context { - llama_tokens speculation; - int32_t instance_id; + llama_tokens candidate; + int32_t active_id; std::mutex mtx; + bool done; }; speculation_context spec_ctx; static void split_done_cb(int split) { - //fprintf(stderr, "split done: %d\n", split); if (split == 1 || split == 2) { std::lock_guard guard(spec_ctx.mtx); - spec_ctx.instance_id = 3 - split; + spec_ctx.active_id = 2 - split; } } -int main(int argc, char ** argv) { - gpt_params params; - - if (gpt_params_parse(argc, argv, params) == false) { - return 1; +// this ignores all the other sampling criteria +static std::vector greedy_tokens( + llama_model * model, + llama_context * ctx, + int32_t from_idx, + int32_t to_idx) +{ + auto n_vocab = llama_n_vocab(model); + std::vector res; + if (n_vocab <= 0) + { + return res; } - if (params.seed == LLAMA_DEFAULT_SEED) { - params.seed = time(NULL); + for (int idx = from_idx; idx < to_idx; idx++) + { + auto * logits = llama_get_logits_ith(ctx, idx); + llama_token new_token_id = 0; + for (llama_token token_id = 1; token_id < n_vocab; token_id++) + { + if (logits[token_id] > logits[new_token_id]) + { + new_token_id = token_id; + } + } + + res.push_back(new_token_id); } - llama_backend_init(); - llama_numa_init(params.numa); + return res; +} - llama_model * model = nullptr; - llama_context * ctx = nullptr; - params.cb_split_done = split_done_cb; - std::tie(model, ctx) = llama_init_from_gpt_params(params); +static int speculation( + std::vector model, + speculation_context * spec_ctx, + std::vector ctx, + std::vector input /* copy here */) { - llama_tokens input = llama_tokenize(ctx, params.prompt, true); - const size_t n_input = input.size(); + int32_t active = 1; - // print the prompt token-by-token - for (auto id : input) { - fprintf(stdout, "%s", llama_token_to_piece(ctx, id).c_str()); + llama_batch batch = llama_batch_init(512, 0, 1); + + for (size_t i = 0; i < input.size(); i++) + { + llama_batch_add(batch, input[i], i, { 0 }, false); + } + + batch.logits[batch.n_tokens - 1] = true; + + if (llama_decode(ctx[active], batch) != 0) { + LOG_TEE("%s: llama_decode() failed\n", __func__); + return 1; + } + + int logit_idx = batch.n_tokens - 1; + std::vector local_spec = input; + size_t match_len; + + while (true) { + auto next_tokens = greedy_tokens(model[active], ctx[active], logit_idx, logit_idx + 1); + if (next_tokens.size() != 1) { + fprintf(stderr, "invalid next tokens\n"); + return 1; } - fflush(stdout); + local_spec.push_back(next_tokens[0]); + + { + std::lock_guard _lock(spec_ctx->mtx); + if (spec_ctx->done) + { + break; + } + auto& spec = spec_ctx->candidate; + bool match = true; + match_len = local_spec.size() - 1; + for (size_t i = 0; i < std::min(spec.size(), local_spec.size()); i++) + { + if (spec[i] != local_spec[i]) + { + match = false; + match_len = i; + // here we need to clear both contexts + llama_kv_cache_seq_rm(ctx[0], 0, i, -1); + llama_kv_cache_seq_rm(ctx[1], 0, i, -1); + break; + } + } + if (match) { + spec = local_spec; + } else { + local_spec = spec; + } + active = spec_ctx->active_id; + } + + llama_batch_clear(batch); + // TODO theoretically this can be empty? + for (size_t i = match_len; i < local_spec.size(); i++) { + llama_batch_add(batch, local_spec[i], i, { 0 }, true); + } + + logit_idx = batch.n_tokens - 1; + + if (llama_decode(ctx[active], batch)) { + fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); + return 1; + } + } + + llama_batch_free(batch); + return 0; +} + +static int target(llama_model * model, llama_context * ctx, const llama_tokens& input, size_t n_predict) +{ + // TODO: batch size llama_batch batch = llama_batch_init(512, 0, 1); // evaluate the initial prompt @@ -69,80 +184,211 @@ int main(int argc, char ** argv) { return 1; } - int n_cur = batch.n_tokens; - int n_decode = 0; + // how many tokens are currently accepted + // TODO: rename to n_accepted + size_t n_cur = input.size(); + size_t n_decode = 0; const auto t_main_start = ggml_time_us(); // we'll use logits from this position to determine next token - int logit_idx = batch.n_tokens - 1; + int logits_from = batch.n_tokens - 1; + int logits_to = batch.n_tokens; - while (n_decode <= params.n_predict) { - // sample the next token + llama_tokens input_seq, next_tokens, output; + input_seq.push_back(input.back()); + + while (n_decode <= n_predict) + { + next_tokens = greedy_tokens(model, ctx, logits_from, logits_to); + if (next_tokens.size() != input_seq.size()) { - auto n_vocab = llama_n_vocab(model); - auto * logits = llama_get_logits_ith(ctx, logit_idx); - - std::vector candidates; - candidates.reserve(n_vocab); - - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); - } - - llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - - // sample the most likely token - const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); - - // is it an end of generation? - if (llama_token_is_eog(model, new_token_id) || n_decode >= params.n_predict) { - break; - } - - fprintf(stdout, "%s", llama_token_to_piece(ctx, new_token_id).c_str()); - fflush(stdout); - - // prepare the next batch - llama_batch_clear(batch); - - // push this new token for next evaluation - llama_batch_add(batch, new_token_id, n_cur, { 0 }, true); - - // we still use the 'original' token to sample on next iteration - logit_idx = batch.n_tokens - 1; - - n_decode += 1; + fprintf(stderr, "invalid next tokens\n"); + return 1; } - n_cur += 1; + size_t next_tokens_pos = n_cur; + // we always accept at least one new token + n_cur += 1; + n_decode += 1; + for (size_t i = 0; i + 1 < input_seq.size(); i++) + { + if (next_tokens[i] == input_seq[i + 1]) + { + n_cur += 1; + n_decode += 1; + } + else + { + // reject. next_tokens[i] is the last correct one. + next_tokens.erase(next_tokens.begin() + i + 1, next_tokens.end()); + break; + } + } - // evaluate the current batch with the transformer model + // empty the non-matching portion of kv cache. + // n_cur is incremented at least once and will be > 0 + llama_kv_cache_seq_rm(ctx, 0, n_cur - 1, -1); + + bool done = false; + for (size_t i = 0; i < next_tokens.size(); i++) + { + // TODO: what should we do here, is this correct + if (next_tokens[i] == llama_token_eos(model) || llama_token_is_eog(model, next_tokens[i])) + { + done = true; + next_tokens.erase(next_tokens.begin() + i, next_tokens.end()); + break; + } + } + output.insert(output.end(), next_tokens.begin(), next_tokens.end()); + + { + std::lock_guard _lock(spec_ctx.mtx); + auto & spec = spec_ctx.candidate; + size_t n_match = 0; + for (size_t i = 0; i < next_tokens.size() && i + next_tokens_pos < spec.size(); i++) + { + if (next_tokens[i] == spec[i + next_tokens_pos]) + { + n_match++; + } + else + { + break; + } + } + + std::string accepted = ""; + for (size_t i = next_tokens_pos; i < next_tokens_pos + n_match; i++) + { + accepted += llama_token_to_piece(ctx, spec[i]); + } + dbg_accepted(accepted); + if (n_match != next_tokens.size()) + { + std::string rejected = ""; + for (size_t i = next_tokens_pos + n_match; i < spec.size(); i++) + { + rejected += llama_token_to_piece(ctx, spec[i]); + } + dbg_rejected(rejected); + std::string not_matched = ""; + for (size_t i = n_match; i < next_tokens.size(); i++) + { + not_matched += llama_token_to_piece(ctx, next_tokens[i]); + } + dbg_not_matched(not_matched); + } + + // remove non-matched tokens + if (n_match != next_tokens.size()) + { + spec.erase(spec.begin() + next_tokens_pos, spec.end()); + for (const auto tok: next_tokens) + { + spec.push_back(tok); + } + } + input_seq.assign(spec.begin() + n_cur - 1, spec.end()); + } + if (n_decode >= n_predict || done) + { + break; + } + + llama_batch_clear(batch); + for (size_t i = 0; i < input_seq.size(); i++) + { + llama_batch_add(batch, input_seq[i], n_cur - 1 + i, { 0 }, true); + } if (llama_decode(ctx, batch)) { fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); return 1; } - // remove the cached entries from mock tokens - llama_kv_cache_seq_rm(ctx, 0, n_cur, -1); + logits_from = 0; + logits_to = input_seq.size(); } - LOG_TEE("\n"); - const auto t_main_end = ggml_time_us(); - LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n", + LOG_TEE("%s: decoded %zu tokens in %.2f s, speed: %.2f t/s\n", __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); - //llama_print_timings(ctx); - + llama_print_timings(ctx); fprintf(stderr, "\n"); + { + std::lock_guard _lock(spec_ctx.mtx); + spec_ctx.done = true; + } llama_batch_free(batch); + return 0; +} +int main(int argc, char ** argv) { + gpt_params params; + + if (gpt_params_parse(argc, argv, params) == false) { + return 1; + } + + if (params.seed == LLAMA_DEFAULT_SEED) { + params.seed = time(NULL); + } + + // parse 2 speculation rpc instances + std::string draft_rpcs = params.rpc_servers_draft; + size_t i = draft_rpcs.find(','); + if (i == std::string::npos || draft_rpcs.find(',', i + 1) != std::string::npos) + { + fprintf(stderr, "drpc must contain exactly two servers\n"); + return 1; + } + + llama_backend_init(); + llama_numa_init(params.numa); + + // main model and context + llama_model * model = nullptr; + llama_context * ctx = nullptr; + params.cb_split_done = split_done_cb; + std::tie(model, ctx) = llama_init_from_gpt_params(params); + + llama_tokens input = llama_tokenize(ctx, params.prompt, true); + spec_ctx.candidate = input; + + // prepare draft model and contexts. No need for two model instances? + std::vector draft_models = {nullptr, nullptr}; + std::vector draft_ctx = {nullptr, nullptr}; + + params.model = params.model_draft; + params.n_gpu_layers = params.n_gpu_layers_draft; + if (params.n_threads_draft > 0) + { + params.n_threads = params.n_threads_draft; + } + params.n_threads_batch = params.n_threads_batch_draft; + + params.rpc_servers = draft_rpcs.substr(0, i); + std::tie(draft_models[0], draft_ctx[0]) = llama_init_from_gpt_params(params); + params.rpc_servers = draft_rpcs.substr(i + 1); + std::tie(draft_models[1], draft_ctx[1]) = llama_init_from_gpt_params(params); + std::thread spec_thread = std::thread(speculation, draft_models, &spec_ctx, draft_ctx, input); + + target(model, ctx, input, params.n_predict); + + spec_thread.join(); + llama_free(ctx); + llama_free(draft_ctx[0]); + llama_free(draft_ctx[1]); + llama_free_model(model); + llama_free_model(draft_models[0]); + llama_free_model(draft_models[1]); llama_backend_free(); return 0; -} +} \ No newline at end of file