From d3f0c7166adfa952237e0f437a5344362d8256d4 Mon Sep 17 00:00:00 2001 From: fairydreaming <166155368+fairydreaming@users.noreply.github.com> Date: Mon, 5 Aug 2024 09:38:01 +0200 Subject: [PATCH] Stop the generation when <|eom_id|> token is encountered - needed for Llama 3.1 tool call support (#8858) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * gguf-py, llama : add constants and methods related to Llama-3.1 <|eom_id|> token * llama : find Llama-3.1 <|eom_id|> token id during vocab loading * llama-vocab : add Llama-3.1 <|eom_id|> token to the set of tokens stopping the generation --------- Co-authored-by: Stanisław Szymczyk --- gguf-py/gguf/constants.py | 2 ++ gguf-py/gguf/gguf_writer.py | 3 +++ src/llama-vocab.cpp | 7 ++++++- src/llama-vocab.h | 2 ++ src/llama.cpp | 14 ++++++++++++++ 5 files changed, 27 insertions(+), 1 deletion(-) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index e343c2ef1..59ffd92ea 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -161,6 +161,7 @@ class Keys: SUFFIX_ID = "tokenizer.ggml.suffix_token_id" MIDDLE_ID = "tokenizer.ggml.middle_token_id" EOT_ID = "tokenizer.ggml.eot_token_id" + EOM_ID = "tokenizer.ggml.eom_token_id" class Adapter: TYPE = "adapter.type" @@ -1327,3 +1328,4 @@ KEY_TOKENIZER_PRIFIX_ID = Keys.Tokenizer.PREFIX_ID KEY_TOKENIZER_SUFFIX_ID = Keys.Tokenizer.SUFFIX_ID KEY_TOKENIZER_MIDDLE_ID = Keys.Tokenizer.MIDDLE_ID KEY_TOKENIZER_EOT_ID = Keys.Tokenizer.EOT_ID +KEY_TOKENIZER_EOM_ID = Keys.Tokenizer.EOM_ID diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 2e0b335ee..76385a828 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -828,6 +828,9 @@ class GGUFWriter: def add_eot_token_id(self, id: int) -> None: self.add_uint32(Keys.Tokenizer.EOT_ID, id) + def add_eom_token_id(self, id: int) -> None: + self.add_uint32(Keys.Tokenizer.EOM_ID, id) + def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes: pack_prefix = '' if not skip_pack_prefix: diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 133094904..9be076f6d 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1444,7 +1444,8 @@ llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, lla bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token) { return token != -1 && ( token == llama_token_eos_impl(vocab) || - token == llama_token_eot_impl(vocab) + token == llama_token_eot_impl(vocab) || + token == llama_token_eom_impl(vocab) ); } @@ -1500,6 +1501,10 @@ llama_token llama_token_eot_impl(const struct llama_vocab & vocab) { return vocab.special_eot_id; } +llama_token llama_token_eom_impl(const struct llama_vocab & vocab) { + return vocab.special_eom_id; +} + int32_t llama_tokenize_impl( const struct llama_vocab & vocab, const char * text, diff --git a/src/llama-vocab.h b/src/llama-vocab.h index 30b565d55..7adfc16da 100644 --- a/src/llama-vocab.h +++ b/src/llama-vocab.h @@ -45,6 +45,7 @@ struct llama_vocab { id special_suffix_id = -1; id special_middle_id = -1; id special_eot_id = -1; // TODO: move above after "eos_id", and here add "file separator" token + id special_eom_id = -1; // tokenizer flags bool tokenizer_add_space_prefix = false; @@ -101,6 +102,7 @@ llama_token llama_token_prefix_impl(const struct llama_vocab & vocab); llama_token llama_token_middle_impl(const struct llama_vocab & vocab); llama_token llama_token_suffix_impl(const struct llama_vocab & vocab); llama_token llama_token_eot_impl (const struct llama_vocab & vocab); +llama_token llama_token_eom_impl (const struct llama_vocab & vocab); int32_t llama_tokenize_impl( const struct llama_vocab & vocab, diff --git a/src/llama.cpp b/src/llama.cpp index ff234565d..a7b1c9ebd 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -359,6 +359,7 @@ enum llm_kv { LLM_KV_TOKENIZER_SUFFIX_ID, LLM_KV_TOKENIZER_MIDDLE_ID, LLM_KV_TOKENIZER_EOT_ID, + LLM_KV_TOKENIZER_EOM_ID, LLM_KV_ADAPTER_TYPE, LLM_KV_ADAPTER_LORA_ALPHA, @@ -456,6 +457,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_TOKENIZER_SUFFIX_ID, "tokenizer.ggml.suffix_token_id" }, { LLM_KV_TOKENIZER_MIDDLE_ID, "tokenizer.ggml.middle_token_id" }, { LLM_KV_TOKENIZER_EOT_ID, "tokenizer.ggml.eot_token_id" }, + { LLM_KV_TOKENIZER_EOM_ID, "tokenizer.ggml.eom_token_id" }, { LLM_KV_ADAPTER_TYPE, "adapter.type" }, { LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" }, @@ -5583,6 +5585,7 @@ static void llm_load_vocab( { LLM_KV_TOKENIZER_SUFFIX_ID, vocab.special_suffix_id }, { LLM_KV_TOKENIZER_MIDDLE_ID, vocab.special_middle_id }, { LLM_KV_TOKENIZER_EOT_ID, vocab.special_eot_id }, + { LLM_KV_TOKENIZER_EOM_ID, vocab.special_eom_id }, }; for (const auto & it : special_token_types) { @@ -5635,6 +5638,17 @@ static void llm_load_vocab( } } } + + // find EOM token: "<|eom_id|>" + // + // TODO: convert scripts should provide this token through the KV metadata LLAMA_KV_TOKENIZER_EOM_ID + // for now, we apply this workaround to find the EOM token based on its text + if (vocab.special_eom_id == -1) { + const auto & t = vocab.token_to_id.find("<|eom_id|>"); + if (t != vocab.token_to_id.end()) { + vocab.special_eom_id = t->second; + } + } } // build special tokens cache