Support converting models with multiple chat templates
Adds the following metadata: * tokenizer.chat_templates * tokenizer.chat_template.<name1> * tokenizer.chat_template.<name2> * tokenizer.chat_template.<...> Where `tokenizer.chat_templates` is an array of the template names (except `default`), `default` is added to the regular `tokenizer.chat_template`.
This commit is contained in:
parent
67fac4b95f
commit
06808a3d0d
3 changed files with 32 additions and 3 deletions
|
@ -90,6 +90,8 @@ class Keys:
|
||||||
HF_JSON = "tokenizer.huggingface.json"
|
HF_JSON = "tokenizer.huggingface.json"
|
||||||
RWKV = "tokenizer.rwkv.world"
|
RWKV = "tokenizer.rwkv.world"
|
||||||
CHAT_TEMPLATE = "tokenizer.chat_template"
|
CHAT_TEMPLATE = "tokenizer.chat_template"
|
||||||
|
CHAT_TEMPLATE_N = "tokenizer.chat_template.{name}"
|
||||||
|
CHAT_TEMPLATES = "tokenizer.chat_templates"
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
|
|
|
@ -6,7 +6,8 @@ import struct
|
||||||
import tempfile
|
import tempfile
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from io import BufferedWriter
|
from io import BufferedWriter
|
||||||
from typing import IO, Any, Sequence
|
from typing import IO, Any, Sequence, Mapping
|
||||||
|
from string import ascii_letters, digits
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
@ -466,7 +467,33 @@ class GGUFWriter:
|
||||||
def add_add_space_prefix(self, value: bool) -> None:
|
def add_add_space_prefix(self, value: bool) -> None:
|
||||||
self.add_bool(Keys.Tokenizer.ADD_PREFIX, value)
|
self.add_bool(Keys.Tokenizer.ADD_PREFIX, value)
|
||||||
|
|
||||||
def add_chat_template(self, value: str) -> None:
|
def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None:
|
||||||
|
if isinstance(value, list):
|
||||||
|
template_default = None
|
||||||
|
template_names = set()
|
||||||
|
|
||||||
|
for choice in value:
|
||||||
|
name = choice.get('name', '')
|
||||||
|
template = choice.get('template')
|
||||||
|
|
||||||
|
# Allowing non-alphanumerical characters in template name is probably not a good idea, so filter it
|
||||||
|
name = ''.join((c for c in name if c in ['_'] + list(ascii_letters) + list(digits)))
|
||||||
|
|
||||||
|
if name and template is not None:
|
||||||
|
if name == 'default':
|
||||||
|
template_default = template
|
||||||
|
else:
|
||||||
|
template_names.add(name)
|
||||||
|
self.add_string(Keys.Tokenizer.CHAT_TEMPLATE_N.format(name=name), template)
|
||||||
|
|
||||||
|
if template_names:
|
||||||
|
self.add_array(Keys.Tokenizer.CHAT_TEMPLATES, list(template_names))
|
||||||
|
|
||||||
|
if template_default is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
value = template_default
|
||||||
|
|
||||||
self.add_string(Keys.Tokenizer.CHAT_TEMPLATE, value)
|
self.add_string(Keys.Tokenizer.CHAT_TEMPLATE, value)
|
||||||
|
|
||||||
def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
|
def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
|
||||||
|
|
|
@ -141,7 +141,7 @@ class SpecialVocab:
|
||||||
with open(tokenizer_config_file, encoding = 'utf-8') as f:
|
with open(tokenizer_config_file, encoding = 'utf-8') as f:
|
||||||
tokenizer_config = json.load(f)
|
tokenizer_config = json.load(f)
|
||||||
chat_template = tokenizer_config.get('chat_template')
|
chat_template = tokenizer_config.get('chat_template')
|
||||||
if chat_template is None or isinstance(chat_template, str):
|
if chat_template is None or isinstance(chat_template, (str, list)):
|
||||||
self.chat_template = chat_template
|
self.chat_template = chat_template
|
||||||
else:
|
else:
|
||||||
print(
|
print(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue