diff --git a/gguf-py/scripts/gguf-template.py b/gguf-py/scripts/gguf-template.py index 78d952475..ef2b78148 100644 --- a/gguf-py/scripts/gguf-template.py +++ b/gguf-py/scripts/gguf-template.py @@ -50,6 +50,7 @@ def render_chat_template( chat_template: str, bos_token: str, eos_token: str, + add_generation_prompt: bool = False, render_template: bool = False, ) -> str: """ @@ -60,6 +61,7 @@ def render_chat_template( render_template (bool, optional): Whether to format the template using Jinja2. Defaults to False. """ logger.debug(f"Render Template: {render_template}") + logger.debug(f"Add Generation Prompt: {add_generation_prompt}") if render_template: # Render the formatted template using Jinja2 with a context that includes 'bos_token' and 'eos_token' @@ -81,6 +83,7 @@ def render_chat_template( messages=messages, bos_token=bos_token, eos_token=eos_token, + add_generation_prompt=add_generation_prompt, ) except jinja2.exceptions.UndefinedError: # system message is incompatible with set format @@ -88,6 +91,7 @@ def render_chat_template( messages=messages[1:], bos_token=bos_token, eos_token=eos_token, + add_generation_prompt=add_generation_prompt, ) return formatted_template @@ -120,6 +124,12 @@ def main(): default="", help="Set a eos special token. Default is ''.", ) + parser.add_argument( + "-g", + "--agp", + action="store_true", + help="Add generation prompt. Default is True.", + ) parser.add_argument( "-v", "--verbose", @@ -138,6 +148,7 @@ def main(): chat_template, args.bos, args.eos, + add_generation_prompt=args.agp, render_template=args.render_template, ) print(rendered_template) # noqa: NP100