initial commit, going through initializations
This commit is contained in:
parent
23b5e12eb5
commit
cae8f50b1a
3 changed files with 118 additions and 0 deletions
5
examples/lookup/CMakeLists.txt
Normal file
5
examples/lookup/CMakeLists.txt
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
set(TARGET lookup)
|
||||||
|
add_executable(${TARGET} lookup.cpp)
|
||||||
|
install(TARGETS ${TARGET} RUNTIME)
|
||||||
|
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||||
|
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
0
examples/lookup/README.md
Normal file
0
examples/lookup/README.md
Normal file
113
examples/lookup/lookup.cpp
Normal file
113
examples/lookup/lookup.cpp
Normal file
|
@ -0,0 +1,113 @@
|
||||||
|
#include "common.h"
|
||||||
|
#include "llama.h"
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
/*
|
||||||
|
def find_candidate_pred_tokens(input_ids, max_ngram_size=3, num_pred_tokens=10):
|
||||||
|
input_length = input_ids.size(1)
|
||||||
|
|
||||||
|
for ngram_size in range(max_ngram_size, 0, -1):
|
||||||
|
# Extract the last n tokens as our search ngram
|
||||||
|
ngram = input_ids[0, -ngram_size:].tolist()
|
||||||
|
|
||||||
|
# Create sliding windows of size ngram_size
|
||||||
|
windows = input_ids.unfold(dimension=1, size=ngram_size, step=1)
|
||||||
|
|
||||||
|
# Convert ngram to a tensor for comparison
|
||||||
|
ngram_tensor = torch.tensor(ngram, device=input_ids.device).unsqueeze(0)
|
||||||
|
|
||||||
|
# Find where the windows match the ngram
|
||||||
|
matches = (windows == ngram_tensor).all(dim=2)
|
||||||
|
|
||||||
|
# Get the indices of matches
|
||||||
|
match_indices = matches.nonzero(as_tuple=True)[1]
|
||||||
|
|
||||||
|
# Iterate through match indices to find a valid continuation
|
||||||
|
for idx in match_indices:
|
||||||
|
start_idx = idx + ngram_size
|
||||||
|
end_idx = start_idx + num_pred_tokens
|
||||||
|
# Ensure we don't go beyond the length of input_ids and avoid self-match
|
||||||
|
if end_idx <= input_length and start_idx < input_length - ngram_size:
|
||||||
|
return input_ids[0, start_idx:end_idx]
|
||||||
|
|
||||||
|
# If no match is found, return an empty tensor
|
||||||
|
return torch.tensor([], dtype=torch.long, device=input_ids.device)
|
||||||
|
*/
|
||||||
|
|
||||||
|
int main(int argc, char ** argv){
|
||||||
|
gpt_params params;
|
||||||
|
|
||||||
|
if(gpt_params_parse(argc, argv, params) == false){
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// maximum n-grams to search for in prompt
|
||||||
|
const int max_ngram_size = 3;
|
||||||
|
|
||||||
|
// length of the candidate sequence, if match is found
|
||||||
|
const int num_pred_tokens = 10;
|
||||||
|
|
||||||
|
#ifndef LOG_DISABLE_LOGS
|
||||||
|
log_set_target(log_filename_generator("lookup", "log"));
|
||||||
|
LOG_TEE("Log start\n");
|
||||||
|
log_dump_cmdline(argc, argv);
|
||||||
|
#endif // LOG_DISABLE_LOGS
|
||||||
|
|
||||||
|
// init llama.cpp
|
||||||
|
llama_backend_init(params.numa);
|
||||||
|
|
||||||
|
llama_model * model = NULL;
|
||||||
|
llama_context * ctx = NULL;
|
||||||
|
|
||||||
|
// load the model
|
||||||
|
std::tie(model, ctx) = llama_init_from_gpt_params(params);
|
||||||
|
|
||||||
|
// tokenize the prompt
|
||||||
|
const bool add_bos = llama_should_add_bos_token(model);
|
||||||
|
LOG("add_bos tgt: %d\n", add_bos);
|
||||||
|
|
||||||
|
std::vector<llama_token> inp;
|
||||||
|
inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
|
||||||
|
|
||||||
|
const int max_context_size = llama_n_ctx(ctx);
|
||||||
|
const int max_tokens_list_size = max_context_size - 4;
|
||||||
|
|
||||||
|
if ((int) inp.size() > max_tokens_list_size) {
|
||||||
|
fprintf(stderr, "%s: error: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), max_tokens_list_size);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
fprintf(stderr, "\n\n");
|
||||||
|
|
||||||
|
for (auto id : inp) {
|
||||||
|
fprintf(stderr, "%s", llama_token_to_piece(ctx, id).c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
fflush(stderr);
|
||||||
|
|
||||||
|
const int n_input = inp.size();
|
||||||
|
|
||||||
|
const auto t_enc_start = ggml_time_us();
|
||||||
|
|
||||||
|
llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1, 0, 0));
|
||||||
|
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0));
|
||||||
|
|
||||||
|
const auto t_enc_end = ggml_time_us();
|
||||||
|
|
||||||
|
int n_accept = 0;
|
||||||
|
|
||||||
|
int n_past = inp.size();
|
||||||
|
|
||||||
|
bool has_eos = false;
|
||||||
|
|
||||||
|
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
|
||||||
|
|
||||||
|
const auto t_dec_start = ggml_time_us();
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue