speculative : refactor and add a simpler example

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-11-15 08:20:28 +02:00
parent 1bb30bf28c
commit 71fc16bb6c
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
10 changed files with 513 additions and 1 deletions

View file

@ -66,6 +66,8 @@ add_library(${TARGET} STATIC
ngram-cache.h
sampling.cpp
sampling.h
speculative.cpp
speculative.h
)
if (BUILD_SHARED_LIBS)

View file

@ -320,6 +320,28 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
return cur_p.data[cur_p.selected].id;
}
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const std::vector<llama_token> & draft, bool grammar_first) {
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
std::vector<llama_token> result;
result.reserve(idxs.size());
size_t i = 0;
for (; i < draft.size(); i++) {
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
if (draft[i] != id) {
break;
}
result.push_back(id);
}
result.push_back(common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first));
return result;
}
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
return llama_sampler_get_seed(gsmpl->chain);
}

View file

@ -60,6 +60,8 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
//
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const std::vector<llama_token> & draft, bool grammar_first = false);
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
// helpers

159
common/speculative.cpp Normal file
View file

@ -0,0 +1,159 @@
#include "speculative.h"
#include "log.h"
#include "common.h"
#include "sampling.h"
#include <vector>
struct seq_draft {
};
struct common_speculative {
struct common_speculative_params params;
llama_batch batch_dft;
struct common_sampler * smpl;
std::vector<int> i_batch_tgt;
std::vector<llama_token> tokens;
};
struct common_speculative * common_speculative_init(struct common_speculative_params params) {
auto * result = new common_speculative {
/* .params = */ params,
/* .batch_dft = */ llama_batch_init(llama_n_batch(params.ctx_dft), 0, 1),
/* .smpl = */ nullptr,
/* .i_batch_tgt = */ {},
/* .tokens = */ {},
};
// TODO: optimize or pass from outside?
#if 0
{
common_sampler_params sparams;
sparams.no_perf = false;
sparams.top_k = 40;
sparams.top_p = 0.9;
sparams.samplers = {
COMMON_SAMPLER_TYPE_TOP_K,
COMMON_SAMPLER_TYPE_TOP_P,
COMMON_SAMPLER_TYPE_INFILL,
};
result->smpl = common_sampler_init(params.model_dft, sparams);
}
#else
{
common_sampler_params sparams;
sparams.no_perf = false;
sparams.top_k = 10;
sparams.samplers = {
COMMON_SAMPLER_TYPE_TOP_K,
};
result->smpl = common_sampler_init(params.model_dft, sparams);
}
#endif
result->batch_dft = llama_batch_init(llama_n_batch(params.ctx_dft), 0, 1);
return result;
}
void common_speculative_free(struct common_speculative * spec) {
common_sampler_free(spec->smpl);
llama_batch_free(spec->batch_dft);
delete spec;
}
void common_speculative_set_prompt(struct common_speculative * spec, llama_token * tokens, int32_t n_tokens) {
llama_kv_cache_clear(spec->params.ctx_dft);
// TODO: error handling
llama_decode(spec->params.ctx_dft, llama_batch_get_one(tokens, n_tokens));
}
void common_speculative_add_draft(
struct common_speculative * spec,
struct llama_batch & batch_tgt,
llama_token id_last,
int n_past) {
spec->tokens.clear();
spec->i_batch_tgt.clear();
spec->i_batch_tgt.push_back(0);
common_sampler_reset(spec->smpl);
common_batch_clear(spec->batch_dft);
common_batch_add (spec->batch_dft, id_last, n_past, { 0 }, true);
llama_decode(spec->params.ctx_dft, spec->batch_dft);
// sample n_draft tokens from the draft model
for (int i = 0; i < spec->params.n_draft; ++i) {
common_batch_clear(spec->batch_dft);
common_sampler_sample(spec->smpl, spec->params.ctx_dft, 0, true);
const auto * cur_p = common_sampler_get_candidates(spec->smpl);
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(spec->params.ctx_dft, cur_p->data[k].id).c_str());
}
// add drafted token for each sequence
const llama_token id = cur_p->data[0].id;
// only collect very high-confidence draft tokens
if (cur_p->data[0].p < 0.75 && spec->tokens.size() >= 0) {
break;
}
common_sampler_accept(spec->smpl, id, true);
spec->tokens.push_back(id);
// add unique drafted tokens to the target batch
spec->i_batch_tgt.push_back(batch_tgt.n_tokens);
common_batch_add(batch_tgt, id, n_past + i + 1, { 0 }, true);
if (batch_tgt.n_tokens > spec->params.n_draft) {
break;
}
common_batch_add(spec->batch_dft, id, n_past + i + 1, { 0 }, true);
// evaluate the drafted tokens on the draft model
llama_decode(spec->params.ctx_dft, spec->batch_dft);
}
// don't waste time on small batches
// TODO: do not evaluate the draft model for tha many rounds
if (batch_tgt.n_tokens < spec->params.n_min) {
batch_tgt.n_tokens = 1;
spec->tokens.resize(0);
spec->i_batch_tgt.resize(1);
}
// print current draft sequences
LOG_DBG("draft %s\n", string_from(spec->params.ctx_dft, spec->tokens).c_str());
}
std::vector<llama_token> common_speculative_sample(
struct common_speculative * spec,
struct common_sampler * smpl,
struct llama_context * ctx_tgt) {
return common_sampler_sample_n(smpl, ctx_tgt, spec->i_batch_tgt, spec->tokens);
}

33
common/speculative.h Normal file
View file

@ -0,0 +1,33 @@
#pragma once
#include "llama.h"
#include <vector>
struct common_speculative;
struct common_speculative_params {
int n_draft = 16;
int n_min = 5; // do not add drafts smaller than this, TODO: leave this to user?
struct llama_model * model_dft = nullptr;
struct llama_context * ctx_dft = nullptr;
};
struct common_speculative * common_speculative_init(struct common_speculative_params params);
void common_speculative_free(struct common_speculative * spec);
void common_speculative_set_prompt(struct common_speculative * spec, llama_token * tokens, int32_t n_tokens);
void common_speculative_add_draft(
struct common_speculative * spec,
struct llama_batch & batch_tgt,
llama_token id_last,
int n_past);
std::vector<llama_token> common_speculative_sample(
struct common_speculative * spec,
struct common_sampler * smpl,
struct llama_context * ctx_tgt);

View file

@ -50,5 +50,6 @@ else()
add_subdirectory(simple)
add_subdirectory(simple-chat)
add_subdirectory(speculative)
add_subdirectory(speculative-simple)
add_subdirectory(tokenize)
endif()

View file

@ -0,0 +1,5 @@
set(TARGET llama-speculative-simple)
add_executable(${TARGET} speculative-simple.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)

View file

@ -0,0 +1,3 @@
# llama.cpp/examples/speculative-simple
Demonstration of basic greedy speculative decoding

View file

@ -0,0 +1,285 @@
#include "arg.h"
#include "common.h"
#include "sampling.h"
#include "speculative.h"
#include "log.h"
#include "llama.h"
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <string>
#include <vector>
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
struct seq_draft {
std::vector<int> i_batch_tgt;
std::vector<llama_token> tokens;
struct common_sampler * smpl = nullptr;
};
int main(int argc, char ** argv) {
common_params params;
// needed to get candidate probs even for temp <= 0.0
params.sparams.n_probs = 128;
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) {
return 1;
}
if (params.n_predict < -1) {
LOG_ERR("%s: --n-predict must be >= -1\n", __func__);
return 1;
}
common_init();
if (params.model_draft.empty()) {
LOG_ERR("%s: --model-draft is required\n", __func__);
return 1;
}
// init llama.cpp
llama_backend_init();
llama_numa_init(params.numa);
llama_model * model_tgt = NULL;
llama_model * model_dft = NULL;
llama_context * ctx_tgt = NULL;
llama_context * ctx_dft = NULL;
// load the target model
common_init_result llama_init_tgt = common_init_from_params(params);
model_tgt = llama_init_tgt.model;
ctx_tgt = llama_init_tgt.context;
// load the draft model
params.model = params.model_draft;
params.n_gpu_layers = params.n_gpu_layers_draft;
if (params.draft_cpuparams.n_threads > 0) {
params.cpuparams.n_threads = params.draft_cpuparams.n_threads;
}
params.cpuparams_batch.n_threads = params.draft_cpuparams_batch.n_threads;
common_init_result llama_init_dft = common_init_from_params(params);
model_dft = llama_init_dft.model;
ctx_dft = llama_init_dft.context;
const bool vocab_type_tgt = llama_vocab_type(model_tgt);
LOG_DBG("vocab_type tgt: %d\n", vocab_type_tgt);
const bool vocab_type_dft = llama_vocab_type(model_dft);
LOG_DBG("vocab_type dft: %d\n", vocab_type_dft);
if (vocab_type_tgt != vocab_type_dft) {
LOG_ERR("%s: draft model vocab type must match target model to use speculation but ", __func__);
LOG_ERR("vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt);
return 1;
}
if (
llama_add_bos_token(model_tgt) != llama_add_bos_token(model_dft) ||
llama_add_eos_token(model_tgt) != llama_add_eos_token(model_dft) ||
llama_token_bos(model_tgt) != llama_token_bos(model_dft) ||
llama_token_eos(model_tgt) != llama_token_eos(model_dft)
) {
LOG_ERR("%s: draft model special tokens must match target model to use speculation\n", __func__);
return 1;
}
{
const int n_vocab_tgt = llama_n_vocab(model_tgt);
const int n_vocab_dft = llama_n_vocab(model_dft);
const int vocab_diff = n_vocab_tgt > n_vocab_dft
? n_vocab_tgt - n_vocab_dft
: n_vocab_dft - n_vocab_tgt;
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
LOG_ERR("%s: draft model vocab must closely match target model to use speculation but ", __func__);
LOG_ERR("target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
n_vocab_tgt, llama_n_vocab(model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
return 1;
}
for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
const char * token_text_tgt = llama_token_get_text(model_tgt, i);
const char * token_text_dft = llama_token_get_text(model_dft, i);
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
LOG_ERR("%s: draft model vocab must match target model to use speculation but ", __func__);
LOG_ERR("token %d content differs - target '%s', draft '%s'\n", i,
common_token_to_piece(ctx_tgt, i).c_str(),
common_token_to_piece(ctx_dft, i).c_str());
return 1;
}
}
}
// Tokenize the prompt
std::vector<llama_token> inp;
inp = common_tokenize(ctx_tgt, params.prompt, true, true);
const int max_context_size = llama_n_ctx(ctx_tgt);
const int max_tokens_list_size = max_context_size - 4;
if ((int) inp.size() > max_tokens_list_size) {
LOG_ERR("%s: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), max_tokens_list_size);
return 1;
}
LOG("\n\n");
for (auto id : inp) {
LOG("%s", common_token_to_piece(ctx_tgt, id).c_str());
}
const int n_input = inp.size();
const auto t_enc_start = ggml_time_us();
// eval the prompt
llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), n_input - 1));
// note: keep the last token separate!
llama_token id_last = inp.back();
int n_past = inp.size() - 1;
// how many tokens to draft each time
int n_draft = params.n_draft;
int n_predict = 0;
int n_drafted = 0;
int n_accept = 0;
// used to determine end of generation
bool has_eos = false;
// target model sampling context
struct common_sampler * smpl = common_sampler_init(model_tgt, params.sparams);
// init the speculator
struct common_speculative_params params_spec;
params_spec.n_draft = n_draft;
params_spec.n_min = 5;
params_spec.model_dft = model_dft;
params_spec.ctx_dft = ctx_dft;
struct common_speculative * spec = common_speculative_init(params_spec);
// feed the prompt to the speculator
common_speculative_set_prompt(spec, inp.data(), n_input - 1);
llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1);
const auto t_enc_end = ggml_time_us();
const auto t_dec_start = ggml_time_us();
while (true) {
// always have a token to evaluate from before
common_batch_clear(batch_tgt);
common_batch_add (batch_tgt, id_last, n_past, { 0 }, true);
// optionally, append draft tokens to the target batch
common_speculative_add_draft(spec, batch_tgt, id_last, n_past);
// evaluate the target model on the drafted tokens
{
//LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str());
llama_decode(ctx_tgt, batch_tgt);
}
// process the full target batch and return the accepted token based on the target sampler
const auto ids = common_speculative_sample(spec, smpl, ctx_tgt);
n_past += ids.size();
n_drafted += batch_tgt.n_tokens - 1;
n_accept += ids.size() - 1;
// process the accepted tokens and update contexts
{
llama_token id;
std::string token_str;
for (size_t i = 0; i < ids.size(); ++i) {
id = ids[i];
++n_predict;
if (llama_token_is_eog(model_tgt, id)) {
has_eos = true;
break;
}
token_str = common_token_to_piece(ctx_tgt, id);
if (params.use_color && i + 1 < ids.size()) {
LOG("\u001b[%dm%s\u001b[37m", (36 - 0 % 6), token_str.c_str());
} else {
LOG("%s", token_str.c_str());
}
}
if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
break;
}
LOG_DBG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", id, token_str.c_str());
{
LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);
llama_kv_cache_seq_rm(ctx_tgt, 0, n_past, -1);
llama_kv_cache_seq_rm(ctx_dft, 0, n_past, -1);
}
id_last = id;
}
}
auto t_dec_end = ggml_time_us();
LOG("\n\n");
LOG_INF("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f));
LOG_INF("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f));
LOG_INF("\n");
LOG_INF("n_draft = %d\n", n_draft);
LOG_INF("n_predict = %d\n", n_predict);
LOG_INF("n_drafted = %d\n", n_drafted);
LOG_INF("n_accept = %d\n", n_accept);
LOG_INF("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);
LOG_INF("\n");
LOG_INF("draft:\n\n");
llama_perf_context_print(ctx_dft);
LOG_INF("\n");
LOG_INF("target:\n\n");
common_perf_print(ctx_tgt, smpl);
common_sampler_free(smpl);
common_speculative_free(spec);
llama_free(ctx_tgt);
llama_free_model(model_tgt);
llama_free(ctx_dft);
llama_free_model(model_dft);
llama_backend_free();
LOG("\n\n");
return 0;
}

View file

@ -12,7 +12,7 @@
#include <string>
#include <vector>
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
struct seq_draft {