Expand simple-inference command handling.
Clear KV cache when a sequence is killed. Allow parse_escapes to handle \x sequences
This commit is contained in:
parent
a10f7cd087
commit
a0c5587401
1 changed files with 162 additions and 75 deletions
|
@ -70,7 +70,7 @@ typedef struct seq_ctx {
|
||||||
int32_t batch_idx;
|
int32_t batch_idx;
|
||||||
enum seq_state state;
|
enum seq_state state;
|
||||||
size_t n_remain;
|
size_t n_remain;
|
||||||
size_t n_generated;
|
size_t n_toks; // Note: Does not include initial prompt size.
|
||||||
llama_sampling_context *ctx_sampling;
|
llama_sampling_context *ctx_sampling;
|
||||||
|
|
||||||
llama_token last_sampled;
|
llama_token last_sampled;
|
||||||
|
@ -121,7 +121,6 @@ typedef struct gen_ctx {
|
||||||
~gen_ctx();
|
~gen_ctx();
|
||||||
void dump_batches(const size_t prompt_start = 0);
|
void dump_batches(const size_t prompt_start = 0);
|
||||||
void dump_chunks(const std::vector<tokens_chunk> & chunks, const size_t start_offset = 0);
|
void dump_chunks(const std::vector<tokens_chunk> & chunks, const size_t start_offset = 0);
|
||||||
void dump_batch(const size_t seq);
|
|
||||||
void handle_seq(seq_ctx & sctx);
|
void handle_seq(seq_ctx & sctx);
|
||||||
#ifndef LLAMA_NO_SEQREP_SAMPLER
|
#ifndef LLAMA_NO_SEQREP_SAMPLER
|
||||||
void handle_seq_seqrep(seq_ctx & sctx);
|
void handle_seq_seqrep(seq_ctx & sctx);
|
||||||
|
@ -227,8 +226,8 @@ static bool check_unsupported(const gpt_params * params) {
|
||||||
nope = "prompt cache";
|
nope = "prompt cache";
|
||||||
else if (params->escape)
|
else if (params->escape)
|
||||||
nope = "prompt escaping";
|
nope = "prompt escaping";
|
||||||
else if (params->interactive || params->interactive_first || params->instruct)
|
else if (params->interactive_first || params->instruct)
|
||||||
nope = "interactive mode";
|
nope = "interactive first or instruct mode";
|
||||||
else if (!params->input_prefix.empty() || !params->input_suffix.empty() || params->input_prefix_bos)
|
else if (!params->input_prefix.empty() || !params->input_suffix.empty() || params->input_prefix_bos)
|
||||||
nope = "input prefix or suffix";
|
nope = "input prefix or suffix";
|
||||||
else if (params->hellaswag)
|
else if (params->hellaswag)
|
||||||
|
@ -238,7 +237,7 @@ static bool check_unsupported(const gpt_params * params) {
|
||||||
else if (!params->antiprompt.empty())
|
else if (!params->antiprompt.empty())
|
||||||
nope = "reverse prompt";
|
nope = "reverse prompt";
|
||||||
if (!nope.empty()) {
|
if (!nope.empty()) {
|
||||||
LOG_TEE("%s: error: We don't support %s here.\n", __func__, nope.c_str());
|
printf("%s: error: We don't support %s here.\n", __func__, nope.c_str());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
|
@ -254,15 +253,15 @@ bool gen_ctx::init_params(const int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.rope_freq_base != 10000.0) {
|
if (params.rope_freq_base != 10000.0) {
|
||||||
LOG_TEE("%s: warning: changing RoPE frequency base to %g (default 10000.0)\n", __func__, params.rope_freq_base);
|
printf("%s: warning: changing RoPE frequency base to %g (default 10000.0)\n", __func__, params.rope_freq_base);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.rope_freq_scale != 1.0) {
|
if (params.rope_freq_scale != 1.0) {
|
||||||
LOG_TEE("%s: warning: scaling RoPE frequency by %g (default 1.0)\n", __func__, params.rope_freq_scale);
|
printf("%s: warning: scaling RoPE frequency by %g (default 1.0)\n", __func__, params.rope_freq_scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.n_ctx < 8) {
|
if (params.n_ctx < 8) {
|
||||||
LOG_TEE("%s: warning: minimum context size is 8, using minimum size.\n", __func__);
|
printf("%s: warning: minimum context size is 8, using minimum size.\n", __func__);
|
||||||
params.n_ctx = 8;
|
params.n_ctx = 8;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -270,7 +269,7 @@ bool gen_ctx::init_params(const int argc, char ** argv) {
|
||||||
params.seed = time(NULL);
|
params.seed = time(NULL);
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG_TEE("%s: seed = %u\n", __func__, params.seed);
|
printf("%s: seed = %u\n", __func__, params.seed);
|
||||||
|
|
||||||
std::mt19937 rng(params.seed);
|
std::mt19937 rng(params.seed);
|
||||||
if (params.random_prompt) {
|
if (params.random_prompt) {
|
||||||
|
@ -289,14 +288,14 @@ bool gen_ctx::init_model() {
|
||||||
std::tie(model, ctx) = llama_init_from_gpt_params(params);
|
std::tie(model, ctx) = llama_init_from_gpt_params(params);
|
||||||
|
|
||||||
if (model == NULL) {
|
if (model == NULL) {
|
||||||
LOG_TEE("%s: error: unable to load model\n", __func__);
|
printf("%s: error: unable to load model\n", __func__);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// print system information
|
// print system information
|
||||||
{
|
{
|
||||||
LOG_TEE("\n");
|
printf("\n");
|
||||||
LOG_TEE("system_info: n_threads = %d / %d | %s\n",
|
printf("system_info: n_threads = %d / %d | %s\n",
|
||||||
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
|
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -327,20 +326,20 @@ bool gen_ctx::init_prompt() {
|
||||||
LOG("n_ctx: %d\n", n_ctx);
|
LOG("n_ctx: %d\n", n_ctx);
|
||||||
|
|
||||||
if ((int) prompt_tokens.size() > n_ctx - 4) {
|
if ((int) prompt_tokens.size() > n_ctx - 4) {
|
||||||
LOG_TEE("%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int) prompt_tokens.size(), n_ctx - 4);
|
printf("%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int) prompt_tokens.size(), n_ctx - 4);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
prompt_size = prompt_tokens.size();
|
prompt_size = prompt_tokens.size();
|
||||||
|
|
||||||
if (params.verbose_prompt) {
|
if (params.verbose_prompt) {
|
||||||
LOG_TEE("\n");
|
printf("\n");
|
||||||
LOG_TEE("%s: prompt: '%s'\n", __func__, params.prompt.c_str());
|
printf("%s: prompt: '%s'\n", __func__, params.prompt.c_str());
|
||||||
LOG_TEE("%s: number of tokens in prompt = %zu\n", __func__, prompt_tokens.size());
|
printf("%s: number of tokens in prompt = %zu\n", __func__, prompt_tokens.size());
|
||||||
for (int i = 0; i < (int) prompt_tokens.size(); i++) {
|
for (int i = 0; i < (int) prompt_tokens.size(); i++) {
|
||||||
LOG_TEE("%6d -> '%s'\n", prompt_tokens[i], llama_token_to_piece(ctx, prompt_tokens[i]).c_str());
|
printf("%6d -> '%s'\n", prompt_tokens[i], llama_token_to_piece(ctx, prompt_tokens[i]).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG_TEE("\n");
|
printf("\n");
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -367,7 +366,7 @@ bool gen_ctx::init_handlers() {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool gen_ctx::init_sampling() {
|
bool gen_ctx::init_sampling() {
|
||||||
LOG_TEE("sampling: %s\n", llama_sampling_print(sparams).c_str());
|
printf("sampling: %s\n", llama_sampling_print(sparams).c_str());
|
||||||
#ifndef LLAMA_NO_SEQREP_SAMPLER
|
#ifndef LLAMA_NO_SEQREP_SAMPLER
|
||||||
for (auto & sr_params : sparams.seqrep_params) {
|
for (auto & sr_params : sparams.seqrep_params) {
|
||||||
seqrep_sampler_params_dump(&sr_params);
|
seqrep_sampler_params_dump(&sr_params);
|
||||||
|
@ -476,7 +475,7 @@ bool gen_ctx::feed_prompt(const std::vector<llama_token> & tokens, llama_pos pos
|
||||||
|
|
||||||
if (llama_decode(ctx, batch) != 0) {
|
if (llama_decode(ctx, batch) != 0) {
|
||||||
console::set_display(console::reset);
|
console::set_display(console::reset);
|
||||||
LOG_TEE("%s : failed to eval\n", __func__);
|
printf("%s : failed to eval\n", __func__);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
decode_count++;
|
decode_count++;
|
||||||
|
@ -534,7 +533,7 @@ void gen_ctx::dump_batches(const size_t prompt_start) {
|
||||||
if (sctx.seq_id == focused_sequence) continue;
|
if (sctx.seq_id == focused_sequence) continue;
|
||||||
printf("\n\n%s Result #%d (size: %zu",
|
printf("\n\n%s Result #%d (size: %zu",
|
||||||
!first ? "====================" : "####################",
|
!first ? "====================" : "####################",
|
||||||
sctx.seq_id + 1, prompt_size + sctx.n_generated);
|
sctx.seq_id + 1, prompt_size + sctx.n_toks);
|
||||||
#ifndef LLAMA_NO_SEQREP_SAMPLER
|
#ifndef LLAMA_NO_SEQREP_SAMPLER
|
||||||
printf(", rewind cnt/toks: %zu/%zu", sctx.rewind_count, sctx.rewind_tokens);
|
printf(", rewind cnt/toks: %zu/%zu", sctx.rewind_count, sctx.rewind_tokens);
|
||||||
#endif
|
#endif
|
||||||
|
@ -545,7 +544,7 @@ void gen_ctx::dump_batches(const size_t prompt_start) {
|
||||||
seq_ctx & sctx = ctxs_seq[focused_sequence];
|
seq_ctx & sctx = ctxs_seq[focused_sequence];
|
||||||
printf("\n\n%s Result #%d (size: %zu",
|
printf("\n\n%s Result #%d (size: %zu",
|
||||||
!first ? "====================" : "####################",
|
!first ? "====================" : "####################",
|
||||||
sctx.seq_id + 1, prompt_size + sctx.n_generated);
|
sctx.seq_id + 1, prompt_size + sctx.n_toks);
|
||||||
#ifndef LLAMA_NO_SEQREP_SAMPLER
|
#ifndef LLAMA_NO_SEQREP_SAMPLER
|
||||||
printf(", rewind cnt/toks: %zu/%zu", sctx.rewind_count, sctx.rewind_tokens);
|
printf(", rewind cnt/toks: %zu/%zu", sctx.rewind_count, sctx.rewind_tokens);
|
||||||
#endif
|
#endif
|
||||||
|
@ -572,7 +571,7 @@ void gen_ctx::handle_seq(seq_ctx & sctx) {
|
||||||
fputs(token_str.c_str(), stdout);
|
fputs(token_str.c_str(), stdout);
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
}
|
}
|
||||||
sctx.n_generated++;
|
sctx.n_toks++;
|
||||||
sctx.n_remain--;
|
sctx.n_remain--;
|
||||||
if (sctx.chunks.empty() || sctx.chunks.back().is_input) {
|
if (sctx.chunks.empty() || sctx.chunks.back().is_input) {
|
||||||
sctx.chunks.emplace_back(0, false, std::vector<llama_token>());
|
sctx.chunks.emplace_back(0, false, std::vector<llama_token>());
|
||||||
|
@ -581,11 +580,11 @@ void gen_ctx::handle_seq(seq_ctx & sctx) {
|
||||||
if (sctx.last_sampled == llama_token_eos(model) || sctx.n_remain == 0) {
|
if (sctx.last_sampled == llama_token_eos(model) || sctx.n_remain == 0) {
|
||||||
sctx.state = SEQ_DONE;
|
sctx.state = SEQ_DONE;
|
||||||
sctx.batch_idx = -1;
|
sctx.batch_idx = -1;
|
||||||
// LOG_TEE(" [end of text]\n");
|
// printf(" [end of text]\n");
|
||||||
// break;
|
// break;
|
||||||
} else {
|
} else {
|
||||||
sctx.batch_idx = batch.n_tokens;
|
sctx.batch_idx = batch.n_tokens;
|
||||||
llama_batch_add(batch, sctx.last_sampled, prompt_size + sctx.n_generated, {sctx.seq_id}, true);
|
llama_batch_add(batch, sctx.last_sampled, prompt_size + sctx.n_toks, {sctx.seq_id}, true);
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
|
||||||
|
@ -600,19 +599,18 @@ void gen_ctx::handle_seq(seq_ctx & sctx) {
|
||||||
const size_t remain = chunk.tokens.size() - chunk.consumed;
|
const size_t remain = chunk.tokens.size() - chunk.consumed;
|
||||||
const size_t to_consume = std::min(size_t(params.n_batch), remain);
|
const size_t to_consume = std::min(size_t(params.n_batch), remain);
|
||||||
for (size_t i = chunk.consumed; i < chunk.consumed + to_consume; ++i) {
|
for (size_t i = chunk.consumed; i < chunk.consumed + to_consume; ++i) {
|
||||||
llama_batch_add(batch, chunk.tokens[i], llama_pos(prompt_size + sctx.n_generated + i), {sctx.seq_id}, false);
|
llama_batch_add(batch, chunk.tokens[i], llama_pos(prompt_size + sctx.n_toks + i), {sctx.seq_id}, false);
|
||||||
}
|
}
|
||||||
chunk.consumed += to_consume;
|
chunk.consumed += to_consume;
|
||||||
sctx.n_remain -= to_consume;
|
sctx.n_remain -= to_consume;
|
||||||
// FIXME: This a lie, we didn't generate it.
|
sctx.n_toks += to_consume;
|
||||||
sctx.n_generated += to_consume;
|
|
||||||
if (chunk.consumed == chunk.tokens.size()) {
|
if (chunk.consumed == chunk.tokens.size()) {
|
||||||
#ifndef LLAMA_NO_SEQREP_SAMPLER
|
#ifndef LLAMA_NO_SEQREP_SAMPLER
|
||||||
// FIXME: Move this logic to a more appropriate place.
|
// FIXME: Move this logic to a more appropriate place.
|
||||||
for (size_t i = 0; i < chunk.consumed; i++) {
|
for (size_t i = 0; i < chunk.consumed; i++) {
|
||||||
sctx.rewind_state.logit_slots.emplace_back(n_vocab);
|
sctx.rewind_state.logit_slots.emplace_back(n_vocab);
|
||||||
}
|
}
|
||||||
sctx.high_water_mark = sctx.n_generated + 1;
|
sctx.high_water_mark = sctx.n_toks + 1;
|
||||||
#endif
|
#endif
|
||||||
sctx.batch_idx = batch.n_tokens - 1;
|
sctx.batch_idx = batch.n_tokens - 1;
|
||||||
batch.logits[sctx.batch_idx] = true;
|
batch.logits[sctx.batch_idx] = true;
|
||||||
|
@ -631,29 +629,29 @@ void gen_ctx::handle_seq(seq_ctx & sctx) {
|
||||||
|
|
||||||
#ifndef LLAMA_NO_SEQREP_SAMPLER
|
#ifndef LLAMA_NO_SEQREP_SAMPLER
|
||||||
void gen_ctx::handle_seq_seqrep(seq_ctx & sctx) {
|
void gen_ctx::handle_seq_seqrep(seq_ctx & sctx) {
|
||||||
if (sctx.n_generated > 0) {
|
if (sctx.n_toks > 0) {
|
||||||
seqrep_rewind_slot & rw_slot = sctx.rewind_state.get_rewind_slot(sctx.n_generated);
|
seqrep_rewind_slot & rw_slot = sctx.rewind_state.get_rewind_slot(sctx.n_toks);
|
||||||
if (rw_slot.ctx_sampling == nullptr) {
|
if (rw_slot.ctx_sampling == nullptr) {
|
||||||
rw_slot.ctx_sampling = llama_sampling_init(params.sparams);
|
rw_slot.ctx_sampling = llama_sampling_init(params.sparams);
|
||||||
}
|
}
|
||||||
llama_sampling_cp(sctx.ctx_sampling, rw_slot.ctx_sampling);
|
llama_sampling_cp(sctx.ctx_sampling, rw_slot.ctx_sampling);
|
||||||
sctx.rewind_state.set_logits_slot(ctx, sctx.n_generated, sctx.batch_idx);
|
sctx.rewind_state.set_logits_slot(ctx, sctx.n_toks, sctx.batch_idx);
|
||||||
} else {
|
} else {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
std::vector<llama_token> seq_last_tokens;
|
std::vector<llama_token> seq_last_tokens;
|
||||||
seq_last_tokens.reserve(sctx.n_generated);
|
seq_last_tokens.reserve(sctx.n_toks);
|
||||||
concat_chunks(sctx.chunks, seq_last_tokens, prompt_size);
|
concat_chunks(sctx.chunks, seq_last_tokens, prompt_size);
|
||||||
|
|
||||||
size_t rewind_distance =
|
size_t rewind_distance =
|
||||||
llama_seqrep_handle_rewind(
|
llama_seqrep_handle_rewind(
|
||||||
ctx, sctx.rewind_state, seq_last_tokens, sctx.n_generated, prompt_tokens,
|
ctx, sctx.rewind_state, seq_last_tokens, sctx.n_toks, prompt_tokens,
|
||||||
sparams.seqrep_params, &sctx.high_water_mark, sctx.batch_idx);
|
sparams.seqrep_params, &sctx.high_water_mark, sctx.batch_idx);
|
||||||
if (rewind_distance < 1) {
|
if (rewind_distance < 1) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
GGML_ASSERT(rewind_distance <= sctx.n_generated && "Rewind index out of bounds somehow?");
|
GGML_ASSERT(rewind_distance <= sctx.n_toks && "Rewind index out of bounds somehow?");
|
||||||
const size_t slot_idx = sctx.n_generated - rewind_distance;
|
const size_t slot_idx = sctx.n_toks - rewind_distance;
|
||||||
const llama_token nl_id = llama_token_nl(model);
|
const llama_token nl_id = llama_token_nl(model);
|
||||||
|
|
||||||
seqrep_rewind_slot & rw_slot = sctx.rewind_state.get_rewind_slot(slot_idx);
|
seqrep_rewind_slot & rw_slot = sctx.rewind_state.get_rewind_slot(slot_idx);
|
||||||
|
@ -676,10 +674,10 @@ void gen_ctx::handle_seq(seq_ctx & sctx) {
|
||||||
}
|
}
|
||||||
|
|
||||||
sctx.n_remain += rewind_distance;
|
sctx.n_remain += rewind_distance;
|
||||||
sctx.n_generated -= rewind_distance;
|
sctx.n_toks -= rewind_distance;
|
||||||
sctx.rewind_count++;
|
sctx.rewind_count++;
|
||||||
sctx.rewind_tokens += rewind_distance;
|
sctx.rewind_tokens += rewind_distance;
|
||||||
llama_kv_cache_seq_rm(ctx, sctx.seq_id, prompt_size + sctx.n_generated + 1, -1);
|
llama_kv_cache_seq_rm(ctx, sctx.seq_id, prompt_size + sctx.n_toks + 1, -1);
|
||||||
while (!sctx.chunks.empty() && rewind_distance > 0) {
|
while (!sctx.chunks.empty() && rewind_distance > 0) {
|
||||||
tokens_chunk & last_chunk = sctx.chunks.back();
|
tokens_chunk & last_chunk = sctx.chunks.back();
|
||||||
GGML_ASSERT(!last_chunk.is_input);
|
GGML_ASSERT(!last_chunk.is_input);
|
||||||
|
@ -738,8 +736,9 @@ bool gen_ctx::go() {
|
||||||
decode_time_last = std::max(int64_t(0), ggml_time_us() - decode_time_last);
|
decode_time_last = std::max(int64_t(0), ggml_time_us() - decode_time_last);
|
||||||
decode_time_total += decode_time_last;
|
decode_time_total += decode_time_last;
|
||||||
|
|
||||||
|
// FIXME: Handle KV cache pressure better.
|
||||||
if (decode_result != 0) {
|
if (decode_result != 0) {
|
||||||
LOG_TEE("%s : failed to eval batch of size %d: %s\n", __func__, batch.n_tokens,
|
fprintf(stderr, "%s : failed to eval batch of size %d: %s\n", __func__, batch.n_tokens,
|
||||||
decode_result == 1 ? "couldn't find slot" : "unknown error");
|
decode_result == 1 ? "couldn't find slot" : "unknown error");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -752,7 +751,7 @@ static bool handle_commands(gen_ctx & gctx) {
|
||||||
line.reserve(1024);
|
line.reserve(1024);
|
||||||
|
|
||||||
|
|
||||||
LOG_TEE("\n- Entering command mode. Use /help for help, blank line to exit. Focused sequence: %d\n", gctx.focused_sequence + 1);
|
printf("\n- Entering command mode. Use /help for help, blank line to exit. Focused sequence: %d\n", gctx.focused_sequence + 1);
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
while (1) {
|
while (1) {
|
||||||
printf("> ");
|
printf("> ");
|
||||||
|
@ -764,7 +763,7 @@ static bool handle_commands(gen_ctx & gctx) {
|
||||||
}
|
}
|
||||||
if (line.empty()) break;
|
if (line.empty()) break;
|
||||||
if (line.size() < 2 || line.front() != '/') {
|
if (line.size() < 2 || line.front() != '/') {
|
||||||
LOG_TEE("\n- Bad command\n");
|
printf("\n- Bad command\n");
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
size_t sep_idx = line.find(' ');
|
size_t sep_idx = line.find(' ');
|
||||||
|
@ -775,47 +774,65 @@ static bool handle_commands(gen_ctx & gctx) {
|
||||||
} else {
|
} else {
|
||||||
command = line.substr(1);
|
command = line.substr(1);
|
||||||
}
|
}
|
||||||
|
for (char & c : command) c = std::tolower(c);
|
||||||
|
|
||||||
if (command == "help") {
|
if (command == "h" || command == "help") {
|
||||||
LOG_TEE("- Availabe commands:\n");
|
printf("- Help: For commands with [SEQ], optionally specify a sequence number here to set the target.\n");
|
||||||
LOG_TEE(" /add TEXT : Adds the specified text to the focused sequence. Alias: /a\n");
|
printf(" If sequence isn't specified, then the current focus is used if possible.\n");
|
||||||
LOG_TEE(" /addline TEXT : Adds the specified text to the focused with a newline at the end. Alias: /al\n");
|
printf(" One of any punctuation character is allowed after the number.\n");
|
||||||
LOG_TEE(" /help : Show this help.\n");
|
printf(" For example, '/1add hello' and '/1,add hello' both add 'hello' to sequence 1.\n");
|
||||||
LOG_TEE(" /kill N : Stop sequence N. Alia: /k\n");
|
printf("- Available commands:\n");
|
||||||
LOG_TEE(" /list : List sequences and their state. Alias: /l\n");
|
printf(" /[SEQ]add TEXT : Adds the specified text to the focused sequence. Alias: /a\n");
|
||||||
LOG_TEE(" /N : Focus sequence N. Example /2 focus sequence 2.\n");
|
printf(" /[SEQ]addesc TEXT : Same as /add but handles escapes (\\n, \\x20, etc) and tokenizes without a leading space. Alias: /ae\n");
|
||||||
LOG_TEE(" /quit : Exit the program. Alias: /q\n");
|
printf(" /[SEQ]addline TEXT : Same as /add but appends a newline. Alias: /al\n");
|
||||||
LOG_TEE("- End listing\n");
|
printf(" /help : Show this help. Alias: /h\n");
|
||||||
|
printf(" /[SEQ]dump N : Dump the last N tokens of SEQ showing offsets from the end. Alias: /d\n");
|
||||||
|
printf(" /[SEQ]dumptokens N : Same as /dump but displays token IDs as well. Alias: /dt\n");
|
||||||
|
printf(" /[SEQ]kill : Stop sequence SEQ. Alia: /k\n");
|
||||||
|
printf(" /list : List sequences and their state. Alias: /l\n");
|
||||||
|
printf(" /[SEQ]focus : Focus sequence SEQ. Alias: Just use /1, /2, etc\n");
|
||||||
|
printf(" /[SEQ]print : Display the content of SEQ. Alias: /p\n");
|
||||||
|
printf(" /quit : Exit the program. Alias: /q\n");
|
||||||
|
printf("- End listing\n");
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (command == "q" || command == "quit") return false;
|
if (command == "q" || command == "quit") return false;
|
||||||
|
|
||||||
|
llama_seq_id target = -1;
|
||||||
|
|
||||||
// Focus
|
// Focus
|
||||||
if (isdigit(command[0])) {
|
if (isdigit(command[0])) {
|
||||||
const int target = std::atoi(command.c_str());
|
char * parse_end = nullptr;
|
||||||
|
target = std::strtol(command.c_str(), &parse_end, 10);
|
||||||
if (target < 1 || size_t(target) > gctx.ctxs_seq.size()) {
|
if (target < 1 || size_t(target) > gctx.ctxs_seq.size()) {
|
||||||
LOG_TEE("! Focus: Bad seq id\n");
|
printf("! Bad seq id\n");
|
||||||
} else {
|
continue;
|
||||||
gctx.focused_sequence = llama_seq_id(target - 1);
|
|
||||||
}
|
}
|
||||||
|
target--;
|
||||||
|
if (std::ispunct(*parse_end)) parse_end++;
|
||||||
|
command = std::string(parse_end);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (command.empty() || command == "focus") {
|
||||||
|
printf("- Focus changed from %d to %d\n", gctx.focused_sequence + 1, target + 1);
|
||||||
|
gctx.focused_sequence = llama_seq_id(target);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (command == "k" || command == "kill") {
|
if (command == "k" || command == "kill") {
|
||||||
const int target = std::atoi(rest.c_str());
|
if (target == gctx.focused_sequence) {
|
||||||
if (target < 1 || size_t(target) > gctx.ctxs_seq.size()) {
|
printf("! Kill: Can't kill focus\n");
|
||||||
LOG_TEE("! Kill: Bad seq id\n");
|
|
||||||
} else if (target - 1 == gctx.focused_sequence) {
|
|
||||||
LOG_TEE("! Kill: Can't kill focus\n");
|
|
||||||
} else {
|
} else {
|
||||||
gctx.ctxs_seq[target - 1].state = SEQ_DONE;
|
printf("- Killed sequence %d\n", target + 1);
|
||||||
|
gctx.ctxs_seq[target].state = SEQ_DONE;
|
||||||
|
llama_kv_cache_seq_rm(gctx.ctx, target, -1, -1);
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (command == "l" || command == "list") {
|
if (command == "l" || command == "list") {
|
||||||
LOG_TEE("- Listing %zu sequence%s:\n",
|
printf("- Listing %zu sequence%s:\n",
|
||||||
gctx.ctxs_seq.size(),
|
gctx.ctxs_seq.size(),
|
||||||
gctx.ctxs_seq.size() != 1 ? "s" : "");
|
gctx.ctxs_seq.size() != 1 ? "s" : "");
|
||||||
for (const seq_ctx & sctx : gctx.ctxs_seq) {
|
for (const seq_ctx & sctx : gctx.ctxs_seq) {
|
||||||
|
@ -827,30 +844,37 @@ static bool handle_commands(gen_ctx & gctx) {
|
||||||
case SEQ_SHARE_PROMPT: label = "WAIT"; break;
|
case SEQ_SHARE_PROMPT: label = "WAIT"; break;
|
||||||
default: GGML_ASSERT(false);
|
default: GGML_ASSERT(false);
|
||||||
}
|
}
|
||||||
LOG_TEE(" %s%3d (%s): generated %5zu, remain %5zu. chunks: ",
|
printf(" %s%3d (%s): generated %5zu, remain %5zu. chunks: ",
|
||||||
sctx.seq_id == gctx.focused_sequence ? "*" : " ",
|
sctx.seq_id == gctx.focused_sequence ? "*" : " ",
|
||||||
sctx.seq_id + 1, label.c_str(),
|
sctx.seq_id + 1, label.c_str(),
|
||||||
sctx.n_generated, sctx.n_remain);
|
sctx.n_toks, sctx.n_remain);
|
||||||
for (const tokens_chunk & chunk : sctx.chunks) {
|
for (const tokens_chunk & chunk : sctx.chunks) {
|
||||||
if (chunk.is_input) {
|
if (chunk.is_input) {
|
||||||
LOG_TEE("INP(%5zu,%5zu), ", chunk.tokens.size(), chunk.consumed);
|
printf("INP(%5zu,%5zu), ", chunk.tokens.size(), chunk.consumed);
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
LOG_TEE("GEN(%5zu), ", chunk.tokens.size());
|
printf("GEN(%5zu), ", chunk.tokens.size());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
LOG_TEE("\n");
|
printf("\n");
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (command == "al" || command == "a" || command == "add" || command == "addline") {
|
if ( command == "al" || command == "a" || command == "ae"
|
||||||
seq_ctx & sctx = gctx.ctxs_seq[gctx.focused_sequence];
|
|| command == "add" || command == "addline" || command == "addesc") {
|
||||||
|
bool is_special = false;
|
||||||
|
seq_ctx & sctx = gctx.ctxs_seq[target < 0 ? gctx.focused_sequence : target];
|
||||||
|
|
||||||
if (command == "al" || command == "addline") rest.push_back('\n');
|
if (command == "al" || command == "addline") {
|
||||||
std::vector<llama_token> input_tokens = ::llama_tokenize(gctx.model, rest, false);
|
rest.push_back('\n');
|
||||||
|
} else if (command == "ae" || command == "addesc") {
|
||||||
|
process_escapes(rest);
|
||||||
|
is_special = true;
|
||||||
|
}
|
||||||
|
std::vector<llama_token> input_tokens = ::llama_tokenize(gctx.model, rest, false, is_special);
|
||||||
if (input_tokens.size() > sctx.n_remain) {
|
if (input_tokens.size() > sctx.n_remain) {
|
||||||
LOG_TEE("! Input is %zu token(s) but sequence %d only has space for %zu\n",
|
printf("! Input is %zu token(s) but sequence %d only has space for %zu\n",
|
||||||
input_tokens.size(), gctx.focused_sequence + 1, sctx.n_remain);
|
input_tokens.size(), gctx.focused_sequence + 1, sctx.n_remain);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -867,7 +891,68 @@ static bool handle_commands(gen_ctx & gctx) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG_TEE("! Bad command\n");
|
if (command == "p" || command == "print") {
|
||||||
|
seq_ctx & sctx = gctx.ctxs_seq[target < 0 ? gctx.focused_sequence : target];
|
||||||
|
std::string label;
|
||||||
|
switch (sctx.state) {
|
||||||
|
case SEQ_DONE: label = "DONE"; break;
|
||||||
|
case SEQ_GENERATING: label = "LIVE"; break;
|
||||||
|
case SEQ_INPUT: label = "FEED"; break;
|
||||||
|
case SEQ_SHARE_PROMPT: label = "WAIT"; break;
|
||||||
|
default: GGML_ASSERT(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
printf("- Showing sequence %3d%s: state %s, generated %5zu, remain %5zu. chunks: ",
|
||||||
|
sctx.seq_id + 1,
|
||||||
|
sctx.seq_id == gctx.focused_sequence ? "(focus)" : " ",
|
||||||
|
label.c_str(), sctx.n_toks, sctx.n_remain);
|
||||||
|
for (const tokens_chunk & chunk : sctx.chunks) {
|
||||||
|
if (chunk.is_input) {
|
||||||
|
printf("INP(%5zu,%5zu), ", chunk.tokens.size(), chunk.consumed);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
printf("GEN(%5zu), ", chunk.tokens.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
gctx.dump_chunks(sctx.chunks);
|
||||||
|
printf("\n- Done\n");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (command == "d" || command == "dt" || command == "dump" || command == "dumptokens") {
|
||||||
|
seq_ctx & sctx = gctx.ctxs_seq[target < 0 ? gctx.focused_sequence : target];
|
||||||
|
const bool with_id = command == "dt" || command == "dumptokens";
|
||||||
|
const size_t max_n = sctx.n_toks + gctx.prompt_size;
|
||||||
|
size_t dump_n = size_t(std::max(0, atoi(rest.c_str())));
|
||||||
|
if (dump_n == 0) dump_n = 200;
|
||||||
|
dump_n = std::min(dump_n, max_n);
|
||||||
|
|
||||||
|
printf("- Dumping last %zu token%s from sequence %d\n",
|
||||||
|
dump_n, dump_n != 1 ? "s" : "", target + 1);
|
||||||
|
|
||||||
|
std::vector<llama_token> result;
|
||||||
|
result.reserve(dump_n);
|
||||||
|
concat_chunks(sctx.chunks, result, max_n - dump_n);
|
||||||
|
GGML_ASSERT(result.size() == dump_n);
|
||||||
|
for (size_t i = 0; i < dump_n; i++) {
|
||||||
|
const llama_token tid = result[i];
|
||||||
|
console::set_display(console::user_input);
|
||||||
|
printf("[%zu", dump_n - i);
|
||||||
|
if (with_id) {
|
||||||
|
printf(",%d", tid);
|
||||||
|
}
|
||||||
|
fputs("]", stdout);
|
||||||
|
console::set_display(console::reset);
|
||||||
|
fputs(llama_token_to_piece(gctx.ctx, tid).c_str(), stdout);
|
||||||
|
|
||||||
|
}
|
||||||
|
console::set_display(console::reset);
|
||||||
|
printf("\n\n- Dump complete.\n");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
printf("! Bad command\n");
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -875,10 +960,12 @@ static bool handle_commands(gen_ctx & gctx) {
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
gen_ctx gctx(argc, argv);
|
gen_ctx gctx(argc, argv);
|
||||||
|
|
||||||
while (gctx.go() && !done) {
|
// This might look weird but done can get set while go() is running.
|
||||||
|
while (!done && gctx.go() && !done) {
|
||||||
bool need_dump = gctx.params.n_parallel > 1 && gctx.decode_count % SI_DUMP_SEQUENCES_INTERVAL == 0;
|
bool need_dump = gctx.params.n_parallel > 1 && gctx.decode_count % SI_DUMP_SEQUENCES_INTERVAL == 0;
|
||||||
if (interrupted) {
|
if (interrupted) {
|
||||||
if (!handle_commands(gctx)) break;
|
if (!gctx.params.interactive || !handle_commands(gctx)) break;
|
||||||
|
// Double check that ^C wasn't hit again.
|
||||||
if (done) break;
|
if (done) break;
|
||||||
interrupted = false;
|
interrupted = false;
|
||||||
need_dump = true;
|
need_dump = true;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue