refactor: Remove rename from display to render and return result instead of printing

This commit is contained in:
teleprint-me 2024-05-12 21:17:04 -04:00
parent 214e9e6f0b
commit f8bb223924
No known key found for this signature in database
GPG key ID: B0D11345E65C4D48

View file

@ -45,9 +45,9 @@ def get_chat_template(model_file: str) -> str:
return "" return ""
def display_chat_template( def render_chat_template(
chat_template: str, bos_token: str, eos_token: str, render_template: bool = False chat_template: str, bos_token: str, eos_token: str, render_template: bool = False
): ) -> str:
""" """
Display the chat template to standard output, optionally formatting it using Jinja2. Display the chat template to standard output, optionally formatting it using Jinja2.
@ -86,10 +86,10 @@ def display_chat_template(
eos_token=eos_token, eos_token=eos_token,
) )
print(formatted_template) return formatted_template
else: else:
# Display the raw template # Display the raw template
print(chat_template) return chat_template
# Example usage: # Example usage:
@ -130,9 +130,10 @@ def main():
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
chat_template = get_chat_template(args.model_file) chat_template = get_chat_template(args.model_file)
display_chat_template( rendered_template = render_chat_template(
chat_template, args.bos, args.eos, render_template=args.render_template chat_template, args.bos, args.eos, render_template=args.render_template
) )
print(rendered_template)
if __name__ == "__main__": if __name__ == "__main__":