diff --git a/third_party/ggml/llama.cc b/third_party/ggml/llama.cc index ae098d550..7de386066 100644 --- a/third_party/ggml/llama.cc +++ b/third_party/ggml/llama.cc @@ -29,6 +29,7 @@ #include "third_party/ggml/llama.h" #include "libc/assert.h" #include "libc/intrin/bits.h" +#include "libc/macros.internal.h" #include "third_party/ggml/ggml.h" #include "third_party/ggml/llama_util.h" #include "third_party/libcxx/algorithm" @@ -225,6 +226,7 @@ struct llama_vocab { std::unordered_map token_to_id; std::vector id_to_token; + int longest_token; }; struct llama_context { @@ -475,6 +477,7 @@ struct llama_file_loader { hparams.ftype = (enum llama_ftype) file.read_u32(); } void read_vocab() { + vocab.longest_token = 0; vocab.id_to_token.resize(hparams.n_vocab); for (uint32_t i = 0; i < hparams.n_vocab; i++) { @@ -487,6 +490,7 @@ struct llama_file_loader { } vocab.token_to_id[word] = i; + vocab.longest_token = MAX(vocab.longest_token, word.size()); auto & tok_score = vocab.id_to_token[i]; tok_score.tok = std::move(word); @@ -2755,6 +2759,10 @@ const char * llama_token_to_str(const struct llama_context * ctx, llama_token to return ctx->vocab.id_to_token[token].tok.c_str(); } +int llama_longest_token(const struct llama_context * ctx) { + return ctx->vocab.longest_token; +} + llama_token llama_token_bos() { return 1; } diff --git a/third_party/ggml/llama.h b/third_party/ggml/llama.h index 6b3c143d5..041794768 100644 --- a/third_party/ggml/llama.h +++ b/third_party/ggml/llama.h @@ -183,6 +183,9 @@ extern "C" { // Token Id -> String. Uses the vocabulary in the provided context LLAMA_API const char * llama_token_to_str(const struct llama_context * ctx, llama_token token); + // Returns number of bytes in the longest token string. + LLAMA_API int llama_longest_token(const struct llama_context * ctx); + // Special tokens LLAMA_API llama_token llama_token_bos(); LLAMA_API llama_token llama_token_eos(); diff --git a/third_party/ggml/main.cc b/third_party/ggml/main.cc index 3efb2204e..65aba0571 100644 --- a/third_party/ggml/main.cc +++ b/third_party/ggml/main.cc @@ -32,6 +32,7 @@ #include "libc/calls/struct/stat.h" #include "libc/intrin/bits.h" #include "libc/log/log.h" +#include "libc/macros.internal.h" #include "libc/nexgen32e/x86feature.h" #include "libc/stdio/stdio.h" #include "libc/sysv/consts/map.h" @@ -117,6 +118,7 @@ static void remember_init() { for (std::string & antiprompt : params.antiprompt) { longest_antiprompt = std::max(longest_antiprompt, antiprompt.size()); } + longest_antiprompt += llama_longest_token(ctx) * 2; } static void remember_token(llama_token tok) { @@ -284,7 +286,7 @@ int main(int argc, char ** argv) { } // Add a space in front of the first character to match OG llama tokenizer behavior - params.prompt.insert(0, 1, ' '); + // params.prompt.insert(0, 1, ' '); // tokenize the prompt auto embd_inp = ::llama_tokenize(ctx, params.prompt, true); @@ -757,17 +759,32 @@ int main(int argc, char ** argv) { // --prompt 'Question: How old are you?\nAnswer: ' // --reverse-prompt $'\n' // - is_antiprompt = has_antiprompt(); + std::string ap_text; + std::string::size_type ap_index; + std::string::size_type ap_extra; + is_antiprompt = has_antiprompt(&ap_index, &ap_text); // display text + bool got_newline = false; if (!input_noecho) { + std::string printme; for (auto id : embd) { - printf("%s", llama_token_to_str(ctx, id)); + printme.append(llama_token_to_str(ctx, id)); + } + if (is_antiprompt) { + ap_extra = last_output.size() - (ap_index + ap_text.size()); + printme.erase(printme.size() - MIN(printme.size(), ap_extra)); + } + if (printme.size()) { + got_newline = printme[printme.size() - 1] == '\n'; + printf("%s", printme.c_str()); + fflush(stdout); } - fflush(stdout); } if (is_antiprompt && !params.interactive) { - printf("\n"); + if (!got_newline) { + printf("\n"); + } break; } if (prompt_status == kPromptCompleted) {