diff --git a/Makefile b/Makefile index e947d720a..ab1f61263 100644 --- a/Makefile +++ b/Makefile @@ -563,7 +563,7 @@ ifndef LLAMA_NO_SEQREP_SAMPLER COMMON_H_DEFS += common/seqrep-sampler.h COMMON_DEPS += seqrep-sampler.o -seqrep-sampler.o: common/seqrep-sampler.cpp $(COMMON_H_DEPS) +seqrep-sampler.o: common/seqrep-sampler.cpp common/seqrep-sampler.h $(COMMON_H_DEPS) $(CXX) $(CXXFLAGS) -c $< -o $@ endif diff --git a/common/seqrep-sampler.h b/common/seqrep-sampler.h index 273a33150..f56b904ae 100644 --- a/common/seqrep-sampler.h +++ b/common/seqrep-sampler.h @@ -135,6 +135,10 @@ struct seqrep_logit_info { // Return top k token_data by logit. std::vector top_k(const float * const logits, const size_t k); + seqrep_logit_info(const int n_vocab, const std::vector & token_data = {}) + : n_vocab(n_vocab) + , token_data(token_data) + {} }; struct seqrep_rewind_slot { diff --git a/examples/simple-inference/simple-inference.cpp b/examples/simple-inference/simple-inference.cpp index d56cc401e..3d1dd7af7 100644 --- a/examples/simple-inference/simple-inference.cpp +++ b/examples/simple-inference/simple-inference.cpp @@ -75,8 +75,6 @@ typedef struct seq_ctx { llama_token last_sampled; std::vector chunks; - // std::vector pending_input; - // std::vector output; #ifndef LLAMA_NO_SEQREP_SAMPLER size_t high_water_mark; struct seqrep_rewind_state rewind_state; @@ -145,7 +143,7 @@ static void concat_chunks(const std::vector & chunks, std::vector< continue; } - const size_t chunk_offset = start_offset - offset; + const size_t chunk_offset = offset < start_offset ? start_offset - offset : 0; const size_t chunk_size = chunk.tokens.size() - chunk_offset; const llama_token * tp = chunk.tokens.data() + chunk_offset; @@ -595,16 +593,31 @@ void gen_ctx::handle_seq(seq_ctx & sctx) { GGML_ASSERT(!sctx.chunks.empty()); tokens_chunk & chunk = sctx.chunks.back(); GGML_ASSERT(chunk.is_input); + GGML_ASSERT(chunk.consumed < chunk.tokens.size()); + GGML_ASSERT(!chunk.tokens.empty()); const size_t remain = chunk.tokens.size() - chunk.consumed; const size_t to_consume = std::min(size_t(params.n_batch), remain); for (size_t i = chunk.consumed; i < chunk.consumed + to_consume; ++i) { - llama_batch_add(batch, chunk.tokens[i], llama_pos(i), {sctx.seq_id}, false); + llama_batch_add(batch, chunk.tokens[i], llama_pos(prompt_size + sctx.n_generated + i), {sctx.seq_id}, false); } chunk.consumed += to_consume; + sctx.n_remain -= to_consume; + // FIXME: This a lie, we didn't generate it. + sctx.n_generated += to_consume; if (chunk.consumed == chunk.tokens.size()) { +#ifndef LLAMA_NO_SEQREP_SAMPLER + // FIXME: Move this logic to a more appropriate place. + for (size_t i = 0; i < chunk.consumed; i++) { + sctx.rewind_state.logit_slots.emplace_back(n_vocab); + } + sctx.high_water_mark = sctx.n_generated + 1; +#endif sctx.batch_idx = batch.n_tokens - 1; batch.logits[sctx.batch_idx] = true; + sctx.chunks.emplace_back(false, 0, std::vector()); + sctx.chunks.back().tokens.reserve(sctx.n_remain); + sctx.state = SEQ_GENERATING; } else { sctx.batch_idx = -1; } @@ -638,7 +651,6 @@ void gen_ctx::handle_seq(seq_ctx & sctx) { if (rewind_distance < 1) { return; } - // if (sctx.seq_id != 0) printf("<%d:%zu>", sctx.seq_id + 1, rewind_distance); GGML_ASSERT(rewind_distance <= sctx.n_generated && "Rewind index out of bounds somehow?"); const size_t slot_idx = sctx.n_generated - rewind_distance; const llama_token nl_id = llama_token_nl(model); @@ -649,14 +661,12 @@ void gen_ctx::handle_seq(seq_ctx & sctx) { if (sctx.seq_id == focused_sequence) { console::set_display(console::error); fputs("\u3010", stdout); - // printf("%zu,%zu,%zu", rewind_distance, sctx.n_generated, sctx.generated_tokens.size()); for (size_t i = seq_last_tokens.size() - rewind_distance; i < seq_last_tokens.size(); i++) { if (seq_last_tokens[i] == nl_id) { fputs("\\n", stdout); continue; } const std::string token_str = llama_token_to_piece(ctx, seq_last_tokens[i]); - // fputs("|", stdout); fputs(token_str.c_str(), stdout); } fputs("\u3011", stdout); @@ -701,8 +711,10 @@ bool gen_ctx::go() { sctx.batch_idx = batch.n_tokens - 1; sctx.state = SEQ_GENERATING; if (sctx.seq_id == 0) { + sctx.chunks.back().consumed = prompt_size; sctx.chunks.emplace_back(false, 0, std::vector()); } else { + sctx.chunks.front().consumed = prompt_size; llama_kv_cache_seq_cp(ctx, 0, sctx.seq_id, 0, prompt_size); } #ifndef LLAMA_NO_SEQREP_SAMPLER @@ -739,10 +751,11 @@ static bool handle_commands(gen_ctx & gctx) { line.reserve(1024); - puts(""); + LOG_TEE("\n- Entering command mode. Use /help for help, blank line to exit. Focused sequence: %d\n", gctx.focused_sequence + 1); fflush(stdout); while (1) { printf("> "); + fflush(stdout); console::readline(line, false); console::set_display(console::reset); while (!line.empty() && std::isspace(line.back())) { @@ -750,7 +763,7 @@ static bool handle_commands(gen_ctx & gctx) { } if (line.empty()) break; if (line.size() < 2 || line.front() != '/') { - printf("\n* Bad command\n"); + LOG_TEE("\n- Bad command\n"); continue; } size_t sep_idx = line.find(' '); @@ -762,32 +775,98 @@ static bool handle_commands(gen_ctx & gctx) { command = line.substr(1); } - if (command == "quit") return false; + if (command == "help") { + LOG_TEE("- Availabe commands:\n"); + LOG_TEE(" /add TEXT : Adds the specified text to the focused sequence. Alias: /a\n"); + LOG_TEE(" /addline TEXT : Adds the specified text to the focused with a newline at the end. Alias: /al\n"); + LOG_TEE(" /help : Show this help.\n"); + LOG_TEE(" /kill N : Stop sequence N. Alia: /k\n"); + LOG_TEE(" /list : List sequences and their state. Alias: /l\n"); + LOG_TEE(" /N : Focus sequence N. Example /2 focus sequence 2.\n"); + LOG_TEE(" /quit : Exit the program. Alias: /q\n"); + LOG_TEE("- End listing\n"); + continue; + } + + if (command == "q" || command == "quit") return false; // Focus if (isdigit(command[0])) { const int target = std::atoi(command.c_str()); if (target < 1 || size_t(target) > gctx.ctxs_seq.size()) { - printf("\n* Focus: Bad seq id\n"); + LOG_TEE("! Focus: Bad seq id\n"); } else { gctx.focused_sequence = llama_seq_id(target - 1); } continue; } - if (command == "kill") { + if (command == "k" || command == "kill") { const int target = std::atoi(rest.c_str()); if (target < 1 || size_t(target) > gctx.ctxs_seq.size()) { - printf("\n* Kill: Bad seq id\n"); + LOG_TEE("! Kill: Bad seq id\n"); } else if (target - 1 == gctx.focused_sequence) { - printf("\n* Kill: Can't kill focus\n"); + LOG_TEE("! Kill: Can't kill focus\n"); } else { gctx.ctxs_seq[target - 1].state = SEQ_DONE; } continue; } - printf("\n* Bad command\n"); + if (command == "l" || command == "list") { + LOG_TEE("- Listing %zu sequence%s:\n", + gctx.ctxs_seq.size(), + gctx.ctxs_seq.size() != 1 ? "s" : ""); + for (const seq_ctx & sctx : gctx.ctxs_seq) { + std::string label; + switch (sctx.state) { + case SEQ_DONE: label = "DONE"; break; + case SEQ_GENERATING: label = "LIVE"; break; + case SEQ_INPUT: label = "FEED"; break; + case SEQ_SHARE_PROMPT: label = "WAIT"; break; + default: GGML_ASSERT(false); + } + LOG_TEE(" %s%3d (%s): generated %5zu, remain %5zu. chunks: ", + sctx.seq_id == gctx.focused_sequence ? "*" : " ", + sctx.seq_id + 1, label.c_str(), + sctx.n_generated, sctx.n_remain); + for (const tokens_chunk & chunk : sctx.chunks) { + if (chunk.is_input) { + LOG_TEE("INP(%5zu,%5zu), ", chunk.tokens.size(), chunk.consumed); + + } else { + LOG_TEE("GEN(%5zu), ", chunk.tokens.size()); + } + } + LOG_TEE("\n"); + } + continue; + } + + if (command == "al" || command == "a" || command == "add" || command == "addline") { + seq_ctx & sctx = gctx.ctxs_seq[gctx.focused_sequence]; + + if (command == "al" || command == "addline") rest.push_back('\n'); + std::vector input_tokens = ::llama_tokenize(gctx.model, rest, false); + if (input_tokens.size() > sctx.n_remain) { + LOG_TEE("! Input is %zu token(s) but sequence %d only has space for %zu\n", + input_tokens.size(), gctx.focused_sequence + 1, sctx.n_remain); + continue; + } + if (!sctx.chunks.back().is_input) { + sctx.chunks.emplace_back(true, 0, input_tokens); + } else { + tokens_chunk & chunk = sctx.chunks.back(); + const size_t old_size = chunk.tokens.size(); + + chunk.tokens.resize(old_size + input_tokens.size()); + std::copy(input_tokens.begin(), input_tokens.end(), chunk.tokens.begin() + old_size); + } + sctx.state = SEQ_INPUT; + continue; + } + + LOG_TEE("! Bad command\n"); } return true; }