feat: Add option for adding generation prompt

Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com>
This commit is contained in:
teleprint-me 2024-05-13 13:12:38 -04:00
parent da96fdd15f
commit cfe659d90a
No known key found for this signature in database
GPG key ID: B0D11345E65C4D48

View file

@ -50,6 +50,7 @@ def render_chat_template(
chat_template: str,
bos_token: str,
eos_token: str,
add_generation_prompt: bool = False,
render_template: bool = False,
) -> str:
"""
@ -60,6 +61,7 @@ def render_chat_template(
render_template (bool, optional): Whether to format the template using Jinja2. Defaults to False.
"""
logger.debug(f"Render Template: {render_template}")
logger.debug(f"Add Generation Prompt: {add_generation_prompt}")
if render_template:
# Render the formatted template using Jinja2 with a context that includes 'bos_token' and 'eos_token'
@ -81,6 +83,7 @@ def render_chat_template(
messages=messages,
bos_token=bos_token,
eos_token=eos_token,
add_generation_prompt=add_generation_prompt,
)
except jinja2.exceptions.UndefinedError:
# system message is incompatible with set format
@ -88,6 +91,7 @@ def render_chat_template(
messages=messages[1:],
bos_token=bos_token,
eos_token=eos_token,
add_generation_prompt=add_generation_prompt,
)
return formatted_template
@ -120,6 +124,12 @@ def main():
default="</s>",
help="Set a eos special token. Default is '</s>'.",
)
parser.add_argument(
"-g",
"--agp",
action="store_true",
help="Add generation prompt. Default is True.",
)
parser.add_argument(
"-v",
"--verbose",
@ -138,6 +148,7 @@ def main():
chat_template,
args.bos,
args.eos,
add_generation_prompt=args.agp,
render_template=args.render_template,
)
print(rendered_template) # noqa: NP100