Expand simple-inference command support
Fix Makefile header dep Rebase on master
This commit is contained in:
parent
11fa3dfd69
commit
34175b0b0c
3 changed files with 99 additions and 16 deletions
2
Makefile
2
Makefile
|
@ -563,7 +563,7 @@ ifndef LLAMA_NO_SEQREP_SAMPLER
|
|||
COMMON_H_DEFS += common/seqrep-sampler.h
|
||||
COMMON_DEPS += seqrep-sampler.o
|
||||
|
||||
seqrep-sampler.o: common/seqrep-sampler.cpp $(COMMON_H_DEPS)
|
||||
seqrep-sampler.o: common/seqrep-sampler.cpp common/seqrep-sampler.h $(COMMON_H_DEPS)
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||
endif
|
||||
|
||||
|
|
|
@ -135,6 +135,10 @@ struct seqrep_logit_info {
|
|||
// Return top k token_data by logit.
|
||||
std::vector<llama_token_data> top_k(const float * const logits, const size_t k);
|
||||
|
||||
seqrep_logit_info(const int n_vocab, const std::vector<llama_token_data> & token_data = {})
|
||||
: n_vocab(n_vocab)
|
||||
, token_data(token_data)
|
||||
{}
|
||||
};
|
||||
|
||||
struct seqrep_rewind_slot {
|
||||
|
|
|
@ -75,8 +75,6 @@ typedef struct seq_ctx {
|
|||
|
||||
llama_token last_sampled;
|
||||
std::vector<tokens_chunk> chunks;
|
||||
// std::vector<llama_token> pending_input;
|
||||
// std::vector<llama_token> output;
|
||||
#ifndef LLAMA_NO_SEQREP_SAMPLER
|
||||
size_t high_water_mark;
|
||||
struct seqrep_rewind_state rewind_state;
|
||||
|
@ -145,7 +143,7 @@ static void concat_chunks(const std::vector<tokens_chunk> & chunks, std::vector<
|
|||
continue;
|
||||
}
|
||||
|
||||
const size_t chunk_offset = start_offset - offset;
|
||||
const size_t chunk_offset = offset < start_offset ? start_offset - offset : 0;
|
||||
const size_t chunk_size = chunk.tokens.size() - chunk_offset;
|
||||
const llama_token * tp = chunk.tokens.data() + chunk_offset;
|
||||
|
||||
|
@ -595,16 +593,31 @@ void gen_ctx::handle_seq(seq_ctx & sctx) {
|
|||
GGML_ASSERT(!sctx.chunks.empty());
|
||||
tokens_chunk & chunk = sctx.chunks.back();
|
||||
GGML_ASSERT(chunk.is_input);
|
||||
GGML_ASSERT(chunk.consumed < chunk.tokens.size());
|
||||
GGML_ASSERT(!chunk.tokens.empty());
|
||||
|
||||
const size_t remain = chunk.tokens.size() - chunk.consumed;
|
||||
const size_t to_consume = std::min(size_t(params.n_batch), remain);
|
||||
for (size_t i = chunk.consumed; i < chunk.consumed + to_consume; ++i) {
|
||||
llama_batch_add(batch, chunk.tokens[i], llama_pos(i), {sctx.seq_id}, false);
|
||||
llama_batch_add(batch, chunk.tokens[i], llama_pos(prompt_size + sctx.n_generated + i), {sctx.seq_id}, false);
|
||||
}
|
||||
chunk.consumed += to_consume;
|
||||
sctx.n_remain -= to_consume;
|
||||
// FIXME: This a lie, we didn't generate it.
|
||||
sctx.n_generated += to_consume;
|
||||
if (chunk.consumed == chunk.tokens.size()) {
|
||||
#ifndef LLAMA_NO_SEQREP_SAMPLER
|
||||
// FIXME: Move this logic to a more appropriate place.
|
||||
for (size_t i = 0; i < chunk.consumed; i++) {
|
||||
sctx.rewind_state.logit_slots.emplace_back(n_vocab);
|
||||
}
|
||||
sctx.high_water_mark = sctx.n_generated + 1;
|
||||
#endif
|
||||
sctx.batch_idx = batch.n_tokens - 1;
|
||||
batch.logits[sctx.batch_idx] = true;
|
||||
sctx.chunks.emplace_back(false, 0, std::vector<llama_token>());
|
||||
sctx.chunks.back().tokens.reserve(sctx.n_remain);
|
||||
sctx.state = SEQ_GENERATING;
|
||||
} else {
|
||||
sctx.batch_idx = -1;
|
||||
}
|
||||
|
@ -638,7 +651,6 @@ void gen_ctx::handle_seq(seq_ctx & sctx) {
|
|||
if (rewind_distance < 1) {
|
||||
return;
|
||||
}
|
||||
// if (sctx.seq_id != 0) printf("<%d:%zu>", sctx.seq_id + 1, rewind_distance);
|
||||
GGML_ASSERT(rewind_distance <= sctx.n_generated && "Rewind index out of bounds somehow?");
|
||||
const size_t slot_idx = sctx.n_generated - rewind_distance;
|
||||
const llama_token nl_id = llama_token_nl(model);
|
||||
|
@ -649,14 +661,12 @@ void gen_ctx::handle_seq(seq_ctx & sctx) {
|
|||
if (sctx.seq_id == focused_sequence) {
|
||||
console::set_display(console::error);
|
||||
fputs("\u3010", stdout);
|
||||
// printf("%zu,%zu,%zu", rewind_distance, sctx.n_generated, sctx.generated_tokens.size());
|
||||
for (size_t i = seq_last_tokens.size() - rewind_distance; i < seq_last_tokens.size(); i++) {
|
||||
if (seq_last_tokens[i] == nl_id) {
|
||||
fputs("\\n", stdout);
|
||||
continue;
|
||||
}
|
||||
const std::string token_str = llama_token_to_piece(ctx, seq_last_tokens[i]);
|
||||
// fputs("|", stdout);
|
||||
fputs(token_str.c_str(), stdout);
|
||||
}
|
||||
fputs("\u3011", stdout);
|
||||
|
@ -701,8 +711,10 @@ bool gen_ctx::go() {
|
|||
sctx.batch_idx = batch.n_tokens - 1;
|
||||
sctx.state = SEQ_GENERATING;
|
||||
if (sctx.seq_id == 0) {
|
||||
sctx.chunks.back().consumed = prompt_size;
|
||||
sctx.chunks.emplace_back(false, 0, std::vector<llama_token>());
|
||||
} else {
|
||||
sctx.chunks.front().consumed = prompt_size;
|
||||
llama_kv_cache_seq_cp(ctx, 0, sctx.seq_id, 0, prompt_size);
|
||||
}
|
||||
#ifndef LLAMA_NO_SEQREP_SAMPLER
|
||||
|
@ -739,10 +751,11 @@ static bool handle_commands(gen_ctx & gctx) {
|
|||
line.reserve(1024);
|
||||
|
||||
|
||||
puts("");
|
||||
LOG_TEE("\n- Entering command mode. Use /help for help, blank line to exit. Focused sequence: %d\n", gctx.focused_sequence + 1);
|
||||
fflush(stdout);
|
||||
while (1) {
|
||||
printf("> ");
|
||||
fflush(stdout);
|
||||
console::readline(line, false);
|
||||
console::set_display(console::reset);
|
||||
while (!line.empty() && std::isspace(line.back())) {
|
||||
|
@ -750,7 +763,7 @@ static bool handle_commands(gen_ctx & gctx) {
|
|||
}
|
||||
if (line.empty()) break;
|
||||
if (line.size() < 2 || line.front() != '/') {
|
||||
printf("\n* Bad command\n");
|
||||
LOG_TEE("\n- Bad command\n");
|
||||
continue;
|
||||
}
|
||||
size_t sep_idx = line.find(' ');
|
||||
|
@ -762,32 +775,98 @@ static bool handle_commands(gen_ctx & gctx) {
|
|||
command = line.substr(1);
|
||||
}
|
||||
|
||||
if (command == "quit") return false;
|
||||
if (command == "help") {
|
||||
LOG_TEE("- Availabe commands:\n");
|
||||
LOG_TEE(" /add TEXT : Adds the specified text to the focused sequence. Alias: /a\n");
|
||||
LOG_TEE(" /addline TEXT : Adds the specified text to the focused with a newline at the end. Alias: /al\n");
|
||||
LOG_TEE(" /help : Show this help.\n");
|
||||
LOG_TEE(" /kill N : Stop sequence N. Alia: /k\n");
|
||||
LOG_TEE(" /list : List sequences and their state. Alias: /l\n");
|
||||
LOG_TEE(" /N : Focus sequence N. Example /2 focus sequence 2.\n");
|
||||
LOG_TEE(" /quit : Exit the program. Alias: /q\n");
|
||||
LOG_TEE("- End listing\n");
|
||||
continue;
|
||||
}
|
||||
|
||||
if (command == "q" || command == "quit") return false;
|
||||
|
||||
// Focus
|
||||
if (isdigit(command[0])) {
|
||||
const int target = std::atoi(command.c_str());
|
||||
if (target < 1 || size_t(target) > gctx.ctxs_seq.size()) {
|
||||
printf("\n* Focus: Bad seq id\n");
|
||||
LOG_TEE("! Focus: Bad seq id\n");
|
||||
} else {
|
||||
gctx.focused_sequence = llama_seq_id(target - 1);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (command == "kill") {
|
||||
if (command == "k" || command == "kill") {
|
||||
const int target = std::atoi(rest.c_str());
|
||||
if (target < 1 || size_t(target) > gctx.ctxs_seq.size()) {
|
||||
printf("\n* Kill: Bad seq id\n");
|
||||
LOG_TEE("! Kill: Bad seq id\n");
|
||||
} else if (target - 1 == gctx.focused_sequence) {
|
||||
printf("\n* Kill: Can't kill focus\n");
|
||||
LOG_TEE("! Kill: Can't kill focus\n");
|
||||
} else {
|
||||
gctx.ctxs_seq[target - 1].state = SEQ_DONE;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
printf("\n* Bad command\n");
|
||||
if (command == "l" || command == "list") {
|
||||
LOG_TEE("- Listing %zu sequence%s:\n",
|
||||
gctx.ctxs_seq.size(),
|
||||
gctx.ctxs_seq.size() != 1 ? "s" : "");
|
||||
for (const seq_ctx & sctx : gctx.ctxs_seq) {
|
||||
std::string label;
|
||||
switch (sctx.state) {
|
||||
case SEQ_DONE: label = "DONE"; break;
|
||||
case SEQ_GENERATING: label = "LIVE"; break;
|
||||
case SEQ_INPUT: label = "FEED"; break;
|
||||
case SEQ_SHARE_PROMPT: label = "WAIT"; break;
|
||||
default: GGML_ASSERT(false);
|
||||
}
|
||||
LOG_TEE(" %s%3d (%s): generated %5zu, remain %5zu. chunks: ",
|
||||
sctx.seq_id == gctx.focused_sequence ? "*" : " ",
|
||||
sctx.seq_id + 1, label.c_str(),
|
||||
sctx.n_generated, sctx.n_remain);
|
||||
for (const tokens_chunk & chunk : sctx.chunks) {
|
||||
if (chunk.is_input) {
|
||||
LOG_TEE("INP(%5zu,%5zu), ", chunk.tokens.size(), chunk.consumed);
|
||||
|
||||
} else {
|
||||
LOG_TEE("GEN(%5zu), ", chunk.tokens.size());
|
||||
}
|
||||
}
|
||||
LOG_TEE("\n");
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (command == "al" || command == "a" || command == "add" || command == "addline") {
|
||||
seq_ctx & sctx = gctx.ctxs_seq[gctx.focused_sequence];
|
||||
|
||||
if (command == "al" || command == "addline") rest.push_back('\n');
|
||||
std::vector<llama_token> input_tokens = ::llama_tokenize(gctx.model, rest, false);
|
||||
if (input_tokens.size() > sctx.n_remain) {
|
||||
LOG_TEE("! Input is %zu token(s) but sequence %d only has space for %zu\n",
|
||||
input_tokens.size(), gctx.focused_sequence + 1, sctx.n_remain);
|
||||
continue;
|
||||
}
|
||||
if (!sctx.chunks.back().is_input) {
|
||||
sctx.chunks.emplace_back(true, 0, input_tokens);
|
||||
} else {
|
||||
tokens_chunk & chunk = sctx.chunks.back();
|
||||
const size_t old_size = chunk.tokens.size();
|
||||
|
||||
chunk.tokens.resize(old_size + input_tokens.size());
|
||||
std::copy(input_tokens.begin(), input_tokens.end(), chunk.tokens.begin() + old_size);
|
||||
}
|
||||
sctx.state = SEQ_INPUT;
|
||||
continue;
|
||||
}
|
||||
|
||||
LOG_TEE("! Bad command\n");
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue