speculative : refactor and add a simpler example
ggml-ci
This commit is contained in:
parent
1bb30bf28c
commit
71fc16bb6c
10 changed files with 513 additions and 1 deletions
|
@ -66,6 +66,8 @@ add_library(${TARGET} STATIC
|
||||||
ngram-cache.h
|
ngram-cache.h
|
||||||
sampling.cpp
|
sampling.cpp
|
||||||
sampling.h
|
sampling.h
|
||||||
|
speculative.cpp
|
||||||
|
speculative.h
|
||||||
)
|
)
|
||||||
|
|
||||||
if (BUILD_SHARED_LIBS)
|
if (BUILD_SHARED_LIBS)
|
||||||
|
|
|
@ -320,6 +320,28 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
|
||||||
return cur_p.data[cur_p.selected].id;
|
return cur_p.data[cur_p.selected].id;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const std::vector<llama_token> & draft, bool grammar_first) {
|
||||||
|
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
|
||||||
|
|
||||||
|
std::vector<llama_token> result;
|
||||||
|
result.reserve(idxs.size());
|
||||||
|
|
||||||
|
size_t i = 0;
|
||||||
|
for (; i < draft.size(); i++) {
|
||||||
|
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
|
||||||
|
|
||||||
|
if (draft[i] != id) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
result.push_back(id);
|
||||||
|
}
|
||||||
|
|
||||||
|
result.push_back(common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first));
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
|
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
|
||||||
return llama_sampler_get_seed(gsmpl->chain);
|
return llama_sampler_get_seed(gsmpl->chain);
|
||||||
}
|
}
|
||||||
|
|
|
@ -60,6 +60,8 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
|
||||||
//
|
//
|
||||||
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
|
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
|
||||||
|
|
||||||
|
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const std::vector<llama_token> & draft, bool grammar_first = false);
|
||||||
|
|
||||||
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
|
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
|
||||||
|
|
||||||
// helpers
|
// helpers
|
||||||
|
|
159
common/speculative.cpp
Normal file
159
common/speculative.cpp
Normal file
|
@ -0,0 +1,159 @@
|
||||||
|
#include "speculative.h"
|
||||||
|
|
||||||
|
#include "log.h"
|
||||||
|
#include "common.h"
|
||||||
|
#include "sampling.h"
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
struct seq_draft {
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_speculative {
|
||||||
|
struct common_speculative_params params;
|
||||||
|
|
||||||
|
llama_batch batch_dft;
|
||||||
|
|
||||||
|
struct common_sampler * smpl;
|
||||||
|
|
||||||
|
std::vector<int> i_batch_tgt;
|
||||||
|
|
||||||
|
std::vector<llama_token> tokens;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_speculative * common_speculative_init(struct common_speculative_params params) {
|
||||||
|
auto * result = new common_speculative {
|
||||||
|
/* .params = */ params,
|
||||||
|
/* .batch_dft = */ llama_batch_init(llama_n_batch(params.ctx_dft), 0, 1),
|
||||||
|
/* .smpl = */ nullptr,
|
||||||
|
/* .i_batch_tgt = */ {},
|
||||||
|
/* .tokens = */ {},
|
||||||
|
};
|
||||||
|
|
||||||
|
// TODO: optimize or pass from outside?
|
||||||
|
#if 0
|
||||||
|
{
|
||||||
|
common_sampler_params sparams;
|
||||||
|
sparams.no_perf = false;
|
||||||
|
|
||||||
|
sparams.top_k = 40;
|
||||||
|
sparams.top_p = 0.9;
|
||||||
|
|
||||||
|
sparams.samplers = {
|
||||||
|
COMMON_SAMPLER_TYPE_TOP_K,
|
||||||
|
COMMON_SAMPLER_TYPE_TOP_P,
|
||||||
|
COMMON_SAMPLER_TYPE_INFILL,
|
||||||
|
};
|
||||||
|
|
||||||
|
result->smpl = common_sampler_init(params.model_dft, sparams);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
{
|
||||||
|
common_sampler_params sparams;
|
||||||
|
sparams.no_perf = false;
|
||||||
|
|
||||||
|
sparams.top_k = 10;
|
||||||
|
|
||||||
|
sparams.samplers = {
|
||||||
|
COMMON_SAMPLER_TYPE_TOP_K,
|
||||||
|
};
|
||||||
|
|
||||||
|
result->smpl = common_sampler_init(params.model_dft, sparams);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
result->batch_dft = llama_batch_init(llama_n_batch(params.ctx_dft), 0, 1);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_speculative_free(struct common_speculative * spec) {
|
||||||
|
common_sampler_free(spec->smpl);
|
||||||
|
|
||||||
|
llama_batch_free(spec->batch_dft);
|
||||||
|
|
||||||
|
delete spec;
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_speculative_set_prompt(struct common_speculative * spec, llama_token * tokens, int32_t n_tokens) {
|
||||||
|
llama_kv_cache_clear(spec->params.ctx_dft);
|
||||||
|
|
||||||
|
// TODO: error handling
|
||||||
|
llama_decode(spec->params.ctx_dft, llama_batch_get_one(tokens, n_tokens));
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_speculative_add_draft(
|
||||||
|
struct common_speculative * spec,
|
||||||
|
struct llama_batch & batch_tgt,
|
||||||
|
llama_token id_last,
|
||||||
|
int n_past) {
|
||||||
|
spec->tokens.clear();
|
||||||
|
|
||||||
|
spec->i_batch_tgt.clear();
|
||||||
|
spec->i_batch_tgt.push_back(0);
|
||||||
|
|
||||||
|
common_sampler_reset(spec->smpl);
|
||||||
|
|
||||||
|
common_batch_clear(spec->batch_dft);
|
||||||
|
common_batch_add (spec->batch_dft, id_last, n_past, { 0 }, true);
|
||||||
|
|
||||||
|
llama_decode(spec->params.ctx_dft, spec->batch_dft);
|
||||||
|
|
||||||
|
// sample n_draft tokens from the draft model
|
||||||
|
for (int i = 0; i < spec->params.n_draft; ++i) {
|
||||||
|
common_batch_clear(spec->batch_dft);
|
||||||
|
|
||||||
|
common_sampler_sample(spec->smpl, spec->params.ctx_dft, 0, true);
|
||||||
|
|
||||||
|
const auto * cur_p = common_sampler_get_candidates(spec->smpl);
|
||||||
|
|
||||||
|
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
|
||||||
|
LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
||||||
|
k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(spec->params.ctx_dft, cur_p->data[k].id).c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
// add drafted token for each sequence
|
||||||
|
const llama_token id = cur_p->data[0].id;
|
||||||
|
|
||||||
|
// only collect very high-confidence draft tokens
|
||||||
|
if (cur_p->data[0].p < 0.75 && spec->tokens.size() >= 0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
common_sampler_accept(spec->smpl, id, true);
|
||||||
|
|
||||||
|
spec->tokens.push_back(id);
|
||||||
|
|
||||||
|
// add unique drafted tokens to the target batch
|
||||||
|
spec->i_batch_tgt.push_back(batch_tgt.n_tokens);
|
||||||
|
|
||||||
|
common_batch_add(batch_tgt, id, n_past + i + 1, { 0 }, true);
|
||||||
|
|
||||||
|
if (batch_tgt.n_tokens > spec->params.n_draft) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
common_batch_add(spec->batch_dft, id, n_past + i + 1, { 0 }, true);
|
||||||
|
|
||||||
|
// evaluate the drafted tokens on the draft model
|
||||||
|
llama_decode(spec->params.ctx_dft, spec->batch_dft);
|
||||||
|
}
|
||||||
|
|
||||||
|
// don't waste time on small batches
|
||||||
|
// TODO: do not evaluate the draft model for tha many rounds
|
||||||
|
if (batch_tgt.n_tokens < spec->params.n_min) {
|
||||||
|
batch_tgt.n_tokens = 1;
|
||||||
|
spec->tokens.resize(0);
|
||||||
|
spec->i_batch_tgt.resize(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// print current draft sequences
|
||||||
|
LOG_DBG("draft %s\n", string_from(spec->params.ctx_dft, spec->tokens).c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<llama_token> common_speculative_sample(
|
||||||
|
struct common_speculative * spec,
|
||||||
|
struct common_sampler * smpl,
|
||||||
|
struct llama_context * ctx_tgt) {
|
||||||
|
return common_sampler_sample_n(smpl, ctx_tgt, spec->i_batch_tgt, spec->tokens);
|
||||||
|
}
|
33
common/speculative.h
Normal file
33
common/speculative.h
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "llama.h"
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
struct common_speculative;
|
||||||
|
|
||||||
|
struct common_speculative_params {
|
||||||
|
int n_draft = 16;
|
||||||
|
int n_min = 5; // do not add drafts smaller than this, TODO: leave this to user?
|
||||||
|
|
||||||
|
struct llama_model * model_dft = nullptr;
|
||||||
|
|
||||||
|
struct llama_context * ctx_dft = nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_speculative * common_speculative_init(struct common_speculative_params params);
|
||||||
|
|
||||||
|
void common_speculative_free(struct common_speculative * spec);
|
||||||
|
|
||||||
|
void common_speculative_set_prompt(struct common_speculative * spec, llama_token * tokens, int32_t n_tokens);
|
||||||
|
|
||||||
|
void common_speculative_add_draft(
|
||||||
|
struct common_speculative * spec,
|
||||||
|
struct llama_batch & batch_tgt,
|
||||||
|
llama_token id_last,
|
||||||
|
int n_past);
|
||||||
|
|
||||||
|
std::vector<llama_token> common_speculative_sample(
|
||||||
|
struct common_speculative * spec,
|
||||||
|
struct common_sampler * smpl,
|
||||||
|
struct llama_context * ctx_tgt);
|
|
@ -50,5 +50,6 @@ else()
|
||||||
add_subdirectory(simple)
|
add_subdirectory(simple)
|
||||||
add_subdirectory(simple-chat)
|
add_subdirectory(simple-chat)
|
||||||
add_subdirectory(speculative)
|
add_subdirectory(speculative)
|
||||||
|
add_subdirectory(speculative-simple)
|
||||||
add_subdirectory(tokenize)
|
add_subdirectory(tokenize)
|
||||||
endif()
|
endif()
|
||||||
|
|
5
examples/speculative-simple/CMakeLists.txt
Normal file
5
examples/speculative-simple/CMakeLists.txt
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
set(TARGET llama-speculative-simple)
|
||||||
|
add_executable(${TARGET} speculative-simple.cpp)
|
||||||
|
install(TARGETS ${TARGET} RUNTIME)
|
||||||
|
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||||
|
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
3
examples/speculative-simple/README.md
Normal file
3
examples/speculative-simple/README.md
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
# llama.cpp/examples/speculative-simple
|
||||||
|
|
||||||
|
Demonstration of basic greedy speculative decoding
|
285
examples/speculative-simple/speculative-simple.cpp
Normal file
285
examples/speculative-simple/speculative-simple.cpp
Normal file
|
@ -0,0 +1,285 @@
|
||||||
|
#include "arg.h"
|
||||||
|
#include "common.h"
|
||||||
|
#include "sampling.h"
|
||||||
|
#include "speculative.h"
|
||||||
|
#include "log.h"
|
||||||
|
#include "llama.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <cstring>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
|
||||||
|
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
|
||||||
|
|
||||||
|
struct seq_draft {
|
||||||
|
std::vector<int> i_batch_tgt;
|
||||||
|
|
||||||
|
std::vector<llama_token> tokens;
|
||||||
|
|
||||||
|
struct common_sampler * smpl = nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
int main(int argc, char ** argv) {
|
||||||
|
common_params params;
|
||||||
|
|
||||||
|
// needed to get candidate probs even for temp <= 0.0
|
||||||
|
params.sparams.n_probs = 128;
|
||||||
|
|
||||||
|
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (params.n_predict < -1) {
|
||||||
|
LOG_ERR("%s: --n-predict must be >= -1\n", __func__);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
common_init();
|
||||||
|
|
||||||
|
if (params.model_draft.empty()) {
|
||||||
|
LOG_ERR("%s: --model-draft is required\n", __func__);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// init llama.cpp
|
||||||
|
llama_backend_init();
|
||||||
|
llama_numa_init(params.numa);
|
||||||
|
|
||||||
|
llama_model * model_tgt = NULL;
|
||||||
|
llama_model * model_dft = NULL;
|
||||||
|
|
||||||
|
llama_context * ctx_tgt = NULL;
|
||||||
|
llama_context * ctx_dft = NULL;
|
||||||
|
|
||||||
|
// load the target model
|
||||||
|
common_init_result llama_init_tgt = common_init_from_params(params);
|
||||||
|
model_tgt = llama_init_tgt.model;
|
||||||
|
ctx_tgt = llama_init_tgt.context;
|
||||||
|
|
||||||
|
// load the draft model
|
||||||
|
params.model = params.model_draft;
|
||||||
|
params.n_gpu_layers = params.n_gpu_layers_draft;
|
||||||
|
if (params.draft_cpuparams.n_threads > 0) {
|
||||||
|
params.cpuparams.n_threads = params.draft_cpuparams.n_threads;
|
||||||
|
}
|
||||||
|
|
||||||
|
params.cpuparams_batch.n_threads = params.draft_cpuparams_batch.n_threads;
|
||||||
|
common_init_result llama_init_dft = common_init_from_params(params);
|
||||||
|
model_dft = llama_init_dft.model;
|
||||||
|
ctx_dft = llama_init_dft.context;
|
||||||
|
|
||||||
|
const bool vocab_type_tgt = llama_vocab_type(model_tgt);
|
||||||
|
LOG_DBG("vocab_type tgt: %d\n", vocab_type_tgt);
|
||||||
|
|
||||||
|
const bool vocab_type_dft = llama_vocab_type(model_dft);
|
||||||
|
LOG_DBG("vocab_type dft: %d\n", vocab_type_dft);
|
||||||
|
|
||||||
|
if (vocab_type_tgt != vocab_type_dft) {
|
||||||
|
LOG_ERR("%s: draft model vocab type must match target model to use speculation but ", __func__);
|
||||||
|
LOG_ERR("vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
llama_add_bos_token(model_tgt) != llama_add_bos_token(model_dft) ||
|
||||||
|
llama_add_eos_token(model_tgt) != llama_add_eos_token(model_dft) ||
|
||||||
|
llama_token_bos(model_tgt) != llama_token_bos(model_dft) ||
|
||||||
|
llama_token_eos(model_tgt) != llama_token_eos(model_dft)
|
||||||
|
) {
|
||||||
|
LOG_ERR("%s: draft model special tokens must match target model to use speculation\n", __func__);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
const int n_vocab_tgt = llama_n_vocab(model_tgt);
|
||||||
|
const int n_vocab_dft = llama_n_vocab(model_dft);
|
||||||
|
const int vocab_diff = n_vocab_tgt > n_vocab_dft
|
||||||
|
? n_vocab_tgt - n_vocab_dft
|
||||||
|
: n_vocab_dft - n_vocab_tgt;
|
||||||
|
|
||||||
|
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
|
||||||
|
LOG_ERR("%s: draft model vocab must closely match target model to use speculation but ", __func__);
|
||||||
|
LOG_ERR("target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
|
||||||
|
n_vocab_tgt, llama_n_vocab(model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
|
||||||
|
const char * token_text_tgt = llama_token_get_text(model_tgt, i);
|
||||||
|
const char * token_text_dft = llama_token_get_text(model_dft, i);
|
||||||
|
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
|
||||||
|
LOG_ERR("%s: draft model vocab must match target model to use speculation but ", __func__);
|
||||||
|
LOG_ERR("token %d content differs - target '%s', draft '%s'\n", i,
|
||||||
|
common_token_to_piece(ctx_tgt, i).c_str(),
|
||||||
|
common_token_to_piece(ctx_dft, i).c_str());
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Tokenize the prompt
|
||||||
|
std::vector<llama_token> inp;
|
||||||
|
inp = common_tokenize(ctx_tgt, params.prompt, true, true);
|
||||||
|
|
||||||
|
const int max_context_size = llama_n_ctx(ctx_tgt);
|
||||||
|
const int max_tokens_list_size = max_context_size - 4;
|
||||||
|
|
||||||
|
if ((int) inp.size() > max_tokens_list_size) {
|
||||||
|
LOG_ERR("%s: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), max_tokens_list_size);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG("\n\n");
|
||||||
|
|
||||||
|
for (auto id : inp) {
|
||||||
|
LOG("%s", common_token_to_piece(ctx_tgt, id).c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
const int n_input = inp.size();
|
||||||
|
|
||||||
|
const auto t_enc_start = ggml_time_us();
|
||||||
|
|
||||||
|
// eval the prompt
|
||||||
|
llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), n_input - 1));
|
||||||
|
|
||||||
|
// note: keep the last token separate!
|
||||||
|
llama_token id_last = inp.back();
|
||||||
|
|
||||||
|
int n_past = inp.size() - 1;
|
||||||
|
|
||||||
|
// how many tokens to draft each time
|
||||||
|
int n_draft = params.n_draft;
|
||||||
|
|
||||||
|
int n_predict = 0;
|
||||||
|
int n_drafted = 0;
|
||||||
|
int n_accept = 0;
|
||||||
|
|
||||||
|
// used to determine end of generation
|
||||||
|
bool has_eos = false;
|
||||||
|
|
||||||
|
// target model sampling context
|
||||||
|
struct common_sampler * smpl = common_sampler_init(model_tgt, params.sparams);
|
||||||
|
|
||||||
|
// init the speculator
|
||||||
|
struct common_speculative_params params_spec;
|
||||||
|
params_spec.n_draft = n_draft;
|
||||||
|
params_spec.n_min = 5;
|
||||||
|
params_spec.model_dft = model_dft;
|
||||||
|
params_spec.ctx_dft = ctx_dft;
|
||||||
|
|
||||||
|
struct common_speculative * spec = common_speculative_init(params_spec);
|
||||||
|
|
||||||
|
// feed the prompt to the speculator
|
||||||
|
common_speculative_set_prompt(spec, inp.data(), n_input - 1);
|
||||||
|
|
||||||
|
llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1);
|
||||||
|
|
||||||
|
const auto t_enc_end = ggml_time_us();
|
||||||
|
|
||||||
|
const auto t_dec_start = ggml_time_us();
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
// always have a token to evaluate from before
|
||||||
|
common_batch_clear(batch_tgt);
|
||||||
|
common_batch_add (batch_tgt, id_last, n_past, { 0 }, true);
|
||||||
|
|
||||||
|
// optionally, append draft tokens to the target batch
|
||||||
|
common_speculative_add_draft(spec, batch_tgt, id_last, n_past);
|
||||||
|
|
||||||
|
// evaluate the target model on the drafted tokens
|
||||||
|
{
|
||||||
|
//LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str());
|
||||||
|
|
||||||
|
llama_decode(ctx_tgt, batch_tgt);
|
||||||
|
}
|
||||||
|
|
||||||
|
// process the full target batch and return the accepted token based on the target sampler
|
||||||
|
const auto ids = common_speculative_sample(spec, smpl, ctx_tgt);
|
||||||
|
|
||||||
|
n_past += ids.size();
|
||||||
|
n_drafted += batch_tgt.n_tokens - 1;
|
||||||
|
n_accept += ids.size() - 1;
|
||||||
|
|
||||||
|
// process the accepted tokens and update contexts
|
||||||
|
{
|
||||||
|
llama_token id;
|
||||||
|
std::string token_str;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < ids.size(); ++i) {
|
||||||
|
id = ids[i];
|
||||||
|
|
||||||
|
++n_predict;
|
||||||
|
|
||||||
|
if (llama_token_is_eog(model_tgt, id)) {
|
||||||
|
has_eos = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
token_str = common_token_to_piece(ctx_tgt, id);
|
||||||
|
|
||||||
|
if (params.use_color && i + 1 < ids.size()) {
|
||||||
|
LOG("\u001b[%dm%s\u001b[37m", (36 - 0 % 6), token_str.c_str());
|
||||||
|
} else {
|
||||||
|
LOG("%s", token_str.c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_DBG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", id, token_str.c_str());
|
||||||
|
|
||||||
|
{
|
||||||
|
LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);
|
||||||
|
|
||||||
|
llama_kv_cache_seq_rm(ctx_tgt, 0, n_past, -1);
|
||||||
|
llama_kv_cache_seq_rm(ctx_dft, 0, n_past, -1);
|
||||||
|
}
|
||||||
|
|
||||||
|
id_last = id;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto t_dec_end = ggml_time_us();
|
||||||
|
|
||||||
|
LOG("\n\n");
|
||||||
|
|
||||||
|
LOG_INF("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f));
|
||||||
|
LOG_INF("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f));
|
||||||
|
|
||||||
|
LOG_INF("\n");
|
||||||
|
LOG_INF("n_draft = %d\n", n_draft);
|
||||||
|
LOG_INF("n_predict = %d\n", n_predict);
|
||||||
|
LOG_INF("n_drafted = %d\n", n_drafted);
|
||||||
|
LOG_INF("n_accept = %d\n", n_accept);
|
||||||
|
LOG_INF("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);
|
||||||
|
|
||||||
|
LOG_INF("\n");
|
||||||
|
LOG_INF("draft:\n\n");
|
||||||
|
|
||||||
|
llama_perf_context_print(ctx_dft);
|
||||||
|
|
||||||
|
LOG_INF("\n");
|
||||||
|
LOG_INF("target:\n\n");
|
||||||
|
common_perf_print(ctx_tgt, smpl);
|
||||||
|
|
||||||
|
common_sampler_free(smpl);
|
||||||
|
common_speculative_free(spec);
|
||||||
|
|
||||||
|
llama_free(ctx_tgt);
|
||||||
|
llama_free_model(model_tgt);
|
||||||
|
|
||||||
|
llama_free(ctx_dft);
|
||||||
|
llama_free_model(model_dft);
|
||||||
|
|
||||||
|
llama_backend_free();
|
||||||
|
|
||||||
|
LOG("\n\n");
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
|
@ -12,7 +12,7 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100
|
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
|
||||||
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
|
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
|
||||||
|
|
||||||
struct seq_draft {
|
struct seq_draft {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue