gguf : add special tokens metadata for FIM/Infill (#6689)

This commit adds special token metadata for Fill-In-the-Middle
(FIM)/Infill to the GGUF model.

The motivation for this is that currently there is support for CodeLlama
but other models exist now like CodeGemma, but the different models use
different token ids for the special tokens and this commit allows for
supporting multiple models.

Signed-off-by: Daniel Bevenius <daniel.bevenius@gmail.com>
This commit is contained in:
Daniel Bevenius 2024-04-16 08:13:13 +02:00 committed by GitHub
parent 7593639ce3
commit 4fbd8098e6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 83 additions and 11 deletions

View file

@ -90,6 +90,11 @@ class Keys:
HF_JSON = "tokenizer.huggingface.json"
RWKV = "tokenizer.rwkv.world"
CHAT_TEMPLATE = "tokenizer.chat_template"
# FIM/Infill special tokens constants
PREFIX_ID = "tokenizer.ggml.prefix_token_id"
SUFFIX_ID = "tokenizer.ggml.suffix_token_id"
MIDDLE_ID = "tokenizer.ggml.middle_token_id"
EOT_ID = "tokenizer.ggml.eot_token_id"
#
@ -885,3 +890,7 @@ KEY_TOKENIZER_CLS_ID = Keys.Tokenizer.CLS_ID
KEY_TOKENIZER_MASK_ID = Keys.Tokenizer.MASK_ID
KEY_TOKENIZER_HF_JSON = Keys.Tokenizer.HF_JSON
KEY_TOKENIZER_RWKV = Keys.Tokenizer.RWKV
KEY_TOKENIZER_PRIFIX_ID = Keys.Tokenizer.PREFIX_ID
KEY_TOKENIZER_SUFFIX_ID = Keys.Tokenizer.SUFFIX_ID
KEY_TOKENIZER_MIDDLE_ID = Keys.Tokenizer.MIDDLE_ID
KEY_TOKENIZER_EOT_ID = Keys.Tokenizer.EOT_ID

View file

@ -469,6 +469,18 @@ class GGUFWriter:
def add_chat_template(self, value: str) -> None:
self.add_string(Keys.Tokenizer.CHAT_TEMPLATE, value)
def add_prefix_token_id(self, id: int) -> None:
self.add_uint32(Keys.Tokenizer.PREFIX_ID, id)
def add_suffix_token_id(self, id: int) -> None:
self.add_uint32(Keys.Tokenizer.SUFFIX_ID, id)
def add_middle_token_id(self, id: int) -> None:
self.add_uint32(Keys.Tokenizer.MIDDLE_ID, id)
def add_eot_token_id(self, id: int) -> None:
self.add_uint32(Keys.Tokenizer.EOT_ID, id)
def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
pack_prefix = ''
if not skip_pack_prefix: