feat: Add option for adding generation prompt
Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com>
This commit is contained in:
parent
da96fdd15f
commit
cfe659d90a
1 changed files with 11 additions and 0 deletions
|
@ -50,6 +50,7 @@ def render_chat_template(
|
||||||
chat_template: str,
|
chat_template: str,
|
||||||
bos_token: str,
|
bos_token: str,
|
||||||
eos_token: str,
|
eos_token: str,
|
||||||
|
add_generation_prompt: bool = False,
|
||||||
render_template: bool = False,
|
render_template: bool = False,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
|
@ -60,6 +61,7 @@ def render_chat_template(
|
||||||
render_template (bool, optional): Whether to format the template using Jinja2. Defaults to False.
|
render_template (bool, optional): Whether to format the template using Jinja2. Defaults to False.
|
||||||
"""
|
"""
|
||||||
logger.debug(f"Render Template: {render_template}")
|
logger.debug(f"Render Template: {render_template}")
|
||||||
|
logger.debug(f"Add Generation Prompt: {add_generation_prompt}")
|
||||||
|
|
||||||
if render_template:
|
if render_template:
|
||||||
# Render the formatted template using Jinja2 with a context that includes 'bos_token' and 'eos_token'
|
# Render the formatted template using Jinja2 with a context that includes 'bos_token' and 'eos_token'
|
||||||
|
@ -81,6 +83,7 @@ def render_chat_template(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
bos_token=bos_token,
|
bos_token=bos_token,
|
||||||
eos_token=eos_token,
|
eos_token=eos_token,
|
||||||
|
add_generation_prompt=add_generation_prompt,
|
||||||
)
|
)
|
||||||
except jinja2.exceptions.UndefinedError:
|
except jinja2.exceptions.UndefinedError:
|
||||||
# system message is incompatible with set format
|
# system message is incompatible with set format
|
||||||
|
@ -88,6 +91,7 @@ def render_chat_template(
|
||||||
messages=messages[1:],
|
messages=messages[1:],
|
||||||
bos_token=bos_token,
|
bos_token=bos_token,
|
||||||
eos_token=eos_token,
|
eos_token=eos_token,
|
||||||
|
add_generation_prompt=add_generation_prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
return formatted_template
|
return formatted_template
|
||||||
|
@ -120,6 +124,12 @@ def main():
|
||||||
default="</s>",
|
default="</s>",
|
||||||
help="Set a eos special token. Default is '</s>'.",
|
help="Set a eos special token. Default is '</s>'.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-g",
|
||||||
|
"--agp",
|
||||||
|
action="store_true",
|
||||||
|
help="Add generation prompt. Default is True.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-v",
|
"-v",
|
||||||
"--verbose",
|
"--verbose",
|
||||||
|
@ -138,6 +148,7 @@ def main():
|
||||||
chat_template,
|
chat_template,
|
||||||
args.bos,
|
args.bos,
|
||||||
args.eos,
|
args.eos,
|
||||||
|
add_generation_prompt=args.agp,
|
||||||
render_template=args.render_template,
|
render_template=args.render_template,
|
||||||
)
|
)
|
||||||
print(rendered_template) # noqa: NP100
|
print(rendered_template) # noqa: NP100
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue