feat: Add option for adding generation prompt
Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com>
This commit is contained in:
parent
da96fdd15f
commit
cfe659d90a
1 changed files with 11 additions and 0 deletions
|
@ -50,6 +50,7 @@ def render_chat_template(
|
|||
chat_template: str,
|
||||
bos_token: str,
|
||||
eos_token: str,
|
||||
add_generation_prompt: bool = False,
|
||||
render_template: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
|
@ -60,6 +61,7 @@ def render_chat_template(
|
|||
render_template (bool, optional): Whether to format the template using Jinja2. Defaults to False.
|
||||
"""
|
||||
logger.debug(f"Render Template: {render_template}")
|
||||
logger.debug(f"Add Generation Prompt: {add_generation_prompt}")
|
||||
|
||||
if render_template:
|
||||
# Render the formatted template using Jinja2 with a context that includes 'bos_token' and 'eos_token'
|
||||
|
@ -81,6 +83,7 @@ def render_chat_template(
|
|||
messages=messages,
|
||||
bos_token=bos_token,
|
||||
eos_token=eos_token,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
)
|
||||
except jinja2.exceptions.UndefinedError:
|
||||
# system message is incompatible with set format
|
||||
|
@ -88,6 +91,7 @@ def render_chat_template(
|
|||
messages=messages[1:],
|
||||
bos_token=bos_token,
|
||||
eos_token=eos_token,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
)
|
||||
|
||||
return formatted_template
|
||||
|
@ -120,6 +124,12 @@ def main():
|
|||
default="</s>",
|
||||
help="Set a eos special token. Default is '</s>'.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-g",
|
||||
"--agp",
|
||||
action="store_true",
|
||||
help="Add generation prompt. Default is True.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-v",
|
||||
"--verbose",
|
||||
|
@ -138,6 +148,7 @@ def main():
|
|||
chat_template,
|
||||
args.bos,
|
||||
args.eos,
|
||||
add_generation_prompt=args.agp,
|
||||
render_template=args.render_template,
|
||||
)
|
||||
print(rendered_template) # noqa: NP100
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue