From 71fc16bb6cd92b842f1fb7425e3db48e86ef3e07 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 15 Nov 2024 08:20:28 +0200 Subject: [PATCH] speculative : refactor and add a simpler example ggml-ci --- common/CMakeLists.txt | 2 + common/sampling.cpp | 22 ++ common/sampling.h | 2 + common/speculative.cpp | 159 ++++++++++ common/speculative.h | 33 ++ examples/CMakeLists.txt | 1 + examples/speculative-simple/CMakeLists.txt | 5 + examples/speculative-simple/README.md | 3 + .../speculative-simple/speculative-simple.cpp | 285 ++++++++++++++++++ examples/speculative/speculative.cpp | 2 +- 10 files changed, 513 insertions(+), 1 deletion(-) create mode 100644 common/speculative.cpp create mode 100644 common/speculative.h create mode 100644 examples/speculative-simple/CMakeLists.txt create mode 100644 examples/speculative-simple/README.md create mode 100644 examples/speculative-simple/speculative-simple.cpp diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 5ab1ffa19..62a8a7db5 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -66,6 +66,8 @@ add_library(${TARGET} STATIC ngram-cache.h sampling.cpp sampling.h + speculative.cpp + speculative.h ) if (BUILD_SHARED_LIBS) diff --git a/common/sampling.cpp b/common/sampling.cpp index 7922fde47..fe1ef5bf9 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -320,6 +320,28 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co return cur_p.data[cur_p.selected].id; } +std::vector common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const std::vector & draft, bool grammar_first) { + GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1"); + + std::vector result; + result.reserve(idxs.size()); + + size_t i = 0; + for (; i < draft.size(); i++) { + const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first); + + if (draft[i] != id) { + break; + } + + result.push_back(id); + } + + result.push_back(common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first)); + + return result; +} + uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) { return llama_sampler_get_seed(gsmpl->chain); } diff --git a/common/sampling.h b/common/sampling.h index d37f25ad3..9e61690aa 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -60,6 +60,8 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam // llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false); +std::vector common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const std::vector & draft, bool grammar_first = false); + uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl); // helpers diff --git a/common/speculative.cpp b/common/speculative.cpp new file mode 100644 index 000000000..d16cc3c8e --- /dev/null +++ b/common/speculative.cpp @@ -0,0 +1,159 @@ +#include "speculative.h" + +#include "log.h" +#include "common.h" +#include "sampling.h" + +#include + +struct seq_draft { +}; + +struct common_speculative { + struct common_speculative_params params; + + llama_batch batch_dft; + + struct common_sampler * smpl; + + std::vector i_batch_tgt; + + std::vector tokens; +}; + +struct common_speculative * common_speculative_init(struct common_speculative_params params) { + auto * result = new common_speculative { + /* .params = */ params, + /* .batch_dft = */ llama_batch_init(llama_n_batch(params.ctx_dft), 0, 1), + /* .smpl = */ nullptr, + /* .i_batch_tgt = */ {}, + /* .tokens = */ {}, + }; + + // TODO: optimize or pass from outside? +#if 0 + { + common_sampler_params sparams; + sparams.no_perf = false; + + sparams.top_k = 40; + sparams.top_p = 0.9; + + sparams.samplers = { + COMMON_SAMPLER_TYPE_TOP_K, + COMMON_SAMPLER_TYPE_TOP_P, + COMMON_SAMPLER_TYPE_INFILL, + }; + + result->smpl = common_sampler_init(params.model_dft, sparams); + } +#else + { + common_sampler_params sparams; + sparams.no_perf = false; + + sparams.top_k = 10; + + sparams.samplers = { + COMMON_SAMPLER_TYPE_TOP_K, + }; + + result->smpl = common_sampler_init(params.model_dft, sparams); + } +#endif + + result->batch_dft = llama_batch_init(llama_n_batch(params.ctx_dft), 0, 1); + + return result; +} + +void common_speculative_free(struct common_speculative * spec) { + common_sampler_free(spec->smpl); + + llama_batch_free(spec->batch_dft); + + delete spec; +} + +void common_speculative_set_prompt(struct common_speculative * spec, llama_token * tokens, int32_t n_tokens) { + llama_kv_cache_clear(spec->params.ctx_dft); + + // TODO: error handling + llama_decode(spec->params.ctx_dft, llama_batch_get_one(tokens, n_tokens)); +} + +void common_speculative_add_draft( + struct common_speculative * spec, + struct llama_batch & batch_tgt, + llama_token id_last, + int n_past) { + spec->tokens.clear(); + + spec->i_batch_tgt.clear(); + spec->i_batch_tgt.push_back(0); + + common_sampler_reset(spec->smpl); + + common_batch_clear(spec->batch_dft); + common_batch_add (spec->batch_dft, id_last, n_past, { 0 }, true); + + llama_decode(spec->params.ctx_dft, spec->batch_dft); + + // sample n_draft tokens from the draft model + for (int i = 0; i < spec->params.n_draft; ++i) { + common_batch_clear(spec->batch_dft); + + common_sampler_sample(spec->smpl, spec->params.ctx_dft, 0, true); + + const auto * cur_p = common_sampler_get_candidates(spec->smpl); + + for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { + LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", + k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(spec->params.ctx_dft, cur_p->data[k].id).c_str()); + } + + // add drafted token for each sequence + const llama_token id = cur_p->data[0].id; + + // only collect very high-confidence draft tokens + if (cur_p->data[0].p < 0.75 && spec->tokens.size() >= 0) { + break; + } + + common_sampler_accept(spec->smpl, id, true); + + spec->tokens.push_back(id); + + // add unique drafted tokens to the target batch + spec->i_batch_tgt.push_back(batch_tgt.n_tokens); + + common_batch_add(batch_tgt, id, n_past + i + 1, { 0 }, true); + + if (batch_tgt.n_tokens > spec->params.n_draft) { + break; + } + + common_batch_add(spec->batch_dft, id, n_past + i + 1, { 0 }, true); + + // evaluate the drafted tokens on the draft model + llama_decode(spec->params.ctx_dft, spec->batch_dft); + } + + // don't waste time on small batches + // TODO: do not evaluate the draft model for tha many rounds + if (batch_tgt.n_tokens < spec->params.n_min) { + batch_tgt.n_tokens = 1; + spec->tokens.resize(0); + spec->i_batch_tgt.resize(1); + } + + // print current draft sequences + LOG_DBG("draft %s\n", string_from(spec->params.ctx_dft, spec->tokens).c_str()); +} + +std::vector common_speculative_sample( + struct common_speculative * spec, + struct common_sampler * smpl, + struct llama_context * ctx_tgt) { + return common_sampler_sample_n(smpl, ctx_tgt, spec->i_batch_tgt, spec->tokens); +} diff --git a/common/speculative.h b/common/speculative.h new file mode 100644 index 000000000..0952e5e70 --- /dev/null +++ b/common/speculative.h @@ -0,0 +1,33 @@ +#pragma once + +#include "llama.h" + +#include + +struct common_speculative; + +struct common_speculative_params { + int n_draft = 16; + int n_min = 5; // do not add drafts smaller than this, TODO: leave this to user? + + struct llama_model * model_dft = nullptr; + + struct llama_context * ctx_dft = nullptr; +}; + +struct common_speculative * common_speculative_init(struct common_speculative_params params); + +void common_speculative_free(struct common_speculative * spec); + +void common_speculative_set_prompt(struct common_speculative * spec, llama_token * tokens, int32_t n_tokens); + +void common_speculative_add_draft( + struct common_speculative * spec, + struct llama_batch & batch_tgt, + llama_token id_last, + int n_past); + +std::vector common_speculative_sample( + struct common_speculative * spec, + struct common_sampler * smpl, + struct llama_context * ctx_tgt); diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index d63a96c1c..9bd099d4e 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -50,5 +50,6 @@ else() add_subdirectory(simple) add_subdirectory(simple-chat) add_subdirectory(speculative) + add_subdirectory(speculative-simple) add_subdirectory(tokenize) endif() diff --git a/examples/speculative-simple/CMakeLists.txt b/examples/speculative-simple/CMakeLists.txt new file mode 100644 index 000000000..7a3a141c2 --- /dev/null +++ b/examples/speculative-simple/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET llama-speculative-simple) +add_executable(${TARGET} speculative-simple.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) diff --git a/examples/speculative-simple/README.md b/examples/speculative-simple/README.md new file mode 100644 index 000000000..6f3d6dc15 --- /dev/null +++ b/examples/speculative-simple/README.md @@ -0,0 +1,3 @@ +# llama.cpp/examples/speculative-simple + +Demonstration of basic greedy speculative decoding diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp new file mode 100644 index 000000000..31a09e61d --- /dev/null +++ b/examples/speculative-simple/speculative-simple.cpp @@ -0,0 +1,285 @@ +#include "arg.h" +#include "common.h" +#include "sampling.h" +#include "speculative.h" +#include "log.h" +#include "llama.h" + +#include +#include +#include +#include +#include + +#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128 +#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 + +struct seq_draft { + std::vector i_batch_tgt; + + std::vector tokens; + + struct common_sampler * smpl = nullptr; +}; + +int main(int argc, char ** argv) { + common_params params; + + // needed to get candidate probs even for temp <= 0.0 + params.sparams.n_probs = 128; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) { + return 1; + } + + if (params.n_predict < -1) { + LOG_ERR("%s: --n-predict must be >= -1\n", __func__); + return 1; + } + + common_init(); + + if (params.model_draft.empty()) { + LOG_ERR("%s: --model-draft is required\n", __func__); + return 1; + } + + // init llama.cpp + llama_backend_init(); + llama_numa_init(params.numa); + + llama_model * model_tgt = NULL; + llama_model * model_dft = NULL; + + llama_context * ctx_tgt = NULL; + llama_context * ctx_dft = NULL; + + // load the target model + common_init_result llama_init_tgt = common_init_from_params(params); + model_tgt = llama_init_tgt.model; + ctx_tgt = llama_init_tgt.context; + + // load the draft model + params.model = params.model_draft; + params.n_gpu_layers = params.n_gpu_layers_draft; + if (params.draft_cpuparams.n_threads > 0) { + params.cpuparams.n_threads = params.draft_cpuparams.n_threads; + } + + params.cpuparams_batch.n_threads = params.draft_cpuparams_batch.n_threads; + common_init_result llama_init_dft = common_init_from_params(params); + model_dft = llama_init_dft.model; + ctx_dft = llama_init_dft.context; + + const bool vocab_type_tgt = llama_vocab_type(model_tgt); + LOG_DBG("vocab_type tgt: %d\n", vocab_type_tgt); + + const bool vocab_type_dft = llama_vocab_type(model_dft); + LOG_DBG("vocab_type dft: %d\n", vocab_type_dft); + + if (vocab_type_tgt != vocab_type_dft) { + LOG_ERR("%s: draft model vocab type must match target model to use speculation but ", __func__); + LOG_ERR("vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt); + return 1; + } + + if ( + llama_add_bos_token(model_tgt) != llama_add_bos_token(model_dft) || + llama_add_eos_token(model_tgt) != llama_add_eos_token(model_dft) || + llama_token_bos(model_tgt) != llama_token_bos(model_dft) || + llama_token_eos(model_tgt) != llama_token_eos(model_dft) + ) { + LOG_ERR("%s: draft model special tokens must match target model to use speculation\n", __func__); + return 1; + } + + { + const int n_vocab_tgt = llama_n_vocab(model_tgt); + const int n_vocab_dft = llama_n_vocab(model_dft); + const int vocab_diff = n_vocab_tgt > n_vocab_dft + ? n_vocab_tgt - n_vocab_dft + : n_vocab_dft - n_vocab_tgt; + + if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) { + LOG_ERR("%s: draft model vocab must closely match target model to use speculation but ", __func__); + LOG_ERR("target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n", + n_vocab_tgt, llama_n_vocab(model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE); + return 1; + } + + for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) { + const char * token_text_tgt = llama_token_get_text(model_tgt, i); + const char * token_text_dft = llama_token_get_text(model_dft, i); + if (std::strcmp(token_text_tgt, token_text_dft) != 0) { + LOG_ERR("%s: draft model vocab must match target model to use speculation but ", __func__); + LOG_ERR("token %d content differs - target '%s', draft '%s'\n", i, + common_token_to_piece(ctx_tgt, i).c_str(), + common_token_to_piece(ctx_dft, i).c_str()); + return 1; + } + } + } + + + // Tokenize the prompt + std::vector inp; + inp = common_tokenize(ctx_tgt, params.prompt, true, true); + + const int max_context_size = llama_n_ctx(ctx_tgt); + const int max_tokens_list_size = max_context_size - 4; + + if ((int) inp.size() > max_tokens_list_size) { + LOG_ERR("%s: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), max_tokens_list_size); + return 1; + } + + LOG("\n\n"); + + for (auto id : inp) { + LOG("%s", common_token_to_piece(ctx_tgt, id).c_str()); + } + + const int n_input = inp.size(); + + const auto t_enc_start = ggml_time_us(); + + // eval the prompt + llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), n_input - 1)); + + // note: keep the last token separate! + llama_token id_last = inp.back(); + + int n_past = inp.size() - 1; + + // how many tokens to draft each time + int n_draft = params.n_draft; + + int n_predict = 0; + int n_drafted = 0; + int n_accept = 0; + + // used to determine end of generation + bool has_eos = false; + + // target model sampling context + struct common_sampler * smpl = common_sampler_init(model_tgt, params.sparams); + + // init the speculator + struct common_speculative_params params_spec; + params_spec.n_draft = n_draft; + params_spec.n_min = 5; + params_spec.model_dft = model_dft; + params_spec.ctx_dft = ctx_dft; + + struct common_speculative * spec = common_speculative_init(params_spec); + + // feed the prompt to the speculator + common_speculative_set_prompt(spec, inp.data(), n_input - 1); + + llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1); + + const auto t_enc_end = ggml_time_us(); + + const auto t_dec_start = ggml_time_us(); + + while (true) { + // always have a token to evaluate from before + common_batch_clear(batch_tgt); + common_batch_add (batch_tgt, id_last, n_past, { 0 }, true); + + // optionally, append draft tokens to the target batch + common_speculative_add_draft(spec, batch_tgt, id_last, n_past); + + // evaluate the target model on the drafted tokens + { + //LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str()); + + llama_decode(ctx_tgt, batch_tgt); + } + + // process the full target batch and return the accepted token based on the target sampler + const auto ids = common_speculative_sample(spec, smpl, ctx_tgt); + + n_past += ids.size(); + n_drafted += batch_tgt.n_tokens - 1; + n_accept += ids.size() - 1; + + // process the accepted tokens and update contexts + { + llama_token id; + std::string token_str; + + for (size_t i = 0; i < ids.size(); ++i) { + id = ids[i]; + + ++n_predict; + + if (llama_token_is_eog(model_tgt, id)) { + has_eos = true; + break; + } + + token_str = common_token_to_piece(ctx_tgt, id); + + if (params.use_color && i + 1 < ids.size()) { + LOG("\u001b[%dm%s\u001b[37m", (36 - 0 % 6), token_str.c_str()); + } else { + LOG("%s", token_str.c_str()); + } + } + + if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) { + break; + } + + LOG_DBG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", id, token_str.c_str()); + + { + LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past); + + llama_kv_cache_seq_rm(ctx_tgt, 0, n_past, -1); + llama_kv_cache_seq_rm(ctx_dft, 0, n_past, -1); + } + + id_last = id; + } + } + + auto t_dec_end = ggml_time_us(); + + LOG("\n\n"); + + LOG_INF("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f)); + LOG_INF("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f)); + + LOG_INF("\n"); + LOG_INF("n_draft = %d\n", n_draft); + LOG_INF("n_predict = %d\n", n_predict); + LOG_INF("n_drafted = %d\n", n_drafted); + LOG_INF("n_accept = %d\n", n_accept); + LOG_INF("accept = %.3f%%\n", 100.0f * n_accept / n_drafted); + + LOG_INF("\n"); + LOG_INF("draft:\n\n"); + + llama_perf_context_print(ctx_dft); + + LOG_INF("\n"); + LOG_INF("target:\n\n"); + common_perf_print(ctx_tgt, smpl); + + common_sampler_free(smpl); + common_speculative_free(spec); + + llama_free(ctx_tgt); + llama_free_model(model_tgt); + + llama_free(ctx_dft); + llama_free_model(model_dft); + + llama_backend_free(); + + LOG("\n\n"); + + return 0; +} diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 6cafd8a83..207b8ea34 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -12,7 +12,7 @@ #include #include -#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100 +#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128 #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 struct seq_draft {