diff --git a/common/sampling.cpp b/common/sampling.cpp index f90ac8b90..75e2e5d29 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -320,7 +320,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co return cur_p.data[cur_p.selected].id; } -std::vector common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const std::vector & draft, bool grammar_first) { +std::vector common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft, bool grammar_first) { GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1"); std::vector result; @@ -342,23 +342,10 @@ std::vector common_sampler_sample_n(struct common_sampler * gsmpl, return result; } -std::vector common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const struct llama_batch & batch, bool grammar_first) { - std::vector idxs; - idxs.reserve(batch.n_tokens); - - std::vector draft; - draft.reserve(batch.n_tokens); - - for (int i = 0; i < batch.n_tokens; i++) { - if (batch.logits[i] == 0) { - continue; - } - - if (idxs.size() > 0) { - GGML_ASSERT(batch.pos[idxs.back()] + 1 == batch.pos[i]); - draft.push_back(batch.token[i]); - } - idxs.push_back(i); +std::vector common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) { + std::vector idxs(draft.size() + 1); + for (size_t i = 0; i < idxs.size(); ++i) { + idxs[i] = i; } return common_sampler_sample_n(gsmpl, ctx, idxs, draft, grammar_first); diff --git a/common/sampling.h b/common/sampling.h index ba496ac27..f9b193ac8 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -71,9 +71,10 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co // // returns at least 1 token, up to idxs.size() // -std::vector common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const std::vector & draft, bool grammar_first = false); +std::vector common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft, bool grammar_first = false); -std::vector common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const struct llama_batch & batch, bool grammar_first = false); +// assume idxs == [ 0, 1, 2, ..., draft.size() ] +std::vector common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false); uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl); diff --git a/common/speculative.cpp b/common/speculative.cpp index 810fa93e4..eccba93e0 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -10,24 +10,19 @@ #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 struct common_speculative { - struct common_speculative_params params; - - llama_batch batch; - struct llama_context * ctx; struct common_sampler * smpl; + llama_batch batch; llama_tokens prompt; }; struct common_speculative * common_speculative_init( - struct common_speculative_params params, struct llama_context * ctx_dft) { auto * result = new common_speculative { - /* .params = */ params, - /* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1), /* .ctx = */ ctx_dft, /* .smpl = */ nullptr, + /* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1), /* .prompt = */ {}, }; @@ -130,12 +125,11 @@ bool common_speculative_are_compatible( return true; } -void common_speculative_add_draft( +llama_tokens common_speculative_gen_draft( struct common_speculative * spec, - struct llama_batch & batch_tgt, + struct common_speculative_params params, const llama_tokens & prompt_tgt, - llama_token id_last, - llama_token n_past_tgt) { + llama_token id_last) { auto & batch = spec->batch; auto & ctx = spec->ctx; auto & smpl = spec->smpl; @@ -144,7 +138,7 @@ void common_speculative_add_draft( int reuse_i = 0; int reuse_n = 0; - const int n_ctx = llama_n_ctx(ctx) - spec->params.n_draft; + const int n_ctx = llama_n_ctx(ctx) - params.n_draft; const int i_start = std::max(0, (int) prompt_tgt.size() - n_ctx); @@ -156,7 +150,7 @@ void common_speculative_add_draft( cur++; } - if ((cur >= spec->params.n_reuse || prompt_tgt.size() <= n_ctx) && cur > reuse_n) { + if ((cur >= params.n_reuse || prompt_tgt.size() <= n_ctx) && cur > reuse_n) { reuse_i = i; reuse_n = cur; } @@ -207,8 +201,11 @@ void common_speculative_add_draft( common_sampler_reset(smpl); + llama_tokens result; + result.reserve(params.n_draft); + // sample n_draft tokens from the draft model - for (int i = 0; i < spec->params.n_draft; ++i) { + for (int i = 0; i < params.n_draft; ++i) { common_batch_clear(batch); common_sampler_sample(smpl, ctx, 0, true); @@ -224,15 +221,15 @@ void common_speculative_add_draft( const llama_token id = cur_p->data[0].id; // only collect very high-confidence draft tokens - if (cur_p->data[0].p < spec->params.p_min) { + if (cur_p->data[0].p < params.p_min) { break; } common_sampler_accept(smpl, id, true); - common_batch_add(batch_tgt, id, n_past_tgt + i, { 0 }, true); + result.push_back(id); - if (batch_tgt.n_tokens > spec->params.n_draft) { + if (result.size() >= params.n_draft) { break; } @@ -244,9 +241,5 @@ void common_speculative_add_draft( prompt.push_back(id); } - // don't waste time on small batches - // TODO: do not evaluate the draft model for that many rounds - if (batch_tgt.n_tokens < spec->params.n_min) { - batch_tgt.n_tokens = 1; - } + return result; } diff --git a/common/speculative.h b/common/speculative.h index b657b6229..9fb669fde 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -7,15 +7,12 @@ struct common_speculative; struct common_speculative_params { int n_draft = 16; - int n_min = 5; // do not add drafts smaller than this, TODO: leave this to user? int n_reuse = 256; float p_min = 0.9f; }; -struct common_speculative * common_speculative_init( - struct common_speculative_params params, - struct llama_context * ctx_dft); +struct common_speculative * common_speculative_init(struct llama_context * ctx_dft); void common_speculative_free(struct common_speculative * spec); @@ -25,9 +22,8 @@ bool common_speculative_are_compatible( // sample up to n_draft tokens and add them to the batch using the draft model // -void common_speculative_add_draft( +llama_tokens common_speculative_gen_draft( struct common_speculative * spec, - struct llama_batch & batch_tgt, + struct common_speculative_params params, const llama_tokens & prompt, - llama_token id_last, - llama_token n_past_tgt); + llama_token id_last); diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index cdfd5b886..d7e572cf8 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -13,6 +13,9 @@ int main(int argc, char ** argv) { common_params params; + // minimum size of the draft to use + const int n_min = 5; + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) { return 1; } @@ -92,31 +95,29 @@ int main(int argc, char ** argv) { // everything until here is standard initialization // the relevant stuff for speculative decoding starts here - const int n_input = inp.size(); - const auto t_enc_start = ggml_time_us(); // target model sampling context struct common_sampler * smpl = common_sampler_init(model_tgt, params.sparams); // eval the prompt - llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), n_input - 1)); + llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1)); // note: keep the last token separate! llama_token id_last = inp.back(); - auto prompt_dft = std::vector(inp.begin(), inp.end() - 1); + // all tokens currently in the target context + auto prompt_tgt = std::vector(inp.begin(), inp.end() - 1); int n_past = inp.size() - 1; // init the speculator struct common_speculative_params params_spec; params_spec.n_draft = n_draft; - params_spec.n_min = 5; params_spec.n_reuse = 256; params_spec.p_min = 0.9f; - struct common_speculative * spec = common_speculative_init(params_spec, ctx_dft); + struct common_speculative * spec = common_speculative_init(ctx_dft); llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1); @@ -125,21 +126,30 @@ int main(int argc, char ** argv) { const auto t_dec_start = ggml_time_us(); while (true) { - // always have a token to evaluate from before - common_batch_clear(batch_tgt); - common_batch_add (batch_tgt, id_last, n_past, { 0 }, true); - - // optionally, append draft tokens to the target batch + // optionally, generate draft tokens that can be appended to the target batch // // this is the most important part of the speculation. the more probable tokens that are provided here // the better the performance will be. in theory, this computation can be performed asynchronously and even // offloaded to a remote device. it doesn't even have to be based on an LLM. instead, it can provide tokens // from a cache or lookup tables. // - common_speculative_add_draft(spec, batch_tgt, prompt_dft, id_last, n_past + 1); + llama_tokens draft = common_speculative_gen_draft(spec, params_spec, prompt_tgt, id_last); + + // always have a token to evaluate from before - id_last + common_batch_clear(batch_tgt); + common_batch_add (batch_tgt, id_last, n_past++, { 0 }, true); // evaluate the target model on [id_last, draft0, draft1, ..., draftN-1] { + // do not waste time on small drafts + if (draft.size() < n_min) { + draft.clear(); + } + + for (size_t i = 0; i < draft.size(); ++i) { + common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true); + } + //LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str()); llama_decode(ctx_tgt, batch_tgt); @@ -152,11 +162,11 @@ int main(int argc, char ** argv) { // available logits from the batch and sample the next token until we run out of logits or the sampler // disagrees with the draft // - const auto ids = common_sampler_sample_n(smpl, ctx_tgt, batch_tgt); + const auto ids = common_sampler_sample_n(smpl, ctx_tgt, draft); GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token - n_past += ids.size(); + n_past += ids.size() - 1; n_drafted += batch_tgt.n_tokens - 1; n_accept += ids.size() - 1; @@ -192,7 +202,7 @@ int main(int argc, char ** argv) { break; } - LOG_DBG("accepted %d draft tokens, the last target token is: (%d, '%s')\n", (int) ids.size() - 1, id, token_str.c_str()); + LOG_DBG("accepted %d/%d draft tokens, the last target token is: (%d, '%s')\n", (int) ids.size() - 1, (int) draft.size(), id, token_str.c_str()); { LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past); @@ -200,8 +210,8 @@ int main(int argc, char ** argv) { llama_kv_cache_seq_rm(ctx_tgt, 0, n_past, -1); } - prompt_dft.push_back(id_last); - prompt_dft.insert(prompt_dft.end(), ids.begin(), ids.end() - 1); + prompt_tgt.push_back(id_last); + prompt_tgt.insert(prompt_tgt.end(), ids.begin(), ids.end() - 1); // remember the last accepted token for the next iteration id_last = id; @@ -210,6 +220,8 @@ int main(int argc, char ** argv) { auto t_dec_end = ggml_time_us(); + const int n_input = inp.size(); + LOG("\n\n"); LOG_INF("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f));