gguf: Add example script for extracting chat template
This commit is contained in:
parent
6f1b63606f
commit
eac2e83f9f
1 changed files with 103 additions and 0 deletions
103
gguf-py/scripts/gguf-template.py
Normal file
103
gguf-py/scripts/gguf-template.py
Normal file
|
@ -0,0 +1,103 @@
|
||||||
|
"""
|
||||||
|
gguf_chat_template.py - example file to extract the chat template from the models metadata
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import jinja2
|
||||||
|
|
||||||
|
# Necessary to load the local gguf package
|
||||||
|
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
|
from gguf.constants import Keys
|
||||||
|
from gguf.gguf_reader import GGUFReader # noqa: E402
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger("gguf-chat-template")
|
||||||
|
|
||||||
|
|
||||||
|
def get_chat_template(model_file: str) -> str:
|
||||||
|
reader = GGUFReader(model_file)
|
||||||
|
|
||||||
|
# Available keys
|
||||||
|
logger.info("Detected model metadata!")
|
||||||
|
logger.info("Outputting available model fields:")
|
||||||
|
for key in reader.fields.keys():
|
||||||
|
logger.info(key)
|
||||||
|
|
||||||
|
# Access the 'chat_template' field directly using its key
|
||||||
|
chat_template_field = reader.fields.get(Keys.Tokenizer.CHAT_TEMPLATE)
|
||||||
|
|
||||||
|
if chat_template_field:
|
||||||
|
# Extract the chat template string from the field
|
||||||
|
chat_template_memmap = chat_template_field.parts[-1]
|
||||||
|
chat_template_string = chat_template_memmap.tobytes().decode("utf-8")
|
||||||
|
return chat_template_string
|
||||||
|
else:
|
||||||
|
logger.error("Chat template field not found in model metadata.")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def display_chat_template(chat_template: str, format_template: bool = False):
|
||||||
|
"""
|
||||||
|
Display the chat template to standard output, optionally formatting it using Jinja2.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_template (str): The extracted chat template.
|
||||||
|
format_template (bool, optional): Whether to format the template using Jinja2. Defaults to False.
|
||||||
|
"""
|
||||||
|
logger.info(f"Format Template: {format_template}")
|
||||||
|
|
||||||
|
if format_template:
|
||||||
|
# Render the formatted template using Jinja2 with a context that includes 'bos_token' and 'eos_token'
|
||||||
|
env = jinja2.Environment(
|
||||||
|
loader=jinja2.BaseLoader(),
|
||||||
|
trim_blocks=True,
|
||||||
|
lstrip_blocks=True,
|
||||||
|
)
|
||||||
|
logger.info(chat_template)
|
||||||
|
template = env.from_string(chat_template)
|
||||||
|
formatted_template = template.render(
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "I am a helpful assistant."},
|
||||||
|
{"role": "user", "content": "Hello!"},
|
||||||
|
],
|
||||||
|
bos_token="[BOS]",
|
||||||
|
eos_token="[EOS]",
|
||||||
|
)
|
||||||
|
print(formatted_template)
|
||||||
|
else:
|
||||||
|
# Display the raw template
|
||||||
|
print(chat_template)
|
||||||
|
|
||||||
|
|
||||||
|
# Example usage:
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Extract chat template from a GGUF model file"
|
||||||
|
)
|
||||||
|
parser.add_argument("model_file", type=str, help="Path to the GGUF model file")
|
||||||
|
parser.add_argument(
|
||||||
|
"--format",
|
||||||
|
action="store_true",
|
||||||
|
help="Format the chat template using Jinja2",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
model_file = args.model_file
|
||||||
|
chat_template = get_chat_template(model_file)
|
||||||
|
|
||||||
|
display_chat_template(chat_template, format_template=args.format)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
Add table
Add a link
Reference in a new issue