Added colors to distinguish drafted tokens (--color). Updated README

This commit is contained in:
Leon Ericsson 2023-12-17 13:04:46 +01:00
parent 45b8032b9c
commit 1b26d7151a
2 changed files with 30 additions and 11 deletions

View file

@ -0,0 +1,13 @@
# llama.cpp/examples/lookup
Demonstration of Prompt Lookup Decoding
https://github.com/apoorvumang/prompt-lookup-decoding
The two key parameters for lookup decoding are `max_ngram_size` and `n_draft`. The first, determines how many ngrams to use when searching through the prompt for a match and the second specifies how many subsequent tokens to draft if a match is found.
More info:
https://github.com/ggerganov/llama.cpp/pull/4484
https://github.com/ggerganov/llama.cpp/issues/4226

View file

@ -88,19 +88,16 @@ int main(int argc, char ** argv){
int i_dft = 0;
while (true) {
//LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
// sample from the target model
llama_token id = llama_sampling_sample(ctx_sampling, ctx, NULL, 0);
llama_token id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_dft);
llama_sampling_accept(ctx_sampling, ctx, id, true);
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str());
const std::string token_str = llama_token_to_piece(ctx, id);
printf("%s", token_str.c_str());
fflush(stdout);
if (!params.use_color) {
printf("%s", token_str.c_str());
}
if (id == llama_token_eos(model)) {
has_eos = true;
@ -114,9 +111,21 @@ int main(int argc, char ** argv){
++n_accept;
++n_past;
++i_dft;
inp.push_back(id);
if (params.use_color) {
// color accepted draft token
printf("\033[34m%s\033[0m", token_str.c_str());
fflush(stdout);
}
continue;
}
if (params.use_color) {
printf("%s", token_str.c_str());
}
fflush(stdout);
LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", id, token_str.c_str());
@ -176,9 +185,6 @@ int main(int argc, char ** argv){
++n_past;
draft.erase(draft.begin());
// we have our draft!
}
auto t_dec_end = ggml_time_us();