gguf: Add example script for extracting chat template
This commit is contained in:
parent
6f1b63606f
commit
eac2e83f9f
1 changed files with 103 additions and 0 deletions
103
gguf-py/scripts/gguf-template.py
Normal file
103
gguf-py/scripts/gguf-template.py
Normal file
|
@ -0,0 +1,103 @@
|
|||
"""
|
||||
gguf_chat_template.py - example file to extract the chat template from the models metadata
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import jinja2
|
||||
|
||||
# Necessary to load the local gguf package
|
||||
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from gguf.constants import Keys
|
||||
from gguf.gguf_reader import GGUFReader # noqa: E402
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("gguf-chat-template")
|
||||
|
||||
|
||||
def get_chat_template(model_file: str) -> str:
|
||||
reader = GGUFReader(model_file)
|
||||
|
||||
# Available keys
|
||||
logger.info("Detected model metadata!")
|
||||
logger.info("Outputting available model fields:")
|
||||
for key in reader.fields.keys():
|
||||
logger.info(key)
|
||||
|
||||
# Access the 'chat_template' field directly using its key
|
||||
chat_template_field = reader.fields.get(Keys.Tokenizer.CHAT_TEMPLATE)
|
||||
|
||||
if chat_template_field:
|
||||
# Extract the chat template string from the field
|
||||
chat_template_memmap = chat_template_field.parts[-1]
|
||||
chat_template_string = chat_template_memmap.tobytes().decode("utf-8")
|
||||
return chat_template_string
|
||||
else:
|
||||
logger.error("Chat template field not found in model metadata.")
|
||||
return ""
|
||||
|
||||
|
||||
def display_chat_template(chat_template: str, format_template: bool = False):
|
||||
"""
|
||||
Display the chat template to standard output, optionally formatting it using Jinja2.
|
||||
|
||||
Args:
|
||||
chat_template (str): The extracted chat template.
|
||||
format_template (bool, optional): Whether to format the template using Jinja2. Defaults to False.
|
||||
"""
|
||||
logger.info(f"Format Template: {format_template}")
|
||||
|
||||
if format_template:
|
||||
# Render the formatted template using Jinja2 with a context that includes 'bos_token' and 'eos_token'
|
||||
env = jinja2.Environment(
|
||||
loader=jinja2.BaseLoader(),
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True,
|
||||
)
|
||||
logger.info(chat_template)
|
||||
template = env.from_string(chat_template)
|
||||
formatted_template = template.render(
|
||||
messages=[
|
||||
{"role": "system", "content": "I am a helpful assistant."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
],
|
||||
bos_token="[BOS]",
|
||||
eos_token="[EOS]",
|
||||
)
|
||||
print(formatted_template)
|
||||
else:
|
||||
# Display the raw template
|
||||
print(chat_template)
|
||||
|
||||
|
||||
# Example usage:
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Extract chat template from a GGUF model file"
|
||||
)
|
||||
parser.add_argument("model_file", type=str, help="Path to the GGUF model file")
|
||||
parser.add_argument(
|
||||
"--format",
|
||||
action="store_true",
|
||||
help="Format the chat template using Jinja2",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
model_file = args.model_file
|
||||
chat_template = get_chat_template(model_file)
|
||||
|
||||
display_chat_template(chat_template, format_template=args.format)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Add table
Add a link
Reference in a new issue