Expand simple-inference command support
Fix Makefile header dep Rebase on master
This commit is contained in:
parent
11fa3dfd69
commit
34175b0b0c
3 changed files with 99 additions and 16 deletions
2
Makefile
2
Makefile
|
@ -563,7 +563,7 @@ ifndef LLAMA_NO_SEQREP_SAMPLER
|
||||||
COMMON_H_DEFS += common/seqrep-sampler.h
|
COMMON_H_DEFS += common/seqrep-sampler.h
|
||||||
COMMON_DEPS += seqrep-sampler.o
|
COMMON_DEPS += seqrep-sampler.o
|
||||||
|
|
||||||
seqrep-sampler.o: common/seqrep-sampler.cpp $(COMMON_H_DEPS)
|
seqrep-sampler.o: common/seqrep-sampler.cpp common/seqrep-sampler.h $(COMMON_H_DEPS)
|
||||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||||
endif
|
endif
|
||||||
|
|
||||||
|
|
|
@ -135,6 +135,10 @@ struct seqrep_logit_info {
|
||||||
// Return top k token_data by logit.
|
// Return top k token_data by logit.
|
||||||
std::vector<llama_token_data> top_k(const float * const logits, const size_t k);
|
std::vector<llama_token_data> top_k(const float * const logits, const size_t k);
|
||||||
|
|
||||||
|
seqrep_logit_info(const int n_vocab, const std::vector<llama_token_data> & token_data = {})
|
||||||
|
: n_vocab(n_vocab)
|
||||||
|
, token_data(token_data)
|
||||||
|
{}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct seqrep_rewind_slot {
|
struct seqrep_rewind_slot {
|
||||||
|
|
|
@ -75,8 +75,6 @@ typedef struct seq_ctx {
|
||||||
|
|
||||||
llama_token last_sampled;
|
llama_token last_sampled;
|
||||||
std::vector<tokens_chunk> chunks;
|
std::vector<tokens_chunk> chunks;
|
||||||
// std::vector<llama_token> pending_input;
|
|
||||||
// std::vector<llama_token> output;
|
|
||||||
#ifndef LLAMA_NO_SEQREP_SAMPLER
|
#ifndef LLAMA_NO_SEQREP_SAMPLER
|
||||||
size_t high_water_mark;
|
size_t high_water_mark;
|
||||||
struct seqrep_rewind_state rewind_state;
|
struct seqrep_rewind_state rewind_state;
|
||||||
|
@ -145,7 +143,7 @@ static void concat_chunks(const std::vector<tokens_chunk> & chunks, std::vector<
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
const size_t chunk_offset = start_offset - offset;
|
const size_t chunk_offset = offset < start_offset ? start_offset - offset : 0;
|
||||||
const size_t chunk_size = chunk.tokens.size() - chunk_offset;
|
const size_t chunk_size = chunk.tokens.size() - chunk_offset;
|
||||||
const llama_token * tp = chunk.tokens.data() + chunk_offset;
|
const llama_token * tp = chunk.tokens.data() + chunk_offset;
|
||||||
|
|
||||||
|
@ -595,16 +593,31 @@ void gen_ctx::handle_seq(seq_ctx & sctx) {
|
||||||
GGML_ASSERT(!sctx.chunks.empty());
|
GGML_ASSERT(!sctx.chunks.empty());
|
||||||
tokens_chunk & chunk = sctx.chunks.back();
|
tokens_chunk & chunk = sctx.chunks.back();
|
||||||
GGML_ASSERT(chunk.is_input);
|
GGML_ASSERT(chunk.is_input);
|
||||||
|
GGML_ASSERT(chunk.consumed < chunk.tokens.size());
|
||||||
|
GGML_ASSERT(!chunk.tokens.empty());
|
||||||
|
|
||||||
const size_t remain = chunk.tokens.size() - chunk.consumed;
|
const size_t remain = chunk.tokens.size() - chunk.consumed;
|
||||||
const size_t to_consume = std::min(size_t(params.n_batch), remain);
|
const size_t to_consume = std::min(size_t(params.n_batch), remain);
|
||||||
for (size_t i = chunk.consumed; i < chunk.consumed + to_consume; ++i) {
|
for (size_t i = chunk.consumed; i < chunk.consumed + to_consume; ++i) {
|
||||||
llama_batch_add(batch, chunk.tokens[i], llama_pos(i), {sctx.seq_id}, false);
|
llama_batch_add(batch, chunk.tokens[i], llama_pos(prompt_size + sctx.n_generated + i), {sctx.seq_id}, false);
|
||||||
}
|
}
|
||||||
chunk.consumed += to_consume;
|
chunk.consumed += to_consume;
|
||||||
|
sctx.n_remain -= to_consume;
|
||||||
|
// FIXME: This a lie, we didn't generate it.
|
||||||
|
sctx.n_generated += to_consume;
|
||||||
if (chunk.consumed == chunk.tokens.size()) {
|
if (chunk.consumed == chunk.tokens.size()) {
|
||||||
|
#ifndef LLAMA_NO_SEQREP_SAMPLER
|
||||||
|
// FIXME: Move this logic to a more appropriate place.
|
||||||
|
for (size_t i = 0; i < chunk.consumed; i++) {
|
||||||
|
sctx.rewind_state.logit_slots.emplace_back(n_vocab);
|
||||||
|
}
|
||||||
|
sctx.high_water_mark = sctx.n_generated + 1;
|
||||||
|
#endif
|
||||||
sctx.batch_idx = batch.n_tokens - 1;
|
sctx.batch_idx = batch.n_tokens - 1;
|
||||||
batch.logits[sctx.batch_idx] = true;
|
batch.logits[sctx.batch_idx] = true;
|
||||||
|
sctx.chunks.emplace_back(false, 0, std::vector<llama_token>());
|
||||||
|
sctx.chunks.back().tokens.reserve(sctx.n_remain);
|
||||||
|
sctx.state = SEQ_GENERATING;
|
||||||
} else {
|
} else {
|
||||||
sctx.batch_idx = -1;
|
sctx.batch_idx = -1;
|
||||||
}
|
}
|
||||||
|
@ -638,7 +651,6 @@ void gen_ctx::handle_seq(seq_ctx & sctx) {
|
||||||
if (rewind_distance < 1) {
|
if (rewind_distance < 1) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
// if (sctx.seq_id != 0) printf("<%d:%zu>", sctx.seq_id + 1, rewind_distance);
|
|
||||||
GGML_ASSERT(rewind_distance <= sctx.n_generated && "Rewind index out of bounds somehow?");
|
GGML_ASSERT(rewind_distance <= sctx.n_generated && "Rewind index out of bounds somehow?");
|
||||||
const size_t slot_idx = sctx.n_generated - rewind_distance;
|
const size_t slot_idx = sctx.n_generated - rewind_distance;
|
||||||
const llama_token nl_id = llama_token_nl(model);
|
const llama_token nl_id = llama_token_nl(model);
|
||||||
|
@ -649,14 +661,12 @@ void gen_ctx::handle_seq(seq_ctx & sctx) {
|
||||||
if (sctx.seq_id == focused_sequence) {
|
if (sctx.seq_id == focused_sequence) {
|
||||||
console::set_display(console::error);
|
console::set_display(console::error);
|
||||||
fputs("\u3010", stdout);
|
fputs("\u3010", stdout);
|
||||||
// printf("%zu,%zu,%zu", rewind_distance, sctx.n_generated, sctx.generated_tokens.size());
|
|
||||||
for (size_t i = seq_last_tokens.size() - rewind_distance; i < seq_last_tokens.size(); i++) {
|
for (size_t i = seq_last_tokens.size() - rewind_distance; i < seq_last_tokens.size(); i++) {
|
||||||
if (seq_last_tokens[i] == nl_id) {
|
if (seq_last_tokens[i] == nl_id) {
|
||||||
fputs("\\n", stdout);
|
fputs("\\n", stdout);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
const std::string token_str = llama_token_to_piece(ctx, seq_last_tokens[i]);
|
const std::string token_str = llama_token_to_piece(ctx, seq_last_tokens[i]);
|
||||||
// fputs("|", stdout);
|
|
||||||
fputs(token_str.c_str(), stdout);
|
fputs(token_str.c_str(), stdout);
|
||||||
}
|
}
|
||||||
fputs("\u3011", stdout);
|
fputs("\u3011", stdout);
|
||||||
|
@ -701,8 +711,10 @@ bool gen_ctx::go() {
|
||||||
sctx.batch_idx = batch.n_tokens - 1;
|
sctx.batch_idx = batch.n_tokens - 1;
|
||||||
sctx.state = SEQ_GENERATING;
|
sctx.state = SEQ_GENERATING;
|
||||||
if (sctx.seq_id == 0) {
|
if (sctx.seq_id == 0) {
|
||||||
|
sctx.chunks.back().consumed = prompt_size;
|
||||||
sctx.chunks.emplace_back(false, 0, std::vector<llama_token>());
|
sctx.chunks.emplace_back(false, 0, std::vector<llama_token>());
|
||||||
} else {
|
} else {
|
||||||
|
sctx.chunks.front().consumed = prompt_size;
|
||||||
llama_kv_cache_seq_cp(ctx, 0, sctx.seq_id, 0, prompt_size);
|
llama_kv_cache_seq_cp(ctx, 0, sctx.seq_id, 0, prompt_size);
|
||||||
}
|
}
|
||||||
#ifndef LLAMA_NO_SEQREP_SAMPLER
|
#ifndef LLAMA_NO_SEQREP_SAMPLER
|
||||||
|
@ -739,10 +751,11 @@ static bool handle_commands(gen_ctx & gctx) {
|
||||||
line.reserve(1024);
|
line.reserve(1024);
|
||||||
|
|
||||||
|
|
||||||
puts("");
|
LOG_TEE("\n- Entering command mode. Use /help for help, blank line to exit. Focused sequence: %d\n", gctx.focused_sequence + 1);
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
while (1) {
|
while (1) {
|
||||||
printf("> ");
|
printf("> ");
|
||||||
|
fflush(stdout);
|
||||||
console::readline(line, false);
|
console::readline(line, false);
|
||||||
console::set_display(console::reset);
|
console::set_display(console::reset);
|
||||||
while (!line.empty() && std::isspace(line.back())) {
|
while (!line.empty() && std::isspace(line.back())) {
|
||||||
|
@ -750,7 +763,7 @@ static bool handle_commands(gen_ctx & gctx) {
|
||||||
}
|
}
|
||||||
if (line.empty()) break;
|
if (line.empty()) break;
|
||||||
if (line.size() < 2 || line.front() != '/') {
|
if (line.size() < 2 || line.front() != '/') {
|
||||||
printf("\n* Bad command\n");
|
LOG_TEE("\n- Bad command\n");
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
size_t sep_idx = line.find(' ');
|
size_t sep_idx = line.find(' ');
|
||||||
|
@ -762,32 +775,98 @@ static bool handle_commands(gen_ctx & gctx) {
|
||||||
command = line.substr(1);
|
command = line.substr(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (command == "quit") return false;
|
if (command == "help") {
|
||||||
|
LOG_TEE("- Availabe commands:\n");
|
||||||
|
LOG_TEE(" /add TEXT : Adds the specified text to the focused sequence. Alias: /a\n");
|
||||||
|
LOG_TEE(" /addline TEXT : Adds the specified text to the focused with a newline at the end. Alias: /al\n");
|
||||||
|
LOG_TEE(" /help : Show this help.\n");
|
||||||
|
LOG_TEE(" /kill N : Stop sequence N. Alia: /k\n");
|
||||||
|
LOG_TEE(" /list : List sequences and their state. Alias: /l\n");
|
||||||
|
LOG_TEE(" /N : Focus sequence N. Example /2 focus sequence 2.\n");
|
||||||
|
LOG_TEE(" /quit : Exit the program. Alias: /q\n");
|
||||||
|
LOG_TEE("- End listing\n");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (command == "q" || command == "quit") return false;
|
||||||
|
|
||||||
// Focus
|
// Focus
|
||||||
if (isdigit(command[0])) {
|
if (isdigit(command[0])) {
|
||||||
const int target = std::atoi(command.c_str());
|
const int target = std::atoi(command.c_str());
|
||||||
if (target < 1 || size_t(target) > gctx.ctxs_seq.size()) {
|
if (target < 1 || size_t(target) > gctx.ctxs_seq.size()) {
|
||||||
printf("\n* Focus: Bad seq id\n");
|
LOG_TEE("! Focus: Bad seq id\n");
|
||||||
} else {
|
} else {
|
||||||
gctx.focused_sequence = llama_seq_id(target - 1);
|
gctx.focused_sequence = llama_seq_id(target - 1);
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (command == "kill") {
|
if (command == "k" || command == "kill") {
|
||||||
const int target = std::atoi(rest.c_str());
|
const int target = std::atoi(rest.c_str());
|
||||||
if (target < 1 || size_t(target) > gctx.ctxs_seq.size()) {
|
if (target < 1 || size_t(target) > gctx.ctxs_seq.size()) {
|
||||||
printf("\n* Kill: Bad seq id\n");
|
LOG_TEE("! Kill: Bad seq id\n");
|
||||||
} else if (target - 1 == gctx.focused_sequence) {
|
} else if (target - 1 == gctx.focused_sequence) {
|
||||||
printf("\n* Kill: Can't kill focus\n");
|
LOG_TEE("! Kill: Can't kill focus\n");
|
||||||
} else {
|
} else {
|
||||||
gctx.ctxs_seq[target - 1].state = SEQ_DONE;
|
gctx.ctxs_seq[target - 1].state = SEQ_DONE;
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
printf("\n* Bad command\n");
|
if (command == "l" || command == "list") {
|
||||||
|
LOG_TEE("- Listing %zu sequence%s:\n",
|
||||||
|
gctx.ctxs_seq.size(),
|
||||||
|
gctx.ctxs_seq.size() != 1 ? "s" : "");
|
||||||
|
for (const seq_ctx & sctx : gctx.ctxs_seq) {
|
||||||
|
std::string label;
|
||||||
|
switch (sctx.state) {
|
||||||
|
case SEQ_DONE: label = "DONE"; break;
|
||||||
|
case SEQ_GENERATING: label = "LIVE"; break;
|
||||||
|
case SEQ_INPUT: label = "FEED"; break;
|
||||||
|
case SEQ_SHARE_PROMPT: label = "WAIT"; break;
|
||||||
|
default: GGML_ASSERT(false);
|
||||||
|
}
|
||||||
|
LOG_TEE(" %s%3d (%s): generated %5zu, remain %5zu. chunks: ",
|
||||||
|
sctx.seq_id == gctx.focused_sequence ? "*" : " ",
|
||||||
|
sctx.seq_id + 1, label.c_str(),
|
||||||
|
sctx.n_generated, sctx.n_remain);
|
||||||
|
for (const tokens_chunk & chunk : sctx.chunks) {
|
||||||
|
if (chunk.is_input) {
|
||||||
|
LOG_TEE("INP(%5zu,%5zu), ", chunk.tokens.size(), chunk.consumed);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
LOG_TEE("GEN(%5zu), ", chunk.tokens.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
LOG_TEE("\n");
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (command == "al" || command == "a" || command == "add" || command == "addline") {
|
||||||
|
seq_ctx & sctx = gctx.ctxs_seq[gctx.focused_sequence];
|
||||||
|
|
||||||
|
if (command == "al" || command == "addline") rest.push_back('\n');
|
||||||
|
std::vector<llama_token> input_tokens = ::llama_tokenize(gctx.model, rest, false);
|
||||||
|
if (input_tokens.size() > sctx.n_remain) {
|
||||||
|
LOG_TEE("! Input is %zu token(s) but sequence %d only has space for %zu\n",
|
||||||
|
input_tokens.size(), gctx.focused_sequence + 1, sctx.n_remain);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (!sctx.chunks.back().is_input) {
|
||||||
|
sctx.chunks.emplace_back(true, 0, input_tokens);
|
||||||
|
} else {
|
||||||
|
tokens_chunk & chunk = sctx.chunks.back();
|
||||||
|
const size_t old_size = chunk.tokens.size();
|
||||||
|
|
||||||
|
chunk.tokens.resize(old_size + input_tokens.size());
|
||||||
|
std::copy(input_tokens.begin(), input_tokens.end(), chunk.tokens.begin() + old_size);
|
||||||
|
}
|
||||||
|
sctx.state = SEQ_INPUT;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_TEE("! Bad command\n");
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue