brought infill up to current main code

This commit is contained in:
vvhg1 2023-09-30 20:56:24 +02:00
parent ac0f1752ea
commit 79619527b0

View file

@ -41,7 +41,8 @@ static std::ostringstream * g_output_ss;
static std::vector<llama_token> * g_output_tokens; static std::vector<llama_token> * g_output_tokens;
static bool is_interacting = false; static bool is_interacting = false;
void write_logfile(
static void write_logfile(
const llama_context * ctx, const gpt_params & params, const llama_model * model, const llama_context * ctx, const gpt_params & params, const llama_model * model,
const std::vector<llama_token> & input_tokens, const std::string & output, const std::vector<llama_token> & input_tokens, const std::string & output,
const std::vector<llama_token> & output_tokens const std::vector<llama_token> & output_tokens
@ -86,7 +87,7 @@ void write_logfile(
} }
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
void sigint_handler(int signo) { static void sigint_handler(int signo) {
if (signo == SIGINT) { if (signo == SIGINT) {
if (!is_interacting) { if (!is_interacting) {
is_interacting = true; is_interacting = true;
@ -118,7 +119,7 @@ int main(int argc, char ** argv) {
console::init(params.simple_io, params.use_color); console::init(params.simple_io, params.use_color);
atexit([]() { console::cleanup(); }); atexit([]() { console::cleanup(); });
if (params.perplexity) { if (params.logits_all) {
printf("\n************\n"); printf("\n************\n");
printf("%s: please use the 'perplexity' tool for perplexity calculations\n", __func__); printf("%s: please use the 'perplexity' tool for perplexity calculations\n", __func__);
printf("************\n\n"); printf("************\n\n");
@ -133,6 +134,11 @@ int main(int argc, char ** argv) {
return 0; return 0;
} }
if (params.n_ctx != 0 && params.n_ctx < 8) {
LOG_TEE("%s: warning: minimum context size is 8, using minimum size.\n", __func__);
params.n_ctx = 8;
}
if (params.instruct) { if (params.instruct) {
printf("\n************\n"); printf("\n************\n");
printf("%s: please use the 'main' tool for instruct mode\n", __func__); printf("%s: please use the 'main' tool for instruct mode\n", __func__);
@ -169,12 +175,12 @@ int main(int argc, char ** argv) {
return 0; return 0;
} }
if (params.rope_freq_base != 10000.0) { if (params.rope_freq_base != 0.0) {
LOG_TEE("%s: warning: changing RoPE frequency base to %g (default 10000.0)\n", __func__, params.rope_freq_base); LOG_TEE("%s: warning: changing RoPE frequency base to %g.\n", __func__, params.rope_freq_base);
} }
if (params.rope_freq_scale != 1.0) { if (params.rope_freq_scale != 0.0) {
LOG_TEE("%s: warning: scaling RoPE frequency by %g (default 1.0)\n", __func__, params.rope_freq_scale); LOG_TEE("%s: warning: scaling RoPE frequency by %g.\n", __func__, params.rope_freq_scale);
} }
LOG_TEE("%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT); LOG_TEE("%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT);
@ -210,32 +216,21 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
const int n_ctx_train = llama_n_ctx_train(ctx); const int n_ctx_train = llama_n_ctx_train(model);
if (params.n_ctx > n_ctx_train) { const int n_ctx = llama_n_ctx(ctx);
LOG("n_ctx: %d\n", n_ctx);
if (n_ctx > n_ctx_train) {
LOG_TEE("%s: warning: model was trained on only %d context tokens (%d specified)\n", LOG_TEE("%s: warning: model was trained on only %d context tokens (%d specified)\n",
__func__, n_ctx_train, params.n_ctx); __func__, n_ctx_train, n_ctx);
} else if (params.n_ctx < 8) {
LOG_TEE("%s: warning: minimum context size is 8, using minimum size.\n", __func__);
params.n_ctx = 8;
} }
// print system information // print system information
{ {
LOG_TEE("\n"); LOG_TEE("\n");
LOG_TEE("system_info: n_threads = %d / %d | %s\n", LOG_TEE("%s\n", get_system_info(params).c_str());
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
} }
const bool add_bos = llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM;
// export the cgraph and exit
if (params.export_cgraph) {
llama_eval_export(ctx, "llama.ggml");
llama_free(ctx);
llama_free_model(model);
return 0;
}
const bool add_bos = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
LOG("add_bos: %d\n", add_bos); LOG("add_bos: %d\n", add_bos);
std::vector<llama_token> embd_inp; std::vector<llama_token> embd_inp;
@ -276,9 +271,6 @@ int main(int argc, char ** argv) {
LOG("guidance_offset: %s", log_tostr(guidance_offset)); LOG("guidance_offset: %s", log_tostr(guidance_offset));
} }
const int n_ctx = llama_n_ctx(ctx);
LOG("n_ctx: %d\n", n_ctx);
if ((int) embd_inp.size() > n_ctx - 4) { if ((int) embd_inp.size() > n_ctx - 4) {
LOG_TEE("%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int) embd_inp.size(), n_ctx - 4); LOG_TEE("%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int) embd_inp.size(), n_ctx - 4);
return 1; return 1;
@ -428,7 +420,7 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd; std::vector<llama_token> embd;
std::vector<llama_token> embd_guidance; std::vector<llama_token> embd_guidance;
const int n_vocab = llama_n_vocab(ctx); const int n_vocab = llama_n_vocab(model);
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab); candidates.reserve(n_vocab);
@ -436,7 +428,7 @@ int main(int argc, char ** argv) {
while (n_remain != 0 || params.interactive) { while (n_remain != 0 || params.interactive) {
// predict // predict
if (!embd.empty()) { if (!embd.empty()) {
// Note: n_ctx - 4 here is to match the logic for commandLine prompt handling via // Note: n_ctx - 4 here is to match the logic for commandline prompt handling via
// --prompt or --file which uses the same value. // --prompt or --file which uses the same value.
int max_embd_size = n_ctx - 4; int max_embd_size = n_ctx - 4;
@ -461,18 +453,23 @@ int main(int argc, char ** argv) {
break; break;
} }
const int n_left = n_past - params.n_keep; const int n_left = n_past - params.n_keep - 1;
LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d\n", n_past, n_left, n_ctx, params.n_keep); const int n_discard = n_left/2;
// always keep the first token - BOS LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
n_past = std::max(1, params.n_keep); n_past, n_left, n_ctx, params.n_keep, n_discard);
n_past_guidance = std::max(1, params.n_keep + guidance_offset);
llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
n_past -= n_discard;
if (ctx_guidance) {
n_past_guidance -= n_discard;
}
LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance); LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance);
// insert n_left/2 tokens at the start of embd from last_tokens
embd.insert(embd.begin(), last_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_tokens.end() - embd.size());
LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd)); LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd));
} }
@ -509,7 +506,7 @@ int main(int argc, char ** argv) {
for (int i = 0; i < input_size; i += params.n_batch) { for (int i = 0; i < input_size; i += params.n_batch) {
int n_eval = std::min(input_size - i, params.n_batch); int n_eval = std::min(input_size - i, params.n_batch);
if (llama_eval(ctx_guidance, input_buf + i, n_eval, n_past_guidance, params.n_threads)) { if (llama_decode(ctx_guidance, llama_batch_get_one(input_buf + i, n_eval, n_past_guidance, 0))) {
LOG_TEE("%s : failed to eval\n", __func__); LOG_TEE("%s : failed to eval\n", __func__);
return 1; return 1;
} }
@ -526,7 +523,7 @@ int main(int argc, char ** argv) {
LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd)); LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd));
if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads)) { if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) {
LOG_TEE("%s : failed to eval\n", __func__); LOG_TEE("%s : failed to eval\n", __func__);
return 1; return 1;
} }
@ -769,3 +766,4 @@ int main(int argc, char ** argv) {
return 0; return 0;
} }