Fix subtoken antiprompt scanning

This commit is contained in:
Justine Tunney 2023-05-12 08:55:40 -07:00
parent 80c174d494
commit e8de1e4766
No known key found for this signature in database
GPG key ID: BE714B4575D6E328
3 changed files with 33 additions and 5 deletions

View file

@ -29,6 +29,7 @@
#include "third_party/ggml/llama.h"
#include "libc/assert.h"
#include "libc/intrin/bits.h"
#include "libc/macros.internal.h"
#include "third_party/ggml/ggml.h"
#include "third_party/ggml/llama_util.h"
#include "third_party/libcxx/algorithm"
@ -225,6 +226,7 @@ struct llama_vocab {
std::unordered_map<token, id> token_to_id;
std::vector<token_score> id_to_token;
int longest_token;
};
struct llama_context {
@ -475,6 +477,7 @@ struct llama_file_loader {
hparams.ftype = (enum llama_ftype) file.read_u32();
}
void read_vocab() {
vocab.longest_token = 0;
vocab.id_to_token.resize(hparams.n_vocab);
for (uint32_t i = 0; i < hparams.n_vocab; i++) {
@ -487,6 +490,7 @@ struct llama_file_loader {
}
vocab.token_to_id[word] = i;
vocab.longest_token = MAX(vocab.longest_token, word.size());
auto & tok_score = vocab.id_to_token[i];
tok_score.tok = std::move(word);
@ -2755,6 +2759,10 @@ const char * llama_token_to_str(const struct llama_context * ctx, llama_token to
return ctx->vocab.id_to_token[token].tok.c_str();
}
int llama_longest_token(const struct llama_context * ctx) {
return ctx->vocab.longest_token;
}
llama_token llama_token_bos() {
return 1;
}

View file

@ -183,6 +183,9 @@ extern "C" {
// Token Id -> String. Uses the vocabulary in the provided context
LLAMA_API const char * llama_token_to_str(const struct llama_context * ctx, llama_token token);
// Returns number of bytes in the longest token string.
LLAMA_API int llama_longest_token(const struct llama_context * ctx);
// Special tokens
LLAMA_API llama_token llama_token_bos();
LLAMA_API llama_token llama_token_eos();

View file

@ -32,6 +32,7 @@
#include "libc/calls/struct/stat.h"
#include "libc/intrin/bits.h"
#include "libc/log/log.h"
#include "libc/macros.internal.h"
#include "libc/nexgen32e/x86feature.h"
#include "libc/stdio/stdio.h"
#include "libc/sysv/consts/map.h"
@ -117,6 +118,7 @@ static void remember_init() {
for (std::string & antiprompt : params.antiprompt) {
longest_antiprompt = std::max(longest_antiprompt, antiprompt.size());
}
longest_antiprompt += llama_longest_token(ctx) * 2;
}
static void remember_token(llama_token tok) {
@ -284,7 +286,7 @@ int main(int argc, char ** argv) {
}
// Add a space in front of the first character to match OG llama tokenizer behavior
params.prompt.insert(0, 1, ' ');
// params.prompt.insert(0, 1, ' ');
// tokenize the prompt
auto embd_inp = ::llama_tokenize(ctx, params.prompt, true);
@ -757,17 +759,32 @@ int main(int argc, char ** argv) {
// --prompt 'Question: How old are you?\nAnswer: '
// --reverse-prompt $'\n'
//
is_antiprompt = has_antiprompt();
std::string ap_text;
std::string::size_type ap_index;
std::string::size_type ap_extra;
is_antiprompt = has_antiprompt(&ap_index, &ap_text);
// display text
bool got_newline = false;
if (!input_noecho) {
std::string printme;
for (auto id : embd) {
printf("%s", llama_token_to_str(ctx, id));
printme.append(llama_token_to_str(ctx, id));
}
if (is_antiprompt) {
ap_extra = last_output.size() - (ap_index + ap_text.size());
printme.erase(printme.size() - MIN(printme.size(), ap_extra));
}
if (printme.size()) {
got_newline = printme[printme.size() - 1] == '\n';
printf("%s", printme.c_str());
fflush(stdout);
}
fflush(stdout);
}
if (is_antiprompt && !params.interactive) {
printf("\n");
if (!got_newline) {
printf("\n");
}
break;
}
if (prompt_status == kPromptCompleted) {