refactor: Use render template instead of format

This commit is contained in:
teleprint-me 2024-05-12 20:37:27 -04:00
parent 8b9ed888bc
commit 668c7ee6c5
No known key found for this signature in database
GPG key ID: B0D11345E65C4D48

View file

@ -46,17 +46,17 @@ def get_chat_template(model_file: str) -> str:
return "" return ""
def display_chat_template(chat_template: str, format_template: bool = False): def display_chat_template(chat_template: str, render_template: bool = False):
""" """
Display the chat template to standard output, optionally formatting it using Jinja2. Display the chat template to standard output, optionally formatting it using Jinja2.
Args: Args:
chat_template (str): The extracted chat template. chat_template (str): The extracted chat template.
format_template (bool, optional): Whether to format the template using Jinja2. Defaults to False. render_template (bool, optional): Whether to format the template using Jinja2. Defaults to False.
""" """
logger.info(f"Format Template: {format_template}") logger.info(f"Format Template: {render_template}")
if format_template: if render_template:
# Render the formatted template using Jinja2 with a context that includes 'bos_token' and 'eos_token' # Render the formatted template using Jinja2 with a context that includes 'bos_token' and 'eos_token'
env = jinja2.Environment( env = jinja2.Environment(
loader=jinja2.BaseLoader(), trim_blocks=True, lstrip_blocks=True loader=jinja2.BaseLoader(), trim_blocks=True, lstrip_blocks=True
@ -98,9 +98,10 @@ def main():
) )
parser.add_argument("model_file", type=str, help="Path to the GGUF model file") parser.add_argument("model_file", type=str, help="Path to the GGUF model file")
parser.add_argument( parser.add_argument(
"--format", "-r",
"--render-template",
action="store_true", action="store_true",
help="Format the chat template using Jinja2", help="Render the chat template using Jinja2",
) )
args = parser.parse_args() args = parser.parse_args()
@ -108,7 +109,7 @@ def main():
model_file = args.model_file model_file = args.model_file
chat_template = get_chat_template(model_file) chat_template = get_chat_template(model_file)
display_chat_template(chat_template, format_template=args.format) display_chat_template(chat_template, render_template=args.render_template)
if __name__ == "__main__": if __name__ == "__main__":