From 79619527b07b1145312464c682134bb33b2c3f1a Mon Sep 17 00:00:00 2001 From: vvhg1 Date: Sat, 30 Sep 2023 20:56:24 +0200 Subject: [PATCH] brought infill up to current main code --- examples/infill/infill.cpp | 82 +++++++++++++++++++------------------- 1 file changed, 40 insertions(+), 42 deletions(-) diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 56ff0a1e1..9ec75ce42 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -41,7 +41,8 @@ static std::ostringstream * g_output_ss; static std::vector * g_output_tokens; static bool is_interacting = false; -void write_logfile( + +static void write_logfile( const llama_context * ctx, const gpt_params & params, const llama_model * model, const std::vector & input_tokens, const std::string & output, const std::vector & output_tokens @@ -86,7 +87,7 @@ void write_logfile( } #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) -void sigint_handler(int signo) { +static void sigint_handler(int signo) { if (signo == SIGINT) { if (!is_interacting) { is_interacting = true; @@ -118,7 +119,7 @@ int main(int argc, char ** argv) { console::init(params.simple_io, params.use_color); atexit([]() { console::cleanup(); }); - if (params.perplexity) { + if (params.logits_all) { printf("\n************\n"); printf("%s: please use the 'perplexity' tool for perplexity calculations\n", __func__); printf("************\n\n"); @@ -133,6 +134,11 @@ int main(int argc, char ** argv) { return 0; } + + if (params.n_ctx != 0 && params.n_ctx < 8) { + LOG_TEE("%s: warning: minimum context size is 8, using minimum size.\n", __func__); + params.n_ctx = 8; + } if (params.instruct) { printf("\n************\n"); printf("%s: please use the 'main' tool for instruct mode\n", __func__); @@ -169,12 +175,12 @@ int main(int argc, char ** argv) { return 0; } - if (params.rope_freq_base != 10000.0) { - LOG_TEE("%s: warning: changing RoPE frequency base to %g (default 10000.0)\n", __func__, params.rope_freq_base); + if (params.rope_freq_base != 0.0) { + LOG_TEE("%s: warning: changing RoPE frequency base to %g.\n", __func__, params.rope_freq_base); } - if (params.rope_freq_scale != 1.0) { - LOG_TEE("%s: warning: scaling RoPE frequency by %g (default 1.0)\n", __func__, params.rope_freq_scale); + if (params.rope_freq_scale != 0.0) { + LOG_TEE("%s: warning: scaling RoPE frequency by %g.\n", __func__, params.rope_freq_scale); } LOG_TEE("%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT); @@ -210,32 +216,21 @@ int main(int argc, char ** argv) { return 1; } - const int n_ctx_train = llama_n_ctx_train(ctx); - if (params.n_ctx > n_ctx_train) { + const int n_ctx_train = llama_n_ctx_train(model); + const int n_ctx = llama_n_ctx(ctx); + LOG("n_ctx: %d\n", n_ctx); + + if (n_ctx > n_ctx_train) { LOG_TEE("%s: warning: model was trained on only %d context tokens (%d specified)\n", - __func__, n_ctx_train, params.n_ctx); - } else if (params.n_ctx < 8) { - LOG_TEE("%s: warning: minimum context size is 8, using minimum size.\n", __func__); - params.n_ctx = 8; + __func__, n_ctx_train, n_ctx); } // print system information { LOG_TEE("\n"); - LOG_TEE("system_info: n_threads = %d / %d | %s\n", - params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info()); + LOG_TEE("%s\n", get_system_info(params).c_str()); } - - // export the cgraph and exit - if (params.export_cgraph) { - llama_eval_export(ctx, "llama.ggml"); - llama_free(ctx); - llama_free_model(model); - - return 0; - } - - const bool add_bos = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM; + const bool add_bos = llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM; LOG("add_bos: %d\n", add_bos); std::vector embd_inp; @@ -276,9 +271,6 @@ int main(int argc, char ** argv) { LOG("guidance_offset: %s", log_tostr(guidance_offset)); } - const int n_ctx = llama_n_ctx(ctx); - LOG("n_ctx: %d\n", n_ctx); - if ((int) embd_inp.size() > n_ctx - 4) { LOG_TEE("%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int) embd_inp.size(), n_ctx - 4); return 1; @@ -428,7 +420,7 @@ int main(int argc, char ** argv) { std::vector embd; std::vector embd_guidance; - const int n_vocab = llama_n_vocab(ctx); + const int n_vocab = llama_n_vocab(model); std::vector candidates; candidates.reserve(n_vocab); @@ -436,7 +428,7 @@ int main(int argc, char ** argv) { while (n_remain != 0 || params.interactive) { // predict if (!embd.empty()) { - // Note: n_ctx - 4 here is to match the logic for commandLine prompt handling via + // Note: n_ctx - 4 here is to match the logic for commandline prompt handling via // --prompt or --file which uses the same value. int max_embd_size = n_ctx - 4; @@ -461,18 +453,23 @@ int main(int argc, char ** argv) { break; } - const int n_left = n_past - params.n_keep; - LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d\n", n_past, n_left, n_ctx, params.n_keep); + const int n_left = n_past - params.n_keep - 1; + const int n_discard = n_left/2; - // always keep the first token - BOS - n_past = std::max(1, params.n_keep); - n_past_guidance = std::max(1, params.n_keep + guidance_offset); + LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", + n_past, n_left, n_ctx, params.n_keep, n_discard); + + llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); + llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); + + n_past -= n_discard; + + if (ctx_guidance) { + n_past_guidance -= n_discard; + } LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance); - // insert n_left/2 tokens at the start of embd from last_tokens - embd.insert(embd.begin(), last_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_tokens.end() - embd.size()); - LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd)); } @@ -509,7 +506,7 @@ int main(int argc, char ** argv) { for (int i = 0; i < input_size; i += params.n_batch) { int n_eval = std::min(input_size - i, params.n_batch); - if (llama_eval(ctx_guidance, input_buf + i, n_eval, n_past_guidance, params.n_threads)) { + if (llama_decode(ctx_guidance, llama_batch_get_one(input_buf + i, n_eval, n_past_guidance, 0))) { LOG_TEE("%s : failed to eval\n", __func__); return 1; } @@ -526,7 +523,7 @@ int main(int argc, char ** argv) { LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd)); - if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads)) { + if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) { LOG_TEE("%s : failed to eval\n", __func__); return 1; } @@ -741,7 +738,7 @@ int main(int argc, char ** argv) { // In interactive mode, respect the maximum number of tokens and drop back to user input when reached. // We skip this logic when n_predict == -1 (infinite) or -2 (stop at context size). - if (params.interactive && n_remain <= 0 && params.n_predict >= 0 ) { + if (params.interactive && n_remain <= 0 && params.n_predict >= 0) { n_remain = params.n_predict; is_interacting = true; } @@ -769,3 +766,4 @@ int main(int argc, char ** argv) { return 0; } +