patch: Handle how templates are rendered if no system prompt is allowed

This commit is contained in:
teleprint-me 2024-05-12 20:17:35 -04:00
parent 4a018e706f
commit 8b9ed888bc
No known key found for this signature in database
GPG key ID: B0D11345E65C4D48

View file

@ -59,21 +59,32 @@ def display_chat_template(chat_template: str, format_template: bool = False):
if format_template:
# Render the formatted template using Jinja2 with a context that includes 'bos_token' and 'eos_token'
env = jinja2.Environment(
loader=jinja2.BaseLoader(),
trim_blocks=True,
lstrip_blocks=True,
loader=jinja2.BaseLoader(), trim_blocks=True, lstrip_blocks=True
)
logger.info(chat_template)
template = env.from_string(chat_template)
formatted_template = template.render(
messages=[
messages = [
{"role": "system", "content": "I am a helpful assistant."},
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hello! How may I assist you today?"},
],
bos_token="[BOS]",
eos_token="[EOS]",
]
bos_token = "<s>"
eos_token = "</s>"
try:
formatted_template = template.render(
messages=messages,
bos_token=bos_token,
eos_token=eos_token,
)
except jinja2.exceptions.UndefinedError:
# system message is incompatible with set format
formatted_template = template.render(
messages=messages[1:],
bos_token=bos_token,
eos_token=eos_token,
)
print(formatted_template)
else:
# Display the raw template