diff --git a/examples/lookup/CMakeLists.txt b/examples/lookup/CMakeLists.txt new file mode 100644 index 000000000..c060b8f56 --- /dev/null +++ b/examples/lookup/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET lookup) +add_executable(${TARGET} lookup.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) diff --git a/examples/lookup/README.md b/examples/lookup/README.md new file mode 100644 index 000000000..e69de29bb diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp new file mode 100644 index 000000000..e7174ffe4 --- /dev/null +++ b/examples/lookup/lookup.cpp @@ -0,0 +1,113 @@ +#include "common.h" +#include "llama.h" + +#include +#include +#include +#include + +/* +def find_candidate_pred_tokens(input_ids, max_ngram_size=3, num_pred_tokens=10): + input_length = input_ids.size(1) + + for ngram_size in range(max_ngram_size, 0, -1): + # Extract the last n tokens as our search ngram + ngram = input_ids[0, -ngram_size:].tolist() + + # Create sliding windows of size ngram_size + windows = input_ids.unfold(dimension=1, size=ngram_size, step=1) + + # Convert ngram to a tensor for comparison + ngram_tensor = torch.tensor(ngram, device=input_ids.device).unsqueeze(0) + + # Find where the windows match the ngram + matches = (windows == ngram_tensor).all(dim=2) + + # Get the indices of matches + match_indices = matches.nonzero(as_tuple=True)[1] + + # Iterate through match indices to find a valid continuation + for idx in match_indices: + start_idx = idx + ngram_size + end_idx = start_idx + num_pred_tokens + # Ensure we don't go beyond the length of input_ids and avoid self-match + if end_idx <= input_length and start_idx < input_length - ngram_size: + return input_ids[0, start_idx:end_idx] + + # If no match is found, return an empty tensor + return torch.tensor([], dtype=torch.long, device=input_ids.device) +*/ + +int main(int argc, char ** argv){ + gpt_params params; + + if(gpt_params_parse(argc, argv, params) == false){ + return 1; + } + + // maximum n-grams to search for in prompt + const int max_ngram_size = 3; + + // length of the candidate sequence, if match is found + const int num_pred_tokens = 10; + +#ifndef LOG_DISABLE_LOGS + log_set_target(log_filename_generator("lookup", "log")); + LOG_TEE("Log start\n"); + log_dump_cmdline(argc, argv); +#endif // LOG_DISABLE_LOGS + + // init llama.cpp + llama_backend_init(params.numa); + + llama_model * model = NULL; + llama_context * ctx = NULL; + + // load the model + std::tie(model, ctx) = llama_init_from_gpt_params(params); + + // tokenize the prompt + const bool add_bos = llama_should_add_bos_token(model); + LOG("add_bos tgt: %d\n", add_bos); + + std::vector inp; + inp = ::llama_tokenize(ctx, params.prompt, add_bos, true); + + const int max_context_size = llama_n_ctx(ctx); + const int max_tokens_list_size = max_context_size - 4; + + if ((int) inp.size() > max_tokens_list_size) { + fprintf(stderr, "%s: error: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), max_tokens_list_size); + return 1; + } + + fprintf(stderr, "\n\n"); + + for (auto id : inp) { + fprintf(stderr, "%s", llama_token_to_piece(ctx, id).c_str()); + } + + fflush(stderr); + + const int n_input = inp.size(); + + const auto t_enc_start = ggml_time_us(); + + llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1, 0, 0)); + llama_decode(ctx, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0)); + + const auto t_enc_end = ggml_time_us(); + + int n_accept = 0; + + int n_past = inp.size(); + + bool has_eos = false; + + struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams); + + const auto t_dec_start = ggml_time_us(); + + + +} \ No newline at end of file