From 50ea1ef7c8639609593017215604c3a3a84987d4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 22 Dec 2023 18:04:30 +0200 Subject: [PATCH] lookup : final touches --- examples/lookup/README.md | 2 +- examples/lookup/lookup.cpp | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/lookup/README.md b/examples/lookup/README.md index 03a772c45..5bfb0de93 100644 --- a/examples/lookup/README.md +++ b/examples/lookup/README.md @@ -4,7 +4,7 @@ Demonstration of Prompt Lookup Decoding https://github.com/apoorvumang/prompt-lookup-decoding -The two key parameters for lookup decoding are `max_ngram_size` and `n_draft`. The first, determines how many ngrams to use when searching through the prompt for a match and the second specifies how many subsequent tokens to draft if a match is found. +The key parameters for lookup decoding are `ngram_min`, `ngram_max` and `n_draft`. The first two determine the size of the ngrams to search for in the prompt for a match. The latter specifies how many subsequent tokens to draft if a match is found. More info: diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index ab1be0a32..d8de7dd38 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -9,12 +9,13 @@ int main(int argc, char ** argv){ gpt_params params; - if(gpt_params_parse(argc, argv, params) == false){ + if (!gpt_params_parse(argc, argv, params)) { return 1; } - // maximum n-grams to search for in prompt - const int max_ngram_size = 3; + // max/min n-grams size to search for in prompt + const int ngram_max = 4; + const int ngram_min = 1; // length of the candidate / draft sequence, if match is found const int n_draft = params.n_draft; @@ -160,7 +161,7 @@ int main(int argc, char ** argv){ // generate n_pred tokens through prompt lookup auto prompt_lookup = [&]() -> void { int inp_size = inp.size(); - for (int ngram_size = max_ngram_size ; ngram_size > 0; --ngram_size){ + for (int ngram_size = ngram_max ; ngram_size > ngram_min; --ngram_size){ const llama_token * ngram = &inp[inp_size - ngram_size]; for (int i = 0; i <= (int) inp_size - (ngram_size * 2); ++i) {