lookup : fix token positions in the draft batch

This commit is contained in:
Georgi Gerganov 2023-12-17 16:47:26 +02:00
parent 1b26d7151a
commit 5b27975479
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 28 additions and 16 deletions

View file

@ -240,3 +240,4 @@ void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size = 80);
// Dump the KV cache view showing individual sequences in each cell (long output). // Dump the KV cache view showing individual sequences in each cell (long output).
void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40); void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40);

View file

@ -19,6 +19,8 @@ int main(int argc, char ** argv){
// length of the candidate / draft sequence, if match is found // length of the candidate / draft sequence, if match is found
const int n_draft = 10; const int n_draft = 10;
const bool dump_kv_cache = params.dump_kv_cache;
#ifndef LOG_DISABLE_LOGS #ifndef LOG_DISABLE_LOGS
log_set_target(log_filename_generator("lookup", "log")); log_set_target(log_filename_generator("lookup", "log"));
LOG_TEE("Log start\n"); LOG_TEE("Log start\n");
@ -76,13 +78,22 @@ int main(int argc, char ** argv){
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams); struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
std::vector<llama_token> draft(n_draft); std::vector<llama_token> draft;
llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, 1); llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, 1);
// debug
struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, 1);
const auto t_dec_start = ggml_time_us(); const auto t_dec_start = ggml_time_us();
while(true){ while (true) {
// debug
if (dump_kv_cache) {
llama_kv_cache_view_update(ctx, &kvc_view);
dump_kv_cache_view_seqs(kvc_view, 40);
}
// print current draft sequence // print current draft sequence
LOG("drafted %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, draft).c_str()); LOG("drafted %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, draft).c_str());
@ -135,7 +146,7 @@ int main(int argc, char ** argv){
break; break;
} }
if (n_predict > params.n_predict || has_eos) { if ((params.n_predict > 0 && n_predict > params.n_predict) || has_eos) {
break; break;
} }
@ -164,11 +175,11 @@ int main(int argc, char ** argv){
if (match) { if (match) {
const int startIdx = i + ngram_size; const int startIdx = i + ngram_size;
const int endIdx = startIdx + n_draft; const int endIdx = startIdx + n_draft;
if (endIdx < inp_size){ if (endIdx < inp_size) {
for (int j = startIdx; j < endIdx; ++j) { for (int j = startIdx; j < endIdx; ++j) {
LOG(" - draft candidate %d: %d\n", j, inp[j]); LOG(" - draft candidate %d: %d\n", j, inp[j]);
draft.push_back(inp[j]); draft.push_back(inp[j]);
llama_batch_add(batch_tgt, inp[j], n_past + j + 1, { 0 }, true); llama_batch_add(batch_tgt, inp[j], n_past + (j - startIdx) + 1, { 0 }, true);
++n_drafted; ++n_drafted;
} }
return; return;