Expand simple-inference command handling.

Clear KV cache when a sequence is killed.

Allow parse_escapes to handle \x sequences
This commit is contained in:
KerfuffleV2 2023-11-03 05:06:57 -06:00
parent a10f7cd087
commit a0c5587401

View file

@ -70,7 +70,7 @@ typedef struct seq_ctx {
int32_t batch_idx;
enum seq_state state;
size_t n_remain;
size_t n_generated;
size_t n_toks; // Note: Does not include initial prompt size.
llama_sampling_context *ctx_sampling;
llama_token last_sampled;
@ -121,7 +121,6 @@ typedef struct gen_ctx {
~gen_ctx();
void dump_batches(const size_t prompt_start = 0);
void dump_chunks(const std::vector<tokens_chunk> & chunks, const size_t start_offset = 0);
void dump_batch(const size_t seq);
void handle_seq(seq_ctx & sctx);
#ifndef LLAMA_NO_SEQREP_SAMPLER
void handle_seq_seqrep(seq_ctx & sctx);
@ -227,8 +226,8 @@ static bool check_unsupported(const gpt_params * params) {
nope = "prompt cache";
else if (params->escape)
nope = "prompt escaping";
else if (params->interactive || params->interactive_first || params->instruct)
nope = "interactive mode";
else if (params->interactive_first || params->instruct)
nope = "interactive first or instruct mode";
else if (!params->input_prefix.empty() || !params->input_suffix.empty() || params->input_prefix_bos)
nope = "input prefix or suffix";
else if (params->hellaswag)
@ -238,7 +237,7 @@ static bool check_unsupported(const gpt_params * params) {
else if (!params->antiprompt.empty())
nope = "reverse prompt";
if (!nope.empty()) {
LOG_TEE("%s: error: We don't support %s here.\n", __func__, nope.c_str());
printf("%s: error: We don't support %s here.\n", __func__, nope.c_str());
return false;
}
return true;
@ -254,15 +253,15 @@ bool gen_ctx::init_params(const int argc, char ** argv) {
}
if (params.rope_freq_base != 10000.0) {
LOG_TEE("%s: warning: changing RoPE frequency base to %g (default 10000.0)\n", __func__, params.rope_freq_base);
printf("%s: warning: changing RoPE frequency base to %g (default 10000.0)\n", __func__, params.rope_freq_base);
}
if (params.rope_freq_scale != 1.0) {
LOG_TEE("%s: warning: scaling RoPE frequency by %g (default 1.0)\n", __func__, params.rope_freq_scale);
printf("%s: warning: scaling RoPE frequency by %g (default 1.0)\n", __func__, params.rope_freq_scale);
}
if (params.n_ctx < 8) {
LOG_TEE("%s: warning: minimum context size is 8, using minimum size.\n", __func__);
printf("%s: warning: minimum context size is 8, using minimum size.\n", __func__);
params.n_ctx = 8;
}
@ -270,7 +269,7 @@ bool gen_ctx::init_params(const int argc, char ** argv) {
params.seed = time(NULL);
}
LOG_TEE("%s: seed = %u\n", __func__, params.seed);
printf("%s: seed = %u\n", __func__, params.seed);
std::mt19937 rng(params.seed);
if (params.random_prompt) {
@ -289,14 +288,14 @@ bool gen_ctx::init_model() {
std::tie(model, ctx) = llama_init_from_gpt_params(params);
if (model == NULL) {
LOG_TEE("%s: error: unable to load model\n", __func__);
printf("%s: error: unable to load model\n", __func__);
return false;
}
// print system information
{
LOG_TEE("\n");
LOG_TEE("system_info: n_threads = %d / %d | %s\n",
printf("\n");
printf("system_info: n_threads = %d / %d | %s\n",
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
}
@ -327,20 +326,20 @@ bool gen_ctx::init_prompt() {
LOG("n_ctx: %d\n", n_ctx);
if ((int) prompt_tokens.size() > n_ctx - 4) {
LOG_TEE("%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int) prompt_tokens.size(), n_ctx - 4);
printf("%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int) prompt_tokens.size(), n_ctx - 4);
return false;
}
prompt_size = prompt_tokens.size();
if (params.verbose_prompt) {
LOG_TEE("\n");
LOG_TEE("%s: prompt: '%s'\n", __func__, params.prompt.c_str());
LOG_TEE("%s: number of tokens in prompt = %zu\n", __func__, prompt_tokens.size());
printf("\n");
printf("%s: prompt: '%s'\n", __func__, params.prompt.c_str());
printf("%s: number of tokens in prompt = %zu\n", __func__, prompt_tokens.size());
for (int i = 0; i < (int) prompt_tokens.size(); i++) {
LOG_TEE("%6d -> '%s'\n", prompt_tokens[i], llama_token_to_piece(ctx, prompt_tokens[i]).c_str());
printf("%6d -> '%s'\n", prompt_tokens[i], llama_token_to_piece(ctx, prompt_tokens[i]).c_str());
}
LOG_TEE("\n");
printf("\n");
}
return true;
}
@ -367,7 +366,7 @@ bool gen_ctx::init_handlers() {
}
bool gen_ctx::init_sampling() {
LOG_TEE("sampling: %s\n", llama_sampling_print(sparams).c_str());
printf("sampling: %s\n", llama_sampling_print(sparams).c_str());
#ifndef LLAMA_NO_SEQREP_SAMPLER
for (auto & sr_params : sparams.seqrep_params) {
seqrep_sampler_params_dump(&sr_params);
@ -476,7 +475,7 @@ bool gen_ctx::feed_prompt(const std::vector<llama_token> & tokens, llama_pos pos
if (llama_decode(ctx, batch) != 0) {
console::set_display(console::reset);
LOG_TEE("%s : failed to eval\n", __func__);
printf("%s : failed to eval\n", __func__);
return false;
}
decode_count++;
@ -534,7 +533,7 @@ void gen_ctx::dump_batches(const size_t prompt_start) {
if (sctx.seq_id == focused_sequence) continue;
printf("\n\n%s Result #%d (size: %zu",
!first ? "====================" : "####################",
sctx.seq_id + 1, prompt_size + sctx.n_generated);
sctx.seq_id + 1, prompt_size + sctx.n_toks);
#ifndef LLAMA_NO_SEQREP_SAMPLER
printf(", rewind cnt/toks: %zu/%zu", sctx.rewind_count, sctx.rewind_tokens);
#endif
@ -545,7 +544,7 @@ void gen_ctx::dump_batches(const size_t prompt_start) {
seq_ctx & sctx = ctxs_seq[focused_sequence];
printf("\n\n%s Result #%d (size: %zu",
!first ? "====================" : "####################",
sctx.seq_id + 1, prompt_size + sctx.n_generated);
sctx.seq_id + 1, prompt_size + sctx.n_toks);
#ifndef LLAMA_NO_SEQREP_SAMPLER
printf(", rewind cnt/toks: %zu/%zu", sctx.rewind_count, sctx.rewind_tokens);
#endif
@ -572,7 +571,7 @@ void gen_ctx::handle_seq(seq_ctx & sctx) {
fputs(token_str.c_str(), stdout);
fflush(stdout);
}
sctx.n_generated++;
sctx.n_toks++;
sctx.n_remain--;
if (sctx.chunks.empty() || sctx.chunks.back().is_input) {
sctx.chunks.emplace_back(0, false, std::vector<llama_token>());
@ -581,11 +580,11 @@ void gen_ctx::handle_seq(seq_ctx & sctx) {
if (sctx.last_sampled == llama_token_eos(model) || sctx.n_remain == 0) {
sctx.state = SEQ_DONE;
sctx.batch_idx = -1;
// LOG_TEE(" [end of text]\n");
// printf(" [end of text]\n");
// break;
} else {
sctx.batch_idx = batch.n_tokens;
llama_batch_add(batch, sctx.last_sampled, prompt_size + sctx.n_generated, {sctx.seq_id}, true);
llama_batch_add(batch, sctx.last_sampled, prompt_size + sctx.n_toks, {sctx.seq_id}, true);
}
} break;
@ -600,19 +599,18 @@ void gen_ctx::handle_seq(seq_ctx & sctx) {
const size_t remain = chunk.tokens.size() - chunk.consumed;
const size_t to_consume = std::min(size_t(params.n_batch), remain);
for (size_t i = chunk.consumed; i < chunk.consumed + to_consume; ++i) {
llama_batch_add(batch, chunk.tokens[i], llama_pos(prompt_size + sctx.n_generated + i), {sctx.seq_id}, false);
llama_batch_add(batch, chunk.tokens[i], llama_pos(prompt_size + sctx.n_toks + i), {sctx.seq_id}, false);
}
chunk.consumed += to_consume;
sctx.n_remain -= to_consume;
// FIXME: This a lie, we didn't generate it.
sctx.n_generated += to_consume;
sctx.n_toks += to_consume;
if (chunk.consumed == chunk.tokens.size()) {
#ifndef LLAMA_NO_SEQREP_SAMPLER
// FIXME: Move this logic to a more appropriate place.
for (size_t i = 0; i < chunk.consumed; i++) {
sctx.rewind_state.logit_slots.emplace_back(n_vocab);
}
sctx.high_water_mark = sctx.n_generated + 1;
sctx.high_water_mark = sctx.n_toks + 1;
#endif
sctx.batch_idx = batch.n_tokens - 1;
batch.logits[sctx.batch_idx] = true;
@ -631,29 +629,29 @@ void gen_ctx::handle_seq(seq_ctx & sctx) {
#ifndef LLAMA_NO_SEQREP_SAMPLER
void gen_ctx::handle_seq_seqrep(seq_ctx & sctx) {
if (sctx.n_generated > 0) {
seqrep_rewind_slot & rw_slot = sctx.rewind_state.get_rewind_slot(sctx.n_generated);
if (sctx.n_toks > 0) {
seqrep_rewind_slot & rw_slot = sctx.rewind_state.get_rewind_slot(sctx.n_toks);
if (rw_slot.ctx_sampling == nullptr) {
rw_slot.ctx_sampling = llama_sampling_init(params.sparams);
}
llama_sampling_cp(sctx.ctx_sampling, rw_slot.ctx_sampling);
sctx.rewind_state.set_logits_slot(ctx, sctx.n_generated, sctx.batch_idx);
sctx.rewind_state.set_logits_slot(ctx, sctx.n_toks, sctx.batch_idx);
} else {
return;
}
std::vector<llama_token> seq_last_tokens;
seq_last_tokens.reserve(sctx.n_generated);
seq_last_tokens.reserve(sctx.n_toks);
concat_chunks(sctx.chunks, seq_last_tokens, prompt_size);
size_t rewind_distance =
llama_seqrep_handle_rewind(
ctx, sctx.rewind_state, seq_last_tokens, sctx.n_generated, prompt_tokens,
ctx, sctx.rewind_state, seq_last_tokens, sctx.n_toks, prompt_tokens,
sparams.seqrep_params, &sctx.high_water_mark, sctx.batch_idx);
if (rewind_distance < 1) {
return;
}
GGML_ASSERT(rewind_distance <= sctx.n_generated && "Rewind index out of bounds somehow?");
const size_t slot_idx = sctx.n_generated - rewind_distance;
GGML_ASSERT(rewind_distance <= sctx.n_toks && "Rewind index out of bounds somehow?");
const size_t slot_idx = sctx.n_toks - rewind_distance;
const llama_token nl_id = llama_token_nl(model);
seqrep_rewind_slot & rw_slot = sctx.rewind_state.get_rewind_slot(slot_idx);
@ -676,10 +674,10 @@ void gen_ctx::handle_seq(seq_ctx & sctx) {
}
sctx.n_remain += rewind_distance;
sctx.n_generated -= rewind_distance;
sctx.n_toks -= rewind_distance;
sctx.rewind_count++;
sctx.rewind_tokens += rewind_distance;
llama_kv_cache_seq_rm(ctx, sctx.seq_id, prompt_size + sctx.n_generated + 1, -1);
llama_kv_cache_seq_rm(ctx, sctx.seq_id, prompt_size + sctx.n_toks + 1, -1);
while (!sctx.chunks.empty() && rewind_distance > 0) {
tokens_chunk & last_chunk = sctx.chunks.back();
GGML_ASSERT(!last_chunk.is_input);
@ -738,8 +736,9 @@ bool gen_ctx::go() {
decode_time_last = std::max(int64_t(0), ggml_time_us() - decode_time_last);
decode_time_total += decode_time_last;
// FIXME: Handle KV cache pressure better.
if (decode_result != 0) {
LOG_TEE("%s : failed to eval batch of size %d: %s\n", __func__, batch.n_tokens,
fprintf(stderr, "%s : failed to eval batch of size %d: %s\n", __func__, batch.n_tokens,
decode_result == 1 ? "couldn't find slot" : "unknown error");
return false;
}
@ -752,7 +751,7 @@ static bool handle_commands(gen_ctx & gctx) {
line.reserve(1024);
LOG_TEE("\n- Entering command mode. Use /help for help, blank line to exit. Focused sequence: %d\n", gctx.focused_sequence + 1);
printf("\n- Entering command mode. Use /help for help, blank line to exit. Focused sequence: %d\n", gctx.focused_sequence + 1);
fflush(stdout);
while (1) {
printf("> ");
@ -764,7 +763,7 @@ static bool handle_commands(gen_ctx & gctx) {
}
if (line.empty()) break;
if (line.size() < 2 || line.front() != '/') {
LOG_TEE("\n- Bad command\n");
printf("\n- Bad command\n");
continue;
}
size_t sep_idx = line.find(' ');
@ -775,47 +774,65 @@ static bool handle_commands(gen_ctx & gctx) {
} else {
command = line.substr(1);
}
for (char & c : command) c = std::tolower(c);
if (command == "help") {
LOG_TEE("- Availabe commands:\n");
LOG_TEE(" /add TEXT : Adds the specified text to the focused sequence. Alias: /a\n");
LOG_TEE(" /addline TEXT : Adds the specified text to the focused with a newline at the end. Alias: /al\n");
LOG_TEE(" /help : Show this help.\n");
LOG_TEE(" /kill N : Stop sequence N. Alia: /k\n");
LOG_TEE(" /list : List sequences and their state. Alias: /l\n");
LOG_TEE(" /N : Focus sequence N. Example /2 focus sequence 2.\n");
LOG_TEE(" /quit : Exit the program. Alias: /q\n");
LOG_TEE("- End listing\n");
if (command == "h" || command == "help") {
printf("- Help: For commands with [SEQ], optionally specify a sequence number here to set the target.\n");
printf(" If sequence isn't specified, then the current focus is used if possible.\n");
printf(" One of any punctuation character is allowed after the number.\n");
printf(" For example, '/1add hello' and '/1,add hello' both add 'hello' to sequence 1.\n");
printf("- Available commands:\n");
printf(" /[SEQ]add TEXT : Adds the specified text to the focused sequence. Alias: /a\n");
printf(" /[SEQ]addesc TEXT : Same as /add but handles escapes (\\n, \\x20, etc) and tokenizes without a leading space. Alias: /ae\n");
printf(" /[SEQ]addline TEXT : Same as /add but appends a newline. Alias: /al\n");
printf(" /help : Show this help. Alias: /h\n");
printf(" /[SEQ]dump N : Dump the last N tokens of SEQ showing offsets from the end. Alias: /d\n");
printf(" /[SEQ]dumptokens N : Same as /dump but displays token IDs as well. Alias: /dt\n");
printf(" /[SEQ]kill : Stop sequence SEQ. Alia: /k\n");
printf(" /list : List sequences and their state. Alias: /l\n");
printf(" /[SEQ]focus : Focus sequence SEQ. Alias: Just use /1, /2, etc\n");
printf(" /[SEQ]print : Display the content of SEQ. Alias: /p\n");
printf(" /quit : Exit the program. Alias: /q\n");
printf("- End listing\n");
continue;
}
if (command == "q" || command == "quit") return false;
llama_seq_id target = -1;
// Focus
if (isdigit(command[0])) {
const int target = std::atoi(command.c_str());
char * parse_end = nullptr;
target = std::strtol(command.c_str(), &parse_end, 10);
if (target < 1 || size_t(target) > gctx.ctxs_seq.size()) {
LOG_TEE("! Focus: Bad seq id\n");
} else {
gctx.focused_sequence = llama_seq_id(target - 1);
printf("! Bad seq id\n");
continue;
}
target--;
if (std::ispunct(*parse_end)) parse_end++;
command = std::string(parse_end);
}
if (command.empty() || command == "focus") {
printf("- Focus changed from %d to %d\n", gctx.focused_sequence + 1, target + 1);
gctx.focused_sequence = llama_seq_id(target);
continue;
}
if (command == "k" || command == "kill") {
const int target = std::atoi(rest.c_str());
if (target < 1 || size_t(target) > gctx.ctxs_seq.size()) {
LOG_TEE("! Kill: Bad seq id\n");
} else if (target - 1 == gctx.focused_sequence) {
LOG_TEE("! Kill: Can't kill focus\n");
if (target == gctx.focused_sequence) {
printf("! Kill: Can't kill focus\n");
} else {
gctx.ctxs_seq[target - 1].state = SEQ_DONE;
printf("- Killed sequence %d\n", target + 1);
gctx.ctxs_seq[target].state = SEQ_DONE;
llama_kv_cache_seq_rm(gctx.ctx, target, -1, -1);
}
continue;
}
if (command == "l" || command == "list") {
LOG_TEE("- Listing %zu sequence%s:\n",
printf("- Listing %zu sequence%s:\n",
gctx.ctxs_seq.size(),
gctx.ctxs_seq.size() != 1 ? "s" : "");
for (const seq_ctx & sctx : gctx.ctxs_seq) {
@ -827,30 +844,37 @@ static bool handle_commands(gen_ctx & gctx) {
case SEQ_SHARE_PROMPT: label = "WAIT"; break;
default: GGML_ASSERT(false);
}
LOG_TEE(" %s%3d (%s): generated %5zu, remain %5zu. chunks: ",
printf(" %s%3d (%s): generated %5zu, remain %5zu. chunks: ",
sctx.seq_id == gctx.focused_sequence ? "*" : " ",
sctx.seq_id + 1, label.c_str(),
sctx.n_generated, sctx.n_remain);
sctx.n_toks, sctx.n_remain);
for (const tokens_chunk & chunk : sctx.chunks) {
if (chunk.is_input) {
LOG_TEE("INP(%5zu,%5zu), ", chunk.tokens.size(), chunk.consumed);
printf("INP(%5zu,%5zu), ", chunk.tokens.size(), chunk.consumed);
} else {
LOG_TEE("GEN(%5zu), ", chunk.tokens.size());
printf("GEN(%5zu), ", chunk.tokens.size());
}
}
LOG_TEE("\n");
printf("\n");
}
continue;
}
if (command == "al" || command == "a" || command == "add" || command == "addline") {
seq_ctx & sctx = gctx.ctxs_seq[gctx.focused_sequence];
if ( command == "al" || command == "a" || command == "ae"
|| command == "add" || command == "addline" || command == "addesc") {
bool is_special = false;
seq_ctx & sctx = gctx.ctxs_seq[target < 0 ? gctx.focused_sequence : target];
if (command == "al" || command == "addline") rest.push_back('\n');
std::vector<llama_token> input_tokens = ::llama_tokenize(gctx.model, rest, false);
if (command == "al" || command == "addline") {
rest.push_back('\n');
} else if (command == "ae" || command == "addesc") {
process_escapes(rest);
is_special = true;
}
std::vector<llama_token> input_tokens = ::llama_tokenize(gctx.model, rest, false, is_special);
if (input_tokens.size() > sctx.n_remain) {
LOG_TEE("! Input is %zu token(s) but sequence %d only has space for %zu\n",
printf("! Input is %zu token(s) but sequence %d only has space for %zu\n",
input_tokens.size(), gctx.focused_sequence + 1, sctx.n_remain);
continue;
}
@ -867,7 +891,68 @@ static bool handle_commands(gen_ctx & gctx) {
continue;
}
LOG_TEE("! Bad command\n");
if (command == "p" || command == "print") {
seq_ctx & sctx = gctx.ctxs_seq[target < 0 ? gctx.focused_sequence : target];
std::string label;
switch (sctx.state) {
case SEQ_DONE: label = "DONE"; break;
case SEQ_GENERATING: label = "LIVE"; break;
case SEQ_INPUT: label = "FEED"; break;
case SEQ_SHARE_PROMPT: label = "WAIT"; break;
default: GGML_ASSERT(false);
}
printf("- Showing sequence %3d%s: state %s, generated %5zu, remain %5zu. chunks: ",
sctx.seq_id + 1,
sctx.seq_id == gctx.focused_sequence ? "(focus)" : " ",
label.c_str(), sctx.n_toks, sctx.n_remain);
for (const tokens_chunk & chunk : sctx.chunks) {
if (chunk.is_input) {
printf("INP(%5zu,%5zu), ", chunk.tokens.size(), chunk.consumed);
} else {
printf("GEN(%5zu), ", chunk.tokens.size());
}
}
printf("\n");
gctx.dump_chunks(sctx.chunks);
printf("\n- Done\n");
continue;
}
if (command == "d" || command == "dt" || command == "dump" || command == "dumptokens") {
seq_ctx & sctx = gctx.ctxs_seq[target < 0 ? gctx.focused_sequence : target];
const bool with_id = command == "dt" || command == "dumptokens";
const size_t max_n = sctx.n_toks + gctx.prompt_size;
size_t dump_n = size_t(std::max(0, atoi(rest.c_str())));
if (dump_n == 0) dump_n = 200;
dump_n = std::min(dump_n, max_n);
printf("- Dumping last %zu token%s from sequence %d\n",
dump_n, dump_n != 1 ? "s" : "", target + 1);
std::vector<llama_token> result;
result.reserve(dump_n);
concat_chunks(sctx.chunks, result, max_n - dump_n);
GGML_ASSERT(result.size() == dump_n);
for (size_t i = 0; i < dump_n; i++) {
const llama_token tid = result[i];
console::set_display(console::user_input);
printf("[%zu", dump_n - i);
if (with_id) {
printf(",%d", tid);
}
fputs("]", stdout);
console::set_display(console::reset);
fputs(llama_token_to_piece(gctx.ctx, tid).c_str(), stdout);
}
console::set_display(console::reset);
printf("\n\n- Dump complete.\n");
continue;
}
printf("! Bad command\n");
}
return true;
}
@ -875,10 +960,12 @@ static bool handle_commands(gen_ctx & gctx) {
int main(int argc, char ** argv) {
gen_ctx gctx(argc, argv);
while (gctx.go() && !done) {
// This might look weird but done can get set while go() is running.
while (!done && gctx.go() && !done) {
bool need_dump = gctx.params.n_parallel > 1 && gctx.decode_count % SI_DUMP_SEQUENCES_INTERVAL == 0;
if (interrupted) {
if (!handle_commands(gctx)) break;
if (!gctx.params.interactive || !handle_commands(gctx)) break;
// Double check that ^C wasn't hit again.
if (done) break;
interrupted = false;
need_dump = true;