llama.cpp : fix spm whitespace escaping + clean up

This commit is contained in:
klosax 2023-08-26 12:08:34 +02:00 committed by GitHub
parent bae5c5f679
commit d52894602d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1635,7 +1635,7 @@ static void llm_load_hparams(
} }
// TODO: This should probably be in llama.h // TODO: This should probably be in llama.h
static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, const std::string & raw_text, bool bos, bool escape); static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, const std::string & raw_text, bool bos);
static void llm_load_vocab( static void llm_load_vocab(
llama_model_loader & ml, llama_model_loader & ml,
@ -1737,7 +1737,7 @@ static void llm_load_vocab(
} }
// determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n' // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n'
vocab.linefeed_id = llama_tokenize_internal(vocab, "\n", false, false)[0]; vocab.linefeed_id = llama_tokenize_internal(vocab, "\n", false)[0];
// special tokens // special tokens
GGUF_GET_KEY(ctx, vocab.special_bos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_BOS_ID)); GGUF_GET_KEY(ctx, vocab.special_bos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_BOS_ID));
@ -3027,14 +3027,8 @@ static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) {
} }
static std::string llama_escape_whitespace(const std::string& text) { static std::string llama_escape_whitespace(const std::string& text) {
std::string result = "\xe2\x96\x81"; std::string result = text;
for (size_t offs = 0; offs < text.length(); ++offs) { replace_all(result, " ", "\xe2\x96\x81");
if (text[offs] == ' ') {
result += "\xe2\x96\x81";
} else {
result += text[offs];
}
}
return result; return result;
} }
@ -3219,7 +3213,7 @@ struct llm_bigram_bpe {
}; };
struct llm_tokenizer_bpe { struct llm_tokenizer_bpe {
llm_tokenizer_bpe(const llama_vocab & vocab, bool g2ws): vocab(vocab) { flag_g2ws = g2ws; } llm_tokenizer_bpe(const llama_vocab & vocab): vocab(vocab) {}
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) { void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
int final_prev_index = -1; int final_prev_index = -1;
@ -3371,8 +3365,6 @@ private:
return words; return words;
} }
bool flag_g2ws = false;
const llama_vocab & vocab; const llama_vocab & vocab;
std::vector<llm_symbol> symbols; std::vector<llm_symbol> symbols;
@ -3381,39 +3373,26 @@ private:
llm_bigram_bpe::queue work_queue; llm_bigram_bpe::queue work_queue;
}; };
static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, const std::string & raw_text, bool bos, bool escape) { static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, const std::string & raw_text, bool bos) {
std::vector<llama_vocab::id> output; std::vector<llama_vocab::id> output;
if (raw_text.empty()) { if (raw_text.empty()) {
return output; return output;
} }
if (bos && vocab.special_bos_id != -1) {
output.push_back(vocab.special_bos_id);
}
switch (vocab.type) { switch (vocab.type) {
case LLAMA_VOCAB_TYPE_SPM: case LLAMA_VOCAB_TYPE_SPM:
{ {
llm_tokenizer_spm tokenizer(vocab); llm_tokenizer_spm tokenizer(vocab);
tokenizer.tokenize(llama_escape_whitespace(raw_text), output);
if (bos) {
output.push_back(vocab.special_bos_id);
}
std::string text;
if (escape) {
text = llama_escape_whitespace(raw_text);
} else {
text = raw_text;
}
tokenizer.tokenize(text, output);
} break; } break;
case LLAMA_VOCAB_TYPE_BPE: case LLAMA_VOCAB_TYPE_BPE:
{ {
llm_tokenizer_bpe tokenizer(vocab, escape); llm_tokenizer_bpe tokenizer(vocab);
if (bos && vocab.special_bos_id != -1) {
output.push_back(vocab.special_bos_id);
}
tokenizer.tokenize(raw_text, output); tokenizer.tokenize(raw_text, output);
} break; } break;
}; };
@ -6095,8 +6074,7 @@ int llama_tokenize_with_model(
llama_token * tokens, llama_token * tokens,
int n_max_tokens, int n_max_tokens,
bool add_bos) { bool add_bos) {
auto escape = llama_vocab_get_type(model->vocab) == LLAMA_VOCAB_TYPE_SPM; auto res = llama_tokenize_internal(model->vocab, text, add_bos);
auto res = llama_tokenize_internal(model->vocab, text, add_bos, escape);
if (n_max_tokens < (int) res.size()) { if (n_max_tokens < (int) res.size()) {
LLAMA_LOG_ERROR("%s: too many tokens\n", __func__); LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);