feat: Add sane defaults and options for setting special tokens

This commit is contained in:
teleprint-me 2024-05-12 20:48:29 -04:00
parent fa0b0b10cc
commit 6be3576e01
No known key found for this signature in database
GPG key ID: B0D11345E65C4D48

View file

@ -47,7 +47,9 @@ def get_chat_template(model_file: str, verbose: bool = False) -> str:
return "" return ""
def display_chat_template(chat_template: str, render_template: bool = False): def display_chat_template(
chat_template: str, bos_token: str, eos_token: str, render_template: bool = False
):
""" """
Display the chat template to standard output, optionally formatting it using Jinja2. Display the chat template to standard output, optionally formatting it using Jinja2.
@ -68,9 +70,9 @@ def display_chat_template(chat_template: str, render_template: bool = False):
{"role": "system", "content": "I am a helpful assistant."}, {"role": "system", "content": "I am a helpful assistant."},
{"role": "user", "content": "Hello!"}, {"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hello! How may I assist you today?"}, {"role": "assistant", "content": "Hello! How may I assist you today?"},
{"role": "user", "content": "Can you tell me what pickled mayonnaise is?"},
{"role": "assistant", "content": "Certainly! What would you like to know about it?"},
] ]
bos_token = "<s>"
eos_token = "</s>"
try: try:
formatted_template = template.render( formatted_template = template.render(
@ -104,6 +106,18 @@ def main():
action="store_true", action="store_true",
help="Render the chat template using Jinja2", help="Render the chat template using Jinja2",
) )
parser.add_argument(
"-b",
"--bos",
default="<s>",
help="Set a bos special token. Default is '<s>'.",
)
parser.add_argument(
"-e",
"--eos",
default="</s>",
help="Set a eos special token. Default is '</s>'.",
)
parser.add_argument( parser.add_argument(
"-v", "-v",
"--verbose", "--verbose",
@ -114,7 +128,9 @@ def main():
model_file = args.model_file model_file = args.model_file
chat_template = get_chat_template(model_file, args.verbose) chat_template = get_chat_template(model_file, args.verbose)
display_chat_template(chat_template, render_template=args.render_template) display_chat_template(
chat_template, args.bos, args.eos, render_template=args.render_template
)
if __name__ == "__main__": if __name__ == "__main__":