refactor: Add logging debug and clean up logger implementation

This commit is contained in:
teleprint-me 2024-05-12 21:07:05 -04:00
parent 6be3576e01
commit 214e9e6f0b
No known key found for this signature in database
GPG key ID: B0D11345E65C4D48

View file

@ -20,19 +20,17 @@ from gguf.constants import Keys
from gguf.gguf_reader import GGUFReader # noqa: E402 from gguf.gguf_reader import GGUFReader # noqa: E402
# Configure logging # Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("gguf-chat-template") logger = logging.getLogger("gguf-chat-template")
def get_chat_template(model_file: str, verbose: bool = False) -> str: def get_chat_template(model_file: str) -> str:
reader = GGUFReader(model_file) reader = GGUFReader(model_file)
# Available keys # Available keys
logger.info("Detected model metadata!") logger.debug("Detected model metadata!")
if verbose: logger.debug("Outputting available model fields:")
logger.info("Outputting available model fields:") for key in reader.fields.keys():
for key in reader.fields.keys(): logger.debug(key)
logger.info(key)
# Access the 'chat_template' field directly using its key # Access the 'chat_template' field directly using its key
chat_template_field = reader.fields.get(Keys.Tokenizer.CHAT_TEMPLATE) chat_template_field = reader.fields.get(Keys.Tokenizer.CHAT_TEMPLATE)
@ -57,7 +55,7 @@ def display_chat_template(
chat_template (str): The extracted chat template. chat_template (str): The extracted chat template.
render_template (bool, optional): Whether to format the template using Jinja2. Defaults to False. render_template (bool, optional): Whether to format the template using Jinja2. Defaults to False.
""" """
logger.info(f"Format Template: {render_template}") logger.debug(f"Render Template: {render_template}")
if render_template: if render_template:
# Render the formatted template using Jinja2 with a context that includes 'bos_token' and 'eos_token' # Render the formatted template using Jinja2 with a context that includes 'bos_token' and 'eos_token'
@ -126,8 +124,12 @@ def main():
) )
args = parser.parse_args() args = parser.parse_args()
model_file = args.model_file if args.verbose:
chat_template = get_chat_template(model_file, args.verbose) logging.basicConfig(level=logging.DEBUG)
else:
logging.basicConfig(level=logging.INFO)
chat_template = get_chat_template(args.model_file)
display_chat_template( display_chat_template(
chat_template, args.bos, args.eos, render_template=args.render_template chat_template, args.bos, args.eos, render_template=args.render_template
) )