From a0c5587401de35bb76183ce97172915526631d53 Mon Sep 17 00:00:00 2001 From: KerfuffleV2 Date: Fri, 3 Nov 2023 05:06:57 -0600 Subject: [PATCH] Expand simple-inference command handling. Clear KV cache when a sequence is killed. Allow parse_escapes to handle \x sequences --- .../simple-inference/simple-inference.cpp | 237 ++++++++++++------ 1 file changed, 162 insertions(+), 75 deletions(-) diff --git a/examples/simple-inference/simple-inference.cpp b/examples/simple-inference/simple-inference.cpp index 798f016b3..711268534 100644 --- a/examples/simple-inference/simple-inference.cpp +++ b/examples/simple-inference/simple-inference.cpp @@ -70,7 +70,7 @@ typedef struct seq_ctx { int32_t batch_idx; enum seq_state state; size_t n_remain; - size_t n_generated; + size_t n_toks; // Note: Does not include initial prompt size. llama_sampling_context *ctx_sampling; llama_token last_sampled; @@ -121,7 +121,6 @@ typedef struct gen_ctx { ~gen_ctx(); void dump_batches(const size_t prompt_start = 0); void dump_chunks(const std::vector & chunks, const size_t start_offset = 0); - void dump_batch(const size_t seq); void handle_seq(seq_ctx & sctx); #ifndef LLAMA_NO_SEQREP_SAMPLER void handle_seq_seqrep(seq_ctx & sctx); @@ -227,8 +226,8 @@ static bool check_unsupported(const gpt_params * params) { nope = "prompt cache"; else if (params->escape) nope = "prompt escaping"; - else if (params->interactive || params->interactive_first || params->instruct) - nope = "interactive mode"; + else if (params->interactive_first || params->instruct) + nope = "interactive first or instruct mode"; else if (!params->input_prefix.empty() || !params->input_suffix.empty() || params->input_prefix_bos) nope = "input prefix or suffix"; else if (params->hellaswag) @@ -238,7 +237,7 @@ static bool check_unsupported(const gpt_params * params) { else if (!params->antiprompt.empty()) nope = "reverse prompt"; if (!nope.empty()) { - LOG_TEE("%s: error: We don't support %s here.\n", __func__, nope.c_str()); + printf("%s: error: We don't support %s here.\n", __func__, nope.c_str()); return false; } return true; @@ -254,15 +253,15 @@ bool gen_ctx::init_params(const int argc, char ** argv) { } if (params.rope_freq_base != 10000.0) { - LOG_TEE("%s: warning: changing RoPE frequency base to %g (default 10000.0)\n", __func__, params.rope_freq_base); + printf("%s: warning: changing RoPE frequency base to %g (default 10000.0)\n", __func__, params.rope_freq_base); } if (params.rope_freq_scale != 1.0) { - LOG_TEE("%s: warning: scaling RoPE frequency by %g (default 1.0)\n", __func__, params.rope_freq_scale); + printf("%s: warning: scaling RoPE frequency by %g (default 1.0)\n", __func__, params.rope_freq_scale); } if (params.n_ctx < 8) { - LOG_TEE("%s: warning: minimum context size is 8, using minimum size.\n", __func__); + printf("%s: warning: minimum context size is 8, using minimum size.\n", __func__); params.n_ctx = 8; } @@ -270,7 +269,7 @@ bool gen_ctx::init_params(const int argc, char ** argv) { params.seed = time(NULL); } - LOG_TEE("%s: seed = %u\n", __func__, params.seed); + printf("%s: seed = %u\n", __func__, params.seed); std::mt19937 rng(params.seed); if (params.random_prompt) { @@ -289,14 +288,14 @@ bool gen_ctx::init_model() { std::tie(model, ctx) = llama_init_from_gpt_params(params); if (model == NULL) { - LOG_TEE("%s: error: unable to load model\n", __func__); + printf("%s: error: unable to load model\n", __func__); return false; } // print system information { - LOG_TEE("\n"); - LOG_TEE("system_info: n_threads = %d / %d | %s\n", + printf("\n"); + printf("system_info: n_threads = %d / %d | %s\n", params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info()); } @@ -327,20 +326,20 @@ bool gen_ctx::init_prompt() { LOG("n_ctx: %d\n", n_ctx); if ((int) prompt_tokens.size() > n_ctx - 4) { - LOG_TEE("%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int) prompt_tokens.size(), n_ctx - 4); + printf("%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int) prompt_tokens.size(), n_ctx - 4); return false; } prompt_size = prompt_tokens.size(); if (params.verbose_prompt) { - LOG_TEE("\n"); - LOG_TEE("%s: prompt: '%s'\n", __func__, params.prompt.c_str()); - LOG_TEE("%s: number of tokens in prompt = %zu\n", __func__, prompt_tokens.size()); + printf("\n"); + printf("%s: prompt: '%s'\n", __func__, params.prompt.c_str()); + printf("%s: number of tokens in prompt = %zu\n", __func__, prompt_tokens.size()); for (int i = 0; i < (int) prompt_tokens.size(); i++) { - LOG_TEE("%6d -> '%s'\n", prompt_tokens[i], llama_token_to_piece(ctx, prompt_tokens[i]).c_str()); + printf("%6d -> '%s'\n", prompt_tokens[i], llama_token_to_piece(ctx, prompt_tokens[i]).c_str()); } - LOG_TEE("\n"); + printf("\n"); } return true; } @@ -367,7 +366,7 @@ bool gen_ctx::init_handlers() { } bool gen_ctx::init_sampling() { - LOG_TEE("sampling: %s\n", llama_sampling_print(sparams).c_str()); + printf("sampling: %s\n", llama_sampling_print(sparams).c_str()); #ifndef LLAMA_NO_SEQREP_SAMPLER for (auto & sr_params : sparams.seqrep_params) { seqrep_sampler_params_dump(&sr_params); @@ -476,7 +475,7 @@ bool gen_ctx::feed_prompt(const std::vector & tokens, llama_pos pos if (llama_decode(ctx, batch) != 0) { console::set_display(console::reset); - LOG_TEE("%s : failed to eval\n", __func__); + printf("%s : failed to eval\n", __func__); return false; } decode_count++; @@ -534,7 +533,7 @@ void gen_ctx::dump_batches(const size_t prompt_start) { if (sctx.seq_id == focused_sequence) continue; printf("\n\n%s Result #%d (size: %zu", !first ? "====================" : "####################", - sctx.seq_id + 1, prompt_size + sctx.n_generated); + sctx.seq_id + 1, prompt_size + sctx.n_toks); #ifndef LLAMA_NO_SEQREP_SAMPLER printf(", rewind cnt/toks: %zu/%zu", sctx.rewind_count, sctx.rewind_tokens); #endif @@ -545,7 +544,7 @@ void gen_ctx::dump_batches(const size_t prompt_start) { seq_ctx & sctx = ctxs_seq[focused_sequence]; printf("\n\n%s Result #%d (size: %zu", !first ? "====================" : "####################", - sctx.seq_id + 1, prompt_size + sctx.n_generated); + sctx.seq_id + 1, prompt_size + sctx.n_toks); #ifndef LLAMA_NO_SEQREP_SAMPLER printf(", rewind cnt/toks: %zu/%zu", sctx.rewind_count, sctx.rewind_tokens); #endif @@ -572,7 +571,7 @@ void gen_ctx::handle_seq(seq_ctx & sctx) { fputs(token_str.c_str(), stdout); fflush(stdout); } - sctx.n_generated++; + sctx.n_toks++; sctx.n_remain--; if (sctx.chunks.empty() || sctx.chunks.back().is_input) { sctx.chunks.emplace_back(0, false, std::vector()); @@ -581,11 +580,11 @@ void gen_ctx::handle_seq(seq_ctx & sctx) { if (sctx.last_sampled == llama_token_eos(model) || sctx.n_remain == 0) { sctx.state = SEQ_DONE; sctx.batch_idx = -1; - // LOG_TEE(" [end of text]\n"); + // printf(" [end of text]\n"); // break; } else { sctx.batch_idx = batch.n_tokens; - llama_batch_add(batch, sctx.last_sampled, prompt_size + sctx.n_generated, {sctx.seq_id}, true); + llama_batch_add(batch, sctx.last_sampled, prompt_size + sctx.n_toks, {sctx.seq_id}, true); } } break; @@ -600,19 +599,18 @@ void gen_ctx::handle_seq(seq_ctx & sctx) { const size_t remain = chunk.tokens.size() - chunk.consumed; const size_t to_consume = std::min(size_t(params.n_batch), remain); for (size_t i = chunk.consumed; i < chunk.consumed + to_consume; ++i) { - llama_batch_add(batch, chunk.tokens[i], llama_pos(prompt_size + sctx.n_generated + i), {sctx.seq_id}, false); + llama_batch_add(batch, chunk.tokens[i], llama_pos(prompt_size + sctx.n_toks + i), {sctx.seq_id}, false); } chunk.consumed += to_consume; sctx.n_remain -= to_consume; - // FIXME: This a lie, we didn't generate it. - sctx.n_generated += to_consume; + sctx.n_toks += to_consume; if (chunk.consumed == chunk.tokens.size()) { #ifndef LLAMA_NO_SEQREP_SAMPLER // FIXME: Move this logic to a more appropriate place. for (size_t i = 0; i < chunk.consumed; i++) { sctx.rewind_state.logit_slots.emplace_back(n_vocab); } - sctx.high_water_mark = sctx.n_generated + 1; + sctx.high_water_mark = sctx.n_toks + 1; #endif sctx.batch_idx = batch.n_tokens - 1; batch.logits[sctx.batch_idx] = true; @@ -631,29 +629,29 @@ void gen_ctx::handle_seq(seq_ctx & sctx) { #ifndef LLAMA_NO_SEQREP_SAMPLER void gen_ctx::handle_seq_seqrep(seq_ctx & sctx) { - if (sctx.n_generated > 0) { - seqrep_rewind_slot & rw_slot = sctx.rewind_state.get_rewind_slot(sctx.n_generated); + if (sctx.n_toks > 0) { + seqrep_rewind_slot & rw_slot = sctx.rewind_state.get_rewind_slot(sctx.n_toks); if (rw_slot.ctx_sampling == nullptr) { rw_slot.ctx_sampling = llama_sampling_init(params.sparams); } llama_sampling_cp(sctx.ctx_sampling, rw_slot.ctx_sampling); - sctx.rewind_state.set_logits_slot(ctx, sctx.n_generated, sctx.batch_idx); + sctx.rewind_state.set_logits_slot(ctx, sctx.n_toks, sctx.batch_idx); } else { return; } std::vector seq_last_tokens; - seq_last_tokens.reserve(sctx.n_generated); + seq_last_tokens.reserve(sctx.n_toks); concat_chunks(sctx.chunks, seq_last_tokens, prompt_size); size_t rewind_distance = llama_seqrep_handle_rewind( - ctx, sctx.rewind_state, seq_last_tokens, sctx.n_generated, prompt_tokens, + ctx, sctx.rewind_state, seq_last_tokens, sctx.n_toks, prompt_tokens, sparams.seqrep_params, &sctx.high_water_mark, sctx.batch_idx); if (rewind_distance < 1) { return; } - GGML_ASSERT(rewind_distance <= sctx.n_generated && "Rewind index out of bounds somehow?"); - const size_t slot_idx = sctx.n_generated - rewind_distance; + GGML_ASSERT(rewind_distance <= sctx.n_toks && "Rewind index out of bounds somehow?"); + const size_t slot_idx = sctx.n_toks - rewind_distance; const llama_token nl_id = llama_token_nl(model); seqrep_rewind_slot & rw_slot = sctx.rewind_state.get_rewind_slot(slot_idx); @@ -676,10 +674,10 @@ void gen_ctx::handle_seq(seq_ctx & sctx) { } sctx.n_remain += rewind_distance; - sctx.n_generated -= rewind_distance; + sctx.n_toks -= rewind_distance; sctx.rewind_count++; sctx.rewind_tokens += rewind_distance; - llama_kv_cache_seq_rm(ctx, sctx.seq_id, prompt_size + sctx.n_generated + 1, -1); + llama_kv_cache_seq_rm(ctx, sctx.seq_id, prompt_size + sctx.n_toks + 1, -1); while (!sctx.chunks.empty() && rewind_distance > 0) { tokens_chunk & last_chunk = sctx.chunks.back(); GGML_ASSERT(!last_chunk.is_input); @@ -738,8 +736,9 @@ bool gen_ctx::go() { decode_time_last = std::max(int64_t(0), ggml_time_us() - decode_time_last); decode_time_total += decode_time_last; + // FIXME: Handle KV cache pressure better. if (decode_result != 0) { - LOG_TEE("%s : failed to eval batch of size %d: %s\n", __func__, batch.n_tokens, + fprintf(stderr, "%s : failed to eval batch of size %d: %s\n", __func__, batch.n_tokens, decode_result == 1 ? "couldn't find slot" : "unknown error"); return false; } @@ -752,7 +751,7 @@ static bool handle_commands(gen_ctx & gctx) { line.reserve(1024); - LOG_TEE("\n- Entering command mode. Use /help for help, blank line to exit. Focused sequence: %d\n", gctx.focused_sequence + 1); + printf("\n- Entering command mode. Use /help for help, blank line to exit. Focused sequence: %d\n", gctx.focused_sequence + 1); fflush(stdout); while (1) { printf("> "); @@ -764,7 +763,7 @@ static bool handle_commands(gen_ctx & gctx) { } if (line.empty()) break; if (line.size() < 2 || line.front() != '/') { - LOG_TEE("\n- Bad command\n"); + printf("\n- Bad command\n"); continue; } size_t sep_idx = line.find(' '); @@ -775,47 +774,65 @@ static bool handle_commands(gen_ctx & gctx) { } else { command = line.substr(1); } + for (char & c : command) c = std::tolower(c); - if (command == "help") { - LOG_TEE("- Availabe commands:\n"); - LOG_TEE(" /add TEXT : Adds the specified text to the focused sequence. Alias: /a\n"); - LOG_TEE(" /addline TEXT : Adds the specified text to the focused with a newline at the end. Alias: /al\n"); - LOG_TEE(" /help : Show this help.\n"); - LOG_TEE(" /kill N : Stop sequence N. Alia: /k\n"); - LOG_TEE(" /list : List sequences and their state. Alias: /l\n"); - LOG_TEE(" /N : Focus sequence N. Example /2 focus sequence 2.\n"); - LOG_TEE(" /quit : Exit the program. Alias: /q\n"); - LOG_TEE("- End listing\n"); + if (command == "h" || command == "help") { + printf("- Help: For commands with [SEQ], optionally specify a sequence number here to set the target.\n"); + printf(" If sequence isn't specified, then the current focus is used if possible.\n"); + printf(" One of any punctuation character is allowed after the number.\n"); + printf(" For example, '/1add hello' and '/1,add hello' both add 'hello' to sequence 1.\n"); + printf("- Available commands:\n"); + printf(" /[SEQ]add TEXT : Adds the specified text to the focused sequence. Alias: /a\n"); + printf(" /[SEQ]addesc TEXT : Same as /add but handles escapes (\\n, \\x20, etc) and tokenizes without a leading space. Alias: /ae\n"); + printf(" /[SEQ]addline TEXT : Same as /add but appends a newline. Alias: /al\n"); + printf(" /help : Show this help. Alias: /h\n"); + printf(" /[SEQ]dump N : Dump the last N tokens of SEQ showing offsets from the end. Alias: /d\n"); + printf(" /[SEQ]dumptokens N : Same as /dump but displays token IDs as well. Alias: /dt\n"); + printf(" /[SEQ]kill : Stop sequence SEQ. Alia: /k\n"); + printf(" /list : List sequences and their state. Alias: /l\n"); + printf(" /[SEQ]focus : Focus sequence SEQ. Alias: Just use /1, /2, etc\n"); + printf(" /[SEQ]print : Display the content of SEQ. Alias: /p\n"); + printf(" /quit : Exit the program. Alias: /q\n"); + printf("- End listing\n"); continue; } if (command == "q" || command == "quit") return false; + llama_seq_id target = -1; + // Focus if (isdigit(command[0])) { - const int target = std::atoi(command.c_str()); + char * parse_end = nullptr; + target = std::strtol(command.c_str(), &parse_end, 10); if (target < 1 || size_t(target) > gctx.ctxs_seq.size()) { - LOG_TEE("! Focus: Bad seq id\n"); - } else { - gctx.focused_sequence = llama_seq_id(target - 1); + printf("! Bad seq id\n"); + continue; } + target--; + if (std::ispunct(*parse_end)) parse_end++; + command = std::string(parse_end); + } + + if (command.empty() || command == "focus") { + printf("- Focus changed from %d to %d\n", gctx.focused_sequence + 1, target + 1); + gctx.focused_sequence = llama_seq_id(target); continue; } if (command == "k" || command == "kill") { - const int target = std::atoi(rest.c_str()); - if (target < 1 || size_t(target) > gctx.ctxs_seq.size()) { - LOG_TEE("! Kill: Bad seq id\n"); - } else if (target - 1 == gctx.focused_sequence) { - LOG_TEE("! Kill: Can't kill focus\n"); + if (target == gctx.focused_sequence) { + printf("! Kill: Can't kill focus\n"); } else { - gctx.ctxs_seq[target - 1].state = SEQ_DONE; + printf("- Killed sequence %d\n", target + 1); + gctx.ctxs_seq[target].state = SEQ_DONE; + llama_kv_cache_seq_rm(gctx.ctx, target, -1, -1); } continue; } if (command == "l" || command == "list") { - LOG_TEE("- Listing %zu sequence%s:\n", + printf("- Listing %zu sequence%s:\n", gctx.ctxs_seq.size(), gctx.ctxs_seq.size() != 1 ? "s" : ""); for (const seq_ctx & sctx : gctx.ctxs_seq) { @@ -827,30 +844,37 @@ static bool handle_commands(gen_ctx & gctx) { case SEQ_SHARE_PROMPT: label = "WAIT"; break; default: GGML_ASSERT(false); } - LOG_TEE(" %s%3d (%s): generated %5zu, remain %5zu. chunks: ", + printf(" %s%3d (%s): generated %5zu, remain %5zu. chunks: ", sctx.seq_id == gctx.focused_sequence ? "*" : " ", sctx.seq_id + 1, label.c_str(), - sctx.n_generated, sctx.n_remain); + sctx.n_toks, sctx.n_remain); for (const tokens_chunk & chunk : sctx.chunks) { if (chunk.is_input) { - LOG_TEE("INP(%5zu,%5zu), ", chunk.tokens.size(), chunk.consumed); + printf("INP(%5zu,%5zu), ", chunk.tokens.size(), chunk.consumed); } else { - LOG_TEE("GEN(%5zu), ", chunk.tokens.size()); + printf("GEN(%5zu), ", chunk.tokens.size()); } } - LOG_TEE("\n"); + printf("\n"); } continue; } - if (command == "al" || command == "a" || command == "add" || command == "addline") { - seq_ctx & sctx = gctx.ctxs_seq[gctx.focused_sequence]; + if ( command == "al" || command == "a" || command == "ae" + || command == "add" || command == "addline" || command == "addesc") { + bool is_special = false; + seq_ctx & sctx = gctx.ctxs_seq[target < 0 ? gctx.focused_sequence : target]; - if (command == "al" || command == "addline") rest.push_back('\n'); - std::vector input_tokens = ::llama_tokenize(gctx.model, rest, false); + if (command == "al" || command == "addline") { + rest.push_back('\n'); + } else if (command == "ae" || command == "addesc") { + process_escapes(rest); + is_special = true; + } + std::vector input_tokens = ::llama_tokenize(gctx.model, rest, false, is_special); if (input_tokens.size() > sctx.n_remain) { - LOG_TEE("! Input is %zu token(s) but sequence %d only has space for %zu\n", + printf("! Input is %zu token(s) but sequence %d only has space for %zu\n", input_tokens.size(), gctx.focused_sequence + 1, sctx.n_remain); continue; } @@ -867,7 +891,68 @@ static bool handle_commands(gen_ctx & gctx) { continue; } - LOG_TEE("! Bad command\n"); + if (command == "p" || command == "print") { + seq_ctx & sctx = gctx.ctxs_seq[target < 0 ? gctx.focused_sequence : target]; + std::string label; + switch (sctx.state) { + case SEQ_DONE: label = "DONE"; break; + case SEQ_GENERATING: label = "LIVE"; break; + case SEQ_INPUT: label = "FEED"; break; + case SEQ_SHARE_PROMPT: label = "WAIT"; break; + default: GGML_ASSERT(false); + } + + printf("- Showing sequence %3d%s: state %s, generated %5zu, remain %5zu. chunks: ", + sctx.seq_id + 1, + sctx.seq_id == gctx.focused_sequence ? "(focus)" : " ", + label.c_str(), sctx.n_toks, sctx.n_remain); + for (const tokens_chunk & chunk : sctx.chunks) { + if (chunk.is_input) { + printf("INP(%5zu,%5zu), ", chunk.tokens.size(), chunk.consumed); + + } else { + printf("GEN(%5zu), ", chunk.tokens.size()); + } + } + printf("\n"); + gctx.dump_chunks(sctx.chunks); + printf("\n- Done\n"); + continue; + } + + if (command == "d" || command == "dt" || command == "dump" || command == "dumptokens") { + seq_ctx & sctx = gctx.ctxs_seq[target < 0 ? gctx.focused_sequence : target]; + const bool with_id = command == "dt" || command == "dumptokens"; + const size_t max_n = sctx.n_toks + gctx.prompt_size; + size_t dump_n = size_t(std::max(0, atoi(rest.c_str()))); + if (dump_n == 0) dump_n = 200; + dump_n = std::min(dump_n, max_n); + + printf("- Dumping last %zu token%s from sequence %d\n", + dump_n, dump_n != 1 ? "s" : "", target + 1); + + std::vector result; + result.reserve(dump_n); + concat_chunks(sctx.chunks, result, max_n - dump_n); + GGML_ASSERT(result.size() == dump_n); + for (size_t i = 0; i < dump_n; i++) { + const llama_token tid = result[i]; + console::set_display(console::user_input); + printf("[%zu", dump_n - i); + if (with_id) { + printf(",%d", tid); + } + fputs("]", stdout); + console::set_display(console::reset); + fputs(llama_token_to_piece(gctx.ctx, tid).c_str(), stdout); + + } + console::set_display(console::reset); + printf("\n\n- Dump complete.\n"); + continue; + } + + printf("! Bad command\n"); } return true; } @@ -875,10 +960,12 @@ static bool handle_commands(gen_ctx & gctx) { int main(int argc, char ** argv) { gen_ctx gctx(argc, argv); - while (gctx.go() && !done) { + // This might look weird but done can get set while go() is running. + while (!done && gctx.go() && !done) { bool need_dump = gctx.params.n_parallel > 1 && gctx.decode_count % SI_DUMP_SEQUENCES_INTERVAL == 0; if (interrupted) { - if (!handle_commands(gctx)) break; + if (!gctx.params.interactive || !handle_commands(gctx)) break; + // Double check that ^C wasn't hit again. if (done) break; interrupted = false; need_dump = true;