mirror of
https://github.com/jart/cosmopolitan.git
synced 2025-01-31 03:27:39 +00:00
Fix subtoken antiprompt scanning
This commit is contained in:
parent
80c174d494
commit
e8de1e4766
3 changed files with 33 additions and 5 deletions
8
third_party/ggml/llama.cc
vendored
8
third_party/ggml/llama.cc
vendored
|
@ -29,6 +29,7 @@
|
|||
#include "third_party/ggml/llama.h"
|
||||
#include "libc/assert.h"
|
||||
#include "libc/intrin/bits.h"
|
||||
#include "libc/macros.internal.h"
|
||||
#include "third_party/ggml/ggml.h"
|
||||
#include "third_party/ggml/llama_util.h"
|
||||
#include "third_party/libcxx/algorithm"
|
||||
|
@ -225,6 +226,7 @@ struct llama_vocab {
|
|||
|
||||
std::unordered_map<token, id> token_to_id;
|
||||
std::vector<token_score> id_to_token;
|
||||
int longest_token;
|
||||
};
|
||||
|
||||
struct llama_context {
|
||||
|
@ -475,6 +477,7 @@ struct llama_file_loader {
|
|||
hparams.ftype = (enum llama_ftype) file.read_u32();
|
||||
}
|
||||
void read_vocab() {
|
||||
vocab.longest_token = 0;
|
||||
vocab.id_to_token.resize(hparams.n_vocab);
|
||||
|
||||
for (uint32_t i = 0; i < hparams.n_vocab; i++) {
|
||||
|
@ -487,6 +490,7 @@ struct llama_file_loader {
|
|||
}
|
||||
|
||||
vocab.token_to_id[word] = i;
|
||||
vocab.longest_token = MAX(vocab.longest_token, word.size());
|
||||
|
||||
auto & tok_score = vocab.id_to_token[i];
|
||||
tok_score.tok = std::move(word);
|
||||
|
@ -2755,6 +2759,10 @@ const char * llama_token_to_str(const struct llama_context * ctx, llama_token to
|
|||
return ctx->vocab.id_to_token[token].tok.c_str();
|
||||
}
|
||||
|
||||
int llama_longest_token(const struct llama_context * ctx) {
|
||||
return ctx->vocab.longest_token;
|
||||
}
|
||||
|
||||
llama_token llama_token_bos() {
|
||||
return 1;
|
||||
}
|
||||
|
|
3
third_party/ggml/llama.h
vendored
3
third_party/ggml/llama.h
vendored
|
@ -183,6 +183,9 @@ extern "C" {
|
|||
// Token Id -> String. Uses the vocabulary in the provided context
|
||||
LLAMA_API const char * llama_token_to_str(const struct llama_context * ctx, llama_token token);
|
||||
|
||||
// Returns number of bytes in the longest token string.
|
||||
LLAMA_API int llama_longest_token(const struct llama_context * ctx);
|
||||
|
||||
// Special tokens
|
||||
LLAMA_API llama_token llama_token_bos();
|
||||
LLAMA_API llama_token llama_token_eos();
|
||||
|
|
27
third_party/ggml/main.cc
vendored
27
third_party/ggml/main.cc
vendored
|
@ -32,6 +32,7 @@
|
|||
#include "libc/calls/struct/stat.h"
|
||||
#include "libc/intrin/bits.h"
|
||||
#include "libc/log/log.h"
|
||||
#include "libc/macros.internal.h"
|
||||
#include "libc/nexgen32e/x86feature.h"
|
||||
#include "libc/stdio/stdio.h"
|
||||
#include "libc/sysv/consts/map.h"
|
||||
|
@ -117,6 +118,7 @@ static void remember_init() {
|
|||
for (std::string & antiprompt : params.antiprompt) {
|
||||
longest_antiprompt = std::max(longest_antiprompt, antiprompt.size());
|
||||
}
|
||||
longest_antiprompt += llama_longest_token(ctx) * 2;
|
||||
}
|
||||
|
||||
static void remember_token(llama_token tok) {
|
||||
|
@ -284,7 +286,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// Add a space in front of the first character to match OG llama tokenizer behavior
|
||||
params.prompt.insert(0, 1, ' ');
|
||||
// params.prompt.insert(0, 1, ' ');
|
||||
|
||||
// tokenize the prompt
|
||||
auto embd_inp = ::llama_tokenize(ctx, params.prompt, true);
|
||||
|
@ -757,17 +759,32 @@ int main(int argc, char ** argv) {
|
|||
// --prompt 'Question: How old are you?\nAnswer: '
|
||||
// --reverse-prompt $'\n'
|
||||
//
|
||||
is_antiprompt = has_antiprompt();
|
||||
std::string ap_text;
|
||||
std::string::size_type ap_index;
|
||||
std::string::size_type ap_extra;
|
||||
is_antiprompt = has_antiprompt(&ap_index, &ap_text);
|
||||
|
||||
// display text
|
||||
bool got_newline = false;
|
||||
if (!input_noecho) {
|
||||
std::string printme;
|
||||
for (auto id : embd) {
|
||||
printf("%s", llama_token_to_str(ctx, id));
|
||||
printme.append(llama_token_to_str(ctx, id));
|
||||
}
|
||||
if (is_antiprompt) {
|
||||
ap_extra = last_output.size() - (ap_index + ap_text.size());
|
||||
printme.erase(printme.size() - MIN(printme.size(), ap_extra));
|
||||
}
|
||||
if (printme.size()) {
|
||||
got_newline = printme[printme.size() - 1] == '\n';
|
||||
printf("%s", printme.c_str());
|
||||
fflush(stdout);
|
||||
}
|
||||
fflush(stdout);
|
||||
}
|
||||
if (is_antiprompt && !params.interactive) {
|
||||
printf("\n");
|
||||
if (!got_newline) {
|
||||
printf("\n");
|
||||
}
|
||||
break;
|
||||
}
|
||||
if (prompt_status == kPromptCompleted) {
|
||||
|
|
Loading…
Reference in a new issue