Add show token count

This commit is contained in:
pudepiedj 2024-01-11 10:08:58 +00:00
parent 9289306e7a
commit c22c70454c
3 changed files with 10 additions and 2 deletions

View file

@ -616,6 +616,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break; break;
} }
params.ppl_stride = std::stoi(argv[i]); params.ppl_stride = std::stoi(argv[i]);
} else if (arg == "-stc" || arg == "--show_token_count") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.token_interval = std::stoi(argv[i]);
} else if (arg == "--ppl-output-type") { } else if (arg == "--ppl-output-type") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
@ -926,6 +932,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --override-kv KEY=TYPE:VALUE\n"); printf(" --override-kv KEY=TYPE:VALUE\n");
printf(" advanced option to override model metadata by key. may be specified multiple times.\n"); printf(" advanced option to override model metadata by key. may be specified multiple times.\n");
printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n"); printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
printf(" -stc N --show_token_count N\n");
printf(" show consumed tokens every N tokens\n");
printf("\n"); printf("\n");
#ifndef LOG_DISABLE_LOGS #ifndef LOG_DISABLE_LOGS
log_print_usage(); log_print_usage();

View file

@ -64,6 +64,7 @@ struct gpt_params {
int32_t n_beams = 0; // if non-zero then use beam search of given width. int32_t n_beams = 0; // if non-zero then use beam search of given width.
float rope_freq_base = 0.0f; // RoPE base frequency float rope_freq_base = 0.0f; // RoPE base frequency
float rope_freq_scale = 0.0f; // RoPE frequency scaling factor float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
int32_t token_interval = 512; // show token count every 512 tokens
float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor
float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor
float yarn_beta_fast = 32.0f; // YaRN low correction dim float yarn_beta_fast = 32.0f; // YaRN low correction dim
@ -240,4 +241,3 @@ void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size = 80);
// Dump the KV cache view showing individual sequences in each cell (long output). // Dump the KV cache view showing individual sequences in each cell (long output).
void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40); void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40);

View file

@ -613,7 +613,7 @@ int main(int argc, char ** argv) {
LOG("n_past = %d\n", n_past); LOG("n_past = %d\n", n_past);
// I added the next two lines on 20240110 // I added the next two lines on 20240110
if (n_past % 256 == 0) { if (n_past % params.token_interval == 0) {
printf("\n\033[31mTokens consumed so far = %d / %d \033[0m\n", n_past, n_ctx); printf("\n\033[31mTokens consumed so far = %d / %d \033[0m\n", n_past, n_ctx);
} }
} }