main : add trace log

This commit is contained in:
Georgi Gerganov 2023-08-29 15:12:37 +03:00
parent c72d344c1a
commit ecdf113c69
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 90 additions and 50 deletions

View file

@ -11,6 +11,6 @@ cd ..
# #
# "--keep 48" is based on the contents of prompts/chat-with-bob.txt # "--keep 48" is based on the contents of prompts/chat-with-bob.txt
# #
./main -m ./models/7B/ggml-model-q4_0.bin -c 512 -b 1024 -n 256 --keep 48 \ ./main -m ./models/llama-7b/ggml-model-q4_0.gguf -c 512 -b 1024 -n 256 --keep 48 \
--repeat_penalty 1.0 --color -i \ --repeat_penalty 1.0 --color -i \
-r "User:" -f prompts/chat-with-bob.txt -r "User:" -f prompts/chat-with-bob.txt

View file

@ -505,15 +505,17 @@ int main(int argc, char ** argv) {
if (embd.size() > 0) { if (embd.size() > 0) {
// Note: n_ctx - 4 here is to match the logic for commandline prompt handling via // Note: n_ctx - 4 here is to match the logic for commandline prompt handling via
// --prompt or --file which uses the same value. // --prompt or --file which uses the same value.
auto max_embd_size = n_ctx - 4; int max_embd_size = n_ctx - 4;
// Ensure the input doesn't exceed the context size by truncating embd if necessary. // Ensure the input doesn't exceed the context size by truncating embd if necessary.
if ((int)embd.size() > max_embd_size) { if ((int) embd.size() > max_embd_size) {
auto skipped_tokens = embd.size() - max_embd_size; const int skipped_tokens = (int) embd.size() - max_embd_size;
embd.resize(max_embd_size);
console::set_display(console::error); console::set_display(console::error);
printf("<<input too long: skipped %zu token%s>>", skipped_tokens, skipped_tokens != 1 ? "s" : ""); printf("<<input too long: skipped %d token%s>>", skipped_tokens, skipped_tokens != 1 ? "s" : "");
console::set_display(console::reset); console::set_display(console::reset);
fflush(stdout); fflush(stdout);
embd.resize(max_embd_size);
} }
// infinite text generation via context swapping // infinite text generation via context swapping
@ -522,28 +524,26 @@ int main(int argc, char ** argv) {
// - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
if (n_past + (int) embd.size() + std::max<int>(0, guidance_offset) > n_ctx) { if (n_past + (int) embd.size() + std::max<int>(0, guidance_offset) > n_ctx) {
if (params.n_predict == -2) { if (params.n_predict == -2) {
LOG_TEE("\n\n%s: context full, stopping generation\n", __func__); LOG_TEE("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
break; break;
} }
const int n_left = n_past - params.n_keep; const int n_left = n_past - params.n_keep;
LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d\n", n_past, n_left, n_ctx, params.n_keep);
// always keep the first token - BOS // always keep the first token - BOS
n_past = std::max(1, params.n_keep); n_past = std::max(1, params.n_keep);
n_past_guidance = std::max(1, params.n_keep + guidance_offset); n_past_guidance = std::max(1, params.n_keep + guidance_offset);
LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance);
// insert n_left/2 tokens at the start of embd from last_n_tokens // insert n_left/2 tokens at the start of embd from last_n_tokens
embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_n_tokens.end() - embd.size()); embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_n_tokens.end() - embd.size());
// stop saving session if we run out of context LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd));
path_session.clear();
//printf("\n---\n"); LOG("clear session path\n");
//printf("resetting: '"); path_session.clear();
//for (int i = 0; i < (int) embd.size(); i++) {
// printf("%s", llama_token_to_piece(ctx, embd[i]));
//}
//printf("'\n");
//printf("\n---\n");
} }
// try to reuse a matching prefix from the loaded session instead of re-eval (via n_past) // try to reuse a matching prefix from the loaded session instead of re-eval (via n_past)
@ -573,7 +573,7 @@ int main(int argc, char ** argv) {
if (ctx_guidance) { if (ctx_guidance) {
int input_size = 0; int input_size = 0;
llama_token* input_buf = NULL; llama_token * input_buf = NULL;
if (n_past_guidance < (int) guidance_inp.size()) { if (n_past_guidance < (int) guidance_inp.size()) {
// Guidance context should have the same data with these modifications: // Guidance context should have the same data with these modifications:
@ -591,11 +591,8 @@ int main(int argc, char ** argv) {
input_buf = embd_guidance.data(); input_buf = embd_guidance.data();
input_size = embd_guidance.size(); input_size = embd_guidance.size();
//LOG_TEE("\n---------------------\n");
//for (int i = 0; i < (int) embd_guidance.size(); i++) { LOG("guidance context: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_guidance));
//LOG_TEE("%s", llama_token_to_piece(ctx, embd_guidance[i]));
//}
//LOG_TEE("\n---------------------\n");
} else { } else {
input_buf = embd.data(); input_buf = embd.data();
input_size = embd.size(); input_size = embd.size();
@ -617,11 +614,17 @@ int main(int argc, char ** argv) {
if (n_eval > params.n_batch) { if (n_eval > params.n_batch) {
n_eval = params.n_batch; n_eval = params.n_batch;
} }
LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd));
if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads)) { if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads)) {
LOG_TEE("%s : failed to eval\n", __func__); LOG_TEE("%s : failed to eval\n", __func__);
return 1; return 1;
} }
n_past += n_eval; n_past += n_eval;
LOG("n_past = %d\n", n_past);
} }
if (embd.size() > 0 && !path_session.empty()) { if (embd.size() > 0 && !path_session.empty()) {
@ -634,7 +637,6 @@ int main(int argc, char ** argv) {
embd_guidance.clear(); embd_guidance.clear();
if ((int) embd_inp.size() <= n_consumed && !is_interacting) { if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
// out of user input, sample next token
const float temp = params.temp; const float temp = params.temp;
const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k; const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
const float top_p = params.top_p; const float top_p = params.top_p;
@ -653,6 +655,8 @@ int main(int argc, char ** argv) {
if (!path_session.empty() && need_to_save_session && !params.prompt_cache_ro) { if (!path_session.empty() && need_to_save_session && !params.prompt_cache_ro) {
need_to_save_session = false; need_to_save_session = false;
llama_save_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size()); llama_save_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
LOG("saved session to %s\n", path_session.c_str());
} }
llama_token id = 0; llama_token id = 0;
@ -672,55 +676,68 @@ int main(int argc, char ** argv) {
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
} }
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; llama_token_data_array cur_p = { candidates.data(), candidates.size(), false };
if (ctx_guidance) { if (ctx_guidance) {
llama_sample_classifier_free_guidance(ctx, &candidates_p, ctx_guidance, params.cfg_scale); llama_sample_classifier_free_guidance(ctx, &cur_p, ctx_guidance, params.cfg_scale);
} }
// Apply penalties // Apply penalties
float nl_logit = logits[llama_token_nl(ctx)]; float nl_logit = logits[llama_token_nl(ctx)];
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx); auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
llama_sample_repetition_penalty(ctx, &candidates_p, llama_sample_repetition_penalty(ctx, &cur_p,
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
last_n_repeat, repeat_penalty); last_n_repeat, repeat_penalty);
llama_sample_frequency_and_presence_penalties(ctx, &candidates_p, llama_sample_frequency_and_presence_penalties(ctx, &cur_p,
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
last_n_repeat, alpha_frequency, alpha_presence); last_n_repeat, alpha_frequency, alpha_presence);
if (!penalize_nl) { if (!penalize_nl) {
for (size_t idx = 0; idx < candidates_p.size; idx++) { for (size_t idx = 0; idx < cur_p.size; idx++) {
if (candidates_p.data[idx].id == llama_token_nl(ctx)) { if (cur_p.data[idx].id == llama_token_nl(ctx)) {
candidates_p.data[idx].logit = nl_logit; cur_p.data[idx].logit = nl_logit;
break; break;
} }
} }
} }
if (grammar != NULL) { if (grammar != NULL) {
llama_sample_grammar(ctx, &candidates_p, grammar); llama_sample_grammar(ctx, &cur_p, grammar);
} }
if (temp <= 0) { if (temp <= 0) {
// Greedy sampling // Greedy sampling
id = llama_sample_token_greedy(ctx, &candidates_p); id = llama_sample_token_greedy(ctx, &cur_p);
} else { } else {
if (mirostat == 1) { if (mirostat == 1) {
static float mirostat_mu = 2.0f * mirostat_tau; static float mirostat_mu = 2.0f * mirostat_tau;
const int mirostat_m = 100; const int mirostat_m = 100;
llama_sample_temperature(ctx, &candidates_p, temp); llama_sample_temperature(ctx, &cur_p, temp);
id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
} else if (mirostat == 2) { } else if (mirostat == 2) {
static float mirostat_mu = 2.0f * mirostat_tau; static float mirostat_mu = 2.0f * mirostat_tau;
llama_sample_temperature(ctx, &candidates_p, temp); llama_sample_temperature(ctx, &cur_p, temp);
id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu); id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu);
} else { } else {
// Temperature sampling // Temperature sampling
llama_sample_top_k(ctx, &candidates_p, top_k, 1); llama_sample_top_k (ctx, &cur_p, top_k, 1);
llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1); llama_sample_tail_free (ctx, &cur_p, tfs_z, 1);
llama_sample_typical(ctx, &candidates_p, typical_p, 1); llama_sample_typical (ctx, &cur_p, typical_p, 1);
llama_sample_top_p(ctx, &candidates_p, top_p, 1); llama_sample_top_p (ctx, &cur_p, top_p, 1);
llama_sample_temperature(ctx, &candidates_p, temp); llama_sample_temperature(ctx, &cur_p, temp);
id = llama_sample_token(ctx, &candidates_p);
{
const int n_top = 10;
LOG("top %d candidates:\n", n_top);
for (int i = 0; i < n_top; i++) {
const llama_token id = cur_p.data[i].id;
LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx, id).c_str(), cur_p.data[i].p);
}
}
id = llama_sample_token(ctx, &cur_p);
LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx, id).c_str());
} }
} }
// printf("`%d`", candidates_p.size); // printf("`%d`", candidates_p.size);
@ -731,9 +748,10 @@ int main(int argc, char ** argv) {
last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(id); last_n_tokens.push_back(id);
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, last_n_tokens));
} }
// add it to the context
embd.push_back(id); embd.push_back(id);
// echo this to console // echo this to console
@ -741,8 +759,11 @@ int main(int argc, char ** argv) {
// decrement remaining sampling budget // decrement remaining sampling budget
--n_remain; --n_remain;
LOG("n_remain: %d\n", n_remain);
} else { } else {
// some user input remains from prompt or interaction, forward it to processing // some user input remains from prompt or interaction, forward it to processing
LOG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed);
while ((int) embd_inp.size() > n_consumed) { while ((int) embd_inp.size() > n_consumed) {
embd.push_back(embd_inp[n_consumed]); embd.push_back(embd_inp[n_consumed]);
last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.erase(last_n_tokens.begin());
@ -770,13 +791,12 @@ int main(int argc, char ** argv) {
fflush(stdout); fflush(stdout);
} }
// reset color to default if we there is no pending user input // reset color to default if we there is no pending user input
if (input_echo && (int)embd_inp.size() == n_consumed) { if (input_echo && (int) embd_inp.size() == n_consumed) {
console::set_display(console::reset); console::set_display(console::reset);
} }
// if not currently processing queued inputs; // if not currently processing queued inputs;
if ((int) embd_inp.size() <= n_consumed) { if ((int) embd_inp.size() <= n_consumed) {
// check for reverse prompt // check for reverse prompt
if (params.antiprompt.size()) { if (params.antiprompt.size()) {
std::string last_output; std::string last_output;
@ -794,7 +814,7 @@ int main(int argc, char ** argv) {
? last_output.length() - static_cast<size_t>(antiprompt.length() + extra_padding) ? last_output.length() - static_cast<size_t>(antiprompt.length() + extra_padding)
: 0; : 0;
if (last_output.find(antiprompt.c_str(), search_start_pos) != std::string::npos) { if (last_output.find(antiprompt, search_start_pos) != std::string::npos) {
if (params.interactive) { if (params.interactive) {
is_interacting = true; is_interacting = true;
console::set_display(console::user_input); console::set_display(console::user_input);
@ -804,10 +824,16 @@ int main(int argc, char ** argv) {
break; break;
} }
} }
if (is_antiprompt) {
LOG("found antiprompt: %s\n", last_output.c_str());
}
} }
// deal with end of text token in interactive mode // deal with end of text token in interactive mode
if (last_n_tokens.back() == llama_token_eos(ctx)) { if (last_n_tokens.back() == llama_token_eos(ctx)) {
LOG("found EOS token\n");
if (params.interactive) { if (params.interactive) {
if (params.antiprompt.size() != 0) { if (params.antiprompt.size() != 0) {
// tokenize and inject first reverse prompt // tokenize and inject first reverse prompt
@ -826,16 +852,20 @@ int main(int argc, char ** argv) {
} }
if (n_past > 0 && is_interacting) { if (n_past > 0 && is_interacting) {
LOG("waiting for user input\n");
if (params.instruct) { if (params.instruct) {
printf("\n> "); printf("\n> ");
} }
if (params.input_prefix_bos) { if (params.input_prefix_bos) {
LOG("adding input prefix BOS token\n");
embd_inp.push_back(llama_token_bos(ctx)); embd_inp.push_back(llama_token_bos(ctx));
} }
std::string buffer; std::string buffer;
if (!params.input_prefix.empty()) { if (!params.input_prefix.empty()) {
LOG("appending input prefix: '%s'\n", params.input_prefix.c_str());
buffer += params.input_prefix; buffer += params.input_prefix;
printf("%s", buffer.c_str()); printf("%s", buffer.c_str());
} }
@ -855,23 +885,30 @@ int main(int argc, char ** argv) {
if (buffer.length() > 1) { if (buffer.length() > 1) {
// append input suffix if any // append input suffix if any
if (!params.input_suffix.empty()) { if (!params.input_suffix.empty()) {
LOG("appending input suffix: '%s'\n", params.input_suffix.c_str());
buffer += params.input_suffix; buffer += params.input_suffix;
printf("%s", params.input_suffix.c_str()); printf("%s", params.input_suffix.c_str());
} }
LOG("buffer: '%s'\n", buffer.c_str());
const size_t original_size = embd_inp.size(); const size_t original_size = embd_inp.size();
// instruct mode: insert instruction prefix // instruct mode: insert instruction prefix
if (params.instruct && !is_antiprompt) { if (params.instruct && !is_antiprompt) {
LOG("inserting instruction prefix\n");
n_consumed = embd_inp.size(); n_consumed = embd_inp.size();
embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end()); embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end());
} }
auto line_inp = ::llama_tokenize(ctx, buffer, false); const auto line_inp = ::llama_tokenize(ctx, buffer, false);
LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp));
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end()); embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
// instruct mode: insert response suffix // instruct mode: insert response suffix
if (params.instruct) { if (params.instruct) {
LOG("inserting instruction suffix\n");
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end()); embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
} }
@ -882,6 +919,9 @@ int main(int argc, char ** argv) {
} }
n_remain -= line_inp.size(); n_remain -= line_inp.size();
LOG("n_remain: %d\n", n_remain);
} else {
LOG("empty line, passing control back\n");
} }
input_echo = false; // do not echo this again input_echo = false; // do not echo this again