Expand simple-inference command support

Fix Makefile header dep

Rebase on master
This commit is contained in:
KerfuffleV2 2023-11-02 04:31:39 -06:00
parent 11fa3dfd69
commit 34175b0b0c
3 changed files with 99 additions and 16 deletions

View file

@ -563,7 +563,7 @@ ifndef LLAMA_NO_SEQREP_SAMPLER
COMMON_H_DEFS += common/seqrep-sampler.h
COMMON_DEPS += seqrep-sampler.o
seqrep-sampler.o: common/seqrep-sampler.cpp $(COMMON_H_DEPS)
seqrep-sampler.o: common/seqrep-sampler.cpp common/seqrep-sampler.h $(COMMON_H_DEPS)
$(CXX) $(CXXFLAGS) -c $< -o $@
endif

View file

@ -135,6 +135,10 @@ struct seqrep_logit_info {
// Return top k token_data by logit.
std::vector<llama_token_data> top_k(const float * const logits, const size_t k);
seqrep_logit_info(const int n_vocab, const std::vector<llama_token_data> & token_data = {})
: n_vocab(n_vocab)
, token_data(token_data)
{}
};
struct seqrep_rewind_slot {

View file

@ -75,8 +75,6 @@ typedef struct seq_ctx {
llama_token last_sampled;
std::vector<tokens_chunk> chunks;
// std::vector<llama_token> pending_input;
// std::vector<llama_token> output;
#ifndef LLAMA_NO_SEQREP_SAMPLER
size_t high_water_mark;
struct seqrep_rewind_state rewind_state;
@ -145,7 +143,7 @@ static void concat_chunks(const std::vector<tokens_chunk> & chunks, std::vector<
continue;
}
const size_t chunk_offset = start_offset - offset;
const size_t chunk_offset = offset < start_offset ? start_offset - offset : 0;
const size_t chunk_size = chunk.tokens.size() - chunk_offset;
const llama_token * tp = chunk.tokens.data() + chunk_offset;
@ -595,16 +593,31 @@ void gen_ctx::handle_seq(seq_ctx & sctx) {
GGML_ASSERT(!sctx.chunks.empty());
tokens_chunk & chunk = sctx.chunks.back();
GGML_ASSERT(chunk.is_input);
GGML_ASSERT(chunk.consumed < chunk.tokens.size());
GGML_ASSERT(!chunk.tokens.empty());
const size_t remain = chunk.tokens.size() - chunk.consumed;
const size_t to_consume = std::min(size_t(params.n_batch), remain);
for (size_t i = chunk.consumed; i < chunk.consumed + to_consume; ++i) {
llama_batch_add(batch, chunk.tokens[i], llama_pos(i), {sctx.seq_id}, false);
llama_batch_add(batch, chunk.tokens[i], llama_pos(prompt_size + sctx.n_generated + i), {sctx.seq_id}, false);
}
chunk.consumed += to_consume;
sctx.n_remain -= to_consume;
// FIXME: This a lie, we didn't generate it.
sctx.n_generated += to_consume;
if (chunk.consumed == chunk.tokens.size()) {
#ifndef LLAMA_NO_SEQREP_SAMPLER
// FIXME: Move this logic to a more appropriate place.
for (size_t i = 0; i < chunk.consumed; i++) {
sctx.rewind_state.logit_slots.emplace_back(n_vocab);
}
sctx.high_water_mark = sctx.n_generated + 1;
#endif
sctx.batch_idx = batch.n_tokens - 1;
batch.logits[sctx.batch_idx] = true;
sctx.chunks.emplace_back(false, 0, std::vector<llama_token>());
sctx.chunks.back().tokens.reserve(sctx.n_remain);
sctx.state = SEQ_GENERATING;
} else {
sctx.batch_idx = -1;
}
@ -638,7 +651,6 @@ void gen_ctx::handle_seq(seq_ctx & sctx) {
if (rewind_distance < 1) {
return;
}
// if (sctx.seq_id != 0) printf("<%d:%zu>", sctx.seq_id + 1, rewind_distance);
GGML_ASSERT(rewind_distance <= sctx.n_generated && "Rewind index out of bounds somehow?");
const size_t slot_idx = sctx.n_generated - rewind_distance;
const llama_token nl_id = llama_token_nl(model);
@ -649,14 +661,12 @@ void gen_ctx::handle_seq(seq_ctx & sctx) {
if (sctx.seq_id == focused_sequence) {
console::set_display(console::error);
fputs("\u3010", stdout);
// printf("%zu,%zu,%zu", rewind_distance, sctx.n_generated, sctx.generated_tokens.size());
for (size_t i = seq_last_tokens.size() - rewind_distance; i < seq_last_tokens.size(); i++) {
if (seq_last_tokens[i] == nl_id) {
fputs("\\n", stdout);
continue;
}
const std::string token_str = llama_token_to_piece(ctx, seq_last_tokens[i]);
// fputs("|", stdout);
fputs(token_str.c_str(), stdout);
}
fputs("\u3011", stdout);
@ -701,8 +711,10 @@ bool gen_ctx::go() {
sctx.batch_idx = batch.n_tokens - 1;
sctx.state = SEQ_GENERATING;
if (sctx.seq_id == 0) {
sctx.chunks.back().consumed = prompt_size;
sctx.chunks.emplace_back(false, 0, std::vector<llama_token>());
} else {
sctx.chunks.front().consumed = prompt_size;
llama_kv_cache_seq_cp(ctx, 0, sctx.seq_id, 0, prompt_size);
}
#ifndef LLAMA_NO_SEQREP_SAMPLER
@ -739,10 +751,11 @@ static bool handle_commands(gen_ctx & gctx) {
line.reserve(1024);
puts("");
LOG_TEE("\n- Entering command mode. Use /help for help, blank line to exit. Focused sequence: %d\n", gctx.focused_sequence + 1);
fflush(stdout);
while (1) {
printf("> ");
fflush(stdout);
console::readline(line, false);
console::set_display(console::reset);
while (!line.empty() && std::isspace(line.back())) {
@ -750,7 +763,7 @@ static bool handle_commands(gen_ctx & gctx) {
}
if (line.empty()) break;
if (line.size() < 2 || line.front() != '/') {
printf("\n* Bad command\n");
LOG_TEE("\n- Bad command\n");
continue;
}
size_t sep_idx = line.find(' ');
@ -762,32 +775,98 @@ static bool handle_commands(gen_ctx & gctx) {
command = line.substr(1);
}
if (command == "quit") return false;
if (command == "help") {
LOG_TEE("- Availabe commands:\n");
LOG_TEE(" /add TEXT : Adds the specified text to the focused sequence. Alias: /a\n");
LOG_TEE(" /addline TEXT : Adds the specified text to the focused with a newline at the end. Alias: /al\n");
LOG_TEE(" /help : Show this help.\n");
LOG_TEE(" /kill N : Stop sequence N. Alia: /k\n");
LOG_TEE(" /list : List sequences and their state. Alias: /l\n");
LOG_TEE(" /N : Focus sequence N. Example /2 focus sequence 2.\n");
LOG_TEE(" /quit : Exit the program. Alias: /q\n");
LOG_TEE("- End listing\n");
continue;
}
if (command == "q" || command == "quit") return false;
// Focus
if (isdigit(command[0])) {
const int target = std::atoi(command.c_str());
if (target < 1 || size_t(target) > gctx.ctxs_seq.size()) {
printf("\n* Focus: Bad seq id\n");
LOG_TEE("! Focus: Bad seq id\n");
} else {
gctx.focused_sequence = llama_seq_id(target - 1);
}
continue;
}
if (command == "kill") {
if (command == "k" || command == "kill") {
const int target = std::atoi(rest.c_str());
if (target < 1 || size_t(target) > gctx.ctxs_seq.size()) {
printf("\n* Kill: Bad seq id\n");
LOG_TEE("! Kill: Bad seq id\n");
} else if (target - 1 == gctx.focused_sequence) {
printf("\n* Kill: Can't kill focus\n");
LOG_TEE("! Kill: Can't kill focus\n");
} else {
gctx.ctxs_seq[target - 1].state = SEQ_DONE;
}
continue;
}
printf("\n* Bad command\n");
if (command == "l" || command == "list") {
LOG_TEE("- Listing %zu sequence%s:\n",
gctx.ctxs_seq.size(),
gctx.ctxs_seq.size() != 1 ? "s" : "");
for (const seq_ctx & sctx : gctx.ctxs_seq) {
std::string label;
switch (sctx.state) {
case SEQ_DONE: label = "DONE"; break;
case SEQ_GENERATING: label = "LIVE"; break;
case SEQ_INPUT: label = "FEED"; break;
case SEQ_SHARE_PROMPT: label = "WAIT"; break;
default: GGML_ASSERT(false);
}
LOG_TEE(" %s%3d (%s): generated %5zu, remain %5zu. chunks: ",
sctx.seq_id == gctx.focused_sequence ? "*" : " ",
sctx.seq_id + 1, label.c_str(),
sctx.n_generated, sctx.n_remain);
for (const tokens_chunk & chunk : sctx.chunks) {
if (chunk.is_input) {
LOG_TEE("INP(%5zu,%5zu), ", chunk.tokens.size(), chunk.consumed);
} else {
LOG_TEE("GEN(%5zu), ", chunk.tokens.size());
}
}
LOG_TEE("\n");
}
continue;
}
if (command == "al" || command == "a" || command == "add" || command == "addline") {
seq_ctx & sctx = gctx.ctxs_seq[gctx.focused_sequence];
if (command == "al" || command == "addline") rest.push_back('\n');
std::vector<llama_token> input_tokens = ::llama_tokenize(gctx.model, rest, false);
if (input_tokens.size() > sctx.n_remain) {
LOG_TEE("! Input is %zu token(s) but sequence %d only has space for %zu\n",
input_tokens.size(), gctx.focused_sequence + 1, sctx.n_remain);
continue;
}
if (!sctx.chunks.back().is_input) {
sctx.chunks.emplace_back(true, 0, input_tokens);
} else {
tokens_chunk & chunk = sctx.chunks.back();
const size_t old_size = chunk.tokens.size();
chunk.tokens.resize(old_size + input_tokens.size());
std::copy(input_tokens.begin(), input_tokens.end(), chunk.tokens.begin() + old_size);
}
sctx.state = SEQ_INPUT;
continue;
}
LOG_TEE("! Bad command\n");
}
return true;
}