main : skip printing token healing prefix twice

This commit is contained in:
mare5x 2024-05-03 21:56:11 +02:00
parent 7d0cc78bc3
commit d4cbccb103

View file

@ -509,6 +509,7 @@ int main(int argc, char ** argv) {
int n_consumed = 0; int n_consumed = 0;
int n_session_consumed = 0; int n_session_consumed = 0;
int n_past_guidance = 0; int n_past_guidance = 0;
int n_bytes_to_skip = 0; // to skip printing when generating token healing prefix
std::vector<int> input_tokens; g_input_tokens = &input_tokens; std::vector<int> input_tokens; g_input_tokens = &input_tokens;
std::vector<int> output_tokens; g_output_tokens = &output_tokens; std::vector<int> output_tokens; g_output_tokens = &output_tokens;
@ -745,7 +746,16 @@ int main(int argc, char ** argv) {
if (input_echo && display) { if (input_echo && display) {
for (auto id : embd) { for (auto id : embd) {
const std::string token_str = llama_token_to_piece(ctx, id); const std::string token_str = llama_token_to_piece(ctx, id);
// Suppress printing while generating token healing prefix (only for interactive mode; kinda hacky...)
if (n_bytes_to_skip > 0 && n_bytes_to_skip < (int)token_str.size()) {
printf("%s", token_str.substr(n_bytes_to_skip).c_str());
n_bytes_to_skip = 0;
} else if (n_bytes_to_skip > 0) {
n_bytes_to_skip -= token_str.size();
} else {
printf("%s", token_str.c_str()); printf("%s", token_str.c_str());
}
if (embd.size() > 1) { if (embd.size() > 1) {
input_tokens.push_back(id); input_tokens.push_back(id);
@ -939,6 +949,7 @@ int main(int argc, char ** argv) {
if (token_healing_n_removed > 0) { if (token_healing_n_removed > 0) {
// Set new prefix after an interaction // Set new prefix after an interaction
ctx_sampling->token_healing_prefix = token_healing_prefix; ctx_sampling->token_healing_prefix = token_healing_prefix;
n_bytes_to_skip = ctx_sampling->token_healing_prefix.size();
} }
} }
is_interacting = false; is_interacting = false;