fix: set ignore_merges only for llama-3
This commit is contained in:
parent
c7614930f3
commit
8a51d3b12d
2 changed files with 19 additions and 4 deletions
|
@ -12200,12 +12200,14 @@ struct llm_tokenizer_bpe {
|
|||
|
||||
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
|
||||
int final_prev_index = -1;
|
||||
int ignore_merges = false;
|
||||
|
||||
std::vector<std::string> word_collection;
|
||||
switch (vocab.type) {
|
||||
case LLAMA_VOCAB_TYPE_BPE:
|
||||
switch (vocab.type_pre) {
|
||||
case LLAMA_VOCAB_PRE_TYPE_LLAMA3:
|
||||
ignore_merges = true;
|
||||
case LLAMA_VOCAB_PRE_TYPE_DBRX:
|
||||
word_collection = unicode_regex_split(text, {
|
||||
// original regex from tokenizer.json
|
||||
|
@ -12292,7 +12294,7 @@ struct llm_tokenizer_bpe {
|
|||
symbols_final.clear();
|
||||
|
||||
for (auto & word : word_collection) {
|
||||
if (vocab.token_to_id.find(word) != vocab.token_to_id.end()) {
|
||||
if (ignore_merges && vocab.token_to_id.find(word) != vocab.token_to_id.end()) {
|
||||
llm_symbol sym;
|
||||
sym.text = word.c_str();
|
||||
sym.n = word.size();
|
||||
|
|
|
@ -13,15 +13,27 @@
|
|||
#include <vector>
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
if (argc < 2) {
|
||||
fprintf(stderr, "Usage: %s <vocab-file>\n", argv[0]);
|
||||
if (argc < 2 || argc > 3) {
|
||||
fprintf(stderr, "Usage: %s <vocab-file> [--ignore-merges]\n", argv[0]);
|
||||
return 1;
|
||||
}
|
||||
|
||||
const std::string fname = argv[1];
|
||||
bool ignore_merges = false;
|
||||
if (argc == 3) {
|
||||
if (std::strcmp(argv[2], "ignore-merges") != 0) {
|
||||
fprintf(stderr, "Usage: %s <vocab-file> [--ignore-merges]\n", argv[0]);
|
||||
return 1;
|
||||
}
|
||||
ignore_merges = true;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s : reading vocab from: '%s'\n", __func__, fname.c_str());
|
||||
|
||||
if (ignore_merges) {
|
||||
fprintf(stderr, "%s : ignoring merges for tokens inside vocab\n", __func__);
|
||||
}
|
||||
|
||||
llama_model * model;
|
||||
llama_context * ctx;
|
||||
|
||||
|
@ -66,7 +78,7 @@ int main(int argc, char **argv) {
|
|||
try {
|
||||
auto cps = unicode_cpts_from_utf8(str);
|
||||
std::vector<llama_token> tokens = llama_tokenize(ctx, str, false, true);
|
||||
if (tokens.size() > 1) {
|
||||
if (ignore_merges && tokens.size() > 1) {
|
||||
fprintf(stderr,
|
||||
"%s : error: token %d detokenizes to '%s'(%zu) but "
|
||||
"tokenization of this to multiple tokens: [",
|
||||
|
@ -76,6 +88,7 @@ int main(int argc, char **argv) {
|
|||
fprintf(stderr, ", %d", tokens[i]);
|
||||
}
|
||||
fprintf(stderr, "]\n");
|
||||
return 2;
|
||||
}
|
||||
std::string check = llama_detokenize_bpe(ctx, tokens);
|
||||
if (check != str) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue