lookup : fix token positions in the draft batch
This commit is contained in:
parent
1b26d7151a
commit
5b27975479
2 changed files with 28 additions and 16 deletions
|
@ -239,4 +239,5 @@ void dump_non_result_info_yaml(
|
||||||
void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size = 80);
|
void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size = 80);
|
||||||
|
|
||||||
// Dump the KV cache view showing individual sequences in each cell (long output).
|
// Dump the KV cache view showing individual sequences in each cell (long output).
|
||||||
void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40);
|
void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40);
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,8 @@ int main(int argc, char ** argv){
|
||||||
// length of the candidate / draft sequence, if match is found
|
// length of the candidate / draft sequence, if match is found
|
||||||
const int n_draft = 10;
|
const int n_draft = 10;
|
||||||
|
|
||||||
|
const bool dump_kv_cache = params.dump_kv_cache;
|
||||||
|
|
||||||
#ifndef LOG_DISABLE_LOGS
|
#ifndef LOG_DISABLE_LOGS
|
||||||
log_set_target(log_filename_generator("lookup", "log"));
|
log_set_target(log_filename_generator("lookup", "log"));
|
||||||
LOG_TEE("Log start\n");
|
LOG_TEE("Log start\n");
|
||||||
|
@ -37,7 +39,7 @@ int main(int argc, char ** argv){
|
||||||
// tokenize the prompt
|
// tokenize the prompt
|
||||||
const bool add_bos = llama_should_add_bos_token(model);
|
const bool add_bos = llama_should_add_bos_token(model);
|
||||||
LOG("add_bos tgt: %d\n", add_bos);
|
LOG("add_bos tgt: %d\n", add_bos);
|
||||||
|
|
||||||
std::vector<llama_token> inp;
|
std::vector<llama_token> inp;
|
||||||
inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
|
inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
|
||||||
|
|
||||||
|
@ -69,24 +71,33 @@ int main(int argc, char ** argv){
|
||||||
int n_predict = 0;
|
int n_predict = 0;
|
||||||
int n_drafted = 0;
|
int n_drafted = 0;
|
||||||
int n_accept = 0;
|
int n_accept = 0;
|
||||||
|
|
||||||
int n_past = inp.size();
|
int n_past = inp.size();
|
||||||
|
|
||||||
bool has_eos = false;
|
bool has_eos = false;
|
||||||
|
|
||||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
|
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
|
||||||
|
|
||||||
std::vector<llama_token> draft(n_draft);
|
std::vector<llama_token> draft;
|
||||||
|
|
||||||
llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, 1);
|
llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, 1);
|
||||||
|
|
||||||
|
// debug
|
||||||
|
struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, 1);
|
||||||
|
|
||||||
const auto t_dec_start = ggml_time_us();
|
const auto t_dec_start = ggml_time_us();
|
||||||
|
|
||||||
while(true){
|
while (true) {
|
||||||
|
// debug
|
||||||
|
if (dump_kv_cache) {
|
||||||
|
llama_kv_cache_view_update(ctx, &kvc_view);
|
||||||
|
dump_kv_cache_view_seqs(kvc_view, 40);
|
||||||
|
}
|
||||||
|
|
||||||
// print current draft sequence
|
// print current draft sequence
|
||||||
LOG("drafted %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, draft).c_str());
|
LOG("drafted %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, draft).c_str());
|
||||||
|
|
||||||
int i_dft = 0;
|
int i_dft = 0;
|
||||||
while (true) {
|
while (true) {
|
||||||
// sample from the target model
|
// sample from the target model
|
||||||
llama_token id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_dft);
|
llama_token id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_dft);
|
||||||
|
@ -120,13 +131,13 @@ int main(int argc, char ** argv){
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.use_color) {
|
if (params.use_color) {
|
||||||
printf("%s", token_str.c_str());
|
printf("%s", token_str.c_str());
|
||||||
}
|
}
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
|
|
||||||
|
|
||||||
LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", id, token_str.c_str());
|
LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", id, token_str.c_str());
|
||||||
|
|
||||||
draft.clear();
|
draft.clear();
|
||||||
|
@ -135,7 +146,7 @@ int main(int argc, char ** argv){
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (n_predict > params.n_predict || has_eos) {
|
if ((params.n_predict > 0 && n_predict > params.n_predict) || has_eos) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -149,9 +160,9 @@ int main(int argc, char ** argv){
|
||||||
// generate n_pred tokens through prompt lookup
|
// generate n_pred tokens through prompt lookup
|
||||||
auto prompt_lookup = [&]() -> void {
|
auto prompt_lookup = [&]() -> void {
|
||||||
int inp_size = inp.size();
|
int inp_size = inp.size();
|
||||||
for (int ngram_size = max_ngram_size ; ngram_size > 0; --ngram_size){
|
for (int ngram_size = max_ngram_size ; ngram_size > 0; --ngram_size){
|
||||||
const llama_token * ngram = &inp[inp_size - ngram_size];
|
const llama_token * ngram = &inp[inp_size - ngram_size];
|
||||||
|
|
||||||
for (int i = 0; i <= (int) inp_size - (ngram_size * 2); ++i) {
|
for (int i = 0; i <= (int) inp_size - (ngram_size * 2); ++i) {
|
||||||
bool match = true;
|
bool match = true;
|
||||||
for (int j = 0; j < ngram_size; ++j) {
|
for (int j = 0; j < ngram_size; ++j) {
|
||||||
|
@ -164,11 +175,11 @@ int main(int argc, char ** argv){
|
||||||
if (match) {
|
if (match) {
|
||||||
const int startIdx = i + ngram_size;
|
const int startIdx = i + ngram_size;
|
||||||
const int endIdx = startIdx + n_draft;
|
const int endIdx = startIdx + n_draft;
|
||||||
if (endIdx < inp_size){
|
if (endIdx < inp_size) {
|
||||||
for (int j = startIdx; j < endIdx; ++j) {
|
for (int j = startIdx; j < endIdx; ++j) {
|
||||||
LOG(" - draft candidate %d: %d\n", j, inp[j]);
|
LOG(" - draft candidate %d: %d\n", j, inp[j]);
|
||||||
draft.push_back(inp[j]);
|
draft.push_back(inp[j]);
|
||||||
llama_batch_add(batch_tgt, inp[j], n_past + j + 1, { 0 }, true);
|
llama_batch_add(batch_tgt, inp[j], n_past + (j - startIdx) + 1, { 0 }, true);
|
||||||
++n_drafted;
|
++n_drafted;
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
|
@ -180,7 +191,7 @@ int main(int argc, char ** argv){
|
||||||
};
|
};
|
||||||
|
|
||||||
prompt_lookup();
|
prompt_lookup();
|
||||||
|
|
||||||
llama_decode(ctx, batch_tgt);
|
llama_decode(ctx, batch_tgt);
|
||||||
++n_past;
|
++n_past;
|
||||||
|
|
||||||
|
@ -215,4 +226,4 @@ int main(int argc, char ** argv){
|
||||||
fprintf(stderr, "\n\n");
|
fprintf(stderr, "\n\n");
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue