diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 8153a71fb..5c453a57e 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -528,7 +528,20 @@ int main(int argc, char ** argv) { exit(1); } - bool should_show_special_tokens = sparams.grammar.empty(); + // Create the pipe for special token handling + int stok_pipe[2] = {0}; + if (pipe(stok_pipe) == -1) { + fprintf(stderr, "%s: failed to initialize special token output stream\n", __func__); + exit(1); + } + + close(stok_pipe[0]); // Read Special Token Not In Use + + FILE *special_token_stream_output_fd = fdopen(stok_pipe[1], "w"); + if (special_token_stream_output_fd == NULL) { + fprintf(stderr, "%s: failed to open special token output stream\n", __func__); + exit(1); + } while ((n_remain != 0 && !is_antiprompt) || params.interactive) { // predict @@ -742,18 +755,31 @@ int main(int argc, char ** argv) { // display text if (input_echo && display) { for (auto id : embd) { - const std::string token_str = llama_token_to_piece(ctx, id, !params.conversation && should_show_special_tokens); - printf("%s", token_str.c_str()); + const std::string token_str = llama_token_to_piece(ctx, id); + // Console/Stream Output + if (llama_token_is_control_token(llama_get_model(ctx), id)) { + // Stream Output Token To Special Token Output + fprintf(special_token_stream_output_fd, "%s", token_str.c_str()); + } else { + // Stream Output Token To Standard Output + fprintf(stdout, "%s", token_str.c_str()); + } + + // Record Displayed Tokens To Log + // Note: Generated tokens are created one by one hence this check if (embd.size() > 1) { + // Incoming Requested Tokens input_tokens.push_back(id); } else { + // Outgoing Generated Tokens output_tokens.push_back(id); output_ss << token_str; } } fflush(stdout); } + // reset color to default if there is no pending user input if (input_echo && (int) embd_inp.size() == n_consumed) { console::set_display(console::reset); @@ -908,7 +934,7 @@ int main(int argc, char ** argv) { for (size_t i = original_size; i < embd_inp.size(); ++i) { const llama_token token = embd_inp[i]; output_tokens.push_back(token); - output_ss << llama_token_to_piece(ctx, token, should_show_special_tokens); + output_ss << llama_token_to_piece(ctx, token); } n_remain -= line_inp.size(); @@ -957,6 +983,8 @@ int main(int argc, char ** argv) { llama_sampling_free(ctx_sampling); llama_backend_free(); + fclose(special_token_stream_output_fd); + #ifndef LOG_DISABLE_LOGS LOG_TEE("Log end\n"); #endif // LOG_DISABLE_LOGS diff --git a/llama.cpp b/llama.cpp index b752ddc6b..f41c6e5b6 100644 --- a/llama.cpp +++ b/llama.cpp @@ -17634,6 +17634,10 @@ bool llama_token_is_eog(const struct llama_model * model, llama_token token) { ); } +bool llama_token_is_control_token(const struct llama_model * model, llama_token token) { + return llama_is_control_token(model->vocab, token); +} + llama_token llama_token_bos(const struct llama_model * model) { return model->vocab.special_bos_id; } diff --git a/llama.h b/llama.h index 612e32c4e..7cacb3d64 100644 --- a/llama.h +++ b/llama.h @@ -816,6 +816,9 @@ extern "C" { // Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.) LLAMA_API bool llama_token_is_eog(const struct llama_model * model, llama_token token); + // Identify if Token Id is a control token or a render-able token + LLAMA_API bool llama_token_is_control_token(const struct llama_model * model, llama_token token); + // Special tokens LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence