diff --git a/examples/pydantic-models-to-grammar-examples.py b/examples/pydantic-models-to-grammar-examples.py index 69efd3cec..a8a4919cf 100644 --- a/examples/pydantic-models-to-grammar-examples.py +++ b/examples/pydantic-models-to-grammar-examples.py @@ -8,7 +8,7 @@ import requests from pydantic import BaseModel, Field import importlib -pydantic_models_to_grammar = importlib.import_module("pydantic-models-to-grammar") +from pydantic_models_to_grammar import generate_gbnf_grammar_and_documentation # Function to get completion on the llama.cpp server with grammar. def create_completion(prompt, grammar): @@ -70,7 +70,7 @@ class Calculator(BaseModel): # outer_object_content is the name of outer object content. # model_prefix is the optional prefix for models in the documentation. (Default="Output Model") # fields_prefix is the prefix for the model fields in the documentation. (Default="Output Fields") -gbnf_grammar, documentation = pydantic_models_to_grammar.generate_gbnf_grammar_and_documentation( +gbnf_grammar, documentation = generate_gbnf_grammar_and_documentation( pydantic_model_list=[SendMessageToUser, Calculator], outer_object_name="function", outer_object_content="function_parameters", model_prefix="Function", fields_prefix="Parameters") @@ -122,7 +122,7 @@ class Book(BaseModel): # We need no additional parameters other than our list of pydantic models. -gbnf_grammar, documentation = pydantic_models_to_grammar.generate_gbnf_grammar_and_documentation([Book]) +gbnf_grammar, documentation = generate_gbnf_grammar_and_documentation([Book]) system_message = "You are an advanced AI, tasked to create a dataset entry in JSON for a Book. The following is the expected output model:\n\n" + documentation diff --git a/examples/pydantic-models-to-grammar.py b/examples/pydantic_models_to_grammar.py similarity index 98% rename from examples/pydantic-models-to-grammar.py rename to examples/pydantic_models_to_grammar.py index 1d2c1686b..41b98fdc1 100644 --- a/examples/pydantic-models-to-grammar.py +++ b/examples/pydantic_models_to_grammar.py @@ -5,7 +5,7 @@ from inspect import isclass, getdoc from types import NoneType from pydantic import BaseModel, create_model, Field -from typing import Any, Type, List, get_args, get_origin, Tuple, Union, Optional +from typing import Any, Type, List, get_args, get_origin, Tuple, Union, Optional, _GenericAlias from enum import Enum from typing import get_type_hints, Callable import re @@ -290,7 +290,7 @@ def generate_gbnf_rule_for_type(model_name, field_name, enum_rule = f"{model_name}-{field_name} ::= {' | '.join(enum_values)}" rules.append(enum_rule) gbnf_type, rules = model_name + "-" + field_name, rules - elif get_origin(field_type) == list: # Array + elif get_origin(field_type) == list or field_type == list: # Array element_type = get_args(field_type)[0] element_rule_name, additional_rules = generate_gbnf_rule_for_type(model_name, f"{field_name}-element", @@ -301,7 +301,7 @@ def generate_gbnf_rule_for_type(model_name, field_name, rules.append(array_rule) gbnf_type, rules = model_name + "-" + field_name, rules - elif get_origin(field_type) == set: # Array + elif get_origin(field_type) == set or field_type == set: # Array element_type = get_args(field_type)[0] element_rule_name, additional_rules = generate_gbnf_rule_for_type(model_name, f"{field_name}-element", @@ -335,7 +335,16 @@ def generate_gbnf_rule_for_type(model_name, field_name, union_rules = [] for union_type in union_types: - if not issubclass(union_type, NoneType): + if isinstance(union_type, _GenericAlias): + union_gbnf_type, union_rules_list = generate_gbnf_rule_for_type(model_name, + field_name, union_type, + False, + processed_models, created_rules) + union_rules.append(union_gbnf_type) + rules.extend(union_rules_list) + + + elif not issubclass(union_type, NoneType): union_gbnf_type, union_rules_list = generate_gbnf_rule_for_type(model_name, field_name, union_type, False,