mirror of
https://github.com/jart/cosmopolitan.git
synced 2025-06-28 15:28:30 +00:00
Fix subtoken antiprompt scanning
This commit is contained in:
parent
80c174d494
commit
e8de1e4766
3 changed files with 33 additions and 5 deletions
8
third_party/ggml/llama.cc
vendored
8
third_party/ggml/llama.cc
vendored
|
@ -29,6 +29,7 @@
|
||||||
#include "third_party/ggml/llama.h"
|
#include "third_party/ggml/llama.h"
|
||||||
#include "libc/assert.h"
|
#include "libc/assert.h"
|
||||||
#include "libc/intrin/bits.h"
|
#include "libc/intrin/bits.h"
|
||||||
|
#include "libc/macros.internal.h"
|
||||||
#include "third_party/ggml/ggml.h"
|
#include "third_party/ggml/ggml.h"
|
||||||
#include "third_party/ggml/llama_util.h"
|
#include "third_party/ggml/llama_util.h"
|
||||||
#include "third_party/libcxx/algorithm"
|
#include "third_party/libcxx/algorithm"
|
||||||
|
@ -225,6 +226,7 @@ struct llama_vocab {
|
||||||
|
|
||||||
std::unordered_map<token, id> token_to_id;
|
std::unordered_map<token, id> token_to_id;
|
||||||
std::vector<token_score> id_to_token;
|
std::vector<token_score> id_to_token;
|
||||||
|
int longest_token;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_context {
|
struct llama_context {
|
||||||
|
@ -475,6 +477,7 @@ struct llama_file_loader {
|
||||||
hparams.ftype = (enum llama_ftype) file.read_u32();
|
hparams.ftype = (enum llama_ftype) file.read_u32();
|
||||||
}
|
}
|
||||||
void read_vocab() {
|
void read_vocab() {
|
||||||
|
vocab.longest_token = 0;
|
||||||
vocab.id_to_token.resize(hparams.n_vocab);
|
vocab.id_to_token.resize(hparams.n_vocab);
|
||||||
|
|
||||||
for (uint32_t i = 0; i < hparams.n_vocab; i++) {
|
for (uint32_t i = 0; i < hparams.n_vocab; i++) {
|
||||||
|
@ -487,6 +490,7 @@ struct llama_file_loader {
|
||||||
}
|
}
|
||||||
|
|
||||||
vocab.token_to_id[word] = i;
|
vocab.token_to_id[word] = i;
|
||||||
|
vocab.longest_token = MAX(vocab.longest_token, word.size());
|
||||||
|
|
||||||
auto & tok_score = vocab.id_to_token[i];
|
auto & tok_score = vocab.id_to_token[i];
|
||||||
tok_score.tok = std::move(word);
|
tok_score.tok = std::move(word);
|
||||||
|
@ -2755,6 +2759,10 @@ const char * llama_token_to_str(const struct llama_context * ctx, llama_token to
|
||||||
return ctx->vocab.id_to_token[token].tok.c_str();
|
return ctx->vocab.id_to_token[token].tok.c_str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int llama_longest_token(const struct llama_context * ctx) {
|
||||||
|
return ctx->vocab.longest_token;
|
||||||
|
}
|
||||||
|
|
||||||
llama_token llama_token_bos() {
|
llama_token llama_token_bos() {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
3
third_party/ggml/llama.h
vendored
3
third_party/ggml/llama.h
vendored
|
@ -183,6 +183,9 @@ extern "C" {
|
||||||
// Token Id -> String. Uses the vocabulary in the provided context
|
// Token Id -> String. Uses the vocabulary in the provided context
|
||||||
LLAMA_API const char * llama_token_to_str(const struct llama_context * ctx, llama_token token);
|
LLAMA_API const char * llama_token_to_str(const struct llama_context * ctx, llama_token token);
|
||||||
|
|
||||||
|
// Returns number of bytes in the longest token string.
|
||||||
|
LLAMA_API int llama_longest_token(const struct llama_context * ctx);
|
||||||
|
|
||||||
// Special tokens
|
// Special tokens
|
||||||
LLAMA_API llama_token llama_token_bos();
|
LLAMA_API llama_token llama_token_bos();
|
||||||
LLAMA_API llama_token llama_token_eos();
|
LLAMA_API llama_token llama_token_eos();
|
||||||
|
|
27
third_party/ggml/main.cc
vendored
27
third_party/ggml/main.cc
vendored
|
@ -32,6 +32,7 @@
|
||||||
#include "libc/calls/struct/stat.h"
|
#include "libc/calls/struct/stat.h"
|
||||||
#include "libc/intrin/bits.h"
|
#include "libc/intrin/bits.h"
|
||||||
#include "libc/log/log.h"
|
#include "libc/log/log.h"
|
||||||
|
#include "libc/macros.internal.h"
|
||||||
#include "libc/nexgen32e/x86feature.h"
|
#include "libc/nexgen32e/x86feature.h"
|
||||||
#include "libc/stdio/stdio.h"
|
#include "libc/stdio/stdio.h"
|
||||||
#include "libc/sysv/consts/map.h"
|
#include "libc/sysv/consts/map.h"
|
||||||
|
@ -117,6 +118,7 @@ static void remember_init() {
|
||||||
for (std::string & antiprompt : params.antiprompt) {
|
for (std::string & antiprompt : params.antiprompt) {
|
||||||
longest_antiprompt = std::max(longest_antiprompt, antiprompt.size());
|
longest_antiprompt = std::max(longest_antiprompt, antiprompt.size());
|
||||||
}
|
}
|
||||||
|
longest_antiprompt += llama_longest_token(ctx) * 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void remember_token(llama_token tok) {
|
static void remember_token(llama_token tok) {
|
||||||
|
@ -284,7 +286,7 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add a space in front of the first character to match OG llama tokenizer behavior
|
// Add a space in front of the first character to match OG llama tokenizer behavior
|
||||||
params.prompt.insert(0, 1, ' ');
|
// params.prompt.insert(0, 1, ' ');
|
||||||
|
|
||||||
// tokenize the prompt
|
// tokenize the prompt
|
||||||
auto embd_inp = ::llama_tokenize(ctx, params.prompt, true);
|
auto embd_inp = ::llama_tokenize(ctx, params.prompt, true);
|
||||||
|
@ -757,17 +759,32 @@ int main(int argc, char ** argv) {
|
||||||
// --prompt 'Question: How old are you?\nAnswer: '
|
// --prompt 'Question: How old are you?\nAnswer: '
|
||||||
// --reverse-prompt $'\n'
|
// --reverse-prompt $'\n'
|
||||||
//
|
//
|
||||||
is_antiprompt = has_antiprompt();
|
std::string ap_text;
|
||||||
|
std::string::size_type ap_index;
|
||||||
|
std::string::size_type ap_extra;
|
||||||
|
is_antiprompt = has_antiprompt(&ap_index, &ap_text);
|
||||||
|
|
||||||
// display text
|
// display text
|
||||||
|
bool got_newline = false;
|
||||||
if (!input_noecho) {
|
if (!input_noecho) {
|
||||||
|
std::string printme;
|
||||||
for (auto id : embd) {
|
for (auto id : embd) {
|
||||||
printf("%s", llama_token_to_str(ctx, id));
|
printme.append(llama_token_to_str(ctx, id));
|
||||||
|
}
|
||||||
|
if (is_antiprompt) {
|
||||||
|
ap_extra = last_output.size() - (ap_index + ap_text.size());
|
||||||
|
printme.erase(printme.size() - MIN(printme.size(), ap_extra));
|
||||||
|
}
|
||||||
|
if (printme.size()) {
|
||||||
|
got_newline = printme[printme.size() - 1] == '\n';
|
||||||
|
printf("%s", printme.c_str());
|
||||||
|
fflush(stdout);
|
||||||
}
|
}
|
||||||
fflush(stdout);
|
|
||||||
}
|
}
|
||||||
if (is_antiprompt && !params.interactive) {
|
if (is_antiprompt && !params.interactive) {
|
||||||
printf("\n");
|
if (!got_newline) {
|
||||||
|
printf("\n");
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if (prompt_status == kPromptCompleted) {
|
if (prompt_status == kPromptCompleted) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue